diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cff3ac06..eac6fbcf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,8 +10,8 @@ on: - cron: '17 3 * * 0' concurrency: - group: ${{ github.head_ref || github.ref_name }} - cancel-in-progress: true + group: ${{ github.head_ref || github.ref_name }} + cancel-in-progress: true jobs: typos: @@ -77,7 +77,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.x"] + python-version: ["3.10", "3.12", "3.x"] steps: - uses: actions/checkout@v4 - diff --git a/doc/algorithms.rst b/doc/algorithms.rst index 6ef962a3..25237233 100644 --- a/doc/algorithms.rst +++ b/doc/algorithms.rst @@ -2,12 +2,3 @@ Algorithms ========== .. automodule:: pymbolic.algorithm - -.. autofunction:: integer_power -.. autofunction:: extended_euclidean -.. autofunction:: gcd -.. autofunction:: lcm -.. autofunction:: fft -.. autofunction:: ifft -.. autofunction:: sym_fft - diff --git a/doc/conf.py b/doc/conf.py index 5d2018a3..a3c1b0af 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -31,3 +31,8 @@ "ExpressionT": "ExpressionT", "ArithmeticExpressionT": "ArithmeticExpressionT", } + +import sys + + +sys._BUILDING_SPHINX_DOCS = True diff --git a/doc/mappers.rst b/doc/mappers.rst index 4a92531f..205aded6 100644 --- a/doc/mappers.rst +++ b/doc/mappers.rst @@ -44,5 +44,10 @@ Analysis tools .. automodule:: pymbolic.mapper.analysis +Simplification +^^^^^^^^^^^^^^ + +.. automodule:: pymbolic.mapper.flattener + .. vim: sw=4 diff --git a/pymbolic/__init__.py b/pymbolic/__init__.py index 1d92fb7e..5ed3cb35 100644 --- a/pymbolic/__init__.py +++ b/pymbolic/__init__.py @@ -38,8 +38,6 @@ from .mapper import flattener from . import primitives -from .polynomial import Polynomial - from .primitives import (Variable as var, # noqa: N813 Variable, Expression, @@ -73,7 +71,6 @@ "Expression", "ExpressionT", "NumberT", - "Polynomial", "ScalarT", "Variable", "compile", diff --git a/pymbolic/algorithm.py b/pymbolic/algorithm.py index 2bdbb6b8..f66fc604 100644 --- a/pymbolic/algorithm.py +++ b/pymbolic/algorithm.py @@ -1,3 +1,15 @@ +""" +.. autofunction:: integer_power +.. autofunction:: extended_euclidean +.. autofunction:: gcd +.. autofunction:: lcm +.. autofunction:: fft +.. autofunction:: ifft +.. autofunction:: sym_fft +.. autofunction:: reduced_row_echelon_form +.. autofunction:: solve_affine_equations_for +""" + from __future__ import annotations @@ -23,7 +35,16 @@ THE SOFTWARE. """ -from pytools import memoize +import operator +import sys +from typing import TYPE_CHECKING, overload +from warnings import warn + +from pytools import MovedFunctionDeprecationWrapper, memoize + + +if TYPE_CHECKING or getattr(sys, "_BUILDING_SPHINX_DOCS", None): + import numpy as np # {{{ integer powers @@ -281,10 +302,47 @@ def csr_matrix_multiply(S, x): # noqa return result -# {{{ gaussian elimination +# {{{ reduced_row_echelon_form -def gaussian_elimination(mat, rhs): +@overload +def reduced_row_echelon_form( + mat: np.ndarray, + *, integral: bool | None = None, + ) -> np.ndarray: + ... + + +@overload +def reduced_row_echelon_form( + mat: np.ndarray, + rhs: np.ndarray, + *, integral: bool | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + ... + + +def reduced_row_echelon_form( + mat: np.ndarray, + rhs: np.ndarray | None = None, + integral: bool | None = None, + ) -> tuple[np.ndarray, np.ndarray] | np.ndarray: m, n = mat.shape + + mat = mat.copy() + if rhs is not None: + rhs = rhs.copy() + + if integral is None: + warn( + "Not specifying 'integral' is deprecated, please add it as an argument. " + "This will stop being supported in 2025.", + DeprecationWarning, stacklevel=2) + + if integral: + div_func = operator.floordiv + else: + div_func = operator.truediv + i = 0 j = 0 @@ -303,8 +361,9 @@ def gaussian_elimination(mat, rhs): # swap rows i and nonz mat[i], mat[nonz_row] = \ (mat[nonz_row].copy(), mat[i].copy()) - rhs[i], rhs[nonz_row] = \ - (rhs[nonz_row].copy(), rhs[i].copy()) + if rhs is not None: + rhs[i], rhs[nonz_row] = \ + (rhs[nonz_row].copy(), rhs[i].copy()) for u in range(0, m): if u == i: @@ -314,11 +373,12 @@ def gaussian_elimination(mat, rhs): continue ell = lcm(mat[u, j], mat[i, j]) - u_fac = ell//mat[u, j] - i_fac = ell//mat[i, j] + u_fac = div_func(ell, mat[u, j]) + i_fac = div_func(ell, mat[i, j]) mat[u] = u_fac*mat[u] - i_fac*mat[i] - rhs[u] = u_fac*rhs[u] - i_fac*rhs[i] + if rhs is not None: + rhs[u] = u_fac*rhs[u] - i_fac*rhs[i] assert mat[u, j] == 0 @@ -326,20 +386,38 @@ def gaussian_elimination(mat, rhs): j += 1 - for i in range(m): - g = gcd_many(*( - [a for a in mat[i] if a] - + - [a for a in rhs[i] if a])) + if integral: + for i in range(m): + g = gcd_many(*( + [a for a in mat[i] if a] + + + [a for a in rhs[i] if a] if rhs is not None else [])) + + mat[i] //= g + if rhs is not None: + rhs[i] //= g + + import numpy as np + + from pymbolic.mapper.flattener import flatten + vec_flatten = np.vectorize(flatten, otypes=[object]) - mat[i] //= g - rhs[i] //= g + for i in range(m): + mat[i] = vec_flatten(mat[i]) + if rhs is not None: + rhs[i] = vec_flatten(rhs[i]) - return mat, rhs + if rhs is None: + return mat + else: + return mat, rhs # }}} +gaussian_elimination = MovedFunctionDeprecationWrapper(reduced_row_echelon_form, "2025") + + # {{{ symbolic (linear) equation solving def solve_affine_equations_for(unknowns, equations): @@ -393,7 +471,7 @@ def solve_affine_equations_for(unknowns, equations): # }}} - mat, rhs_mat = gaussian_elimination(mat, rhs_mat) + mat, rhs_mat = reduced_row_echelon_form(mat, rhs_mat, integral=True) # FIXME /!\ Does not check for overdetermined system. @@ -411,7 +489,8 @@ def solve_affine_equations_for(unknowns, equations): div = mat[nonz_row, j] unknown_val = int(rhs_mat[nonz_row, -1]) // div - for parameter, coeff in zip(parameters_list, rhs_mat[nonz_row]): + for parameter, coeff in zip( + parameters_list, rhs_mat[nonz_row, :-1], strict=True): unknown_val += (int(coeff) // div) * parameter result[unknown] = unknown_val diff --git a/pymbolic/compiler.py b/pymbolic/compiler.py index ab9a811f..cad698ae 100644 --- a/pymbolic/compiler.py +++ b/pymbolic/compiler.py @@ -26,7 +26,7 @@ import math import pymbolic -from pymbolic.mapper.stringifier import PREC_NONE, PREC_POWER, PREC_SUM, StringifyMapper +from pymbolic.mapper.stringifier import PREC_NONE, StringifyMapper class CompileMapper(StringifyMapper): @@ -45,34 +45,6 @@ def map_constant(self, expr, enclosing_prec): return repr(expr) - def map_polynomial(self, expr, enclosing_prec): - # Use Horner's scheme to evaluate the polynomial - - sbase = self(expr.base, PREC_POWER) - - def stringify_exp(exp): - if exp == 0: - return "" - elif exp == 1: - return f"*{sbase}" - else: - return f"*{sbase}**{exp}" - - result = "" - rev_data = expr.data[::-1] - for i, (exp, coeff) in enumerate(rev_data): - if i+1 < len(rev_data): - next_exp = rev_data[i+1][0] - else: - next_exp = 0 - result = "({}+{}){}".format(result, self(coeff, PREC_SUM), - stringify_exp(exp-next_exp)) - - if enclosing_prec > PREC_SUM and len(expr.data) > 1: - return f"({result})" - else: - return result - def map_numpy_array(self, expr, enclosing_prec): def stringify_leading_dimension(ary): if len(ary.shape) == 1: diff --git a/pymbolic/geometric_algebra/__init__.py b/pymbolic/geometric_algebra/__init__.py index 0efd8134..c14cfbad 100644 --- a/pymbolic/geometric_algebra/__init__.py +++ b/pymbolic/geometric_algebra/__init__.py @@ -23,10 +23,18 @@ THE SOFTWARE. """ +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from typing import Any, Generic, TypeVar, cast + import numpy as np from pytools import memoize, memoize_method +from pymbolic.primitives import expr_dataclass +from pymbolic.typing import ArithmeticExpressionT + __doc__ = """ See `Wikipedia `__ for an idea @@ -50,6 +58,10 @@ Multivectors ------------ +.. class:: CoeffT + + A type variable for coefficients of :class:`MultiVector`. Requires some arithmetic. + .. autoclass:: MultiVector .. _ga-examples: @@ -178,16 +190,20 @@ def canonical_reordering_sign(a_bits, b_bits): # {{{ space +@dataclass(frozen=True, init=False) class Space: """ - .. attribute :: basis_names - - A sequence of names of basis vectors. + .. autoattribute :: basis_names + .. autoattribute :: metric_matrix + """ - .. attribute :: metric_matrix + basis_names: Sequence[str] + "A sequence of names of basis vectors." - A *(dims,dims)*-shaped matrix, whose *(i,j)*-th entry represents the - inner product of basis vector *i* and basis vector *j*. + metric_matrix: np.ndarray + """ + A *(dims,dims)*-shaped matrix, whose *(i,j)*-th entry represents the + inner product of basis vector *i* and basis vector *j*. """ def __init__(self, basis=None, metric_matrix=None): @@ -217,16 +233,13 @@ def __init__(self, basis=None, metric_matrix=None): and all(dim == len(basis) for dim in metric_matrix.shape)): raise ValueError("metric_matrix has the wrong shape") - self.basis_names = basis - self.metric_matrix = metric_matrix + object.__setattr__(self, "basis_names", basis) + object.__setattr__(self, "metric_matrix", metric_matrix) @property - def dimensions(self): + def dimensions(self) -> int: return len(self.basis_names) - def __getinitargs__(self): - return (self.basis_names, self.metric_matrix) - @memoize_method def bits_and_sign(self, basis_indices): # assert no repetitions @@ -277,6 +290,9 @@ def get_euclidean_space(n): # }}} +CoeffT = TypeVar("CoeffT", bound=ArithmeticExpressionT) + + # {{{ blade product weights def _shared_metric_coeff(shared_bits, space): @@ -294,8 +310,18 @@ def _shared_metric_coeff(shared_bits, space): return result -class _GAProduct: - pass +class _GAProduct(ABC, Generic[CoeffT]): + @staticmethod + @abstractmethod + def generic_blade_product_weight(a_bits: int, b_bits: int, space: Space) -> CoeffT: + ... + + @staticmethod + @abstractmethod + def orthogonal_blade_product_weight( + a_bits: int, b_bits: int, space: Space + ) -> CoeffT: + ... class _OuterProduct(_GAProduct): @@ -388,25 +414,23 @@ def orthogonal_blade_product_weight(a_bits, b_bits, space): # {{{ multivector -def _cast_or_ni(obj, space): +def _cast_to_mv(obj: Any, space: Space) -> MultiVector: if isinstance(obj, MultiVector): return obj else: return MultiVector(obj, space) -class MultiVector: +@expr_dataclass(init=False, hash=False) +class MultiVector(Generic[CoeffT]): r"""An immutable multivector type. Its implementation follows [DFM]. It is pickleable, and not picky about what data is used as coefficients. It supports :class:`pymbolic.primitives.Expression` objects of course, but it can take just about any other scalar-ish coefficients. - .. attribute:: data + .. autoattribute:: data - A mapping from a basis vector bitmap (indicating blades) to coefficients. - (see [DFM], Chapter 19 for idea and rationale) - - .. attribute:: space + .. autoattribute:: space See the following literature: @@ -496,9 +520,22 @@ class MultiVector: """ + data: Mapping[int, CoeffT] + """A mapping from a basis vector bitmap (indicating blades) to coefficients. + (see [DFM], Chapter 19 for idea and rationale) + """ + + space: Space + + mapper_method = "map_multivector" + # {{{ construction - def __init__(self, data, space=None): + def __init__( + self, + data: Mapping[tuple[int, ...] | int, CoeffT] | np.ndarray | CoeffT, + space: Space | None = None + ) -> None: """ :arg data: This may be one of the following: @@ -523,12 +560,11 @@ def __init__(self, data, space=None): raise ValueError("only numpy vectors (not higher-rank objects) " "are supported for 'data'") dimensions, = data.shape - data = { - (i,): xi for i, xi in enumerate(data)} + data = {(i,): xi for i, xi in enumerate(data)} elif isinstance(data, dict): pass else: - data = {0: data} + data = {0: cast(CoeffT, data)} if space is None: space = get_euclidean_space(dimensions) @@ -544,34 +580,28 @@ def __init__(self, data, space=None): from pymbolic.primitives import is_zero if data and single_valued(isinstance(k, tuple) for k in data.keys()): # data is in non-normalized non-bits tuple form - new_data = {} + new_data: dict[int, CoeffT] = {} for basis_indices, coeff in data.items(): bits, sign = space.bits_and_sign(basis_indices) - new_coeff = new_data.setdefault(bits, 0) + sign*coeff + new_coeff = new_data.setdefault(bits, cast(CoeffT, 0)) + sign*coeff if is_zero(new_coeff): del new_data[bits] else: new_data[bits] = new_coeff - - data = new_data + else: + new_data = cast(dict[int, CoeffT], data) # }}} # assert that multivectors don't get nested - assert not any(isinstance(coeff, MultiVector) - for coeff in data.values()) + assert not any(isinstance(coeff, MultiVector) for coeff in new_data.values()) - self.space = space - self.data = data + object.__setattr__(self, "space", space) + object.__setattr__(self, "data", new_data) # }}} - def __getinitargs__(self): - return (self.data, self.space) - - mapper_method = "map_multivector" - # {{{ stringification def stringify(self, coeff_stringifier, enclosing_prec): @@ -630,16 +660,14 @@ def __repr__(self): # {{{ additive operators - def __neg__(self): + def __neg__(self) -> MultiVector: return MultiVector( {bits: -coeff for bits, coeff in self.data.items()}, self.space) - def __add__(self, other): - other = _cast_or_ni(other, self.space) - if other is NotImplemented: - return NotImplemented + def __add__(self, other) -> MultiVector: + other = _cast_to_mv(other, self.space) if self.space is not other.space: raise ValueError("can only add multivectors from identical spaces") @@ -649,7 +677,8 @@ def __add__(self, other): from pymbolic.primitives import is_zero new_data = {} for bits in all_bits: - new_coeff = self.data.get(bits, 0) + other.data.get(bits, 0) + new_coeff = (self.data.get(bits, cast(CoeffT, 0)) + + other.data.get(bits, cast(CoeffT, 0))) if not is_zero(new_coeff): new_data[bits] = new_coeff @@ -669,7 +698,10 @@ def __rsub__(self, other): # {{{ multiplicative operators - def _generic_product(self, other, product_class): + def _generic_product(self, + other: MultiVector, + product_class: _GAProduct + ) -> MultiVector: """ :arg product_class: A subclass of :class:`_GAProduct`. """ @@ -684,7 +716,7 @@ def _generic_product(self, other, product_class): "from identical spaces") from pymbolic.primitives import is_zero - new_data = {} + new_data: dict[int, CoeffT] = {} for sbits, scoeff in self.data.items(): for obits, ocoeff in other.data.items(): new_bits = sbits ^ obits @@ -695,7 +727,7 @@ def _generic_product(self, other, product_class): coeff = (weight * canonical_reordering_sign(sbits, obits) * scoeff * ocoeff) - new_coeff = new_data.setdefault(new_bits, 0) + coeff + new_coeff = new_data.setdefault(new_bits, cast(CoeffT, 0)) + coeff if is_zero(new_coeff): del new_data[new_bits] else: @@ -704,9 +736,7 @@ def _generic_product(self, other, product_class): return MultiVector(new_data, self.space) def __mul__(self, other): - other = _cast_or_ni(other, self.space) - if other is NotImplemented: - return NotImplemented + other = _cast_to_mv(other, self.space) return self._generic_product(other, _GeometricProduct) @@ -715,9 +745,7 @@ def __rmul__(self, other): ._generic_product(self, _GeometricProduct) def __xor__(self, other): - other = _cast_or_ni(other, self.space) - if other is NotImplemented: - return NotImplemented + other = _cast_to_mv(other, self.space) return self._generic_product(other, _OuterProduct) @@ -726,9 +754,7 @@ def __rxor__(self, other): ._generic_product(self, _OuterProduct) def __or__(self, other): - other = _cast_or_ni(other, self.space) - if other is NotImplemented: - return NotImplemented + other = _cast_to_mv(other, self.space) return self._generic_product(other, _InnerProduct) @@ -737,9 +763,7 @@ def __ror__(self, other): ._generic_product(self, _InnerProduct) def __lshift__(self, other): - other = _cast_or_ni(other, self.space) - if other is NotImplemented: - return NotImplemented + other = _cast_to_mv(other, self.space) return self._generic_product(other, _LeftContractionProduct) @@ -748,9 +772,7 @@ def __rlshift__(self, other): ._generic_product(self, _LeftContractionProduct) def __rshift__(self, other): - other = _cast_or_ni(other, self.space) - if other is NotImplemented: - return NotImplemented + other = _cast_to_mv(other, self.space) return self._generic_product(other, _RightContractionProduct) @@ -764,10 +786,7 @@ def scalar_product(self, other): Often written :math:`A*B`. """ - other_new = _cast_or_ni(other, self.space) - if other_new is NotImplemented: - raise NotImplementedError( - f"scalar product between multivector and '{type(other)}'") + other_new = _cast_to_mv(other, self.space) return self._generic_product(other_new, _ScalarProduct).as_scalar() @@ -791,9 +810,7 @@ def __pow__(self, other): def __truediv__(self, other): """Return ``self*(1/other)``. """ - other = _cast_or_ni(other, self.space) - if other is NotImplemented: - return NotImplemented + other = _cast_to_mv(other, self.space) return self*other.inv() @@ -913,9 +930,7 @@ def __bool__(self): __nonzero__ = __bool__ def __eq__(self, other): - other = _cast_or_ni(other, self.space) - if other is NotImplemented: - return NotImplemented + other = _cast_to_mv(other, self.space) return self.data == other.data diff --git a/pymbolic/geometric_algebra/mapper.py b/pymbolic/geometric_algebra/mapper.py index 61ca8333..bf5b64ac 100644 --- a/pymbolic/geometric_algebra/mapper.py +++ b/pymbolic/geometric_algebra/mapper.py @@ -285,7 +285,8 @@ def map_product(self, expr): if not has_d_source_nablas: rec_children = [self.rec(child) for child in expr.children] if all(rec_child is child - for rec_child, child in zip(rec_children, expr.children)): + for rec_child, child in zip( + rec_children, expr.children, strict=True)): return expr return type(expr)(tuple(rec_children)) @@ -296,7 +297,7 @@ def map_product(self, expr): result = [list(expr.children)] for child_idx, (d_source_nabla_ids, _child) in enumerate( - zip(d_source_nabla_ids_per_child, expr.children)): + zip(d_source_nabla_ids_per_child, expr.children, strict=True)): if not d_source_nabla_ids: continue diff --git a/pymbolic/geometric_algebra/primitives.py b/pymbolic/geometric_algebra/primitives.py index b8140f3a..00cf299c 100644 --- a/pymbolic/geometric_algebra/primitives.py +++ b/pymbolic/geometric_algebra/primitives.py @@ -26,7 +26,8 @@ # This is experimental, undocumented, and could go away any second. # Consider yourself warned. -from typing import ClassVar, Hashable +from collections.abc import Hashable +from typing import ClassVar from pymbolic.primitives import Expression, Variable, expr_dataclass from pymbolic.typing import ExpressionT diff --git a/pymbolic/interop/ast.py b/pymbolic/interop/ast.py index 03ee7fe2..d4090e9e 100644 --- a/pymbolic/interop/ast.py +++ b/pymbolic/interop/ast.py @@ -31,7 +31,7 @@ import pymbolic.primitives as p from pymbolic.mapper import CachedMapper -from pymbolic.typing import ExpressionT, ScalarT +from pymbolic.typing import ExpressionT __doc__ = r''' @@ -215,15 +215,10 @@ def map_Str(self, expr): # noqa def map_Bytes(self, expr): # noqa return expr.s - # 3.8 and up def map_Constant(self, expr): # noqa # (singleton value) return expr.value - def map_NameConstant(self, expr): # noqa - # (singleton value) - return expr.value - def map_Attribute(self, expr): # noqa # (expr value, identifier attr, expr_context ctx) return p.Lookup(self.rec(expr.value), expr.attr) @@ -263,7 +258,7 @@ def map_Tuple(self, expr): # noqa # {{{ PymbolicToASTMapper -class PymbolicToASTMapper(CachedMapper): +class PymbolicToASTMapper(CachedMapper[ast.expr, []]): def map_variable(self, expr) -> ast.expr: return ast.Name(id=expr.name) @@ -283,11 +278,8 @@ def map_sum(self, expr: p.Sum) -> ast.expr: def map_product(self, expr: p.Product) -> ast.expr: return self._map_multi_children_op(expr.children, ast.Mult()) - def map_constant(self, expr: ScalarT) -> ast.expr: - if isinstance(expr, bool): - return ast.NameConstant(expr) - else: - return ast.Constant(expr, None) + def map_constant(self, expr: object) -> ast.expr: + return ast.Constant(expr, None) def map_call(self, expr: p.Call) -> ast.expr: return ast.Call( @@ -393,7 +385,7 @@ def map_nan(self, expr: p.NaN) -> ast.expr: raise NotImplementedError("Non-float nan not implemented") def map_slice(self, expr: p.Slice) -> ast.expr: - return ast.Slice(*[self.rec(child) + return ast.Slice(*[None if child is None else self.rec(child) for child in expr.children]) def map_numpy_array(self, expr) -> ast.expr: @@ -417,9 +409,6 @@ def map_if_positive(self, expr) -> ast.expr: def map_comparison(self, expr: p.Comparison) -> ast.expr: raise NotImplementedError - def map_polynomial(self, expr) -> ast.expr: - raise NotImplementedError - def map_wildcard(self, expr) -> ast.expr: raise NotImplementedError @@ -468,19 +457,10 @@ def to_evaluatable_python_function(expr: ExpressionT, def foo(*, E, S): return S // 32 + E % 32 """ - import sys from pymbolic.mapper.dependency import CachedDependencyMapper - if sys.version_info < (3, 9): - try: - from astunparse import unparse - except ImportError: - raise RuntimeError("'to_evaluate_python_function' needs" - "astunparse for Py<3.9. Install via `pip" - " install astunparse`") from None - else: - unparse = ast.unparse + unparse = ast.unparse dep_mapper = CachedDependencyMapper(composite_leaves=True) deps = sorted({dep.name for dep in dep_mapper(expr)}) diff --git a/pymbolic/interop/matchpy/__init__.py b/pymbolic/interop/matchpy/__init__.py index 38e8156c..a7a9e912 100644 --- a/pymbolic/interop/matchpy/__init__.py +++ b/pymbolic/interop/matchpy/__init__.py @@ -42,9 +42,10 @@ import abc +from collections.abc import Callable, Iterable, Iterator, Mapping from dataclasses import dataclass, field, fields from functools import partial -from typing import Callable, ClassVar, Generic, Iterable, Iterator, Mapping, TypeVar +from typing import ClassVar, Generic, TypeAlias, TypeVar from matchpy import ( Arity, @@ -54,7 +55,6 @@ ReplacementRule, Wildcard as BaseWildcard, ) -from typing_extensions import TypeAlias import pymbolic.primitives as p from pymbolic.typing import ScalarT @@ -85,7 +85,7 @@ def head(self): def __lt__(self, other): # Used by matchpy internally to order subexpressions - if not isinstance(other, (Expression,)): + if not isinstance(other, Expression): return NotImplemented if type(other) is type(self): if self.value == other.value: diff --git a/pymbolic/interop/matchpy/mapper.py b/pymbolic/interop/matchpy/mapper.py index 1af6aa0a..44f25b6e 100644 --- a/pymbolic/interop/matchpy/mapper.py +++ b/pymbolic/interop/matchpy/mapper.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from pymbolic.interop.matchpy import PymbolicOp diff --git a/pymbolic/interop/matchpy/tofrom.py b/pymbolic/interop/matchpy/tofrom.py index e4ae9d9c..63996881 100644 --- a/pymbolic/interop/matchpy/tofrom.py +++ b/pymbolic/interop/matchpy/tofrom.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable +from typing import Any import multiset import numpy as np diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 8ad8c92d..2a32cae4 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -24,17 +24,39 @@ """ from abc import ABC, abstractmethod -from typing import Any +from collections.abc import Callable, Hashable, Iterable, Mapping, Set +from typing import ( + TYPE_CHECKING, + Concatenate, + Generic, + TypeAlias, + TypeVar, + cast, +) +from warnings import warn from immutabledict import immutabledict +from typing_extensions import ParamSpec, TypeIs -import pymbolic.primitives as primitives +import pymbolic.primitives as p +from pymbolic.typing import ArithmeticExpressionT, ExpressionT + + +if TYPE_CHECKING: + import numpy as np + + from pymbolic.geometric_algebra import MultiVector + from pymbolic.rational import Rational __doc__ = """ Basic dispatch -------------- +.. class:: ResultT + + A type variable for the result returned by a :class:`Mapper`. + .. autoclass:: Mapper .. automethod:: __call__ @@ -96,14 +118,20 @@ """ -try: - import numpy +if TYPE_CHECKING: + import numpy as np - def is_numpy_array(val): - return isinstance(val, numpy.ndarray) -except ImportError: - def is_numpy_array(ary): - return False + def is_numpy_array(val) -> TypeIs[np.ndarray]: + return isinstance(val, np.ndarray) +else: + try: + import numpy as np + + def is_numpy_array(val): + return isinstance(val, np.ndarray) + except ImportError: + def is_numpy_array(ary): + return False class UnsupportedExpressionError(ValueError): @@ -112,15 +140,28 @@ class UnsupportedExpressionError(ValueError): # {{{ mapper base -class Mapper: +ResultT = TypeVar("ResultT") + +# This ParamSpec could be marked contravariant (just like Callable is contravariant +# in its arguments). As of mypy 1.14/Py3.13 (Nov 2024), mypy complains of as-yet +# undefined semantics, so it's probably too soon. +P = ParamSpec("P") + + +class Mapper(Generic[ResultT, P]): """A visitor for trees of :class:`pymbolic.Expression` subclasses. Each expression-derived object is dispatched to the method named by the :attr:`pymbolic.Expression.mapper_method` attribute and if not found, the methods named by the class attribute *mapper_method* in the method resolution order of the object. + + ..automethod:: handle_unsupported_expression + ..automethod:: __call__ + ..automethod:: rec """ - def handle_unsupported_expression(self, expr, *args, **kwargs): + def handle_unsupported_expression(self, + expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Mapper method that is invoked for :class:`pymbolic.Expression` subclasses for which a mapper method does not exist in this mapper. @@ -130,7 +171,8 @@ def handle_unsupported_expression(self, expr, *args, **kwargs): "{} cannot handle expressions of type {}".format( type(self), type(expr))) - def __call__(self, expr, *args, **kwargs): + def __call__(self, + expr: ExpressionT, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Dispatch *expr* to its corresponding mapper method. Pass on ``*args`` and ``**kwargs`` unmodified. @@ -148,7 +190,7 @@ def __call__(self, expr, *args, **kwargs): result = method(expr, *args, **kwargs) return result - if isinstance(expr, primitives.Expression): + if isinstance(expr, p.Expression): for cls in type(expr).__mro__[1:]: method_name = getattr(cls, "mapper_method", None) if method_name: @@ -162,8 +204,9 @@ def __call__(self, expr, *args, **kwargs): rec = __call__ - def rec_fallback(self, expr, *args, **kwargs): - if isinstance(expr, primitives.Expression): + def rec_fallback(self, + expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: + if isinstance(expr, p.Expression): for cls in type(expr).__mro__[1:]: method_name = getattr(cls, "mapper_method", None) if method_name: @@ -175,76 +218,188 @@ def rec_fallback(self, expr, *args, **kwargs): else: return self.map_foreign(expr, *args, **kwargs) - def map_algebraic_leaf(self, expr, *args, **kwargs): + def map_algebraic_leaf(self, + expr: p.AlgebraicLeaf, + *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_variable(self, expr, *args, **kwargs): + def map_variable(self, + expr: p.Variable, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: p.Subscript, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_call(self, expr, *args, **kwargs): + def map_call(self, + expr: p.Call, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_lookup(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: p.CallWithKwargs, + *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_if_positive(self, expr, *args, **kwargs): + def map_lookup(self, + expr: p.Lookup, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_rational(self, expr, *args, **kwargs): - return self.map_quotient(expr, *args, **kwargs) + def map_if(self, + expr: p.If, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError - def map_quotient(self, expr, *args, **kwargs): + def map_sum(self, + expr: p.Sum, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_constant(self, expr, *args, **kwargs): + def map_product(self, + expr: p.Product, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_list(self, expr, *args, **kwargs): + def map_rational(self, + expr: Rational, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_tuple(self, expr, *args, **kwargs): + def map_quotient(self, + expr: p.Quotient, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_numpy_array(self, expr, *args, **kwargs): + def map_floor_div(self, + expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_nan(self, expr, *args, **kwargs): + def map_remainder(self, + expr: p.Remainder, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_constant(self, + expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_comparison(self, + expr: p.Comparison, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_min(self, + expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_max(self, + expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_list(self, + expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_tuple(self, + expr: tuple[ExpressionT, ...], + *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_numpy_array(self, + expr: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_left_shift(self, + expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_right_shift(self, + expr: p.RightShift, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_bitwise_not(self, + expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_bitwise_or(self, + expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_bitwise_and(self, + expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_bitwise_xor(self, + expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_logical_not(self, + expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_logical_or(self, + expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_logical_and(self, + expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + raise NotImplementedError + + def map_nan(self, + expr: p.NaN, + *args: P.args, + **kwargs: P.kwargs + ) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_foreign(self, expr, *args, **kwargs): + def map_foreign(self, + expr: object, + *args: P.args, + **kwargs: P.kwargs + ) -> ResultT: """Mapper method dispatch for non-:mod:`pymbolic` objects.""" - if isinstance(expr, primitives.VALID_CONSTANT_CLASSES): + if isinstance(expr, p.VALID_CONSTANT_CLASSES): return self.map_constant(expr, *args, **kwargs) elif is_numpy_array(expr): return self.map_numpy_array(expr, *args, **kwargs) - elif isinstance(expr, list): - return self.map_list(expr, *args, **kwargs) elif isinstance(expr, tuple): return self.map_tuple(expr, *args, **kwargs) + elif isinstance(expr, list): + warn("List found in expression graph. " + "This is deprecated and will stop working in 2025. " + "Use tuples instead.", DeprecationWarning, stacklevel=2 + ) + return self.map_list(expr, *args, **kwargs) else: raise ValueError( "{} encountered invalid foreign object: {}".format( self.__class__, repr(expr))) -_NOT_IN_CACHE = object() +class _NotInCache: + pass -class CachedMapper(Mapper): +CacheKeyT: TypeAlias = Hashable + + +class CachedMapper(Mapper[ResultT, P]): """ A mapper that memoizes the mapped result for the expressions traversed. .. automethod:: get_cache_key """ - def __init__(self): - self._cache: dict[Any, Any] = {} + def __init__(self) -> None: + self._cache: dict[CacheKeyT, ResultT] = {} Mapper.__init__(self) - def get_cache_key(self, expr, *args, **kwargs): + def get_cache_key(self, + expr: ExpressionT, + *args: P.args, + **kwargs: P.kwargs + ) -> CacheKeyT: """ Returns the key corresponding to which the result of a mapper method is stored in the cache. @@ -260,16 +415,23 @@ def get_cache_key(self, expr, *args, **kwargs): # and "4 == 4.0", but their traversal results cannot be re-used. return (type(expr), expr, args, immutabledict(kwargs)) - def __call__(self, expr, *args, **kwargs): + def __call__(self, + expr: ExpressionT, + *args: P.args, + **kwargs: P.kwargs + ) -> ResultT: result = self._cache.get( (cache_key := self.get_cache_key(expr, *args, **kwargs)), - _NOT_IN_CACHE) - if result is not _NOT_IN_CACHE: + _NotInCache) + if not isinstance(result, type): return result method_name = getattr(expr, "mapper_method", None) if method_name is not None: - method = getattr(self, method_name, None) + method = cast( + Callable[Concatenate[ExpressionT, P], ResultT], + getattr(self, method_name, None) + ) if method is not None: result = method(expr, *args, **kwargs) self._cache[cache_key] = result @@ -284,12 +446,9 @@ def __call__(self, expr, *args, **kwargs): # }}} -RecursiveMapper = Mapper - - # {{{ combine mapper -class CombineMapper(RecursiveMapper): +class CombineMapper(Mapper[ResultT, P]): """A mapper whose goal it is to *combine* all branches of the expression tree into one final result. The default implementation of all mapper methods simply recurse (:meth:`Mapper.rec`) on all branches emanating from @@ -307,16 +466,19 @@ class CombineMapper(RecursiveMapper): :class:`pymbolic.mapper.dependency.DependencyMapper` is another example. """ - def combine(self, values): + def combine(self, values: Iterable[ResultT]) -> ResultT: raise NotImplementedError - def map_call(self, expr, *args, **kwargs): + def map_call(self, + expr: p.Call, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.function, *args, **kwargs), *[self.rec(child, *args, **kwargs) for child in expr.parameters] )) - def map_call_with_kwargs(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: p.CallWithKwargs, + *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.function, *args, **kwargs), *[self.rec(child, *args, **kwargs) for child in expr.parameters], @@ -324,87 +486,141 @@ def map_call_with_kwargs(self, expr, *args, **kwargs): for child in expr.kw_parameters.values()] )) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: p.Subscript, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine( [self.rec(expr.aggregate, *args, **kwargs), self.rec(expr.index, *args, **kwargs)]) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup(self, + expr: p.Lookup, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.rec(expr.aggregate, *args, **kwargs) - def map_sum(self, expr, *args, **kwargs): + def map_sum(self, + expr: p.Sum, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(self.rec(child, *args, **kwargs) for child in expr.children) - map_product = map_sum + def map_product(self, + expr: p.Product, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) - def map_quotient(self, expr, *args, **kwargs): + def map_quotient(self, + expr: p.Quotient, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.numerator, *args, **kwargs), self.rec(expr.denominator, *args, **kwargs))) - map_floor_div = map_quotient - map_remainder = map_quotient + def map_floor_div(self, + expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(( + self.rec(expr.numerator, *args, **kwargs), + self.rec(expr.denominator, *args, **kwargs))) + + def map_remainder(self, + expr: p.Remainder, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(( + self.rec(expr.numerator, *args, **kwargs), + self.rec(expr.denominator, *args, **kwargs))) - def map_power(self, expr, *args, **kwargs): + def map_power(self, + expr: p.Power, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.base, *args, **kwargs), self.rec(expr.exponent, *args, **kwargs))) - def map_polynomial(self, expr, *args, **kwargs): + def map_left_shift(self, + expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( - self.rec(expr.base, *args, **kwargs), - *[self.rec(coeff, *args, **kwargs) for exp, coeff in expr.data] - )) + self.rec(expr.shiftee, *args, **kwargs), + self.rec(expr.shift, *args, **kwargs))) - def map_left_shift(self, expr, *args, **kwargs): + def map_right_shift(self, + expr: p.RightShift, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.shiftee, *args, **kwargs), self.rec(expr.shift, *args, **kwargs))) - map_right_shift = map_left_shift + def map_bitwise_not(self, + expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.rec(expr.child, *args, **kwargs) - def map_bitwise_not(self, expr, *args, **kwargs): + def map_bitwise_or(self, + expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_bitwise_and(self, + expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_bitwise_xor(self, + expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_logical_not(self, + expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.rec(expr.child, *args, **kwargs) - map_bitwise_or = map_sum - map_bitwise_xor = map_sum - map_bitwise_and = map_sum - map_logical_not = map_bitwise_not - map_logical_and = map_sum - map_logical_or = map_sum + def map_logical_or(self, + expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_logical_and(self, + expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) - def map_comparison(self, expr, *args, **kwargs): + def map_comparison(self, + expr: p.Comparison, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.left, *args, **kwargs), self.rec(expr.right, *args, **kwargs))) - map_max = map_sum - map_min = map_sum + def map_max(self, + expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) - def map_list(self, expr, *args, **kwargs): + def map_min(self, + expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_tuple(self, + expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.combine(self.rec(child, *args, **kwargs) for child in expr) - map_tuple = map_list + def map_list(self, + expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) for child in expr) - def map_numpy_array(self, expr, *args, **kwargs): + def map_numpy_array(self, + expr: np.ndarray, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.combine(self.rec(el, *args, **kwargs) for el in expr.flat) - def map_multivector(self, expr, *args, **kwargs): + def map_multivector(self, + expr: MultiVector[ArithmeticExpressionT], + *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.combine( self.rec(coeff, *args, **kwargs) for bits, coeff in expr.data.items()) - def map_common_subexpression(self, expr, *args, **kwargs): + def map_common_subexpression(self, + expr: p.CommonSubexpression, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.rec(expr.child, *args, **kwargs) - def map_if_positive(self, expr, *args, **kwargs): - return self.combine([ - self.rec(expr.criterion, *args, **kwargs), - self.rec(expr.then, *args, **kwargs), - self.rec(expr.else_, *args, **kwargs)]) - - def map_if(self, expr, *args, **kwargs): + def map_if(self, + expr: p.If, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine([ self.rec(expr.condition, *args, **kwargs), self.rec(expr.then, *args, **kwargs), @@ -419,7 +635,10 @@ class CachedCombineMapper(CachedMapper, CombineMapper): # {{{ collector -class Collector(CombineMapper): +CollectedT = TypeVar("CollectedT") + + +class Collector(CombineMapper[Set[CollectedT], P]): """A subclass of :class:`CombineMapper` for the common purpose of collecting data derived from an expression in a set that gets 'unioned' across children at each non-leaf node in the expression tree. @@ -429,19 +648,36 @@ class Collector(CombineMapper): .. versionadded:: 2014.3 """ - def combine(self, values): + def combine(self, + values: Iterable[Set[CollectedT]] + ) -> Set[CollectedT]: import operator from functools import reduce return reduce(operator.or_, values, set()) - def map_constant(self, expr, *args, **kwargs): + def map_constant(self, expr: object, + *args: P.args, **kwargs: P.kwargs) -> Set[CollectedT]: return set() - map_variable = map_constant - map_wildcard = map_constant - map_dot_wildcard = map_constant - map_star_wildcard = map_constant - map_function_symbol = map_constant + def map_variable(self, expr: p.Variable, + *args: P.args, **kwargs: P.kwargs) -> Set[CollectedT]: + return set() + + def map_wildcard(self, expr: p.Wildcard, + *args: P.args, **kwargs: P.kwargs) -> Set[CollectedT]: + return set() + + def map_dot_wildcard(self, expr: p.DotWildcard, + *args: P.args, **kwargs: P.kwargs) -> Set[CollectedT]: + return set() + + def map_star_wildcard(self, expr: p.StarWildcard, + *args: P.args, **kwargs: P.kwargs) -> Set[CollectedT]: + return set() + + def map_function_symbol(self, expr: p.FunctionSymbol, + *args: P.args, **kwargs: P.kwargs) -> Set[CollectedT]: + return set() class CachedCollector(CachedMapper, Collector): @@ -452,142 +688,250 @@ class CachedCollector(CachedMapper, Collector): # {{{ identity mapper -class IdentityMapper(Mapper): +class IdentityMapper(Mapper[ExpressionT, P]): """A :class:`Mapper` whose default mapper methods make a deep copy of each subexpression. See :ref:`custom-manipulation` for an example of the manipulations that can be implemented this way. + + .. automethod:: rec_arith """ - def map_constant(self, expr, *args, **kwargs): + + def rec_arith(self, + expr: ArithmeticExpressionT, *args: P.args, **kwargs: P.kwargs + ) -> ArithmeticExpressionT: + res = self.rec(expr, *args, **kwargs) + assert p.is_arithmetic_expression(res) + return res + + def map_constant(self, + expr: object, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: # leaf -- no need to rebuild + assert p.is_valid_operand(expr) return expr - def map_variable(self, expr, *args, **kwargs): + def map_variable(self, + expr: p.Variable, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: # leaf -- no need to rebuild return expr - def map_wildcard(self, expr, *args, **kwargs): + def map_wildcard(self, + expr: p.Wildcard, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr - def map_dot_wildcard(self, expr, *args, **kwargs): + def map_dot_wildcard(self, + expr: p.DotWildcard, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr - def map_star_wildcard(self, expr, *args, **kwargs): + def map_star_wildcard(self, + expr: p.StarWildcard, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr - def map_function_symbol(self, expr, *args, **kwargs): + def map_function_symbol(self, + expr: p.FunctionSymbol, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr - def map_call(self, expr, *args, **kwargs): + def map_call(self, + expr: p.Call, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: function = self.rec(expr.function, *args, **kwargs) parameters = tuple([ self.rec(child, *args, **kwargs) for child in expr.parameters ]) if (function is expr.function and all(child is orig_child - for child, orig_child in zip(expr.parameters, parameters))): + for child, orig_child in zip( + expr.parameters, parameters, strict=True))): return expr return type(expr)(function, parameters) - def map_call_with_kwargs(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: p.CallWithKwargs, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: function = self.rec(expr.function, *args, **kwargs) parameters = tuple([ self.rec(child, *args, **kwargs) for child in expr.parameters ]) - kw_parameters = immutabledict({ + kw_parameters: Mapping[str, ExpressionT] = immutabledict({ key: self.rec(val, *args, **kwargs) for key, val in expr.kw_parameters.items()}) if (function is expr.function and all(child is orig_child for child, orig_child in - zip(parameters, expr.parameters)) + zip(parameters, expr.parameters, strict=True)) and all(kw_parameters[k] is v for k, v in expr.kw_parameters.items())): return expr return type(expr)(function, parameters, kw_parameters) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: p.Subscript, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: aggregate = self.rec(expr.aggregate, *args, **kwargs) index = self.rec(expr.index, *args, **kwargs) if aggregate is expr.aggregate and index is expr.index: return expr return type(expr)(aggregate, index) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup(self, + expr: p.Lookup, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: aggregate = self.rec(expr.aggregate, *args, **kwargs) if aggregate is expr.aggregate: return expr return type(expr)(aggregate, expr.name) - def map_sum(self, expr, *args, **kwargs): + def map_sum(self, + expr: p.Sum, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(tuple(children)) - map_product = map_sum + def map_product(self, + expr: p.Product, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children, strict=True)): + return expr + + return type(expr)(tuple(children)) + + def map_quotient(self, + expr: p.Quotient, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + numerator = self.rec_arith(expr.numerator, *args, **kwargs) + denominator = self.rec_arith(expr.denominator, *args, **kwargs) + if numerator is expr.numerator and denominator is expr.denominator: + return expr + return expr.__class__(numerator, denominator) - def map_quotient(self, expr, *args, **kwargs): - numerator = self.rec(expr.numerator, *args, **kwargs) - denominator = self.rec(expr.denominator, *args, **kwargs) + def map_floor_div(self, + expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + numerator = self.rec_arith(expr.numerator, *args, **kwargs) + denominator = self.rec_arith(expr.denominator, *args, **kwargs) if numerator is expr.numerator and denominator is expr.denominator: return expr return expr.__class__(numerator, denominator) - map_floor_div = map_quotient - map_remainder = map_quotient + def map_remainder(self, + expr: p.Remainder, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + numerator = self.rec_arith(expr.numerator, *args, **kwargs) + denominator = self.rec_arith(expr.denominator, *args, **kwargs) + if numerator is expr.numerator and denominator is expr.denominator: + return expr + return expr.__class__(numerator, denominator) - def map_power(self, expr, *args, **kwargs): - base = self.rec(expr.base, *args, **kwargs) - exponent = self.rec(expr.exponent, *args, **kwargs) + def map_power(self, + expr: p.Power, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + base = self.rec_arith(expr.base, *args, **kwargs) + exponent = self.rec_arith(expr.exponent, *args, **kwargs) if base is expr.base and exponent is expr.exponent: return expr return expr.__class__(base, exponent) - def map_polynomial(self, expr, *args, **kwargs): - base = self.rec(expr.base, *args, **kwargs) - data = ((exp, self.rec(coeff, *args, **kwargs)) - for exp, coeff in expr.data) - if base is expr.base and all( - t[1] is orig_t[1] for t, orig_t in zip(data, expr.data)): + def map_left_shift(self, + expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + shiftee = self.rec(expr.shiftee, *args, **kwargs) + shift = self.rec(expr.shift, *args, **kwargs) + if shiftee is expr.shiftee and shift is expr.shift: return expr - return expr.__class__(base, data) + return type(expr)(shiftee, shift) - def map_left_shift(self, expr, *args, **kwargs): + def map_right_shift(self, + expr: p.RightShift, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: shiftee = self.rec(expr.shiftee, *args, **kwargs) shift = self.rec(expr.shift, *args, **kwargs) if shiftee is expr.shiftee and shift is expr.shift: return expr return type(expr)(shiftee, shift) - map_right_shift = map_left_shift + def map_bitwise_not(self, + expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + child = self.rec(expr.child, *args, **kwargs) + if child is expr.child: + return expr + return type(expr)(child) + + def map_bitwise_or(self, + expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children, strict=True)): + return expr + + return type(expr)(tuple(children)) + + def map_bitwise_and(self, + expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children, strict=True)): + return expr - def map_bitwise_not(self, expr, *args, **kwargs): + return type(expr)(tuple(children)) + + def map_bitwise_xor(self, + expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children, strict=True)): + return expr + + return type(expr)(tuple(children)) + + def map_logical_not(self, + expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: child = self.rec(expr.child, *args, **kwargs) if child is expr.child: return expr return type(expr)(child) - def map_bitwise_or(self, expr, *args, **kwargs): + def map_logical_or(self, + expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(tuple(children)) - map_bitwise_xor = map_bitwise_or - map_bitwise_and = map_bitwise_or + def map_logical_and(self, + expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children, strict=True)): + return expr - map_logical_not = map_bitwise_not - map_logical_or = map_bitwise_or - map_logical_and = map_bitwise_or + return type(expr)(tuple(children)) - def map_comparison(self, expr, *args, **kwargs): + def map_comparison(self, + expr: p.Comparison, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: left = self.rec(expr.left, *args, **kwargs) right = self.rec(expr.right, *args, **kwargs) if left is expr.left and right is expr.right: @@ -595,32 +939,45 @@ def map_comparison(self, expr, *args, **kwargs): return type(expr)(left, expr.operator, right) - def map_list(self, expr, *args, **kwargs): - return [self.rec(child, *args, **kwargs) for child in expr] + def map_list(self, + expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + + # True fact: lists aren't expressions + return [self.rec(child, *args, **kwargs) for child in expr] # type: ignore[return-value] - def map_tuple(self, expr, *args, **kwargs): + def map_tuple(self, + expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr] if all(child is orig_child - for child, orig_child in zip(children, expr)): + for child, orig_child in zip(children, expr, strict=True)): return expr return tuple(children) - def map_numpy_array(self, expr, *args, **kwargs): + def map_numpy_array(self, + expr: np.ndarray, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + import numpy result = numpy.empty(expr.shape, dtype=object) for i in numpy.ndindex(expr.shape): result[i] = self.rec(expr[i], *args, **kwargs) - return result - def map_multivector(self, expr, *args, **kwargs): + # True fact: ndarrays aren't expressions + return result # type: ignore[return-value] + + def map_multivector(self, + expr: MultiVector[ArithmeticExpressionT], + *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr.map(lambda ch: self.rec(ch, *args, **kwargs)) - def map_common_subexpression(self, expr, *args, **kwargs): - from pymbolic.primitives import is_zero + def map_common_subexpression(self, + expr: p.CommonSubexpression, + *args: P.args, **kwargs: P.kwargs) -> ExpressionT: result = self.rec(expr.child, *args, **kwargs) - if is_zero(result): - return 0 if result is expr.child: return expr @@ -630,45 +987,40 @@ def map_common_subexpression(self, expr, *args, **kwargs): expr.scope, **expr.get_extra_properties()) - def map_substitution(self, expr, *args, **kwargs): + def map_substitution(self, + expr: p.Substitution, + *args: P.args, **kwargs: P.kwargs) -> ExpressionT: child = self.rec(expr.child, *args, **kwargs) values = tuple([self.rec(v, *args, **kwargs) for v in expr.values]) if child is expr.child and all(val is orig_val - for val, orig_val in zip(values, expr.values)): + for val, orig_val in zip(values, expr.values, strict=True)): return expr return type(expr)(child, expr.variables, values) - def map_derivative(self, expr, *args, **kwargs): + def map_derivative(self, + expr: p.Derivative, + *args: P.args, **kwargs: P.kwargs) -> ExpressionT: child = self.rec(expr.child, *args, **kwargs) if child is expr.child: return expr return type(expr)(child, expr.variables) - def map_slice(self, expr, *args, **kwargs): - children = tuple([ + def map_slice(self, + expr: p.Slice, + *args: P.args, **kwargs: P.kwargs) -> ExpressionT: + children: p.SliceChildrenT = cast(p.SliceChildrenT, tuple([ None if child is None else self.rec(child, *args, **kwargs) for child in expr.children - ]) + ])) if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(children) - def map_if_positive(self, expr, *args, **kwargs): - criterion = self.rec(expr.criterion, *args, **kwargs) - then = self.rec(expr.then, *args, **kwargs) - else_ = self.rec(expr.else_, *args, **kwargs) - if criterion is expr.criterion \ - and then is expr.then \ - and else_ is expr.else_: - return expr - - return type(expr)(criterion, then, else_) - - def map_if(self, expr, *args, **kwargs): + def map_if(self, expr: p.If, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: condition = self.rec(expr.condition, *args, **kwargs) then = self.rec(expr.then, *args, **kwargs) else_ = self.rec(expr.else_, *args, **kwargs) @@ -679,24 +1031,32 @@ def map_if(self, expr, *args, **kwargs): return type(expr)(condition, then, else_) - def map_min(self, expr, *args, **kwargs): + def map_min(self, expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: children = tuple([ self.rec(child, *args, **kwargs) for child in expr.children ]) if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(children) - map_max = map_min + def map_max(self, expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: + children = tuple([ + self.rec(child, *args, **kwargs) for child in expr.children + ]) + if all(child is orig_child + for child, orig_child in zip(children, expr.children, strict=True)): + return expr + + return type(expr)(children) - def map_nan(self, expr, *args, **kwargs): + def map_nan(self, expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: # Leaf node -- don't recurse return expr -class CachedIdentityMapper(CachedMapper, IdentityMapper): +class CachedIdentityMapper(CachedMapper[ExpressionT, P], IdentityMapper[P]): pass # }}} @@ -704,7 +1064,7 @@ class CachedIdentityMapper(CachedMapper, IdentityMapper): # {{{ walk mapper -class WalkMapper(RecursiveMapper): +class WalkMapper(Mapper[None, P]): """A mapper whose default mapper method implementations simply recurse without propagating any result. Also calls :meth:`visit` for each visited subexpression. @@ -721,21 +1081,39 @@ class WalkMapper(RecursiveMapper): Is called after a node's children are visited. """ - def map_constant(self, expr, *args, **kwargs): + def map_constant(self, expr: object, *args: P.args, **kwargs: P.kwargs) -> None: + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) + + def map_variable(self, expr: p.Variable, *args: P.args, **kwargs: P.kwargs) -> None: + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) + + def map_wildcard(self, expr: p.Wildcard, *args: P.args, **kwargs: P.kwargs) -> None: + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) + + def map_dot_wildcard(self, + expr: p.DotWildcard, *args: P.args, **kwargs: P.kwargs) -> None: self.visit(expr, *args, **kwargs) self.post_visit(expr, *args, **kwargs) - def map_variable(self, expr, *args, **kwargs): + def map_star_wildcard(self, + expr: p.StarWildcard, *args: P.args, **kwargs: P.kwargs) -> None: self.visit(expr, *args, **kwargs) self.post_visit(expr, *args, **kwargs) - map_wildcard = map_variable - map_dot_wildcard = map_variable - map_star_wildcard = map_variable - map_function_symbol = map_variable - map_nan = map_variable + def map_function_symbol(self, + expr: p.FunctionSymbol, *args: P.args, **kwargs: P.kwargs) -> None: + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) - def map_call(self, expr, *args, **kwargs): + def map_nan(self, + expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> None: + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) + + def map_call(self, expr: p.Call, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -745,7 +1123,9 @@ def map_call(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - def map_call_with_kwargs(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: p.CallWithKwargs, + *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -758,7 +1138,9 @@ def map_call_with_kwargs(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: p.Subscript, + *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -767,7 +1149,8 @@ def map_subscript(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup(self, + expr: p.Lookup, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -775,7 +1158,7 @@ def map_lookup(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - def map_sum(self, expr, *args, **kwargs): + def map_sum(self, expr: p.Sum, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -784,9 +1167,16 @@ def map_sum(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - map_product = map_sum + def map_product(self, expr: p.Product, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) - def map_quotient(self, expr, *args, **kwargs): + def map_quotient(self, expr: p.Quotient, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -795,29 +1185,37 @@ def map_quotient(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - map_floor_div = map_quotient - map_remainder = map_quotient + def map_floor_div(self, + expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr.numerator, *args, **kwargs) + self.rec(expr.denominator, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) - def map_power(self, expr, *args, **kwargs): + def map_remainder(self, + expr: p.Remainder, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return - self.rec(expr.base, *args, **kwargs) - self.rec(expr.exponent, *args, **kwargs) + self.rec(expr.numerator, *args, **kwargs) + self.rec(expr.denominator, *args, **kwargs) self.post_visit(expr, *args, **kwargs) - def map_polynomial(self, expr, *args, **kwargs): + def map_power(self, expr: p.Power, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return self.rec(expr.base, *args, **kwargs) - for _exp, coeff in expr.data: - self.rec(coeff, *args, **kwargs) + self.rec(expr.exponent, *args, **kwargs) self.post_visit(expr, *args, **kwargs) - def map_list(self, expr, *args, **kwargs): + def map_tuple(self, + expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -826,9 +1224,8 @@ def map_list(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - map_tuple = map_list - - def map_numpy_array(self, expr, *args, **kwargs): + def map_numpy_array(self, + expr: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -838,16 +1235,19 @@ def map_numpy_array(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - def map_multivector(self, expr, *args, **kwargs): + def map_multivector(self, + expr: MultiVector[ArithmeticExpressionT], + *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return for _bits, coeff in expr.data.items(): - self.rec(coeff) + self.rec(coeff, *args, **kwargs) self.post_visit(expr, *args, **kwargs) - def map_common_subexpression(self, expr, *args, **kwargs): + def map_common_subexpression(self, + expr: p.CommonSubexpression, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -855,7 +1255,8 @@ def map_common_subexpression(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - def map_left_shift(self, expr, *args, **kwargs): + def map_left_shift(self, + expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -864,9 +1265,18 @@ def map_left_shift(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - map_right_shift = map_left_shift + def map_right_shift(self, + expr: p.RightShift, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr.shift, *args, **kwargs) + self.rec(expr.shiftee, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) - def map_bitwise_not(self, expr, *args, **kwargs): + def map_bitwise_not(self, + expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -874,11 +1284,37 @@ def map_bitwise_not(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - map_bitwise_or = map_sum - map_bitwise_xor = map_sum - map_bitwise_and = map_sum + def map_bitwise_or(self, + expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_bitwise_xor(self, + expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_bitwise_and(self, + expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return - def map_comparison(self, expr, *args, **kwargs): + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_comparison(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -887,11 +1323,36 @@ def map_comparison(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - map_logical_not = map_bitwise_not - map_logical_and = map_sum - map_logical_or = map_sum + def map_logical_not(self, + expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr.child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_logical_or(self, + expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_logical_and(self, + expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) - def map_if(self, expr, *args, **kwargs): + def map_if(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -901,7 +1362,7 @@ def map_if(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - def map_if_positive(self, expr, *args, **kwargs): + def map_if_positive(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -911,11 +1372,28 @@ def map_if_positive(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - map_min = map_sum - map_max = map_sum + def map_min(self, + expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_max(self, + expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) - def map_substitution(self, expr, *args, **kwargs): - if not self.visit(expr): + self.post_visit(expr, *args, **kwargs) + + def map_substitution(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): return self.rec(expr.child, *args, **kwargs) @@ -924,7 +1402,7 @@ def map_substitution(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - def map_derivative(self, expr, *args, **kwargs): + def map_derivative(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -932,7 +1410,7 @@ def map_derivative(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - def map_slice(self, expr, *args, **kwargs): + def map_slice(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -945,10 +1423,10 @@ def map_slice(self, expr, *args, **kwargs): self.post_visit(expr, *args, **kwargs) - def visit(self, expr, *args, **kwargs): + def visit(self, expr, *args: P.args, **kwargs: P.kwargs) -> bool: return True - def post_visit(self, expr, *args, **kwargs): + def post_visit(self, expr, *args: P.args, **kwargs: P.kwargs) -> None: pass @@ -960,7 +1438,8 @@ class CachedWalkMapper(CachedMapper, WalkMapper): # {{{ callback mapper -class CallbackMapper(RecursiveMapper): +# FIXME: Is it worth typing this? +class CallbackMapper(Mapper): def __init__(self, function, fallback_mapper): self.function = function self.fallback_mapper = fallback_mapper @@ -993,7 +1472,6 @@ def map_constant(self, expr, *args, **kwargs): map_logical_or = map_constant map_logical_and = map_constant - map_polynomial = map_constant map_list = map_constant map_tuple = map_constant map_numpy_array = map_constant @@ -1007,44 +1485,7 @@ def map_constant(self, expr, *args, **kwargs): # {{{ caching mixins -class CachingMapperMixin: - def __init__(self): - super().__init__() - self.result_cache = {} - - from warnings import warn - warn("CachingMapperMixin is deprecated and will be removed " - "in version 2023.x. Use CachedMapper instead.", - DeprecationWarning, stacklevel=2) - - def rec(self, expr): - try: - return self.result_cache[expr] - except TypeError: - # not hashable, oh well - method_name = getattr(expr, "mapper_method", None) - if method_name is not None: - method = getattr(self, method_name, None) - if method is not None: - return method(expr, ) - return super().rec(expr) - except KeyError: - method_name = getattr(expr, "mapper_method", None) - if method_name is not None: - method = getattr(self, method_name, None) - if method is not None: - result = method(expr, ) - self.result_cache[expr] = result - return result - - result = super().rec(expr) - self.result_cache[expr] = result - return result - - __call__ = rec - - -class CSECachingMapperMixin(ABC): +class CSECachingMapperMixin(ABC, Generic[ResultT, P]): """A :term:`mix-in` that helps subclassed mappers implement caching for :class:`pymbolic.primitives.CommonSubexpression` @@ -1059,25 +1500,40 @@ class CSECachingMapperMixin(ABC): This method deliberately does not support extra arguments in mapper dispatch, to avoid spurious dependencies of the cache on these arguments. """ + _cse_cache_dict: dict[tuple[ExpressionT, P.args, P.kwargs], ResultT] - def map_common_subexpression(self, expr, *args): + def map_common_subexpression(self, + expr: p.CommonSubexpression, + *args: P.args, **kwargs: P.kwargs) -> ResultT: try: ccd = self._cse_cache_dict except AttributeError: ccd = self._cse_cache_dict = {} - key = (expr, *args) + key: tuple[ExpressionT, P.args, P.kwargs] = (expr, args, immutabledict(kwargs)) try: return ccd[key] except KeyError: - result = self.map_common_subexpression_uncached(expr, *args) + result = self.map_common_subexpression_uncached(expr, *args, **kwargs) ccd[key] = result return result @abstractmethod - def map_common_subexpression_uncached(self, expr, *args): + def map_common_subexpression_uncached(self, + expr: p.CommonSubexpression, + *args: P.args, **kwargs: P.kwargs) -> ResultT: pass # }}} + +def __getattr__(name: str) -> object: + if name == "RecursiveMapper": + warn("RecursiveMapper is deprecated. Use Mapper instead. " + "RecursiveMapper will go away in 2026.", + DeprecationWarning, stacklevel=2) + return Mapper + + return None + # vim: foldmethod=marker diff --git a/pymbolic/mapper/coefficient.py b/pymbolic/mapper/coefficient.py index f3c6dec7..6a7feb55 100644 --- a/pymbolic/mapper/coefficient.py +++ b/pymbolic/mapper/coefficient.py @@ -23,17 +23,25 @@ THE SOFTWARE. """ +from collections.abc import Collection, Mapping +from typing import Literal, TypeAlias, cast + +import pymbolic.primitives as p from pymbolic.mapper import Mapper +from pymbolic.typing import ArithmeticExpressionT + + +CoeffsT: TypeAlias = Mapping[p.AlgebraicLeaf | Literal[1], ArithmeticExpressionT] -class CoefficientCollector(Mapper): - def __init__(self, target_names=None): +class CoefficientCollector(Mapper[CoeffsT, []]): + def __init__(self, target_names: Collection[str] | None = None) -> None: self.target_names = target_names - def map_sum(self, expr): + def map_sum(self, expr: p.Sum) -> CoeffsT: stride_dicts = [self.rec(ch) for ch in expr.children] - result = {} + result: dict[p.AlgebraicLeaf | Literal[1], ArithmeticExpressionT] = {} for stride_dict in stride_dicts: for var, stride in stride_dict.items(): if var in result: @@ -43,9 +51,7 @@ def map_sum(self, expr): return result - def map_product(self, expr): - result = {} - + def map_product(self, expr: p.Product) -> CoeffsT: children_coeffs = [self.rec(child) for child in expr.children] idx_of_child_with_vars = None @@ -58,35 +64,33 @@ def map_product(self, expr): "nonlinear expression") idx_of_child_with_vars = i - other_coeffs = 1 + other_coeffs: ArithmeticExpressionT = 1 for i, child_coeffs in enumerate(children_coeffs): if i != idx_of_child_with_vars: assert len(child_coeffs) == 1 - other_coeffs *= child_coeffs[1] + other_coeffs *= cast(ArithmeticExpressionT, child_coeffs[1]) if idx_of_child_with_vars is None: return {1: other_coeffs} else: return { - var: other_coeffs*coeff + var: p.flattened_product((other_coeffs, coeff)) for var, coeff in children_coeffs[idx_of_child_with_vars].items()} - return result - - def map_quotient(self, expr): + def map_quotient(self, expr: p.Quotient) -> CoeffsT: from pymbolic.primitives import Quotient - d_num = self.rec(expr.numerator) + d_num = dict(self.rec(expr.numerator)) d_den = self.rec(expr.denominator) # d_den should look like {1: k} if len(d_den) > 1 or 1 not in d_den: raise RuntimeError("nonlinear expression") val = d_den[1] for k in d_num.keys(): - d_num[k] *= Quotient(1, val) + d_num[k] = p.flattened_product((d_num[k], Quotient(1, val))) return d_num - def map_power(self, expr): + def map_power(self, expr: p.Power) -> CoeffsT: d_base = self.rec(expr.base) d_exponent = self.rec(expr.exponent) # d_exponent should look like {1: k} @@ -97,11 +101,19 @@ def map_power(self, expr): raise RuntimeError("nonlinear expression") return {1: expr} - def map_constant(self, expr): - return {1: expr} + def map_constant(self, expr: object) -> CoeffsT: + assert p.is_arithmetic_expression(expr) + from pymbolic.primitives import is_zero + return {} if is_zero(expr) else {1: expr} - def map_algebraic_leaf(self, expr): + def map_variable(self, expr: p.Variable) -> CoeffsT: if self.target_names is None or expr.name in self.target_names: return {expr: 1} else: return {1: expr} + + def map_algebraic_leaf(self, expr: p.AlgebraicLeaf) -> CoeffsT: + if self.target_names is None: + return {expr: 1} + else: + return {1: expr} diff --git a/pymbolic/mapper/collector.py b/pymbolic/mapper/collector.py index 2799cba3..5b211948 100644 --- a/pymbolic/mapper/collector.py +++ b/pymbolic/mapper/collector.py @@ -26,11 +26,17 @@ THE SOFTWARE. """ +from collections.abc import Sequence, Set +from typing import cast + import pymbolic +import pymbolic.primitives as p from pymbolic.mapper import IdentityMapper +from pymbolic.mapper.dependency import DependenciesT +from pymbolic.typing import ArithmeticExpressionT, ExpressionT -class TermCollector(IdentityMapper): +class TermCollector(IdentityMapper[[]]): """A term collector that assumes that multiplication is commutative. Allows specifying *parameters* (a set of @@ -38,16 +44,19 @@ class TermCollector(IdentityMapper): coefficients and are not used for term collection. """ - def __init__(self, parameters=None): + def __init__(self, parameters: Set[p.AlgebraicLeaf] | None = None): if parameters is None: parameters = set() self.parameters = parameters - def get_dependencies(self, expr): + def get_dependencies(self, expr: ExpressionT) -> DependenciesT: from pymbolic.mapper.dependency import DependencyMapper return DependencyMapper()(expr) - def split_term(self, mul_term): + def split_term(self, mul_term: ExpressionT) -> tuple[ + Set[tuple[ArithmeticExpressionT, ArithmeticExpressionT]], + ArithmeticExpressionT + ]: """Returns a pair consisting of: - a frozenset of (base, exponent) pairs - a product of coefficients (i.e. constants and parameters) @@ -58,28 +67,29 @@ def split_term(self, mul_term): """ from pymbolic.primitives import AlgebraicLeaf, Power, Product - def base(term): + def base(term: ExpressionT) -> ArithmeticExpressionT: if isinstance(term, Power): return term.base else: + assert p.is_arithmetic_expression(term) return term - def exponent(term): + def exponent(term: ExpressionT) -> ArithmeticExpressionT: if isinstance(term, Power): return term.exponent else: return 1 if isinstance(mul_term, Product): - terms = mul_term.children - elif isinstance(mul_term, (Power, AlgebraicLeaf)): + terms: Sequence[ExpressionT] = mul_term.children + elif isinstance(mul_term, Power | AlgebraicLeaf): terms = [mul_term] elif not bool(self.get_dependencies(mul_term)): terms = [mul_term] else: raise RuntimeError("split_term expects a multiplicative term") - base2exp = {} + base2exp: dict[ArithmeticExpressionT, ArithmeticExpressionT] = {} for term in terms: mybase = base(term) myexp = exponent(term) @@ -91,20 +101,23 @@ def exponent(term): coefficients = [] cleaned_base2exp = {} - for base, exp in base2exp.items(): - term = base**exp + for item_base, item_exp in base2exp.items(): + term = item_base**item_exp if self.get_dependencies(term) <= self.parameters: coefficients.append(term) else: - cleaned_base2exp[base] = exp + cleaned_base2exp[item_base] = item_exp - term = frozenset( + base_exp_set = frozenset( (base, exp) for base, exp in cleaned_base2exp.items()) - return term, self.rec(pymbolic.flattened_product(coefficients)) - - def map_sum(self, mysum): - term2coeff = {} - for child in mysum.children: + return base_exp_set, cast(ArithmeticExpressionT, + self.rec(pymbolic.flattened_product(coefficients))) + + def map_sum(self, expr: p.Sum) -> ExpressionT: + term2coeff: dict[ + Set[tuple[ArithmeticExpressionT, ArithmeticExpressionT]], + ArithmeticExpressionT] = {} + for child in expr.children: term, coeff = self.split_term(child) term2coeff[term] = term2coeff.get(term, 0) + coeff diff --git a/pymbolic/mapper/constant_folder.py b/pymbolic/mapper/constant_folder.py index 62b3eddf..e68971ab 100644 --- a/pymbolic/mapper/constant_folder.py +++ b/pymbolic/mapper/constant_folder.py @@ -27,13 +27,18 @@ THE SOFTWARE. """ +from collections.abc import Callable + from pymbolic.mapper import ( CSECachingMapperMixin, IdentityMapper, + Mapper, ) +from pymbolic.primitives import Product, Sum, is_arithmetic_expression +from pymbolic.typing import ArithmeticExpressionT, ExpressionT -class ConstantFoldingMapperBase: +class ConstantFoldingMapperBase(Mapper[ExpressionT, []]): def is_constant(self, expr): from pymbolic.mapper.dependency import DependencyMapper return not bool(DependencyMapper()(expr)) @@ -45,15 +50,27 @@ def evaluate(self, expr): except ValueError: return None - def fold(self, expr, klass, op, constructor): + def fold(self, + expr: Sum | Product, + op: Callable[ + [ArithmeticExpressionT, ArithmeticExpressionT], + ArithmeticExpressionT], + constructor: Callable[ + [tuple[ArithmeticExpressionT, ...]], + ArithmeticExpressionT], + ) -> ExpressionT: + klass = type(expr) - constants = [] - nonconstants = [] + constants: list[ArithmeticExpressionT] = [] + nonconstants: list[ArithmeticExpressionT] = [] queue = list(expr.children) while queue: - child = self.rec(queue.pop(0)) # pylint:disable=no-member + child = self.rec(queue.pop(0)) + assert is_arithmetic_expression(child) + if isinstance(child, klass): + assert isinstance(child, Sum | Product) queue = list(child.children) + queue else: if self.is_constant(child): @@ -73,37 +90,36 @@ def fold(self, expr, klass, op, constructor): else: return constructor(tuple(nonconstants)) - def map_sum(self, expr): + def map_sum(self, expr: Sum) -> ExpressionT: import operator - from pymbolic.primitives import Sum, flattened_sum + from pymbolic.primitives import flattened_sum - return self.fold(expr, Sum, operator.add, flattened_sum) + return self.fold(expr, operator.add, flattened_sum) class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase): def map_product(self, expr): import operator - from pymbolic.primitives import Product, flattened_product + from pymbolic.primitives import flattened_product - return self.fold(expr, Product, operator.mul, flattened_product) + return self.fold(expr, operator.mul, flattened_product) class ConstantFoldingMapper( - CSECachingMapperMixin, + CSECachingMapperMixin[ExpressionT, []], ConstantFoldingMapperBase, - IdentityMapper): + IdentityMapper[[]]): map_common_subexpression_uncached = \ IdentityMapper.map_common_subexpression -# Yes, map_product incompatible: missing *args, **kwargs -class CommutativeConstantFoldingMapper( # type: ignore[misc] - CSECachingMapperMixin, +class CommutativeConstantFoldingMapper( + CSECachingMapperMixin[ExpressionT, []], CommutativeConstantFoldingMapperBase, - IdentityMapper): + IdentityMapper[[]]): map_common_subexpression_uncached = \ IdentityMapper.map_common_subexpression diff --git a/pymbolic/mapper/cse_tagger.py b/pymbolic/mapper/cse_tagger.py index 7d46f797..7734dfb0 100644 --- a/pymbolic/mapper/cse_tagger.py +++ b/pymbolic/mapper/cse_tagger.py @@ -54,7 +54,6 @@ def map_call(self, expr): map_floor_div = map_call map_remainder = map_call map_power = map_call - map_polynomial = map_call map_left_shift = map_call map_right_shift = map_call diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index b4473e3a..89b91db2 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -2,6 +2,7 @@ .. autoclass:: DependencyMapper .. autoclass:: CachedDependencyMapper """ + from __future__ import annotations @@ -27,10 +28,20 @@ THE SOFTWARE. """ -from pymbolic.mapper import CachedMapper, Collector, CSECachingMapperMixin +from collections.abc import Set +from typing import TypeAlias + +import pymbolic.primitives as p +from pymbolic.mapper import CachedMapper, Collector, CSECachingMapperMixin, P + + +DependenciesT: TypeAlias = Set[p.AlgebraicLeaf | p.CommonSubexpression] -class DependencyMapper(CSECachingMapperMixin, Collector): +class DependencyMapper( + CSECachingMapperMixin[DependenciesT, P], + Collector[p.AlgebraicLeaf | p.CommonSubexpression, P], +): """Maps an expression to the :class:`set` of expressions it is based on. The ``include_*`` arguments to the constructor determine which types of objects occur in this output set. @@ -38,12 +49,14 @@ class DependencyMapper(CSECachingMapperMixin, Collector): instances are included. """ - def __init__(self, - include_subscripts=True, - include_lookups=True, - include_calls=True, - include_cses=False, - composite_leaves=None): + def __init__( + self, + include_subscripts: bool = True, + include_lookups: bool = True, + include_calls: bool = True, + include_cses: bool = False, + composite_leaves: bool | None = None, + ): """ :arg composite_leaves: Setting this is equivalent to setting all preceding ``include_*`` flags. @@ -66,68 +79,92 @@ def __init__(self, self.include_cses = include_cses - def map_variable(self, expr, *args, **kwargs): + def map_variable( + self, expr: p.Variable, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: return {expr} - def map_call(self, expr, *args, **kwargs): + def map_call( + self, expr: p.Call, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: if self.include_calls == "descend_args": - return self.combine( - [self.rec(child, *args, **kwargs) for child in expr.parameters]) + return self.combine([ + self.rec(child, *args, **kwargs) for child in expr.parameters + ]) elif self.include_calls: return {expr} else: return super().map_call(expr, *args, **kwargs) - def map_call_with_kwargs(self, expr, *args, **kwargs): + def map_call_with_kwargs( + self, expr: p.CallWithKwargs, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: if self.include_calls == "descend_args": return self.combine( - [self.rec(child, *args, **kwargs) for child in expr.parameters] - + [self.rec(val, *args, **kwargs) for name, val in - expr.kw_parameters.items()] - ) + [self.rec(child, *args, **kwargs) for child in expr.parameters] + + [ + self.rec(val, *args, **kwargs) + for name, val in expr.kw_parameters.items() + ] + ) elif self.include_calls: return {expr} else: return super().map_call_with_kwargs(expr, *args, **kwargs) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup( + self, expr: p.Lookup, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: if self.include_lookups: return {expr} else: return super().map_lookup(expr, *args, **kwargs) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript( + self, expr: p.Subscript, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: if self.include_subscripts: return {expr} else: return super().map_subscript(expr, *args, **kwargs) - def map_common_subexpression_uncached(self, expr, *args, **kwargs): + def map_common_subexpression_uncached( + self, expr: p.CommonSubexpression, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: if self.include_cses: return {expr} else: - return Collector.map_common_subexpression(self, expr, *args, **kwargs) - - def map_slice(self, expr, *args, **kwargs): - return self.combine( - [self.rec(child, *args, **kwargs) for child in expr.children - if child is not None]) - - def map_nan(self, expr, *args, **kwargs): + # FIXME: These look like mypy bugs, revisit + return Collector.map_common_subexpression(self, expr, *args, **kwargs) # type: ignore[return-value, arg-type] + + def map_slice( + self, expr: p.Slice, *args: P.args, **kwargs: P.kwargs + ) -> DependenciesT: + return self.combine([ + self.rec(child, *args, **kwargs) + for child in expr.children + if child is not None + ]) + + def map_nan(self, expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> DependenciesT: return set() class CachedDependencyMapper(CachedMapper, DependencyMapper): - def __init__(self, - include_subscripts=True, - include_lookups=True, - include_calls=True, - include_cses=False, - composite_leaves=None): + def __init__( + self, + include_subscripts=True, + include_lookups=True, + include_calls=True, + include_cses=False, + composite_leaves=None, + ): CachedMapper.__init__(self) - DependencyMapper.__init__(self, - include_subscripts=include_subscripts, - include_lookups=include_lookups, - include_calls=include_calls, - include_cses=include_cses, - composite_leaves=composite_leaves) + DependencyMapper.__init__( + self, + include_subscripts=include_subscripts, + include_lookups=include_lookups, + include_calls=include_calls, + include_cses=include_cses, + composite_leaves=composite_leaves, + ) diff --git a/pymbolic/mapper/differentiator.py b/pymbolic/mapper/differentiator.py index 30b7f23b..c5ec0599 100644 --- a/pymbolic/mapper/differentiator.py +++ b/pymbolic/mapper/differentiator.py @@ -73,7 +73,7 @@ def make_f(name): raise RuntimeError("unrecognized function, cannot differentiate") -class DifferentiationMapper(pymbolic.mapper.RecursiveMapper, +class DifferentiationMapper(pymbolic.mapper.Mapper, pymbolic.mapper.CSECachingMapperMixin): """Example usage: @@ -84,8 +84,9 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper, >>> x = p.Variable("x") >>> expr = x*(x+5)**3/(x-1)**2 + >>> from pymbolic import flatten >>> from pymbolic.mapper.differentiator import DifferentiationMapper as DM - >>> print(DM(x)(expr)) + >>> print(flatten(DM(x)(expr))) (((x + 5)**3 + x*3*(x + 5)**2)*(x + -1)**2 + (-1)*2*(x + -1)*x*(x + 5)**3) / (x + -1)**2**2 """ # noqa: E501 @@ -193,28 +194,6 @@ def map_power(self, expr, *args): return log(f) * f**g * dg + \ g * f**(g-1) * df - def map_polynomial(self, expr, *args): - # (a(x)*f(x))^n)' = a'(x)f(x)^n + a(x)f'(x)*n*f(x)^(n-1) - deriv_coeff = [] - deriv_base = [] - - dbase = self.rec(expr.base, *args) - - for exp, coeff in expr.data: - dcoeff = self.rec(coeff, *args) - if dcoeff: - deriv_coeff.append((exp, dcoeff)) - if dbase and exp > 0: - deriv_base.append((exp-1, exp*dbase*self.rec_undiff(coeff, *args))) - - from pymbolic import Polynomial - - return ( - Polynomial(self.rec_undiff(expr.base, *args), - tuple(deriv_coeff), expr.unit) - + Polynomial(self.rec_undiff(expr.base, *args), - tuple(deriv_base), expr.unit)) - def map_numpy_array(self, expr, *args): import numpy result = numpy.empty(expr.shape, dtype=object) @@ -243,8 +222,9 @@ def differentiate(expression, variable, func_mapper=map_math_functions_by_name, allowed_nonsmoothness="none"): - if not isinstance(variable, (primitives.Variable, primitives.Subscript)): + if not isinstance(variable, primitives.Variable | primitives.Subscript): variable = primitives.make_variable(variable) - return DifferentiationMapper( + from pymbolic import flatten + return flatten(DifferentiationMapper( variable, func_mapper, allowed_nonsmoothness=allowed_nonsmoothness - )(expression) + )(expression)) diff --git a/pymbolic/mapper/distributor.py b/pymbolic/mapper/distributor.py index faa523e6..85d13504 100644 --- a/pymbolic/mapper/distributor.py +++ b/pymbolic/mapper/distributor.py @@ -27,14 +27,17 @@ THE SOFTWARE. """ +from typing import cast + import pymbolic +import pymbolic.primitives as p from pymbolic.mapper import IdentityMapper from pymbolic.mapper.collector import TermCollector from pymbolic.mapper.constant_folder import CommutativeConstantFoldingMapper -from pymbolic.primitives import Product, Sum, is_zero +from pymbolic.typing import ArithmeticExpressionT, ExpressionT -class DistributeMapper(IdentityMapper): +class DistributeMapper(IdentityMapper[[]]): """Example usage: .. doctest:: @@ -47,7 +50,7 @@ class DistributeMapper(IdentityMapper): 7*x**6 + 21*x**5 + 21*x**2 + 35*x**3 + 1 + 35*x**4 + 7*x + x**7 """ - def __init__(self, collector=None, const_folder=None): + def __init__(self, collector=None, const_folder=None) -> None: if collector is None: collector = TermCollector() if const_folder is None: @@ -61,19 +64,19 @@ def collect(self, expr): def map_sum(self, expr): res = IdentityMapper.map_sum(self, expr) - if isinstance(res, Sum): + if isinstance(res, p.Sum): return self.collect(res) else: return res - def map_product(self, expr): + def map_product(self, expr: p.Product) -> ExpressionT: def dist(prod): - if not isinstance(prod, Product): + if not isinstance(prod, p.Product): return prod leading = [] for i in prod.children: - if isinstance(i, Sum): + if isinstance(i, p.Sum): break else: leading.append(i) @@ -84,10 +87,10 @@ def dist(prod): return result else: sum = prod.children[len(leading)] - assert isinstance(sum, Sum) + assert isinstance(sum, p.Sum) rest = prod.children[len(leading)+1:] if rest: - rest = dist(Product(rest)) + rest = dist(p.Product(rest)) else: rest = 1 @@ -100,7 +103,7 @@ def dist(prod): return dist(IdentityMapper.map_product(self, expr)) def map_quotient(self, expr): - if is_zero(expr.numerator - 1): + if p.is_zero(expr.numerator - 1): return expr else: # not the smartest thing we can do, but at least *something* @@ -109,18 +112,19 @@ def map_quotient(self, expr): self.rec(expr.numerator) ]) - def map_power(self, expr): + def map_power(self, expr: p.Power) -> ExpressionT: from pymbolic.primitives import Sum newbase = self.rec(expr.base) - if isinstance(expr.base, Product): + if isinstance(newbase, p.Product): return self.rec(pymbolic.flattened_product([ - child**expr.exponent for child in newbase + cast(ArithmeticExpressionT, child)**expr.exponent + for child in newbase.children ])) if isinstance(expr.exponent, int): if isinstance(newbase, Sum): - return self.map_product( + return self.rec( pymbolic.flattened_product( expr.exponent*(newbase,))) else: @@ -129,7 +133,7 @@ def map_power(self, expr): return IdentityMapper.map_power(self, expr) -def distribute(expr, parameters=None, commutative=True): +def distribute(expr: ExpressionT, parameters=None, commutative=True) -> ExpressionT: if parameters is None: parameters = frozenset() if commutative: diff --git a/pymbolic/mapper/evaluator.py b/pymbolic/mapper/evaluator.py index d95a9676..67b7334f 100644 --- a/pymbolic/mapper/evaluator.py +++ b/pymbolic/mapper/evaluator.py @@ -33,19 +33,27 @@ THE SOFTWARE. """ - import operator as op +from collections.abc import Mapping from functools import reduce -from typing import Any +from typing import TYPE_CHECKING, Any + +import pymbolic.primitives as p +from pymbolic.mapper import CachedMapper, CSECachingMapperMixin, Mapper +from pymbolic.typing import ExpressionT + -from pymbolic.mapper import CachedMapper, CSECachingMapperMixin, RecursiveMapper +if TYPE_CHECKING: + import numpy as np + + from pymbolic.geometric_algebra import MultiVector class UnknownVariableError(Exception): pass -class EvaluationMapper(RecursiveMapper, CSECachingMapperMixin): +class EvaluationMapper(Mapper[Any, []], CSECachingMapperMixin): """Example usage: .. doctest:: @@ -62,7 +70,9 @@ class EvaluationMapper(RecursiveMapper, CSECachingMapperMixin): 110 """ - def __init__(self, context=None): + context: Mapping[str, Any] + + def __init__(self, context: Mapping[str, Any] | None = None) -> None: """ :arg context: a mapping from variable names to values """ @@ -70,21 +80,20 @@ def __init__(self, context=None): context = {} self.context = context - self.common_subexp_cache = {} - def map_constant(self, expr): + def map_constant(self, expr: object) -> Any: return expr - def map_variable(self, expr): + def map_variable(self, expr: p.Variable) -> None: try: return self.context[expr.name] except KeyError: raise UnknownVariableError(expr.name) from None - def map_call(self, expr): + def map_call(self, expr: p.Call) -> Any: return self.rec(expr.function)(*[self.rec(par) for par in expr.parameters]) - def map_call_with_kwargs(self, expr): + def map_call_with_kwargs(self, expr: p.CallWithKwargs) -> Any: args = [self.rec(par) for par in expr.parameters] kwargs = { k: self.rec(v) @@ -92,124 +101,97 @@ def map_call_with_kwargs(self, expr): return self.rec(expr.function)(*args, **kwargs) - def map_subscript(self, expr): - rec_result = self.rec(expr.aggregate) + def map_subscript(self, expr: p.Subscript) -> Any: + return self.rec(expr.aggregate)[self.rec(expr.index)] - from pymbolic.primitives import Expression - if isinstance(rec_result, Expression): - return rec_result.index(self.rec(expr.index)) - else: - return rec_result[self.rec(expr.index)] - - def map_lookup(self, expr): + def map_lookup(self, expr: p.Lookup) -> Any: return getattr(self.rec(expr.aggregate), expr.name) - def map_sum(self, expr): + def map_sum(self, expr: p.Sum) -> Any: return sum(self.rec(child) for child in expr.children) - def map_product(self, expr): + def map_product(self, expr: p.Product) -> Any: from pytools import product return product(self.rec(child) for child in expr.children) - def map_quotient(self, expr): + def map_quotient(self, expr: p.Quotient) -> Any: return self.rec(expr.numerator) / self.rec(expr.denominator) - def map_floor_div(self, expr): + def map_floor_div(self, expr: p.FloorDiv) -> Any: return self.rec(expr.numerator) // self.rec(expr.denominator) - def map_remainder(self, expr): + def map_remainder(self, expr: p.Remainder) -> Any: return self.rec(expr.numerator) % self.rec(expr.denominator) - def map_power(self, expr): + def map_power(self, expr: p.Power) -> Any: return self.rec(expr.base) ** self.rec(expr.exponent) - def map_left_shift(self, expr): + def map_left_shift(self, expr: p.LeftShift) -> Any: return self.rec(expr.shiftee) << self.rec(expr.shift) - def map_right_shift(self, expr): + def map_right_shift(self, expr: p.RightShift) -> Any: return self.rec(expr.shiftee) >> self.rec(expr.shift) - def map_bitwise_not(self, expr): + def map_bitwise_not(self, expr: p.BitwiseNot) -> Any: # ??? Why, pylint, why ??? # pylint: disable=invalid-unary-operand-type return ~self.rec(expr.child) - def map_bitwise_or(self, expr): + def map_bitwise_or(self, expr: p.BitwiseOr) -> Any: return reduce(op.or_, (self.rec(ch) for ch in expr.children)) - def map_bitwise_xor(self, expr): + def map_bitwise_xor(self, expr: p.BitwiseXor) -> Any: return reduce(op.xor, (self.rec(ch) for ch in expr.children)) - def map_bitwise_and(self, expr): + def map_bitwise_and(self, expr: p.BitwiseAnd) -> Any: return reduce(op.and_, (self.rec(ch) for ch in expr.children)) - def map_logical_not(self, expr): + def map_logical_not(self, expr: p.LogicalNot) -> Any: return not self.rec(expr.child) - def map_logical_or(self, expr): + def map_logical_or(self, expr: p.LogicalOr) -> Any: return any(self.rec(ch) for ch in expr.children) - def map_logical_and(self, expr): + def map_logical_and(self, expr: p.LogicalAnd) -> Any: return all(self.rec(ch) for ch in expr.children) - def map_polynomial(self, expr): - # evaluate using Horner's scheme - result = 0 - rev_data = expr.data[::-1] - ev_base = self.rec(expr.base) - - for i, (exp, coeff) in enumerate(rev_data): - if i+1 < len(rev_data): - next_exp = rev_data[i+1][0] - else: - next_exp = 0 - result = (result+coeff)*ev_base**(exp-next_exp) - - return result - - def map_list(self, expr): + def map_list(self, expr: list[ExpressionT]) -> Any: return [self.rec(child) for child in expr] - def map_numpy_array(self, expr): + def map_numpy_array(self, expr: np.ndarray) -> Any: import numpy result = numpy.empty(expr.shape, dtype=object) for i in numpy.ndindex(expr.shape): result[i] = self.rec(expr[i]) return result - def map_multivector(self, expr, *args): - return expr.map(lambda ch: self.rec(ch, *args)) + def map_multivector(self, expr: MultiVector) -> Any: + return expr.map(lambda ch: self.rec(ch)) - def map_common_subexpression_uncached(self, expr): + def map_common_subexpression_uncached(self, expr: p.CommonSubexpression) -> Any: return self.rec(expr.child) - def map_if_positive(self, expr): - if self.rec(expr.criterion) > 0: + def map_if(self, expr: p.If) -> Any: + if self.rec(expr.condition): return self.rec(expr.then) else: return self.rec(expr.else_) - def map_comparison(self, expr): + def map_comparison(self, expr: p.Comparison) -> Any: import operator return getattr(operator, expr.operator_to_name[expr.operator])( self.rec(expr.left), self.rec(expr.right)) - def map_if(self, expr): - if self.rec(expr.condition): - return self.rec(expr.then) - else: - return self.rec(expr.else_) - - def map_min(self, expr): + def map_min(self, expr: p.Min) -> Any: return min(self.rec(child) for child in expr.children) - def map_max(self, expr): + def map_max(self, expr: p.Max) -> Any: return max(self.rec(child) for child in expr.children) - def map_tuple(self, expr): + def map_tuple(self, expr: tuple[ExpressionT, ...]) -> Any: return tuple([self.rec(child) for child in expr]) - def map_nan(self, expr): + def map_nan(self, expr: p.NaN) -> Any: if expr.data_type is None: from math import nan return nan diff --git a/pymbolic/mapper/flattener.py b/pymbolic/mapper/flattener.py index 9fa1a7d5..121cf657 100644 --- a/pymbolic/mapper/flattener.py +++ b/pymbolic/mapper/flattener.py @@ -1,3 +1,11 @@ +""" +.. autoclass:: FlattenMapper + +.. currentmodule:: pymbolic + +.. autofunction:: flatten +""" + from __future__ import annotations @@ -23,18 +31,76 @@ THE SOFTWARE. """ +from typing import cast + +import pymbolic.primitives as p from pymbolic.mapper import IdentityMapper +from pymbolic.typing import ArithmeticExpressionT, ArithmeticOrExpressionT, ExpressionT -class FlattenMapper(IdentityMapper): - def map_sum(self, expr): +class FlattenMapper(IdentityMapper[[]]): + """ + Applies :func:`pymbolic.primitives.flattened_sum` + to :class:`~pymbolic.primitives.Sum`" + and :func:`pymbolic.primitives.flattened_product` + to :class:`~pymbolic.primitives.Product`." + Also applies light-duty simplification to other operators. + + This parallels what was done implicitly in the expression node + constructors. + """ + def map_sum(self, expr: p.Sum) -> ExpressionT: from pymbolic.primitives import flattened_sum - return flattened_sum([self.rec(ch) for ch in expr.children]) + return flattened_sum([ + cast(ArithmeticExpressionT, self.rec(ch)) + for ch in expr.children]) - def map_product(self, expr): + def map_product(self, expr: p.Product) -> ExpressionT: from pymbolic.primitives import flattened_product - return flattened_product([self.rec(ch) for ch in expr.children]) + return flattened_product([ + cast(ArithmeticExpressionT, self.rec(ch)) + for ch in expr.children]) + + def map_quotient(self, expr: p.Quotient) -> ExpressionT: + r_num = self.rec_arith(expr.numerator) + r_den = self.rec_arith(expr.denominator) + if p.is_zero(r_num): + return 0 + if p.is_zero(r_den - 1): + return r_num + + return expr.__class__(r_num, r_den) + + def map_floor_div(self, expr: p.FloorDiv) -> ExpressionT: + r_num = self.rec_arith(expr.numerator) + r_den = self.rec_arith(expr.denominator) + if p.is_zero(r_num): + return 0 + if p.is_zero(r_den - 1): + return r_num + + return expr.__class__(r_num, r_den) + + def map_remainder(self, expr: p.Remainder) -> ExpressionT: + r_num = self.rec_arith(expr.numerator) + r_den = self.rec_arith(expr.denominator) + assert p.is_arithmetic_expression(r_den) + if p.is_zero(r_num): + return 0 + if p.is_zero(r_den - 1): + return r_num + + return expr.__class__(r_num, r_den) + + def map_power(self, expr: p.Power) -> ExpressionT: + r_base = self.rec_arith(expr.base) + r_exp = self.rec_arith(expr.exponent) + + if p.is_zero(r_exp - 1): + return r_base + + return expr.__class__(r_base, r_exp) -def flatten(expr): - return FlattenMapper()(expr) +def flatten(expr: ArithmeticOrExpressionT) -> ArithmeticOrExpressionT: + return cast(ArithmeticOrExpressionT, FlattenMapper()(expr)) diff --git a/pymbolic/mapper/optimize.py b/pymbolic/mapper/optimize.py index 4295e097..a07757cc 100644 --- a/pymbolic/mapper/optimize.py +++ b/pymbolic/mapper/optimize.py @@ -24,7 +24,9 @@ """ import ast +from collections.abc import Callable, Iterable, MutableMapping from functools import cached_property, lru_cache +from typing import TextIO, TypeVar, cast # This machinery applies AST rewriting to the mapper in a mildly brutal @@ -39,7 +41,14 @@ # {{{ ast retrieval -def _get_def_from_ast_container(container, name, node_type): +AstDefNodeT = TypeVar("AstDefNodeT", ast.FunctionDef, ast.ClassDef) + + +def _get_def_from_ast_container( + container: Iterable[ast.AST], + name: str, + node_type: type[AstDefNodeT] + ) -> AstDefNodeT: for entry in container: if isinstance(entry, node_type) and entry.name == name: return entry @@ -48,17 +57,17 @@ def _get_def_from_ast_container(container, name, node_type): @lru_cache -def _get_ast_for_file(filename): +def _get_ast_for_file(filename: str) -> ast.Module: with open(filename) as inf: return ast.parse(inf.read(), filename) -def _get_file_name_for_module_name(module_name): +def _get_file_name_for_module_name(module_name: str) -> str | None: from importlib import import_module return import_module(module_name).__file__ -def _get_ast_for_module_name(module_name): +def _get_ast_for_module_name(module_name: str) -> ast.Module: return _get_ast_for_file(_get_file_name_for_module_name(module_name)) @@ -66,13 +75,13 @@ def _get_module_ast_for_object(obj): return _get_ast_for_module_name(obj.__module__) -def _get_ast_for_class(cls): +def _get_ast_for_class(cls: type) -> ast.ClassDef: mod_ast = _get_module_ast_for_object(cls) return _get_def_from_ast_container( mod_ast.body, cls.__name__, ast.ClassDef) -def _get_ast_for_method(f): +def _get_ast_for_method(f: Callable) -> ast.FunctionDef: dot_components = f.__qualname__.split(".") assert dot_components[-1] == f.__name__ cls_name, = dot_components[:-1] @@ -120,15 +129,15 @@ def __init__(self, *, inline_rec, inline_cache): self.inline_rec = inline_rec self.inline_cache = inline_cache - def visit_Call(self, node): # noqa: N802 - node = self.generic_visit(node) + def visit_Call(self, node: ast.Call) -> ast.AST: # noqa: N802 + node = cast(ast.Call, self.generic_visit(node)) - result_expr = node + result_expr: ast.expr = node if (isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and node.func.value.id == "self" - and node.func.attr == "rec"): + and node.func.attr in ["rec", "rec_arith"]): from ast import ( Attribute, @@ -182,7 +191,7 @@ def expr_assign(name, value): args=[expr], keywords=[]) cache_key_expr = ast.Tuple([expr_type, expr], ctx=Load()) - nic = Name(id="_NOT_IN_CACHE", ctx=Load()) + nic = Name(id="_NotInCache", ctx=Load()) result_expr = IfExp( test=Compare( @@ -238,22 +247,31 @@ def visit_Call(self, node): # noqa: N802 return result_expr -def _set_and_return(mapping, key, value): +KeyT = TypeVar("KeyT") +ValueT = TypeVar("ValueT") + + +def _set_and_return( + mapping: MutableMapping[KeyT, ValueT], + key: KeyT, + value: ValueT + ) -> ValueT: mapping[key] = value return value def optimize_mapper( - *, drop_args=False, drop_kwargs=False, - inline_rec=False, inline_cache=False, inline_get_cache_key=False, - print_modified_code_file=None): + *, drop_args: bool = False, drop_kwargs: bool = False, + inline_rec: bool = False, inline_cache: bool = False, + inline_get_cache_key: bool = False, + print_modified_code_file: TextIO | None = None) -> Callable[[type], type]: """ :param print_modified_code_file: a file-like object to which the modified code will be printed, or ``None``. """ # This is a crime, an abomination. But a somewhat effective one. - def wrapper(cls): + def wrapper(cls: type) -> type: try: # Introduced in Py3.9 ast.unparse # noqa: B018 @@ -283,7 +301,8 @@ def wrapper(cls): for name in dir(cls): if not name.startswith("__") or name == "__call__": method = getattr(cls, name) - if isinstance(method, (property, cached_property)): + if (not callable(method) + or isinstance(method, property | cached_property)): # properties don't have *args, **kwargs continue @@ -378,7 +397,7 @@ def wrapper(cls): "exec"), compile_dict) - return compile_dict[cls.__name__] + return cast(type, compile_dict[cls.__name__]) return wrapper diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 47e062c9..a3ccd3fc 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -22,11 +22,21 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from collections.abc import Sequence +from typing import TYPE_CHECKING, ClassVar, Concatenate +from warnings import warn -from typing import ClassVar +from typing_extensions import deprecated import pymbolic.primitives as p -from pymbolic.mapper import CachedMapper, Mapper +from pymbolic.mapper import CachedMapper, Mapper, P +from pymbolic.typing import ExpressionT + + +if TYPE_CHECKING: + import numpy as np + + from pymbolic.geometric_algebra import MultiVector __doc__ = """ @@ -81,7 +91,8 @@ # {{{ stringifier -class StringifyMapper(Mapper): + +class StringifyMapper(Mapper[str, Concatenate[int, P]]): """A mapper to turn an expression tree into a string. :class:`pymbolic.Expression.__str__` is often implemented using @@ -94,13 +105,25 @@ class StringifyMapper(Mapper): # {{{ replaceable string composition interface - def format(self, s, *args): + def format(self, s: str, *args: object) -> str: return s % args - def join(self, joiner, iterable): - return self.format(joiner.join("%s" for _ in iterable), *iterable) + def join(self, joiner: str, seq: Sequence[ExpressionT]) -> str: + return self.format(joiner.join("%s" for _ in seq), *seq) + + # {{{ deprecated junk + @deprecated("interface not type-safe, use rec_with_parens_around_types") def rec_with_force_parens_around(self, expr, *args, **kwargs): + warn( + "rec_with_force_parens_around is deprecated and will be removed in 2025. " + "Use rec_with_parens_around_types instead. ", + DeprecationWarning, + stacklevel=2, + ) + # Not currently possible to make this type-safe: + # https://peps.python.org/pep-0612/#concatenating-keyword-parameters + force_parens_around = kwargs.pop("force_parens_around", ()) result = self.rec(expr, *args, **kwargs) @@ -110,16 +133,77 @@ def rec_with_force_parens_around(self, expr, *args, **kwargs): return result - def join_rec(self, joiner, iterable, prec, *args, **kwargs): - f = joiner.join("%s" for _ in iterable) - return self.format(f, - *[self.rec_with_force_parens_around(i, prec, *args, **kwargs) - for i in iterable]) + def join_rec( + self, + joiner: str, + seq: Sequence[ExpressionT], + prec: int, + *args, + **kwargs, # force_with_parens_around may hide in here + ) -> str: + f = joiner.join("%s" for _ in seq) + + if "force_parens_around" in kwargs: + warn( + "Passing force_parens_around join_rec is deprecated and will be " + "removed in 2025. " + "Use join_rec_with_parens_around_types instead. ", + DeprecationWarning, + stacklevel=2, + ) + # Not currently possible to make this type-safe: + # https://peps.python.org/pep-0612/#concatenating-keyword-parameters + parens_around_types: tuple[type, ...] = kwargs.pop("force_parens_around") + return self.join_rec_with_parens_around_types( + joiner, seq, prec, parens_around_types, *args, **kwargs + ) + + return self.format( + f, + *[self.rec(i, prec, *args, **kwargs) for i in seq], + ) + + # }}} + + def rec_with_parens_around_types( + self, + expr: ExpressionT, + enclosing_prec: int, + parens_around: tuple[type, ...], + *args: P.args, + **kwargs: P.kwargs, + ) -> str: + result = self.rec(expr, enclosing_prec, *args, **kwargs) + + if isinstance(expr, parens_around): + result = f"({result})" + + return result + + def join_rec_with_parens_around_types( + self, + joiner: str, + seq: Sequence[ExpressionT], + prec: int, + parens_around_types: tuple[type, ...], + *args: P.args, + **kwargs: P.kwargs, + ) -> str: + f = joiner.join("%s" for _ in seq) + return self.format( + f, + *[ + self.rec_with_parens_around_types( + i, prec, parens_around_types, *args, **kwargs + ) + for i in seq + ], + ) - def parenthesize(self, s): + def parenthesize(self, s: str) -> str: return f"({s})" - def parenthesize_if_needed(self, s, enclosing_prec, my_prec): + def parenthesize_if_needed(self, s: str, enclosing_prec: int, my_prec: int) -> str: if enclosing_prec > my_prec: return f"({s})" else: @@ -129,216 +213,391 @@ def parenthesize_if_needed(self, s, enclosing_prec, my_prec): # {{{ mappings - def handle_unsupported_expression(self, expr, enclosing_prec, *args, **kwargs): + def handle_unsupported_expression( + self, expr, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: strifier = expr.make_stringifier(self) if isinstance(self, type(strifier)): - raise ValueError( - f"stringifier '{self}' can't handle '{expr.__class__}'") - return strifier( - expr, enclosing_prec, *args, **kwargs) + raise ValueError(f"stringifier '{self}' can't handle '{expr.__class__}'") + return strifier(expr, enclosing_prec, *args, **kwargs) - def map_constant(self, expr, enclosing_prec, *args, **kwargs): + def map_constant( + self, expr: object, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: result = str(expr) - if not (result.startswith("(") and result.endswith(")")) \ - and ("-" in result or "+" in result) \ - and (enclosing_prec > PREC_SUM): + if ( + not (result.startswith("(") and result.endswith(")")) + and ("-" in result or "+" in result) + and (enclosing_prec > PREC_SUM) + ): return self.parenthesize(result) else: return result - def map_variable(self, expr, enclosing_prec, *args, **kwargs): + def map_variable( + self, expr: p.Variable, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return expr.name - def map_wildcard(self, expr, enclosing_prec, *args, **kwargs): + def map_wildcard( + self, expr: p.Wildcard, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return "*" - def map_function_symbol(self, expr, enclosing_prec, *args, **kwargs): + def map_function_symbol( + self, + expr: p.FunctionSymbol, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return expr.__class__.__name__ - def map_call(self, expr, enclosing_prec, *args, **kwargs): - return self.format("%s(%s)", - self.rec(expr.function, PREC_CALL, *args, **kwargs), - self.join_rec(", ", expr.parameters, PREC_NONE, *args, **kwargs)) - - def map_call_with_kwargs(self, expr, enclosing_prec, *args, **kwargs): - args_strings = ( - tuple([ - self.rec(ch, PREC_NONE, *args, **kwargs) - for ch in expr.parameters - ]) - + - tuple([ - "{}={}".format(name, self.rec(ch, PREC_NONE, *args, **kwargs)) - for name, ch in expr.kw_parameters.items() - ]) - ) - return self.format("%s(%s)", - self.rec(expr.function, PREC_CALL, *args, **kwargs), - ", ".join(args_strings)) - - def map_subscript(self, expr, enclosing_prec, *args, **kwargs): + def map_call( + self, expr: p.Call, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + return self.format( + "%s(%s)", + self.rec(expr.function, PREC_CALL, *args, **kwargs), + self.join_rec(", ", expr.parameters, PREC_NONE, *args, **kwargs), + ) + + def map_call_with_kwargs( + self, + expr: p.CallWithKwargs, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: + args_strings = tuple([ + self.rec(ch, PREC_NONE, *args, **kwargs) for ch in expr.parameters + ]) + tuple([ + "{}={}".format(name, self.rec(ch, PREC_NONE, *args, **kwargs)) + for name, ch in expr.kw_parameters.items() + ]) + return self.format( + "%s(%s)", + self.rec(expr.function, PREC_CALL, *args, **kwargs), + ", ".join(args_strings), + ) + + def map_subscript( + self, expr: p.Subscript, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: if isinstance(expr.index, tuple): index_str = self.join_rec(", ", expr.index, PREC_NONE, *args, **kwargs) else: index_str = self.rec(expr.index, PREC_NONE, *args, **kwargs) return self.parenthesize_if_needed( - self.format("%s[%s]", - self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), - index_str), - enclosing_prec, PREC_CALL) - - def map_lookup(self, expr, enclosing_prec, *args, **kwargs): + self.format( + "%s[%s]", + self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), + index_str, + ), + enclosing_prec, + PREC_CALL, + ) + + def map_lookup( + self, expr: p.Lookup, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s.%s", - self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), - expr.name), - enclosing_prec, PREC_CALL) - - def map_sum(self, expr, enclosing_prec, *args, **kwargs): + self.format( + "%s.%s", self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), expr.name + ), + enclosing_prec, + PREC_CALL, + ) + + def map_sum( + self, expr: p.Sum, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.join_rec(" + ", expr.children, PREC_SUM, *args, **kwargs), - enclosing_prec, PREC_SUM) + self.join_rec(" + ", expr.children, PREC_SUM, *args, **kwargs), + enclosing_prec, + PREC_SUM, + ) # {{{ multiplicative operators multiplicative_primitives = (p.Product, p.Quotient, p.FloorDiv, p.Remainder) - def map_product(self, expr, enclosing_prec, *args, **kwargs): - kwargs["force_parens_around"] = (p.Quotient, p.FloorDiv, p.Remainder) + def map_product( + self, expr: p.Product, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.join_rec("*", expr.children, PREC_PRODUCT, *args, **kwargs), - enclosing_prec, PREC_PRODUCT) - - def map_quotient(self, expr, enclosing_prec, *args, **kwargs): - kwargs["force_parens_around"] = self.multiplicative_primitives + self.join_rec_with_parens_around_types( + "*", + expr.children, + PREC_PRODUCT, + (p.Quotient, p.FloorDiv, p.Remainder), + *args, + **kwargs, + ), + enclosing_prec, + PREC_PRODUCT, + ) + + def map_quotient( + self, expr: p.Quotient, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s / %s", - # space is necessary--otherwise '/*' becomes - # start-of-comment in C. ('*' from dereference) - self.rec_with_force_parens_around(expr.numerator, PREC_PRODUCT, - *args, **kwargs), - self.rec_with_force_parens_around( - expr.denominator, PREC_PRODUCT, *args, **kwargs)), - enclosing_prec, PREC_PRODUCT) - - def map_floor_div(self, expr, enclosing_prec, *args, **kwargs): - kwargs["force_parens_around"] = self.multiplicative_primitives + self.format( + "%s / %s", + # space is necessary--otherwise '/*' becomes + # start-of-comment in C. ('*' from dereference) + self.rec_with_parens_around_types( + expr.numerator, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + self.rec_with_parens_around_types( + expr.denominator, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + ), + enclosing_prec, + PREC_PRODUCT, + ) + + def map_floor_div( + self, expr: p.FloorDiv, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s // %s", - self.rec_with_force_parens_around( - expr.numerator, PREC_PRODUCT, *args, **kwargs), - self.rec_with_force_parens_around( - expr.denominator, PREC_PRODUCT, *args, **kwargs)), - enclosing_prec, PREC_PRODUCT) - - def map_remainder(self, expr, enclosing_prec, *args, **kwargs): - kwargs["force_parens_around"] = self.multiplicative_primitives + self.format( + "%s // %s", + self.rec_with_parens_around_types( + expr.numerator, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + self.rec_with_parens_around_types( + expr.denominator, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + ), + enclosing_prec, + PREC_PRODUCT, + ) + + def map_remainder( + self, expr: p.Remainder, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s %% %s", - self.rec_with_force_parens_around( - expr.numerator, PREC_PRODUCT, *args, **kwargs), - self.rec_with_force_parens_around( - expr.denominator, PREC_PRODUCT, *args, **kwargs)), - enclosing_prec, PREC_PRODUCT) + self.format( + "%s %% %s", + self.rec_with_parens_around_types( + expr.numerator, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + self.rec_with_parens_around_types( + expr.denominator, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + ), + enclosing_prec, + PREC_PRODUCT, + ) # }}} - def map_power(self, expr, enclosing_prec, *args, **kwargs): + def map_power( + self, expr: p.Power, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s**%s", - self.rec(expr.base, PREC_POWER, *args, **kwargs), - self.rec(expr.exponent, PREC_POWER, *args, **kwargs)), - enclosing_prec, PREC_POWER) - - def map_polynomial(self, expr, enclosing_prec, *args, **kwargs): - from pymbolic.primitives import flattened_sum - return self.rec(flattened_sum( - [coeff*expr.base**exp for exp, coeff in expr.data[::-1]]), - enclosing_prec, *args, **kwargs) - - def map_left_shift(self, expr, enclosing_prec, *args, **kwargs): + self.format( + "%s**%s", + self.rec(expr.base, PREC_POWER, *args, **kwargs), + self.rec(expr.exponent, PREC_POWER, *args, **kwargs), + ), + enclosing_prec, + PREC_POWER, + ) + + def map_left_shift( + self, expr: p.LeftShift, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - # +1 to address - # https://gitlab.tiker.net/inducer/pymbolic/issues/6 - self.format("%s << %s", - self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs), - self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)), - enclosing_prec, PREC_SHIFT) - - def map_right_shift(self, expr, enclosing_prec, *args, **kwargs): + # +1 to address + # https://gitlab.tiker.net/inducer/pymbolic/issues/6 + self.format( + "%s << %s", + self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs), + ), + enclosing_prec, + PREC_SHIFT, + ) + + def map_right_shift( + self, + expr: p.RightShift, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - # +1 to address - # https://gitlab.tiker.net/inducer/pymbolic/issues/6 - self.format("%s >> %s", - self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs), - self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)), - enclosing_prec, PREC_SHIFT) - - def map_bitwise_not(self, expr, enclosing_prec, *args, **kwargs): + # +1 to address + # https://gitlab.tiker.net/inducer/pymbolic/issues/6 + self.format( + "%s >> %s", + self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs), + ), + enclosing_prec, + PREC_SHIFT, + ) + + def map_bitwise_not( + self, + expr: p.BitwiseNot, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - "~" + self.rec(expr.child, PREC_UNARY, *args, **kwargs), - enclosing_prec, PREC_UNARY) - - def map_bitwise_or(self, expr, enclosing_prec, *args, **kwargs): + "~" + self.rec(expr.child, PREC_UNARY, *args, **kwargs), + enclosing_prec, + PREC_UNARY, + ) + + def map_bitwise_or( + self, expr: p.BitwiseOr, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - " | ", expr.children, PREC_BITWISE_OR, *args, **kwargs), - enclosing_prec, PREC_BITWISE_OR) - - def map_bitwise_xor(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(" | ", expr.children, PREC_BITWISE_OR, *args, **kwargs), + enclosing_prec, + PREC_BITWISE_OR, + ) + + def map_bitwise_xor( + self, + expr: p.BitwiseXor, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - " ^ ", expr.children, PREC_BITWISE_XOR, *args, **kwargs), - enclosing_prec, PREC_BITWISE_XOR) - - def map_bitwise_and(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(" ^ ", expr.children, PREC_BITWISE_XOR, *args, **kwargs), + enclosing_prec, + PREC_BITWISE_XOR, + ) + + def map_bitwise_and( + self, + expr: p.BitwiseAnd, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - " & ", expr.children, PREC_BITWISE_AND, *args, **kwargs), - enclosing_prec, PREC_BITWISE_AND) - - def map_comparison(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(" & ", expr.children, PREC_BITWISE_AND, *args, **kwargs), + enclosing_prec, + PREC_BITWISE_AND, + ) + + def map_comparison( + self, expr: p.Comparison, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s %s %s", - self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), - expr.operator, - self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)), - enclosing_prec, PREC_COMPARISON) - - def map_logical_not(self, expr, enclosing_prec, *args, **kwargs): + self.format( + "%s %s %s", + self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), + expr.operator, + self.rec(expr.right, PREC_COMPARISON, *args, **kwargs), + ), + enclosing_prec, + PREC_COMPARISON, + ) + + def map_logical_not( + self, + expr: p.LogicalNot, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - "not " + self.rec(expr.child, PREC_UNARY, *args, **kwargs), - enclosing_prec, PREC_UNARY) - - def map_logical_or(self, expr, enclosing_prec, *args, **kwargs): + "not " + self.rec(expr.child, PREC_UNARY, *args, **kwargs), + enclosing_prec, + PREC_UNARY, + ) + + def map_logical_or( + self, expr: p.LogicalOr, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - " or ", expr.children, PREC_LOGICAL_OR, *args, **kwargs), - enclosing_prec, PREC_LOGICAL_OR) - - def map_logical_and(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(" or ", expr.children, PREC_LOGICAL_OR, *args, **kwargs), + enclosing_prec, + PREC_LOGICAL_OR, + ) + + def map_logical_and( + self, + expr: p.LogicalAnd, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - " and ", expr.children, PREC_LOGICAL_AND, *args, **kwargs), - enclosing_prec, PREC_LOGICAL_AND) - - def map_list(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(" and ", expr.children, PREC_LOGICAL_AND, *args, **kwargs), + enclosing_prec, + PREC_LOGICAL_AND, + ) + + def map_list( + self, + expr: list[ExpressionT], + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.format( - "[%s]", self.join_rec(", ", expr, PREC_NONE, *args, **kwargs)) + "[%s]", self.join_rec(", ", expr, PREC_NONE, *args, **kwargs) + ) map_vector = map_list - def map_tuple(self, expr, enclosing_prec, *args, **kwargs): + def map_tuple( + self, + expr: tuple[ExpressionT, ...], + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: el_str = ", ".join( - self.rec(child, PREC_NONE, *args, **kwargs) for child in expr) + self.rec(child, PREC_NONE, *args, **kwargs) for child in expr + ) if len(expr) == 1: el_str += "," return f"({el_str})" - def map_numpy_array(self, expr, enclosing_prec, *args, **kwargs): + def map_numpy_array( + self, + expr: np.ndarray, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: import numpy str_array = numpy.zeros(expr.shape, dtype="object") @@ -351,68 +610,102 @@ def map_numpy_array(self, expr, enclosing_prec, *args, **kwargs): if len(expr.shape) == 1 and max_length < 15: return "array({})".format(", ".join(str_array)) else: - lines = [" {}: {}\n".format( - ",".join(str(i_i) for i_i in i), val) - for i, val in numpy.ndenumerate(str_array)] + lines = [ + " {}: {}\n".format(",".join(str(i_i) for i_i in i), val) + for i, val in numpy.ndenumerate(str_array) + ] if max_length > 70: - splitter = " " + "-"*75 + "\n" + splitter = " " + "-" * 75 + "\n" return "array(\n{})".format(splitter.join(lines)) else: return "array(\n{})".format("".join(lines)) - def map_multivector(self, expr, enclosing_prec, *args, **kwargs): + def map_multivector( + self, + expr: MultiVector, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return expr.stringify(self.rec, enclosing_prec, *args, **kwargs) - def map_common_subexpression(self, expr, enclosing_prec, *args, **kwargs): + def map_common_subexpression( + self, + expr: p.CommonSubexpression, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: from pymbolic.primitives import CommonSubexpression + if type(expr) is CommonSubexpression: type_name = "CSE" else: type_name = type(expr).__name__ - return self.format("%s(%s)", - type_name, self.rec(expr.child, PREC_NONE, *args, **kwargs)) - - def map_if(self, expr, enclosing_prec, *args, **kwargs): - return self.parenthesize_if_needed( - "{} if {} else {}".format( - self.rec(expr.then, PREC_LOGICAL_OR, *args, **kwargs), - self.rec(expr.condition, PREC_LOGICAL_OR, *args, **kwargs), - self.rec(expr.else_, PREC_LOGICAL_OR, *args, **kwargs)), - enclosing_prec, PREC_IF) + return self.format( + "%s(%s)", type_name, self.rec(expr.child, PREC_NONE, *args, **kwargs) + ) - def map_if_positive(self, expr, enclosing_prec, *args, **kwargs): + def map_if( + self, expr: p.If, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - "{} if {} > 0 else {}".format( - self.rec(expr.then, PREC_LOGICAL_OR, *args, **kwargs), - self.rec(expr.criterion, PREC_LOGICAL_OR, *args, **kwargs), - self.rec(expr.else_, PREC_LOGICAL_OR, *args, **kwargs)), - enclosing_prec, PREC_IF) - - def map_min(self, expr, enclosing_prec, *args, **kwargs): + "{} if {} else {}".format( + self.rec(expr.then, PREC_LOGICAL_OR, *args, **kwargs), + self.rec(expr.condition, PREC_LOGICAL_OR, *args, **kwargs), + self.rec(expr.else_, PREC_LOGICAL_OR, *args, **kwargs), + ), + enclosing_prec, + PREC_IF, + ) + + def map_min( + self, expr: p.Min, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: what = type(expr).__name__.lower() - return self.format("%s(%s)", - what, self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs)) - - map_max = map_min - - def map_derivative(self, expr, enclosing_prec, *args, **kwargs): + return self.format( + "%s(%s)", + what, + self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs), + ) + + def map_max( + self, expr: p.Max, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + what = type(expr).__name__.lower() + return self.format( + "%s(%s)", + what, + self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs), + ) + + def map_derivative( + self, expr: p.Derivative, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: derivs = " ".join(f"d/d{v}" for v in expr.variables) return "{} {}".format( - derivs, - self.rec(expr.child, PREC_PRODUCT, *args, **kwargs)) - - def map_substitution(self, expr, enclosing_prec, *args, **kwargs): + derivs, self.rec(expr.child, PREC_PRODUCT, *args, **kwargs) + ) + + def map_substitution( + self, + expr: p.Substitution, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: substs = ", ".join( - "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs)) - for name, val in zip(expr.variables, expr.values)) + "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs)) + for name, val in zip(expr.variables, expr.values, strict=True) + ) - return "[%s]{%s}" % ( - self.rec(expr.child, PREC_NONE, *args, **kwargs), - substs) + return "[%s]{%s}" % (self.rec(expr.child, PREC_NONE, *args, **kwargs), substs) - def map_slice(self, expr, enclosing_prec, *args, **kwargs): + def map_slice( + self, expr: p.Slice, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: children = [] for child in expr.children: if child is None: @@ -421,15 +714,17 @@ def map_slice(self, expr, enclosing_prec, *args, **kwargs): children.append(self.rec(child, PREC_NONE, *args, **kwargs)) return self.parenthesize_if_needed( - self.join(":", children), - enclosing_prec, PREC_NONE) + ":".join(children), enclosing_prec, PREC_NONE + ) - def map_nan(self, expr, enclosing_prec, *args, **kwargs): + def map_nan( + self, expr: p.NaN, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return "NaN" # }}} - def __call__(self, expr, prec=PREC_NONE, *args, **kwargs): + def __call__(self, expr, prec=PREC_NONE, *args: P.args, **kwargs: P.kwargs) -> str: """Return a string corresponding to *expr*. If the enclosing precedence level *prec* is higher than *prec* (see :ref:`prec-constants`), parenthesize the result. @@ -443,15 +738,17 @@ def __init__(self) -> None: StringifyMapper.__init__(self) CachedMapper.__init__(self) - def __call__(self, expr, prec=PREC_NONE, *args, **kwargs): + def __call__(self, expr, prec=PREC_NONE, *args: P.args, **kwargs: P.kwargs) -> str: return CachedMapper.__call__(expr, prec, *args, **kwargs) + # }}} # {{{ cse-splitting stringifier -class CSESplittingStringifyMapperMixin: + +class CSESplittingStringifyMapperMixin(Mapper[str, Concatenate[int, P]]): """A :term:`mix-in` for subclasses of :class:`StringifyMapper` that collects "variable assignments" for @@ -475,44 +772,45 @@ class CSESplittingStringifyMapperMixin: See :class:`pymbolic.mapper.c_code.CCodeMapper` for an example of the use of this mix-in. """ - def __init__(self): + + cse_to_name: dict[ExpressionT, str] + cse_names: set[str] + cse_name_list: list[tuple[str, str]] + + def __init__(self) -> None: self.cse_to_name = {} self.cse_names = set() self.cse_name_list = [] super().__init__() - def map_common_subexpression(self, expr, enclosing_prec, *args, **kwargs): - # This is here for compatibility, in case the constructor did not get called. - try: - self.cse_to_name # noqa: B018 - except AttributeError: - from warnings import warn - warn("Constructor of CSESplittingStringifyMapperMixin did not get " - "called. This is deprecated and will stop working in 2022.", - DeprecationWarning, stacklevel=2) - - self.cse_to_name = {} - self.cse_names = set() - self.cse_name_list = [] - + def map_common_subexpression( + self, + expr: p.CommonSubexpression, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: try: cse_name = self.cse_to_name[expr.child] except KeyError: str_child = self.rec(expr.child, PREC_NONE, *args, **kwargs) if expr.prefix is not None: + def generate_cse_names(): yield expr.prefix i = 2 while True: yield expr.prefix + f"_{i}" i += 1 + else: + def generate_cse_names(): i = 0 while True: - yield "CSE"+str(i) + yield "CSE" + str(i) i += 1 for cse_name in generate_cse_names(): @@ -525,51 +823,63 @@ def generate_cse_names(): return cse_name - def get_cse_strings(self): - return [f"{cse_name} : {cse_str}" - for cse_name, cse_str in - sorted(getattr(self, "cse_name_list", []))] + def get_cse_strings(self) -> list[str]: + return [ + f"{cse_name} : {cse_str}" + for cse_name, cse_str in sorted(getattr(self, "cse_name_list", [])) + ] + # }}} # {{{ sorting stringifier -class SortingStringifyMapper(StringifyMapper): + +class SortingStringifyMapper(StringifyMapper[P]): def __init__(self, reverse=True): super().__init__() self.reverse = reverse - def map_sum(self, expr, enclosing_prec, *args, **kwargs): + def map_sum( + self, expr: p.Sum, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: entries = [self.rec(i, PREC_SUM, *args, **kwargs) for i in expr.children] entries.sort(reverse=self.reverse) - return self.parenthesize_if_needed( - self.join(" + ", entries), - enclosing_prec, PREC_SUM) + return self.parenthesize_if_needed("+".join(entries), enclosing_prec, PREC_SUM) - def map_product(self, expr, enclosing_prec, *args, **kwargs): + def map_product( + self, expr: p.Product, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: entries = [self.rec(i, PREC_PRODUCT, *args, **kwargs) for i in expr.children] entries.sort(reverse=self.reverse) return self.parenthesize_if_needed( - self.join("*", entries), - enclosing_prec, PREC_PRODUCT) + "*".join(entries), enclosing_prec, PREC_PRODUCT + ) + # }}} # {{{ simplifying, sorting stringifier + class SimplifyingSortingStringifyMapper(StringifyMapper): def __init__(self, reverse=True): super().__init__() self.reverse = reverse - def map_sum(self, expr, enclosing_prec, *args, **kwargs): + def map_sum( + self, expr: p.Sum, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: def get_neg_product(expr): from pymbolic.primitives import Product, is_zero - if isinstance(expr, Product) \ - and len(expr.children) and is_zero(expr.children[0]+1): + if ( + isinstance(expr, Product) + and len(expr.children) + and is_zero(expr.children[0] + 1) + ): if len(expr.children) == 2: # only the minus sign and the other child return expr.children[1] @@ -589,29 +899,32 @@ def get_neg_product(expr): positives.append(self.rec(ch, PREC_SUM, *args, **kwargs)) positives.sort(reverse=self.reverse) - positives = " + ".join(positives) + positives_str = " + ".join(positives) negatives.sort(reverse=self.reverse) - negatives = self.join("", - [self.format(" - %s", entry) for entry in negatives]) + negatives_str = "".join(self.format(" - %s", entry) for entry in negatives) - result = positives + negatives + result = positives_str + negatives_str return self.parenthesize_if_needed(result, enclosing_prec, PREC_SUM) - def map_product(self, expr, enclosing_prec, *args, **kwargs): + def map_product( + self, expr: p.Product, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: entries = [] i = 0 from pymbolic.primitives import is_zero while i < len(expr.children): child = expr.children[i] - if False and is_zero(child+1) and i+1 < len(expr.children): + if False and is_zero(child + 1) and i + 1 < len(expr.children): # NOTE: That space needs to be there. # Otherwise two unary minus signs merge into a pre-decrement. entries.append( - self.format( - "- %s", self.rec( - expr.children[i+1], PREC_UNARY, *args, **kwargs))) + self.format( + "- %s", + self.rec(expr.children[i + 1], PREC_UNARY, *args, **kwargs), + ) + ) i += 2 else: entries.append(self.rec(child, PREC_PRODUCT, *args, **kwargs)) @@ -622,127 +935,226 @@ def map_product(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed(result, enclosing_prec, PREC_PRODUCT) + # }}} # {{{ latex stringifier -class LaTeXMapper(StringifyMapper): +class LaTeXMapper(StringifyMapper): COMPARISON_OP_TO_LATEX: ClassVar[dict[str, str]] = { "==": r"=", "!=": r"\ne", "<=": r"\le", ">=": r"\ge", - "<": r"<", - ">": r">", - } - - def map_remainder(self, expr, enclosing_prec, *args, **kwargs): - return self.format(r"(%s \bmod %s)", - self.rec(expr.numerator, PREC_PRODUCT, *args, **kwargs), - self.rec(expr.denominator, PREC_POWER, *args, **kwargs)), + "<": r"<", + ">": r">", + } - def map_left_shift(self, expr, enclosing_prec, *args, **kwargs): + def map_remainder( + self, expr: p.Remainder, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + return self.format( + r"(%s \bmod %s)", + self.rec(expr.numerator, PREC_PRODUCT, *args, **kwargs), + self.rec(expr.denominator, PREC_POWER, *args, **kwargs), + ) + + def map_left_shift( + self, expr: p.LeftShift, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format(r"%s \ll %s", - self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs), - self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)), - enclosing_prec, PREC_SHIFT) - - def map_right_shift(self, expr, enclosing_prec, *args, **kwargs): + self.format( + r"%s \ll %s", + self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs), + ), + enclosing_prec, + PREC_SHIFT, + ) + + def map_right_shift( + self, + expr: p.RightShift, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - self.format(r"%s \gg %s", - self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs), - self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)), - enclosing_prec, PREC_SHIFT) - - def map_bitwise_xor(self, expr, enclosing_prec, *args, **kwargs): + self.format( + r"%s \gg %s", + self.rec(expr.shiftee, PREC_SHIFT + 1, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT + 1, *args, **kwargs), + ), + enclosing_prec, + PREC_SHIFT, + ) + + def map_bitwise_xor( + self, + expr: p.BitwiseXor, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - r" \wedge ", expr.children, PREC_BITWISE_XOR, *args, **kwargs), - enclosing_prec, PREC_BITWISE_XOR) - - def map_product(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec( + r" \wedge ", expr.children, PREC_BITWISE_XOR, *args, **kwargs + ), + enclosing_prec, + PREC_BITWISE_XOR, + ) + + def map_product( + self, expr: p.Product, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.join_rec(" ", expr.children, PREC_PRODUCT, *args, **kwargs), - enclosing_prec, PREC_PRODUCT) - - def map_power(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(" ", expr.children, PREC_PRODUCT, *args, **kwargs), + enclosing_prec, + PREC_PRODUCT, + ) + + def map_power( + self, expr: p.Power, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("{%s}^{%s}", - self.rec(expr.base, PREC_NONE, *args, **kwargs), - self.rec(expr.exponent, PREC_NONE, *args, **kwargs)), - enclosing_prec, PREC_NONE) - - def map_min(self, expr, enclosing_prec, *args, **kwargs): + self.format( + "{%s}^{%s}", + self.rec(expr.base, PREC_NONE, *args, **kwargs), + self.rec(expr.exponent, PREC_NONE, *args, **kwargs), + ), + enclosing_prec, + PREC_NONE, + ) + + def map_min( + self, expr: p.Min, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: from pytools import is_single_valued + if is_single_valued(expr.children): return self.rec(expr.children[0], enclosing_prec) what = type(expr).__name__.lower() - return self.format(r"\%s(%s)", - what, self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs)) - - def map_max(self, expr, enclosing_prec): - return self.map_min(expr, enclosing_prec) + return self.format( + r"\%s(%s)", + what, + self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs), + ) + + def map_max( + self, expr: p.Max, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + from pytools import is_single_valued - def map_floor_div(self, expr, enclosing_prec, *args, **kwargs): - return self.format(r"\lfloor {%s} / {%s} \rfloor", - self.rec(expr.numerator, PREC_NONE, *args, **kwargs), - self.rec(expr.denominator, PREC_NONE, *args, **kwargs)) + if is_single_valued(expr.children): + return self.rec(expr.children[0], enclosing_prec) - def map_subscript(self, expr, enclosing_prec, *args, **kwargs): + what = type(expr).__name__.lower() + return self.format( + r"\%s(%s)", + what, + self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs), + ) + + def map_floor_div( + self, expr: p.FloorDiv, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + return self.format( + r"\lfloor {%s} / {%s} \rfloor", + self.rec(expr.numerator, PREC_NONE, *args, **kwargs), + self.rec(expr.denominator, PREC_NONE, *args, **kwargs), + ) + + def map_subscript( + self, expr: p.Subscript, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: if isinstance(expr.index, tuple): index_str = self.join_rec(", ", expr.index, PREC_NONE, *args, **kwargs) else: index_str = self.rec(expr.index, PREC_NONE, *args, **kwargs) - return self.format("{%s}_{%s}", - self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), - index_str) - - def map_logical_not(self, expr, enclosing_prec, *args, **kwargs): + return self.format( + "{%s}_{%s}", self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), index_str + ) + + def map_logical_not( + self, + expr: p.LogicalNot, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - r"\neg " + self.rec(expr.child, PREC_UNARY, *args, **kwargs), - enclosing_prec, PREC_UNARY) - - def map_logical_or(self, expr, enclosing_prec, *args, **kwargs): + r"\neg " + self.rec(expr.child, PREC_UNARY, *args, **kwargs), + enclosing_prec, + PREC_UNARY, + ) + + def map_logical_or( + self, expr: p.LogicalOr, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - r" \vee ", expr.children, PREC_LOGICAL_OR, *args, **kwargs), - enclosing_prec, PREC_LOGICAL_OR) - - def map_logical_and(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec(r" \vee ", expr.children, PREC_LOGICAL_OR, *args, **kwargs), + enclosing_prec, + PREC_LOGICAL_OR, + ) + + def map_logical_and( + self, + expr: p.LogicalAnd, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: return self.parenthesize_if_needed( - self.join_rec( - r" \wedge ", expr.children, PREC_LOGICAL_AND, *args, **kwargs), - enclosing_prec, PREC_LOGICAL_AND) - - def map_comparison(self, expr, enclosing_prec, *args, **kwargs): + self.join_rec( + r" \wedge ", expr.children, PREC_LOGICAL_AND, *args, **kwargs + ), + enclosing_prec, + PREC_LOGICAL_AND, + ) + + def map_comparison( + self, expr: p.Comparison, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: return self.parenthesize_if_needed( - self.format("%s %s %s", - self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), - self.COMPARISON_OP_TO_LATEX[expr.operator], - self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)), - enclosing_prec, PREC_COMPARISON) - - def map_substitution(self, expr, enclosing_prec, *args, **kwargs): + self.format( + "%s %s %s", + self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), + self.COMPARISON_OP_TO_LATEX[expr.operator], + self.rec(expr.right, PREC_COMPARISON, *args, **kwargs), + ), + enclosing_prec, + PREC_COMPARISON, + ) + + def map_substitution( + self, + expr: p.Substitution, + enclosing_prec: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: substs = ", ".join( - "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs)) - for name, val in zip(expr.variables, expr.values)) + "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs)) + for name, val in zip(expr.variables, expr.values, strict=True) + ) - return self.format(r"[%s]\{%s\}", - self.rec(expr.child, PREC_NONE, *args, **kwargs), - substs) + return self.format( + r"[%s]\{%s\}", self.rec(expr.child, PREC_NONE, *args, **kwargs), substs + ) - def map_derivative(self, expr, enclosing_prec, *args, **kwargs): - derivs = " ".join( - r"\frac{\partial}{\partial %s}" % v - for v in expr.variables) + def map_derivative( + self, expr: p.Derivative, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + derivs = " ".join(r"\frac{\partial}{\partial %s}" % v for v in expr.variables) + + return self.format( + "%s %s", derivs, self.rec(expr.child, PREC_PRODUCT, *args, **kwargs) + ) - return self.format("%s %s", - derivs, self.rec(expr.child, PREC_PRODUCT, *args, **kwargs)) # }}} diff --git a/pymbolic/mapper/substitutor.py b/pymbolic/mapper/substitutor.py index 8cdb3e12..2194c1cf 100644 --- a/pymbolic/mapper/substitutor.py +++ b/pymbolic/mapper/substitutor.py @@ -4,7 +4,16 @@ .. autofunction:: make_subst_func .. autofunction:: substitute +.. autoclass:: Callable[[AlgebraicLeaf], ExpressionT | None] + +References +---------- + +.. class:: SupportsGetItem + + A protocol with a ``__getitem__`` method. """ + from __future__ import annotations @@ -29,12 +38,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from collections.abc import Callable +from typing import Any + +from useful_types import SupportsGetItem, SupportsItems from pymbolic.mapper import CachedIdentityMapper, IdentityMapper +from pymbolic.primitives import AlgebraicLeaf +from pymbolic.typing import ExpressionT -class SubstitutionMapper(IdentityMapper): - def __init__(self, subst_func): +class SubstitutionMapper(IdentityMapper[[]]): + def __init__( + self, subst_func: Callable[[AlgebraicLeaf], ExpressionT | None] + ) -> None: self.subst_func = subst_func def map_variable(self, expr): @@ -59,17 +76,26 @@ def map_lookup(self, expr): return IdentityMapper.map_lookup(self, expr) -class CachedSubstitutionMapper(CachedIdentityMapper, - SubstitutionMapper): - def __init__(self, subst_func): - CachedIdentityMapper.__init__(self) +class CachedSubstitutionMapper(CachedIdentityMapper[[]], SubstitutionMapper): + def __init__( + self, subst_func: Callable[[AlgebraicLeaf], ExpressionT | None] + ) -> None: + # FIXME Mypy says: + # error: Argument 1 to "__init__" of "CachedMapper" has incompatible type + # "CachedSubstitutionMapper"; expected "CachedMapper[ResultT, P]" [arg-type] + # This seems spurious? + CachedIdentityMapper.__init__(self) # type: ignore[arg-type] SubstitutionMapper.__init__(self, subst_func) -def make_subst_func(variable_assignments): +def make_subst_func( + # "Any" here avoids the whole Mapping variance disaster + # e.g. https://github.com/python/typing/issues/445 + variable_assignments: SupportsGetItem[Any, ExpressionT], +) -> Callable[[AlgebraicLeaf], ExpressionT | None]: import pymbolic.primitives as primitives - def subst_func(var): + def subst_func(var: AlgebraicLeaf) -> ExpressionT | None: try: return variable_assignments[var] except KeyError: @@ -84,15 +110,23 @@ def subst_func(var): return subst_func -def substitute(expression, variable_assignments=None, - mapper_cls=CachedSubstitutionMapper, **kwargs): +def substitute( + expression: ExpressionT, + variable_assignments: SupportsItems[AlgebraicLeaf | str, ExpressionT] | None = None, + mapper_cls=CachedSubstitutionMapper, + **kwargs: ExpressionT, +): """ :arg mapper_cls: A :class:`type` of the substitution mapper whose instance applies the substitution. """ if variable_assignments is None: - variable_assignments = {} - variable_assignments = variable_assignments.copy() - variable_assignments.update(kwargs) + # "Any" here avoids pointless grief about variance + # e.g. https://github.com/python/typing/issues/445 + v_ass_copied: dict[Any, ExpressionT] = {} + else: + v_ass_copied = dict(variable_assignments.items()) + + v_ass_copied.update(kwargs) - return mapper_cls(make_subst_func(variable_assignments))(expression) + return mapper_cls(make_subst_func(v_ass_copied))(expression) diff --git a/pymbolic/mapper/unifier.py b/pymbolic/mapper/unifier.py index 350eae0e..9409efd3 100644 --- a/pymbolic/mapper/unifier.py +++ b/pymbolic/mapper/unifier.py @@ -23,7 +23,7 @@ THE SOFTWARE. """ -from pymbolic.mapper import RecursiveMapper +from pymbolic.mapper import Mapper from pymbolic.primitives import Variable @@ -92,7 +92,7 @@ def unify_many(unis1, uni2): return result -class UnifierBase(RecursiveMapper): +class UnifierBase(Mapper): # The idea of the algorithm here is that the unifier accumulates a set of # unification possibilities (:class:`UnificationRecord`) as it descends the # expression tree. :func:`unify_many` above then checks if these possibilities @@ -120,7 +120,7 @@ def treat_mismatch(self, expr, other, urecs): raise NotImplementedError def unification_record_from_equation(self, lhs, rhs): - if isinstance(lhs, (tuple, list)) or isinstance(rhs, (tuple, list)): + if isinstance(lhs, tuple | list) or isinstance(rhs, tuple | list): # Always force lists/tuples to agree elementwise, never # generate a unification record between them directly. # This pushes the matching process down to the elementwise @@ -214,7 +214,7 @@ def map_sum(self, expr, other, urecs): for my_child, other_child in zip( expr.children, - (other.children[i] for i in perm)): + (other.children[i] for i in perm), strict=True): it_assignments = self.rec(my_child, other_child, it_assignments) if not it_assignments: break @@ -302,7 +302,7 @@ def map_list(self, expr, other, urecs): or len(expr) != len(other)): return [] - for my_child, other_child in zip(expr, other): + for my_child, other_child in zip(expr, other, strict=True): urecs = self.rec(my_child, other_child, urecs) if not urecs: break @@ -399,7 +399,7 @@ def partitions(s, k): for partition in partitions( other_leftovers, len(plain_var_candidates)): result = urec - for subset, var in zip(partition, plain_var_candidates): + for subset, var in zip(partition, plain_var_candidates, strict=True): rec = self.unification_record_from_equation( var, factory(other.children[i] for i in subset)) result = result.unify(rec) diff --git a/pymbolic/parser.py b/pymbolic/parser.py index bb79859e..fd4248f2 100644 --- a/pymbolic/parser.py +++ b/pymbolic/parser.py @@ -1,5 +1,7 @@ from __future__ import annotations +from pymbolic.typing import ExpressionT + __copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner" @@ -23,11 +25,11 @@ THE SOFTWARE. """ +from collections.abc import Sequence from sys import intern -from typing import ClassVar, Sequence, Tuple, Union +from typing import ClassVar, TypeAlias from immutabledict import immutabledict -from typing_extensions import TypeAlias import pytools.lex from pytools import memoize_method @@ -92,8 +94,8 @@ _PREC_SHIFT = 205 _PREC_PLUS = 210 _PREC_TIMES = 220 -_PREC_POWER = 230 -_PREC_UNARY = 240 +_PREC_UNARY = 230 +_PREC_POWER = 240 _PREC_CALL = 250 @@ -126,7 +128,7 @@ def __hash__(self) -> int: # type: ignore[override] LexTable: TypeAlias = Sequence[ - Tuple[str, Union[pytools.lex.RE, Tuple[Union[str, pytools.lex.RE], ...]]]] + tuple[str, pytools.lex.RE | tuple[str | pytools.lex.RE, ...]]] class Parser: @@ -498,7 +500,7 @@ def parse_postfix(self, pstate, min_precedence, left_exp): pstate.advance() if pstate.is_at_end() or pstate.next_tag() is _closepar: - if isinstance(left_exp, (tuple, list)) \ + if isinstance(left_exp, tuple | list) \ and not isinstance(left_exp, FinalizedContainer): # left_expr is a container with trailing commas pass @@ -506,7 +508,7 @@ def parse_postfix(self, pstate, min_precedence, left_exp): left_exp = (left_exp,) else: new_el = self.parse_expression(pstate, _PREC_COMMA) - if isinstance(left_exp, (tuple, list)) \ + if isinstance(left_exp, tuple | list) \ and not isinstance(left_exp, FinalizedContainer): left_exp = (*left_exp, new_el) else: @@ -559,7 +561,7 @@ def parse_arglist(self, pstate): comma_allowed = True - def __call__(self, expr_str, min_precedence=0): + def __call__(self, expr_str: str, min_precedence: int = 0) -> ExpressionT: lex_result = [(tag, s, idx, match_obj) for (tag, s, idx, match_obj) in pytools.lex.lex( self.lex_table, expr_str, diff --git a/pymbolic/polynomial.py b/pymbolic/polynomial.py deleted file mode 100644 index edd38f35..00000000 --- a/pymbolic/polynomial.py +++ /dev/null @@ -1,368 +0,0 @@ -from __future__ import annotations - - -__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner" - -__license__ = """ -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - -from sys import intern - -import pymbolic -import pymbolic.algorithm as algorithm -from pymbolic.primitives import Expression -from pymbolic.traits import EuclideanRingTraits, FieldTraits, traits - - -def _sort_uniq(data): - def sortkey(key): - exp, _coeff = key - return exp - - data.sort(key=sortkey) - - uniq_result = [] - last_exp = None - for exp, coeff in data: - if last_exp == exp: - newcoeff = uniq_result[-1][1]+coeff - if not newcoeff: - uniq_result.pop() - else: - uniq_result[-1] = last_exp, newcoeff - - else: - uniq_result.append((exp, coeff)) - last_exp = exp - return uniq_result - - -def _get_dependencies(expr): - from pymbolic.mapper.dependency import DependencyMapper - return DependencyMapper()(expr) - - -class LexicalMonomialOrder: - def __call__(self, a, b): - from pymbolic.primitives import Variable - # is a < b? - assert isinstance(a, Variable) and isinstance(b, Variable) - return a.name < b.name - - def __eq__(self, other): - return isinstance(other, LexicalMonomialOrder) - - def __repr__(self): - return "LexicalMonomialOrder()" - - -class Polynomial(Expression): - def __init__(self, base, data=None, unit=1, var_less=None): - if var_less is None: - var_less = LexicalMonomialOrder() - - self.Base = base - self.Unit = unit - self.VarLess = var_less - - # list of (exponent, coefficient tuples) - # sorted in increasing order - # one entry per degree - if data is None: - self.Data = ((1, unit),) - else: - self.Data = tuple(data) - - # Remember the Zen, Luke: Sparse is better than dense. - - def coefficients(self): - return [coeff for (exp, coeff) in self.Data] - - def traits(self): - return PolynomialTraits() - - def __nonzero__(self): - return len(self.Data) != 0 - - def __eq__(self, other): - return (isinstance(other, Polynomial) - and (self.Base == other.Base) - and (self.Data == other.Data)) - - def __ne__(self, other): - return not self.__eq__(other) - - def __neg__(self): - return Polynomial(self.Base, - [(exp, -coeff) - for (exp, coeff) in self.Data]) - - def __add__(self, other): - if not other: - return self - - if not isinstance(other, Polynomial): - other = Polynomial(self.Base, ((0, other),)) - - if other.Base != self.Base: - assert self.VarLess == other.VarLess - - if self.VarLess(self.Base, other.Base): - other = Polynomial(self.Base, ((0, other),)) - else: - return other.__add__(self) - - i_self = 0 - i_other = 0 - - result = [] - while i_self < len(self.Data) and i_other < len(other.Data): - exp_self = self.Data[i_self][0] - exp_other = other.Data[i_other][0] - if exp_self == exp_other: - coeff = self.Data[i_self][1] + other.Data[i_other][1] - if coeff: - result.append((exp_self, coeff)) - i_self += 1 - i_other += 1 - elif exp_self > exp_other: - result.append((exp_other, other.Data[i_other][1])) - i_other += 1 - elif exp_self < exp_other: - result.append((exp_self, self.Data[i_self][1])) - i_self += 1 - - # we have exhausted at least one list, exhaust the other - while i_self < len(self.Data): - exp_self = self.Data[i_self][0] - result.append((exp_self, self.Data[i_self][1])) - i_self += 1 - - while i_other < len(other.Data): - exp_other = other.Data[i_other][0] - result.append((exp_other, other.Data[i_other][1])) - i_other += 1 - - return Polynomial(self.Base, tuple(result)) - - def __radd__(self, other): - return self.__add__(other) - - def __sub__(self, other): - return self+(-other) - - def __rsub__(self, other): - return (-other)+self - - def __mul__(self, other): - if not isinstance(other, Polynomial): - if other == self.Base: - other = Polynomial(self.Base) - else: - return Polynomial(self.Base, [(exp, coeff * other) - for exp, coeff in self.Data]) - - if other.Base != self.Base: - assert self.VarLess == other.VarLess - - if self.VarLess(self.Base, other.Base): - return Polynomial(self.Base, [(exp, coeff * other) - for exp, coeff in self.Data]) - else: - return other.__mul__(self) - - result = [] - for s_exp, s_coeff in self.Data: - for o_exp, o_coeff in other.Data: - result.append((s_exp+o_exp, s_coeff*o_coeff)) - - return Polynomial(self.Base, tuple(_sort_uniq(result))) - - def __rmul__(self, other): - return Polynomial(self.Base, [(exp, other * coeff) - for exp, coeff in self.Data]) - - def __pow__(self, other): - return algorithm.integer_power(self, int(other), - Polynomial(self.Base, ((0, 1),))) - - def __divmod__(self, other): - if not isinstance(other, Polynomial): - dm_list = [(exp, divmod(coeff, other)) for exp, coeff in self.Data] - return ( - Polynomial(self.Base, [(exp, quot) for exp, (quot, _) in dm_list]), - Polynomial(self.Base, [(exp, rem) for exp, (_, rem) in dm_list])) - - if other.Base != self.Base: - assert self.VarLess == other.VarLess - - if self.VarLess(self.Base, other.Base): - dm_list = [(exp, divmod(coeff, other)) for exp, coeff in self.Data] - return ( - Polynomial(self.Base, [(exp, quot) for exp, (quot, _) in dm_list]), - Polynomial(self.Base, [(exp, rem) for exp, (_, rem) in dm_list])) - - else: - other_unit = Polynomial(other.Base, ((0, other.unit),), self.VarLess) - quot, rem = divmod(other_unit, other) - return quot * self, rem * self - - if other.degree == -1: - raise ZeroDivisionError - - quot = Polynomial(self.Base, ()) - rem = self - other_lead_coeff = other.Data[-1][1] - other_lead_exp = other.Data[-1][0] - - coeffs_are_field = isinstance(traits(self.Unit), FieldTraits) - - from pymbolic.primitives import quotient - - while rem.degree >= other.degree: - if coeffs_are_field: - coeff_factor = quotient(rem.Data[-1][1], other_lead_coeff) - else: - coeff_factor, lead_rem = divmod(rem.Data[-1][1], other_lead_coeff) - if lead_rem: - return quot, rem - deg_diff = rem.Data[-1][0] - other_lead_exp - - this_fac = Polynomial(self.Base, ((deg_diff, coeff_factor),)) - quot += this_fac - rem -= this_fac * other - return quot, rem - - def __div__(self, other): - if not isinstance(other, Polynomial): - return 1/other * self - q, r = divmod(self, other) - if r.degree != -1: - raise ValueError("division yielded a remainder") - return q - - __truediv__ = __div__ - - def __floordiv__(self, other): - return self.__divmod__(other)[0] - - def __mod__(self, other): - return self.__divmod__(other)[1] - - def _data(self): - return self.Data - data = property(_data) - - def _base(self): - return self.Base - base = property(_base) - - def _unit(self): - return self.Unit - unit = property(_unit) - - def _degree(self): - try: - return self.Data[-1][0] - except IndexError: - return -1 - degree = property(_degree) - - def __getinitargs__(self): - return (self.Base, self.Data, self.Unit, self.VarLess) - - mapper_method = intern("map_polynomial") - - def as_primitives(self): - deps = _get_dependencies(self) - context = {dep: dep for dep in deps} - return pymbolic.evaluate(self, context) - - def get_coefficient(self, sought_exp): - # FIXME use bisection - for exp, coeff in self.Data: - if exp == sought_exp: - return coeff - return 0 - - -def differentiate(poly): - return Polynomial( - poly.base, - tuple((exp-1, exp*coeff) - for exp, coeff in poly.data - if not exp == 0)) - - -def integrate(poly): - return Polynomial( - poly.base, - tuple((exp+1, pymbolic.quotient(poly.unit, (exp+1))*coeff) - for exp, coeff in poly.data)) - - -def integrate_definite(poly, a, b): - antideriv = integrate(poly) - a_bound = pymbolic.substitute(antideriv, {poly.base: a}) - b_bound = pymbolic.substitute(antideriv, {poly.base: b}) - - from pymbolic.primitives import Sum - return Sum((b_bound, -a_bound)) - - -def leading_coefficient(poly): - return poly.data[-1][1] - - -def general_polynomial(base, coefflist, degree): - return Polynomial(base, - ((i, coefflist[i]) for i in range(degree+1))) - - -class PolynomialTraits(EuclideanRingTraits): - @staticmethod - def norm(x): - return x.degree - - @staticmethod - def get_unit(x): - lc = leading_coefficient(x) - return traits(lc).get_unit(lc) - - -if __name__ == "__main__": - x = Polynomial(pymbolic.var("x")) - y = Polynomial(pymbolic.var("y")) - - u = (x+1)**5 - v = pymbolic.evaluate_kw(u, x=x) - print(u) - print(v) - - if False: - # NOT WORKING INTRODUCE TESTS - u = (x+y)**5 - v = x+y - # u = x+1 - # v = 3*x+1 - q, r = divmod(u, v) - print(q, "R", r) - print(q*v) - print("REASSEMBLY:", q*v + r) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index aa08cf11..9718f6b0 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -24,15 +24,16 @@ """ import re +from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass, fields from sys import intern from typing import ( TYPE_CHECKING, Any, - Callable, ClassVar, - Mapping, NoReturn, + Protocol, + TypeAlias, TypeVar, cast, ) @@ -295,6 +296,10 @@ An instance of a :func:`~dataclasses.dataclass`. +.. class:: ArithmeticExpressionT + + See :class:`pymbolic.ArithmeticExpressionT` + .. class:: _T A type variable. @@ -446,140 +451,84 @@ def init_arg_names(self) -> tuple[str, ...]: # {{{ arithmetic - def __add__(self, other: object) -> ArithmeticExpressionT: + def __add__(self, other: object) -> Sum: if not is_arithmetic_expression(other): return NotImplemented - if is_nonzero(other): - if self: - if isinstance(other, Sum): - return Sum((self, *other.children)) - else: - return Sum((self, other)) - else: - return other - else: - return self + return Sum((self, other)) - def __radd__(self, other: object) -> ArithmeticExpressionT: - assert is_number(other) - if is_nonzero(other): - if self: - return Sum((other, self)) - else: - return other - else: - return self - - def __sub__(self, other: object) -> ArithmeticExpressionT: - if not is_valid_operand(other): + def __radd__(self, other: object) -> Sum: + if not is_arithmetic_expression(other): return NotImplemented + return Sum((other, self)) - if is_nonzero(other): - return self.__add__(-cast(NumberT, other)) - else: - return self - - def __rsub__(self, other: object) -> ArithmeticExpressionT: - if not is_constant(other): + def __sub__(self, other: object) -> Sum: + if not is_arithmetic_expression(other): return NotImplemented + return Sum((self, -other)) - if is_nonzero(other): - return Sum((other, -self)) - else: - return -self + def __rsub__(self, other: object) -> Sum: + if not is_arithmetic_expression(other): + return NotImplemented + return Sum((other, -self)) - def __mul__(self, other: object) -> ArithmeticExpressionT: + def __mul__(self, other: object) -> Product: if not is_valid_operand(other): return NotImplemented - other = cast(NumberT, other) - if is_zero(other - 1): - return self - elif is_zero(other): - return 0 - else: - return Product((self, other)) + return Product((self, other)) - def __rmul__(self, other: object) -> ArithmeticExpressionT: - if not is_constant(other): + def __rmul__(self, other: object) -> Product: + if not is_valid_operand(other): return NotImplemented - if is_zero(other-1): - return self - elif is_zero(other): - return 0 - else: - return Product((other, self)) + return Product((other, self)) - def __div__(self, other: object) -> ArithmeticExpressionT: - if not is_valid_operand(other): + def __truediv__(self, other: object) -> Quotient: + if not is_arithmetic_expression(other): return NotImplemented - other = cast(NumberT, other) - if is_zero(other-1): - return self - return quotient(self, other) - __truediv__ = __div__ + return Quotient(self, other) - def __rdiv__(self, other: object) -> ArithmeticExpressionT: - if not is_valid_operand(other): + def __rtruediv__(self, other: object) -> Quotient: + if not is_arithmetic_expression(other): return NotImplemented - if is_zero(other): - return 0 - return quotient(other, self) - __rtruediv__ = __rdiv__ + return Quotient(other, self) - def __floordiv__(self, other: object) -> ArithmeticExpressionT: - if not is_valid_operand(other): + def __floordiv__(self, other: object) -> FloorDiv: + if not is_arithmetic_expression(other): return NotImplemented - other = cast(NumberT, other) - if is_zero(other-1): - return self return FloorDiv(self, other) - def __rfloordiv__(self, other: object) -> ArithmeticExpressionT: + def __rfloordiv__(self, other: object) -> FloorDiv: if not is_arithmetic_expression(other): return NotImplemented - if is_zero(self-1): - return other return FloorDiv(other, self) - def __mod__(self, other: object) -> ArithmeticExpressionT: - if not is_valid_operand(other): + def __mod__(self, other: object) -> Remainder: + if not is_arithmetic_expression(other): return NotImplemented - other = cast(NumberT, other) - if is_zero(other-1): - return 0 return Remainder(self, other) - def __rmod__(self, other: object) -> ArithmeticExpressionT: - if not is_valid_operand(other): + def __rmod__(self, other: object) -> Remainder: + if not is_arithmetic_expression(other): return NotImplemented return Remainder(other, self) - def __pow__(self, other: object) -> ArithmeticExpressionT: - if not is_valid_operand(other): + def __pow__(self, other: object) -> Power: + if not is_arithmetic_expression(other): return NotImplemented - other = cast(NumberT, other) - if is_zero(other): # exponent zero - return 1 - elif is_zero(other-1): # exponent one - return self return Power(self, other) - def __rpow__(self, other: object) -> ArithmeticExpressionT: - assert is_constant(other) + def __rpow__(self, other: object) -> Power: + if not is_arithmetic_expression(other): + return NotImplemented - if is_zero(other): # base zero - return 0 - elif is_zero(other-1): # base one - return 1 return Power(other, self) # }}} @@ -671,6 +620,9 @@ def __call__(self, *args, **kwargs) -> Call | CallWithKwargs: return Call(self, args) if not TYPE_CHECKING: + # Subscript has an attribute 'index' which can't coexist with this. + # Thus we're hiding this from mypy until it goes away. + def index(self, subscript: Expression) -> Expression: """Return an expression representing ``self[subscript]``. @@ -818,7 +770,7 @@ def __getstate__(self) -> tuple[Any]: def __setstate__(self, state) -> None: # Can't use trivial pickling: _hash_value cache must stay unset assert len(self.init_arg_names) == len(state), type(self) - for name, value in zip(self.init_arg_names, state): + for name, value in zip(self.init_arg_names, state, strict=True): object.__setattr__(self, name, value) # }}} @@ -946,9 +898,13 @@ def __iter__(self): ) +class _HasMapperMethod(Protocol): + mapper_method: ClassVar[str] + + def _augment_expression_dataclass( cls: type[DataclassInstance], - hash: bool, + generate_hash: bool, ) -> None: attr_tuple = ", ".join(f"self.{fld.name}" for fld in fields(cls)) if attr_tuple: @@ -981,8 +937,9 @@ def {cls.__name__}_eq(self, other): return True if self.__class__ is not other.__class__: return False - if hash(self) != hash(other): - return False + if {generate_hash}: + if hash(self) != hash(other): + return False if self.__class__ is not cls and self.init_arg_names != {fld_name_tuple}: warn(f"{{self.__class__}} is derived from {cls}, which is now " f"a dataclass. {{self.__class__}} should be converted to being " @@ -1017,7 +974,7 @@ def {cls.__name__}_hash(self): object.__setattr__(self, "_hash_value", hash_val) return hash_val - if {hash}: + if {generate_hash}: cls.__hash__ = {cls.__name__}_hash @@ -1083,23 +1040,23 @@ def {cls.__name__}_setstate(self, state): # {{{ assign mapper_method - assert issubclass(cls, Expression) + mm_cls = cast(type[_HasMapperMethod], cls) - snake_clsname = _CAMEL_TO_SNAKE_RE.sub("_", cls.__name__).lower() + snake_clsname = _CAMEL_TO_SNAKE_RE.sub("_", mm_cls.__name__).lower() default_mapper_method_name = f"map_{snake_clsname}" # This covers two cases: the class does not have the attribute in the first # place, or it inherits a value but does not set it itself. - sets_mapper_method = "mapper_method" in cls.__dict__ + sets_mapper_method = "mapper_method" in mm_cls.__dict__ if sets_mapper_method: - if default_mapper_method_name == cls.mapper_method: - warn(f"Explicit mapper_method on {cls} not needed, default matches " + if default_mapper_method_name == mm_cls.mapper_method: + warn(f"Explicit mapper_method on {mm_cls} not needed, default matches " "explicit assignment. Just delete the explicit assignment.", stacklevel=3) if not sets_mapper_method: - cls.mapper_method = intern(default_mapper_method_name) + mm_cls.mapper_method = intern(default_mapper_method_name) # }}} @@ -1110,18 +1067,21 @@ def {cls.__name__}_setstate(self, state): @dataclass_transform(frozen_default=True) def expr_dataclass( init: bool = True, - hash: bool = True + hash: bool = True, ) -> Callable[[type[_T]], type[_T]]: - """A class decorator that makes the class a :func:`~dataclasses.dataclass` + r"""A class decorator that makes the class a :func:`~dataclasses.dataclass` while also adding functionality needed for :class:`Expression` nodes. Specifically, it adds cached hashing, equality comparisons with ``self is other`` shortcuts as well as some methods/attributes - for backward compatibility (e.g. ``__getinitargs__``, ``init_arg_names``) + for backward compatibility (e.g. ``__getinitargs__``, ``init_arg_names``). It also adds a :attr:`Expression.mapper_method` based on the class name if not already present. If :attr:`~Expression.mapper_method` is inherited, it will be viewed as unset and replaced. + Note that the class to which this decorator is applied need not be + a subclass of :class:`~pymbolic.Expression`. + .. versionadded:: 2024.1 """ def map_cls(cls: type[_T]) -> type[_T]: @@ -1135,7 +1095,7 @@ def map_cls(cls: type[_T]) -> type[_T]: # It should just understand that? _augment_expression_dataclass( dc_cls, # type: ignore[arg-type] - hash=hash + generate_hash=hash, ) return dc_cls @@ -1395,8 +1355,8 @@ class Max(Expression): @expr_dataclass() class QuotientBase(Expression): - numerator: ExpressionT - denominator: ExpressionT + numerator: ArithmeticExpressionT + denominator: ArithmeticExpressionT @property def num(self): @@ -1446,8 +1406,8 @@ class Power(Expression): .. autoattribute:: exponent """ - base: ExpressionT - exponent: ExpressionT + base: ArithmeticExpressionT + exponent: ArithmeticExpressionT # }}} @@ -1707,6 +1667,12 @@ class Derivative(Expression): variables: tuple[str, ...] +SliceChildrenT: TypeAlias = (tuple[()] + | tuple[ExpressionT | None] + | tuple[ExpressionT | None, ExpressionT | None] + | tuple[ExpressionT | None, ExpressionT | None, ExpressionT | None]) + + @expr_dataclass() class Slice(Expression): """A slice expression as in a[1:7]. @@ -1718,10 +1684,7 @@ class Slice(Expression): .. autoproperty:: step """ - children: (tuple[()] - | tuple[ExpressionT] - | tuple[ExpressionT, ExpressionT] - | tuple[ExpressionT, ExpressionT, ExpressionT]) + children: SliceChildrenT def __bool__(self): return True @@ -1753,7 +1716,7 @@ def step(self): @expr_dataclass() -class NaN(Expression): +class NaN(AlgebraicLeaf): """ An expression node representing not-a-number as a floating point number. Unlike, :data:`math.nan`, all instances of :class:`NaN` compare equal, as @@ -1794,7 +1757,7 @@ def subscript(expression, index): return Subscript(expression, index) -def flattened_sum(terms): +def flattened_sum(terms: Iterable[ArithmeticExpressionT]) -> ArithmeticExpressionT: r"""Recursively flattens all the top level :class:`Sum`\ s in *terms*. :arg terms: an :class:`~collections.abc.Iterable` of expressions. @@ -1811,7 +1774,8 @@ def flattened_sum(terms): continue if isinstance(item, Sum): - queue += item.children + ch = cast(tuple[ArithmeticExpressionT], item.children) + queue.extend(ch) else: done.append(item) @@ -1825,11 +1789,12 @@ def flattened_sum(terms): def linear_combination(coefficients, expressions): return sum(coefficient * expression - for coefficient, expression in zip(coefficients, expressions) + for coefficient, expression + in zip(coefficients, expressions, strict=True) if coefficient and expression) -def flattened_product(terms): +def flattened_product(terms: Iterable[ArithmeticExpressionT]) -> ArithmeticExpressionT: r"""Recursively flattens all the top level :class:`Product`\ s in *terms*. This operation does not change the order of the terms in the products, so @@ -1851,7 +1816,8 @@ def flattened_product(terms): continue if isinstance(item, Product): - queue += item.children + ch = cast(tuple[ArithmeticExpressionT], item.children) + queue.extend(ch) else: done.append(item) @@ -1933,7 +1899,7 @@ def unregister_constant_class(class_): VALID_CONSTANT_CLASSES = tuple(tmp) -def is_nonzero(value): +def is_nonzero(value: object) -> bool: if value is None: raise ValueError("is_nonzero is undefined for None") @@ -1943,12 +1909,12 @@ def is_nonzero(value): return True -def is_zero(value): +def is_zero(value: object) -> bool: return not is_nonzero(value) -def wrap_in_cse(expr, prefix=None): - if isinstance(expr, (Variable, Subscript)): +def wrap_in_cse(expr: ExpressionT, prefix=None) -> ExpressionT: + if isinstance(expr, Variable | Subscript): return expr if isinstance(expr, CommonSubexpression): diff --git a/pymbolic/traits.py b/pymbolic/traits.py index 321d29b0..4411cc2e 100644 --- a/pymbolic/traits.py +++ b/pymbolic/traits.py @@ -40,7 +40,7 @@ def traits(x): try: return x.traits() except AttributeError: - if isinstance(x, (complex, float)): + if isinstance(x, complex | float): return FieldTraits() elif isinstance(x, int): return IntegerTraits() diff --git a/pymbolic/typing.py b/pymbolic/typing.py index fce695d1..a16ed6cf 100644 --- a/pymbolic/typing.py +++ b/pymbolic/typing.py @@ -4,25 +4,28 @@ Typing helpers -------------- -.. autodata:: BoolT -.. autodata:: NumberT -.. autodata:: ScalarT -.. autodata:: ArithmeticExpressionT +.. autoclass:: BoolT +.. autoclass:: NumberT +.. autoclass:: ScalarT +.. autoclass:: ArithmeticExpressionT A narrower type alias than :class:`ExpressionT` that is returned by arithmetic operators, to allow continue doing arithmetic with the result of arithmetic. - > +.. autoclass:: ExpressionT -.. autodata:: ExpressionT +.. currentmodule:: pymbolic.typing + +.. autoclass:: ArithmeticOrExpressionT + + A type variable that can be either :data:`ArithmeticExpressionT` + or :data:`ExpressionT`. """ from __future__ import annotations -from typing import TYPE_CHECKING, Tuple, TypeVar, Union - -from typing_extensions import TypeAlias +from typing import TYPE_CHECKING, TypeAlias, TypeVar, Union # FIXME: This is a lie. Many more constant types (e.g. numpy and such) @@ -50,15 +53,15 @@ # (e.g. 'Unsupported operand types for * ("Decimal" and "Fraction")') # And leaving them out doesn't really make any of this more precise. -_StdlibInexactNumberT = Union[float, complex] +_StdlibInexactNumberT = float | complex if TYPE_CHECKING: # Yes, type-checking pymbolic will require numpy. That's OK. import numpy as np - BoolT = Union[bool, np.bool_] - IntegerT: TypeAlias = Union[int, np.integer] - InexactNumberT: TypeAlias = Union[_StdlibInexactNumberT, np.inexact] + BoolT = bool | np.bool_ + IntegerT: TypeAlias = int | np.integer + InexactNumberT: TypeAlias = _StdlibInexactNumberT | np.inexact else: try: import numpy as np @@ -67,18 +70,23 @@ IntegerT: TypeAlias = int InexactNumberT: TypeAlias = _StdlibInexactNumberT else: - BoolT = Union[bool, np.bool_] - IntegerT: TypeAlias = Union[int, np.integer] - InexactNumberT: TypeAlias = Union[_StdlibInexactNumberT, np.inexact] + BoolT = bool | np.bool_ + IntegerT: TypeAlias = int | np.integer + InexactNumberT: TypeAlias = _StdlibInexactNumberT | np.inexact -NumberT: TypeAlias = Union[IntegerT, InexactNumberT] -ScalarT: TypeAlias = Union[NumberT, BoolT] +NumberT: TypeAlias = IntegerT | InexactNumberT +ScalarT: TypeAlias = NumberT | BoolT _ScalarOrExpression = Union[ScalarT, "Expression"] ArithmeticExpressionT: TypeAlias = Union[NumberT, "Expression"] -ExpressionT: TypeAlias = Union[_ScalarOrExpression, Tuple["ExpressionT", ...]] +ExpressionT: TypeAlias = _ScalarOrExpression | tuple["ExpressionT", ...] + +ArithmeticOrExpressionT = TypeVar( + "ArithmeticOrExpressionT", + ArithmeticExpressionT, + ExpressionT) T = TypeVar("T") diff --git a/pyproject.toml b/pyproject.toml index 6d18ba83..e4f2c6ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ license = { text = "MIT" } authors = [ { name = "Andreas Kloeckner", email = "inform@tiker.net" }, ] -requires-python = ">=3.8" +requires-python = ">=3.10" classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Developers", @@ -29,11 +29,11 @@ classifiers = [ "Topic :: Utilities", ] dependencies = [ - "astunparse; python_version<='3.9'", "immutabledict", "pytools>=2022.1.14", - # for dataclass_transform, TypeAlias - "typing-extensions>=4", + # for dataclass_transform, TypeAlias, deprecated + "typing-extensions>=4.5", + "useful-types", ] [project.optional-dependencies] diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index 729a2202..00000000 --- a/test/conftest.py +++ /dev/null @@ -1,8 +0,0 @@ -from __future__ import annotations - -import sys - - -collect_ignore = [] -if sys.version_info < (3, 10): - collect_ignore.append("test_pattern_match.py") diff --git a/test/test_matchpy.py b/test/test_matchpy.py index d11bf245..27f68932 100644 --- a/test/test_matchpy.py +++ b/test/test_matchpy.py @@ -101,4 +101,5 @@ def test_make_subexpr_subst(): replaced_expr = m.replace_all(subject, [rule]) - assert replaced_expr == flatten(parse("subst(i, j)*a[(k,)]*d[(k,)]")) + ref_expr = flatten(parse("subst(i, j)*a[(k,)]*d[(k,)]")) + assert flatten(replaced_expr) == ref_expr diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index eb8ac768..8dc3a28a 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -1,5 +1,8 @@ from __future__ import annotations +from pymbolic.mapper.stringifier import StringifyMapper +from pymbolic.typing import ExpressionT + __copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner" @@ -329,6 +332,8 @@ def test_parser(): assert_parsed_same_as_python("0 if 1 if 2 else 3 else 4") assert_parsed_same_as_python("0 if (1 if 2 else 3) else 4") assert_parsed_same_as_python("(2, 3,)") + assert_parsed_same_as_python("-3**0.5") + assert_parsed_same_as_python("1/2/7") with pytest.deprecated_call(): parse("1+if(0, 1, 2)") @@ -1027,6 +1032,27 @@ def test_python_ast_interop_roundtrip(): assert ast2p(p2ast(expr)) == expr +# {{{ test derived stringifiers + +@prim.expr_dataclass() +class CustomOperator: + child: ExpressionT + + def make_stringifier(self, originating_stringifier=None): + return OperatorStringifier() + + +class OperatorStringifier(StringifyMapper[[]]): + def map_custom_operator(self, expr: CustomOperator): + return f"Op({self.rec(expr.child)})" + + +def test_derived_stringifier() -> None: + str(CustomOperator(5)) + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: