Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce fusion overhead #284

Merged
merged 2 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pykokkos/core/fusion/access_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, tid_name: str, view_args: Dict[str, int]):
self.view_args = view_args

# Map from each view (str) + dimension (int) to an AccessIndex
self.access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = {}
self.access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = {}
self.current_iters: List[Tuple[str, bool]] = []

def visit_For(self, node: ast.For) -> None:
Expand All @@ -119,7 +119,7 @@ def visit_Call(self, node: ast.Call) -> None:
if arg.id in self.view_args:
rank: int = self.view_args[arg.id]
for i in range(rank):
self.access_indices[(arg.id, i)] = (AccessIndex.All, AccessMode.ReadWrite)
self.access_indices[(arg.id, i)] = (AccessIndex.All, AccessMode.ReadWrite, "")

def visit_Subscript(self, node: ast.Subscript) -> None:
current_node: ast.Subscript = node
Expand Down Expand Up @@ -160,7 +160,7 @@ def visit_Subscript(self, node: ast.Subscript) -> None:
index_to_set: AccessIndex
mode_to_set: AccessMode

existing_access: Optional[Tuple[AccessIndex, AccessMode]] = self.access_indices.get((view_name, i))
existing_access: Optional[Tuple[AccessIndex, AccessMode, str]] = self.access_indices.get((view_name, i))
if existing_access is None:
index_to_set = new_index
mode_to_set = AccessMode.Read if isinstance(node.ctx, ast.Load) else AccessMode.Write
Expand Down Expand Up @@ -191,10 +191,10 @@ def visit_Subscript(self, node: ast.Subscript) -> None:
if mode_to_set is AccessMode.Write and isinstance(node.parent, ast.AugAssign):
mode_to_set = AccessMode.ReadWrite

self.access_indices[(view_name, i)] = (index_to_set, mode_to_set)
self.access_indices[(view_name, i)] = (index_to_set, mode_to_set, index_node_str)


def get_view_write_indices_and_modes(AST: ast.FunctionDef, view_args: Dict[str, int]) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]]:
def get_view_write_indices_and_modes(AST: ast.FunctionDef, view_args: Dict[str, int]) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]:
"""
Get information from the AST needed for fusion safety

Expand All @@ -207,6 +207,6 @@ def get_view_write_indices_and_modes(AST: ast.FunctionDef, view_args: Dict[str,
tid_name: str = AST.args.args[0].arg
visitor = WriteIndicesVisitor(tid_name, view_args)
visitor.visit(AST)
access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = visitor.access_indices
access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = visitor.access_indices

return access_indices
59 changes: 41 additions & 18 deletions pykokkos/core/fusion/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TracerOperation:
entity_name: str
args: Dict[str, Any]
dependencies: Set[DataDependency]
access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]]
access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]

def __hash__(self) -> int:
return self.op_id
Expand Down Expand Up @@ -81,6 +81,10 @@ def __init__(self) -> None:
# Map from data version to tracer operation
self.data_operation: Dict[DataDependency, TracerOperation] = {}

# Cache expensive operations that require traversing the AST
self.access_modes_cache: Dict[Tuple[str, str], Dict[str, AccessMode]] = {}
self.safety_cache: Dict[Tuple[str, str], Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]] = {}

def log_operation(
self,
future: Optional[Future],
Expand Down Expand Up @@ -108,10 +112,19 @@ def log_operation(
entity: PyKokkosEntity = parser.get_entity(entity_name)
AST: ast.FunctionDef = entity.AST

cache_key: Tuple[str, str] = (parser.path, entity_name)

dependencies: Set[DataDependency]
access_modes: Dict[str, AccessMode]
dependencies, access_modes = self.get_data_dependencies(kwargs, AST)
access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = self.get_safety_info(kwargs, AST)
dependencies, access_modes = self.get_data_dependencies(kwargs, AST, cache_key)

access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]

if cache_key in self.safety_cache:
access_indices = self.safety_cache[cache_key]
else:
access_indices = self.get_safety_info(kwargs, AST)
self.safety_cache[cache_key] = access_indices

tracer_op = TracerOperation(self.op_id, future, name, policy, workunit, operation, parser, entity_name, dict(kwargs), dependencies, access_indices)
self.op_id += 1
Expand All @@ -120,7 +133,7 @@ def log_operation(

self.operations[tracer_op] = None

def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]]:
def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]:
"""
Get the view access indices needed to check for safety

