operations.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from django.contrib.postgres.signals import (
  2. get_citext_oids, get_hstore_oids, register_type_handlers,
  3. )
  4. from django.db.migrations import AddIndex, RemoveIndex
  5. from django.db.migrations.operations.base import Operation
  6. from django.db.utils import NotSupportedError
  7. class CreateExtension(Operation):
  8. reversible = True
  9. def __init__(self, name):
  10. self.name = name
  11. def state_forwards(self, app_label, state):
  12. pass
  13. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  14. if schema_editor.connection.vendor != 'postgresql':
  15. return
  16. schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name))
  17. # Clear cached, stale oids.
  18. get_hstore_oids.cache_clear()
  19. get_citext_oids.cache_clear()
  20. # Registering new type handlers cannot be done before the extension is
  21. # installed, otherwise a subsequent data migration would use the same
  22. # connection.
  23. register_type_handlers(schema_editor.connection)
  24. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  25. schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name))
  26. # Clear cached, stale oids.
  27. get_hstore_oids.cache_clear()
  28. get_citext_oids.cache_clear()
  29. def describe(self):
  30. return "Creates extension %s" % self.name
  31. class BtreeGinExtension(CreateExtension):
  32. def __init__(self):
  33. self.name = 'btree_gin'
  34. class BtreeGistExtension(CreateExtension):
  35. def __init__(self):
  36. self.name = 'btree_gist'
  37. class CITextExtension(CreateExtension):
  38. def __init__(self):
  39. self.name = 'citext'
  40. class CryptoExtension(CreateExtension):
  41. def __init__(self):
  42. self.name = 'pgcrypto'
  43. class HStoreExtension(CreateExtension):
  44. def __init__(self):
  45. self.name = 'hstore'
  46. class TrigramExtension(CreateExtension):
  47. def __init__(self):
  48. self.name = 'pg_trgm'
  49. class UnaccentExtension(CreateExtension):
  50. def __init__(self):
  51. self.name = 'unaccent'
  52. class NotInTransactionMixin:
  53. def _ensure_not_in_transaction(self, schema_editor):
  54. if schema_editor.connection.in_atomic_block:
  55. raise NotSupportedError(
  56. 'The %s operation cannot be executed inside a transaction '
  57. '(set atomic = False on the migration).'
  58. % self.__class__.__name__
  59. )
  60. class AddIndexConcurrently(NotInTransactionMixin, AddIndex):
  61. """Create an index using PostgreSQL's CREATE INDEX CONCURRENTLY syntax."""
  62. atomic = False
  63. def describe(self):
  64. return 'Concurrently create index %s on field(s) %s of model %s' % (
  65. self.index.name,
  66. ', '.join(self.index.fields),
  67. self.model_name,
  68. )
  69. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  70. self._ensure_not_in_transaction(schema_editor)
  71. model = to_state.apps.get_model(app_label, self.model_name)
  72. if self.allow_migrate_model(schema_editor.connection.alias, model):
  73. schema_editor.add_index(model, self.index, concurrently=True)
  74. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  75. self._ensure_not_in_transaction(schema_editor)
  76. model = from_state.apps.get_model(app_label, self.model_name)
  77. if self.allow_migrate_model(schema_editor.connection.alias, model):
  78. schema_editor.remove_index(model, self.index, concurrently=True)
  79. class RemoveIndexConcurrently(NotInTransactionMixin, RemoveIndex):
  80. """Remove an index using PostgreSQL's DROP INDEX CONCURRENTLY syntax."""
  81. atomic = False
  82. def describe(self):
  83. return 'Concurrently remove index %s from %s' % (self.name, self.model_name)
  84. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  85. self._ensure_not_in_transaction(schema_editor)
  86. model = from_state.apps.get_model(app_label, self.model_name)
  87. if self.allow_migrate_model(schema_editor.connection.alias, model):
  88. from_model_state = from_state.models[app_label, self.model_name_lower]
  89. index = from_model_state.get_index_by_name(self.name)
  90. schema_editor.remove_index(model, index, concurrently=True)
  91. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  92. self._ensure_not_in_transaction(schema_editor)
  93. model = to_state.apps.get_model(app_label, self.model_name)
  94. if self.allow_migrate_model(schema_editor.connection.alias, model):
  95. to_model_state = to_state.models[app_label, self.model_name_lower]
  96. index = to_model_state.get_index_by_name(self.name)
  97. schema_editor.add_index(model, index, concurrently=True)