Skip to content

Commit

Permalink
primitives: replace __new__ vectorization
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Nov 16, 2024
1 parent b009509 commit bef9a2a
Showing 1 changed file with 138 additions and 45 deletions.
183 changes: 138 additions & 45 deletions pytential/symbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
THE SOFTWARE.
"""

from collections.abc import Callable
from dataclasses import field
from warnings import warn
from functools import partial
from typing import Any, Union, Literal
from typing import Any, Concatenate, Literal, TypeVar, Union, cast

import numpy as np

Expand All @@ -38,6 +39,7 @@
NablaComponent, Derivative as DerivativeBase)
from pymbolic.primitives import make_sym_vector

from pytools import P
from pytools.obj_array import make_obj_array, flat_obj_array
from sumpy.kernel import Kernel, SpatialConstant

Expand Down Expand Up @@ -347,9 +349,15 @@
)


Operand = Union["Expression", np.ndarray, MultiVector]
# {{{ helpers


Operand = Union["ExpressionBase", np.ndarray[Any, np.dtype[Any]], MultiVector]
QBXForcedLimit = int | Literal["avg"] | None

ExpressionT = TypeVar("ExpressionT", bound=ExpressionBase)
OperandT = TypeVar("OperandT", bound=Operand)


class _NoArgSentinel:
pass
Expand All @@ -359,17 +367,27 @@ class cse_scope(cse_scope_base): # noqa: N801
DISCRETIZATION = "pytential_discretization"


# {{{ helper functions

def array_to_tuple(ary):
"""This function is typically used to make :class:`numpy.ndarray`
instances hashable by converting them to tuples.
def for_each_expression(
f: Callable[Concatenate[ExpressionBase, P], ExpressionBase]
) -> Callable[Concatenate[OperandT, P], OperandT]:
"""A decorator that takes a function that can only work on expressions
and transforms it into a function that can be applied componentwise on
:class:`numpy.ndarray` or :class:`~pymbolic.geometric_algebra.MultiVector`.
"""

if isinstance(ary, np.ndarray):
return tuple(ary)
else:
return ary
from functools import wraps

@wraps(f)
def wrapper(operand: OperandT, *args: P.args, **kwargs: P.kwargs) -> OperandT:
if isinstance(operand, np.ndarray | MultiVector):
def func(operand_i: ExpressionBase) -> ExpressionBase:
return f(operand_i, *args, **kwargs)

return cast(OperandT, componentwise(func, operand))
else:
return operand

return wrapper

# }}}

Expand Down Expand Up @@ -560,11 +578,20 @@ def __new__(cls,
operand: Operand | None = None,
dofdesc: DOFDescriptor | None = None,
) -> "NumReferenceDerivative":
# If the constructor is handed a multivector object, return an
# object array of the operator applied to each of the
# coefficients in the multivector.
if isinstance(ref_axes, int):
warn(f"Passing an 'int' as 'ref_axes' to {cls.__name__!r} "
"is deprecated and will result in an error in 2025. Pass the "
"well-formatted tuple '((ref_axes, 1),)' instead.",
DeprecationWarning, stacklevel=2)

ref_axes = ((ref_axes, 1),)

if isinstance(operand, np.ndarray | MultiVector):
warn(f"Passing {type(operand)} directly to {cls.__name__!r} "
"is deprecated and will result in an error from 2025. Use "
"the 'num_reference_derivative' function instead.",
DeprecationWarning, stacklevel=3)

if isinstance(operand, np.ndarray):
def make_op(operand_i):
return cls(ref_axes, operand_i, as_dofdesc(dofdesc))

Expand All @@ -574,35 +601,46 @@ def make_op(operand_i):
else:
return DiscretizationProperty.__new__(cls)

