Skip to content

Commit

Permalink
Try more threads
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjereddy committed Mar 19, 2023
1 parent 78385d3 commit 274b15a
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 8 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ jobs:
mypy pykokkos
- name: run tests
run: |
export OMP_NUM_THREADS=4
python runtests.py
8 changes: 0 additions & 8 deletions pykokkos/linalg/workunits.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember,
view_a: pk.View2D[pk.double],
view_b: pk.View2D[pk.double],
out: pk.View2D[pk.double]):
printf("tiled workunit checkpoint 1")
# early attempt at tiled matrix multiplication in PyKokkos

# for now, let's assume a 2x2 tiling arrangement and
Expand All @@ -45,12 +44,10 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember,

# start off by getting a global thread id
global_tid: int = team_member.league_rank() * team_member.team_size() + team_member.team_rank()
printf("tiled workunit checkpoint 2 for thread id: %d\n", global_tid)

# TODO: I have no idea how to get 2D scratch memory views?
scratch_mem_a: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size)
scratch_mem_b: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size)
printf("tiled workunit checkpoint 3 for thread id: %d\n", global_tid)
# in a 4 x 4 matrix with 2 x 2 tiling the leagues
# and teams have matching row/col assignment approaches
bx: int = team_member.league_rank() / 2
Expand All @@ -64,7 +61,6 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember,
tmp: float = 0
col: int = by * 2 + ty
row: int = bx * 2 + tx
printf("tiled workunit checkpoint 4 for thread id: %d\n", global_tid)

# these variables are a bit silly--can we not get
# 2D scratch memory indexing?
Expand All @@ -74,16 +70,12 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember,
for i in range(out.extent(1) / 2):
scratch_mem_a[team_member.team_rank()] = view_a[row][i * 2 + ty]
scratch_mem_b[team_member.team_rank()] = view_b[i * 2 + tx][col]
printf("tiled workunit checkpoint 5 for thread id: %d\n", global_tid)
team_member.team_barrier()
printf("tiled workunit checkpoint 6 for thread id: %d\n", global_tid)

for k in range(2):
a_index = k + ((team_member.team_rank() // 2) * 2)
b_index = ty + (k * 2)
tmp += scratch_mem_a[a_index] * scratch_mem_b[b_index]
team_member.team_barrier()
printf("tiled workunit checkpoint 7 for thread id: %d\n", global_tid)

printf("tiled workunit checkpoint 8 for thread id: %d\n", global_tid)
out[row][col] = tmp

0 comments on commit 274b15a

Please sign in to comment.