diff --git a/examples/pykokkos/team_thread_mdrange.py b/examples/pykokkos/team_thread_mdrange.py new file mode 100644 index 00000000..6fcde2c4 --- /dev/null +++ b/examples/pykokkos/team_thread_mdrange.py @@ -0,0 +1,35 @@ +import pykokkos as pk + + +@pk.workunit +def kernel(team, A, B, C, N1, N2): + league_rank: int = team.league_rank() + + def inner_for(i0: int, i1: int): + A[league_rank][i0][i1] = B[league_rank][i0] + C[i1] + + pk.parallel_for(pk.TeamThreadMDRange(team, N1, N2), inner_for) + team.team_barrier() + +def run(): + N0 = 16 + N1 = 4 + N2 = 4 + + A = pk.View((N0, N1, N2)) + B = pk.View((N0, N1)) + C = pk.View((N2,)) + + B.fill(1) + C.fill(1) + + print(N0) + print(N1 * N2) + + policy = pk.TeamPolicy(N0, N1 * N2) + pk.parallel_for(policy, kernel, A=A, B=B, C=C, N1=N1, N2=N2) + + print(A) + +if __name__ == "__main__": + run() \ No newline at end of file diff --git a/pykokkos/core/visitors/workunit_visitor.py b/pykokkos/core/visitors/workunit_visitor.py index ecb14317..3d52a7b4 100644 --- a/pykokkos/core/visitors/workunit_visitor.py +++ b/pykokkos/core/visitors/workunit_visitor.py @@ -255,7 +255,7 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr: return call function = cppast.DeclRefExpr(f"Kokkos::{name}") - if name in ("TeamThreadRange", "ThreadVectorRange"): + if name in ("TeamThreadRange", "ThreadVectorRange", "TeamThreadMDRange"): return cppast.CallExpr(function, args) if name in ("parallel_for", "single"): diff --git a/pykokkos/interface/__init__.py b/pykokkos/interface/__init__.py index 701ae3e9..bf6f9210 100644 --- a/pykokkos/interface/__init__.py +++ b/pykokkos/interface/__init__.py @@ -25,7 +25,7 @@ ) from .execution_policy import ( ExecutionPolicy, RangePolicy, MDRangePolicy, TeamPolicy, - TeamThreadRange, ThreadVectorRange, Iterate, Rank + TeamThreadRange, ThreadVectorRange, TeamThreadMDRange, Iterate, Rank ) from .execution_space import ExecutionSpace, ExecutionSpaceInstance, is_host_execution_space from .layout import Layout, get_default_layout diff --git a/pykokkos/interface/execution_policy.py b/pykokkos/interface/execution_policy.py index 2b8ae762..0220c933 100644 --- a/pykokkos/interface/execution_policy.py +++ b/pykokkos/interface/execution_policy.py @@ -195,3 +195,8 @@ def __init__(self, team_member: TeamMember, count: int): self.team_member = team_member self.count: Final = count self.space: ExecutionSpace = ExecutionSpace.Debug + + +class TeamThreadMDRange(ExecutionPolicy): + def __init__(self, *args) -> None: + pass \ No newline at end of file