Skip to content

Commit

Permalink
Configure, pass mypy, add to CI
Browse files Browse the repository at this point in the history
Also

- Drop SingleGridWorkBalancingPytatoArrayContext
- Eliminate mixin-style MPIPytatoArrayContextBase
  • Loading branch information
inducer committed Dec 2, 2024
1 parent 52da71a commit ce6da6b
Show file tree
Hide file tree
Showing 17 changed files with 268 additions and 185 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ jobs:
pipx install ruff
ruff check
mypy:
name: Mypy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: "Main Script"
run: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_conda_env
python -m pip install mypy
./run-mypy.sh
typos:
name: Typos
runs-on: ubuntu-latest
Expand Down
15 changes: 14 additions & 1 deletion .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Documentation:
tags:
- python3

Flake8:
Ruff:
script:
- pipx install ruff
- ruff check
Expand All @@ -108,6 +108,19 @@ Flake8:
except:
- tags

Mypy:
script: |
EXTRA_INSTALL="Cython mpi4py"
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_venv
python -m pip install mypy
./run-mypy.sh
tags:
- python3
except:
- tags

Pylint:
script: |
EXTRA_INSTALL="pybind11 make numpy scipy matplotlib mpi4py"
Expand Down
5 changes: 5 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@

# index-page demo uses pyopencl via plot_directive
os.environ["PYOPENCL_TEST"] = "port:cpu"

nitpick_ignore_regex = [
["py:class", r"np\.ndarray"],
["py:data|py:class", r"arraycontext.*ContainerTc"],
]
163 changes: 61 additions & 102 deletions grudge/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
.. autofunction:: get_reasonable_array_context_class
"""

from __future__ import annotations


__copyright__ = "Copyright (C) 2020 Andreas Kloeckner"

__license__ = """
Expand Down Expand Up @@ -35,9 +38,11 @@
import logging
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from warnings import warn

from typing_extensions import Self

from meshmode.array_context import (
PyOpenCLArrayContext as _PyOpenCLArrayContextBase,
PytatoPyOpenCLArrayContext as _PytatoPyOpenCLArrayContextBase,
Expand All @@ -48,28 +53,6 @@

logger = logging.getLogger(__name__)

try:
# FIXME: temporary workaround while SingleGridWorkBalancingPytatoArrayContext
# is not available in meshmode's main branch
# (it currently needs
# https://github.com/kaushikcfd/meshmode/tree/pytato-array-context-transforms)
from meshmode.array_context import SingleGridWorkBalancingPytatoArrayContext

try:
# Crude check if we have the correct loopy branch
# (https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms)
from loopy.codegen.result import get_idis_for_kernel # noqa
except ImportError:
# warn("Your loopy and meshmode branches are mismatched. "
# "Please make sure that you have the "
# "https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms " # noqa
# "branch of loopy.")
_HAVE_SINGLE_GRID_WORK_BALANCING = False
else:
_HAVE_SINGLE_GRID_WORK_BALANCING = True

except ImportError:
_HAVE_SINGLE_GRID_WORK_BALANCING = False

try:
# FIXME: temporary workaround while FusionContractorArrayContext
Expand Down Expand Up @@ -119,8 +102,8 @@ class PyOpenCLArrayContext(_PyOpenCLArrayContextBase):
to understand :mod:`grudge`-specific transform metadata. (Of which there isn't
any, for now.)
"""
def __init__(self, queue: "pyopencl.CommandQueue",
allocator: Optional["pyopencl.tools.AllocatorBase"] = None,
def __init__(self, queue: pyopencl.CommandQueue,
allocator: pyopencl.tools.AllocatorBase | None = None,
wait_event_queue_length: int | None = None,
force_device_scalars: bool = True) -> None:

Expand Down Expand Up @@ -165,8 +148,8 @@ def __init__(self, queue, allocator=None,
# }}}


class MPIBasedArrayContext:
mpi_communicator: "MPI.Comm"
class MPIBasedArrayContext(ArrayContext):
mpi_communicator: MPI.Intracomm


# {{{ distributed + pytato
Expand Down Expand Up @@ -345,13 +328,13 @@ class _DistributedCompiledFunction:
type of the callable.
"""

actx: "MPISingleGridWorkBalancingPytatoArrayContext"
distributed_partition: "DistributedGraphPartition"
part_id_to_prg: "Mapping[PartId, pt.target.BoundProgram]"
actx: MPIBasedArrayContext
distributed_partition: DistributedGraphPartition
part_id_to_prg: Mapping[PartId, pt.target.BoundProgram]
input_id_to_name_in_program: Mapping[tuple[Any, ...], str]
output_id_to_name_in_program: Mapping[tuple[Any, ...], str]
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
name_in_program_to_axes: Mapping[str, tuple["pt.Axis", ...]]
name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]]
output_template: ArrayContainer

