Skip to content

Commit

Permalink
deprecation fixes for pymbolic 2024.2
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Nov 25, 2024
1 parent afacd2c commit 223ab83
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 27 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ dependencies = [
"loopy>=2024.1",
"meshmode>=2021.2",
"modepy>=2021.1",
"pymbolic>=2022.2",
"pymbolic>=2024.2",
"pyopencl>=2022.1",
"pytools>=2022.1",
"pytools>=2024.1",
"scipy>=1.2",
"sumpy>=2022.1",
]
Expand Down
3 changes: 2 additions & 1 deletion pytential/symbolic/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

import numpy as np

from pymbolic.primitives import cse_scope, Expression, Variable, Subscript
from pymbolic.primitives import cse_scope, Variable, Subscript
from pymbolic.typing import Expression
from sumpy.kernel import Kernel

from pytential.symbolic.primitives import (
Expand Down
6 changes: 3 additions & 3 deletions pytential/symbolic/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
as DerivativeSourceFinderBase,

GraphvizMapper as GraphvizMapperBase)
from pymbolic.typing import ExpressionT
from pymbolic.typing import Expression
import pytential.symbolic.primitives as prim


Expand Down Expand Up @@ -291,15 +291,15 @@ def flatten(expr):

# {{{ LocationTagger

