Skip to content

Commit

Permalink
Debug prints
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjereddy committed Mar 19, 2023
1 parent 80e0d84 commit 1695b02
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions pykokkos/linalg/l3_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def dgemm(alpha: float,
# TODO: league and team size requests outside of these
# values can segfault...
league_size = int(C.size / (tile_width ** 2))
print("league_size:", league_size)
pk.parallel_for("tiled_matmul",
pk.TeamPolicy(league_size=league_size,
team_size=tile_width ** 2),
Expand Down
8 changes: 8 additions & 0 deletions pykokkos/linalg/workunits.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember,
view_a: pk.View2D[pk.double],
view_b: pk.View2D[pk.double],
out: pk.View2D[pk.double]):
printf("tiled workunit checkpoint 1")
# early attempt at tiled matrix multiplication in PyKokkos

# for now, let's assume a 2x2 tiling arrangement and
Expand All @@ -44,10 +45,12 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember,

# start off by getting a global thread id
global_tid: int = team_member.league_rank() * team_member.team_size() + team_member.team_rank()
printf("tiled workunit checkpoint 2 for thread id: %d\n", global_tid)

# TODO: I have no idea how to get 2D scratch memory views?
scratch_mem_a: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size)
scratch_mem_b: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size)
printf("tiled workunit checkpoint 3 for thread id: %d\n", global_tid)
# in a 4 x 4 matrix with 2 x 2 tiling the leagues
# and teams have matching row/col assignment approaches
bx: int = team_member.league_rank() / 2
Expand All @@ -61,6 +64,7 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember,
tmp: float = 0
col: int = by * 2 + ty
row: int = bx * 2 + tx
printf("tiled workunit checkpoint 4 for thread id: %d\n", global_tid)

# these variables are a bit silly--can we not get
# 2D scratch memory indexing?
Expand All @@ -70,12 +74,16 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember,
for i in range(out.extent(1) / 2):
scratch_mem_a[team_member.team_rank()] = view_a[row][i * 2 + ty]
scratch_mem_b[team_member.team_rank()] = view_b[i * 2 + tx][col]
printf("tiled workunit checkpoint 5 for thread id: %d\n", global_tid)
team_member.team_barrier()
printf("tiled workunit checkpoint 6 for thread id: %d\n", global_tid)

for k in range(2):
a_index = k + ((team_member.team_rank() // 2) * 2)
b_index = ty + (k * 2)
tmp += scratch_mem_a[a_index] * scratch_mem_b[b_index]
team_member.team_barrier()
printf("tiled workunit checkpoint 7 for thread id: %d\n", global_tid)

printf("tiled workunit checkpoint 8 for thread id: %d\n", global_tid)
out[row][col] = tmp

0 comments on commit 1695b02

Please sign in to comment.