aggregates.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from django.contrib.gis.db.models.fields import (
  2. ExtentField, GeometryCollectionField, GeometryField, LineStringField,
  3. )
  4. from django.db.models import Value
  5. from django.db.models.aggregates import Aggregate
  6. from django.utils.functional import cached_property
  7. __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union']
  8. class GeoAggregate(Aggregate):
  9. function = None
  10. is_extent = False
  11. @cached_property
  12. def output_field(self):
  13. return self.output_field_class(self.source_expressions[0].output_field.srid)
  14. def as_sql(self, compiler, connection, function=None, **extra_context):
  15. # this will be called again in parent, but it's needed now - before
  16. # we get the spatial_aggregate_name
  17. connection.ops.check_expression_support(self)
  18. return super().as_sql(
  19. compiler,
  20. connection,
  21. function=function or connection.ops.spatial_aggregate_name(self.name),
  22. **extra_context
  23. )
  24. def as_oracle(self, compiler, connection, **extra_context):
  25. if not self.is_extent:
  26. tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05)
  27. clone = self.copy()
  28. clone.set_source_expressions([
  29. *self.get_source_expressions(),
  30. Value(tolerance),
  31. ])
  32. template = '%(function)s(SDOAGGRTYPE(%(expressions)s))'
  33. return clone.as_sql(compiler, connection, template=template, **extra_context)
  34. return self.as_sql(compiler, connection, **extra_context)
  35. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  36. c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  37. for expr in c.get_source_expressions():
  38. if not hasattr(expr.field, 'geom_type'):
  39. raise ValueError('Geospatial aggregates only allowed on geometry fields.')
  40. return c
  41. class Collect(GeoAggregate):
  42. name = 'Collect'
  43. output_field_class = GeometryCollectionField
  44. class Extent(GeoAggregate):
  45. name = 'Extent'
  46. is_extent = '2D'
  47. def __init__(self, expression, **extra):
  48. super().__init__(expression, output_field=ExtentField(), **extra)
  49. def convert_value(self, value, expression, connection):
  50. return connection.ops.convert_extent(value)
  51. class Extent3D(GeoAggregate):
  52. name = 'Extent3D'
  53. is_extent = '3D'
  54. def __init__(self, expression, **extra):
  55. super().__init__(expression, output_field=ExtentField(), **extra)
  56. def convert_value(self, value, expression, connection):
  57. return connection.ops.convert_extent3d(value)
  58. class MakeLine(GeoAggregate):
  59. name = 'MakeLine'
  60. output_field_class = LineStringField
  61. class Union(GeoAggregate):
  62. name = 'Union'
  63. output_field_class = GeometryField