Skip to content

Commit

Permalink
Merge pull request #224 from NaderAlAwar/fuse_args
Browse files Browse the repository at this point in the history
Add option to fuse kernel arguments
  • Loading branch information
NaderAlAwar authored Dec 12, 2023
2 parents 2db7c16 + 8667e49 commit becf8c7
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 14 deletions.
13 changes: 7 additions & 6 deletions pykokkos/core/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self):
logging.basicConfig(stream=sys.stdout, level=numeric_level)
self.logger = logging.getLogger()

def fuse_objects(self, metadata: List[EntityMetadata], fuse_ASTs: bool) -> Tuple[PyKokkosEntity, List[PyKokkosEntity]]:
def fuse_objects(self, metadata: List[EntityMetadata], fuse_ASTs: bool, **kwargs) -> Tuple[PyKokkosEntity, List[PyKokkosEntity]]:
"""
Fuse two or more workunits into one
Expand Down Expand Up @@ -98,7 +98,7 @@ def fuse_objects(self, metadata: List[EntityMetadata], fuse_ASTs: bool) -> Tuple

fused_name: str = "_".join(names)
if fuse_ASTs:
AST, source = fuse_workunits(fused_name, ASTs, sources)
AST, source = fuse_workunits(fused_name, ASTs, sources, **kwargs)
else:
AST = None
source = None
Expand All @@ -115,7 +115,8 @@ def compile_object(
force_uvm: bool,
updated_decorator: UpdatedDecorator,
updated_types: Optional[UpdatedTypes] = None,
types_signature: Optional[str] = None
types_signature: Optional[str] = None,
**kwargs
) -> Optional[PyKokkosMembers]:
"""
Compile an entity object for a single execution space
Expand All @@ -139,7 +140,7 @@ def compile_object(
classtypes = parser.get_classtypes()
else:
# Avoid fusing the ASTs before checking if it was already compiled
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=False)
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=False, **kwargs)

hash: str = self.members_hash(entity.path, entity.name, types_signature)

Expand All @@ -152,7 +153,7 @@ def compile_object(
if self.is_compiled(module_setup.output_dir):
if hash not in self.members: # True if pre-compiled
if len(metadata) > 1:
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True)
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True, **kwargs)

if types_inferred:
entity.AST = parser.fix_types(entity, updated_types)
Expand All @@ -163,7 +164,7 @@ def compile_object(
return self.members[hash]

if len(metadata) > 1:
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True)
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True, **kwargs)

self.is_compiled_cache[module_setup.output_dir] = True

Expand Down
35 changes: 29 additions & 6 deletions pykokkos/core/fusion/fuse.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import ast
import inspect
import os
from typing import Any, Callable, Dict, List, Set, Tuple, Union


