Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement aliased model rename. #26

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions syzygy/operations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager

from django.db import transaction
from django.db.migrations import operations
from django.db.models.fields import NOT_PROVIDED

Expand Down Expand Up @@ -255,3 +256,130 @@ class AlterField(StagedOperation, operations.AlterField):
"""
Subclass of ``AlterField`` that allows explicitly defining a stage.
"""


class AliasOperationMixin:
@staticmethod
def _create_instead_of_triggers(schema_editor, view_db_name, new_model):
quote = schema_editor.quote_name
schema_editor.execute(
(
"CREATE TRIGGER {trigger_name} INSTEAD OF INSERT ON {view_db_name}\n"
"BEGIN\n"
"INSERT INTO {new_table}({fields}) VALUES({values});\n"
"END"
).format(
trigger_name=f"{view_db_name}_insert",
view_db_name=quote(view_db_name),
new_table=quote(new_model._meta.db_table),
fields=", ".join(
quote(field.column) for field in new_model._meta.local_fields
),
values=", ".join(
f"NEW.{quote(field.column)}"
for field in new_model._meta.local_fields
),
)
)
for field in new_model._meta.local_fields:
schema_editor.execute(
(
"CREATE TRIGGER {trigger_name} INSTEAD OF UPDATE OF {column} ON {view_db_name}\n"
"BEGIN\n"
"UPDATE {new_table} SET {column}=NEW.{column} WHERE {pk}=NEW.{pk};\n"
"END"
).format(
trigger_name=f"{view_db_name}_update_{field.column}",
view_db_name=quote(view_db_name),
new_table=quote(new_model._meta.db_table),
column=quote(field.column),
pk=quote(new_model._meta.pk.column),
)
)
schema_editor.execute(
(
"CREATE TRIGGER {trigger_name} INSTEAD OF DELETE ON {view_db_name}\n"
"BEGIN\n"
"DELETE FROM {new_table} WHERE {pk}=OLD.{pk};\n"
"END"
).format(
trigger_name=f"{view_db_name}_delete",
view_db_name=quote(view_db_name),
new_table=quote(new_model._meta.db_table),
pk=quote(new_model._meta.pk.column),
)
)

@classmethod
def create_view(cls, schema_editor, view_db_name, new_model):
quote = schema_editor.quote_name
schema_editor.execute(
"CREATE VIEW {} AS SELECT * FROM {}".format(
quote(view_db_name), quote(new_model._meta.db_table)
)
)
if schema_editor.connection.vendor == "sqlite":
cls._create_instead_of_triggers(schema_editor, view_db_name, new_model)

@staticmethod
def drop_view(schema_editor, db_table):
schema_editor.execute("DROP VIEW {}".format(schema_editor.quote_name(db_table)))


class AliasedRenameModel(AliasOperationMixin, operations.RenameModel):
stage = Stage.PRE_DEPLOY

def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.new_name)
alias = schema_editor.connection.alias
if not self.allow_migrate_model(alias, new_model):
return
old_model = from_state.apps.get_model(app_label, self.old_name)
view_db_name = old_model._meta.db_table
with transaction.atomic(alias):
super().database_forwards(app_label, schema_editor, from_state, to_state)
self.create_view(schema_editor, view_db_name, new_model)

def database_backwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.old_name_lower)
alias = schema_editor.connection.alias
if not self.allow_migrate_model(alias, new_model):
return
with transaction.atomic(alias):
self.drop_view(schema_editor, new_model._meta.db_table)
super().database_backwards(app_label, schema_editor, from_state, to_state)

def describe(self):
return "Rename model %s to %s while creating an alias for %s" % (
self.old_name,
self.new_name,
self.old_name,
)

def reduce(self, operation, app_label):
if (
isinstance(operation, UnaliasModel)
and operation.name_lower == self.new_name_lower
):
return [operations.RenameModel(self.old_name, self.new_name)]
return super().reduce(operation, app_label)


class UnaliasModel(AliasOperationMixin, operations.models.ModelOperation):
stage = Stage.POST_DEPLOY

def __init__(self, name, view_db_name):
self.view_db_name = view_db_name
super().__init__(name)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name)
if not self.allow_migrate_model(schema_editor.connection.alias, model):
return
self.drop_view(self.view_db_name)

def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name)
if not self.allow_migrate_model(schema_editor.connection.alias, model):
return
self.create_view(schema_editor, self.view_db_name, model)
71 changes: 70 additions & 1 deletion tests/test_operations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Tuple

from django.db import connection, migrations, models
from django.db.migrations.operations import RenameModel
from django.db.migrations.operations.base import Operation
from django.db.migrations.optimizer import MigrationOptimizer
from django.db.migrations.serializer import OperationSerializer
Expand All @@ -10,7 +11,13 @@

from syzygy.autodetector import MigrationAutodetector
from syzygy.constants import Stage
from syzygy.operations import AddField, PostAddField, PreRemoveField
from syzygy.operations import (
AddField,
AliasedRenameModel,
PostAddField,
PreRemoveField,
UnaliasModel,
)
from syzygy.plan import get_operation_stage


Expand Down Expand Up @@ -313,3 +320,65 @@ def test_elidable(self):
migrations.RemoveField(model_name, field_name, field),
]
self.assert_optimizes_to(operations, [operations[-1]])


class AliasedRenameModelTests(OperationTestCase):
def test_describe(self):
self.assertEqual(
AliasedRenameModel("OldName", "NewName").describe(),
"Rename model OldName to NewName while creating an alias for OldName",
)

def _apply_forwards(self):
model_name = "TestModel"
field = models.IntegerField()
pre_state = self.apply_operations(
[
migrations.CreateModel(model_name, [("foo", field)]),
]
)
new_model_name = "NewTestModel"
post_state = self.apply_operations(
[
AliasedRenameModel(model_name, new_model_name),
],
pre_state,
)
return (pre_state, model_name), (post_state, new_model_name)

def test_database_forwards(self):
(pre_state, _), (post_state, new_model_name) = self._apply_forwards()
pre_model = pre_state.apps.get_model("tests", "testmodel")
pre_obj = pre_model.objects.create(foo=1)
if connection.vendor == "sqlite":
# SQLite doesn't allow the usage of RETURNING in INSTEAD OF INSERT
# triggers and thus the object has to be refetched.
pre_obj = pre_model.objects.latest("pk")
self.assertEqual(pre_model.objects.get(), pre_obj)
post_model = post_state.apps.get_model("tests", new_model_name)
self.assertEqual(post_model.objects.get().pk, pre_obj.pk)
pre_model.objects.all().delete()
post_obj = post_model.objects.create(foo=2)
self.assertEqual(post_model.objects.get(), post_obj)
self.assertEqual(pre_model.objects.get().pk, post_obj.pk)
pre_model.objects.update(foo=3)
self.assertEqual(post_model.objects.get().foo, 3)

def test_database_backwards(self):
(pre_state, model_name), (post_state, new_model_name) = self._apply_forwards()
with connection.schema_editor() as schema_editor:
AliasedRenameModel(model_name, new_model_name).database_backwards(
"tests", schema_editor, post_state, pre_state
)

def test_elidable(self):
model_name = "TestModel"
new_model_name = "NewTestModel"
operations = [
AliasedRenameModel(
model_name,
new_model_name,
),
UnaliasModel(new_model_name, "tests_testmodel"),
]
self.assert_optimizes_to(operations, [RenameModel(model_name, new_model_name)])