Skip to content

Commit

Permalink
Merge branch 'main' into mdb-8-auth-ssl-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jibola authored Jan 23, 2025
2 parents 1ec9ce9 + 37af20e commit 3c5936c
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 14 deletions.
4 changes: 2 additions & 2 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _get_column_from_expression(self, expr, alias):
Create a column named `alias` from the given expression to hold the
aggregate value.
"""
column_target = expr.output_field.__class__()
column_target = expr.output_field.clone()
column_target.db_column = alias
column_target.set_attributes_from_name(alias)
return Col(self.collection_name, column_target)
Expand Down Expand Up @@ -81,7 +81,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
alias = (
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
)
column_target = sub_expr.output_field.__class__()
column_target = sub_expr.output_field.clone()
column_target.db_column = alias
column_target.set_attributes_from_name(alias)
inner_column = Col(self.collection_name, column_target)
Expand Down
40 changes: 36 additions & 4 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import difflib

from django.core import checks
from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.models.fields.related import lazy_related_operation
from django.db.models.lookups import Transform
Expand Down Expand Up @@ -123,7 +126,8 @@ def get_transform(self, name):
transform = super().get_transform(name)
if transform:
return transform
return KeyTransformFactory(name)
field = self.embedded_model._meta.get_field(name)
return KeyTransformFactory(name, field)

def validate(self, value, model_instance):
super().validate(value, model_instance)
Expand All @@ -145,9 +149,36 @@ def formfield(self, **kwargs):


class KeyTransform(Transform):
def __init__(self, key_name, *args, **kwargs):
def __init__(self, key_name, ref_field, *args, **kwargs):
super().__init__(*args, **kwargs)
self.key_name = str(key_name)
self.ref_field = ref_field

def get_transform(self, name):
"""
Validate that `name` is either a field of an embedded model or a
lookup on an embedded model's field.
"""
result = None
if isinstance(self.ref_field, EmbeddedModelField):
opts = self.ref_field.embedded_model._meta
new_field = opts.get_field(name)
result = KeyTransformFactory(name, new_field)
else:
if self.ref_field.get_transform(name) is None:
suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups())
if suggested_lookups:
suggested_lookups = " or ".join(suggested_lookups)
suggestion = f", perhaps you meant {suggested_lookups}?"
else:
suggestion = "."
raise FieldDoesNotExist(
f"Unsupported lookup '{name}' for "
f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'"
f"{suggestion}"
)
result = KeyTransformFactory(name, self.ref_field)
return result

def preprocess_lhs(self, compiler, connection):
key_transforms = [self.key_name]
Expand All @@ -165,8 +196,9 @@ def as_mql(self, compiler, connection):


class KeyTransformFactory:
def __init__(self, key_name):
def __init__(self, key_name, ref_field):
self.key_name = key_name
self.ref_field = ref_field

def __call__(self, *args, **kwargs):
return KeyTransform(self.key_name, *args, **kwargs)
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)
92 changes: 84 additions & 8 deletions tests/model_fields_/test_embedded_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from django.core.exceptions import ValidationError
import operator

from django.core.exceptions import FieldDoesNotExist, ValidationError
from django.db import models
from django.db.models import ExpressionWrapper, F, Max, Sum
from django.test import SimpleTestCase, TestCase
from django.test.utils import isolate_apps

Expand All @@ -13,6 +16,7 @@
Data,
Holder,
)
from .utils import truncate_ms


class MethodTests(SimpleTestCase):
Expand All @@ -38,10 +42,6 @@ def test_validate(self):


class ModelTests(TestCase):
def truncate_ms(self, value):
"""Truncate microseconds to milliseconds as supported by MongoDB."""
return value.replace(microsecond=(value.microsecond // 1000) * 1000)

def test_save_load(self):
Holder.objects.create(data=Data(integer="5"))
obj = Holder.objects.get()
Expand All @@ -64,12 +64,12 @@ def test_save_load_null(self):
def test_pre_save(self):
"""Field.pre_save() is called on embedded model fields."""
obj = Holder.objects.create(data=Data())
auto_now = self.truncate_ms(obj.data.auto_now)
auto_now_add = self.truncate_ms(obj.data.auto_now_add)
auto_now = truncate_ms(obj.data.auto_now)
auto_now_add = truncate_ms(obj.data.auto_now_add)
self.assertEqual(auto_now, auto_now_add)
# save() updates auto_now but not auto_now_add.
obj.save()
self.assertEqual(self.truncate_ms(obj.data.auto_now_add), auto_now_add)
self.assertEqual(truncate_ms(obj.data.auto_now_add), auto_now_add)
auto_now_two = obj.data.auto_now
self.assertGreater(auto_now_two, obj.data.auto_now_add)
# And again, save() updates auto_now but not auto_now_add.
Expand Down Expand Up @@ -99,13 +99,89 @@ def test_gt(self):
def test_gte(self):
self.assertCountEqual(Holder.objects.filter(data__integer__gte=3), self.objs[3:])

def test_order_by_embedded_field(self):
qs = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer")
self.assertSequenceEqual(qs, list(reversed(self.objs[4:])))

def test_order_and_group_by_embedded_field(self):
# Create and sort test data by `data__integer`.
expected_objs = sorted(
(Holder.objects.create(data=Data(integer=x)) for x in range(6)),
key=lambda x: x.data.integer,
)
# Group by `data__integer + 5` and get the latest `data__auto_now`
# datetime.
qs = (
Holder.objects.annotate(
group=ExpressionWrapper(F("data__integer") + 5, output_field=models.IntegerField()),
)
.values("group")
.annotate(max_auto_now=Max("data__auto_now"))
.order_by("data__integer")
)
# Each unique `data__integer` is correctly grouped and annotated.
self.assertSequenceEqual(
[{**e, "max_auto_now": e["max_auto_now"]} for e in qs],
[
{"group": e.data.integer + 5, "max_auto_now": truncate_ms(e.data.auto_now)}
for e in expected_objs
],
)

def test_order_and_group_by_embedded_field_annotation(self):
# Create repeated `data__integer` values.
[Holder.objects.create(data=Data(integer=x)) for x in range(6)]
# Group by `data__integer` and compute the sum of occurrences.
qs = (
Holder.objects.values("data__integer")
.annotate(sum=Sum("data__integer"))
.order_by("sum")
)
# The sum is twice the integer values since each appears twice.
self.assertQuerySetEqual(qs, [0, 2, 4, 6, 8, 10], operator.itemgetter("sum"))

def test_nested(self):
obj = Book.objects.create(
author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY"))
)
self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj])


class InvalidLookupTests(SimpleTestCase):
def test_invalid_field(self):
msg = "Author has no field named 'first_name'"
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Book.objects.filter(author__first_name="Bob")

def test_invalid_field_nested(self):
msg = "Address has no field named 'floor'"
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Book.objects.filter(author__address__floor="NYC")

def test_invalid_lookup(self):
msg = "Unsupported lookup 'foo' for CharField 'city'."
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Book.objects.filter(author__address__city__foo="NYC")

def test_invalid_lookup_with_suggestions(self):
msg = (
"Unsupported lookup '{lookup}' for CharField 'name', "
"perhaps you meant {suggested_lookups}?"
)
with self.assertRaisesMessage(
FieldDoesNotExist, msg.format(lookup="exactly", suggested_lookups="exact or iexact")
):
Book.objects.filter(author__name__exactly="NYC")
with self.assertRaisesMessage(
FieldDoesNotExist, msg.format(lookup="gti", suggested_lookups="gt or gte")
):
Book.objects.filter(author__name__gti="NYC")
with self.assertRaisesMessage(
FieldDoesNotExist, msg.format(lookup="is_null", suggested_lookups="isnull")
):
Book.objects.filter(author__name__is_null="NYC")


@isolate_apps("model_fields_")
class CheckTests(SimpleTestCase):
def test_no_relational_fields(self):
Expand Down
3 changes: 3 additions & 0 deletions tests/model_fields_/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def truncate_ms(value):
"""Truncate microseconds to milliseconds as supported by MongoDB."""
return value.replace(microsecond=(value.microsecond // 1000) * 1000)

0 comments on commit 3c5936c

Please sign in to comment.