Skip to content

Commit

Permalink
Merge pull request #231 from NaderAlAwar/trace
Browse files Browse the repository at this point in the history
  • Loading branch information
NaderAlAwar authored Jan 3, 2024
2 parents e49fdf0 + 50f136e commit 2e0b664
Show file tree
Hide file tree
Showing 16 changed files with 525 additions and 141 deletions.
9 changes: 9 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ ignore_errors = True
[mypy-pykokkos.core.fusion.fuse]
ignore_errors = True

[mypy-pykokkos.core.fusion.trace]
ignore_errors = True

[mypy-pykokkos.core.fusion.access_modes]
ignore_errors = True

[mypy-pykokkos.core.fusion.future]
ignore_errors = True

[mypy-pykokkos.core.translators.functor]
ignore_errors = True

Expand Down
10 changes: 1 addition & 9 deletions pykokkos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

from pykokkos.runtime import runtime_singleton
from pykokkos.core import CompilationDefaults, Runtime
from pykokkos.core import Runtime
from pykokkos.interface import *
from pykokkos.kokkos_manager import (
initialize, finalize,
Expand Down Expand Up @@ -79,14 +79,6 @@
__all__ = ["__array_api_version__"]

runtime_singleton.runtime = Runtime()
defaults: Optional[CompilationDefaults] = runtime_singleton.runtime.compiler.read_defaults()

if defaults is not None:
set_default_space(ExecutionSpace[defaults.space])
if defaults.force_uvm:
enable_uvm()
else:
disable_uvm()

def cleanup():
"""
Expand Down
2 changes: 0 additions & 2 deletions pykokkos/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pykokkos.interface import *

from .compiler import CompilationDefaults, Compiler
from .keywords import Keywords
from .runtime import Runtime
2 changes: 1 addition & 1 deletion pykokkos/core/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def compile_object(
updated_types: Optional[UpdatedTypes] = None,
types_signature: Optional[str] = None,
**kwargs
) -> Optional[PyKokkosMembers]:
) -> PyKokkosMembers:
"""
Compile an entity object for a single execution space
Expand Down
3 changes: 2 additions & 1 deletion pykokkos/core/fusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .fuse import fuse_workunit_kwargs_and_params, fuse_workunits
from .fuse import fuse_workunit_kwargs_and_params, fuse_workunits
from .trace import Future, Tracer, TracerOperation
53 changes: 53 additions & 0 deletions pykokkos/core/fusion/access_modes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import ast
from enum import auto, Enum
from typing import Dict, Optional, Set

from pykokkos.core.translators.static import StaticTranslator


class AccessMode(Enum):
Read = auto()
Write = auto()
ReadWrite = auto()


def get_view_access_modes(AST: ast.FunctionDef, view_args: Set[str]) -> Dict[str, AccessMode]:
AST = StaticTranslator.add_parent_refs(AST)
access_modes: Dict[str, AccessMode] = {}

for node in ast.walk(AST):
if not isinstance(node, ast.Subscript): # We are only interested in view accesses
continue

if not isinstance(node.value, ast.Name): # Skip type annotations
continue

name: str = node.value.id
if name not in view_args:
continue

existing_mode: Optional[AccessMode] = access_modes.get(name)
new_mode: AccessMode

if isinstance(node.ctx, ast.Load):
if existing_mode is None:
new_mode = AccessMode.Read
elif existing_mode is AccessMode.Write:
new_mode = AccessMode.ReadWrite
else:
new_mode = existing_mode

if isinstance(node.ctx, ast.Store):
if existing_mode is None:
new_mode = AccessMode.Write
elif existing_mode is AccessMode.Read:
new_mode = AccessMode.ReadWrite
else:
new_mode = existing_mode

if new_mode is AccessMode.Write and isinstance(node.parent, ast.AugAssign):
new_mode = AccessMode.ReadWrite

access_modes[name] = new_mode

return access_modes
44 changes: 44 additions & 0 deletions pykokkos/core/fusion/future.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from pykokkos.runtime import runtime_singleton


class Future:
"""
Delayed reductions and scans return a Future
"""

def __init__(self) -> None:
self.value = None

def assign_value(self, value) -> None:
self.value = value

def __add__(self, other):
self.flush_trace()
return self.value + other

def __sub__(self, other):
self.flush_trace()
return self.value - other

def __mul__(self, other):
self.flush_trace()
return self.value * other

def __truediv__(self, other):
self.flush_trace()
return self.value / other

def __floordiv__(self, other):
self.flush_trace()
return self.value // other

def __str__(self):
self.flush_trace()
return str(self.value)

def __repr__(self) -> str:
return str(f"Future(value={self.value})")

def flush_trace(self) -> None:
runtime_singleton.runtime.flush_data(self)
assert self.value is not None
233 changes: 233 additions & 0 deletions pykokkos/core/fusion/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import ast
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

from pykokkos.core.parsers import Parser, PyKokkosEntity
from pykokkos.interface import ExecutionPolicy, ViewType

from .access_modes import AccessMode, get_view_access_modes
from .future import Future


@dataclass
class DataDependency:
"""
Represents data + version
"""

name: Optional[str] # for debugging purposes
data_id: int
version: int

def __hash__(self) -> int:
return hash((self.data_id, self.version))

def __eq__(self, other: Any) -> bool:
if not isinstance(other, DataDependency):
return False

return self.data_id == other.data_id and self.version == other.version


@dataclass
class TracerOperation:
"""
A single operation in a trace
"""

op_id: int
future: Optional[Future]
name: Optional[str]
policy: ExecutionPolicy
workunit: Callable[..., None]
operation: str
parser: Parser
args: Dict[str, Any]
dependencies: Set[DataDependency]

def __hash__(self) -> int:
return self.op_id

def __eq__(self, other: Any) -> bool:
if not isinstance(other, TracerOperation):
return False

return self.op_id == other.op_id

def __repr__(self) -> str:
if self.name is None:
return self.workunit.__name__

return self.name


class Tracer:
"""
Holds traces of operations
"""

def __init__(self) -> None:
self.op_id: int = 0

# This functions as an ordered set
self.operations: Dict[TracerOperation, None] = {}

# Map from each data object id (future or array) to the current version
self.data_version: Dict[int, int] = {}

# Map from data version to tracer operation
self.data_operation: Dict[DataDependency, TracerOperation] = {}

def log_operation(
self,
future: Optional[Future],
name: Optional[str],
policy: ExecutionPolicy,
workunit: Callable[..., None],
operation: str,
parser: Parser,
entity_name: str,
**kwargs
) -> None:
"""
Log the workunit and its arguments in the trace
:param future: the future object corresponding to the output of reductions and scans
:param name: the name of the kernel
:param policy: the execution policy of the operation
:param workunit: the workunit function object
:param kwargs: the keyword arguments passed to the workunit
:param operation: the name of the operation "for", "reduce", or "scan"
:param parser: the parser containing the AST of the workunit
:param entity_name: the name of the workunit entity
"""

entity: PyKokkosEntity = parser.get_entity(entity_name)
AST: ast.FunctionDef = entity.AST

dependencies: Set[DataDependency]
access_modes: Dict[str, AccessMode]
dependencies, access_modes = self.get_data_dependencies(kwargs, AST)

tracer_op = TracerOperation(self.op_id, future, name, policy, workunit, operation, parser, dict(kwargs), dependencies)
self.op_id += 1

self.update_output_data_operations(kwargs, access_modes, tracer_op, future, operation)

self.operations[tracer_op] = None

def get_operations(self, data: Union[Future, ViewType]) -> List[TracerOperation]:
"""
Get all the operations needed to update the data of a future
or view and remove them from the trace
:param future: the future corresponding to the value that needs to be updated
:returns: the list of operations to be executed
"""

version: int = self.data_version.get(id(data), 0)
dependency = DataDependency(None, id(data), version)

operation: TracerOperation = self.data_operation[dependency]
if operation not in self.operations:
# This means that the dependency was already updated
return []

operations: List[TracerOperation] = [operation]
del self.operations[operation]

# Ideally, we would not have to do this. By adding an
# operation to this list, its dependencies should be
# automatically updated when the operation is executed.
# However, since the operations are not executed in Python, we
# cannot trigger the flush. We could also potentially iterate
# over kwargs prior to invoking a kernel and call flush_data()
# for all futures and views. We should implement both and
# benchmark them
i: int = 0
while i < len(operations):
current_op = operations[i]

for dep in current_op.dependencies:
if dep not in self.data_operation:
assert dep.version == 0
continue

dependency_op: TracerOperation = self.data_operation[dep]
if dependency_op not in self.operations:
# This means that the dependency was already updated
continue

operations.append(dependency_op)
del self.operations[dependency_op]

i += 1

operations.sort(key=lambda op: op.op_id, reverse=True)

return operations

def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Tuple[Set[DataDependency], Dict[str, AccessMode]]:
"""
Get the data dependencies of an operation from its input arguments
:param kwargs: the keyword arguments passed to the workunit
:param AST: the AST of the input workunit
:returns: the set of data dependencies and the access modes of the views
"""

dependencies: Set[DataDependency] = set()
view_args: Set[str] = set()

# First pass to get the Future dependencies and record all the views
for arg, value in kwargs.items():
if isinstance(value, Future):
version: int = self.data_version.get(id(value), 0)
dependency = DataDependency(arg, id(value), version)

dependencies.add(dependency)

if isinstance(value, ViewType):
view_args.add(arg)

access_modes: Dict[str, AccessMode] = get_view_access_modes(AST, view_args)

# Second pass to check if the views are dependencies
for arg, value in kwargs.items():
if isinstance(value, ViewType) and access_modes[arg] in {AccessMode.Read, AccessMode.ReadWrite}:
version: int = self.data_version.get(id(value), 0)
dependency = DataDependency(arg, id(value), version)

dependencies.add(dependency)

return dependencies, access_modes

def update_output_data_operations(
self,
kwargs: Dict[str, Any],
access_modes: Dict[str, AccessMode],
tracer_op: TracerOperation,
future: Optional[Future],
operation: str
) -> None:
"""
Update the data versions and operations of all data being written to
:param kwargs: the keyword arguments passed to the workunit
:param access_modes: how the passed views are being accessed
:param tracer_op: the current tracer operation being logged
:param future: the future object corresponding to the output of reductions and scans
:param operation: the name of the operation "for", "reduce", or "scan"
"""

for arg, value in kwargs.items():
if isinstance(value, ViewType) and access_modes[arg] in {AccessMode.Write, AccessMode.ReadWrite}:
version: int = self.data_version.get(id(value), 0)
self.data_version[id(value)] = version + 1
dependency = DataDependency(arg, id(value), version + 1)

self.data_operation[dependency] = tracer_op

if operation in {"reduce", "scan"}:
assert future is not None
self.data_operation[DataDependency(None, id(future), 0)] = tracer_op
Loading

0 comments on commit 2e0b664

Please sign in to comment.