Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing CUDAGraph Target #4

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@
Python 3 Titan V:
script:
- py_version=3
- export PYOPENCL_TEST=nvi:titan
- export EXTRA_INSTALL="pyopencl mpi4py jax[cpu]"
- echo "CUDADRV_LIB_DIR = ['/usr/lib/x86_64-linux-gnu/nvidia/current']" > siteconf.py
- curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh
- "export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH"
- ". ./build-and-test-py-project.sh"
tags:
- python3
- nvidia-titan-v
except:
- tags
artifacts:
reports:
junit: test/pytest.xml

Python 3 POCL:
script:
- export PY_EXE=python3
- export PYOPENCL_TEST=portable:pthread
- export EXTRA_INSTALL="pyopencl mpi4py jax[cpu]"
- echo "CUDADRV_LIB_DIR = ['/usr/lib/x86_64-linux-gnu/nvidia/current']" > siteconf.py
- curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh
- "export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH"
- ". ./build-and-test-py-project.sh"
tags:
- python3
Expand Down
1 change: 1 addition & 0 deletions .pylintrc-local.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- ply
- pygments.lexers
- pygments.formatters
- pycuda

# https://github.com/PyCQA/pylint/issues/7623
- arg: disable
Expand Down
Empty file added config
Empty file.
3 changes: 2 additions & 1 deletion pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def set_debug_enabled(flag: bool) -> None:
from pytato.target import Target
from pytato.target.loopy import LoopyPyOpenCLTarget
from pytato.target.python.jax import generate_jax
from pytato.target.pycuda.cudagraph import generate_cudagraph
from pytato.visualization import (get_dot_graph, show_dot_graph,
get_ascii_graph, show_ascii_graph,
get_dot_graph_from_partition)
Expand Down Expand Up @@ -127,7 +128,7 @@ def set_debug_enabled(flag: bool) -> None:
"matmul", "roll", "transpose", "stack", "reshape", "expand_dims",
"concatenate",

"generate_loopy", "generate_jax",
"generate_loopy", "generate_jax", "generate_cudagraph",

"Target", "LoopyPyOpenCLTarget",

Expand Down
174 changes: 174 additions & 0 deletions pytato/target/pycuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from __future__ import annotations

__copyright__ = """Copyright (C) 2022 Mit Kotak"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

__doc__ = """
.. currentmodule:: pytato

.. autofunction:: generate_cudagraph

.. currentmodule:: pytato.target.python

.. autoclass:: PythonTarget
.. autoclass:: BoundPythonProgram
.. autoclass:: CUDAGraphTarget
.. autoclass:: BoundCUDAGraphProgram
"""

from dataclasses import dataclass
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any, Mapping, FrozenSet, Callable, Dict, Set

from pytato.target import Target, BoundProgram


# {{{ abstract types

class PythonTarget(Target, ABC):
"""
A target that generates code for a python program, typically by invoking
some :mod:`numpy`-like for the array operations.
"""

@abstractmethod
def bind_program(self,
program: str,
entrypoint: str,
expected_arguments: FrozenSet[str],
bound_arguments: Mapping[str, Any]) -> BoundProgram:
"""
:arg program: The python code containing the compiled routine.
:arg entrypoint: Name of the entrypoint
"""
pass


@dataclass(repr=False, eq=False)
class BoundCUDAGraphProgram(BoundProgram):
"""
A wrapper for executing python programs with bound arguments.

