Automatically use a function at insert or selectΒΆ

Sometimes the application wants to apply a function in an insert or in a select. For example, the application might need the geometry with lat/lon coordinates while they are projected in the DB. To avoid having to always tweak the query with a ST_Transform(), it is possible to define a TypeDecorator

 11 from sqlalchemy import Column
 12 from sqlalchemy import Integer
 13 from sqlalchemy import MetaData
 14 from sqlalchemy import func
 15 from sqlalchemy import text
 16 from sqlalchemy.ext.declarative import declarative_base
 17 from sqlalchemy.types import TypeDecorator
 18
 19 from geoalchemy2 import Geometry
 20 from geoalchemy2 import shape
 21
 22 # Tests imports
 23 from tests import test_only_with_dialects
 24
 25 metadata = MetaData()
 26
 27 Base = declarative_base(metadata=metadata)
 28
 29
 30 class TransformedGeometry(TypeDecorator):
 31     """This class is used to insert a ST_Transform() in each insert or select."""
 32     impl = Geometry
 33
 34     def __init__(self, db_srid, app_srid, **kwargs):
 35         kwargs["srid"] = db_srid
 36         self.impl = self.__class__.impl(**kwargs)
 37         self.app_srid = app_srid
 38         self.db_srid = db_srid
 39
 40     def column_expression(self, col):
 41         """The column_expression() method is overridden to ensure that the
 42         SRID of the resulting WKBElement is correct"""
 43         return getattr(func, self.impl.as_binary)(
 44             func.ST_Transform(col, self.app_srid),
 45             type_=self.__class__.impl(srid=self.app_srid)
 46             # srid could also be -1 so that the SRID is deduced from the
 47             # WKB data
 48         )
 49
 50     def bind_expression(self, bindvalue):
 51         return func.ST_Transform(
 52             self.impl.bind_expression(bindvalue), self.db_srid)
 53
 54
 55 class ThreeDGeometry(TypeDecorator):
 56     """This class is used to insert a ST_Force3D() in each insert."""
 57     impl = Geometry
 58
 59     def bind_expression(self, bindvalue):
 60         return func.ST_Force3D(self.impl.bind_expression(bindvalue))
 61
 62
 63 class Point(Base):
 64     __tablename__ = "point"
 65     id = Column(Integer, primary_key=True)
 66     raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
 67     geom = Column(
 68         TransformedGeometry(
 69             db_srid=2154, app_srid=4326, geometry_type="POINT"))
 70     three_d_geom = Column(
 71         ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
 72
 73
 74 def check_wkb(wkb, x, y):
 75     pt = shape.to_shape(wkb)
 76     assert round(pt.x, 5) == x
 77     assert round(pt.y, 5) == y
 78
 79
 80 @test_only_with_dialects("postgresql")
 81 class TestTypeDecorator():
 82
 83     def _create_one_point(self, session, conn):
 84         metadata.drop_all(conn, checkfirst=True)
 85         metadata.create_all(conn)
 86
 87         # Create new point instance
 88         p = Point()
 89         p.raw_geom = "SRID=4326;POINT(5 45)"
 90         p.geom = "SRID=4326;POINT(5 45)"
 91         p.three_d_geom = "SRID=4326;POINT(5 45)"  # Insert 2D geometry into 3D column
 92
 93         # Insert point
 94         session.add(p)
 95         session.flush()
 96         session.expire(p)
 97
 98         return p.id
 99
100     def test_transform(self, session, conn):
101         self._create_one_point(session, conn)
102
103         # Query the point and check the result
104         pt = session.query(Point).one()
105         assert pt.id == 1
106         assert pt.raw_geom.srid == 4326
107         check_wkb(pt.raw_geom, 5, 45)
108
109         assert pt.geom.srid == 4326
110         check_wkb(pt.geom, 5, 45)
111
112         # Check that the data is correct in DB using raw query
113         q = text("SELECT id, ST_AsEWKT(geom) AS geom FROM point;")
114         res_q = session.execute(q).fetchone()
115         assert res_q.id == 1
116         assert res_q.geom == "SRID=2154;POINT(857581.899319668 6435414.7478354)"
117
118         # Compare geom, raw_geom with auto transform and explicit transform
119         pt_trans = session.query(
120             Point,
121             Point.raw_geom,
122             func.ST_Transform(Point.raw_geom, 2154).label("trans")
123         ).one()
124
125         assert pt_trans[0].id == 1
126
127         assert pt_trans[0].geom.srid == 4326
128         check_wkb(pt_trans[0].geom, 5, 45)
129
130         assert pt_trans[0].raw_geom.srid == 4326
131         check_wkb(pt_trans[0].raw_geom, 5, 45)
132
133         assert pt_trans[1].srid == 4326
134         check_wkb(pt_trans[1], 5, 45)
135
136         assert pt_trans[2].srid == 2154
137         check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
138
139     def test_force_3d(self, session, conn):
140         self._create_one_point(session, conn)
141
142         # Query the point and check the result
143         pt = session.query(Point).one()
144
145         assert pt.id == 1
146         assert pt.three_d_geom.srid == 4326
147         assert pt.three_d_geom.desc.lower() == (
148             '01010000a0e6100000000000000000144000000000008046400000000000000000')

Gallery generated by Sphinx-Gallery