diff --git a/syzygy/operations.py b/syzygy/operations.py index 9479a0b..a30e720 100644 --- a/syzygy/operations.py +++ b/syzygy/operations.py @@ -1,5 +1,7 @@ from contextlib import contextmanager +from django.db import connection +from django.db.backends.utils import truncate_name from django.db.migrations import operations from django.db.models.fields import NOT_PROVIDED from django.utils.functional import cached_property @@ -363,124 +365,196 @@ def __init__(self, model_name, name, field, stage, preserve_default=True): class AliasOperationMixin: @staticmethod - def _create_instead_of_triggers(schema_editor, view_db_name, new_model): + def _create_instead_of_triggers(schema_editor, view_name, model): + """ + SQLite requires INSTEAD OF triggers to be created for the view to + direct DML statements to the referenced table. + """ quote = schema_editor.quote_name + max_name_length = schema_editor.connection.ops.max_name_length() + opts = model._meta + view = quote(view_name) + table = quote(opts.db_table) + columns = [quote(field.column) for field in opts.local_fields] + pk = quote(opts.pk.column) schema_editor.execute( ( - "CREATE TRIGGER {trigger_name} INSTEAD OF INSERT ON {view_db_name}\n" + "CREATE TRIGGER {trigger} INSTEAD OF INSERT ON {view}\n" "BEGIN\n" - "INSERT INTO {new_table}({fields}) VALUES({values});\n" + "INSERT INTO {table}({fields}) VALUES({values});\n" "END" ).format( - trigger_name=f"{view_db_name}_insert", - view_db_name=quote(view_db_name), - new_table=quote(new_model._meta.db_table), - fields=", ".join( - quote(field.column) for field in new_model._meta.local_fields - ), - values=", ".join( - f"NEW.{quote(field.column)}" - for field in new_model._meta.local_fields - ), + trigger=quote(truncate_name(f"{view_name}_insert", max_name_length)), + view=view, + table=table, + fields=", ".join(columns), + values=", ".join(f"NEW.{column}" for column in columns), ) ) - for field in new_model._meta.local_fields: + for field, column in zip(opts.local_fields, columns): schema_editor.execute( ( - "CREATE TRIGGER {trigger_name} INSTEAD OF UPDATE OF {column} ON {view_db_name}\n" + "CREATE TRIGGER {trigger} INSTEAD OF UPDATE OF {column} ON {view}\n" "BEGIN\n" - "UPDATE {new_table} SET {column}=NEW.{column} WHERE {pk}=NEW.{pk};\n" + "UPDATE {table} SET {column}=NEW.{column} WHERE {pk}=NEW.{pk};\n" "END" ).format( - trigger_name=f"{view_db_name}_update_{field.column}", - view_db_name=quote(view_db_name), - new_table=quote(new_model._meta.db_table), - column=quote(field.column), - pk=quote(new_model._meta.pk.column), + trigger=quote( + truncate_name( + f"{view_name}_update_{field.name}", max_name_length + ) + ), + view=view, + table=table, + column=column, + pk=pk, ) ) schema_editor.execute( ( - "CREATE TRIGGER {trigger_name} INSTEAD OF DELETE ON {view_db_name}\n" + "CREATE TRIGGER {trigger} INSTEAD OF DELETE ON {view}\n" "BEGIN\n" - "DELETE FROM {new_table} WHERE {pk}=OLD.{pk};\n" + "DELETE FROM {table} WHERE {pk}=OLD.{pk};\n" "END" ).format( - trigger_name=f"{view_db_name}_delete", - view_db_name=quote(view_db_name), - new_table=quote(new_model._meta.db_table), - pk=quote(new_model._meta.pk.column), + trigger=quote(truncate_name(f"{view_name}_delete", max_name_length)), + view=view, + table=table, + pk=pk, ) ) + @staticmethod + def _get_view_name(app_label, alias_name): + return truncate_name( + "%s_%s" % (app_label, alias_name.lower()), + connection.ops.max_name_length(), + ) + @classmethod - def create_view(cls, schema_editor, view_db_name, new_model): + def _create_view(cls, schema_editor, model, alias_name): + # XXX: Explicitly use connection to retrieve ops.max_name_length() and + # not schema_editor.connection as Django systematically use the default + # connection (see https://code.djangoproject.com/ticket/13528). + view_name = cls._get_view_name(model._meta.app_label, alias_name) quote = schema_editor.quote_name schema_editor.execute( "CREATE VIEW {} AS SELECT * FROM {}".format( - quote(view_db_name), quote(new_model._meta.db_table) + quote(view_name), quote(model._meta.db_table) ) ) if schema_editor.connection.vendor == "sqlite": - cls._create_instead_of_triggers(schema_editor, view_db_name, new_model) + cls._create_instead_of_triggers(schema_editor, view_name, model) @staticmethod - def drop_view(schema_editor, db_table): - schema_editor.execute("DROP VIEW {}".format(schema_editor.quote_name(db_table))) + def _drop_view(schema_editor, view_name: str): + schema_editor.execute( + "DROP VIEW {}".format(schema_editor.quote_name(view_name)) + ) + + @classmethod + def alias_model(cls, schema_editor, model, alias_name): + for many_to_many in model._meta.local_many_to_many: + through = many_to_many.remote_field.through + if through._meta.auto_created: + raise NotImplementedError( + "Aliasing of models containing many-to-many fields " + "without an explicit intermediary model (through) is not " + "implemented." + ) + cls._create_view(schema_editor, model, alias_name) + + def unalias_model(cls, schema_editor, model, alias_name): + view_name = cls._get_view_name(model._meta.app_label, alias_name) + for many_to_many in model._meta.local_many_to_many: + through = many_to_many.remote_field.through + if through._meta.auto_created: + raise NotImplementedError( + "Un-aliasing of models containing many-to-many fields " + "without an explicit intermediary model (through) is not " + "implemented." + ) + cls._drop_view(schema_editor, view_name) -class AliasedRenameModel(AliasOperationMixin, operations.RenameModel): +class AliasModel(AliasOperationMixin, operations.models.ModelOperation): + """ + First stage of alias-based `RenameModel` replacement. + + Implements the alias creation for the table meant to be renamed. + """ + stage = Stage.PRE_DEPLOY + alias_name: str + + def __init__(self, name: str, alias_name: str): + self.alias_name = alias_name + super().__init__(name) + + @cached_property + def alias_name_lower(self): + return self.alias_name.lower() + + def state_forwards(self, app_label, state): + # Create an un-managed model to represent the alias. + aliased_model = state.models[app_label, self.name_lower].clone() + aliased_model.name = self.alias_name + aliased_model.options["managed"] = False + state.add_model(aliased_model) def database_forwards(self, app_label, schema_editor, from_state, to_state): - new_model = to_state.apps.get_model(app_label, self.new_name) + model = to_state.apps.get_model(app_label, self.name) alias = schema_editor.connection.alias - if not self.allow_migrate_model(alias, new_model): + if not self.allow_migrate_model(alias, model): return - old_model = from_state.apps.get_model(app_label, self.old_name) - view_db_name = old_model._meta.db_table - super().database_forwards(app_label, schema_editor, from_state, to_state) - self.create_view(schema_editor, view_db_name, new_model) + self.alias_model(schema_editor, model, self.alias_name) def database_backwards(self, app_label, schema_editor, from_state, to_state): - new_model = to_state.apps.get_model(app_label, self.old_name_lower) + model = from_state.apps.get_model(app_label, self.name) alias = schema_editor.connection.alias - if not self.allow_migrate_model(alias, new_model): + if not self.allow_migrate_model(alias, model): return - self.drop_view(schema_editor, new_model._meta.db_table) - super().database_backwards(app_label, schema_editor, from_state, to_state) + self.unalias_model(schema_editor, model, self.alias_name) def describe(self): - return "Rename model %s to %s while creating an alias for %s" % ( - self.old_name, - self.new_name, - self.old_name, - ) + return f"Alias model {self.name} to {self.alias_name}" def reduce(self, operation, app_label): if ( - isinstance(operation, UnaliasModel) - and operation.name_lower == self.new_name_lower + isinstance(operation, RenameAliasedModel) + and operation.name_lower == operation.old_name_lower + and operation.new_name_lower == self.alias_name_lower ): - return [operations.RenameModel(self.old_name, self.new_name)] + return [operations.RenameModel(operation.old_name, operation.new_name)] return super().reduce(operation, app_label) -class UnaliasModel(AliasOperationMixin, operations.models.ModelOperation): +class RenameAliasedModel(AliasOperationMixin, operations.RenameModel): + """ + Second stage of alias-based `RenameModel` replacements. + + Implements the table renaming and alias removal. + """ + stage = Stage.POST_DEPLOY - def __init__(self, name, view_db_name): - self.view_db_name = view_db_name - super().__init__(name) + def state_forwards(self, app_label, state): + state.remove_model(app_label, self.new_name) + super().state_forwards(app_label, state) def database_forwards(self, app_label, schema_editor, from_state, to_state): - model = to_state.apps.get_model(app_label, self.name) - if not self.allow_migrate_model(schema_editor.connection.alias, model): + to_model = to_state.apps.get_model(app_label, self.new_name) + if not self.allow_migrate_model(schema_editor.connection.alias, to_model): return - self.drop_view(self.view_db_name) + self.unalias_model(schema_editor, to_model, self.alias_name) + super().database_forwards(app_label, schema_editor, from_state, to_state) def database_backwards(self, app_label, schema_editor, from_state, to_state): - model = to_state.apps.get_model(app_label, self.name) - if not self.allow_migrate_model(schema_editor.connection.alias, model): + from_model = from_state.apps.get_model(app_label, self.new_name) + if not self.allow_migrate_model(schema_editor.connection.alias, from_model): return - self.create_view(schema_editor, self.view_db_name, model) + super().database_backwards(app_label, schema_editor, from_state, to_state) + self.alias_model(schema_editor, from_model, self.alias_name) + + def describe(self): + return f"Rename model {self.name} to {self.alias_name}" diff --git a/tests/test_operations.py b/tests/test_operations.py index 7ecc463..8df668e 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -14,11 +14,11 @@ from syzygy.compat import field_db_default_supported from syzygy.constants import Stage from syzygy.operations import ( - AliasedRenameModel, + AliasModel, AlterField, + RenameAliasedModel, RenameField, RenameModel, - UnaliasModel, get_post_add_field_operation, get_pre_add_field_operation, get_pre_remove_field_operation, @@ -496,11 +496,11 @@ def test_serialization(self): self.assertEqual(operation.stage, Stage.PRE_DEPLOY) -class AliasedRenameModelTests(OperationTestCase): +class AliasModelTests(OperationTestCase): def test_describe(self): self.assertEqual( - AliasedRenameModel("OldName", "NewName").describe(), - "Rename model OldName to NewName while creating an alias for OldName", + AliasModel("OldName", "NewName").describe(), + "Alias model OldName to NewName", ) def _apply_forwards(self): @@ -514,7 +514,7 @@ def _apply_forwards(self): new_model_name = "NewTestModel" post_state = self.apply_operations( [ - AliasedRenameModel(model_name, new_model_name), + AliasModel(model_name, new_model_name), ], pre_state, ) @@ -524,15 +524,16 @@ def test_database_forwards(self): (pre_state, _), (post_state, new_model_name) = self._apply_forwards() pre_model = pre_state.apps.get_model("tests", "testmodel") pre_obj = pre_model.objects.create(foo=1) - if connection.vendor == "sqlite": - # SQLite doesn't allow the usage of RETURNING in INSTEAD OF INSERT - # triggers and thus the object has to be refetched. - pre_obj = pre_model.objects.latest("pk") self.assertEqual(pre_model.objects.get(), pre_obj) post_model = post_state.apps.get_model("tests", new_model_name) self.assertEqual(post_model.objects.get().pk, pre_obj.pk) pre_model.objects.all().delete() post_obj = post_model.objects.create(foo=2) + # XXX: Does that make the option non-viable on SQLite? + if connection.vendor == "sqlite": + # SQLite doesn't allow the usage of RETURNING in INSTEAD OF INSERT + # triggers and thus the object has to be re-fetched. + post_obj = post_model.objects.latest("pk") self.assertEqual(post_model.objects.get(), post_obj) self.assertEqual(pre_model.objects.get().pk, post_obj.pk) pre_model.objects.update(foo=3) @@ -541,7 +542,7 @@ def test_database_forwards(self): def test_database_backwards(self): (pre_state, model_name), (post_state, new_model_name) = self._apply_forwards() with connection.schema_editor() as schema_editor: - AliasedRenameModel(model_name, new_model_name).database_backwards( + AliasModel(model_name, new_model_name).database_backwards( "tests", schema_editor, post_state, pre_state ) @@ -549,11 +550,11 @@ def test_elidable(self): model_name = "TestModel" new_model_name = "NewTestModel" operations = [ - AliasedRenameModel( + AliasModel( model_name, new_model_name, ), - UnaliasModel(new_model_name, "tests_testmodel"), + RenameAliasedModel(model_name, new_model_name), ] self.assert_optimizes_to( operations, [migrations.RenameModel(model_name, new_model_name)]