Note
Click here to download the full example code
Automatically use a function at insert or selectΒΆ
Sometimes the application wants to apply a function in an insert or in a select.
For example, the application might need the geometry with lat/lon coordinates while they
are projected in the DB. To avoid having to always tweak the query with a
ST_Transform()
, it is possible to define a TypeDecorator
11 from sqlalchemy import create_engine
12 from sqlalchemy import MetaData
13 from sqlalchemy import Column
14 from sqlalchemy import Integer
15 from sqlalchemy import func
16 from sqlalchemy.ext.declarative import declarative_base
17 from sqlalchemy.orm import sessionmaker
18 from sqlalchemy.types import TypeDecorator
19
20 from geoalchemy2 import Geometry
21 from geoalchemy2 import shape
22
23
24 engine = create_engine('postgresql://gis:gis@localhost/gis', echo=True)
25 metadata = MetaData(engine)
26
27 Base = declarative_base(metadata=metadata)
28
29
30 class TransformedGeometry(TypeDecorator):
31 """This class is used to insert a ST_Transform() in each insert or select."""
32 impl = Geometry
33
34 def __init__(self, db_srid, app_srid, **kwargs):
35 kwargs["srid"] = db_srid
36 self.impl = self.__class__.impl(**kwargs)
37 self.app_srid = app_srid
38 self.db_srid = db_srid
39
40 def column_expression(self, col):
41 """The column_expression() method is overrided to ensure that the
42 SRID of the resulting WKBElement is correct"""
43 return getattr(func, self.impl.as_binary)(
44 func.ST_Transform(col, self.app_srid),
45 type_=self.__class__.impl(srid=self.app_srid)
46 # srid could also be -1 so that the SRID is deduced from the
47 # WKB data
48 )
49
50 def bind_expression(self, bindvalue):
51 return func.ST_Transform(
52 self.impl.bind_expression(bindvalue), self.db_srid)
53
54
55 class ThreeDGeometry(TypeDecorator):
56 """This class is used to insert a ST_Force3D() in each insert."""
57 impl = Geometry
58
59 def bind_expression(self, bindvalue):
60 return func.ST_Force3D(self.impl.bind_expression(bindvalue))
61
62
63 class Point(Base):
64 __tablename__ = "point"
65 id = Column(Integer, primary_key=True)
66 raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
67 geom = Column(
68 TransformedGeometry(
69 db_srid=2154, app_srid=4326, geometry_type="POINT"))
70 three_d_geom = Column(
71 ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
72
73
74 session = sessionmaker(bind=engine)()
75
76
77 def check_wkb(wkb, x, y):
78 pt = shape.to_shape(wkb)
79 assert round(pt.x, 5) == x
80 assert round(pt.y, 5) == y
81
82
83 class TestTypeDecorator():
84
85 def setup(self):
86 metadata.drop_all(checkfirst=True)
87 metadata.create_all()
88
89 def teardown(self):
90 session.rollback()
91 metadata.drop_all()
92
93 def _create_one_point(self):
94 # Create new point instance
95 p = Point()
96 p.raw_geom = "SRID=4326;POINT(5 45)"
97 p.geom = "SRID=4326;POINT(5 45)"
98 p.three_d_geom = "SRID=4326;POINT(5 45)" # Insert 2D geometry into 3D column
99
100 # Insert point
101 session.add(p)
102 session.flush()
103 session.expire(p)
104
105 return p.id
106
107 def test_transform(self):
108 self._create_one_point()
109
110 # Query the point and check the result
111 pt = session.query(Point).one()
112 assert pt.id == 1
113 assert pt.raw_geom.srid == 4326
114 check_wkb(pt.raw_geom, 5, 45)
115
116 assert pt.geom.srid == 4326
117 check_wkb(pt.geom, 5, 45)
118
119 # Check that the data is correct in DB using raw query
120 q = "SELECT id, ST_AsEWKT(geom) AS geom FROM point;"
121 res_q = session.execute(q).fetchone()
122 assert res_q.id == 1
123 assert res_q.geom == "SRID=2154;POINT(857581.899319668 6435414.7478354)"
124
125 # Compare geom, raw_geom with auto transform and explicit transform
126 pt_trans = session.query(
127 Point,
128 Point.raw_geom,
129 func.ST_Transform(Point.raw_geom, 2154).label("trans")
130 ).one()
131
132 assert pt_trans[0].id == 1
133
134 assert pt_trans[0].geom.srid == 4326
135 check_wkb(pt_trans[0].geom, 5, 45)
136
137 assert pt_trans[0].raw_geom.srid == 4326
138 check_wkb(pt_trans[0].raw_geom, 5, 45)
139
140 assert pt_trans[1].srid == 4326
141 check_wkb(pt_trans[1], 5, 45)
142
143 assert pt_trans[2].srid == 2154
144 check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
145
146 def test_force_3d(self):
147 self._create_one_point()
148
149 # Query the point and check the result
150 pt = session.query(Point).one()
151
152 assert pt.id == 1
153 assert pt.three_d_geom.srid == 4326
154 assert pt.three_d_geom.desc.lower() == (
155 '01010000a0e6100000000000000000144000000000008046400000000000000000')
Total running time of the script: ( 0 minutes 0.000 seconds)