Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DNM/RFC] Track dependents without weakrefs #831

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
87 changes: 61 additions & 26 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@
import pandas as pd
import toolz
from dask.dataframe.core import is_dataframe_like, is_index_like, is_series_like
from dask.delayed import Delayed
from dask.utils import funcname, import_required, is_arraylike
from toolz.dicttoolz import merge

from dask_expr._util import _BackendData, _tokenize_deterministic


def _unpack_collections(o):
if isinstance(o, Expr):
return o

if hasattr(o, "expr"):
return o.expr
return o, o._name
elif hasattr(o, "expr") and not isinstance(o, Delayed):
return o.expr, o.expr._name
else:
return o
return o, None


class Expr:
Expand All @@ -38,15 +39,50 @@ def __new__(cls, *args, **kwargs):
except KeyError:
operands.append(cls._defaults[parameter])
assert not kwargs, kwargs

parsed_operands = []
children = set()
_subgraphs = []
_subgraph_instances = []
_graph_instances = {}
for o in operands:
expr, name = _unpack_collections(o)
parsed_operands.append(expr)
if name is not None:
children.add(name)
_subgraphs.append(expr._graph)
_subgraph_instances.append(expr._graph_instances)
_graph_instances[name] = expr

inst = object.__new__(cls)
inst.operands = [_unpack_collections(o) for o in operands]
inst.operands = parsed_operands
_name = inst._name

# Graph instances is a mapping name -> Expr instance
# Graph itself is a mapping of dependencies mapping names to a set of names

if _name in Expr._instances:
return Expr._instances[_name]
inst = Expr._instances[_name]
inst._graph_instances.update(merge(_graph_instances, *_subgraph_instances))
inst._graph.update(merge(*_subgraphs))
inst._graph[_name].update(children)
# Probably a bad idea to have a self ref
inst._graph_instances[_name] = inst
return inst

Expr._instances[_name] = inst
inst._graph_instances = merge(_graph_instances, *_subgraph_instances)
inst._graph = merge(*_subgraphs)
inst._graph[_name] = children
# Probably a bad idea to have a self ref
inst._graph_instances[_name] = inst
return inst

def __hash__(self):
raise TypeError(
"Expr objects can't be used in sets or dicts or similar, use the _name instead"
)

def _tune_down(self):
return None

Expand Down Expand Up @@ -150,6 +186,21 @@ def dependencies(self):
# Dependencies are `Expr` operands only
return [operand for operand in self.operands if isinstance(operand, Expr)]

@functools.cached_property
def _dependent_graph(self):
rv = defaultdict(set)
# This should be O(E)
for expr, dependencies in self._graph.items():
rv[expr]
for dep in dependencies:
rv[dep].add(expr)
for name, exprs in rv.items():
rv[name] = {self._graph_instances[e] for e in exprs}
return rv

def dependents(self):
return self._dependent_graph

