ranges.py 8.8 KB


  1. import datetime
  2. import json
  3. from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
  4. from django.contrib.postgres import forms, lookups
  5. from django.db import models
  6. from .utils import AttributeSetter
  7. __all__ = [
  8. 'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
  9. 'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField',
  10. 'FloatRangeField',
  11. 'RangeBoundary', 'RangeOperators',
  12. ]
  13. class RangeBoundary(models.Expression):
  14. """A class that represents range boundaries."""
  15. def __init__(self, inclusive_lower=True, inclusive_upper=False):
  16. self.lower = '[' if inclusive_lower else '('
  17. self.upper = ']' if inclusive_upper else ')'
  18. def as_sql(self, compiler, connection):
  19. return "'%s%s'" % (self.lower, self.upper), []
  20. class RangeOperators:
  21. # https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
  22. EQUAL = '='
  23. NOT_EQUAL = '<>'
  24. CONTAINS = '@>'
  25. CONTAINED_BY = '<@'
  26. OVERLAPS = '&&'
  27. FULLY_LT = '<<'
  28. FULLY_GT = '>>'
  29. NOT_LT = '&>'
  30. NOT_GT = '&<'
  31. ADJACENT_TO = '-|-'
  32. class RangeField(models.Field):
  33. empty_strings_allowed = False
  34. def __init__(self, *args, **kwargs):
  35. # Initializing base_field here ensures that its model matches the model for self.
  36. if hasattr(self, 'base_field'):
  37. self.base_field = self.base_field()
  38. super().__init__(*args, **kwargs)
  39. @property
  40. def model(self):
  41. try:
  42. return self.__dict__['model']
  43. except KeyError:
  44. raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
  45. @model.setter
  46. def model(self, model):
  47. self.__dict__['model'] = model
  48. self.base_field.model = model
  49. def get_prep_value(self, value):
  50. if value is None:
  51. return None
  52. elif isinstance(value, Range):
  53. return value
  54. elif isinstance(value, (list, tuple)):
  55. return self.range_type(value[0], value[1])
  56. return value
  57. def to_python(self, value):
  58. if isinstance(value, str):
  59. # Assume we're deserializing
  60. vals = json.loads(value)
  61. for end in ('lower', 'upper'):
  62. if end in vals:
  63. vals[end] = self.base_field.to_python(vals[end])
  64. value = self.range_type(**vals)
  65. elif isinstance(value, (list, tuple)):
  66. value = self.range_type(value[0], value[1])
  67. return value
  68. def set_attributes_from_name(self, name):
  69. super().set_attributes_from_name(name)
  70. self.base_field.set_attributes_from_name(name)
  71. def value_to_string(self, obj):
  72. value = self.value_from_object(obj)
  73. if value is None:
  74. return None
  75. if value.isempty:
  76. return json.dumps({"empty": True})
  77. base_field = self.base_field
  78. result = {"bounds": value._bounds}
  79. for end in ('lower', 'upper'):
  80. val = getattr(value, end)
  81. if val is None:
  82. result[end] = None
  83. else:
  84. obj = AttributeSetter(base_field.attname, val)
  85. result[end] = base_field.value_to_string(obj)
  86. return json.dumps(result)
  87. def formfield(self, **kwargs):
  88. kwargs.setdefault('form_class', self.form_field)
  89. return super().formfield(**kwargs)
  90. class IntegerRangeField(RangeField):
  91. base_field = models.IntegerField
  92. range_type = NumericRange
  93. form_field = forms.IntegerRangeField
  94. def db_type(self, connection):
  95. return 'int4range'
  96. class BigIntegerRangeField(RangeField):
  97. base_field = models.BigIntegerField
  98. range_type = NumericRange
  99. form_field = forms.IntegerRangeField
  100. def db_type(self, connection):
  101. return 'int8range'
  102. class DecimalRangeField(RangeField):
  103. base_field = models.DecimalField
  104. range_type = NumericRange
  105. form_field = forms.DecimalRangeField
  106. def db_type(self, connection):
  107. return 'numrange'
  108. class FloatRangeField(RangeField):
  109. system_check_deprecated_details = {
  110. 'msg': (
  111. 'FloatRangeField is deprecated and will be removed in Django 3.1.'
  112. ),
  113. 'hint': 'Use DecimalRangeField instead.',
  114. 'id': 'fields.W902',
  115. }
  116. base_field = models.FloatField
  117. range_type = NumericRange
  118. form_field = forms.FloatRangeField
  119. def db_type(self, connection):
  120. return 'numrange'
  121. class DateTimeRangeField(RangeField):
  122. base_field = models.DateTimeField
  123. range_type = DateTimeTZRange
  124. form_field = forms.DateTimeRangeField
  125. def db_type(self, connection):
  126. return 'tstzrange'
  127. class DateRangeField(RangeField):
  128. base_field = models.DateField
  129. range_type = DateRange
  130. form_field = forms.DateRangeField
  131. def db_type(self, connection):
  132. return 'daterange'
  133. RangeField.register_lookup(lookups.DataContains)
  134. RangeField.register_lookup(lookups.ContainedBy)
  135. RangeField.register_lookup(lookups.Overlap)
  136. class DateTimeRangeContains(lookups.PostgresSimpleLookup):
  137. """
  138. Lookup for Date/DateTimeRange containment to cast the rhs to the correct
  139. type.
  140. """
  141. lookup_name = 'contains'
  142. operator = RangeOperators.CONTAINS
  143. def process_rhs(self, compiler, connection):
  144. # Transform rhs value for db lookup.
  145. if isinstance(self.rhs, datetime.date):
  146. output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
  147. value = models.Value(self.rhs, output_field=output_field)
  148. self.rhs = value.resolve_expression(compiler.query)
  149. return super().process_rhs(compiler, connection)
  150. def as_sql(self, compiler, connection):
  151. sql, params = super().as_sql(compiler, connection)
  152. # Cast the rhs if needed.
  153. cast_sql = ''
  154. if (
  155. isinstance(self.rhs, models.Expression) and
  156. self.rhs._output_field_or_none and
  157. # Skip cast if rhs has a matching range type.
  158. not isinstance(self.rhs._output_field_or_none, self.lhs.output_field.__class__)
  159. ):
  160. cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
  161. cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
  162. return '%s%s' % (sql, cast_sql), params
  163. DateRangeField.register_lookup(DateTimeRangeContains)
  164. DateTimeRangeField.register_lookup(DateTimeRangeContains)
  165. class RangeContainedBy(lookups.PostgresSimpleLookup):
  166. lookup_name = 'contained_by'
  167. type_mapping = {
  168. 'integer': 'int4range',
  169. 'bigint': 'int8range',
  170. 'double precision': 'numrange',
  171. 'date': 'daterange',
  172. 'timestamp with time zone': 'tstzrange',
  173. }
  174. operator = RangeOperators.CONTAINED_BY
  175. def process_rhs(self, compiler, connection):
  176. rhs, rhs_params = super().process_rhs(compiler, connection)
  177. cast_type = self.type_mapping[self.lhs.output_field.db_type(connection)]
  178. return '%s::%s' % (rhs, cast_type), rhs_params
  179. def process_lhs(self, compiler, connection):
  180. lhs, lhs_params = super().process_lhs(compiler, connection)
  181. if isinstance(self.lhs.output_field, models.FloatField):
  182. lhs = '%s::numeric' % lhs
  183. return lhs, lhs_params
  184. def get_prep_lookup(self):
  185. return RangeField().get_prep_value(self.rhs)
  186. models.DateField.register_lookup(RangeContainedBy)
  187. models.DateTimeField.register_lookup(RangeContainedBy)
  188. models.IntegerField.register_lookup(RangeContainedBy)
  189. models.BigIntegerField.register_lookup(RangeContainedBy)
  190. models.FloatField.register_lookup(RangeContainedBy)
  191. @RangeField.register_lookup
  192. class FullyLessThan(lookups.PostgresSimpleLookup):
  193. lookup_name = 'fully_lt'
  194. operator = RangeOperators.FULLY_LT
  195. @RangeField.register_lookup
  196. class FullGreaterThan(lookups.PostgresSimpleLookup):
  197. lookup_name = 'fully_gt'
  198. operator = RangeOperators.FULLY_GT
  199. @RangeField.register_lookup
  200. class NotLessThan(lookups.PostgresSimpleLookup):
  201. lookup_name = 'not_lt'
  202. operator = RangeOperators.NOT_LT
  203. @RangeField.register_lookup
  204. class NotGreaterThan(lookups.PostgresSimpleLookup):
  205. lookup_name = 'not_gt'
  206. operator = RangeOperators.NOT_GT
  207. @RangeField.register_lookup
  208. class AdjacentToLookup(lookups.PostgresSimpleLookup):
  209. lookup_name = 'adjacent_to'
  210. operator = RangeOperators.ADJACENT_TO
  211. @RangeField.register_lookup
  212. class RangeStartsWith(models.Transform):
  213. lookup_name = 'startswith'
  214. function = 'lower'
  215. @property
  216. def output_field(self):
  217. return self.lhs.output_field.base_field
  218. @RangeField.register_lookup
  219. class RangeEndsWith(models.Transform):
  220. lookup_name = 'endswith'
  221. function = 'upper'
  222. @property
  223. def output_field(self):
  224. return self.lhs.output_field.base_field
  225. @RangeField.register_lookup
  226. class IsEmpty(models.Transform):
  227. lookup_name = 'isempty'
  228. function = 'isempty'
  229. output_field = models.BooleanField()