def __call__(self, arg_id_to_arg) -> ArrayContainer:
Expand All @@ -368,10 +351,11 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)

from pytato import execute_distributed_partition
assert isinstance(self.actx, PytatoPyOpenCLArrayContext | PyOpenCLArrayContext)
out_dict = execute_distributed_partition(
self.distributed_partition, self.part_id_to_prg,
self.actx.queue, self.actx.mpi_communicator,
allocator=self.actx.allocator,
self.actx.queue, self.actx.mpi_communicator, # pylint: disable=no-member
allocator=self.actx.allocator, # pylint: disable=no-member
input_args=input_args_for_prg)

def to_output_template(keys, _):
Expand All @@ -387,42 +371,6 @@ def to_output_template(keys, _):
self.output_template)


class MPIPytatoArrayContextBase(MPIBasedArrayContext):
def __init__(
self, mpi_communicator, queue, *, mpi_base_tag, allocator=None,
compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
) -> None:
"""
:arg compile_trace_callback: A function of three arguments
*(what, stage, ir)*, where *what* identifies the object
being compiled, *stage* is a string describing the compilation
pass, and *ir* is an object containing the intermediate
representation. This interface should be considered
unstable.
"""
if allocator is None:
warn("No memory allocator specified, please pass one. "
"(Preferably a pyopencl.tools.MemoryPool in order "
"to reduce device allocations)", stacklevel=2)

super().__init__(queue, allocator,
compile_trace_callback=compile_trace_callback)

self.mpi_communicator = mpi_communicator
self.mpi_base_tag = mpi_base_tag

# FIXME: implement distributed-aware freeze

def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
return _DistributedLazilyPyOpenCLCompilingFunctionCaller(self, f)

def clone(self):
# type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member
# pylint: disable=no-member
return type(self)(self.mpi_communicator, self.queue,
mpi_base_tag=self.mpi_base_tag,
allocator=self.allocator)

# }}}


Expand All @@ -437,8 +385,8 @@ class MPIPyOpenCLArrayContext(PyOpenCLArrayContext, MPIBasedArrayContext):

