Skip to content

Commit

Permalink
Interface: add TeamThreadMDRange
Browse files Browse the repository at this point in the history
  • Loading branch information
NaderAlAwar committed Sep 21, 2024
1 parent bed9349 commit 7e569b2
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 2 deletions.
35 changes: 35 additions & 0 deletions examples/pykokkos/team_thread_mdrange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pykokkos as pk


@pk.workunit
def kernel(team, A, B, C, N1, N2):
league_rank: int = team.league_rank()

def inner_for(i0: int, i1: int):
A[league_rank][i0][i1] = B[league_rank][i0] + C[i1]

pk.parallel_for(pk.TeamThreadMDRange(team, N1, N2), inner_for)
team.team_barrier()

def run():
N0 = 16
N1 = 4
N2 = 4

A = pk.View((N0, N1, N2))
B = pk.View((N0, N1))
C = pk.View((N2,))

B.fill(1)
C.fill(1)

print(N0)
print(N1 * N2)

policy = pk.TeamPolicy(N0, N1 * N2)
pk.parallel_for(policy, kernel, A=A, B=B, C=C, N1=N1, N2=N2)

print(A)

if __name__ == "__main__":
run()
2 changes: 1 addition & 1 deletion pykokkos/core/visitors/workunit_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:
return call

function = cppast.DeclRefExpr(f"Kokkos::{name}")
if name in ("TeamThreadRange", "ThreadVectorRange"):
if name in ("TeamThreadRange", "ThreadVectorRange", "TeamThreadMDRange"):
return cppast.CallExpr(function, args)

if name in ("parallel_for", "single"):
Expand Down
2 changes: 1 addition & 1 deletion pykokkos/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from .execution_policy import (
ExecutionPolicy, RangePolicy, MDRangePolicy, TeamPolicy,
TeamThreadRange, ThreadVectorRange, Iterate, Rank
TeamThreadRange, ThreadVectorRange, TeamThreadMDRange, Iterate, Rank
)
from .execution_space import ExecutionSpace, ExecutionSpaceInstance, is_host_execution_space
from .layout import Layout, get_default_layout
Expand Down
5 changes: 5 additions & 0 deletions pykokkos/interface/execution_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,8 @@ def __init__(self, team_member: TeamMember, count: int):
self.team_member = team_member
self.count: Final = count
self.space: ExecutionSpace = ExecutionSpace.Debug


class TeamThreadMDRange(ExecutionPolicy):
def __init__(self, *args) -> None:
pass

0 comments on commit 7e569b2

Please sign in to comment.