Skip to content

Commit

Permalink
Implement aliased model rename.
Browse files Browse the repository at this point in the history
  • Loading branch information
charettes committed Feb 20, 2023
1 parent 6cc95c3 commit 3be3709
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 1 deletion.
128 changes: 128 additions & 0 deletions syzygy/operations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager

from django.db import transaction
from django.db.migrations import operations
from django.db.models.fields import NOT_PROVIDED

Expand Down Expand Up @@ -255,3 +256,130 @@ class AlterField(StagedOperation, operations.AlterField):
"""
Subclass of ``AlterField`` that allows explicitly defining a stage.
"""


class AliasOperationMixin:
@staticmethod
def _create_instead_of_triggers(schema_editor, view_db_name, new_model):
quote = schema_editor.quote_name
schema_editor.execute(
(
"CREATE TRIGGER {trigger_name} INSTEAD OF INSERT ON {view_db_name}\n"
"BEGIN\n"
"INSERT INTO {new_table}({fields}) VALUES({values});\n"
"END"
).format(
trigger_name=f"{view_db_name}_insert",
view_db_name=quote(view_db_name),
new_table=quote(new_model._meta.db_table),
fields=", ".join(
quote(field.column) for field in new_model._meta.local_fields
),
values=", ".join(
f"NEW.{quote(field.column)}"
for field in new_model._meta.local_fields
),
)
)
for field in new_model._meta.local_fields:
schema_editor.execute(
(
"CREATE TRIGGER {trigger_name} INSTEAD OF UPDATE OF {column} ON {view_db_name}\n"
"BEGIN\n"
"UPDATE {new_table} SET {column}=NEW.{column} WHERE {pk}=NEW.{pk};\n"
"END"
).format(
trigger_name=f"{view_db_name}_update_{field.column}",
view_db_name=quote(view_db_name),
new_table=quote(new_model._meta.db_table),
column=quote(field.column),
pk=quote(new_model._meta.pk.column),
)
)
schema_editor.execute(
(
"CREATE TRIGGER {trigger_name} INSTEAD OF DELETE ON {view_db_name}\n"
"BEGIN\n"
"DELETE FROM {new_table} WHERE {pk}=OLD.{pk};\n"
"END"
).format(
trigger_name=f"{view_db_name}_delete",
view_db_name=quote(view_db_name),
new_table=quote(new_model._meta.db_table),
pk=quote(new_model._meta.pk.column),
)
)

@classmethod
def create_view(cls, schema_editor, view_db_name, new_model):
quote = schema_editor.quote_name
schema_editor.execute(
"CREATE VIEW {} AS SELECT * FROM {}".format(
quote(view_db_name), quote(new_model._meta.db_table)
)
)
if schema_editor.connection.vendor == "sqlite":
cls._create_instead_of_triggers(schema_editor, view_db_name, new_model)

@staticmethod
def drop_view(schema_editor, db_table):
schema_editor.execute("DROP VIEW {}".format(schema_editor.quote_name(db_table)))


class AliasedRenameModel(AliasOperationMixin, operations.RenameModel):
stage = Stage.PRE_DEPLOY

def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.new_name)
alias = schema_editor.connection.alias
if not self.allow_migrate_model(alias, new_model):
return
old_model = from_state.apps.get_model(app_label, self.old_name)
view_db_name = old_model._meta.db_table
with transaction.atomic(alias):
super().database_forwards(app_label, schema_editor, from_state, to_state)
self.create_view(schema_editor, view_db_name, new_model)

def database_backwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.old_name_lower)
alias = schema_editor.connection.alias
if not self.allow_migrate_model(alias, new_model):
return
with transaction.atomic(alias):
self.drop_view(schema_editor, new_model._meta.db_table)
super().database_backwards(app_label, schema_editor, from_state, to_state)

def describe(self):
return "Rename model %s to %s while creating an alias for %s" % (
self.old_name,
self.new_name,
self.old_name,
)

