diff --git a/pytential/symbolic/primitives.py b/pytential/symbolic/primitives.py index 459f75b77..4ff540135 100644 --- a/pytential/symbolic/primitives.py +++ b/pytential/symbolic/primitives.py @@ -20,10 +20,11 @@ THE SOFTWARE. """ +from collections.abc import Callable from dataclasses import field from warnings import warn from functools import partial -from typing import Any, Union, Literal +from typing import Any, Concatenate, Literal, TypeVar, Union, cast import numpy as np @@ -38,6 +39,7 @@ NablaComponent, Derivative as DerivativeBase) from pymbolic.primitives import make_sym_vector +from pytools import P from pytools.obj_array import make_obj_array, flat_obj_array from sumpy.kernel import Kernel, SpatialConstant @@ -347,9 +349,15 @@ ) -Operand = Union["Expression", np.ndarray, MultiVector] +# {{{ helpers + + +Operand = Union["ExpressionBase", np.ndarray[Any, np.dtype[Any]], MultiVector] QBXForcedLimit = int | Literal["avg"] | None +ExpressionT = TypeVar("ExpressionT", bound=ExpressionBase) +OperandT = TypeVar("OperandT", bound=Operand) + class _NoArgSentinel: pass @@ -359,17 +367,27 @@ class cse_scope(cse_scope_base): # noqa: N801 DISCRETIZATION = "pytential_discretization" -# {{{ helper functions - -def array_to_tuple(ary): - """This function is typically used to make :class:`numpy.ndarray` - instances hashable by converting them to tuples. +def for_each_expression( + f: Callable[Concatenate[ExpressionBase, P], ExpressionBase] + ) -> Callable[Concatenate[OperandT, P], OperandT]: + """A decorator that takes a function that can only work on expressions + and transforms it into a function that can be applied componentwise on + :class:`numpy.ndarray` or :class:`~pymbolic.geometric_algebra.MultiVector`. """ - if isinstance(ary, np.ndarray): - return tuple(ary) - else: - return ary + from functools import wraps + + @wraps(f) + def wrapper(operand: OperandT, *args: P.args, **kwargs: P.kwargs) -> OperandT: + if isinstance(operand, np.ndarray | MultiVector): + def func(operand_i: ExpressionBase) -> ExpressionBase: + return f(operand_i, *args, **kwargs) + + return cast(OperandT, componentwise(func, operand)) + else: + return operand + + return wrapper # }}} @@ -560,11 +578,20 @@ def __new__(cls, operand: Operand | None = None, dofdesc: DOFDescriptor | None = None, ) -> "NumReferenceDerivative": - # If the constructor is handed a multivector object, return an - # object array of the operator applied to each of the - # coefficients in the multivector. + if isinstance(ref_axes, int): + warn(f"Passing an 'int' as 'ref_axes' to {cls.__name__!r} " + "is deprecated and will result in an error in 2025. Pass the " + "well-formatted tuple '((ref_axes, 1),)' instead.", + DeprecationWarning, stacklevel=2) + + ref_axes = ((ref_axes, 1),) + + if isinstance(operand, np.ndarray | MultiVector): + warn(f"Passing {type(operand)} directly to {cls.__name__!r} " + "is deprecated and will result in an error from 2025. Use " + "the 'num_reference_derivative' function instead.", + DeprecationWarning, stacklevel=3) - if isinstance(operand, np.ndarray): def make_op(operand_i): return cls(ref_axes, operand_i, as_dofdesc(dofdesc)) @@ -574,35 +601,46 @@ def make_op(operand_i): else: return DiscretizationProperty.__new__(cls) - # FIXME: this is added for backwards compatibility with pre-dataclass expressions + # FIXME: this is added for backwards compatibility with pre-dataclass expressions. + # Ideally, we'd just have a __post_init__, but the order of the arguments is + # different.. def __init__(self, ref_axes: tuple[tuple[int, int], ...], - operand: Expression, + operand: ExpressionBase, dofdesc: DOFDescriptorLike) -> None: + if not isinstance(ref_axes, tuple): + raise ValueError(f"'ref_axes' must be a tuple: {type(ref_axes)}") + + if tuple(sorted(ref_axes)) != ref_axes: + raise ValueError( + f"'ref_axes' must be sorted by axis index: {ref_axes}" + ) + + if len(dict(ref_axes)) != len(ref_axes): + raise ValueError( + f"'ref_axes' must not contain an axis more than once: {ref_axes}" + ) + object.__setattr__(self, "ref_axes", ref_axes) object.__setattr__(self, "operand", operand) super().__init__(dofdesc) # type: ignore[arg-type] - if isinstance(self.ref_axes, int): - warn(f"Passing an 'int' as 'ref_axes' to {type(self).__name__!r} " - "is deprecated and will be removed in 2025. Pass the " - "well-formatted tuple '((ref_axes, 1),)' instead.", - DeprecationWarning, stacklevel=2) - object.__setattr__(self, "ref_axes", ((self.ref_axes, 1),)) +@for_each_expression +def num_reference_derivative( + expr: ExpressionBase, + ref_axes: tuple[tuple[int, int], ...] | int, + dofdesc: DOFDescriptorLike | None) -> NumReferenceDerivative: + """Take a derivative of *expr* with respect to the the element reference + coordinates. - if not isinstance(self.ref_axes, tuple): - raise ValueError(f"'ref_axes' must be a tuple: {type(self)}") + See :class:`~pytential.symbolic.primitives.NumReferenceDerivative`. + """ - if tuple(sorted(self.ref_axes)) != self.ref_axes: - raise ValueError( - f"'ref_axes' must be sorted by axis index: {self.ref_axes}" - ) + if isinstance(ref_axes, int): + ref_axes = ((ref_axes, 1),) - if len(dict(self.ref_axes)) != len(self.ref_axes): - raise ValueError( - f"'ref_axes' must not contain an axis more than once: {self.ref_axes}" - ) + return NumReferenceDerivative(ref_axes, expr, as_dofdesc(dofdesc)) def reference_jacobian(func, output_dim, dim, dofdesc=None): @@ -1135,11 +1173,12 @@ def __new__(cls, from_dd = as_dofdesc(from_dd) to_dd = as_dofdesc(to_dd) - if from_dd == to_dd: - # FIXME: __new__ should return a class instance - return operand # type: ignore[return-value] + if isinstance(operand, np.ndarray | MultiVector): + warn(f"Passing {type(operand)} directly to {cls.__name__!r} " + "is deprecated and will result in an error from 2025. Use " + "the 'interpolate' function instead.", + DeprecationWarning, stacklevel=3) - if isinstance(operand, np.ndarray): def make_op(operand_i): return cls(from_dd, to_dd, operand_i) @@ -1166,9 +1205,26 @@ def __post_init__(self) -> None: def interp(from_dd, to_dd, operand): + warn("Calling 'interp' is deprecated and it will be removed in 2025. Use " + "'interpolate' instead (has a different argument order).", + DeprecationWarning, stacklevel=2) + return Interpolation(as_dofdesc(from_dd), as_dofdesc(to_dd), operand) +@for_each_expression +def interpolate(operand: ExpressionT, + from_dd: DOFDescriptorLike, + to_dd: DOFDescriptorLike) -> ExpressionT | Interpolation: + from_dd = as_dofdesc(from_dd) + to_dd = as_dofdesc(to_dd) + + if from_dd == to_dd: + return operand + + return Interpolation(from_dd, to_dd, operand) + + @expr_dataclass() class SingleScalarOperandExpression(Expression): """ @@ -1180,11 +1236,13 @@ class SingleScalarOperandExpression(Expression): def __new__(cls, operand: Operand | None = None) -> "SingleScalarOperandExpression": - # If the constructor is handed a multivector object, return an - # object array of the operator applied to each of the - # coefficients in the multivector. - if isinstance(operand, np.ndarray | MultiVector): + name = cls.mapper_method[4:] + warn(f"Passing {type(operand)} directly to {cls.__name__!r} " + "is deprecated and will result in an error from 2025. Use " + f"the '{name}' function instead.", + DeprecationWarning, stacklevel=3) + def make_op(operand_i): return cls(operand_i) @@ -1201,6 +1259,11 @@ class NodeSum(SingleScalarOperandExpression): """ +@for_each_expression +def node_sum(expr: ExpressionBase) -> NodeSum: + return NodeSum(expr) + + @expr_dataclass() class NodeMax(SingleScalarOperandExpression): """Bases: :class:`~pytential.symbolic.primitives.Expression`. @@ -1209,6 +1272,11 @@ class NodeMax(SingleScalarOperandExpression): """ +@for_each_expression +def node_max(expr: ExpressionBase) -> NodeMax: + return NodeMax(expr) + + @expr_dataclass() class NodeMin(SingleScalarOperandExpression): """Bases: :class:`~pytential.symbolic.primitives.Expression`. @@ -1217,6 +1285,11 @@ class NodeMin(SingleScalarOperandExpression): """ +@for_each_expression +def node_min(expr: ExpressionBase) -> NodeMin: + return NodeMin(expr) + + def integral(ambient_dim, dim, operand, dofdesc=None): """A volume integral of *operand*.""" @@ -1243,11 +1316,13 @@ def __new__(cls, operand: Operand | None = None, dofdesc: DOFDescriptorLike | None = None, ) -> "SingleScalarOperandExpressionWithWhere": - # If the constructor is handed a multivector object, return an - # object array of the operator applied to each of the - # coefficients in the multivector. - if isinstance(operand, np.ndarray | MultiVector): + name = cls.mapper_method[4:] + warn(f"Passing {type(operand)} directly to {cls.__name__!r} " + "is deprecated and will result in an error from 2025. Use " + f"the '{name}' function instead.", + DeprecationWarning, stacklevel=2) + def make_op(operand_i): return cls(operand_i, as_dofdesc(dofdesc)) @@ -1274,6 +1349,12 @@ class ElementwiseSum(SingleScalarOperandExpressionWithWhere): """ +@for_each_expression +def elementwise_sum(expr: ExpressionBase, + dofdesc: DOFDescriptorLike = None) -> ElementwiseSum: + return ElementwiseSum(expr, as_dofdesc(dofdesc)) + + @expr_dataclass() class ElementwiseMin(SingleScalarOperandExpressionWithWhere): """Bases: :class:`~pytential.symbolic.primitives.Expression`. @@ -1283,6 +1364,12 @@ class ElementwiseMin(SingleScalarOperandExpressionWithWhere): """ +@for_each_expression +def elementwise_min(expr: ExpressionBase, + dofdesc: DOFDescriptorLike = None) -> ElementwiseMin: + return ElementwiseMin(expr, as_dofdesc(dofdesc)) + + @expr_dataclass() class ElementwiseMax(SingleScalarOperandExpressionWithWhere): """Bases: :class:`~pytential.symbolic.primitives.Expression`. @@ -1292,6 +1379,12 @@ class ElementwiseMax(SingleScalarOperandExpressionWithWhere): """ +@for_each_expression +def elementwise_max(expr: ExpressionBase, + dofdesc: DOFDescriptorLike = None) -> ElementwiseMax: + return ElementwiseMax(expr, as_dofdesc(dofdesc)) + + @expr_dataclass() class Ones(Expression): """A DOF-vector that is constant *one* on the whole discretization.