.. automethod:: __call__
.. automethod:: copy
.. automethod:: with_transformed_program
"""
expected_arguments: FrozenSet[str]
entrypoint: str

@cached_property
def _compiled_function(self) -> Callable[..., Any]:
variables_after_execution: Dict[str, Any] = {
"_MODULE_SOURCE_CODE": self.program, # helps pudb
"cuda_allocator": self.cuda_allocator,
"cuda_dev": self.cuda_dev
}
exec(self.program, variables_after_execution)
assert callable(variables_after_execution[self.entrypoint])
return variables_after_execution[ # type: ignore[no-any-return]
self.entrypoint]

@cached_property
def _bound_argment_names(self) -> Set[str]:
return set(self.bound_arguments.keys())

@cached_property
def _processed_bound_args(self) -> Mapping[str, Any]:
import pycuda.autoinit
import pycuda.gpuarray as gpuarray
import numpy as np

processed_bound_args = {}
for name, arg in zip(self.bound_arguments.keys(),
self.bound_arguments.values()):
if isinstance(arg, np.ndarray):
processed_bound_args[name] = gpuarray.to_gpu(arg)
elif isinstance(arg, gpuarray.GPUArray):
processed_bound_args[name] = arg
else:
raise ValueError("Array format not supported")
return processed_bound_args

def __call__(self, cuda_allocator=None, *args: Any, **kwargs: Any) -> Any:
if bool(cuda_allocator):
self.cuda_allocator = cuda_allocator
else:
import pycuda
self.cuda_allocator = pycuda.tools.DeviceMemoryPool().allocate
import pycuda.driver as drv
self.cuda_dev = drv.Context.get_device()
if args:
raise ValueError(f"'{type(self).__call__}' does not take positional"
" arguments.")

if set(kwargs.keys()) & self._bound_argment_names:
raise ValueError("Got arguments that were previously bound: "
f"'{set(kwargs.keys()) & set(self.bound_arguments.keys())}'.")
updated_kwargs = dict(self._processed_bound_args)
updated_kwargs.update(kwargs)
updated_kwargs = {kw: arg
for kw, arg in updated_kwargs.items()
if kw in self.expected_arguments}
return self._compiled_function(**updated_kwargs)

def copy(self, **kwargs: Any) -> BoundCUDAGraphProgram:
from dataclasses import replace
return replace(self, **kwargs)

def with_transformed_program(self, *args: Any, **kwargs: Any
) -> BoundCUDAGraphProgram:
raise ValueError("Cannot transform python program.")

# }}}


# {{{ cudagraph-numpy target


class CUDAGraphTarget(PythonTarget):
"""
A target that generates code for a python program by offloading array
operations to :mod:`cudagraph.cudagraph`.
"""

@property
def numpy_like_module_name(self) -> str:
return "pycuda.driver"

@property
def numpy_like_module_name_shorthand(self) -> str:
return "_pt_drv"

def bind_program(self,
program: str,
entrypoint: str,
expected_arguments: FrozenSet[str],
bound_arguments: Mapping[str, Any]) -> BoundCUDAGraphProgram:
return BoundCUDAGraphProgram(target=self, program=program,
entrypoint=entrypoint,
expected_arguments=expected_arguments,
bound_arguments=bound_arguments)

# }}}

# vim: foldmethod=marker
57 changes: 57 additions & 0 deletions pytato/target/pycuda/cudagraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
__copyright__ = """
Copyright (C) 2022 Mit Kotak
"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

from typing import Union, Optional, Mapping

from pytato.array import Array, DictOfNamedArrays
from pytato.target.pycuda import CUDAGraphTarget, BoundCUDAGraphProgram
from pytato.target.pycuda.numpy_like import generate_cudagraph_numpy_like

__doc__ = """
.. autofunction:: generate_cudagraph
"""


def generate_cudagraph(expr: Union[Array, Mapping[str, Array], DictOfNamedArrays],
*,
target: Optional[CUDAGraphTarget] = None,
function_name: str = "_pt_kernel",
show_code: bool = False,
colorize_show_code: bool = True,
dot_graph_path: Optional[str] = "") -> BoundCUDAGraphProgram:
"""
Returns a :class:`pytato.target.python.BoundCUDAGraphProgram` for the array
expressions in *expr*.
:arg function: Name of the entrypoint function in the generated code.
:arg show_code: If *True*, the generated code is printed to ``stdout``.
"""
if target is None:
target = CUDAGraphTarget()
assert isinstance(target, CUDAGraphTarget)
return generate_cudagraph_numpy_like(expr,
target=target,
function_name=function_name,
show_code=show_code,
colorize_show_code=colorize_show_code,
dot_graph_path=dot_graph_path)
101 changes: 101 additions & 0 deletions pytato/target/pycuda/levels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import ast
import sys
import os
import numpy as np
import loopy as lp