def _task(self, index: int):
"""The task for the i'th partition

Expand Down Expand Up @@ -318,8 +369,8 @@ def simplify_once(self, dependents: defaultdict, simplified: dict):
changed = False
for operand in expr.operands:
if isinstance(operand, Expr):
# Bandaid for now, waiting for Singleton
dependents[operand._name].append(weakref.ref(expr))
# # Bandaid for now, waiting for Singleton
dependents[operand._name].add(expr)
new = operand.simplify_once(
dependents=dependents, simplified=simplified
)
Expand All @@ -340,7 +391,7 @@ def simplify_once(self, dependents: defaultdict, simplified: dict):
def simplify(self) -> Expr:
expr = self
while True:
dependents = collect_dependents(expr)
dependents = expr.dependents()
new = expr.simplify_once(dependents=dependents, simplified={})
if new._name == expr._name:
break
Expand Down Expand Up @@ -712,19 +763,3 @@ def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]:
or issubclass(operation, Expr)
), "`operation` must be`Expr` subclass)"
return (expr for expr in self.walk() if isinstance(expr, operation))


def collect_dependents(expr) -> defaultdict:
dependents = defaultdict(list)
stack = [expr]
seen = set()
while stack:
node = stack.pop()
if node._name in seen:
continue
seen.add(node._name)

for dep in node.dependencies():
stack.append(dep)
dependents[dep._name].append(weakref.ref(node))
return dependents
18 changes: 6 additions & 12 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ def _simplify_up(self, parent, dependents):
):
predicate = None
if self.frame.ndim == 1 and self.ndim == 2:
name = self.frame._meta.name
name = self._meta.columns[0]
# Avoid Projection since we are already a Series
subs = Projection(self, name)
predicate = parent.predicate.substitute(subs, self.frame)
Expand Down Expand Up @@ -2076,9 +2076,7 @@ def _simplify_up(self, parent, dependents):
parent, dependents
):
parents = [
p().columns
for p in dependents[self._name]
if p() is not None and not isinstance(p(), Filter)
p.columns for p in dependents[self._name] if not isinstance(p, Filter)
]
predicate = None
if not set(flatten(parents, list)).issubset(set(self.frame.columns)):
Expand Down Expand Up @@ -2107,7 +2105,7 @@ def _simplify_up(self, parent, dependents):
if col in (self.name, "index", self.frame._meta.index.name):
return
if all(
isinstance(d(), Projection) and d().operand("columns") == col
isinstance(d, Projection) and d.operand("columns") == col
for d in dependents[self._name]
):
return type(self)(self.frame, True, self.name)
Expand Down Expand Up @@ -2715,10 +2713,6 @@ class _DelayedExpr(Expr):
# TODO
_parameters = ["obj"]

def __init__(self, obj):
self.obj = obj
self.operands = [obj]

def __str__(self):
return f"{type(self).__name__}({str(self.obj)})"

Expand Down Expand Up @@ -3451,7 +3445,7 @@ def determine_column_projection(expr, parent, dependents, additional_columns=Non
column_union = []
else:
column_union = parent.columns.copy()
parents = [x() for x in dependents[expr._name] if x() is not None]
parents = dependents[expr._name]

seen = set()
for p in parents:
Expand Down Expand Up @@ -3511,7 +3505,7 @@ def plain_column_projection(expr, parent, dependents, additional_columns=None):


def is_filter_pushdown_available(expr, parent, dependents, allow_reduction=True):
parents = [x() for x in dependents[expr._name] if x() is not None]
parents = dependents[expr._name]
filters = {e._name for e in parents if isinstance(e, Filter)}
if len(filters) != 1:
# Don't push down if not exactly one Filter
Expand Down Expand Up @@ -3618,7 +3612,7 @@ def _check_dependents_are_predicates(
continue
seen.add(e._name)

e_dependents = {x()._name for x in dependents[e._name] if x() is not None}
e_dependents = {x._name for x in dependents[e._name]}

if not allow_reduction:
if isinstance(e, Reduction):
Expand Down
1 change: 1 addition & 0 deletions dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2392,6 +2392,7 @@ def test_predicate_pushdown_ndim_change(df, pdf):
expected = expected[expected[0] > 1]
assert_eq(result, expected)
assert isinstance(result.simplify().expr.frame, Filter)
result.simplify().pprint()
assert isinstance(result.simplify().expr, ToFrame)


Expand Down
3 changes: 2 additions & 1 deletion dask_expr/tests/test_shuffle.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from collections import OrderedDict

import dask
Expand Down Expand Up @@ -640,7 +641,7 @@ def test_set_index_sort_values_one_partition(pdf):

def test_set_index_triggers_calc_when_accessing_divisions(pdf, df):
divisions_lru.data = OrderedDict()
query = df.set_index("x")
query = df.fillna(random.randint(1, 100)).set_index("x")
Copy link
Collaborator

@phofl phofl Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is funny, the previous expression was cached, so we were never calculating divisions with this PR. Using a random number here avoid this

assert len(divisions_lru.data) == 0
divisions = query.divisions # noqa: F841
assert len(divisions_lru.data) == 1
Expand Down
Loading