lookups.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. import re
  2. from django.contrib.gis.db.models.fields import BaseSpatialField
  3. from django.contrib.gis.measure import Distance
  4. from django.db import NotSupportedError
  5. from django.db.models.expressions import Expression
  6. from django.db.models.lookups import Lookup, Transform
  7. from django.db.models.sql.query import Query
  8. class RasterBandTransform(Transform):
  9. def as_sql(self, compiler, connection):
  10. return compiler.compile(self.lhs)
  11. class GISLookup(Lookup):
  12. sql_template = None
  13. transform_func = None
  14. distance = False
  15. band_rhs = None
  16. band_lhs = None
  17. def __init__(self, lhs, rhs):
  18. rhs, *self.rhs_params = rhs if isinstance(rhs, (list, tuple)) else [rhs]
  19. super().__init__(lhs, rhs)
  20. self.template_params = {}
  21. self.process_rhs_params()
  22. def process_rhs_params(self):
  23. if self.rhs_params:
  24. # Check if a band index was passed in the query argument.
  25. if len(self.rhs_params) == (2 if self.lookup_name == 'relate' else 1):
  26. self.process_band_indices()
  27. elif len(self.rhs_params) > 1:
  28. raise ValueError('Tuple too long for lookup %s.' % self.lookup_name)
  29. elif isinstance(self.lhs, RasterBandTransform):
  30. self.process_band_indices(only_lhs=True)
  31. def process_band_indices(self, only_lhs=False):
  32. """
  33. Extract the lhs band index from the band transform class and the rhs
  34. band index from the input tuple.
  35. """
  36. # PostGIS band indices are 1-based, so the band index needs to be
  37. # increased to be consistent with the GDALRaster band indices.
  38. if only_lhs:
  39. self.band_rhs = 1
  40. self.band_lhs = self.lhs.band_index + 1
  41. return
  42. if isinstance(self.lhs, RasterBandTransform):
  43. self.band_lhs = self.lhs.band_index + 1
  44. else:
  45. self.band_lhs = 1
  46. self.band_rhs, *self.rhs_params = self.rhs_params
  47. def get_db_prep_lookup(self, value, connection):
  48. # get_db_prep_lookup is called by process_rhs from super class
  49. return ('%s', [connection.ops.Adapter(value)])
  50. def process_rhs(self, compiler, connection):
  51. if isinstance(self.rhs, Query):
  52. # If rhs is some Query, don't touch it.
  53. return super().process_rhs(compiler, connection)
  54. if isinstance(self.rhs, Expression):
  55. self.rhs = self.rhs.resolve_expression(compiler.query)
  56. rhs, rhs_params = super().process_rhs(compiler, connection)
  57. placeholder = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler)
  58. return placeholder % rhs, rhs_params
  59. def get_rhs_op(self, connection, rhs):
  60. # Unlike BuiltinLookup, the GIS get_rhs_op() implementation should return
  61. # an object (SpatialOperator) with an as_sql() method to allow for more
  62. # complex computations (where the lhs part can be mixed in).
  63. return connection.ops.gis_operators[self.lookup_name]
  64. def as_sql(self, compiler, connection):
  65. lhs_sql, sql_params = self.process_lhs(compiler, connection)
  66. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  67. sql_params.extend(rhs_params)
  68. template_params = {'lhs': lhs_sql, 'rhs': rhs_sql, 'value': '%s', **self.template_params}
  69. rhs_op = self.get_rhs_op(connection, rhs_sql)
  70. return rhs_op.as_sql(connection, self, template_params, sql_params)
  71. # ------------------
  72. # Geometry operators
  73. # ------------------
  74. @BaseSpatialField.register_lookup
  75. class OverlapsLeftLookup(GISLookup):
  76. """
  77. The overlaps_left operator returns true if A's bounding box overlaps or is to the
  78. left of B's bounding box.
  79. """
  80. lookup_name = 'overlaps_left'
  81. @BaseSpatialField.register_lookup
  82. class OverlapsRightLookup(GISLookup):
  83. """
  84. The 'overlaps_right' operator returns true if A's bounding box overlaps or is to the
  85. right of B's bounding box.
  86. """
  87. lookup_name = 'overlaps_right'
  88. @BaseSpatialField.register_lookup
  89. class OverlapsBelowLookup(GISLookup):
  90. """
  91. The 'overlaps_below' operator returns true if A's bounding box overlaps or is below
  92. B's bounding box.
  93. """
  94. lookup_name = 'overlaps_below'
  95. @BaseSpatialField.register_lookup
  96. class OverlapsAboveLookup(GISLookup):
  97. """
  98. The 'overlaps_above' operator returns true if A's bounding box overlaps or is above
  99. B's bounding box.
  100. """
  101. lookup_name = 'overlaps_above'
  102. @BaseSpatialField.register_lookup
  103. class LeftLookup(GISLookup):
  104. """
  105. The 'left' operator returns true if A's bounding box is strictly to the left
  106. of B's bounding box.
  107. """
  108. lookup_name = 'left'
  109. @BaseSpatialField.register_lookup
  110. class RightLookup(GISLookup):
  111. """
  112. The 'right' operator returns true if A's bounding box is strictly to the right
  113. of B's bounding box.
  114. """
  115. lookup_name = 'right'
  116. @BaseSpatialField.register_lookup
  117. class StrictlyBelowLookup(GISLookup):
  118. """
  119. The 'strictly_below' operator returns true if A's bounding box is strictly below B's
  120. bounding box.
  121. """
  122. lookup_name = 'strictly_below'
  123. @BaseSpatialField.register_lookup
  124. class StrictlyAboveLookup(GISLookup):
  125. """
  126. The 'strictly_above' operator returns true if A's bounding box is strictly above B's
  127. bounding box.
  128. """
  129. lookup_name = 'strictly_above'
  130. @BaseSpatialField.register_lookup
  131. class SameAsLookup(GISLookup):
  132. """
  133. The "~=" operator is the "same as" operator. It tests actual geometric
  134. equality of two features. So if A and B are the same feature,
  135. vertex-by-vertex, the operator returns true.
  136. """
  137. lookup_name = 'same_as'
  138. BaseSpatialField.register_lookup(SameAsLookup, 'exact')
  139. @BaseSpatialField.register_lookup
  140. class BBContainsLookup(GISLookup):
  141. """
  142. The 'bbcontains' operator returns true if A's bounding box completely contains
  143. by B's bounding box.
  144. """
  145. lookup_name = 'bbcontains'
  146. @BaseSpatialField.register_lookup
  147. class BBOverlapsLookup(GISLookup):
  148. """
  149. The 'bboverlaps' operator returns true if A's bounding box overlaps B's bounding box.
  150. """
  151. lookup_name = 'bboverlaps'
  152. @BaseSpatialField.register_lookup
  153. class ContainedLookup(GISLookup):
  154. """
  155. The 'contained' operator returns true if A's bounding box is completely contained
  156. by B's bounding box.
  157. """
  158. lookup_name = 'contained'
  159. # ------------------
  160. # Geometry functions
  161. # ------------------
  162. @BaseSpatialField.register_lookup
  163. class ContainsLookup(GISLookup):
  164. lookup_name = 'contains'
  165. @BaseSpatialField.register_lookup
  166. class ContainsProperlyLookup(GISLookup):
  167. lookup_name = 'contains_properly'
  168. @BaseSpatialField.register_lookup
  169. class CoveredByLookup(GISLookup):
  170. lookup_name = 'coveredby'
  171. @BaseSpatialField.register_lookup
  172. class CoversLookup(GISLookup):
  173. lookup_name = 'covers'
  174. @BaseSpatialField.register_lookup
  175. class CrossesLookup(GISLookup):
  176. lookup_name = 'crosses'
  177. @BaseSpatialField.register_lookup
  178. class DisjointLookup(GISLookup):
  179. lookup_name = 'disjoint'
  180. @BaseSpatialField.register_lookup
  181. class EqualsLookup(GISLookup):
  182. lookup_name = 'equals'
  183. @BaseSpatialField.register_lookup
  184. class IntersectsLookup(GISLookup):
  185. lookup_name = 'intersects'
  186. @BaseSpatialField.register_lookup
  187. class OverlapsLookup(GISLookup):
  188. lookup_name = 'overlaps'
  189. @BaseSpatialField.register_lookup
  190. class RelateLookup(GISLookup):
  191. lookup_name = 'relate'
  192. sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)'
  193. pattern_regex = re.compile(r'^[012TF\*]{9}$')
  194. def process_rhs(self, compiler, connection):
  195. # Check the pattern argument
  196. pattern = self.rhs_params[0]
  197. backend_op = connection.ops.gis_operators[self.lookup_name]
  198. if hasattr(backend_op, 'check_relate_argument'):
  199. backend_op.check_relate_argument(pattern)
  200. elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
  201. raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
  202. sql, params = super().process_rhs(compiler, connection)
  203. return sql, params + [pattern]
  204. @BaseSpatialField.register_lookup
  205. class TouchesLookup(GISLookup):
  206. lookup_name = 'touches'
  207. @BaseSpatialField.register_lookup
  208. class WithinLookup(GISLookup):
  209. lookup_name = 'within'
  210. class DistanceLookupBase(GISLookup):
  211. distance = True
  212. sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
  213. def process_rhs_params(self):
  214. if not 1 <= len(self.rhs_params) <= 3:
  215. raise ValueError("2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name)
  216. elif len(self.rhs_params) == 3 and self.rhs_params[2] != 'spheroid':
  217. raise ValueError("For 4-element tuples the last argument must be the 'spheroid' directive.")
  218. # Check if the second parameter is a band index.
  219. if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid':
  220. self.process_band_indices()
  221. def process_distance(self, compiler, connection):
  222. dist_param = self.rhs_params[0]
  223. return (
  224. compiler.compile(dist_param.resolve_expression(compiler.query))
  225. if hasattr(dist_param, 'resolve_expression') else
  226. ('%s', connection.ops.get_distance(self.lhs.output_field, self.rhs_params, self.lookup_name))
  227. )
  228. @BaseSpatialField.register_lookup
  229. class DWithinLookup(DistanceLookupBase):
  230. lookup_name = 'dwithin'
  231. sql_template = '%(func)s(%(lhs)s, %(rhs)s, %(value)s)'
  232. def process_distance(self, compiler, connection):
  233. dist_param = self.rhs_params[0]
  234. if (
  235. not connection.features.supports_dwithin_distance_expr and
  236. hasattr(dist_param, 'resolve_expression') and
  237. not isinstance(dist_param, Distance)
  238. ):
  239. raise NotSupportedError(
  240. 'This backend does not support expressions for specifying '
  241. 'distance in the dwithin lookup.'
  242. )
  243. return super().process_distance(compiler, connection)
  244. def process_rhs(self, compiler, connection):
  245. dist_sql, dist_params = self.process_distance(compiler, connection)
  246. self.template_params['value'] = dist_sql
  247. rhs_sql, params = super().process_rhs(compiler, connection)
  248. return rhs_sql, params + dist_params
  249. class DistanceLookupFromFunction(DistanceLookupBase):
  250. def as_sql(self, compiler, connection):
  251. spheroid = (len(self.rhs_params) == 2 and self.rhs_params[-1] == 'spheroid') or None
  252. distance_expr = connection.ops.distance_expr_for_lookup(self.lhs, self.rhs, spheroid=spheroid)
  253. sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query))
  254. dist_sql, dist_params = self.process_distance(compiler, connection)
  255. return (
  256. '%(func)s %(op)s %(dist)s' % {'func': sql, 'op': self.op, 'dist': dist_sql},
  257. params + dist_params,
  258. )
  259. @BaseSpatialField.register_lookup
  260. class DistanceGTLookup(DistanceLookupFromFunction):
  261. lookup_name = 'distance_gt'
  262. op = '>'
  263. @BaseSpatialField.register_lookup
  264. class DistanceGTELookup(DistanceLookupFromFunction):
  265. lookup_name = 'distance_gte'
  266. op = '>='
  267. @BaseSpatialField.register_lookup
  268. class DistanceLTLookup(DistanceLookupFromFunction):
  269. lookup_name = 'distance_lt'
  270. op = '<'
  271. @BaseSpatialField.register_lookup
  272. class DistanceLTELookup(DistanceLookupFromFunction):
  273. lookup_name = 'distance_lte'
  274. op = '<='