ranges.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import datetime
  2. import json
  3. from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
  4. from django.contrib.postgres import forms, lookups
  5. from django.db import models
  6. from .utils import AttributeSetter
  7. __all__ = [
  8. 'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
  9. 'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField',
  10. 'FloatRangeField',
  11. 'RangeBoundary', 'RangeOperators',
  12. ]
  13. class RangeBoundary(models.Expression):
  14. """A class that represents range boundaries."""
  15. def __init__(self, inclusive_lower=True, inclusive_upper=False):
  16. self.lower = '[' if inclusive_lower else '('
  17. self.upper = ']' if inclusive_upper else ')'
  18. def as_sql(self, compiler, connection):
  19. return "'%s%s'" % (self.lower, self.upper), []
  20. class RangeOperators:
  21. # https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
  22. EQUAL = '='
  23. NOT_EQUAL = '<>'
  24. CONTAINS = '@>'
  25. CONTAINED_BY = '<@'
  26. OVERLAPS = '&&'
  27. FULLY_LT = '<<'
  28. FULLY_GT = '>>'
  29. NOT_LT = '&>'
  30. NOT_GT = '&<'
  31. ADJACENT_TO = '-|-'
  32. class RangeField(models.Field):
  33. empty_strings_allowed = False
  34. def __init__(self, *args, **kwargs):
  35. # Initializing base_field here ensures that its model matches the model for self.
  36. if hasattr(self, 'base_field'):
  37. self.base_field = self.base_field()
  38. super().__init__(*args, **kwargs)
  39. @property
  40. def model(self):
  41. try:
  42. return self.__dict__['model']
  43. except KeyError:
  44. raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
  45. @model.setter
  46. def model(self, model):
  47. self.__dict__['model'] = model
  48. self.base_field.model = model
  49. def get_prep_value(self, value):
  50. if value is None:
  51. return None
  52. elif isinstance(value, Range):
  53. return value
  54. elif isinstance(value, (list, tuple)):
  55. return self.range_type(value[0], value[1])
  56. return value
  57. def to_python(self, value):
  58. if isinstance(value, str):
  59. # Assume we're deserializing
  60. vals = json.loads(value)
  61. for end in ('lower', 'upper'):
  62. if end in vals:
  63. vals[end] = self.base_field.to_python(vals[end])
  64. value = self.range_type(**vals)
  65. elif isinstance(value, (list, tuple)):
  66. value = self.range_type(value[0], value[1])
  67. return value
  68. def set_attributes_from_name(self, name):
  69. super().set_attributes_from_name(name)
  70. self.base_field.set_attributes_from_name(name)
  71. def value_to_string(self, obj):
  72. value = self.value_from_object(obj)
  73. if value is None:
  74. return None
  75. if value.isempty:
  76. return json.dumps({"empty": True})
  77. base_field = self.base_field
  78. result = {"bounds": value._bounds}
  79. for end in ('lower', 'upper'):
  80. val = getattr(value, end)
  81. if val is None:
  82. result[end] = None
  83. else:
  84. obj = AttributeSetter(base_field.attname, val)
  85. result[end] = base_field.value_to_string(obj)
  86. return json.dumps(result)
  87. def formfield(self, **kwargs):
  88. kwargs.setdefault('form_class', self.form_field)
  89. return super().formfield(**kwargs)
  90. class IntegerRangeField(RangeField):
  91. base_field = models.IntegerField
  92. range_type = NumericRange
  93. form_field = forms.IntegerRangeField
  94. def db_type(self, connection):
  95. return 'int4range'
  96. class BigIntegerRangeField(RangeField):
  97. base_field = models.BigIntegerField
  98. range_type = NumericRange
  99. form_field = forms.IntegerRangeField
  100. def db_type(self, connection):
  101. return 'int8range'
  102. class DecimalRangeField(RangeField):
  103. base_field = models.DecimalField
  104. range_type = NumericRange
  105. form_field = forms.DecimalRangeField
  106. def db_type(self, connection):
  107. return 'numrange'
  108. class FloatRangeField(RangeField):
  109. system_check_deprecated_details = {
  110. 'msg': (
  111. 'FloatRangeField is deprecated and will be removed in Django 3.1.'
  112. ),
  113. 'hint': 'Use DecimalRangeField instead.',
  114. 'id': 'fields.W902',
  115. }
  116. base_field = models.FloatField
  117. range_type = NumericRange
  118. form_field = forms.FloatRangeField
  119. def db_type(self, connection):
  120. return 'numrange'
  121. class DateTimeRangeField(RangeField):
  122. base_field = models.DateTimeField
  123. range_type = DateTimeTZRange
  124. form_field = forms.DateTimeRangeField
  125. def db_type(self, connection):
  126. return 'tstzrange'
  127. class DateRangeField(RangeField):
  128. base_field = models.DateField
  129. range_type = DateRange
  130. form_field = forms.DateRangeField
  131. def db_type(self, connection):
  132. return 'daterange'
  133. RangeField.register_lookup(lookups.DataContains)
  134. RangeField.register_lookup(lookups.ContainedBy)
  135. RangeField.register_lookup(lookups.Overlap)
  136. class DateTimeRangeContains(lookups.PostgresSimpleLookup):
  137. """
  138. Lookup for Date/DateTimeRange containment to cast the rhs to the correct
  139. type.
  140. """
  141. lookup_name = 'contains'
  142. operator = RangeOperators.CONTAINS
  143. def process_rhs(self, compiler, connection):
  144. # Transform rhs value for db lookup.
  145. if isinstance(self.rhs, datetime.date):
  146. output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
  147. value = models.Value(self.rhs, output_field=output_field)
  148. self.rhs = value.resolve_expression(compiler.query)
  149. return super().process_rhs(compiler, connection)
  150. def as_sql(self, compiler, connection):
  151. sql, params = super().as_sql(compiler, connection)
  152. # Cast the rhs if needed.
  153. cast_sql = ''
  154. if (
  155. isinstance(self.rhs, models.Expression) and
  156. self.rhs._output_field_or_none and
  157. # Skip cast if rhs has a matching range type.
  158. not isinstance(self.rhs._output_field_or_none, self.lhs.output_field.__class__)
  159. ):
  160. cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
  161. cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
  162. return '%s%s' % (sql, cast_sql), params
  163. DateRangeField.register_lookup(DateTimeRangeContains)
  164. DateTimeRangeField.register_lookup(DateTimeRangeContains)
  165. class RangeContainedBy(lookups.PostgresSimpleLookup):
  166. lookup_name = 'contained_by'
  167. type_mapping = {
  168. 'integer': 'int4range',
  169. 'bigint': 'int8range',
  170. 'double precision': 'numrange',
  171. 'date': 'daterange',
  172. 'timestamp with time zone': 'tstzrange',
  173. }
  174. operator = RangeOperators.CONTAINED_BY
  175. def process_rhs(self, compiler, connection):
  176. rhs, rhs_params = super().process_rhs(compiler, connection)
  177. cast_type = self.type_mapping[self.lhs.output_field.db_type(connection)]
  178. return '%s::%s' % (rhs, cast_type), rhs_params
  179. def process_lhs(self, compiler, connection):
  180. lhs, lhs_params = super().process_lhs(compiler, connection)
  181. if isinstance(self.lhs.output_field, models.FloatField):
  182. lhs = '%s::numeric' % lhs
  183. return lhs, lhs_params
  184. def get_prep_lookup(self):
  185. return RangeField().get_prep_value(self.rhs)
  186. models.DateField.register_lookup(RangeContainedBy)
  187. models.DateTimeField.register_lookup(RangeContainedBy)
  188. models.IntegerField.register_lookup(RangeContainedBy)
  189. models.BigIntegerField.register_lookup(RangeContainedBy)
  190. models.FloatField.register_lookup(RangeContainedBy)
  191. @RangeField.register_lookup
  192. class FullyLessThan(lookups.PostgresSimpleLookup):
  193. lookup_name = 'fully_lt'
  194. operator = RangeOperators.FULLY_LT
  195. @RangeField.register_lookup
  196. class FullGreaterThan(lookups.PostgresSimpleLookup):
  197. lookup_name = 'fully_gt'
  198. operator = RangeOperators.FULLY_GT
  199. @RangeField.register_lookup
  200. class NotLessThan(lookups.PostgresSimpleLookup):
  201. lookup_name = 'not_lt'
  202. operator = RangeOperators.NOT_LT
  203. @RangeField.register_lookup
  204. class NotGreaterThan(lookups.PostgresSimpleLookup):
  205. lookup_name = 'not_gt'
  206. operator = RangeOperators.NOT_GT
  207. @RangeField.register_lookup
  208. class AdjacentToLookup(lookups.PostgresSimpleLookup):
  209. lookup_name = 'adjacent_to'
  210. operator = RangeOperators.ADJACENT_TO
  211. @RangeField.register_lookup
  212. class RangeStartsWith(models.Transform):
  213. lookup_name = 'startswith'
  214. function = 'lower'
  215. @property
  216. def output_field(self):
  217. return self.lhs.output_field.base_field
  218. @RangeField.register_lookup
  219. class RangeEndsWith(models.Transform):
  220. lookup_name = 'endswith'
  221. function = 'upper'
  222. @property
  223. def output_field(self):
  224. return self.lhs.output_field.base_field
  225. @RangeField.register_lookup
  226. class IsEmpty(models.Transform):
  227. lookup_name = 'isempty'
  228. function = 'isempty'
  229. output_field = models.BooleanField()