lookups.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. import itertools
  2. import math
  3. from copy import copy
  4. from django.core.exceptions import EmptyResultSet
  5. from django.db.models.expressions import Case, Exists, Func, Value, When
  6. from django.db.models.fields import (
  7. BooleanField, DateTimeField, Field, IntegerField,
  8. )
  9. from django.db.models.query_utils import RegisterLookupMixin
  10. from django.utils.datastructures import OrderedSet
  11. from django.utils.functional import cached_property
  12. class Lookup:
  13. lookup_name = None
  14. prepare_rhs = True
  15. can_use_none_as_rhs = False
  16. def __init__(self, lhs, rhs):
  17. self.lhs, self.rhs = lhs, rhs
  18. self.rhs = self.get_prep_lookup()
  19. if hasattr(self.lhs, 'get_bilateral_transforms'):
  20. bilateral_transforms = self.lhs.get_bilateral_transforms()
  21. else:
  22. bilateral_transforms = []
  23. if bilateral_transforms:
  24. # Warn the user as soon as possible if they are trying to apply
  25. # a bilateral transformation on a nested QuerySet: that won't work.
  26. from django.db.models.sql.query import Query # avoid circular import
  27. if isinstance(rhs, Query):
  28. raise NotImplementedError("Bilateral transformations on nested querysets are not implemented.")
  29. self.bilateral_transforms = bilateral_transforms
  30. def apply_bilateral_transforms(self, value):
  31. for transform in self.bilateral_transforms:
  32. value = transform(value)
  33. return value
  34. def batch_process_rhs(self, compiler, connection, rhs=None):
  35. if rhs is None:
  36. rhs = self.rhs
  37. if self.bilateral_transforms:
  38. sqls, sqls_params = [], []
  39. for p in rhs:
  40. value = Value(p, output_field=self.lhs.output_field)
  41. value = self.apply_bilateral_transforms(value)
  42. value = value.resolve_expression(compiler.query)
  43. sql, sql_params = compiler.compile(value)
  44. sqls.append(sql)
  45. sqls_params.extend(sql_params)
  46. else:
  47. _, params = self.get_db_prep_lookup(rhs, connection)
  48. sqls, sqls_params = ['%s'] * len(params), params
  49. return sqls, sqls_params
  50. def get_source_expressions(self):
  51. if self.rhs_is_direct_value():
  52. return [self.lhs]
  53. return [self.lhs, self.rhs]
  54. def set_source_expressions(self, new_exprs):
  55. if len(new_exprs) == 1:
  56. self.lhs = new_exprs[0]
  57. else:
  58. self.lhs, self.rhs = new_exprs
  59. def get_prep_lookup(self):
  60. if hasattr(self.rhs, 'resolve_expression'):
  61. return self.rhs
  62. if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
  63. return self.lhs.output_field.get_prep_value(self.rhs)
  64. return self.rhs
  65. def get_db_prep_lookup(self, value, connection):
  66. return ('%s', [value])
  67. def process_lhs(self, compiler, connection, lhs=None):
  68. lhs = lhs or self.lhs
  69. if hasattr(lhs, 'resolve_expression'):
  70. lhs = lhs.resolve_expression(compiler.query)
  71. return compiler.compile(lhs)
  72. def process_rhs(self, compiler, connection):
  73. value = self.rhs
  74. if self.bilateral_transforms:
  75. if self.rhs_is_direct_value():
  76. # Do not call get_db_prep_lookup here as the value will be
  77. # transformed before being used for lookup
  78. value = Value(value, output_field=self.lhs.output_field)
  79. value = self.apply_bilateral_transforms(value)
  80. value = value.resolve_expression(compiler.query)
  81. if hasattr(value, 'as_sql'):
  82. return compiler.compile(value)
  83. else:
  84. return self.get_db_prep_lookup(value, connection)
  85. def rhs_is_direct_value(self):
  86. return not hasattr(self.rhs, 'as_sql')
  87. def relabeled_clone(self, relabels):
  88. new = copy(self)
  89. new.lhs = new.lhs.relabeled_clone(relabels)
  90. if hasattr(new.rhs, 'relabeled_clone'):
  91. new.rhs = new.rhs.relabeled_clone(relabels)
  92. return new
  93. def get_group_by_cols(self, alias=None):
  94. cols = self.lhs.get_group_by_cols()
  95. if hasattr(self.rhs, 'get_group_by_cols'):
  96. cols.extend(self.rhs.get_group_by_cols())
  97. return cols
  98. def as_sql(self, compiler, connection):
  99. raise NotImplementedError
  100. def as_oracle(self, compiler, connection):
  101. # Oracle doesn't allow EXISTS() to be compared to another expression
  102. # unless it's wrapped in a CASE WHEN.
  103. wrapped = False
  104. exprs = []
  105. for expr in (self.lhs, self.rhs):
  106. if isinstance(expr, Exists):
  107. expr = Case(When(expr, then=True), default=False, output_field=BooleanField())
  108. wrapped = True
  109. exprs.append(expr)
  110. lookup = type(self)(*exprs) if wrapped else self
  111. return lookup.as_sql(compiler, connection)
  112. @cached_property
  113. def contains_aggregate(self):
  114. return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
  115. @cached_property
  116. def contains_over_clause(self):
  117. return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False)
  118. @property
  119. def is_summary(self):
  120. return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)
  121. class Transform(RegisterLookupMixin, Func):
  122. """
  123. RegisterLookupMixin() is first so that get_lookup() and get_transform()
  124. first examine self and then check output_field.
  125. """
  126. bilateral = False
  127. arity = 1
  128. @property
  129. def lhs(self):
  130. return self.get_source_expressions()[0]
  131. def get_bilateral_transforms(self):
  132. if hasattr(self.lhs, 'get_bilateral_transforms'):
  133. bilateral_transforms = self.lhs.get_bilateral_transforms()
  134. else:
  135. bilateral_transforms = []
  136. if self.bilateral:
  137. bilateral_transforms.append(self.__class__)
  138. return bilateral_transforms
  139. class BuiltinLookup(Lookup):
  140. def process_lhs(self, compiler, connection, lhs=None):
  141. lhs_sql, params = super().process_lhs(compiler, connection, lhs)
  142. field_internal_type = self.lhs.output_field.get_internal_type()
  143. db_type = self.lhs.output_field.db_type(connection=connection)
  144. lhs_sql = connection.ops.field_cast_sql(
  145. db_type, field_internal_type) % lhs_sql
  146. lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
  147. return lhs_sql, list(params)
  148. def as_sql(self, compiler, connection):
  149. lhs_sql, params = self.process_lhs(compiler, connection)
  150. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  151. params.extend(rhs_params)
  152. rhs_sql = self.get_rhs_op(connection, rhs_sql)
  153. return '%s %s' % (lhs_sql, rhs_sql), params
  154. def get_rhs_op(self, connection, rhs):
  155. return connection.operators[self.lookup_name] % rhs
  156. class FieldGetDbPrepValueMixin:
  157. """
  158. Some lookups require Field.get_db_prep_value() to be called on their
  159. inputs.
  160. """
  161. get_db_prep_lookup_value_is_iterable = False
  162. def get_db_prep_lookup(self, value, connection):
  163. # For relational fields, use the 'target_field' attribute of the
  164. # output_field.
  165. field = getattr(self.lhs.output_field, 'target_field', None)
  166. get_db_prep_value = getattr(field, 'get_db_prep_value', None) or self.lhs.output_field.get_db_prep_value
  167. return (
  168. '%s',
  169. [get_db_prep_value(v, connection, prepared=True) for v in value]
  170. if self.get_db_prep_lookup_value_is_iterable else
  171. [get_db_prep_value(value, connection, prepared=True)]
  172. )
  173. class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
  174. """
  175. Some lookups require Field.get_db_prep_value() to be called on each value
  176. in an iterable.
  177. """
  178. get_db_prep_lookup_value_is_iterable = True
  179. def get_prep_lookup(self):
  180. if hasattr(self.rhs, 'resolve_expression'):
  181. return self.rhs
  182. prepared_values = []
  183. for rhs_value in self.rhs:
  184. if hasattr(rhs_value, 'resolve_expression'):
  185. # An expression will be handled by the database but can coexist
  186. # alongside real values.
  187. pass
  188. elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
  189. rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
  190. prepared_values.append(rhs_value)
  191. return prepared_values
  192. def process_rhs(self, compiler, connection):
  193. if self.rhs_is_direct_value():
  194. # rhs should be an iterable of values. Use batch_process_rhs()
  195. # to prepare/transform those values.
  196. return self.batch_process_rhs(compiler, connection)
  197. else:
  198. return super().process_rhs(compiler, connection)
  199. def resolve_expression_parameter(self, compiler, connection, sql, param):
  200. params = [param]
  201. if hasattr(param, 'resolve_expression'):
  202. param = param.resolve_expression(compiler.query)
  203. if hasattr(param, 'as_sql'):
  204. sql, params = param.as_sql(compiler, connection)
  205. return sql, params
  206. def batch_process_rhs(self, compiler, connection, rhs=None):
  207. pre_processed = super().batch_process_rhs(compiler, connection, rhs)
  208. # The params list may contain expressions which compile to a
  209. # sql/param pair. Zip them to get sql and param pairs that refer to the
  210. # same argument and attempt to replace them with the result of
  211. # compiling the param step.
  212. sql, params = zip(*(
  213. self.resolve_expression_parameter(compiler, connection, sql, param)
  214. for sql, param in zip(*pre_processed)
  215. ))
  216. params = itertools.chain.from_iterable(params)
  217. return sql, tuple(params)
  218. @Field.register_lookup
  219. class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
  220. lookup_name = 'exact'
  221. def process_rhs(self, compiler, connection):
  222. from django.db.models.sql.query import Query
  223. if isinstance(self.rhs, Query):
  224. if self.rhs.has_limit_one():
  225. if not self.rhs.has_select_fields:
  226. self.rhs.clear_select_clause()
  227. self.rhs.add_fields(['pk'])
  228. else:
  229. raise ValueError(
  230. 'The QuerySet value for an exact lookup must be limited to '
  231. 'one result using slicing.'
  232. )
  233. return super().process_rhs(compiler, connection)
  234. @Field.register_lookup
  235. class IExact(BuiltinLookup):
  236. lookup_name = 'iexact'
  237. prepare_rhs = False
  238. def process_rhs(self, qn, connection):
  239. rhs, params = super().process_rhs(qn, connection)
  240. if params:
  241. params[0] = connection.ops.prep_for_iexact_query(params[0])
  242. return rhs, params
  243. @Field.register_lookup
  244. class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
  245. lookup_name = 'gt'
  246. @Field.register_lookup
  247. class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
  248. lookup_name = 'gte'
  249. @Field.register_lookup
  250. class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
  251. lookup_name = 'lt'
  252. @Field.register_lookup
  253. class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
  254. lookup_name = 'lte'
  255. class IntegerFieldFloatRounding:
  256. """
  257. Allow floats to work as query values for IntegerField. Without this, the
  258. decimal portion of the float would always be discarded.
  259. """
  260. def get_prep_lookup(self):
  261. if isinstance(self.rhs, float):
  262. self.rhs = math.ceil(self.rhs)
  263. return super().get_prep_lookup()
  264. @IntegerField.register_lookup
  265. class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual):
  266. pass
  267. @IntegerField.register_lookup
  268. class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
  269. pass
  270. @Field.register_lookup
  271. class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
  272. lookup_name = 'in'
  273. def process_rhs(self, compiler, connection):
  274. db_rhs = getattr(self.rhs, '_db', None)
  275. if db_rhs is not None and db_rhs != connection.alias:
  276. raise ValueError(
  277. "Subqueries aren't allowed across different databases. Force "
  278. "the inner query to be evaluated using `list(inner_query)`."
  279. )
  280. if self.rhs_is_direct_value():
  281. try:
  282. rhs = OrderedSet(self.rhs)
  283. except TypeError: # Unhashable items in self.rhs
  284. rhs = self.rhs
  285. if not rhs:
  286. raise EmptyResultSet
  287. # rhs should be an iterable; use batch_process_rhs() to
  288. # prepare/transform those values.
  289. sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
  290. placeholder = '(' + ', '.join(sqls) + ')'
  291. return (placeholder, sqls_params)
  292. else:
  293. if not getattr(self.rhs, 'has_select_fields', True):
  294. self.rhs.clear_select_clause()
  295. self.rhs.add_fields(['pk'])
  296. return super().process_rhs(compiler, connection)
  297. def get_rhs_op(self, connection, rhs):
  298. return 'IN %s' % rhs
  299. def as_sql(self, compiler, connection):
  300. max_in_list_size = connection.ops.max_in_list_size()
  301. if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
  302. return self.split_parameter_list_as_sql(compiler, connection)
  303. return super().as_sql(compiler, connection)
  304. def split_parameter_list_as_sql(self, compiler, connection):
  305. # This is a special case for databases which limit the number of
  306. # elements which can appear in an 'IN' clause.
  307. max_in_list_size = connection.ops.max_in_list_size()
  308. lhs, lhs_params = self.process_lhs(compiler, connection)
  309. rhs, rhs_params = self.batch_process_rhs(compiler, connection)
  310. in_clause_elements = ['(']
  311. params = []
  312. for offset in range(0, len(rhs_params), max_in_list_size):
  313. if offset > 0:
  314. in_clause_elements.append(' OR ')
  315. in_clause_elements.append('%s IN (' % lhs)
  316. params.extend(lhs_params)
  317. sqls = rhs[offset: offset + max_in_list_size]
  318. sqls_params = rhs_params[offset: offset + max_in_list_size]
  319. param_group = ', '.join(sqls)
  320. in_clause_elements.append(param_group)
  321. in_clause_elements.append(')')
  322. params.extend(sqls_params)
  323. in_clause_elements.append(')')
  324. return ''.join(in_clause_elements), params
  325. class PatternLookup(BuiltinLookup):
  326. param_pattern = '%%%s%%'
  327. prepare_rhs = False
  328. def get_rhs_op(self, connection, rhs):
  329. # Assume we are in startswith. We need to produce SQL like:
  330. # col LIKE %s, ['thevalue%']
  331. # For python values we can (and should) do that directly in Python,
  332. # but if the value is for example reference to other column, then
  333. # we need to add the % pattern match to the lookup by something like
  334. # col LIKE othercol || '%%'
  335. # So, for Python values we don't need any special pattern, but for
  336. # SQL reference values or SQL transformations we need the correct
  337. # pattern added.
  338. if hasattr(self.rhs, 'as_sql') or self.bilateral_transforms:
  339. pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
  340. return pattern.format(rhs)
  341. else:
  342. return super().get_rhs_op(connection, rhs)
  343. def process_rhs(self, qn, connection):
  344. rhs, params = super().process_rhs(qn, connection)
  345. if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
  346. params[0] = self.param_pattern % connection.ops.prep_for_like_query(params[0])
  347. return rhs, params
  348. @Field.register_lookup
  349. class Contains(PatternLookup):
  350. lookup_name = 'contains'
  351. @Field.register_lookup
  352. class IContains(Contains):
  353. lookup_name = 'icontains'
  354. @Field.register_lookup
  355. class StartsWith(PatternLookup):
  356. lookup_name = 'startswith'
  357. param_pattern = '%s%%'
  358. @Field.register_lookup
  359. class IStartsWith(StartsWith):
  360. lookup_name = 'istartswith'
  361. @Field.register_lookup
  362. class EndsWith(PatternLookup):
  363. lookup_name = 'endswith'
  364. param_pattern = '%%%s'
  365. @Field.register_lookup
  366. class IEndsWith(EndsWith):
  367. lookup_name = 'iendswith'
  368. @Field.register_lookup
  369. class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
  370. lookup_name = 'range'
  371. def get_rhs_op(self, connection, rhs):
  372. return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
  373. @Field.register_lookup
  374. class IsNull(BuiltinLookup):
  375. lookup_name = 'isnull'
  376. prepare_rhs = False
  377. def as_sql(self, compiler, connection):
  378. sql, params = compiler.compile(self.lhs)
  379. if self.rhs:
  380. return "%s IS NULL" % sql, params
  381. else:
  382. return "%s IS NOT NULL" % sql, params
  383. @Field.register_lookup
  384. class Regex(BuiltinLookup):
  385. lookup_name = 'regex'
  386. prepare_rhs = False
  387. def as_sql(self, compiler, connection):
  388. if self.lookup_name in connection.operators:
  389. return super().as_sql(compiler, connection)
  390. else:
  391. lhs, lhs_params = self.process_lhs(compiler, connection)
  392. rhs, rhs_params = self.process_rhs(compiler, connection)
  393. sql_template = connection.ops.regex_lookup(self.lookup_name)
  394. return sql_template % (lhs, rhs), lhs_params + rhs_params
  395. @Field.register_lookup
  396. class IRegex(Regex):
  397. lookup_name = 'iregex'
  398. class YearLookup(Lookup):
  399. def year_lookup_bounds(self, connection, year):
  400. output_field = self.lhs.lhs.output_field
  401. if isinstance(output_field, DateTimeField):
  402. bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
  403. else:
  404. bounds = connection.ops.year_lookup_bounds_for_date_field(year)
  405. return bounds
  406. def as_sql(self, compiler, connection):
  407. # Avoid the extract operation if the rhs is a direct value to allow
  408. # indexes to be used.
  409. if self.rhs_is_direct_value():
  410. # Skip the extract part by directly using the originating field,
  411. # that is self.lhs.lhs.
  412. lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
  413. rhs_sql, _ = self.process_rhs(compiler, connection)
  414. rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
  415. start, finish = self.year_lookup_bounds(connection, self.rhs)
  416. params.extend(self.get_bound_params(start, finish))
  417. return '%s %s' % (lhs_sql, rhs_sql), params
  418. return super().as_sql(compiler, connection)
  419. def get_direct_rhs_sql(self, connection, rhs):
  420. return connection.operators[self.lookup_name] % rhs
  421. def get_bound_params(self, start, finish):
  422. raise NotImplementedError(
  423. 'subclasses of YearLookup must provide a get_bound_params() method'
  424. )
  425. class YearExact(YearLookup, Exact):
  426. def get_direct_rhs_sql(self, connection, rhs):
  427. return 'BETWEEN %s AND %s'
  428. def get_bound_params(self, start, finish):
  429. return (start, finish)
  430. class YearGt(YearLookup, GreaterThan):
  431. def get_bound_params(self, start, finish):
  432. return (finish,)
  433. class YearGte(YearLookup, GreaterThanOrEqual):
  434. def get_bound_params(self, start, finish):
  435. return (start,)
  436. class YearLt(YearLookup, LessThan):
  437. def get_bound_params(self, start, finish):
  438. return (start,)
  439. class YearLte(YearLookup, LessThanOrEqual):
  440. def get_bound_params(self, start, finish):
  441. return (finish,)