diff --git a/docs/changelog.rst b/docs/changelog.rst index 835a85245..76de9d6cf 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -12,6 +12,10 @@ Development - Fix validate() not being called when inheritance is used in EmbeddedDocument and validate is overriden #2784 - Add support for readPreferenceTags in connection parameters #2644 - Use estimated_documents_count OR documents_count when count is called, based on the query #2529 +- Fix no_dereference context manager which wasn't turning off auto-dereferencing correctly in some cases #2788 +- BREAKING CHANGE: no_dereference context manager no longer returns the class in __enter__ #2788 + as it was useless and making it look like it was returning a different class although it was the same. + Thus, it must be called like `with no_dereference(User):` and no longer `with no_dereference(User) as ...:` Changes in 0.27.0 ================= diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index b9afb60e1..23beb5a30 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -522,7 +522,7 @@ data. To turn off dereferencing of the results of a query use You can also turn off all dereferencing for a fixed period by using the :class:`~mongoengine.context_managers.no_dereference` context manager:: - with no_dereference(Post) as Post: + with no_dereference(Post): post = Post.objects.first() assert(isinstance(post.author, DBRef)) diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index eb9c99622..c864f4539 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -1,3 +1,4 @@ +import threading from contextlib import contextmanager from pymongo.read_concern import ReadConcern @@ -18,6 +19,25 @@ ) +thread_locals = threading.local() +thread_locals.no_dereferencing_class = {} + + +def no_dereferencing_active_for_class(cls): + return cls in thread_locals.no_dereferencing_class + + +def _register_no_dereferencing_for_class(cls): + thread_locals.no_dereferencing_class.setdefault(cls, 0) + thread_locals.no_dereferencing_class[cls] += 1 + + +def _unregister_no_dereferencing_for_class(cls): + thread_locals.no_dereferencing_class[cls] -= 1 + if thread_locals.no_dereferencing_class[cls] == 0: + thread_locals.no_dereferencing_class.pop(cls) + + class switch_db: """switch_db alias context manager. @@ -107,7 +127,7 @@ class no_dereference: Turns off all dereferencing in Documents for the duration of the context manager:: - with no_dereference(Group) as Group: + with no_dereference(Group): Group.objects.find() """ @@ -130,15 +150,17 @@ def __init__(self, cls): def __enter__(self): """Change the objects default and _auto_dereference values.""" + _register_no_dereferencing_for_class(self.cls) + for field in self.deref_fields: self.cls._fields[field]._auto_dereference = False - return self.cls def __exit__(self, t, value, traceback): """Reset the default and _auto_dereference values.""" + _unregister_no_dereferencing_for_class(self.cls) + for field in self.deref_fields: self.cls._fields[field]._auto_dereference = True - return self.cls class no_sub_classes: diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index d33e7b1e3..ad018f468 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -17,6 +17,7 @@ from mongoengine.common import _import_class from mongoengine.connection import get_db from mongoengine.context_managers import ( + no_dereferencing_active_for_class, set_read_write_concern, set_write_concern, switch_db, @@ -51,9 +52,6 @@ class BaseQuerySet: providing :class:`~mongoengine.Document` objects as the results. """ - __dereference = False - _auto_dereference = True - def __init__(self, document, collection): self._document = document self._collection_obj = collection @@ -74,6 +72,9 @@ def __init__(self, document, collection): self._as_pymongo = False self._search_text = None + self.__dereference = False + self.__auto_dereference = True + # If inheritance is allowed, only return instances and instances of # subclasses of the class being used if document._meta.get("allow_inheritance") is True: @@ -795,7 +796,7 @@ def clone(self): return self._clone_into(self.__class__(self._document, self._collection_obj)) def _clone_into(self, new_qs): - """Copy all of the relevant properties of this queryset to + """Copy all the relevant properties of this queryset to a new queryset (which has to be an instance of :class:`~mongoengine.queryset.base.BaseQuerySet`). """ @@ -825,7 +826,6 @@ def _clone_into(self, new_qs): "_empty", "_hint", "_collation", - "_auto_dereference", "_search_text", "_max_time_ms", "_comment", @@ -836,6 +836,8 @@ def _clone_into(self, new_qs): val = getattr(self, prop) setattr(new_qs, prop, copy.copy(val)) + new_qs.__auto_dereference = self._BaseQuerySet__auto_dereference + if self._cursor_obj: new_qs._cursor_obj = self._cursor_obj.clone() @@ -1741,10 +1743,15 @@ def _dereference(self): self.__dereference = _import_class("DeReference")() return self.__dereference + @property + def _auto_dereference(self): + should_deref = not no_dereferencing_active_for_class(self._document) + return should_deref and self.__auto_dereference + def no_dereference(self): """Turn off any dereferencing for the results of this queryset.""" queryset = self.clone() - queryset._auto_dereference = False + queryset.__auto_dereference = False return queryset # Helper Functions diff --git a/tests/document/test_indexes.py b/tests/document/test_indexes.py index 8a0486412..7ba18d588 100644 --- a/tests/document/test_indexes.py +++ b/tests/document/test_indexes.py @@ -9,7 +9,6 @@ from mongoengine.connection import get_db from mongoengine.mongodb_support import ( MONGODB_42, - MONGODB_70, get_mongodb_version, ) from mongoengine.pymongo_support import PYMONGO_VERSION @@ -451,89 +450,29 @@ class Test(Document): # the documents returned might have more keys in that here. query_plan = Test.objects(id=obj.id).exclude("a").explain() assert ( - query_plan.get("queryPlanner") - .get("winningPlan") - .get("inputStage") - .get("stage") - == "IDHACK" + query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IDHACK" ) query_plan = Test.objects(id=obj.id).only("id").explain() assert ( - query_plan.get("queryPlanner") - .get("winningPlan") - .get("inputStage") - .get("stage") - == "IDHACK" + query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IDHACK" ) mongo_db = get_mongodb_version() query_plan = Test.objects(a=1).only("a").exclude("id").explain() - if mongo_db < MONGODB_70: - assert ( - query_plan.get("queryPlanner") - .get("winningPlan") - .get("inputStage") - .get("stage") - == "IXSCAN" - ) - else: - assert ( - query_plan.get("queryPlanner") - .get("winningPlan") - .get("queryPlan") - .get("inputStage") - .get("stage") - == "IXSCAN" - ) + assert ( + query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IXSCAN" + ) PROJECTION_STR = "PROJECTION" if mongo_db < MONGODB_42 else "PROJECTION_COVERED" - if mongo_db < MONGODB_70: - assert ( - query_plan.get("queryPlanner").get("winningPlan").get("stage") - == PROJECTION_STR - ) - else: - assert ( - query_plan.get("queryPlanner") - .get("winningPlan") - .get("queryPlan") - .get("stage") - == PROJECTION_STR - ) + assert query_plan["queryPlanner"]["winningPlan"]["stage"] == PROJECTION_STR query_plan = Test.objects(a=1).explain() - if mongo_db < MONGODB_70: - assert ( - query_plan.get("queryPlanner") - .get("winningPlan") - .get("inputStage") - .get("stage") - == "IXSCAN" - ) - else: - assert ( - query_plan.get("queryPlanner") - .get("winningPlan") - .get("queryPlan") - .get("inputStage") - .get("stage") - == "IXSCAN" - ) - - if mongo_db < MONGODB_70: - assert ( - query_plan.get("queryPlanner").get("winningPlan").get("stage") - == "FETCH" - ) - else: - assert ( - query_plan.get("queryPlanner") - .get("winningPlan") - .get("queryPlan") - .get("stage") - == "FETCH" - ) + assert ( + query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IXSCAN" + ) + + assert query_plan.get("queryPlanner").get("winningPlan").get("stage") == "FETCH" def test_index_on_id(self): class BlogPost(Document): diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index daf9a2c18..ac9ded729 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -1,6 +1,7 @@ import unittest import pytest +from bson import DBRef from mongoengine import * from mongoengine.connection import get_db @@ -19,8 +20,6 @@ class TestContextManagers(MongoDBTestCase): def test_set_write_concern(self): - connect("mongoenginetest") - class User(Document): name = StringField() @@ -39,8 +38,6 @@ class User(Document): assert original_write_concern.document == collection.write_concern.document def test_set_read_write_concern(self): - connect("mongoenginetest") - class User(Document): name = StringField() @@ -65,7 +62,6 @@ class User(Document): assert original_write_concern.document == collection.write_concern.document def test_switch_db_context_manager(self): - connect("mongoenginetest") register_connection("testdb-1", "mongoenginetest2") class Group(Document): @@ -89,7 +85,6 @@ class Group(Document): assert 1 == Group.objects.count() def test_switch_collection_context_manager(self): - connect("mongoenginetest") register_connection(alias="testdb-1", db="mongoenginetest2") class Group(Document): @@ -117,7 +112,6 @@ class Group(Document): def test_no_dereference_context_manager_object_id(self): """Ensure that DBRef items in ListFields aren't dereferenced.""" - connect("mongoenginetest") class User(Document): name = StringField() @@ -136,25 +130,57 @@ class Group(Document): user = User.objects.first() Group(ref=user, members=User.objects, generic=user).save() - with no_dereference(Group) as NoDeRefGroup: - assert Group._fields["members"]._auto_dereference - assert not NoDeRefGroup._fields["members"]._auto_dereference + with no_dereference(Group): + assert not Group._fields["members"]._auto_dereference - with no_dereference(Group) as Group: + with no_dereference(Group): group = Group.objects.first() for m in group.members: - assert not isinstance(m, User) - assert not isinstance(group.ref, User) - assert not isinstance(group.generic, User) + assert isinstance(m, DBRef) + assert isinstance(group.ref, DBRef) + assert isinstance(group.generic, dict) + group = Group.objects.first() for m in group.members: assert isinstance(m, User) assert isinstance(group.ref, User) assert isinstance(group.generic, User) - def test_no_dereference_context_manager_dbref(self): + def test_no_dereference_context_manager_nested(self): """Ensure that DBRef items in ListFields aren't dereferenced.""" - connect("mongoenginetest") + + class User(Document): + name = StringField() + + class Group(Document): + ref = ReferenceField(User, dbref=False) + + User.drop_collection() + Group.drop_collection() + + for i in range(1, 51): + User(name="user %s" % i).save() + + user = User.objects.first() + Group(ref=user).save() + + with no_dereference(Group): + group = Group.objects.first() + assert isinstance(group.ref, DBRef) + + with no_dereference(Group): + group = Group.objects.first() + assert isinstance(group.ref, DBRef) + + # make sure its still off here + group = Group.objects.first() + assert isinstance(group.ref, DBRef) + + group = Group.objects.first() + assert isinstance(group.ref, User) + + def test_no_dereference_context_manager_dbref(self): + """Ensure that DBRef items in ListFields aren't dereferenced""" class User(Document): name = StringField() @@ -173,16 +199,19 @@ class Group(Document): user = User.objects.first() Group(ref=user, members=User.objects, generic=user).save() - with no_dereference(Group) as NoDeRefGroup: - assert Group._fields["members"]._auto_dereference - assert not NoDeRefGroup._fields["members"]._auto_dereference + with no_dereference(Group): + assert not Group._fields["members"]._auto_dereference - with no_dereference(Group) as Group: - group = Group.objects.first() + with no_dereference(Group): + qs = Group.objects + assert qs._auto_dereference is False + group = qs.first() + assert not group._fields["members"]._auto_dereference assert all(not isinstance(m, User) for m in group.members) assert not isinstance(group.ref, User) assert not isinstance(group.generic, User) + group = Group.objects.first() assert all(isinstance(m, User) for m in group.members) assert isinstance(group.ref, User) assert isinstance(group.generic, User) @@ -265,7 +294,6 @@ def test_query_counter_does_not_swallow_exception(self): raise TypeError() def test_query_counter_temporarily_modifies_profiling_level(self): - connect("mongoenginetest") db = get_db() def _current_profiling_level(): @@ -290,7 +318,6 @@ def _set_profiling_level(lvl): raise def test_query_counter(self): - connect("mongoenginetest") db = get_db() collection = db.query_counter @@ -380,7 +407,6 @@ class B(Document): assert q == 3 def test_query_counter_counts_getmore_queries(self): - connect("mongoenginetest") db = get_db() collection = db.query_counter @@ -397,7 +423,6 @@ def test_query_counter_counts_getmore_queries(self): assert q == 2 # 1st select + 1 getmore def test_query_counter_ignores_particular_queries(self): - connect("mongoenginetest") db = get_db() collection = db.query_counter