Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing dependence on inspect #226

Merged
merged 13 commits into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pykokkos/core/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def compile_object(
module_setup: ModuleSetup,
space: ExecutionSpace,
force_uvm: bool,
updated_decorator: UpdatedDecorator,
updated_decorator: Optional[UpdatedDecorator] = None,
updated_types: Optional[UpdatedTypes] = None,
types_signature: Optional[str] = None,
**kwargs
Expand Down
22 changes: 11 additions & 11 deletions pykokkos/core/fusion/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,40 +52,40 @@ def visit_keyword(self, node: ast.keyword) -> Any:


def fuse_workunit_kwargs_and_params(
workunits: List[Callable],
workunit_trees: List[ast.AST],
kwargs: Dict[str, Any]
) -> Tuple[Dict[str, Any], List[inspect.Parameter]]:
) -> Tuple[Dict[str, Any], List[ast.arg]]:
"""
Fuse the parameters and runtime arguments of a list of workunits and rename them as necessary

:param workunits: the list of workunits being merged
:param workunits_trees: the list of workunit trees (ASTs) being merged
:param kwargs: the keyword arguments passed to the call
:returns: a tuple of the fused kwargs and the combined inspected parameters
"""

fused_kwargs: Dict[str, Any] = {}
fused_params: List[inspect.Parameter] = []
fused_params.append(inspect.Parameter("fused_tid", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int))
fused_params: List[ast.arg] = []
fused_params.append(ast.arg(arg="fused_tid", annotation=int))

view_ids: Set[int] = set()

for workunit_idx, workunit in enumerate(workunits):
for workunit_idx, tree in enumerate(workunit_trees):
key: str = f"args_{workunit_idx}"
if key not in kwargs:
raise RuntimeError(f"kwargs not specified for workunit {workunit_idx} with key {key}")
current_kwargs: Dict[str, Any] = kwargs[key]

current_params: List[inspect.Parameter] = list(inspect.signature(workunit).parameters.values())
current_params: List[ast.arg] = [p for p in tree.args.args]
for p in current_params[1:]: # Skip the thread ID
current_arg = current_kwargs[p.name]
current_arg = current_kwargs[p.arg]
if "PK_FUSE_ARGS" in os.environ and id(current_arg) in view_ids:
continue

view_ids.add(id(current_arg))

fused_name: str = f"fused_{p.name}_{workunit_idx}"
fused_kwargs[fused_name] = current_kwargs[p.name]
fused_params.append(inspect.Parameter(fused_name, p.kind, annotation=p.annotation))
fused_name: str = f"fused_{p.arg}_{workunit_idx}"
fused_kwargs[fused_name] = current_kwargs[p.arg]
fused_params.append(ast.arg(arg=fused_name, annotation=p.annotation))

return fused_kwargs, fused_params