# FIXME: this is added for backwards compatibility with pre-dataclass expressions
# FIXME: this is added for backwards compatibility with pre-dataclass expressions.
# Ideally, we'd just have a __post_init__, but the order of the arguments is
# different..
def __init__(self,
ref_axes: tuple[tuple[int, int], ...],
operand: Expression,
operand: ExpressionBase,
dofdesc: DOFDescriptorLike) -> None:
if not isinstance(ref_axes, tuple):
raise ValueError(f"'ref_axes' must be a tuple: {type(ref_axes)}")

if tuple(sorted(ref_axes)) != ref_axes:
raise ValueError(
f"'ref_axes' must be sorted by axis index: {ref_axes}"
)

if len(dict(ref_axes)) != len(ref_axes):
raise ValueError(
f"'ref_axes' must not contain an axis more than once: {ref_axes}"
)

object.__setattr__(self, "ref_axes", ref_axes)
object.__setattr__(self, "operand", operand)
super().__init__(dofdesc) # type: ignore[arg-type]

if isinstance(self.ref_axes, int):
warn(f"Passing an 'int' as 'ref_axes' to {type(self).__name__!r} "
"is deprecated and will be removed in 2025. Pass the "
"well-formatted tuple '((ref_axes, 1),)' instead.",
DeprecationWarning, stacklevel=2)

object.__setattr__(self, "ref_axes", ((self.ref_axes, 1),))
@for_each_expression
def num_reference_derivative(
expr: ExpressionBase,
ref_axes: tuple[tuple[int, int], ...] | int,
dofdesc: DOFDescriptorLike | None) -> NumReferenceDerivative:
"""Take a derivative of *expr* with respect to the the element reference
coordinates.
if not isinstance(self.ref_axes, tuple):
raise ValueError(f"'ref_axes' must be a tuple: {type(self)}")
See :class:`~pytential.symbolic.primitives.NumReferenceDerivative`.
"""

if tuple(sorted(self.ref_axes)) != self.ref_axes:
raise ValueError(
f"'ref_axes' must be sorted by axis index: {self.ref_axes}"
)
if isinstance(ref_axes, int):
ref_axes = ((ref_axes, 1),)

if len(dict(self.ref_axes)) != len(self.ref_axes):
raise ValueError(
f"'ref_axes' must not contain an axis more than once: {self.ref_axes}"
)
return NumReferenceDerivative(ref_axes, expr, as_dofdesc(dofdesc))