Expand All @@ -141,10 +154,10 @@ def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Dict[

# Map from view name (str) + dimension (int) to the type of
# access to that view's dimension
write_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = get_view_write_indices_and_modes(AST, view_name_and_rank)
write_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = get_view_write_indices_and_modes(AST, view_name_and_rank)

# Now need to convert view name to view ID
safety_info: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = {}
safety_info: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = {}
for (name, dim), access_index in write_indices.items():
view_id: int = view_args[name]
safety_info[(view_id, dim)] = access_index
Expand Down Expand Up @@ -225,7 +238,7 @@ def fuse(self, operations: List[TracerOperation], strategy: str) -> List[TracerO

raise RuntimeError(f"Unrecognized fusion strategy '{strategy}'")

def is_safe_to_fuse(self, current: List[TracerOperation], current_views: Set[ViewType], current_safety_info: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]], next: TracerOperation, next_views: Set[ViewType]) -> bool:
def is_safe_to_fuse(self, current: List[TracerOperation], current_views: Set[ViewType], current_safety_info: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode, str]], next: TracerOperation, next_views: Set[ViewType]) -> bool:
"""
Check whether the next operation is safe to fuse with the
current operations
Expand All @@ -244,16 +257,20 @@ def is_safe_to_fuse(self, current: List[TracerOperation], current_views: Set[Vie
for dim in range(view.rank()):
key: Tuple[int, int] = (id(view), dim)

# assert key in current_safety_info and key in next_safety_info
assert key in current_safety_info
assert key in next_safety_info

current_access_index, current_access_mode = current_safety_info[key]
next_access_index, next_access_mode = next_safety_info[key]
current_access_index, current_access_mode, current_index_str = current_safety_info[key]
next_access_index, next_access_mode, next_index_str = next_safety_info[key]

if current_access_mode == AccessMode.Read and next_access_mode == AccessMode.Read:
continue

# If the same function on the thread index is used to
# index both views then this will not prevent fusion.
if current_access_index == AccessIndex.TIDFunc and next_access_index == AccessIndex.TIDFunc and current_index_str == next_index_str:
continue

if current_access_index.value > AccessIndex.TID.value or next_access_index.value > AccessIndex.TID.value:
return False

Expand Down Expand Up @@ -377,7 +394,7 @@ def fuse_naive(self, operations: List[TracerOperation]) -> List[TracerOperation]

return fused_ops

def fuse_safety_info(self, info_0: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]], info_1: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]]) -> Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]]:
def fuse_safety_info(self, info_0: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode, str]], info_1: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode, str]]) -> Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode, str]]:
"""
Fuse the safety info of two separate operations

Expand All @@ -386,13 +403,13 @@ def fuse_safety_info(self, info_0: Dict[Tuple[int, int], Tuple[AccessIndex, Acce
:returns: the fused safety info
"""

fused_info: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]] = {}
fused_info: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode, str]] = {}
for key, value in info_0.items():
if key not in info_1:
fused_info[key] = value
else:
other_index, other_mode = info_1[key]
current_index, current_mode = value
other_index, other_mode, other_index_str = info_1[key]
current_index, current_mode, current_index_str = value

index_to_set: AccessIndex
mode_to_set: AccessMode
Expand All @@ -407,7 +424,7 @@ def fuse_safety_info(self, info_0: Dict[Tuple[int, int], Tuple[AccessIndex, Acce
else:
mode_to_set = AccessMode.ReadWrite

fused_info[key] = (index_to_set, mode_to_set)
fused_info[key] = (index_to_set, mode_to_set, other_index_str)

for key, value in info_1.items():
# Already handled in the previous loop
Expand Down Expand Up @@ -442,7 +459,7 @@ def fuse_operations(self, operations: List[TracerOperation], fused_safety_info:
parsers: List[Parser] = []
args: Dict[str, Dict[str, Any]] = {}
dependencies: Set[DataDependency] = set()
safety_info: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = {}
safety_info: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = {}

for index, op in enumerate(operations):
assert isinstance(op.policy, RangePolicy) and policy.begin == op.policy.begin and policy.end == op.policy.end
Expand All @@ -463,12 +480,13 @@ def fuse_operations(self, operations: List[TracerOperation], fused_safety_info:

return TracerOperation(None, future, fused_name, policy, workunits, operation, parsers, fused_name, args, dependencies, fused_safety_info)

def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Tuple[Set[DataDependency], Dict[str, AccessMode]]:
def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef, cache_key: Tuple[str, str]) -> Tuple[Set[DataDependency], Dict[str, AccessMode]]:
"""
Get the data dependencies of an operation from its input arguments

:param kwargs: the keyword arguments passed to the workunit
:param AST: the AST of the input workunit
:param cache_key: the key used to cache the results of traversing the AST
:returns: the set of data dependencies and the access modes of the views
"""

Expand All @@ -489,7 +507,12 @@ def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) ->
if isinstance(value, ViewType):
view_args.add(arg)

access_modes: Dict[str, AccessMode] = get_view_access_modes(AST, view_args)
access_modes: Dict[str, AccessMode]
if cache_key in self.access_modes_cache:
access_modes = self.access_modes_cache[cache_key]
else:
access_modes = get_view_access_modes(AST, view_args)
self.access_modes_cache[cache_key] = access_modes

# Second pass to check if the views are dependencies
for arg, value in kwargs.items():
Expand Down
2 changes: 0 additions & 2 deletions pykokkos/core/type_inference/args_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,6 @@ def get_type_info(
is_missing_annotations: bool = (
workunit_str in ORIGINAL_PARAMS
or
list_passed
or
check_missing_annotations(this_tree.args.args)
)

Expand Down
Loading