from typing import (Callable, Union, Optional, Mapping, Dict, TypeVar,
List, Set, Tuple, Any)

import pymbolic.primitives as prim
from pytools import (UniqueNameGenerator, memoize_method)
from pytools.graph import reverse_graph
from pytato.transform import CachedMapper, ArrayOrNames
from pytato.array import (IndexLambda, DataWrapper,
Placeholder, SizeParam,
Array, DictOfNamedArrays,
DataInterface)
from pytato.scalar_expr import (IdentityMapper, Reduce,
WalkMapper as ScalarWalkMapper)
from pytato.target.pycuda import (CUDAGraphTarget,
BoundCUDAGraphProgram)
from pyrsistent import pmap
from loopy.symbolic import Reduction as LoopyReduction
from pytato.target.loopy.codegen import PYTATO_REDUCTION_TO_LOOPY_REDUCTION
from dataclasses import dataclass

T = TypeVar("T")


class StringMapper(CachedMapper[ArrayOrNames]):

def __init__(self):
super().__init__()
from bidict import bidict
self.string_map = bidict()
self.dep_map = {}
self.vng = UniqueNameGenerator()

def map_dict_of_named_arrays(self, expr: DictOfNamedArrays):
from pytato.codegen import normalize_outputs as normalize_outputs
from pytato.target.loopy.codegen import preprocess as preprocess_loopy
result = normalize_outputs(expr)
preproc_result = preprocess_loopy(result, target=lp.CudaTarget())
for _, subexpr in list(preproc_result.outputs._data.items()):
if not isinstance(subexpr, Placeholder):
expression = self.rec(subexpr)
self.string_map[subexpr] = expression


def map_placeholder(self, expr: Placeholder):
placeholder = self.vng("_pt_N")
return placeholder

def map_index_lambda(self, expr: IndexLambda):
operation = self.vng("_pt_N")
self.string_map[expr] = self.vng("_pt_N")
for _, bnd in list(expr.bindings.items()):
if not isinstance(bnd, Placeholder):
index_lambda = self.rec(bnd)
self.string_map[bnd] = index_lambda
if not index_lambda in self.dep_map:
self.dep_map[index_lambda] = set()
self.dep_map[index_lambda].add(operation)
return operation

def topological_sort(self):
graph_predecessor = reverse_graph(self.dep_map)
graph_successor = reverse_graph(graph_predecessor)
levels = []
levels_ops = []
while (len(graph_predecessor)) > 0:
level_nodes = []
level_nodes_ops = []
for key,values in list(graph_predecessor.items()):
if len(values) == 0:
level_nodes.append(key)
if key in self.string_map.inverse:
level_nodes_ops.append(self.string_map.inverse[key])
nodes_to_remove = graph_successor[key]
for node in nodes_to_remove:
if key in graph_predecessor[node]:
tmp_set = set(graph_predecessor[node])
tmp_set.remove(key)
graph_predecessor[node] = frozenset(tmp_set)
del graph_predecessor[key]
levels.append(len(level_nodes))
levels_ops.append(level_nodes_ops)
return levels, levels_ops

def weight_calculator(levels_ops):
levels_weights = np.zeros(shape=(len(levels_ops)))
for level_i,level_op in enumerate(levels_ops):
level_bytes = 0
for node in level_op:
assert isinstance(node, IndexLambda)
level_bytes += node.size * np.dtype(node.dtype).itemsize
for key,bnd in node.bindings.items():
level_bytes += bnd.size * np.dtype(bnd.dtype).itemsize
levels_weights[level_i] = level_bytes
return levels_weights/sum(levels_weights)

Loading