From ce6da6b8302a23f2c0de707917e4c0f099a256c7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 29 Nov 2024 15:44:04 -0600 Subject: [PATCH] Configure, pass mypy, add to CI Also - Drop SingleGridWorkBalancingPytatoArrayContext - Eliminate mixin-style MPIPytatoArrayContextBase --- .github/workflows/ci.yml | 13 +++ .gitlab-ci.yml | 15 +++- doc/conf.py | 5 ++ grudge/array_context.py | 163 ++++++++++++++----------------------- grudge/discretization.py | 36 ++++---- grudge/dof_desc.py | 2 +- grudge/dt_utils.py | 26 ++++-- grudge/geometry/metrics.py | 7 +- grudge/models/euler.py | 8 +- grudge/op.py | 36 +++++--- grudge/projection.py | 9 +- grudge/py.typed | 0 grudge/reductions.py | 26 +++--- grudge/tools.py | 30 ++++++- grudge/trace_pair.py | 68 ++++++++++++---- pyproject.toml | 6 +- run-mypy.sh | 3 + 17 files changed, 268 insertions(+), 185 deletions(-) create mode 100644 grudge/py.typed create mode 100755 run-mypy.sh diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index eba956715..ed8346974 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,6 +22,19 @@ jobs: pipx install ruff ruff check + mypy: + name: Mypy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: "Main Script" + run: | + curl -L -O https://tiker.net/ci-support-v0 + . ./ci-support-v0 + build_py_project_in_conda_env + python -m pip install mypy + ./run-mypy.sh + typos: name: Typos runs-on: ubuntu-latest diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 16413ab6b..375875efd 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -99,7 +99,7 @@ Documentation: tags: - python3 -Flake8: +Ruff: script: - pipx install ruff - ruff check @@ -108,6 +108,19 @@ Flake8: except: - tags +Mypy: + script: | + EXTRA_INSTALL="Cython mpi4py" + curl -L -O https://tiker.net/ci-support-v0 + . ./ci-support-v0 + build_py_project_in_venv + python -m pip install mypy + ./run-mypy.sh + tags: + - python3 + except: + - tags + Pylint: script: | EXTRA_INSTALL="pybind11 make numpy scipy matplotlib mpi4py" diff --git a/doc/conf.py b/doc/conf.py index 89ba77226..9f475cc58 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -30,3 +30,8 @@ # index-page demo uses pyopencl via plot_directive os.environ["PYOPENCL_TEST"] = "port:cpu" + +nitpick_ignore_regex = [ + ["py:class", r"np\.ndarray"], + ["py:data|py:class", r"arraycontext.*ContainerTc"], +] diff --git a/grudge/array_context.py b/grudge/array_context.py index aca1edc08..50fa90454 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -8,6 +8,9 @@ .. autofunction:: get_reasonable_array_context_class """ +from __future__ import annotations + + __copyright__ = "Copyright (C) 2020 Andreas Kloeckner" __license__ = """ @@ -35,9 +38,11 @@ import logging from collections.abc import Callable, Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from warnings import warn +from typing_extensions import Self + from meshmode.array_context import ( PyOpenCLArrayContext as _PyOpenCLArrayContextBase, PytatoPyOpenCLArrayContext as _PytatoPyOpenCLArrayContextBase, @@ -48,28 +53,6 @@ logger = logging.getLogger(__name__) -try: - # FIXME: temporary workaround while SingleGridWorkBalancingPytatoArrayContext - # is not available in meshmode's main branch - # (it currently needs - # https://github.com/kaushikcfd/meshmode/tree/pytato-array-context-transforms) - from meshmode.array_context import SingleGridWorkBalancingPytatoArrayContext - - try: - # Crude check if we have the correct loopy branch - # (https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms) - from loopy.codegen.result import get_idis_for_kernel # noqa - except ImportError: - # warn("Your loopy and meshmode branches are mismatched. " - # "Please make sure that you have the " - # "https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms " # noqa - # "branch of loopy.") - _HAVE_SINGLE_GRID_WORK_BALANCING = False - else: - _HAVE_SINGLE_GRID_WORK_BALANCING = True - -except ImportError: - _HAVE_SINGLE_GRID_WORK_BALANCING = False try: # FIXME: temporary workaround while FusionContractorArrayContext @@ -119,8 +102,8 @@ class PyOpenCLArrayContext(_PyOpenCLArrayContextBase): to understand :mod:`grudge`-specific transform metadata. (Of which there isn't any, for now.) """ - def __init__(self, queue: "pyopencl.CommandQueue", - allocator: Optional["pyopencl.tools.AllocatorBase"] = None, + def __init__(self, queue: pyopencl.CommandQueue, + allocator: pyopencl.tools.AllocatorBase | None = None, wait_event_queue_length: int | None = None, force_device_scalars: bool = True) -> None: @@ -165,8 +148,8 @@ def __init__(self, queue, allocator=None, # }}} -class MPIBasedArrayContext: - mpi_communicator: "MPI.Comm" +class MPIBasedArrayContext(ArrayContext): + mpi_communicator: MPI.Intracomm # {{{ distributed + pytato @@ -345,13 +328,13 @@ class _DistributedCompiledFunction: type of the callable. """ - actx: "MPISingleGridWorkBalancingPytatoArrayContext" - distributed_partition: "DistributedGraphPartition" - part_id_to_prg: "Mapping[PartId, pt.target.BoundProgram]" + actx: MPIBasedArrayContext + distributed_partition: DistributedGraphPartition + part_id_to_prg: Mapping[PartId, pt.target.BoundProgram] input_id_to_name_in_program: Mapping[tuple[Any, ...], str] output_id_to_name_in_program: Mapping[tuple[Any, ...], str] name_in_program_to_tags: Mapping[str, frozenset[Tag]] - name_in_program_to_axes: Mapping[str, tuple["pt.Axis", ...]] + name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]] output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: @@ -368,10 +351,11 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: self.actx, self.input_id_to_name_in_program, arg_id_to_arg) from pytato import execute_distributed_partition + assert isinstance(self.actx, PytatoPyOpenCLArrayContext | PyOpenCLArrayContext) out_dict = execute_distributed_partition( self.distributed_partition, self.part_id_to_prg, - self.actx.queue, self.actx.mpi_communicator, - allocator=self.actx.allocator, + self.actx.queue, self.actx.mpi_communicator, # pylint: disable=no-member + allocator=self.actx.allocator, # pylint: disable=no-member input_args=input_args_for_prg) def to_output_template(keys, _): @@ -387,42 +371,6 @@ def to_output_template(keys, _): self.output_template) -class MPIPytatoArrayContextBase(MPIBasedArrayContext): - def __init__( - self, mpi_communicator, queue, *, mpi_base_tag, allocator=None, - compile_trace_callback: Callable[[Any, str, Any], None] | None = None, - ) -> None: - """ - :arg compile_trace_callback: A function of three arguments - *(what, stage, ir)*, where *what* identifies the object - being compiled, *stage* is a string describing the compilation - pass, and *ir* is an object containing the intermediate - representation. This interface should be considered - unstable. - """ - if allocator is None: - warn("No memory allocator specified, please pass one. " - "(Preferably a pyopencl.tools.MemoryPool in order " - "to reduce device allocations)", stacklevel=2) - - super().__init__(queue, allocator, - compile_trace_callback=compile_trace_callback) - - self.mpi_communicator = mpi_communicator - self.mpi_base_tag = mpi_base_tag - - # FIXME: implement distributed-aware freeze - - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: - return _DistributedLazilyPyOpenCLCompilingFunctionCaller(self, f) - - def clone(self): - # type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member - # pylint: disable=no-member - return type(self)(self.mpi_communicator, self.queue, - mpi_base_tag=self.mpi_base_tag, - allocator=self.allocator) - # }}} @@ -437,8 +385,8 @@ class MPIPyOpenCLArrayContext(PyOpenCLArrayContext, MPIBasedArrayContext): def __init__(self, mpi_communicator, - queue: "pyopencl.CommandQueue", - *, allocator: Optional["pyopencl.tools.AllocatorBase"] = None, + queue: pyopencl.CommandQueue, + *, allocator: pyopencl.tools.AllocatorBase | None = None, wait_event_queue_length: int | None = None, force_device_scalars: bool = True) -> None: """ @@ -451,7 +399,7 @@ def __init__(self, self.mpi_communicator = mpi_communicator - def clone(self): + def clone(self) -> Self: # type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member # pylint: disable=no-member return type(self)(self.mpi_communicator, self.queue, @@ -476,7 +424,7 @@ def __init__(self, mpi_communicator) -> None: self.mpi_communicator = mpi_communicator - def clone(self): + def clone(self) -> Self: return type(self)(self.mpi_communicator) # }}} @@ -485,28 +433,50 @@ def clone(self): # {{{ distributed + pytato array context subclasses class MPIBasePytatoPyOpenCLArrayContext( - MPIPytatoArrayContextBase, PytatoPyOpenCLArrayContext): + MPIBasedArrayContext, PytatoPyOpenCLArrayContext): """ .. autofunction:: __init__ """ - pass - - -if _HAVE_SINGLE_GRID_WORK_BALANCING: - class MPISingleGridWorkBalancingPytatoArrayContext( - MPIPytatoArrayContextBase, SingleGridWorkBalancingPytatoArrayContext): + def __init__( + self, mpi_communicator, queue, *, mpi_base_tag, allocator=None, + compile_trace_callback: Callable[[Any, str, Any], None] | None = None, + ) -> None: """ - .. autofunction:: __init__ + :arg compile_trace_callback: A function of three arguments + *(what, stage, ir)*, where *what* identifies the object + being compiled, *stage* is a string describing the compilation + pass, and *ir* is an object containing the intermediate + representation. This interface should be considered + unstable. """ + if allocator is None: + warn("No memory allocator specified, please pass one. " + "(Preferably a pyopencl.tools.MemoryPool in order " + "to reduce device allocations)", stacklevel=2) - MPIPytatoArrayContext = MPISingleGridWorkBalancingPytatoArrayContext -else: - MPIPytatoArrayContext = MPIBasePytatoPyOpenCLArrayContext + super().__init__(queue, allocator, + compile_trace_callback=compile_trace_callback) + + self.mpi_communicator = mpi_communicator + self.mpi_base_tag = mpi_base_tag + + # FIXME: implement distributed-aware freeze + + def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + return _DistributedLazilyPyOpenCLCompilingFunctionCaller(self, f) + + def clone(self) -> Self: + return type(self)(self.mpi_communicator, self.queue, + mpi_base_tag=self.mpi_base_tag, + allocator=self.allocator) + + +MPIPytatoArrayContext: type[MPIBasedArrayContext] = MPIBasePytatoPyOpenCLArrayContext if _HAVE_FUSION_ACTX: class MPIFusionContractorArrayContext( - MPIPytatoArrayContextBase, FusionContractorArrayContext): + MPIBasePytatoPyOpenCLArrayContext, FusionContractorArrayContext): """ .. autofunction:: __init__ """ @@ -570,25 +540,14 @@ def __call__(self): def _get_single_grid_pytato_actx_class(distributed: bool) -> type[ArrayContext]: - if not _HAVE_SINGLE_GRID_WORK_BALANCING: - warn("No device-parallel actx available, execution will be slow. " - "Please make sure you have the right branches for loopy " - "(https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms) " # noqa - "and meshmode " - "(https://github.com/kaushikcfd/meshmode/tree/pytato-array-context-transforms).", - stacklevel=1) + warn("No device-parallel actx available, execution will be slow.", + stacklevel=1) # lazy, non-distributed if not distributed: - if _HAVE_SINGLE_GRID_WORK_BALANCING: - return SingleGridWorkBalancingPytatoArrayContext - else: - return PytatoPyOpenCLArrayContext + return PytatoPyOpenCLArrayContext else: # distributed+lazy: - if _HAVE_SINGLE_GRID_WORK_BALANCING: - return MPISingleGridWorkBalancingPytatoArrayContext - else: - return MPIBasePytatoPyOpenCLArrayContext + return MPIBasePytatoPyOpenCLArrayContext def get_reasonable_array_context_class( @@ -603,7 +562,7 @@ def get_reasonable_array_context_class( if numpy: assert not (lazy or fusion) if distributed: - actx_class = MPINumpyArrayContext + actx_class: type[ArrayContext] = MPINumpyArrayContext else: actx_class = NumpyArrayContext @@ -641,7 +600,7 @@ def get_reasonable_array_context_class( "device-parallel=%r", actx_class.__name__, lazy, distributed, # eager is always device-parallel: - (_HAVE_SINGLE_GRID_WORK_BALANCING or _HAVE_FUSION_ACTX or not lazy)) + (_HAVE_FUSION_ACTX or not lazy)) return actx_class # }}} diff --git a/grudge/discretization.py b/grudge/discretization.py index 4c5e36f7a..5587bfeb7 100644 --- a/grudge/discretization.py +++ b/grudge/discretization.py @@ -34,7 +34,7 @@ """ from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, cast from warnings import warn import numpy as np @@ -44,6 +44,7 @@ from meshmode.discretization.connection import ( FACE_RESTR_ALL, FACE_RESTR_INTERIOR, + DirectDiscretizationConnection, DiscretizationConnection, make_face_restriction, ) @@ -52,7 +53,7 @@ ModalGroupFactory, ) from meshmode.dof_array import DOFArray -from meshmode.mesh import BTAG_PARTITION, Mesh +from meshmode.mesh import BTAG_PARTITION, Mesh, ModepyElementGroup from pytools import memoize_method, single_valued from grudge.dof_desc import ( @@ -82,7 +83,6 @@ # {{{ discr_tag_to_group_factory normalization def _normalize_discr_tag_to_group_factory( - dim: int, discr_tag_to_group_factory: TagToElementGroupFactory | None, order: int | None ) -> TagToElementGroupFactory: @@ -206,7 +206,6 @@ def __init__(self, array_context: ArrayContext, mesh = volume_discrs discr_tag_to_group_factory = _normalize_discr_tag_to_group_factory( - dim=mesh.dim, discr_tag_to_group_factory=discr_tag_to_group_factory, order=order) self._discr_tag_to_group_factory = discr_tag_to_group_factory @@ -363,7 +362,7 @@ def _has_affine_groups(self, domain_tag: DomainTag) -> bool: discr = self.discr_from_dd(DOFDesc(domain_tag, DISCR_TAG_BASE)) return any( megrp.is_affine - and issubclass(megrp._modepy_shape_cls, Simplex) + and issubclass(cast(ModepyElementGroup, megrp).shape_cls, Simplex) for megrp in discr.mesh.groups) @memoize_method @@ -468,6 +467,8 @@ def connection_from_dds( make_face_to_all_faces_embedding, ) + assert isinstance(faces_conn, DirectDiscretizationConnection) + return make_face_to_all_faces_embedding( self._setup_actx, faces_conn, self.discr_from_dd(to_dd), @@ -789,7 +790,7 @@ def normal(self, dd): def make_discretization_collection( array_context: ArrayContext, - volumes: MeshOrDiscr | Mapping[VolumeTag, MeshOrDiscr], + volumes: Mesh | Mapping[VolumeTag, Mesh], order: int | None = None, discr_tag_to_group_factory: TagToElementGroupFactory | None = None, ) -> DiscretizationCollection: @@ -826,34 +827,31 @@ def make_discretization_collection( i.e. all ranks in the communicator must enter this function at the same time. """ - if isinstance(volumes, Mesh | Discretization): - volumes = {VTAG_ALL: volumes} + if not isinstance(volumes, Mesh): + volumes_dict = volumes + else: + volumes_dict = {VTAG_ALL: volumes} from pytools import is_single_valued - assert len(volumes) > 0 - assert is_single_valued(mesh_or_discr.ambient_dim - for mesh_or_discr in volumes.values()) + assert len(volumes_dict) > 0 + if not is_single_valued(mesh_or_discr.ambient_dim + for mesh_or_discr in volumes_dict.values()): + raise ValueError("all parts of a discretization collection must share " + "an ambient dimension") discr_tag_to_group_factory = _normalize_discr_tag_to_group_factory( - dim=single_valued( - mesh_or_discr.dim for mesh_or_discr in volumes.values()), discr_tag_to_group_factory=discr_tag_to_group_factory, order=order) del order - if any( - isinstance(mesh_or_discr, Discretization) - for mesh_or_discr in volumes.values()): - raise NotImplementedError("Doesn't work at the moment") - volume_discrs = { vtag: Discretization( array_context, mesh, discr_tag_to_group_factory[DISCR_TAG_BASE]) - for vtag, mesh in volumes.items()} + for vtag, mesh in volumes_dict.items()} return DiscretizationCollection( array_context=array_context, diff --git a/grudge/dof_desc.py b/grudge/dof_desc.py index 18cb21627..2fefbffaa 100644 --- a/grudge/dof_desc.py +++ b/grudge/dof_desc.py @@ -233,7 +233,7 @@ class DOFDesc: def __init__(self, domain_tag: Any, - discretization_tag: type[DiscretizationTag] | None = None) -> None: + discretization_tag: DiscretizationTag | None = None) -> None: if ( not isinstance(domain_tag, diff --git a/grudge/dt_utils.py b/grudge/dt_utils.py index 36aac52bf..62390060b 100644 --- a/grudge/dt_utils.py +++ b/grudge/dt_utils.py @@ -43,11 +43,13 @@ """ from collections.abc import Sequence +from typing import cast import numpy as np from arraycontext import ArrayContext, Scalar, tag_axes from arraycontext.metadata import NameHint +from meshmode.discretization import NodalElementGroupBase from meshmode.dof_array import DOFArray from meshmode.transform_metadata import ( DiscretizationDOFAxisTag, @@ -64,6 +66,7 @@ FACE_RESTR_ALL, BoundaryDomainTag, DOFDesc, + ScalarDomainTag, as_dofdesc, ) @@ -120,7 +123,7 @@ def _compute_characteristic_lengthscales(): @memoize_on_first_arg def dt_non_geometric_factors( dcoll: DiscretizationCollection, dd: DOFDesc | None = None - ) -> Sequence[float]: + ) -> Sequence[float | np.floating]: r"""Computes the non-geometric scale factors following [Hesthaven_2008]_, section 6.4, for each element group in the *dd* discretization: @@ -140,8 +143,10 @@ def dt_non_geometric_factors( dd = DD_VOLUME_ALL discr = dcoll.discr_from_dd(dd) - min_delta_rs = [] + min_delta_rs: list[np.floating | float] = [] for grp in discr.groups: + assert isinstance(grp, NodalElementGroupBase) + nodes = np.asarray(list(zip(*grp.unit_nodes, strict=True))) nnodes = grp.nunit_dofs @@ -157,7 +162,7 @@ def dt_non_geometric_factors( else: min_delta_rs.append( min( - np.linalg.norm(nodes[i] - nodes[j]) + float(np.linalg.norm(nodes[i] - nodes[j])) for i in range(nnodes) for j in range(nnodes) if i != j ) ) @@ -263,6 +268,9 @@ def dt_geometric_factors( actx = dcoll._setup_actx volm_discr = dcoll.discr_from_dd(dd) + if isinstance(dd.domain_tag, ScalarDomainTag): + raise TypeError("not sensible for scalar domains") + if any(not isinstance(grp, SimplexElementGroupBase) for grp in volm_discr.groups): raise NotImplementedError( @@ -275,10 +283,10 @@ def dt_geometric_factors( "time step estimation is not necessarily valid for non-volume-" "filling discretizations. Continuing anyway.", stacklevel=3) - cell_vols = abs( - op.elementwise_integral( + cell_vols: DOFArray = abs( + cast(DOFArray, op.elementwise_integral( dcoll, dd, volm_discr.zeros(actx) + 1.0 - ) + )) ) if dcoll.dim == 1: @@ -290,10 +298,10 @@ def dt_geometric_factors( face_discr = dcoll.discr_from_dd(dd_face) # Compute areas of each face - face_areas = abs( - op.elementwise_integral( + face_areas: DOFArray = abs( + cast(DOFArray, op.elementwise_integral( dcoll, dd_face, face_discr.zeros(actx) + 1.0 - ) + )) ) if actx.supports_nonscalar_broadcasting: diff --git a/grudge/geometry/metrics.py b/grudge/geometry/metrics.py index b36e3ccd1..af8267cee 100644 --- a/grudge/geometry/metrics.py +++ b/grudge/geometry/metrics.py @@ -62,6 +62,7 @@ from arraycontext import ArrayContext, register_multivector_as_array_container, tag_axes from arraycontext.metadata import NameHint +from meshmode.discretization.connection import DirectDiscretizationConnection from meshmode.dof_array import DOFArray from meshmode.transform_metadata import ( DiscretizationAmbientDimAxisTag, @@ -558,6 +559,7 @@ def _signed_face_ones( all_faces_conn = dcoll.connection_from_dds( dd_base.untrace(), dd_base ) + assert isinstance(all_faces_conn, DirectDiscretizationConnection) signed_ones = dcoll.discr_from_dd(dd.with_discr_tag(DISCR_TAG_BASE)).zeros( actx, dtype=dcoll.real_dtype ) + 1 @@ -566,6 +568,7 @@ def _signed_face_ones( for igrp, grp in enumerate(all_faces_conn.groups): for batch in grp.batches: + assert batch.to_element_face is not None i = actx.to_numpy(actx.thaw(batch.to_element_indices)) grp_field = _signed_face_ones_numpy[igrp].reshape(-1) grp_field[i] = \ @@ -831,7 +834,7 @@ def second_fundamental_form( normal = rel_mv_normal(actx, dcoll, dd=dd).as_vector(dtype=object) if dim == 1: - second_ref_axes = [((0, 2),)] + second_ref_axes: list[tuple[tuple[int, int], ...]] = [((0, 2),)] elif dim == 2: second_ref_axes = [((0, 2),), ((0, 1), (1, 1)), ((1, 2),)] else: @@ -879,7 +882,7 @@ def shape_operator(actx: ArrayContext, dcoll: DiscretizationCollection, def summed_curvature(actx: ArrayContext, dcoll: DiscretizationCollection, - dd: DOFDesc | None = None) -> DOFArray: + dd: DOFDesc | None = None) -> DOFArray | float: r"""Computes the sum of the principal curvatures: .. math:: diff --git a/grudge/models/euler.py b/grudge/models/euler.py index ef55ad63e..66993e5af 100644 --- a/grudge/models/euler.py +++ b/grudge/models/euler.py @@ -75,8 +75,9 @@ @dataclass_array_container @dataclass(frozen=True) class ConservedEulerField: - mass: DOFArray - energy: DOFArray + # mass and energy become arrays when computing fluxes. + mass: DOFArray | np.ndarray + energy: DOFArray | np.ndarray momentum: np.ndarray @property @@ -181,7 +182,6 @@ def boundary_tpair( dcoll: DiscretizationCollection, dd_bc: DOFDesc, state: ConservedEulerField, t=0): - actx = state.array_context dd_base = as_dofdesc("vol", DISCR_TAG_BASE) return TracePair( @@ -359,7 +359,7 @@ def interp_to_quad(u): return op.inverse_mass( dcoll, - volume_fluxes - op.face_mass(dcoll, df, interface_fluxes) + volume_fluxes - op.face_mass(dcoll, df, interface_fluxes) # type: ignore[operator] ) # }}} diff --git a/grudge/op.py b/grudge/op.py index 1937f35d4..7b0aaf071 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -38,13 +38,11 @@ .. class:: ArrayOrContainer - See :class:`arraycontext.ArrayOrContainer`. + See :data:`arraycontext.ArrayOrContainer`. """ from __future__ import annotations -from meshmode.discretization import InterpolatoryElementGroupBase, NodalElementGroupBase - __copyright__ = """ Copyright (C) 2021 Andreas Kloeckner @@ -72,12 +70,23 @@ """ +from collections.abc import Hashable from functools import partial import numpy as np import modepy as mp -from arraycontext import ArrayContext, ArrayOrContainer, map_array_container, tag_axes +from arraycontext import ( + Array, + ArrayContext, + ArrayOrContainer, + map_array_container, + tag_axes, +) +from meshmode.discretization import ( + InterpolatoryElementGroupBase, + NodalElementGroupBase, +) from meshmode.dof_array import DOFArray from meshmode.transform_metadata import ( DiscretizationDOFAxisTag, @@ -117,7 +126,7 @@ from grudge.trace_pair import ( bdry_trace_pair, bv_trace_pair, - connected_ranks, + connected_parts, cross_rank_trace_pairs, interior_trace_pair, interior_trace_pairs, @@ -130,7 +139,7 @@ __all__ = ( "bdry_trace_pair", "bv_trace_pair", - "connected_ranks", + "connected_parts", "cross_rank_trace_pairs", "elementwise_integral", "elementwise_max", @@ -257,16 +266,21 @@ def _divergence_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec def _reference_derivative_matrices(actx: ArrayContext, out_element_group: NodalElementGroupBase, - in_element_group: InterpolatoryElementGroupBase): + in_element_group: InterpolatoryElementGroupBase) -> Array: + + def memoize_key( + out_grp: NodalElementGroupBase, in_grp: InterpolatoryElementGroupBase + ) -> Hashable: + return ( + out_grp.discretization_key(), + in_grp.discretization_key()) @keyed_memoize_in( actx, _reference_derivative_matrices, - lambda outgrp, ingrp: ( - outgrp.discretization_key(), - ingrp.discretization_key())) + memoize_key) def get_ref_derivative_mats( out_grp: NodalElementGroupBase, - in_grp: InterpolatoryElementGroupBase): + in_grp: InterpolatoryElementGroupBase) -> Array: return actx.freeze( actx.tag_axis( 1, DiscretizationDOFAxisTag(), diff --git a/grudge/projection.py b/grudge/projection.py index d5ebdd689..1372ba608 100644 --- a/grudge/projection.py +++ b/grudge/projection.py @@ -35,7 +35,7 @@ """ -from arraycontext import ArrayOrContainer +from arraycontext import ArrayOrContainerOrScalarT from grudge.discretization import DiscretizationCollection from grudge.dof_desc import ( @@ -47,9 +47,10 @@ def project( - dcoll: DiscretizationCollection, - src: ConvertibleToDOFDesc, - tgt: ConvertibleToDOFDesc, vec) -> ArrayOrContainer: + dcoll: DiscretizationCollection, + src: ConvertibleToDOFDesc, + tgt: ConvertibleToDOFDesc, vec: ArrayOrContainerOrScalarT + ) -> ArrayOrContainerOrScalarT: """Project from one discretization to another, e.g. from the volume to the boundary, or from the base to the an overintegrated quadrature discretization. diff --git a/grudge/py.typed b/grudge/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/grudge/reductions.py b/grudge/reductions.py index 6dcde1316..efe078af3 100644 --- a/grudge/reductions.py +++ b/grudge/reductions.py @@ -59,7 +59,6 @@ THE SOFTWARE. """ - from functools import partial, reduce import numpy as np @@ -77,6 +76,7 @@ DiscretizationDOFAxisTag, DiscretizationElementAxisTag, ) +from pymbolic import Number, RealNumber from pytools import memoize_in import grudge.dof_desc as dof_desc @@ -86,7 +86,7 @@ # {{{ Nodal reductions -def norm(dcoll: DiscretizationCollection, vec, p, dd=None) -> Scalar: +def norm(dcoll: DiscretizationCollection, vec, p, dd=None) -> RealNumber: r"""Return the vector p-norm of a function represented by its vector of degrees of freedom *vec*. @@ -104,6 +104,8 @@ def norm(dcoll: DiscretizationCollection, vec, p, dd=None) -> Scalar: from arraycontext import get_container_context_recursively actx = get_container_context_recursively(vec) + assert actx is not None + dd = dof_desc.as_dofdesc(dd) if p == 2: @@ -120,7 +122,7 @@ def norm(dcoll: DiscretizationCollection, vec, p, dd=None) -> Scalar: raise ValueError("unsupported norm order") -def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> Scalar: +def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> Number: r"""Return the nodal sum of a vector of degrees of freedom *vec*. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value @@ -144,7 +146,7 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> Scalar: comm.allreduce(actx.to_numpy(nodal_sum_loc(dcoll, dd, vec)), op=MPI.SUM)) -def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> Scalar: +def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> Number: r"""Return the rank-local nodal sum of a vector of degrees of freedom *vec*. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value @@ -166,7 +168,7 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> Scalar: for grp_ary in vec) -def nodal_min(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scalar: +def nodal_min(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> RealNumber: r"""Return the nodal minimum of a vector of degrees of freedom *vec*. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value @@ -194,7 +196,7 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scal def nodal_min_loc( - dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scalar: + dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> RealNumber: r"""Return the rank-local nodal minimum of a vector of degrees of freedom *vec*. @@ -206,10 +208,10 @@ def nodal_min_loc( :returns: a scalar denoting the rank-local nodal minimum. """ if not isinstance(vec, DOFArray): - return min( + return np.min([ nodal_min_loc(dcoll, dd, comp, initial=initial) for _, comp in serialize_container(vec) - ) + ]) actx = vec.array_context @@ -226,7 +228,7 @@ def nodal_min_loc( vec, initial) -def nodal_max(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scalar: +def nodal_max(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> RealNumber: r"""Return the nodal maximum of a vector of degrees of freedom *vec*. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value @@ -254,7 +256,7 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scal def nodal_max_loc( - dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scalar: + dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> RealNumber: r"""Return the rank-local nodal maximum of a vector of degrees of freedom *vec*. @@ -266,10 +268,10 @@ def nodal_max_loc( :returns: a scalar denoting the rank-local nodal maximum. """ if not isinstance(vec, DOFArray): - return max( + return np.max([ nodal_max_loc(dcoll, dd, comp, initial=initial) for _, comp in serialize_container(vec) - ) + ]) actx = vec.array_context diff --git a/grudge/tools.py b/grudge/tools.py index 69405d62d..892662e84 100644 --- a/grudge/tools.py +++ b/grudge/tools.py @@ -2,6 +2,20 @@ .. autofunction:: build_jacobian .. autofunction:: map_subarrays .. autofunction:: rec_map_subarrays + +Links to canonical locations of external symbols +------------------------------------------------ + +(This section only exists because Sphinx does not appear able to resolve +these symbols correctly.) + +.. class:: ArrayContext + + See :class:`arraycontext.ArrayContext`. + +.. class:: ArrayOrArithContainerTc + + See :data:`arraycontext.context.ArrayOrArithContainerTc`. """ from __future__ import annotations @@ -35,7 +49,13 @@ import numpy as np -from arraycontext import ArrayContext, ArrayOrContainer, ArrayOrContainerT +from arraycontext import ( + ArrayContext, + ArrayOrContainer, +) +from arraycontext.context import ( + ArrayOrArithContainerTc, +) from pytools import product @@ -43,8 +63,8 @@ def build_jacobian( actx: ArrayContext, - f: Callable[[ArrayOrContainerT], ArrayOrContainerT], - base_state: ArrayOrContainerT, + f: Callable[[ArrayOrArithContainerTc], ArrayOrArithContainerTc], + base_state: ArrayOrArithContainerTc, stepsize: float) -> np.ndarray: """Returns a Jacobian matrix of *f* determined by a one-sided finite difference approximation with *stepsize*. @@ -72,7 +92,9 @@ def build_jacobian( f_unit_i = f(base_state + unflatten( base_state, actx.from_numpy(unit_i_flat), actx)) - mat[:, i] = actx.to_numpy(flatten((f_unit_i - f_base) / stepsize, actx)) + mat[:, i] = actx.to_numpy(flatten(( + f_unit_i - f_base + ) / stepsize, actx)) return mat diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 1903a6a2f..235aebe27 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -24,6 +24,24 @@ .. autofunction:: interior_trace_pairs .. autofunction:: local_interior_trace_pair .. autofunction:: cross_rank_trace_pairs + +Links to canonical locations of external symbols +------------------------------------------------ + +(This section only exists because Sphinx does not appear able to resolve +these symbols correctly.) + +.. class:: Array + + See :class:`arraycontext.Array`. + +.. class:: ArrayContainer + + See :class:`arraycontext.ArrayContainer`. + +.. class:: ArrayOrArithContainer + + See :data:`arraycontext.ArrayOrArithContainer`. """ __copyright__ = """ @@ -50,10 +68,10 @@ THE SOFTWARE. """ -from collections.abc import Hashable +from collections.abc import Hashable, Sequence from dataclasses import dataclass from numbers import Number -from typing import Any +from typing import cast from warnings import warn import numpy as np @@ -69,18 +87,21 @@ unflatten, with_container_arithmetic, ) -from meshmode.mesh import BTAG_PARTITION +from meshmode.mesh import BTAG_PARTITION, PartID from pytools import memoize_on_first_arg -from pytools.persistent_dict import KeyBuilder +from pytools.persistent_dict import Hash, KeyBuilder import grudge.dof_desc as dof_desc +from grudge.array_context import MPIBasedArrayContext from grudge.discretization import DiscretizationCollection from grudge.dof_desc import ( DD_VOLUME_ALL, DISCR_TAG_BASE, FACE_RESTR_INTERIOR, + BoundaryDomainTag, ConvertibleToDOFDesc, DOFDesc, + ScalarDomainTag, VolumeDomainTag, ) from grudge.projection import project @@ -290,6 +311,7 @@ def local_interior_trace_pair( interior = project(dcoll, volume_dd, trace_dd, vec) + assert isinstance(trace_dd.domain_tag, BoundaryDomainTag) opposite_face_conn = dcoll.opposite_face_connection(trace_dd.domain_tag) def get_opposite_trace(ary): @@ -350,17 +372,19 @@ def interior_trace_pairs( # {{{ distributed: helper functions class _TagKeyBuilder(KeyBuilder): - def update_for_type(self, key_hash, key: type[Any]): + def update_for_type(self, key_hash: Hash, key: type) -> None: self.rec(key_hash, (key.__module__, key.__name__, key.__name__,)) @memoize_on_first_arg -def connected_ranks( +def connected_parts( dcoll: DiscretizationCollection, - volume_dd: DOFDesc | None = None): + volume_dd: DOFDesc | None = None) -> Sequence[PartID]: if volume_dd is None: volume_dd = DD_VOLUME_ALL + if isinstance(volume_dd.domain_tag, ScalarDomainTag): + return [] from meshmode.distributed import get_connected_parts return get_connected_parts( dcoll._volume_discrs[volume_dd.domain_tag.tag].mesh) @@ -380,6 +404,7 @@ def _sym_tag_to_num_tag(comm_tag: Hashable | None) -> int | None: from mpi4py import MPI tag_ub = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB) + assert tag_ub is not None key_builder = _TagKeyBuilder() digest = key_builder(comm_tag) @@ -402,11 +427,11 @@ class _RankBoundaryCommunicationEager: base_comm_tag = 1273 def __init__(self, + actx: MPIBasedArrayContext, dcoll: DiscretizationCollection, array_container: ArrayOrContainer, - remote_rank, comm_tag: int | None = None, + remote_rank, comm_tag: Hashable = None, volume_dd=DD_VOLUME_ALL): - actx = get_container_context_recursively(array_container) bdry_dd = volume_dd.trace(BTAG_PARTITION(remote_rank)) local_bdry_data = project(dcoll, volume_dd, bdry_dd, array_container) @@ -484,6 +509,7 @@ def finish(self): class _RankBoundaryCommunicationLazy: def __init__(self, + actx: MPIBasedArrayContext, dcoll: DiscretizationCollection, array_container: ArrayOrContainer, remote_rank: int, comm_tag: Hashable, @@ -589,21 +615,33 @@ def cross_rank_trace_pairs( return [TracePair( volume_dd.trace(BTAG_PARTITION(remote_rank)), interior=ary, exterior=ary) - for remote_rank in connected_ranks(dcoll, volume_dd=volume_dd)] + for remote_rank in connected_parts(dcoll, volume_dd=volume_dd)] actx = get_container_context_recursively(ary) - from grudge.array_context import MPIPytatoArrayContextBase + from grudge.array_context import MPIBasePytatoPyOpenCLArrayContext - if isinstance(actx, MPIPytatoArrayContextBase): - rbc_class = _RankBoundaryCommunicationLazy + if isinstance(actx, MPIBasePytatoPyOpenCLArrayContext): + rbc_class: type[ + _RankBoundaryCommunicationEager | _RankBoundaryCommunicationLazy + ] = _RankBoundaryCommunicationLazy else: rbc_class = _RankBoundaryCommunicationEager + cparts = connected_parts(dcoll, volume_dd=volume_dd) + + if not cparts: + return [] + assert isinstance(actx, MPIBasedArrayContext) + # Initialize and post all sends/receives rank_bdry_communicators = [ - rbc_class(dcoll, ary, remote_rank, comm_tag=comm_tag, volume_dd=volume_dd) - for remote_rank in connected_ranks(dcoll, volume_dd=volume_dd) + rbc_class(actx, dcoll, ary, + # FIXME: This is a casualty of incomplete multi-volume support + # for now. + cast(int, remote_rank), + comm_tag=comm_tag, volume_dd=volume_dd) + for remote_rank in cparts ] # Complete send/receives and return communicated data diff --git a/pyproject.toml b/pyproject.toml index 165f991d5..f1ff5ceb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "modepy>=2021.1", "pymbolic>=2022.2", "pyopencl>=2022.1", - "pytools>=2024.1.3", + "pytools>=2024.1.18", ] [project.optional-dependencies] @@ -119,7 +119,11 @@ lines-after-imports = 2 [tool.mypy] python_version = "3.10" +ignore_missing_imports = true warn_unused_ignores = true +# TODO: enable this at some point +# check_untyped_defs = true + [tool.typos.default] extend-ignore-re = [ diff --git a/run-mypy.sh b/run-mypy.sh new file mode 100755 index 000000000..c0b5ae38d --- /dev/null +++ b/run-mypy.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +python -m mypy grudge