Skip to content

Commit

Permalink
Handle creation of m2m field in model creation (#36)
Browse files Browse the repository at this point in the history
* Handle creation of m2m object in model creation

* Create a generic handler for model fields

* Add m2m field to test

* Cover new model fields handler in import tests

* Use generator to check m2m fields

* Update m2m fields tests

* Improve readability of m2m status check
  • Loading branch information
NabilMostafa authored Jun 22, 2022
1 parent 5ee9743 commit 1f250ae
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 37 deletions.
25 changes: 25 additions & 0 deletions tests/migrations/0004_auto_20220526_2254.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Generated by Django 3.2 on 2022-05-26 20:54

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('tests', '0003_similartoadvert'),
]

operations = [
migrations.CreateModel(
name='Publication',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('title', models.CharField(max_length=30)),
],
),
migrations.AddField(
model_name='advert',
name='publications',
field=models.ManyToManyField(blank=True, null=True, to='tests.Publication'),
),
]
10 changes: 8 additions & 2 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
class SimplePage(Page):
intro = models.TextField()


class Publication(models.Model):
title = models.CharField(max_length=30)


@register_snippet
class Advert(AirtableMixin, models.Model):

STAR_RATINGS = (
(1.0, "1"),
(1.5, "1.5"),
Expand All @@ -35,6 +39,7 @@ class Advert(AirtableMixin, models.Model):
long_description = RichTextField(blank=True, null=True)
points = models.IntegerField(null=True, blank=True)
slug = models.SlugField(max_length=100, unique=True, editable=True)
publications = models.ManyToManyField(Publication, null=True, blank=True)

@classmethod
def map_import_fields(cls):
Expand All @@ -48,11 +53,11 @@ def map_import_fields(cls):
"long_description": "long_description",
"points": "points",
"slug": "slug",
"publications": "publications",
}
return mappings

def get_export_fields(self):

return {
"title": self.title,
"description": self.description,
Expand All @@ -62,6 +67,7 @@ def get_export_fields(self):
"long_description": self.long_description,
"points": self.points,
"slug": self.slug,
"publications": self.publications,
}

class Meta:
Expand Down
30 changes: 29 additions & 1 deletion tests/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,36 @@
from wagtail_airtable.serializers import AirtableSerializer


class AdvertSerializer(AirtableSerializer):
class PublicationsObjectsSerializer(serializers.RelatedField):
"""
Let's assume there's a "bank_name" column in Airtable but it stores a string.
When importing from Airtable you'll need to find a model object based on that name.
That's what this serializer is doing.
Usage:
class YourModelSerializer(AirtableSerializer):
...
bank_name = BankNameSerializer(required=False)
...
"""

def to_internal_value(self, data):
from .models import Publication
publications = []
if data:
for publication in data:
publication_obj, _ = Publication.objects.get_or_create(title=publication["title"])
publications.append(publication_obj)
return publications
return data

def get_queryset(self):
pass


class AdvertSerializer(AirtableSerializer):
slug = serializers.CharField(max_length=100, required=True)
title = serializers.CharField(max_length=255)
external_link = serializers.URLField(required=False)
publications = PublicationsObjectsSerializer(required=False)
31 changes: 31 additions & 0 deletions tests/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,37 @@ def test_get_or_set_cached_records(self):
cached_records["second_cached_entry"] = all_records
self.assertEqual(importer.cached_records, cached_records)

def test_check_field_is_m2m(self):
importer = Importer()

client = MockAirtable()
records = client.get_all()
client.get_all.assert_called()
for i, record in enumerate(records):
for field_name, value, m2m in importer.get_fields_and_m2m_status(Advert, record["fields"]):
if field_name == "publications":
self.assertEqual(m2m, True)
else:
self.assertEqual(m2m, False)

def test_update_m2m_fields(self):
importer = Importer()

client = MockAirtable()
records = client.get_all()
client.get_all.assert_called()
advert = Advert.objects.first()
self.assertEqual(len(advert.publications.all()), 0)

advert_serializer = AdvertSerializer(data=records[0]["fields"])
self.assertEqual(advert_serializer.is_valid(), True)

publications_dict = advert_serializer.validated_data["publications"]

importer.update_model_m2m_fields(advert, "publications", publications_dict)

self.assertEqual(len(advert.publications.all()), 3)

def test_convert_mapped_fields(self):
importer = Importer()
record_fields_dict = self.get_valid_record_fields()
Expand Down
81 changes: 47 additions & 34 deletions wagtail_airtable/management/commands/import_airtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

logger = getLogger(__name__)


DEFAULT_OPTIONS = {
"verbosity": 1,
}
Expand Down Expand Up @@ -116,6 +115,27 @@ def convert_mapped_fields(self, record_fields_dict, mapped_fields_dict) -> dict:
}
return mapped_fields_dict