class LocationTagger(CSECachingMapperMixin[ExpressionT, []],
class LocationTagger(CSECachingMapperMixin[Expression, []],
IdentityMapper):
"""Used internally by :class:`ToTargetTagger`."""

def __init__(self, default_target, default_source):
self.default_source = default_source
self.default_target = default_target

def map_common_subexpression_uncached(self, expr) -> ExpressionT:
def map_common_subexpression_uncached(self, expr) -> Expression:
# Mypy 1.13 complains about this:
# error: Too few arguments for "map_common_subexpression" of "IdentityMapper" [call-arg] # noqa: E501
# error: Argument 1 to "map_common_subexpression" of "IdentityMapper" has incompatible type "LocationTagger"; expected "IdentityMapper[P]" [arg-type] # noqa: E501
Expand Down
50 changes: 29 additions & 21 deletions pytential/symbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@
import numpy as np

import modepy as mp
from pymbolic import ExpressionNode, Variable
from pymbolic.primitives import ( # noqa: N813
Expression as ExpressionBase, Variable, Variable as var,
Variable as var,
cse_scope as cse_scope_base,
make_common_subexpression as cse,
expr_dataclass)
from pymbolic.geometric_algebra import MultiVector, componentwise
from pymbolic.geometric_algebra.primitives import (
NablaComponent, Derivative as DerivativeBase)
from pymbolic.primitives import make_sym_vector
from pymbolic.typing import ArithmeticExpressionT
from pymbolic.typing import ArithmeticExpression

from pytools import P
from pytools.obj_array import make_obj_array, flat_obj_array
Expand Down Expand Up @@ -99,12 +100,18 @@
.. autofunction:: for_each_expression
.. autoclass:: OperandT
.. autoclass:: Operand
.. autoclass:: ArithmeticExpressionT
.. class:: P
See :class:`pytools.P`
.. class:: ExpressionNode
See :class:`pymbolic.ExpressionNode`.
Diagnostics
^^^^^^^^^^^
Expand Down Expand Up @@ -376,8 +383,8 @@


@expr_dataclass()
class Expression(ExpressionBase):
"""A subclass of :class:`pymbolic.primitives.Expression` for use with
class Expression(ExpressionNode):
"""A subclass of :class:`pymbolic.primitives.ExpressionNode` for use with
:mod:`pytential` mappers.
"""

Expand All @@ -387,10 +394,11 @@ def make_stringifier(self, originating_stringifier=None):


Operand: TypeAlias = (
ArithmeticExpressionT | np.ndarray[Any, np.dtype[Any]] | MultiVector)
ArithmeticExpression | np.ndarray[Any, np.dtype[Any]] | MultiVector)
QBXForcedLimit = int | Literal["avg"] | None

OperandT = TypeVar("OperandT", bound=ArithmeticExpressionT)
# NOTE: this will likely live in pymbolic at some point, but for now we take it!
ArithmeticExpressionT = TypeVar("ArithmeticExpressionT", bound=ArithmeticExpression)


class _NoArgSentinel:
Expand All @@ -402,7 +410,7 @@ class cse_scope(cse_scope_base): # noqa: N801


def for_each_expression(
f: Callable[Concatenate[ArithmeticExpressionT, P], ArithmeticExpressionT]
f: Callable[Concatenate[ArithmeticExpression, P], ArithmeticExpression]
) -> Callable[Concatenate[Operand, P], Operand]:
"""A decorator that takes a function that can only work on expressions
and transforms it into a function that can be applied componentwise on
Expand All @@ -414,7 +422,7 @@ def for_each_expression(
@wraps(f)
def wrapper(operand: Operand, *args: P.args, **kwargs: P.kwargs) -> Operand:
if isinstance(operand, np.ndarray | MultiVector):
def func(operand_i: ArithmeticExpressionT) -> ArithmeticExpressionT:
def func(operand_i: ArithmeticExpression) -> ArithmeticExpression:
return f(operand_i, *args, **kwargs)

return componentwise(func, operand)
Expand Down Expand Up @@ -631,7 +639,7 @@ def make_op(operand_i):
# different..
def __init__(self,
ref_axes: tuple[tuple[int, int], ...],
operand: ArithmeticExpressionT,
operand: ArithmeticExpression,
dofdesc: DOFDescriptorLike) -> None:
if not isinstance(ref_axes, tuple):
raise ValueError(f"'ref_axes' must be a tuple: {type(ref_axes)}")
Expand All @@ -653,7 +661,7 @@ def __init__(self,

@for_each_expression
def num_reference_derivative(
expr: ArithmeticExpressionT,
expr: ArithmeticExpression,
ref_axes: tuple[tuple[int, int], ...] | int,
dofdesc: DOFDescriptorLike | None) -> NumReferenceDerivative:
"""Take a derivative of *expr* with respect to the the element reference
Expand Down Expand Up @@ -1239,9 +1247,9 @@ def interp(from_dd, to_dd, operand):


@for_each_expression
def interpolate(operand: OperandT,
def interpolate(operand: ArithmeticExpressionT,
from_dd: DOFDescriptorLike,
to_dd: DOFDescriptorLike) -> OperandT | Interpolation:
to_dd: DOFDescriptorLike) -> ArithmeticExpressionT | Interpolation:
from_dd = as_dofdesc(from_dd)
to_dd = as_dofdesc(to_dd)

Expand Down Expand Up @@ -1286,7 +1294,7 @@ class NodeSum(SingleScalarOperandExpression):


@for_each_expression
def node_sum(expr: ArithmeticExpressionT) -> NodeSum:
def node_sum(expr: ArithmeticExpression) -> NodeSum:
return NodeSum(expr)


Expand All @@ -1299,7 +1307,7 @@ class NodeMax(SingleScalarOperandExpression):


@for_each_expression
def node_max(expr: ArithmeticExpressionT) -> NodeMax:
def node_max(expr: ArithmeticExpression) -> NodeMax:
return NodeMax(expr)


Expand All @@ -1312,7 +1320,7 @@ class NodeMin(SingleScalarOperandExpression):


@for_each_expression
def node_min(expr: ArithmeticExpressionT) -> NodeMin:
def node_min(expr: ArithmeticExpression) -> NodeMin:
return NodeMin(expr)


Expand Down Expand Up @@ -1378,7 +1386,7 @@ class ElementwiseSum(SingleScalarOperandExpressionWithWhere):


@for_each_expression
def elementwise_sum(expr: ArithmeticExpressionT,
def elementwise_sum(expr: ArithmeticExpression,
dofdesc: DOFDescriptorLike = None) -> ElementwiseSum:
return ElementwiseSum(expr, as_dofdesc(dofdesc))

Expand All @@ -1393,7 +1401,7 @@ class ElementwiseMin(SingleScalarOperandExpressionWithWhere):


@for_each_expression
def elementwise_min(expr: ArithmeticExpressionT,
def elementwise_min(expr: ArithmeticExpression,
dofdesc: DOFDescriptorLike = None) -> ElementwiseMin:
return ElementwiseMin(expr, as_dofdesc(dofdesc))

Expand All @@ -1408,7 +1416,7 @@ class ElementwiseMax(SingleScalarOperandExpressionWithWhere):


@for_each_expression
def elementwise_max(expr: ArithmeticExpressionT,
def elementwise_max(expr: ArithmeticExpression,
dofdesc: DOFDescriptorLike = None) -> ElementwiseMax:
return ElementwiseMax(expr, as_dofdesc(dofdesc))

Expand Down Expand Up @@ -1462,9 +1470,9 @@ class IterativeInverse(Expression):
.. autoattribute:: dofdesc
"""

expression: ArithmeticExpressionT
expression: ArithmeticExpression
"""The operator *A* used in the linear solve."""
rhs: ArithmeticExpressionT
rhs: ArithmeticExpression
"""The right-hand side variable used in the linear solve."""
variable_name: str
"""The name of the variable to solve for."""
Expand Down

0 comments on commit 223ab83

Please sign in to comment.