def __init__(self,
mpi_communicator,
queue: "pyopencl.CommandQueue",
*, allocator: Optional["pyopencl.tools.AllocatorBase"] = None,
queue: pyopencl.CommandQueue,
*, allocator: pyopencl.tools.AllocatorBase | None = None,
wait_event_queue_length: int | None = None,
force_device_scalars: bool = True) -> None:
"""
Expand All @@ -451,7 +399,7 @@ def __init__(self,

self.mpi_communicator = mpi_communicator

def clone(self):
def clone(self) -> Self:
# type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member
# pylint: disable=no-member
return type(self)(self.mpi_communicator, self.queue,
Expand All @@ -476,7 +424,7 @@ def __init__(self, mpi_communicator) -> None:

self.mpi_communicator = mpi_communicator

def clone(self):
def clone(self) -> Self:
return type(self)(self.mpi_communicator)

# }}}
Expand All @@ -485,28 +433,50 @@ def clone(self):
# {{{ distributed + pytato array context subclasses

class MPIBasePytatoPyOpenCLArrayContext(
MPIPytatoArrayContextBase, PytatoPyOpenCLArrayContext):
MPIBasedArrayContext, PytatoPyOpenCLArrayContext):
"""
.. autofunction:: __init__
"""
pass


if _HAVE_SINGLE_GRID_WORK_BALANCING:
class MPISingleGridWorkBalancingPytatoArrayContext(
MPIPytatoArrayContextBase, SingleGridWorkBalancingPytatoArrayContext):
def __init__(
self, mpi_communicator, queue, *, mpi_base_tag, allocator=None,
compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
) -> None:
"""
.. autofunction:: __init__
:arg compile_trace_callback: A function of three arguments
*(what, stage, ir)*, where *what* identifies the object
being compiled, *stage* is a string describing the compilation
pass, and *ir* is an object containing the intermediate
representation. This interface should be considered
unstable.
"""
if allocator is None:
warn("No memory allocator specified, please pass one. "
"(Preferably a pyopencl.tools.MemoryPool in order "
"to reduce device allocations)", stacklevel=2)

MPIPytatoArrayContext = MPISingleGridWorkBalancingPytatoArrayContext
else:
MPIPytatoArrayContext = MPIBasePytatoPyOpenCLArrayContext
super().__init__(queue, allocator,
compile_trace_callback=compile_trace_callback)

self.mpi_communicator = mpi_communicator
self.mpi_base_tag = mpi_base_tag

# FIXME: implement distributed-aware freeze

def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
return _DistributedLazilyPyOpenCLCompilingFunctionCaller(self, f)

def clone(self) -> Self:
return type(self)(self.mpi_communicator, self.queue,
mpi_base_tag=self.mpi_base_tag,
allocator=self.allocator)


MPIPytatoArrayContext: type[MPIBasedArrayContext] = MPIBasePytatoPyOpenCLArrayContext


if _HAVE_FUSION_ACTX:
class MPIFusionContractorArrayContext(
MPIPytatoArrayContextBase, FusionContractorArrayContext):
MPIBasePytatoPyOpenCLArrayContext, FusionContractorArrayContext):
"""
.. autofunction:: __init__
"""
Expand Down Expand Up @@ -570,25 +540,14 @@ def __call__(self):


def _get_single_grid_pytato_actx_class(distributed: bool) -> type[ArrayContext]:
if not _HAVE_SINGLE_GRID_WORK_BALANCING:
warn("No device-parallel actx available, execution will be slow. "
"Please make sure you have the right branches for loopy "
"(https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms) " # noqa
"and meshmode "
"(https://github.com/kaushikcfd/meshmode/tree/pytato-array-context-transforms).",
stacklevel=1)
warn("No device-parallel actx available, execution will be slow.",
stacklevel=1)
# lazy, non-distributed
if not distributed:
if _HAVE_SINGLE_GRID_WORK_BALANCING:
return SingleGridWorkBalancingPytatoArrayContext
else:
return PytatoPyOpenCLArrayContext
return PytatoPyOpenCLArrayContext
else:
# distributed+lazy:
if _HAVE_SINGLE_GRID_WORK_BALANCING:
return MPISingleGridWorkBalancingPytatoArrayContext
else:
return MPIBasePytatoPyOpenCLArrayContext
return MPIBasePytatoPyOpenCLArrayContext


def get_reasonable_array_context_class(
Expand All @@ -603,7 +562,7 @@ def get_reasonable_array_context_class(
if numpy:
assert not (lazy or fusion)
if distributed:
actx_class = MPINumpyArrayContext
actx_class: type[ArrayContext] = MPINumpyArrayContext
else:
actx_class = NumpyArrayContext

Expand Down Expand Up @@ -641,7 +600,7 @@ def get_reasonable_array_context_class(
"device-parallel=%r",
actx_class.__name__, lazy, distributed,
# eager is always device-parallel:
(_HAVE_SINGLE_GRID_WORK_BALANCING or _HAVE_FUSION_ACTX or not lazy))
(_HAVE_FUSION_ACTX or not lazy))
return actx_class

# }}}
Expand Down
Loading

0 comments on commit ce6da6b

Please sign in to comment.