From a136a5ff4440d72844d366e500b291dbeef301dc Mon Sep 17 00:00:00 2001 From: David Vogt Date: Mon, 10 Jun 2024 16:20:04 +0200 Subject: [PATCH] feat(models): (re)introduce some convenience methods for working with scope trees The Django-MPTT module provided some useful methods that are not available anymore with django-tree-queries. Luckily, it's relatively easy to provide workarounds. Note that they might not have the same performance/efficiency as the MPTT variants, and could possibly be built in a better way. However, let's keep it to the motto "first make it right, then fast, then pretty" --- emeis/core/models.py | 59 ++++++++++++- emeis/core/tests/test_models.py | 149 ++++++++++++++++++++++++++++++++ 2 files changed, 207 insertions(+), 1 deletion(-) diff --git a/emeis/core/models.py b/emeis/core/models.py index 00360be..b82efc6 100644 --- a/emeis/core/models.py +++ b/emeis/core/models.py @@ -1,5 +1,7 @@ +import operator import unicodedata import uuid +from functools import reduce from django.conf import settings from django.contrib.auth.models import AbstractBaseUser, UserManager @@ -158,6 +160,58 @@ def is_authenticated(self): return True +class ScopeQuerySet(TreeQuerySet): + # django-tree-queries sadly does not (yet?) support ancestors query + # for QS - only for single nodes. So we're providing all_descendants() + # and all_ancestors() queryset methods. + + def all_descendants(self, include_self=False): + """Return a QS that contains all descendants of the given QS. + + This is a workaround for django-tree-queries, which currently does + not support this query (it can only do it on single nodes). + + This is in contrast to .descendants(), which can only give the descendants + of one model instance. + """ + descendants_q = reduce( + operator.or_, + [ + models.Q( + pk__in=entry.descendants(include_self=include_self).values("pk") + ) + for entry in self + ], + models.Q(), + ) + return self.model.objects.filter(descendants_q) + + def all_ancestors(self, include_self=False): + """Return a QS that contains all ancestors of the given QS. + + This is a workaround for django-tree-queries, which currently does + not support this query (it can only do it on single nodes). + + This is in contrast to .ancestors(), which can only give the descendants + of one model instance. + """ + + descendants_q = reduce( + operator.or_, + [ + models.Q(pk__in=entry.ancestors(include_self=include_self).values("pk")) + for entry in self + ], + models.Q(), + ) + return self.model.objects.filter(descendants_q) + + def all_roots(self): + return Scope.objects.all().filter( + pk__in=[scope.ancestors(include_self=True).first() for scope in self] + ) + + class Scope(TreeNode, UUIDModel): name = LocalizedCharField(_("scope name"), blank=False, null=False, required=False) @@ -170,7 +224,10 @@ class Scope(TreeNode, UUIDModel): ) is_active = models.BooleanField(default=True) - objects = TreeQuerySet.as_manager(with_tree_fields=True) + objects = ScopeQuerySet.as_manager(with_tree_fields=True) + + def get_root(self): + return self.ancestors(include_self=True).first() def save(self, *args, **kwargs): # django-tree-queries does validation in TreeNode.clean(), which is not diff --git a/emeis/core/tests/test_models.py b/emeis/core/tests/test_models.py index 8e9fb82..d408fc3 100644 --- a/emeis/core/tests/test_models.py +++ b/emeis/core/tests/test_models.py @@ -149,3 +149,152 @@ def test_update_full_name_of_child(db, scope_factory): grandchild.refresh_from_db() assert str(grandchild.full_name) == "r » s » c » g" + + +@pytest.fixture +def simple_tree_structure(db, scope_factory): + # root1 + # - sub1sub1 + # - sub1sub1sub1 + # - sub1sub1sub2 + # - sub1sub2 + # root2 + # - sub2sub1 + # - sub2sub2 + root1 = scope_factory(name="root1") + root2 = scope_factory(name="root2") + sub1sub1 = scope_factory(parent=root1, name="sub1sub1") + sub1sub2 = scope_factory(parent=root1, name="sub1sub2") + sub1sub1sub1 = scope_factory(parent=sub1sub1, name="sub1sub1sub1") + sub1sub1sub2 = scope_factory(parent=sub1sub1, name="sub1sub1sub2") + + sub2sub1 = scope_factory(parent=root2, name="sub2sub1") + sub2sub2 = scope_factory(parent=root2, name="sub2sub2") + return ( + root1, + root2, + sub1sub1, + sub1sub2, + sub1sub1sub1, + sub1sub1sub2, + sub2sub1, + sub2sub2, + ) + + +@pytest.mark.parametrize( + "include_self, expect_count", + [ + (True, 5), + (False, 3), + ], +) +def test_scope_ancestors(db, simple_tree_structure, include_self, expect_count): + ( + root1, + root2, + sub1sub1, + sub1sub2, + sub1sub1sub1, + sub1sub1sub2, + sub2sub1, + sub2sub2, + ) = simple_tree_structure + + qs = Scope.objects.filter(pk__in=[sub2sub2.pk, sub1sub1sub2.pk]) + + ancestors_qs = qs.all_ancestors(include_self=include_self) + # the direct and indirect ancestors must be there + assert root2 in ancestors_qs + assert root1 in ancestors_qs + assert sub1sub1 in ancestors_qs + + if include_self: + assert sub2sub2 in ancestors_qs + assert sub1sub1sub2 in ancestors_qs + else: + assert sub2sub2 not in ancestors_qs + assert sub1sub1sub2 not in ancestors_qs + + # ... and nothing else + assert ancestors_qs.count() == expect_count + + +@pytest.mark.parametrize( + "include_self, expect_count", + [ + (True, 6), + (False, 4), + ], +) +def test_scope_descendants(db, simple_tree_structure, include_self, expect_count): + ( + root1, + root2, + sub1sub1, + sub1sub2, + sub1sub1sub1, + sub1sub1sub2, + sub2sub1, + sub2sub2, + ) = simple_tree_structure + + qs = Scope.objects.filter(pk__in=[sub1sub1.pk, root2.pk]) + + descendants_qs = qs.all_descendants(include_self=include_self) + # the direct and indirect descendants must be there + assert sub1sub1sub1 in descendants_qs + assert sub1sub1sub2 in descendants_qs + assert sub2sub1 in descendants_qs + assert sub2sub2 in descendants_qs + + if include_self: + assert sub1sub1 in descendants_qs + assert root2 in descendants_qs + else: + assert sub1sub1 not in descendants_qs + assert root2 not in descendants_qs + + # ... and nothing else + assert descendants_qs.count() == expect_count + + +def test_get_root(db, simple_tree_structure): + ( + root1, + root2, + sub1sub1, + sub1sub2, + sub1sub1sub1, + sub1sub1sub2, + sub2sub1, + sub2sub2, + ) = simple_tree_structure + + assert sub1sub2.get_root() == root1 + assert sub1sub1.get_root() == root1 + assert sub2sub2.get_root() == root2 + assert sub2sub1.get_root() == root2 + assert sub1sub1sub2.get_root() == root1 + + +def test_all_roots(db, simple_tree_structure): + ( + root1, + root2, + sub1sub1, + sub1sub2, + sub1sub1sub1, + sub1sub1sub2, + sub2sub1, + sub2sub2, + ) = simple_tree_structure + + qs1 = Scope.objects.filter(pk__in=[sub1sub1sub1, sub1sub2]).all_roots() + assert qs1.count() == 1 + assert root1 in qs1 + + qs2 = Scope.objects.filter(pk__in=[sub1sub1sub1, sub2sub2]).all_roots() + assert qs2.count() == 2 + assert root1 in qs2 + assert root2 in qs2