def reference_jacobian(func, output_dim, dim, dofdesc=None):
Expand Down Expand Up @@ -1135,11 +1173,12 @@ def __new__(cls,
from_dd = as_dofdesc(from_dd)
to_dd = as_dofdesc(to_dd)

if from_dd == to_dd:
# FIXME: __new__ should return a class instance
return operand # type: ignore[return-value]
if isinstance(operand, np.ndarray | MultiVector):
warn(f"Passing {type(operand)} directly to {cls.__name__!r} "
"is deprecated and will result in an error from 2025. Use "
"the 'interpolate' function instead.",
DeprecationWarning, stacklevel=3)

if isinstance(operand, np.ndarray):
def make_op(operand_i):
return cls(from_dd, to_dd, operand_i)

Expand All @@ -1166,9 +1205,26 @@ def __post_init__(self) -> None:


def interp(from_dd, to_dd, operand):
warn("Calling 'interp' is deprecated and it will be removed in 2025. Use "
"'interpolate' instead (has a different argument order).",
DeprecationWarning, stacklevel=2)

return Interpolation(as_dofdesc(from_dd), as_dofdesc(to_dd), operand)


@for_each_expression
def interpolate(operand: ExpressionT,
from_dd: DOFDescriptorLike,
to_dd: DOFDescriptorLike) -> ExpressionT | Interpolation:
from_dd = as_dofdesc(from_dd)
to_dd = as_dofdesc(to_dd)

if from_dd == to_dd:
return operand

return Interpolation(from_dd, to_dd, operand)


@expr_dataclass()
class SingleScalarOperandExpression(Expression):
"""
Expand All @@ -1180,11 +1236,13 @@ class SingleScalarOperandExpression(Expression):

def __new__(cls,
operand: Operand | None = None) -> "SingleScalarOperandExpression":
# If the constructor is handed a multivector object, return an
# object array of the operator applied to each of the
# coefficients in the multivector.

if isinstance(operand, np.ndarray | MultiVector):
name = cls.mapper_method[4:]
warn(f"Passing {type(operand)} directly to {cls.__name__!r} "
"is deprecated and will result in an error from 2025. Use "
f"the '{name}' function instead.",
DeprecationWarning, stacklevel=3)

def make_op(operand_i):
return cls(operand_i)

Expand All @@ -1201,6 +1259,11 @@ class NodeSum(SingleScalarOperandExpression):
"""


@for_each_expression
def node_sum(expr: ExpressionBase) -> NodeSum:
return NodeSum(expr)


@expr_dataclass()
class NodeMax(SingleScalarOperandExpression):
"""Bases: :class:`~pytential.symbolic.primitives.Expression`.
Expand All @@ -1209,6 +1272,11 @@ class NodeMax(SingleScalarOperandExpression):
"""


@for_each_expression
def node_max(expr: ExpressionBase) -> NodeMax:
return NodeMax(expr)


@expr_dataclass()
class NodeMin(SingleScalarOperandExpression):
"""Bases: :class:`~pytential.symbolic.primitives.Expression`.
Expand All @@ -1217,6 +1285,11 @@ class NodeMin(SingleScalarOperandExpression):
"""


@for_each_expression
def node_min(expr: ExpressionBase) -> NodeMin:
return NodeMin(expr)


def integral(ambient_dim, dim, operand, dofdesc=None):
"""A volume integral of *operand*."""

Expand All @@ -1243,11 +1316,13 @@ def __new__(cls,
operand: Operand | None = None,
dofdesc: DOFDescriptorLike | None = None,
) -> "SingleScalarOperandExpressionWithWhere":
# If the constructor is handed a multivector object, return an
# object array of the operator applied to each of the
# coefficients in the multivector.

if isinstance(operand, np.ndarray | MultiVector):
name = cls.mapper_method[4:]
warn(f"Passing {type(operand)} directly to {cls.__name__!r} "
"is deprecated and will result in an error from 2025. Use "
f"the '{name}' function instead.",
DeprecationWarning, stacklevel=2)

def make_op(operand_i):
return cls(operand_i, as_dofdesc(dofdesc))

Expand All @@ -1274,6 +1349,12 @@ class ElementwiseSum(SingleScalarOperandExpressionWithWhere):
"""


@for_each_expression
def elementwise_sum(expr: ExpressionBase,
dofdesc: DOFDescriptorLike = None) -> ElementwiseSum:
return ElementwiseSum(expr, as_dofdesc(dofdesc))


@expr_dataclass()
class ElementwiseMin(SingleScalarOperandExpressionWithWhere):
"""Bases: :class:`~pytential.symbolic.primitives.Expression`.
Expand All @@ -1283,6 +1364,12 @@ class ElementwiseMin(SingleScalarOperandExpressionWithWhere):
"""


@for_each_expression
def elementwise_min(expr: ExpressionBase,
dofdesc: DOFDescriptorLike = None) -> ElementwiseMin:
return ElementwiseMin(expr, as_dofdesc(dofdesc))


@expr_dataclass()
class ElementwiseMax(SingleScalarOperandExpressionWithWhere):
"""Bases: :class:`~pytential.symbolic.primitives.Expression`.
Expand All @@ -1292,6 +1379,12 @@ class ElementwiseMax(SingleScalarOperandExpressionWithWhere):
"""


@for_each_expression
def elementwise_max(expr: ExpressionBase,
dofdesc: DOFDescriptorLike = None) -> ElementwiseMax:
return ElementwiseMax(expr, as_dofdesc(dofdesc))


@expr_dataclass()
class Ones(Expression):
"""A DOF-vector that is constant *one* on the whole discretization.
Expand Down

0 comments on commit bef9a2a

Please sign in to comment.