Skip to content

Commit

Permalink
Tweak things around to ease usage of RENAME with multiple members.
Browse files Browse the repository at this point in the history
  • Loading branch information
charettes committed Jan 27, 2025
1 parent 14af104 commit b02dc91
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 74 deletions.
196 changes: 135 additions & 61 deletions syzygy/operations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from contextlib import contextmanager

from django.db import connection
from django.db.backends.utils import truncate_name
from django.db.migrations import operations
from django.db.models.fields import NOT_PROVIDED
from django.utils.functional import cached_property
Expand Down Expand Up @@ -363,124 +365,196 @@ def __init__(self, model_name, name, field, stage, preserve_default=True):

class AliasOperationMixin:
@staticmethod
def _create_instead_of_triggers(schema_editor, view_db_name, new_model):
def _create_instead_of_triggers(schema_editor, view_name, model):
"""
SQLite requires INSTEAD OF triggers to be created for the view to
direct DML statements to the referenced table.
"""
quote = schema_editor.quote_name
max_name_length = schema_editor.connection.ops.max_name_length()
opts = model._meta
view = quote(view_name)
table = quote(opts.db_table)
columns = [quote(field.column) for field in opts.local_fields]
pk = quote(opts.pk.column)
schema_editor.execute(
(
"CREATE TRIGGER {trigger_name} INSTEAD OF INSERT ON {view_db_name}\n"
"CREATE TRIGGER {trigger} INSTEAD OF INSERT ON {view}\n"
"BEGIN\n"
"INSERT INTO {new_table}({fields}) VALUES({values});\n"
"INSERT INTO {table}({fields}) VALUES({values});\n"
"END"
).format(
trigger_name=f"{view_db_name}_insert",
view_db_name=quote(view_db_name),
new_table=quote(new_model._meta.db_table),
fields=", ".join(
quote(field.column) for field in new_model._meta.local_fields
),
values=", ".join(
f"NEW.{quote(field.column)}"
for field in new_model._meta.local_fields
),
trigger=quote(truncate_name(f"{view_name}_insert", max_name_length)),
view=view,
table=table,
fields=", ".join(columns),
values=", ".join(f"NEW.{column}" for column in columns),
)
)
for field in new_model._meta.local_fields:
for field, column in zip(opts.local_fields, columns):
schema_editor.execute(
(
"CREATE TRIGGER {trigger_name} INSTEAD OF UPDATE OF {column} ON {view_db_name}\n"
"CREATE TRIGGER {trigger} INSTEAD OF UPDATE OF {column} ON {view}\n"
"BEGIN\n"
"UPDATE {new_table} SET {column}=NEW.{column} WHERE {pk}=NEW.{pk};\n"
"UPDATE {table} SET {column}=NEW.{column} WHERE {pk}=NEW.{pk};\n"
"END"
).format(
trigger_name=f"{view_db_name}_update_{field.column}",
view_db_name=quote(view_db_name),
new_table=quote(new_model._meta.db_table),
column=quote(field.column),
pk=quote(new_model._meta.pk.column),
trigger=quote(
truncate_name(
f"{view_name}_update_{field.name}", max_name_length
)
),
view=view,
table=table,
column=column,
pk=pk,
)
)
schema_editor.execute(
(
"CREATE TRIGGER {trigger_name} INSTEAD OF DELETE ON {view_db_name}\n"
"CREATE TRIGGER {trigger} INSTEAD OF DELETE ON {view}\n"
"BEGIN\n"
"DELETE FROM {new_table} WHERE {pk}=OLD.{pk};\n"
"DELETE FROM {table} WHERE {pk}=OLD.{pk};\n"
"END"
).format(
trigger_name=f"{view_db_name}_delete",
view_db_name=quote(view_db_name),
new_table=quote(new_model._meta.db_table),
pk=quote(new_model._meta.pk.column),
trigger=quote(truncate_name(f"{view_name}_delete", max_name_length)),
view=view,
table=table,
pk=pk,
)
)

@staticmethod
def _get_view_name(app_label, alias_name):
return truncate_name(
"%s_%s" % (app_label, alias_name.lower()),
connection.ops.max_name_length(),
)

