Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EqualityMapper to follow pymbolic #148

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions pytential/symbolic/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
)
from pymbolic.mapper.dependency import (
DependencyMapper as DependencyMapperBase)
from pymbolic.mapper.equality import (
EqualityMapper as EqualityMapperBase)
from pymbolic.geometric_algebra.mapper import (
CombineMapper as CombineMapperBase,
IdentityMapper as IdentityMapperBase,
Expand Down Expand Up @@ -676,6 +678,80 @@ def map_int_g(self, expr):
# }}}


# {{{ EqualityMapper

class EqualityMapper(EqualityMapperBase):
def map_ones(self, expr, other) -> bool:
return expr.dofdesc == other.dofdesc

map_q_weight = map_ones

def map_node_coordinate_component(self, expr, other) -> bool:
return (
expr.ambient_axis == other.ambient_axis
and expr.dofdesc == other.dofdesc)

def map_num_reference_derivative(self, expr, other) -> bool:
return (
expr.ref_axes == other.ref_axes
and expr.dofdesc == other.dofdesc
and self.rec(expr.operand, other.operand)
)

def map_node_sum(self, expr, other) -> bool:
return self.rec(expr.operand, other.operand)

map_node_max = map_node_sum
map_node_min = map_node_sum

def map_elementwise_sum(self, expr, other) -> bool:
return (
expr.dofdesc == other.dofdesc
and self.rec(expr.operand, other.operand))

map_elementwise_max = map_elementwise_sum
map_elementwise_min = map_elementwise_sum

def map_int_g(self, expr, other) -> bool:
from pytential.symbolic.primitives import hashable_kernel_args
return (
expr.qbx_forced_limit == other.qbx_forced_limit
and expr.source == other.source
and expr.target == other.target
and len(expr.kernel_arguments) == len(other.kernel_arguments)
and len(expr.source_kernels) == len(other.source_kernels)
and len(expr.densities) == len(other.densities)
and expr.target_kernel == other.target_kernel
and all(knl == other_knl for knl, other_knl in zip(
expr.source_kernels, other.source_kernels)
)
and all(d == other_d for d, other_d in zip(
expr.densities, other.densities))
and all(k == other_k
and self.rec(v, other_v)
for (k, v), (other_k, other_v) in zip(
sorted(hashable_kernel_args(expr.kernel_arguments)),
sorted(hashable_kernel_args(other.kernel_arguments))))
)

def map_interpolation(self, expr, other) -> bool:
return (
expr.from_dd == other.from_dd
and expr.to_dd == other.to_dd
and self.rec(expr.operand, other.operand))

def map_is_shape_class(self, expr, other) -> bool:
return (
expr.shape is other.shape,
expr.dofdesc == other.dofdesc
)

def map_error_expression(self, expr, other) -> bool:
return expr.message == other.message

# }}}


# {{{ StringifyMapper

def stringify_where(where):
Expand Down
4 changes: 4 additions & 0 deletions pytential/symbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def array_to_tuple(ary):


class Expression(ExpressionBase):
def make_equality_mapper(self):
from pytential.symbolic.mappers import EqualityMapper
return EqualityMapper()

def make_stringifier(self, originating_stringifier=None):
from pytential.symbolic.mappers import StringifyMapper
return StringifyMapper()
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
numpy != 1.22.0

git+https://github.com/inducer/pytools.git#egg=pytools
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
sympy
git+https://github.com/inducer/modepy.git#egg=modepy
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
git+https://github.com/inducer/islpy.git#egg=islpy
git+https://github.com/inducer/loopy.git#egg=loopy
git+https://github.com/alexfikl/loopy.git@equality-mapper#egg=loopy
git+https://github.com/inducer/boxtree.git#egg=boxtree
git+https://github.com/inducer/arraycontext.git#egg=arraycontext
git+https://github.com/inducer/meshmode.git#egg=meshmode
Expand Down
7 changes: 6 additions & 1 deletion test/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def test_derivative_binder_expr():
d1, d2 = principal_directions(ambient_dim, dim=dim)
expr = (d1 @ d2 + d1 @ d1) / (d2 @ d2)

nruns = 4
nruns = 1
for i in range(nruns):
from pytools import ProcessTimer
with ProcessTimer() as pd:
Expand Down Expand Up @@ -478,6 +478,8 @@ def is_base_kernel(knl):

@pytest.mark.parametrize("op_name", ["dirichlet", "neumann"])
def test_mapper_int_g_term_collector(op_name, k=0):
logging.basicConfig(level=logging.INFO)

ambient_dim = 3
op = _make_operator(ambient_dim, op_name, k)
expr = op.operator(op.get_density_var("sigma"))
Expand All @@ -495,6 +497,9 @@ def test_mapper_int_g_term_collector(op_name, k=0):
else:
raise ValueError(f"unknown operator name: {op_name}")

print(sym.pretty(expr_only_intgs))
print(sym.pretty(expected_expr))

assert expr_only_intgs == expected_expr

# }}}
Expand Down