diff --git a/syzygy/operations.py b/syzygy/operations.py index 5398485..391df10 100644 --- a/syzygy/operations.py +++ b/syzygy/operations.py @@ -1,5 +1,6 @@ from contextlib import contextmanager +from django.db import transaction from django.db.migrations import operations from django.db.models.fields import NOT_PROVIDED @@ -255,3 +256,130 @@ class AlterField(StagedOperation, operations.AlterField): """ Subclass of ``AlterField`` that allows explicitly defining a stage. """ + + +class AliasOperationMixin: + @staticmethod + def _create_instead_of_triggers(schema_editor, view_db_name, new_model): + quote = schema_editor.quote_name + schema_editor.execute( + ( + "CREATE TRIGGER {trigger_name} INSTEAD OF INSERT ON {view_db_name}\n" + "BEGIN\n" + "INSERT INTO {new_table}({fields}) VALUES({values});\n" + "END" + ).format( + trigger_name=f"{view_db_name}_insert", + view_db_name=quote(view_db_name), + new_table=quote(new_model._meta.db_table), + fields=", ".join( + quote(field.column) for field in new_model._meta.local_fields + ), + values=", ".join( + f"NEW.{quote(field.column)}" + for field in new_model._meta.local_fields + ), + ) + ) + for field in new_model._meta.local_fields: + schema_editor.execute( + ( + "CREATE TRIGGER {trigger_name} INSTEAD OF UPDATE OF {column} ON {view_db_name}\n" + "BEGIN\n" + "UPDATE {new_table} SET {column}=NEW.{column} WHERE {pk}=NEW.{pk};\n" + "END" + ).format( + trigger_name=f"{view_db_name}_update_{field.column}", + view_db_name=quote(view_db_name), + new_table=quote(new_model._meta.db_table), + column=quote(field.column), + pk=quote(new_model._meta.pk.column), + ) + ) + schema_editor.execute( + ( + "CREATE TRIGGER {trigger_name} INSTEAD OF DELETE ON {view_db_name}\n" + "BEGIN\n" + "DELETE FROM {new_table} WHERE {pk}=OLD.{pk};\n" + "END" + ).format( + trigger_name=f"{view_db_name}_delete", + view_db_name=quote(view_db_name), + new_table=quote(new_model._meta.db_table), + pk=quote(new_model._meta.pk.column), + ) + ) + + @classmethod + def create_view(cls, schema_editor, view_db_name, new_model): + quote = schema_editor.quote_name + schema_editor.execute( + "CREATE VIEW {} AS SELECT * FROM {}".format( + quote(view_db_name), quote(new_model._meta.db_table) + ) + ) + if schema_editor.connection.vendor == "sqlite": + cls._create_instead_of_triggers(schema_editor, view_db_name, new_model) + + @staticmethod + def drop_view(schema_editor, db_table): + schema_editor.execute("DROP VIEW {}".format(schema_editor.quote_name(db_table))) + + +class AliasedRenameModel(AliasOperationMixin, operations.RenameModel): + stage = Stage.PRE_DEPLOY + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + new_model = to_state.apps.get_model(app_label, self.new_name) + alias = schema_editor.connection.alias + if not self.allow_migrate_model(alias, new_model): + return + old_model = from_state.apps.get_model(app_label, self.old_name) + view_db_name = old_model._meta.db_table + with transaction.atomic(alias): + super().database_forwards(app_label, schema_editor, from_state, to_state) + self.create_view(schema_editor, view_db_name, new_model) + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + new_model = to_state.apps.get_model(app_label, self.old_name_lower) + alias = schema_editor.connection.alias + if not self.allow_migrate_model(alias, new_model): + return + with transaction.atomic(alias): + self.drop_view(schema_editor, new_model._meta.db_table) + super().database_backwards(app_label, schema_editor, from_state, to_state) + + def describe(self): + return "Rename model %s to %s while creating an alias for %s" % ( + self.old_name, + self.new_name, + self.old_name, + ) + + def reduce(self, operation, app_label): + if ( + isinstance(operation, UnaliasModel) + and operation.name_lower == self.new_name_lower + ): + return [operations.RenameModel(self.old_name, self.new_name)] + return super().reduce(operation, app_label) + + +class UnaliasModel(AliasOperationMixin, operations.models.ModelOperation): + stage = Stage.POST_DEPLOY + + def __init__(self, name, view_db_name): + self.view_db_name = view_db_name + super().__init__(name) + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + model = to_state.apps.get_model(app_label, self.name) + if not self.allow_migrate_model(schema_editor.connection.alias, model): + return + self.drop_view(self.view_db_name) + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + model = to_state.apps.get_model(app_label, self.name) + if not self.allow_migrate_model(schema_editor.connection.alias, model): + return + self.create_view(schema_editor, self.view_db_name, model) diff --git a/tests/test_operations.py b/tests/test_operations.py index 39c46e1..3a4018d 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -1,6 +1,7 @@ from typing import List, Optional, Tuple from django.db import connection, migrations, models +from django.db.migrations.operations import RenameModel from django.db.migrations.operations.base import Operation from django.db.migrations.optimizer import MigrationOptimizer from django.db.migrations.serializer import OperationSerializer @@ -10,7 +11,13 @@ from syzygy.autodetector import MigrationAutodetector from syzygy.constants import Stage -from syzygy.operations import AddField, PostAddField, PreRemoveField +from syzygy.operations import ( + AddField, + AliasedRenameModel, + PostAddField, + PreRemoveField, + UnaliasModel, +) from syzygy.plan import get_operation_stage @@ -313,3 +320,65 @@ def test_elidable(self): migrations.RemoveField(model_name, field_name, field), ] self.assert_optimizes_to(operations, [operations[-1]]) + + +class AliasedRenameModelTests(OperationTestCase): + def test_describe(self): + self.assertEqual( + AliasedRenameModel("OldName", "NewName").describe(), + "Rename model OldName to NewName while creating an alias for OldName", + ) + + def _apply_forwards(self): + model_name = "TestModel" + field = models.IntegerField() + pre_state = self.apply_operations( + [ + migrations.CreateModel(model_name, [("foo", field)]), + ] + ) + new_model_name = "NewTestModel" + post_state = self.apply_operations( + [ + AliasedRenameModel(model_name, new_model_name), + ], + pre_state, + ) + return (pre_state, model_name), (post_state, new_model_name) + + def test_database_forwards(self): + (pre_state, _), (post_state, new_model_name) = self._apply_forwards() + pre_model = pre_state.apps.get_model("tests", "testmodel") + pre_obj = pre_model.objects.create(foo=1) + if connection.vendor == "sqlite": + # SQLite doesn't allow the usage of RETURNING in INSTEAD OF INSERT + # triggers and thus the object has to be refetched. + pre_obj = pre_model.objects.latest("pk") + self.assertEqual(pre_model.objects.get(), pre_obj) + post_model = post_state.apps.get_model("tests", new_model_name) + self.assertEqual(post_model.objects.get().pk, pre_obj.pk) + pre_model.objects.all().delete() + post_obj = post_model.objects.create(foo=2) + self.assertEqual(post_model.objects.get(), post_obj) + self.assertEqual(pre_model.objects.get().pk, post_obj.pk) + pre_model.objects.update(foo=3) + self.assertEqual(post_model.objects.get().foo, 3) + + def test_database_backwards(self): + (pre_state, model_name), (post_state, new_model_name) = self._apply_forwards() + with connection.schema_editor() as schema_editor: + AliasedRenameModel(model_name, new_model_name).database_backwards( + "tests", schema_editor, post_state, pre_state + ) + + def test_elidable(self): + model_name = "TestModel" + new_model_name = "NewTestModel" + operations = [ + AliasedRenameModel( + model_name, + new_model_name, + ), + UnaliasModel(new_model_name, "tests_testmodel"), + ] + self.assert_optimizes_to(operations, [RenameModel(model_name, new_model_name)])