From fb3139c3468f4340a9eaf3fd00addf13182c21ef Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sun, 14 Nov 2021 21:33:58 -0600 Subject: [PATCH] add support for pymbolic.EqualityMapper --- pytential/symbolic/mappers.py | 76 ++++++++++++++++++++++++++++++++ pytential/symbolic/primitives.py | 4 ++ requirements.txt | 4 +- test/test_symbolic.py | 7 ++- 4 files changed, 88 insertions(+), 3 deletions(-) diff --git a/pytential/symbolic/mappers.py b/pytential/symbolic/mappers.py index d9cb0f78c..424ac46be 100644 --- a/pytential/symbolic/mappers.py +++ b/pytential/symbolic/mappers.py @@ -31,6 +31,8 @@ ) from pymbolic.mapper.dependency import ( DependencyMapper as DependencyMapperBase) +from pymbolic.mapper.equality import ( + EqualityMapper as EqualityMapperBase) from pymbolic.geometric_algebra.mapper import ( CombineMapper as CombineMapperBase, IdentityMapper as IdentityMapperBase, @@ -676,6 +678,80 @@ def map_int_g(self, expr): # }}} +# {{{ EqualityMapper + +class EqualityMapper(EqualityMapperBase): + def map_ones(self, expr, other) -> bool: + return expr.dofdesc == other.dofdesc + + map_q_weight = map_ones + + def map_node_coordinate_component(self, expr, other) -> bool: + return ( + expr.ambient_axis == other.ambient_axis + and expr.dofdesc == other.dofdesc) + + def map_num_reference_derivative(self, expr, other) -> bool: + return ( + expr.ref_axes == other.ref_axes + and expr.dofdesc == other.dofdesc + and self.rec(expr.operand, other.operand) + ) + + def map_node_sum(self, expr, other) -> bool: + return self.rec(expr.operand, other.operand) + + map_node_max = map_node_sum + map_node_min = map_node_sum + + def map_elementwise_sum(self, expr, other) -> bool: + return ( + expr.dofdesc == other.dofdesc + and self.rec(expr.operand, other.operand)) + + map_elementwise_max = map_elementwise_sum + map_elementwise_min = map_elementwise_sum + + def map_int_g(self, expr, other) -> bool: + from pytential.symbolic.primitives import hashable_kernel_args + return ( + expr.qbx_forced_limit == other.qbx_forced_limit + and expr.source == other.source + and expr.target == other.target + and len(expr.kernel_arguments) == len(other.kernel_arguments) + and len(expr.source_kernels) == len(other.source_kernels) + and len(expr.densities) == len(other.densities) + and expr.target_kernel == other.target_kernel + and all(knl == other_knl for knl, other_knl in zip( + expr.source_kernels, other.source_kernels) + ) + and all(d == other_d for d, other_d in zip( + expr.densities, other.densities)) + and all(k == other_k + and self.rec(v, other_v) + for (k, v), (other_k, other_v) in zip( + sorted(hashable_kernel_args(expr.kernel_arguments)), + sorted(hashable_kernel_args(other.kernel_arguments)))) + ) + + def map_interpolation(self, expr, other) -> bool: + return ( + expr.from_dd == other.from_dd + and expr.to_dd == other.to_dd + and self.rec(expr.operand, other.operand)) + + def map_is_shape_class(self, expr, other) -> bool: + return ( + expr.shape is other.shape, + expr.dofdesc == other.dofdesc + ) + + def map_error_expression(self, expr, other) -> bool: + return expr.message == other.message + +# }}} + + # {{{ StringifyMapper def stringify_where(where): diff --git a/pytential/symbolic/primitives.py b/pytential/symbolic/primitives.py index a4a652c71..8b0958c18 100644 --- a/pytential/symbolic/primitives.py +++ b/pytential/symbolic/primitives.py @@ -248,6 +248,10 @@ def array_to_tuple(ary): class Expression(ExpressionBase): + def make_equality_mapper(self): + from pytential.symbolic.mappers import EqualityMapper + return EqualityMapper() + def make_stringifier(self, originating_stringifier=None): from pytential.symbolic.mappers import StringifyMapper return StringifyMapper() diff --git a/requirements.txt b/requirements.txt index 88777f180..42faff8a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,12 +2,12 @@ numpy != 1.22.0 git+https://github.com/inducer/pytools.git#egg=pytools -git+https://github.com/inducer/pymbolic.git#egg=pymbolic +git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic sympy git+https://github.com/inducer/modepy.git#egg=modepy git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy -git+https://github.com/inducer/loopy.git#egg=loopy +git+https://github.com/alexfikl/loopy.git@equality-mapper#egg=loopy git+https://github.com/inducer/boxtree.git#egg=boxtree git+https://github.com/inducer/arraycontext.git#egg=arraycontext git+https://github.com/inducer/meshmode.git#egg=meshmode diff --git a/test/test_symbolic.py b/test/test_symbolic.py index 62c71cde0..eae16db8c 100644 --- a/test/test_symbolic.py +++ b/test/test_symbolic.py @@ -380,7 +380,7 @@ def test_derivative_binder_expr(): d1, d2 = principal_directions(ambient_dim, dim=dim) expr = (d1 @ d2 + d1 @ d1) / (d2 @ d2) - nruns = 4 + nruns = 1 for i in range(nruns): from pytools import ProcessTimer with ProcessTimer() as pd: @@ -446,6 +446,8 @@ def is_base_kernel(knl): @pytest.mark.parametrize("op_name", ["dirichlet", "neumann"]) def test_mapper_int_g_term_collector(op_name, k=0): + logging.basicConfig(level=logging.INFO) + ambient_dim = 3 op = _make_operator(ambient_dim, op_name, k) expr = op.operator(op.get_density_var("sigma")) @@ -463,6 +465,9 @@ def test_mapper_int_g_term_collector(op_name, k=0): else: raise ValueError(f"unknown operator name: {op_name}") + print(sym.pretty(expr_only_intgs)) + print(sym.pretty(expected_expr)) + assert expr_only_intgs == expected_expr # }}}