@classmethod
def create_view(cls, schema_editor, view_db_name, new_model):
def _create_view(cls, schema_editor, model, alias_name):
# XXX: Explicitly use connection to retrieve ops.max_name_length() and
# not schema_editor.connection as Django systematically use the default
# connection (see https://code.djangoproject.com/ticket/13528).
view_name = cls._get_view_name(model._meta.app_label, alias_name)
quote = schema_editor.quote_name
schema_editor.execute(
"CREATE VIEW {} AS SELECT * FROM {}".format(
quote(view_db_name), quote(new_model._meta.db_table)
quote(view_name), quote(model._meta.db_table)
)
)
if schema_editor.connection.vendor == "sqlite":
cls._create_instead_of_triggers(schema_editor, view_db_name, new_model)
cls._create_instead_of_triggers(schema_editor, view_name, model)

@staticmethod
def drop_view(schema_editor, db_table):
schema_editor.execute("DROP VIEW {}".format(schema_editor.quote_name(db_table)))
def _drop_view(schema_editor, view_name: str):
schema_editor.execute(
"DROP VIEW {}".format(schema_editor.quote_name(view_name))
)

@classmethod
def alias_model(cls, schema_editor, model, alias_name):
for many_to_many in model._meta.local_many_to_many:
through = many_to_many.remote_field.through
if through._meta.auto_created:
raise NotImplementedError(
"Aliasing of models containing many-to-many fields "
"without an explicit intermediary model (through) is not "
"implemented."
)
cls._create_view(schema_editor, model, alias_name)

def unalias_model(cls, schema_editor, model, alias_name):
view_name = cls._get_view_name(model._meta.app_label, alias_name)
for many_to_many in model._meta.local_many_to_many:
through = many_to_many.remote_field.through
if through._meta.auto_created:
raise NotImplementedError(
"Un-aliasing of models containing many-to-many fields "
"without an explicit intermediary model (through) is not "
"implemented."
)
cls._drop_view(schema_editor, view_name)


class AliasedRenameModel(AliasOperationMixin, operations.RenameModel):
class AliasModel(AliasOperationMixin, operations.models.ModelOperation):
"""
First stage of alias-based `RenameModel` replacement.
Implements the alias creation for the table meant to be renamed.
"""

stage = Stage.PRE_DEPLOY
alias_name: str

def __init__(self, name: str, alias_name: str):
self.alias_name = alias_name
super().__init__(name)

@cached_property
def alias_name_lower(self):
return self.alias_name.lower()

def state_forwards(self, app_label, state):
# Create an un-managed model to represent the alias.
aliased_model = state.models[app_label, self.name_lower].clone()
aliased_model.name = self.alias_name
aliased_model.options["managed"] = False
state.add_model(aliased_model)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.new_name)
model = to_state.apps.get_model(app_label, self.name)
alias = schema_editor.connection.alias
if not self.allow_migrate_model(alias, new_model):
if not self.allow_migrate_model(alias, model):
return
old_model = from_state.apps.get_model(app_label, self.old_name)
view_db_name = old_model._meta.db_table
super().database_forwards(app_label, schema_editor, from_state, to_state)
self.create_view(schema_editor, view_db_name, new_model)
self.alias_model(schema_editor, model, self.alias_name)

def database_backwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.old_name_lower)
model = from_state.apps.get_model(app_label, self.name)
alias = schema_editor.connection.alias
if not self.allow_migrate_model(alias, new_model):
if not self.allow_migrate_model(alias, model):
return
self.drop_view(schema_editor, new_model._meta.db_table)
super().database_backwards(app_label, schema_editor, from_state, to_state)
self.unalias_model(schema_editor, model, self.alias_name)

def describe(self):
return "Rename model %s to %s while creating an alias for %s" % (
self.old_name,
self.new_name,
self.old_name,
)
return f"Alias model {self.name} to {self.alias_name}"

def reduce(self, operation, app_label):
if (
isinstance(operation, UnaliasModel)
and operation.name_lower == self.new_name_lower
isinstance(operation, RenameAliasedModel)
and operation.name_lower == operation.old_name_lower
and operation.new_name_lower == self.alias_name_lower
):
return [operations.RenameModel(self.old_name, self.new_name)]
return [operations.RenameModel(operation.old_name, operation.new_name)]
return super().reduce(operation, app_label)


class UnaliasModel(AliasOperationMixin, operations.models.ModelOperation):
class RenameAliasedModel(AliasOperationMixin, operations.RenameModel):
"""
Second stage of alias-based `RenameModel` replacements.
Implements the table renaming and alias removal.
"""

stage = Stage.POST_DEPLOY