def get_node_name(node: Union[ast.Attribute, ast.Name]) -> str:
"""
Copied from visitors_util.py due to circular import
Expand Down Expand Up @@ -67,14 +67,22 @@ def fuse_workunit_kwargs_and_params(
fused_params: List[inspect.Parameter] = []
fused_params.append(inspect.Parameter("fused_tid", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int))

view_ids: Set[int] = set()

for workunit_idx, workunit in enumerate(workunits):
key: str = f"args_{workunit_idx}"
if key not in kwargs:
raise RuntimeError(f"kwargs not specified for workunit {workunit_idx} with key {key}")
current_kwargs: Dict[str, Any] = kwargs[key]

current_params: List[inspect.Parameter] = list(inspect.signature(workunit).parameters.values())
for p in current_params[1:]: # Skip the thread ID
current_arg = current_kwargs[p.name]
if "PK_FUSE_ARGS" in os.environ and id(current_arg) in view_ids:
continue

view_ids.add(id(current_arg))

fused_name: str = f"fused_{p.name}_{workunit_idx}"
fused_kwargs[fused_name] = current_kwargs[p.name]
fused_params.append(inspect.Parameter(fused_name, p.kind, annotation=p.annotation))
Expand Down Expand Up @@ -102,7 +110,7 @@ def fuse_workunit_kwargs_and_params(
# return fused_kwargs, fused_params


def fuse_arguments(all_args: List[ast.arguments]) -> Tuple[ast.arguments, Dict[Tuple[str, int], str]]:
def fuse_arguments(all_args: List[ast.arguments], **kwargs) -> Tuple[ast.arguments, Dict[Tuple[str, int], str]]:
"""
Fuse the ast argument object into one
Expand All @@ -116,7 +124,15 @@ def fuse_arguments(all_args: List[ast.arguments]) -> Tuple[ast.arguments, Dict[T
new_tid: str = "fused_tid"
fused_args = ast.arguments(args=[ast.arg(arg=new_tid, annotation=ast.Name(id='int', ctx=ast.Load()))])

# Map from view ID to fused name
fused_view_names: Dict[int, str] = {}

for workunit_idx, args in enumerate(all_args):
key: str = f"args_{workunit_idx}"
if key not in kwargs:
raise RuntimeError(f"kwargs not specified for workunit {workunit_idx} with key {key}")
current_kwargs: Dict[str, Any] = kwargs[key]

for arg_idx, arg in enumerate(args.args):
old_name: str = arg.arg
key = (old_name, workunit_idx)
Expand All @@ -127,7 +143,13 @@ def fuse_arguments(all_args: List[ast.arguments]) -> Tuple[ast.arguments, Dict[T
name_map[key] = new_tid
continue

current_arg = current_kwargs[old_name]
if "PK_FUSE_ARGS" in os.environ and id(current_arg) in fused_view_names:
name_map[key] = fused_view_names[id(current_arg)]
continue

new_name = f"fused_{old_name}_{workunit_idx}"
fused_view_names[id(current_arg)] = new_name
name_map[key] = new_name
fused_args.args.append(ast.arg(arg=new_name, annotation=arg.annotation))

Expand Down Expand Up @@ -175,7 +197,7 @@ def fuse_decorators(decorators: List[Union[ast.Attribute, ast.Call]], name_map:
return ast.Call(func=decorators[0].func, args=[], keywords=fused_keywords)


def fuse_ASTs(ASTs: List[ast.FunctionDef], name: str) -> ast.FunctionDef:
def fuse_ASTs(ASTs: List[ast.FunctionDef], name: str, **kwargs) -> ast.FunctionDef:
"""
Fuse the ASTs of multiple workunits together
Expand All @@ -186,7 +208,7 @@ def fuse_ASTs(ASTs: List[ast.FunctionDef], name: str) -> ast.FunctionDef:

args: ast.arguments
name_map: Dict[str, str]
args, name_map = fuse_arguments([AST.args for AST in ASTs])
args, name_map = fuse_arguments([AST.args for AST in ASTs], **kwargs)

# decorator: ast.Call = fuse_decorators([AST.decorator_list[0] for AST in ASTs], name_map)
body: List[ast.stmt] = fuse_bodies([AST.body for AST in ASTs], name_map)
Expand All @@ -203,6 +225,7 @@ def fuse_workunits(
fused_name: str,
ASTs: List[ast.FunctionDef],
sources: List[Tuple[List[str], int]],
**kwargs
) -> Tuple[ast.FunctionDef, Tuple[List[str], int]]:
"""
Merge a list of workunits into a single object
Expand All @@ -212,7 +235,7 @@ def fuse_workunits(
:param sources: the raw source of the workunits to be fused
"""

AST: ast.FunctionDef = fuse_ASTs(ASTs, fused_name)
AST: ast.FunctionDef = fuse_ASTs(ASTs, fused_name, **kwargs)
source: Tuple[List[str], int] = fuse_sources(sources)

return AST, source
5 changes: 3 additions & 2 deletions pykokkos/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def precompile_workunit(
updated_decorator: UpdatedDecorator,
updated_types: Optional[UpdatedTypes] = None,
types_signature: Optional[str] = None,
**kwargs,
) -> Optional[PyKokkosMembers]:
"""
precompile the workunit
Expand All @@ -76,7 +77,7 @@ def precompile_workunit(
members: Optional[PyKokkosMembers] = self.compiler.compile_object(module_setup,
space, km.is_uvm_enabled(),
updated_decorator,
updated_types, types_signature)
updated_types, types_signature, **kwargs)

return members

Expand Down Expand Up @@ -128,7 +129,7 @@ def run_workunit(
return run_workunit_debug(policy, workunit, operation, initial_value, **kwargs)

types_signature: str = get_types_signature(updated_types, updated_decorator, execution_space)
members: Optional[PyKokkosMembers] = self.precompile_workunit(workunit, execution_space, updated_decorator, updated_types, types_signature)
members: Optional[PyKokkosMembers] = self.precompile_workunit(workunit, execution_space, updated_decorator, updated_types, types_signature, **kwargs)
if members is None:
raise RuntimeError("ERROR: members cannot be none")

Expand Down

0 comments on commit becf8c7

Please sign in to comment.