Expand Down
2 changes: 1 addition & 1 deletion pykokkos/core/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def reset_entity_tree(self, entity_tree: ast.AST, updated_obj: Union[UpdatedType
param_list = updated_obj.param_list
for param in param_list:
arg_obj = ast.arg(arg=param.name)
if param.annotation is not inspect._empty:
if param.annotation is not None:
type_str = get_type_str(param.annotation) # simplify inspect.annotation to string
arg_obj.annotation = self.get_annotation_node(type_str)
args_list.append(arg_obj)
Expand Down
59 changes: 47 additions & 12 deletions pykokkos/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
import sys
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, List
import sysconfig

import ast
import numpy as np
HannanNaeem marked this conversation as resolved.
Show resolved Hide resolved

from pykokkos.core.fusion import fuse_workunit_kwargs_and_params
from pykokkos.core.keywords import Keywords
from pykokkos.core.translators import PyKokkosMembers
from pykokkos.core.visitors import visitors_util
from pykokkos.core.type_inference import UpdatedTypes, UpdatedDecorator, get_types_signature
from pykokkos.core.type_inference import (
UpdatedTypes, UpdatedDecorator,
get_annotations, get_views_decorator, get_types_signature, process_arg_nodes,
prepare_fusion_args
)
from pykokkos.interface import (
DataType, ExecutionPolicy, ExecutionSpace, MemorySpace,
RandomPool, RangePolicy, TeamPolicy, View, ViewType,
Expand All @@ -19,7 +23,7 @@
import pykokkos.kokkos_manager as km

from .compiler import Compiler
from .module_setup import ModuleSetup
from .module_setup import ModuleSetup, EntityMetadata, get_metadata
from .run_debug import run_workload_debug, run_workunit_debug


Expand All @@ -34,6 +38,9 @@ def __init__(self):
# cache module_setup objects using a workload/workunit and space tuple
self.module_setups: Dict[Tuple, ModuleSetup] = {}

# original parameters for a workunit are preserved for inference changes in the future
self.workunit_params: Dict[str, ast.arguments] = {}

HannanNaeem marked this conversation as resolved.
Show resolved Hide resolved
def run_workload(self, space: ExecutionSpace, workload: object) -> None:
"""
Run the workload
Expand All @@ -59,7 +66,7 @@ def precompile_workunit(
self,
workunit: Callable[..., None],
space: ExecutionSpace,
updated_decorator: UpdatedDecorator,
updated_decorator: Optional[UpdatedDecorator] = None,
updated_types: Optional[UpdatedTypes] = None,
types_signature: Optional[str] = None,
**kwargs,
Expand Down Expand Up @@ -103,8 +110,6 @@ def run_workunit(
policy: ExecutionPolicy,
workunit: Union[Callable[..., None], List[Callable[..., None]]],
operation: str,
updated_decorator: UpdatedDecorator,
updated_types: Optional[UpdatedTypes] = None,
initial_value: Union[float, int] = 0,
**kwargs
) -> Optional[Union[float, int]]:
Expand All @@ -115,8 +120,6 @@ def run_workunit(
:param policy: the execution policy of the operation
:param workunit: the workunit function object
:param kwargs: the keyword arguments passed to the workunit
:param updated_decorator: Object with decorator specifier information
:param updated_types: UpdatedTypes object with type inference information
:param operation: the name of the operation "for", "reduce", or "scan"
:param initial_value: the initial value of the accumulator
:returns: the result of the operation (None for parallel_for)
Expand All @@ -129,13 +132,42 @@ def run_workunit(
raise RuntimeError("ERROR: operation cannot be None for Debug")
return run_workunit_debug(policy, workunit, operation, initial_value, **kwargs)

types_signature: str = get_types_signature(updated_types, updated_decorator, execution_space)
list_passed: bool = True

if not isinstance(workunit, list):
workunit = [workunit] # for easier transformations
list_passed = False

parser = self.compiler.get_parser(get_metadata(workunit[0]).path)

entity_AST: Union[List[ast.AST], ast.AST] = []
is_standalone_workunit: bool

entity_AST, is_standalone_workunit = process_arg_nodes(
list_passed,
parser,
workunit,
entity_AST,
)

workunit_trees: Union[List[Tuple[Callable, ast.AST]], Tuple[Callable, ast.AST]]
workunit_trees, workunit, entity_AST = prepare_fusion_args(list_passed, workunit, entity_AST)

updated_types: UpdatedTypes = None
updated_decorator: UpdatedDecorator = None
types_signature: str = None

if is_standalone_workunit:
updated_types = get_annotations(f"parallel_{operation}", workunit_trees, policy, passed_kwargs=kwargs)
updated_decorator = get_views_decorator(workunit_trees, passed_kwargs=kwargs)
types_signature = get_types_signature(updated_types, updated_decorator, execution_space)

HannanNaeem marked this conversation as resolved.
Show resolved Hide resolved
members: Optional[PyKokkosMembers] = self.precompile_workunit(workunit, execution_space, updated_decorator, updated_types, types_signature, **kwargs)
if members is None:
raise RuntimeError("ERROR: members cannot be none")

module_setup: ModuleSetup = self.get_module_setup(workunit, execution_space, types_signature)
return self.execute(workunit, module_setup, members, execution_space, policy=policy, name=name, **kwargs)
return self.execute(workunit, module_setup, members, execution_space, entity_trees=entity_AST, policy=policy, name=name, **kwargs)

def is_debug(self, space: ExecutionSpace) -> bool:
"""
Expand All @@ -156,6 +188,7 @@ def execute(
space: ExecutionSpace,
policy: Optional[ExecutionPolicy] = None,
name: Optional[str] = None,
entity_trees: Optional[List[ast.AST]] = None,
**kwargs
) -> Optional[Union[float, int]]:
"""
Expand All @@ -167,6 +200,7 @@ def execute(
:param space: the execution space
:param policy: the execution policy for workunits
:param name: the name of the kernel
:param entity_trees: Optional parameter: List of ASTs of entities being fused - only provided when entity is a list
:param kwargs: the keyword arguments passed to the workunit
:returns: the result of the operation (None for "for" and workloads)
"""
Expand All @@ -180,7 +214,7 @@ def execute(

module = self.import_module(module_setup.name, module_path)

args: Dict[str, Any] = self.get_arguments(entity, members, space, policy, **kwargs)
args: Dict[str, Any] = self.get_arguments(entity, members, space, policy, entity_trees, **kwargs)
if name is None:
args["pk_kernel_name"] = ""
else:
Expand Down Expand Up @@ -222,6 +256,7 @@ def get_arguments(
members: PyKokkosMembers,
space: ExecutionSpace,
policy: Optional[ExecutionPolicy],
entity_trees: Optional[List[ast.AST]] = None,
**kwargs
) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -256,7 +291,7 @@ def get_arguments(
else:
is_fused: bool = isinstance(entity, list)
if is_fused:
kwargs, _ = fuse_workunit_kwargs_and_params(entity, kwargs)
kwargs, _ = fuse_workunit_kwargs_and_params(entity_trees, kwargs)
HannanNaeem marked this conversation as resolved.
Show resolved Hide resolved
entity_members = kwargs

args.update(self.get_fields(entity_members))
Expand Down
2 changes: 1 addition & 1 deletion pykokkos/core/type_inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .args_type_inference import (
UpdatedTypes, UpdatedDecorator, HandledArgs,
handle_args, get_annotations, get_type_str, get_types_signature,
get_views_decorator
get_views_decorator, prepare_fusion_args, process_arg_nodes
)
Loading