def __init__(self, name, view_db_name):
self.view_db_name = view_db_name
super().__init__(name)
def state_forwards(self, app_label, state):
state.remove_model(app_label, self.new_name)
super().state_forwards(app_label, state)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name)
if not self.allow_migrate_model(schema_editor.connection.alias, model):
to_model = to_state.apps.get_model(app_label, self.new_name)
if not self.allow_migrate_model(schema_editor.connection.alias, to_model):
return
self.drop_view(self.view_db_name)
self.unalias_model(schema_editor, to_model, self.alias_name)
super().database_forwards(app_label, schema_editor, from_state, to_state)

def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name)
if not self.allow_migrate_model(schema_editor.connection.alias, model):
from_model = from_state.apps.get_model(app_label, self.new_name)
if not self.allow_migrate_model(schema_editor.connection.alias, from_model):
return
self.create_view(schema_editor, self.view_db_name, model)
super().database_backwards(app_label, schema_editor, from_state, to_state)
self.alias_model(schema_editor, from_model, self.alias_name)

def describe(self):
return f"Rename model {self.name} to {self.alias_name}"
27 changes: 14 additions & 13 deletions tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from syzygy.compat import field_db_default_supported
from syzygy.constants import Stage
from syzygy.operations import (
AliasedRenameModel,
AliasModel,
AlterField,
RenameAliasedModel,
RenameField,
RenameModel,
UnaliasModel,
get_post_add_field_operation,
get_pre_add_field_operation,
get_pre_remove_field_operation,
Expand Down Expand Up @@ -496,11 +496,11 @@ def test_serialization(self):
self.assertEqual(operation.stage, Stage.PRE_DEPLOY)


class AliasedRenameModelTests(OperationTestCase):
class AliasModelTests(OperationTestCase):
def test_describe(self):
self.assertEqual(
AliasedRenameModel("OldName", "NewName").describe(),
"Rename model OldName to NewName while creating an alias for OldName",
AliasModel("OldName", "NewName").describe(),
"Alias model OldName to NewName",
)

def _apply_forwards(self):
Expand All @@ -514,7 +514,7 @@ def _apply_forwards(self):
new_model_name = "NewTestModel"
post_state = self.apply_operations(
[
AliasedRenameModel(model_name, new_model_name),
AliasModel(model_name, new_model_name),
],
pre_state,
)
Expand All @@ -524,15 +524,16 @@ def test_database_forwards(self):
(pre_state, _), (post_state, new_model_name) = self._apply_forwards()
pre_model = pre_state.apps.get_model("tests", "testmodel")
pre_obj = pre_model.objects.create(foo=1)
if connection.vendor == "sqlite":
# SQLite doesn't allow the usage of RETURNING in INSTEAD OF INSERT
# triggers and thus the object has to be refetched.
pre_obj = pre_model.objects.latest("pk")
self.assertEqual(pre_model.objects.get(), pre_obj)
post_model = post_state.apps.get_model("tests", new_model_name)
self.assertEqual(post_model.objects.get().pk, pre_obj.pk)
pre_model.objects.all().delete()
post_obj = post_model.objects.create(foo=2)
# XXX: Does that make the option non-viable on SQLite?
if connection.vendor == "sqlite":
# SQLite doesn't allow the usage of RETURNING in INSTEAD OF INSERT
# triggers and thus the object has to be re-fetched.
post_obj = post_model.objects.latest("pk")
self.assertEqual(post_model.objects.get(), post_obj)
self.assertEqual(pre_model.objects.get().pk, post_obj.pk)
pre_model.objects.update(foo=3)
Expand All @@ -541,19 +542,19 @@ def test_database_forwards(self):
def test_database_backwards(self):
(pre_state, model_name), (post_state, new_model_name) = self._apply_forwards()
with connection.schema_editor() as schema_editor:
AliasedRenameModel(model_name, new_model_name).database_backwards(
AliasModel(model_name, new_model_name).database_backwards(
"tests", schema_editor, post_state, pre_state
)

def test_elidable(self):
model_name = "TestModel"
new_model_name = "NewTestModel"
operations = [
AliasedRenameModel(
AliasModel(
model_name,
new_model_name,
),
UnaliasModel(new_model_name, "tests_testmodel"),
RenameAliasedModel(model_name, new_model_name),
]
self.assert_optimizes_to(
operations, [migrations.RenameModel(model_name, new_model_name)]
Expand Down

0 comments on commit b02dc91

Please sign in to comment.