def update_model_m2m_fields(self, instance, field_name, value) -> None:
m2m_field = getattr(instance, field_name)
for m2m_value in value:
m2m_field.add(m2m_value)

def get_fields_and_m2m_status(self, model, data: dict):
for field_name, value in data.items():
field_type = type(
model._meta.get_field(field_name)
) # ie. django.db.models.fields.CharField
is_m2m = issubclass(
field_type,
(
TaggableManager,
ClusterTaggableManager,
models.ManyToManyField,
),
)

yield field_name, value, is_m2m

def update_object(
self, instance, record_id, serialized_data, is_wagtail_model=False
) -> bool:
Expand All @@ -129,23 +149,11 @@ def update_object(
"\t\t Serializer data was valid. Setting attrs on model..."
)
model = type(instance)

for field_name, value in serialized_data.validated_data.items():
field_type = type(
model._meta.get_field(field_name)
) # ie. django.db.models.fields.CharField
# If this field type is a subclass of a known Wagtail Tag, or a Django m2m field
# We need to loop through all the values and add them to the m2m-style field.
if issubclass(
field_type,
(TaggableManager, ClusterTaggableManager, models.ManyToManyField,),
):
m2m_field = getattr(instance, field_name)
for m2m_value in value:
m2m_field.add(m2m_value)
for field_name, value, is_m2m in self.get_fields_and_m2m_status(model, serialized_data.validated_data):
if is_m2m:
self.update_model_m2m_fields(instance, field_name, value)
else:
setattr(instance, field_name, value)

try:
if instance.revisions.count():
before = instance.revisions.last().content_json
Expand Down Expand Up @@ -232,23 +240,9 @@ def update_object_by_uniq_col_name(

if instance:
# A local model object was found by a unique identifier.
for field_name, value in serialized_data.validated_data.items():
field_type = type(
model._meta.get_field(field_name)
) # ie. django.db.models.fields.CharField
# If this field type is a subclass of a known Wagtail Tag, or a Django m2m field
# We need to loop through all the values and add them to the m2m-style field.
if issubclass(
field_type,
(
TaggableManager,
ClusterTaggableManager,
models.ManyToManyField,
),
):
m2m_field = getattr(instance, field_name)
for m2m_value in value:
m2m_field.add(m2m_value)
for field_name, value, is_m2m in self.get_fields_and_m2m_status(model, serialized_data.validated_data):
if is_m2m:
self.update_model_m2m_fields(instance, field_name, value)
else:
setattr(instance, field_name, value)
# When an object is saved it should NOT push its newly saved data back to Airtable.
Expand Down Expand Up @@ -529,7 +523,15 @@ def run(self):
data_for_new_model = self.get_data_for_new_model(
serialized_data, mapped_import_fields, record_id
)

# extract m2m fields to avoid getting the error
# direct assignment to the forward side of a many-to-many set is prohibited
m2m_fields = {}
temp_data = data_for_new_model.copy()
for field_name, value, is_m2m in self.get_fields_and_m2m_status(model, data_for_new_model):
if is_m2m:
m2m_fields[field_name] = value
temp_data.pop(field_name)
data_for_new_model = temp_data
# If there is no match whatsoever, try to create a new `model` instance.
# Note: this may fail if there isn't enough data in the Airtable record.
try:
Expand Down Expand Up @@ -562,6 +564,17 @@ def run(self):
self.debug_message("\t\t Page created")
else:
new_model.save()
# create m2m relationship objects
if m2m_fields:
try:
for field_name, value in m2m_fields.items():
self.update_model_m2m_fields(new_model, field_name, value)
except Exception as e:
logger.info(
f"Could not create new model m2m relationship. Error: {e}"
)
self.debug_message(f"\tCannot create m2m relationship with the model. Error: {e}")
continue
self.debug_message("\t\t Object created")
import_successful = True
self.created = self.created + 1
Expand Down
10 changes: 10 additions & 0 deletions wagtail_airtable/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ class MockAirtable(mock.Mock):
"long_description": "<p>Lorem ipsum dolor sit amet, consectetur adipisicing elit. Veniam laboriosam consequatur saepe. Repellat itaque dolores neque, impedit reprehenderit eum culpa voluptates harum sapiente nesciunt ratione.</p>",
"points": 95,
"slug": "red-its-new-blue",
"publications": [
{"title": "Record 1 publication 1"},
{"title": "Record 1 publication 2"},
{"title": "Record 1 publication 3"},
]
},
},
{
Expand Down Expand Up @@ -138,6 +143,11 @@ class MockAirtable(mock.Mock):
"long_description": "",
"points": 1,
"slug": "record-4",
"publications": [
{"title": "Record 4 publication 1"},
{"title": "Record 4 publication 2"},
{"title": "Record 4 publication 3"},
]
},
},
]

0 comments on commit 1f250ae

Please sign in to comment.