def reduce(self, operation, app_label):
if (
isinstance(operation, UnaliasModel)
and operation.name_lower == self.new_name_lower
):
return [operations.RenameModel(self.old_name, self.new_name)]
return super().reduce(operation, app_label)


class UnaliasModel(AliasOperationMixin, operations.models.ModelOperation):
stage = Stage.POST_DEPLOY

def __init__(self, name, view_db_name):
self.view_db_name = view_db_name
super().__init__(name)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name)
if not self.allow_migrate_model(schema_editor.connection.alias, model):
return
self.drop_view(self.view_db_name)

def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name)
if not self.allow_migrate_model(schema_editor.connection.alias, model):
return
self.create_view(schema_editor, self.view_db_name, model)
72 changes: 71 additions & 1 deletion tests/test_operations.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
from typing import List, Optional, Tuple

from django.db import connection, migrations, models
from django.db.migrations.operations import RenameModel
from django.db.migrations.operations.base import Operation
from django.db.migrations.optimizer import MigrationOptimizer
from django.db.migrations.serializer import OperationSerializer
from django.db.migrations.state import ProjectState
from django.db.models.fields import NOT_PROVIDED
from django.db.utils import DatabaseError
from django.test import TestCase

from syzygy.autodetector import MigrationAutodetector
from syzygy.constants import Stage
from syzygy.operations import AddField, PostAddField, PreRemoveField
from syzygy.operations import (
AddField,
AliasedRenameModel,
PostAddField,
PreRemoveField,
UnaliasModel,
)
from syzygy.plan import get_operation_stage


Expand Down Expand Up @@ -313,3 +321,65 @@ def test_elidable(self):
migrations.RemoveField(model_name, field_name, field),
]
self.assert_optimizes_to(operations, [operations[-1]])


class AliasedRenameModelTests(OperationTestCase):
def test_describe(self):
self.assertEqual(
AliasedRenameModel("OldName", "NewName").describe(),
"Rename model OldName to NewName while creating an alias for OldName",
)

def _apply_forwards(self):
model_name = "TestModel"
field = models.IntegerField()
pre_state = self.apply_operations(
[
migrations.CreateModel(model_name, [("foo", field)]),
]
)
new_model_name = "NewTestModel"
post_state = self.apply_operations(
[
AliasedRenameModel(model_name, new_model_name),
],
pre_state,
)
return (pre_state, model_name), (post_state, new_model_name)

def test_database_forwards(self):
(pre_state, _), (post_state, new_model_name) = self._apply_forwards()
pre_model = pre_state.apps.get_model("tests", "testmodel")
pre_obj = pre_model.objects.create(foo=1)
if connection.vendor == "sqlite":
# SQLite doesn't allow the usage of RETURNING in INSTEAD OF INSERT
# triggers and thus the object has to be refetched.
pre_obj = pre_model.objects.latest("pk")
self.assertEqual(pre_model.objects.get(), pre_obj)
post_model = post_state.apps.get_model("tests", new_model_name)
self.assertEqual(post_model.objects.get().pk, pre_obj.pk)
pre_model.objects.all().delete()
post_obj = post_model.objects.create(foo=2)
self.assertEqual(post_model.objects.get(), post_obj)
self.assertEqual(pre_model.objects.get().pk, post_obj.pk)
pre_model.objects.update(foo=3)
self.assertEqual(post_model.objects.get().foo, 3)

def test_database_backwards(self):
(pre_state, model_name), (post_state, new_model_name) = self._apply_forwards()
with connection.schema_editor() as schema_editor:
AliasedRenameModel(model_name, new_model_name).database_backwards(
"tests", schema_editor, post_state, pre_state
)

def test_elidable(self):
model_name = "TestModel"
new_model_name = "NewTestModel"
operations = [
AliasedRenameModel(
model_name,
new_model_name,
),
UnaliasModel(new_model_name, "tests_testmodel"),
]
self.assert_optimizes_to(operations, [RenameModel(model_name, new_model_name)])

0 comments on commit 3be3709

Please sign in to comment.