Skip to content

Commit

Permalink
feat(models): (re)introduce some convenience methods for working with…
Browse files Browse the repository at this point in the history
… scope trees

The Django-MPTT module provided some useful methods that are not available anymore
with django-tree-queries. Luckily, it's relatively easy to provide workarounds.

Note that they might not have the same performance/efficiency as the
MPTT variants, and could possibly be built in a better way. However,
let's keep it to the motto "first make it right, then fast, then
pretty"
  • Loading branch information
winged committed Jun 11, 2024
1 parent e0c80f2 commit a136a5f
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 1 deletion.
59 changes: 58 additions & 1 deletion emeis/core/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import operator
import unicodedata
import uuid
from functools import reduce

from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser, UserManager
Expand Down Expand Up @@ -158,6 +160,58 @@ def is_authenticated(self):
return True


class ScopeQuerySet(TreeQuerySet):
# django-tree-queries sadly does not (yet?) support ancestors query
# for QS - only for single nodes. So we're providing all_descendants()
# and all_ancestors() queryset methods.

def all_descendants(self, include_self=False):
"""Return a QS that contains all descendants of the given QS.
This is a workaround for django-tree-queries, which currently does
not support this query (it can only do it on single nodes).
This is in contrast to .descendants(), which can only give the descendants
of one model instance.
"""
descendants_q = reduce(
operator.or_,
[
models.Q(
pk__in=entry.descendants(include_self=include_self).values("pk")
)
for entry in self
],
models.Q(),
)
return self.model.objects.filter(descendants_q)

def all_ancestors(self, include_self=False):
"""Return a QS that contains all ancestors of the given QS.
This is a workaround for django-tree-queries, which currently does
not support this query (it can only do it on single nodes).
This is in contrast to .ancestors(), which can only give the descendants
of one model instance.
"""

descendants_q = reduce(
operator.or_,
[
models.Q(pk__in=entry.ancestors(include_self=include_self).values("pk"))
for entry in self
],
models.Q(),
)
return self.model.objects.filter(descendants_q)

def all_roots(self):
return Scope.objects.all().filter(
pk__in=[scope.ancestors(include_self=True).first() for scope in self]
)


class Scope(TreeNode, UUIDModel):
name = LocalizedCharField(_("scope name"), blank=False, null=False, required=False)

Expand All @@ -170,7 +224,10 @@ class Scope(TreeNode, UUIDModel):
)
is_active = models.BooleanField(default=True)

objects = TreeQuerySet.as_manager(with_tree_fields=True)
objects = ScopeQuerySet.as_manager(with_tree_fields=True)

def get_root(self):
return self.ancestors(include_self=True).first()

def save(self, *args, **kwargs):
# django-tree-queries does validation in TreeNode.clean(), which is not
Expand Down
149 changes: 149 additions & 0 deletions emeis/core/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,152 @@ def test_update_full_name_of_child(db, scope_factory):

grandchild.refresh_from_db()
assert str(grandchild.full_name) == "r » s » c » g"


@pytest.fixture
def simple_tree_structure(db, scope_factory):
# root1
# - sub1sub1
# - sub1sub1sub1
# - sub1sub1sub2
# - sub1sub2
# root2
# - sub2sub1
# - sub2sub2
root1 = scope_factory(name="root1")
root2 = scope_factory(name="root2")
sub1sub1 = scope_factory(parent=root1, name="sub1sub1")
sub1sub2 = scope_factory(parent=root1, name="sub1sub2")
sub1sub1sub1 = scope_factory(parent=sub1sub1, name="sub1sub1sub1")
sub1sub1sub2 = scope_factory(parent=sub1sub1, name="sub1sub1sub2")

sub2sub1 = scope_factory(parent=root2, name="sub2sub1")
sub2sub2 = scope_factory(parent=root2, name="sub2sub2")
return (
root1,
root2,
sub1sub1,
sub1sub2,
sub1sub1sub1,
sub1sub1sub2,
sub2sub1,
sub2sub2,
)


@pytest.mark.parametrize(
"include_self, expect_count",
[
(True, 5),
(False, 3),
],
)
def test_scope_ancestors(db, simple_tree_structure, include_self, expect_count):
(
root1,
root2,
sub1sub1,
sub1sub2,
sub1sub1sub1,
sub1sub1sub2,
sub2sub1,
sub2sub2,
) = simple_tree_structure

qs = Scope.objects.filter(pk__in=[sub2sub2.pk, sub1sub1sub2.pk])

ancestors_qs = qs.all_ancestors(include_self=include_self)
# the direct and indirect ancestors must be there
assert root2 in ancestors_qs
assert root1 in ancestors_qs
assert sub1sub1 in ancestors_qs

if include_self:
assert sub2sub2 in ancestors_qs
assert sub1sub1sub2 in ancestors_qs
else:
assert sub2sub2 not in ancestors_qs
assert sub1sub1sub2 not in ancestors_qs

# ... and nothing else
assert ancestors_qs.count() == expect_count


@pytest.mark.parametrize(
"include_self, expect_count",
[
(True, 6),
(False, 4),
],
)
def test_scope_descendants(db, simple_tree_structure, include_self, expect_count):
(
root1,
root2,
sub1sub1,
sub1sub2,
sub1sub1sub1,
sub1sub1sub2,
sub2sub1,
sub2sub2,
) = simple_tree_structure

qs = Scope.objects.filter(pk__in=[sub1sub1.pk, root2.pk])

descendants_qs = qs.all_descendants(include_self=include_self)
# the direct and indirect descendants must be there
assert sub1sub1sub1 in descendants_qs
assert sub1sub1sub2 in descendants_qs
assert sub2sub1 in descendants_qs
assert sub2sub2 in descendants_qs

if include_self:
assert sub1sub1 in descendants_qs
assert root2 in descendants_qs
else:
assert sub1sub1 not in descendants_qs
assert root2 not in descendants_qs

# ... and nothing else
assert descendants_qs.count() == expect_count


def test_get_root(db, simple_tree_structure):
(
root1,
root2,
sub1sub1,
sub1sub2,
sub1sub1sub1,
sub1sub1sub2,
sub2sub1,
sub2sub2,
) = simple_tree_structure

assert sub1sub2.get_root() == root1
assert sub1sub1.get_root() == root1
assert sub2sub2.get_root() == root2
assert sub2sub1.get_root() == root2
assert sub1sub1sub2.get_root() == root1


def test_all_roots(db, simple_tree_structure):
(
root1,
root2,
sub1sub1,
sub1sub2,
sub1sub1sub1,
sub1sub1sub2,
sub2sub1,
sub2sub2,
) = simple_tree_structure

qs1 = Scope.objects.filter(pk__in=[sub1sub1sub1, sub1sub2]).all_roots()
assert qs1.count() == 1
assert root1 in qs1

qs2 = Scope.objects.filter(pk__in=[sub1sub1sub1, sub2sub2]).all_roots()
assert qs2.count() == 2
assert root1 in qs2
assert root2 in qs2

0 comments on commit a136a5f

Please sign in to comment.