Skip to content

Commit

Permalink
Provide default implementations of transform_loopy_program, with warning
Browse files Browse the repository at this point in the history
Closes gh-272
  • Loading branch information
inducer committed Aug 5, 2024
1 parent 3f848dc commit 1da13a5
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 24 deletions.
4 changes: 4 additions & 0 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,4 +583,8 @@ def tag_axes(

# }}}


class UntransformedCodeWarning(UserWarning):
pass

# vim: foldmethod=marker
34 changes: 23 additions & 11 deletions arraycontext/impl/pyopencl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""
from __future__ import annotations


__doc__ = """
.. currentmodule:: arraycontext
.. autoclass:: PyOpenCLArrayContext
.. automodule:: arraycontext.impl.pyopencl.taggable_cl_array
Expand Down Expand Up @@ -36,7 +39,13 @@
from pytools.tag import ToTagSetConvertible

from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainer,
ScalarLike,
UntransformedCodeWarning,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -72,8 +81,8 @@ class PyOpenCLArrayContext(ArrayContext):
"""

def __init__(self,
queue: "pyopencl.CommandQueue",
allocator: Optional["pyopencl.tools.AllocatorBase"] = None,
queue: pyopencl.CommandQueue,
allocator: Optional[pyopencl.tools.AllocatorBase] = None,
wait_event_queue_length: Optional[int] = None,
force_device_scalars: bool = False) -> None:
r"""
Expand Down Expand Up @@ -301,16 +310,19 @@ def clone(self):

# {{{ transform_loopy_program

def transform_loopy_program(self, t_unit):
def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
from warnings import warn
warn("Using arraycontext.PyOpenCLArrayContext.transform_loopy_program "
"to transform a program. This is deprecated and will stop working "
"in 2022. Instead, subclass PyOpenCLArrayContext and implement "
"the specific logic required to transform the program for your "
"package or application. Check higher-level packages "
warn("Using the base "
f"{type(self).__name__}.transform_loopy_program "
"to transform a translation unit. "
"This is largely a no-op and unlikely to result in fast generated "
"code."
f"Instead, subclass {type(self).__name__} and implement "
"the specific transform logic required to transform the program "
"for your package or application. Check higher-level packages "
"(e.g. meshmode), which may already have subclasses you may want "
"to build on.",
DeprecationWarning, stacklevel=2)
UntransformedCodeWarning, stacklevel=2)

# accommodate loopy with and without kernel callables

Expand Down
45 changes: 33 additions & 12 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""
from __future__ import annotations


__doc__ = """
.. currentmodule:: arraycontext
A :mod:`pytato`-based array context defers the evaluation of an array until its
Expand Down Expand Up @@ -62,11 +65,18 @@
from pytools.tag import Tag, ToTagSetConvertible, normalize_tags

from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainer,
ScalarLike,
UntransformedCodeWarning,
)
from arraycontext.metadata import NameHint


if TYPE_CHECKING:
import loopy as lp
import pyopencl as cl
import pytato

Expand Down Expand Up @@ -137,7 +147,6 @@ def __init__(
"""
super().__init__()

import loopy as lp
import pytato as pt
self._freeze_prg_cache: Dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
self._dag_transform_cache: Dict[
Expand Down Expand Up @@ -180,8 +189,8 @@ def empty_like(self, ary):

# {{{ compilation

def transform_dag(self, dag: "pytato.DictOfNamedArrays"
) -> "pytato.DictOfNamedArrays":
def transform_dag(self, dag: pytato.DictOfNamedArrays
) -> pytato.DictOfNamedArrays:
"""
Returns a transformed version of *dag*. Sub-classes are supposed to
override this method to implement context-specific transformations on
Expand All @@ -194,10 +203,22 @@ def transform_dag(self, dag: "pytato.DictOfNamedArrays"
"""
return dag

def transform_loopy_program(self, t_unit):
raise ValueError(
f"{type(self).__name__} does not implement transform_loopy_program. "
"Sub-classes are supposed to implement it.")
def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
from warnings import warn
warn("Using the base "
f"{type(self).__name__}.transform_loopy_program "
"to transform a translation unit. "
"This is a no-op and will result in unoptimized C code for"
"the requested optimization, all in a single statement."
"This will work, but is unlikely to be performatn."
f"Instead, subclass {type(self).__name__} and implement "
"the specific transform logic required to transform the program "
"for your package or application. Check higher-level packages "
"(e.g. meshmode), which may already have subclasses you may want "
"to build on.",
UntransformedCodeWarning, stacklevel=2)

return t_unit

@abc.abstractmethod
def einsum(self, spec, *args, arg_names=None, tagged=()):
Expand Down Expand Up @@ -250,7 +271,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
.. automethod:: compile
"""
def __init__(
self, queue: "cl.CommandQueue", allocator=None, *,
self, queue: cl.CommandQueue, allocator=None, *,
use_memory_pool: Optional[bool] = None,
compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None,

Expand Down Expand Up @@ -642,8 +663,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
from .compile import LazilyPyOpenCLCompilingFunctionCaller
return LazilyPyOpenCLCompilingFunctionCaller(self, f)

def transform_dag(self, dag: "pytato.DictOfNamedArrays"
) -> "pytato.DictOfNamedArrays":
def transform_dag(self, dag: pytato.DictOfNamedArrays
) -> pytato.DictOfNamedArrays:
import pytato as pt
dag = pt.transform.materialize_with_mpms(dag)
return dag
Expand Down
1 change: 0 additions & 1 deletion arraycontext/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def is_available(cls) -> bool:
def actx_class(self):
from arraycontext import PytatoPyOpenCLArrayContext
actx_cls = PytatoPyOpenCLArrayContext
actx_cls.transform_loopy_program = lambda s, t_unit: t_unit
return actx_cls

def __call__(self):
Expand Down

0 comments on commit 1da13a5

Please sign in to comment.