array.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import json
  2. from django.contrib.postgres import lookups
  3. from django.contrib.postgres.forms import SimpleArrayField
  4. from django.contrib.postgres.validators import ArrayMaxLengthValidator
  5. from django.core import checks, exceptions
  6. from django.db.models import Field, IntegerField, Transform
  7. from django.db.models.lookups import Exact, In
  8. from django.utils.translation import gettext_lazy as _
  9. from ..utils import prefix_validation_error
  10. from .mixins import CheckFieldDefaultMixin
  11. from .utils import AttributeSetter
  12. __all__ = ['ArrayField']
  13. class ArrayField(CheckFieldDefaultMixin, Field):
  14. empty_strings_allowed = False
  15. default_error_messages = {
  16. 'item_invalid': _('Item %(nth)s in the array did not validate:'),
  17. 'nested_array_mismatch': _('Nested arrays must have the same length.'),
  18. }
  19. _default_hint = ('list', '[]')
  20. def __init__(self, base_field, size=None, **kwargs):
  21. self.base_field = base_field
  22. self.size = size
  23. if self.size:
  24. self.default_validators = [*self.default_validators, ArrayMaxLengthValidator(self.size)]
  25. # For performance, only add a from_db_value() method if the base field
  26. # implements it.
  27. if hasattr(self.base_field, 'from_db_value'):
  28. self.from_db_value = self._from_db_value
  29. super().__init__(**kwargs)
  30. @property
  31. def model(self):
  32. try:
  33. return self.__dict__['model']
  34. except KeyError:
  35. raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
  36. @model.setter
  37. def model(self, model):
  38. self.__dict__['model'] = model
  39. self.base_field.model = model
  40. def check(self, **kwargs):
  41. errors = super().check(**kwargs)
  42. if self.base_field.remote_field:
  43. errors.append(
  44. checks.Error(
  45. 'Base field for array cannot be a related field.',
  46. obj=self,
  47. id='postgres.E002'
  48. )
  49. )
  50. else:
  51. # Remove the field name checks as they are not needed here.
  52. base_errors = self.base_field.check()
  53. if base_errors:
  54. messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
  55. errors.append(
  56. checks.Error(
  57. 'Base field for array has errors:\n %s' % messages,
  58. obj=self,
  59. id='postgres.E001'
  60. )
  61. )
  62. return errors
  63. def set_attributes_from_name(self, name):
  64. super().set_attributes_from_name(name)
  65. self.base_field.set_attributes_from_name(name)
  66. @property
  67. def description(self):
  68. return 'Array of %s' % self.base_field.description
  69. def db_type(self, connection):
  70. size = self.size or ''
  71. return '%s[%s]' % (self.base_field.db_type(connection), size)
  72. def cast_db_type(self, connection):
  73. size = self.size or ''
  74. return '%s[%s]' % (self.base_field.cast_db_type(connection), size)
  75. def get_placeholder(self, value, compiler, connection):
  76. return '%s::{}'.format(self.db_type(connection))
  77. def get_db_prep_value(self, value, connection, prepared=False):
  78. if isinstance(value, (list, tuple)):
  79. return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
  80. return value
  81. def deconstruct(self):
  82. name, path, args, kwargs = super().deconstruct()
  83. if path == 'django.contrib.postgres.fields.array.ArrayField':
  84. path = 'django.contrib.postgres.fields.ArrayField'
  85. kwargs.update({
  86. 'base_field': self.base_field.clone(),
  87. 'size': self.size,
  88. })
  89. return name, path, args, kwargs
  90. def to_python(self, value):
  91. if isinstance(value, str):
  92. # Assume we're deserializing
  93. vals = json.loads(value)
  94. value = [self.base_field.to_python(val) for val in vals]
  95. return value
  96. def _from_db_value(self, value, expression, connection):
  97. if value is None:
  98. return value
  99. return [
  100. self.base_field.from_db_value(item, expression, connection)
  101. for item in value
  102. ]
  103. def value_to_string(self, obj):
  104. values = []
  105. vals = self.value_from_object(obj)
  106. base_field = self.base_field
  107. for val in vals:
  108. if val is None:
  109. values.append(None)
  110. else:
  111. obj = AttributeSetter(base_field.attname, val)
  112. values.append(base_field.value_to_string(obj))
  113. return json.dumps(values)
  114. def get_transform(self, name):
  115. transform = super().get_transform(name)
  116. if transform:
  117. return transform
  118. if '_' not in name:
  119. try:
  120. index = int(name)
  121. except ValueError:
  122. pass
  123. else:
  124. index += 1 # postgres uses 1-indexing
  125. return IndexTransformFactory(index, self.base_field)
  126. try:
  127. start, end = name.split('_')
  128. start = int(start) + 1
  129. end = int(end) # don't add one here because postgres slices are weird
  130. except ValueError:
  131. pass
  132. else:
  133. return SliceTransformFactory(start, end)
  134. def validate(self, value, model_instance):
  135. super().validate(value, model_instance)
  136. for index, part in enumerate(value):
  137. try:
  138. self.base_field.validate(part, model_instance)
  139. except exceptions.ValidationError as error:
  140. raise prefix_validation_error(
  141. error,
  142. prefix=self.error_messages['item_invalid'],
  143. code='item_invalid',
  144. params={'nth': index + 1},
  145. )
  146. if isinstance(self.base_field, ArrayField):
  147. if len({len(i) for i in value}) > 1:
  148. raise exceptions.ValidationError(
  149. self.error_messages['nested_array_mismatch'],
  150. code='nested_array_mismatch',
  151. )
  152. def run_validators(self, value):
  153. super().run_validators(value)
  154. for index, part in enumerate(value):
  155. try:
  156. self.base_field.run_validators(part)
  157. except exceptions.ValidationError as error:
  158. raise prefix_validation_error(
  159. error,
  160. prefix=self.error_messages['item_invalid'],
  161. code='item_invalid',
  162. params={'nth': index + 1},
  163. )
  164. def formfield(self, **kwargs):
  165. return super().formfield(**{
  166. 'form_class': SimpleArrayField,
  167. 'base_field': self.base_field.formfield(),
  168. 'max_length': self.size,
  169. **kwargs,
  170. })
  171. class ArrayCastRHSMixin:
  172. def process_rhs(self, compiler, connection):
  173. rhs, rhs_params = super().process_rhs(compiler, connection)
  174. cast_type = self.lhs.output_field.cast_db_type(connection)
  175. return '%s::%s' % (rhs, cast_type), rhs_params
  176. @ArrayField.register_lookup
  177. class ArrayContains(ArrayCastRHSMixin, lookups.DataContains):
  178. pass
  179. @ArrayField.register_lookup
  180. class ArrayContainedBy(ArrayCastRHSMixin, lookups.ContainedBy):
  181. pass
  182. @ArrayField.register_lookup
  183. class ArrayExact(ArrayCastRHSMixin, Exact):
  184. pass
  185. @ArrayField.register_lookup
  186. class ArrayOverlap(ArrayCastRHSMixin, lookups.Overlap):
  187. pass
  188. @ArrayField.register_lookup
  189. class ArrayLenTransform(Transform):
  190. lookup_name = 'len'
  191. output_field = IntegerField()
  192. def as_sql(self, compiler, connection):
  193. lhs, params = compiler.compile(self.lhs)
  194. # Distinguish NULL and empty arrays
  195. return (
  196. 'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE '
  197. 'coalesce(array_length(%(lhs)s, 1), 0) END'
  198. ) % {'lhs': lhs}, params
  199. @ArrayField.register_lookup
  200. class ArrayInLookup(In):
  201. def get_prep_lookup(self):
  202. values = super().get_prep_lookup()
  203. if hasattr(values, 'resolve_expression'):
  204. return values
  205. # In.process_rhs() expects values to be hashable, so convert lists
  206. # to tuples.
  207. prepared_values = []
  208. for value in values:
  209. if hasattr(value, 'resolve_expression'):
  210. prepared_values.append(value)
  211. else:
  212. prepared_values.append(tuple(value))
  213. return prepared_values
  214. class IndexTransform(Transform):
  215. def __init__(self, index, base_field, *args, **kwargs):
  216. super().__init__(*args, **kwargs)
  217. self.index = index
  218. self.base_field = base_field
  219. def as_sql(self, compiler, connection):
  220. lhs, params = compiler.compile(self.lhs)
  221. return '%s[%%s]' % lhs, params + [self.index]
  222. @property
  223. def output_field(self):
  224. return self.base_field
  225. class IndexTransformFactory:
  226. def __init__(self, index, base_field):
  227. self.index = index
  228. self.base_field = base_field
  229. def __call__(self, *args, **kwargs):
  230. return IndexTransform(self.index, self.base_field, *args, **kwargs)
  231. class SliceTransform(Transform):
  232. def __init__(self, start, end, *args, **kwargs):
  233. super().__init__(*args, **kwargs)
  234. self.start = start
  235. self.end = end
  236. def as_sql(self, compiler, connection):
  237. lhs, params = compiler.compile(self.lhs)
  238. return '%s[%%s:%%s]' % lhs, params + [self.start, self.end]
  239. class SliceTransformFactory:
  240. def __init__(self, start, end):
  241. self.start = start
  242. self.end = end
  243. def __call__(self, *args, **kwargs):
  244. return SliceTransform(self.start, self.end, *args, **kwargs)