diff --git a/pytential/symbolic/mappers.py b/pytential/symbolic/mappers.py index b6ebb01bd..a89b9a485 100644 --- a/pytential/symbolic/mappers.py +++ b/pytential/symbolic/mappers.py @@ -31,6 +31,8 @@ ) from pymbolic.mapper.dependency import ( DependencyMapper as DependencyMapperBase) +from pymbolic.mapper.equality import ( + EqualityMapper as EqualityMapperBase) from pymbolic.geometric_algebra.mapper import ( CombineMapper as CombineMapperBase, IdentityMapper as IdentityMapperBase, @@ -51,6 +53,8 @@ import pytential.symbolic.primitives as prim +# {{{ IdentityMapper + def rec_int_g_arguments(mapper, expr): densities = mapper.rec(expr.densities) kernel_arguments = { @@ -138,6 +142,11 @@ def map_interpolation(self, expr): return type(expr)(expr.from_dd, expr.to_dd, operand) +# }}} + + +# {{{ CombineMapper + class CombineMapper(CombineMapperBase): def map_node_sum(self, expr): return self.rec(expr.operand) @@ -168,6 +177,10 @@ def map_is_shape_class(self, expr): map_error_expression = map_is_shape_class +# }}} + + +# {{{ Collector class Collector(CollectorBase, CombineMapper): def map_ones(self, expr): @@ -186,6 +199,10 @@ def map_int_g(self, expr): class DependencyMapper(DependencyMapperBase, Collector): pass +# }}} + + +# {{{ EvaluationMapper class EvaluationMapper(EvaluationMapperBase): """Unlike :mod:`pymbolic.mapper.evaluation.EvaluationMapper`, this class @@ -249,8 +266,10 @@ def map_common_subexpression(self, expr): expr.prefix, expr.scope) +# }}} + -# {{{ dofdesc tagging +# {{{ dofdesc tagging: LocationTagger, ToTargetTagger class LocationTagger(CSECachingMapperMixin, IdentityMapper): """Used internally by :class:`ToTargetTagger`.""" @@ -655,6 +674,88 @@ def map_int_g(self, expr): # }}} +# {{{ EqualityMapper + +class EqualityMapper(EqualityMapperBase): + def map_ones(self, expr, other) -> bool: + return expr.dofdesc == other.dofdesc + + map_q_weight = map_ones + + def map_node_coordinate_component(self, expr, other) -> bool: + return ( + expr.ambient_axis == other.ambient_axis + and expr.dofdesc == other.dofdesc) + + def map_num_reference_derivative(self, expr, other) -> bool: + return ( + expr.ref_axes == other.ref_axes + and expr.dofdesc == other.dofdesc + and self.rec(expr.operand, other.operand) + ) + + def map_node_sum(self, expr, other) -> bool: + return self.rec(expr.operand, other.operand) + + map_node_max = map_node_sum + map_node_min = map_node_sum + + def map_elementwise_sum(self, expr, other) -> bool: + return ( + expr.dofdesc == other.dofdesc + and self.rec(expr.operand) == other.operand) + + map_elementwise_max = map_elementwise_sum + map_elementwise_min = map_elementwise_sum + + def map_int_g(self, expr, other) -> bool: + import numpy as np + + def as_hashable(kernel_arg_value): + # FIXME: this is here to match the fact that pickled IntGs get + # restored as tuples, not ndarray, so they don't equal anymore + if isinstance(kernel_arg_value, np.ndarray): + return tuple(kernel_arg_value) + return kernel_arg_value + + return ( + expr.qbx_forced_limit == other.qbx_forced_limit + and expr.source == other.source + and expr.target == other.target + and len(expr.kernel_arguments) == len(other.kernel_arguments) + and len(expr.source_kernels) == len(other.source_kernels) + and len(expr.densities) == len(other.densities) + and expr.target_kernel == other.target_kernel + and all(knl == other_knl for knl, other_knl in zip( + expr.source_kernels, other.source_kernels) + ) + and all(d == other_d for d, other_d in zip( + expr.densities, other.densities)) + and all(k == other_k + and self.rec(as_hashable(v), as_hashable(other_v)) + for (k, v), (other_k, other_v) in zip( + sorted(expr.kernel_arguments.items()), + sorted(other.kernel_arguments.items()))) + ) + + def map_interpolation(self, expr, other) -> bool: + return ( + expr.from_dd == other.from_dd + and expr.to_dd == other.to_dd + and self.rec(expr.operand, other.operand)) + + def map_is_shape_class(self, expr, other) -> bool: + return ( + expr.shape is other.shape, + expr.dofdesc == other.dofdesc + ) + + def map_error_expression(self, expr, other) -> bool: + return expr.message == other.message + +# }}} + + # {{{ stringifier def stringify_where(where): @@ -768,13 +869,13 @@ def map_is_shape_class(self, expr, enclosing_prec): return "IsShape[{}]({})".format(stringify_where(expr.dofdesc), expr.shape.__name__) -# }}} - class PrettyStringifyMapper( CSESplittingStringifyMapperMixin, StringifyMapper): pass +# }}} + # {{{ graphviz diff --git a/pytential/symbolic/primitives.py b/pytential/symbolic/primitives.py index 82f12786e..25a21e500 100644 --- a/pytential/symbolic/primitives.py +++ b/pytential/symbolic/primitives.py @@ -248,6 +248,10 @@ def array_to_tuple(ary): class Expression(ExpressionBase): + def make_equality_mapper(self): + from pytential.symbolic.mappers import EqualityMapper + return EqualityMapper() + def make_stringifier(self, originating_stringifier=None): from pytential.symbolic.mappers import StringifyMapper return StringifyMapper() diff --git a/requirements.txt b/requirements.txt index 88777f180..42faff8a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,12 +2,12 @@ numpy != 1.22.0 git+https://github.com/inducer/pytools.git#egg=pytools -git+https://github.com/inducer/pymbolic.git#egg=pymbolic +git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic sympy git+https://github.com/inducer/modepy.git#egg=modepy git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy -git+https://github.com/inducer/loopy.git#egg=loopy +git+https://github.com/alexfikl/loopy.git@equality-mapper#egg=loopy git+https://github.com/inducer/boxtree.git#egg=boxtree git+https://github.com/inducer/arraycontext.git#egg=arraycontext git+https://github.com/inducer/meshmode.git#egg=meshmode diff --git a/test/test_symbolic.py b/test/test_symbolic.py index 63b9edb62..66b1fd829 100644 --- a/test/test_symbolic.py +++ b/test/test_symbolic.py @@ -380,7 +380,7 @@ def test_derivative_binder_expr(): d1, d2 = principal_directions(ambient_dim, dim=dim) expr = (d1 @ d2 + d1 @ d1) / (d2 @ d2) - nruns = 4 + nruns = 1 for i in range(nruns): from pytools import ProcessTimer with ProcessTimer() as pd: