From c88c279a653ab5b93d0c8f440815679ad26cf7e7 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Tue, 3 Jan 2023 14:02:54 -0700 Subject: [PATCH 01/21] ENH: DGEMM workunits * `dgemm` now uses `pykokkos` workunits/kernels to achieve much faster performance than before * I had to correct a mistake in the benchmark code--we now use larger tiling dimensions to expand the data to avoid having empty arrays there--the net effect is bigger benchmark sizes, which seems desirable anyway * the benchmark code was also adjusted to modulate/directly control the number of OpenMP threads used by PyKokkos using the `threadpoolctl` library--this seems to stabilize the timing from trial to trial a bit better but there is still quite a bit more variation than I'd like between trials (benchmarking concurrent code is hard...) for PyKokkos (warmup issues?) * the small, medium, large slowdowns vs. SciPy are more reasonable now (with kernels pre-compiled/cached) - from gh-134: 310X, 4014X, and 4985X slower, respectively - here with 1 OpenMP thread: 75X, 19X, 14X - here with 4 OpenMP threads: 62X, 66X, 10X - here with 10 OpenMP threads: 38X, 18X, 13X * it may also be interesting to check these on the GPU, although OpenBLAS is just using the host as well --- benchmarks/dgemm_compare.py | 126 ++++++++++++++++++----------------- mypy.ini | 3 + pykokkos/linalg/l3_blas.py | 27 +++++--- pykokkos/linalg/workunits.py | 28 ++++++++ 4 files changed, 115 insertions(+), 69 deletions(-) create mode 100644 pykokkos/linalg/workunits.py diff --git a/benchmarks/dgemm_compare.py b/benchmarks/dgemm_compare.py index 2cdea05b..e1ddcb8d 100644 --- a/benchmarks/dgemm_compare.py +++ b/benchmarks/dgemm_compare.py @@ -3,6 +3,7 @@ (i.e., a wheel with OpenBLAS 0.3.18) """ +from threadpoolctl import threadpool_limits import pykokkos as pk from pykokkos.linalg.l3_blas import dgemm as pk_dgemm @@ -16,66 +17,69 @@ if __name__ == "__main__": import timeit - num_repeats = 50 - results = {"PyKokkos": {}, - "SciPy": {}} - alpha, a, b, c, beta = (3.6, - np.array([[8, 7, 1, 200, 55.3], - [99.2, 1.11, 2.02, 17.7, 900.2], - [5.01, 15.21, 22.07, 1.09, 22.22], - [1, 2, 3, 4, 5]], dtype=np.float64), - np.array([[9, 0, 2, 19], - [77, 100, 4, 19], - [1, 500, 9, 19], - [226.68, 11.61, 12.12, 19], - [17.7, 200.10, 301.17, 20]], dtype=np.float64), - np.ones((4, 4)) * 3.3, - 4.3) - for system_size in ["small", "medium", "large"]: - print("-" * 20) - print(f"system size: {system_size}") + for num_threads in [1, 4, 10]: + print("num OpenMP threads:", num_threads) + num_repeats = 50 + results = {"PyKokkos": {}, + "SciPy": {}} + alpha, a, b, c, beta = (3.6, + np.array([[8, 7, 1, 200, 55.3], + [99.2, 1.11, 2.02, 17.7, 900.2], + [5.01, 15.21, 22.07, 1.09, 22.22], + [1, 2, 3, 4, 5]], dtype=np.float64), + np.array([[9, 0, 2, 19], + [77, 100, 4, 19], + [1, 500, 9, 19], + [226.68, 11.61, 12.12, 19], + [17.7, 200.10, 301.17, 20]], dtype=np.float64), + np.ones((4, 4)) * 3.3, + 4.3) + for system_size in ["small", "medium", "large"]: + print("-" * 20) + print(f"system size: {system_size}") - if system_size == "medium": - a_new = np.tile(a, (10, 0)) - b_new = np.tile(b, (0, 10)) - c_new = np.ones((40, 40)) * 3.3 - elif system_size == "large": - a_new = np.tile(a, (40, 0)) - b_new = np.tile(b, (0, 40)) - c_new = np.ones((160, 160)) * 3.3 - else: - a_new = a - b_new = b - c_new = c + if system_size == "medium": + a_new = np.tile(a, (10, 1)) + b_new = np.tile(b, (1, 10)) + c_new = np.ones((40, 40)) * 3.3 + elif system_size == "large": + a_new = np.tile(a, (40, 1)) + b_new = np.tile(b, (1, 40)) + c_new = np.ones((160, 160)) * 3.3 + else: + a_new = a + b_new = b + c_new = c - view_a = pk.from_numpy(a_new) - view_b = pk.from_numpy(b_new) - view_c = pk.from_numpy(c_new) - pk_dgemm_time_sec = timeit.timeit("pk_dgemm(alpha, view_a, view_b, beta, view_c)", - globals=globals(), - number=num_repeats) - results["PyKokkos"][system_size] = pk_dgemm_time_sec - print(f"PyKokkos DGEMM execution time (s) for {num_repeats} repeats: {pk_dgemm_time_sec}") - scipy_dgemm_time_sec = timeit.timeit("scipy_dgemm(alpha, a_new, b_new, beta, c_new)", - globals=globals(), - number=num_repeats) - results["SciPy"][system_size] = scipy_dgemm_time_sec - print(f"SciPy DGEMM execution time (s) for {num_repeats} repeats: {scipy_dgemm_time_sec}") - ratio = pk_dgemm_time_sec / scipy_dgemm_time_sec - if ratio == 1: - print("PyKokkos DGEMM timing is identical to SciPy") - elif ratio > 1: - print(f"PyKokkos DGEMM timing is slower than SciPy with ratio: {ratio:.2f} fold") - else: - print(f"PyKokkos DGEMM timing is faster than SciPy with ratio: {ratio:.2f} fold") - print("-" * 20) - df = pd.DataFrame.from_dict(results) - print("df:\n", df) - fig, ax = plt.subplots() - df.plot.bar(ax=ax, - rot=0, - logy=True, - xlabel="Problem Size", - ylabel=f"log of time (s) for {num_repeats} repeats", - title="DGEMM Performance Comparison with timeit") - fig.savefig("DGEMM_perf_compare.png", dpi=300) + view_a = pk.from_numpy(a_new) + view_b = pk.from_numpy(b_new) + view_c = pk.from_numpy(c_new) + with threadpool_limits(limits=num_threads, user_api='openmp'): + pk_dgemm_time_sec = timeit.timeit("pk_dgemm(alpha, view_a, view_b, beta, view_c)", + globals=globals(), + number=num_repeats) + results["PyKokkos"][system_size] = pk_dgemm_time_sec + print(f"PyKokkos DGEMM execution time (s) for {num_repeats} repeats: {pk_dgemm_time_sec}") + scipy_dgemm_time_sec = timeit.timeit("scipy_dgemm(alpha, a_new, b_new, beta, c_new)", + globals=globals(), + number=num_repeats) + results["SciPy"][system_size] = scipy_dgemm_time_sec + print(f"SciPy DGEMM execution time (s) for {num_repeats} repeats: {scipy_dgemm_time_sec}") + ratio = pk_dgemm_time_sec / scipy_dgemm_time_sec + if ratio == 1: + print("PyKokkos DGEMM timing is identical to SciPy") + elif ratio > 1: + print(f"PyKokkos DGEMM timing is slower than SciPy with ratio: {ratio:.2f} fold") + else: + print(f"PyKokkos DGEMM timing is faster than SciPy with ratio: {ratio:.2f} fold") + print("-" * 20) + df = pd.DataFrame.from_dict(results) + print("df:\n", df) + fig, ax = plt.subplots() + df.plot.bar(ax=ax, + rot=0, + logy=True, + xlabel="Problem Size", + ylabel=f"log of time (s) for {num_repeats} repeats", + title="DGEMM Performance Comparison with timeit") + fig.savefig(f"DGEMM_perf_compare_{num_threads}_threads.png", dpi=300) diff --git a/mypy.ini b/mypy.ini index 2a2ef3a7..cf45069c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -113,3 +113,6 @@ ignore_errors = True [mypy-pykokkos.lib.ufunc_workunits] ignore_errors = True + +[mypy-pykokkos.linalg.workunits] +ignore_errors = True diff --git a/pykokkos/linalg/l3_blas.py b/pykokkos/linalg/l3_blas.py index bb1a6645..c1dd92c9 100644 --- a/pykokkos/linalg/l3_blas.py +++ b/pykokkos/linalg/l3_blas.py @@ -1,4 +1,5 @@ import pykokkos as pk +from pykokkos.linalg import workunits # Level 3 BLAS functions @@ -45,12 +46,22 @@ def dgemm(alpha: float, C = pk.View([view_a.shape[0], view_b.shape[1]], dtype=pk.double) - for m in range(view_a.shape[0]): - for n in range(view_b.shape[1]): - for k in range(k_a): - subresult = view_a[m, k] * view_b[k, n] * alpha - C[m, n] += float(subresult) # type: ignore - if view_c is not None: - C[m, n] += (view_c[m, n] * beta) # type: ignore - + if view_c is None: + pk.parallel_for(view_a.shape[0], + workunits.dgemm_impl_no_view_c, + k_a=k_a, + alpha=alpha, + view_a=view_a, + view_b=view_b, + out=C) + else: + pk.parallel_for(view_a.shape[0], + workunits.dgemm_impl_view_c, + k_a=k_a, + alpha=alpha, + beta=beta, + view_a=view_a, + view_b=view_b, + view_c=view_c, + out=C) return C diff --git a/pykokkos/linalg/workunits.py b/pykokkos/linalg/workunits.py new file mode 100644 index 00000000..b8bf58bd --- /dev/null +++ b/pykokkos/linalg/workunits.py @@ -0,0 +1,28 @@ +import pykokkos as pk + + +@pk.workunit +def dgemm_impl_view_c(tid: int, + k_a: int, + alpha: float, + beta: float, + view_a: pk.View2D[pk.double], + view_b: pk.View2D[pk.double], + view_c: pk.View2D[pk.double], + out: pk.View2D[pk.double]): + for n in range(view_b.extent(1)): + for k in range(k_a): + out[tid][n] += float(view_a[tid][k] * view_b[k][n] * alpha) + out[tid][n] += (view_c[tid][n] * beta) + + +@pk.workunit +def dgemm_impl_no_view_c(tid: int, + k_a: int, + alpha: float, + view_a: pk.View2D[pk.double], + view_b: pk.View2D[pk.double], + out: pk.View2D[pk.double]): + for n in range(view_b.extent(1)): + for k in range(k_a): + out[tid][n] += float(view_a[tid][k] * view_b[k][n] * alpha) From e49cbf5936a540bb8786fcc924e8f549c1acfedb Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Tue, 3 Jan 2023 15:16:45 -0700 Subject: [PATCH 02/21] MAINT: unpin mypy --- .github/workflows/main_ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main_ci.yml b/.github/workflows/main_ci.yml index ec2fcb61..9bd9a090 100644 --- a/.github/workflows/main_ci.yml +++ b/.github/workflows/main_ci.yml @@ -24,7 +24,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install --upgrade numpy mypy==1.0.1 cmake pytest pybind11 scikit-build patchelf + python -m pip install --upgrade numpy mypy cmake pytest pybind11 scikit-build patchelf - name: Install pykokkos-base run: | cd /tmp From 1cdbb9c57a1755440c023333aede3350d688101b Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Thu, 5 Jan 2023 17:07:06 -0700 Subject: [PATCH 03/21] BENCH: fixup benchmarks * remove `threadpoolctl` stuff and switch to using `OMP_NUM_THREADS` manually + do way more trials and use boxplots to better visualize outliers I might be concerned about --- benchmarks/dgemm_compare.py | 103 ++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 47 deletions(-) diff --git a/benchmarks/dgemm_compare.py b/benchmarks/dgemm_compare.py index e1ddcb8d..0d050c7d 100644 --- a/benchmarks/dgemm_compare.py +++ b/benchmarks/dgemm_compare.py @@ -3,7 +3,8 @@ (i.e., a wheel with OpenBLAS 0.3.18) """ -from threadpoolctl import threadpool_limits +import os + import pykokkos as pk from pykokkos.linalg.l3_blas import dgemm as pk_dgemm @@ -13,31 +14,43 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt import pandas as pd +from tqdm import tqdm if __name__ == "__main__": import timeit - for num_threads in [1, 4, 10]: - print("num OpenMP threads:", num_threads) - num_repeats = 50 - results = {"PyKokkos": {}, - "SciPy": {}} - alpha, a, b, c, beta = (3.6, - np.array([[8, 7, 1, 200, 55.3], - [99.2, 1.11, 2.02, 17.7, 900.2], - [5.01, 15.21, 22.07, 1.09, 22.22], - [1, 2, 3, 4, 5]], dtype=np.float64), - np.array([[9, 0, 2, 19], - [77, 100, 4, 19], - [1, 500, 9, 19], - [226.68, 11.61, 12.12, 19], - [17.7, 200.10, 301.17, 20]], dtype=np.float64), - np.ones((4, 4)) * 3.3, - 4.3) - for system_size in ["small", "medium", "large"]: - print("-" * 20) - print(f"system size: {system_size}") + num_global_repeats = 50 + num_repeats = 5000 + results = { + "PyKokkos": {"small": [], + "medium": [], + "large": []}, + "SciPy": {"small": [], + "medium": [], + "large": []}, + } + alpha, a, b, c, beta = (3.6, + np.array([[8, 7, 1, 200, 55.3], + [99.2, 1.11, 2.02, 17.7, 900.2], + [5.01, 15.21, 22.07, 1.09, 22.22], + [1, 2, 3, 4, 5]], dtype=np.float64), + np.array([[9, 0, 2, 19], + [77, 100, 4, 19], + [1, 500, 9, 19], + [226.68, 11.61, 12.12, 19], + [17.7, 200.10, 301.17, 20]], dtype=np.float64), + np.ones((4, 4)) * 3.3, + 4.3) + num_threads = os.environ.get("OMP_NUM_THREADS") + df = pd.DataFrame(np.full(shape=(num_global_repeats * 2, 4), fill_value=np.nan), + columns=["backend", "small", "medium", "large"]) + df["backend"] = df["backend"].astype(str) + if num_threads is None: + raise ValueError("must set OMP_NUM_THREADS for benchmarks!") + counter = 0 + for global_repeat in tqdm(range(1, num_global_repeats + 1)): + for col_num, system_size in tqdm(enumerate(["small", "medium", "large"]), total=3): if system_size == "medium": a_new = np.tile(a, (10, 1)) b_new = np.tile(b, (1, 10)) @@ -54,32 +67,28 @@ view_a = pk.from_numpy(a_new) view_b = pk.from_numpy(b_new) view_c = pk.from_numpy(c_new) - with threadpool_limits(limits=num_threads, user_api='openmp'): - pk_dgemm_time_sec = timeit.timeit("pk_dgemm(alpha, view_a, view_b, beta, view_c)", - globals=globals(), - number=num_repeats) - results["PyKokkos"][system_size] = pk_dgemm_time_sec - print(f"PyKokkos DGEMM execution time (s) for {num_repeats} repeats: {pk_dgemm_time_sec}") + pk_dgemm_time_sec = timeit.timeit("pk_dgemm(alpha, view_a, view_b, beta, view_c)", + globals=globals(), + number=num_repeats) + results["PyKokkos"][system_size].append(pk_dgemm_time_sec) + df.iloc[counter, 0] = "PyKokkos" + df.iloc[counter, col_num + 1] = pk_dgemm_time_sec scipy_dgemm_time_sec = timeit.timeit("scipy_dgemm(alpha, a_new, b_new, beta, c_new)", globals=globals(), number=num_repeats) - results["SciPy"][system_size] = scipy_dgemm_time_sec - print(f"SciPy DGEMM execution time (s) for {num_repeats} repeats: {scipy_dgemm_time_sec}") - ratio = pk_dgemm_time_sec / scipy_dgemm_time_sec - if ratio == 1: - print("PyKokkos DGEMM timing is identical to SciPy") - elif ratio > 1: - print(f"PyKokkos DGEMM timing is slower than SciPy with ratio: {ratio:.2f} fold") - else: - print(f"PyKokkos DGEMM timing is faster than SciPy with ratio: {ratio:.2f} fold") - print("-" * 20) - df = pd.DataFrame.from_dict(results) - print("df:\n", df) - fig, ax = plt.subplots() - df.plot.bar(ax=ax, - rot=0, - logy=True, - xlabel="Problem Size", - ylabel=f"log of time (s) for {num_repeats} repeats", - title="DGEMM Performance Comparison with timeit") - fig.savefig(f"DGEMM_perf_compare_{num_threads}_threads.png", dpi=300) + results["SciPy"][system_size].append(scipy_dgemm_time_sec) + df.iloc[counter + 1, 0] = "SciPy" + df.iloc[counter + 1, col_num + 1] = scipy_dgemm_time_sec + counter += 2 + + print("df:\n", df) + fig, axes = plt.subplots(nrows=1, ncols=3) + fig.set_size_inches(12, 5) + df.boxplot(ax=axes, + by="backend", + ) + for ax in axes: + ax.set_xlabel("Backend") + axes[0].set_ylabel(f"Time (s) for {num_repeats} DGEMM executions") + fig.suptitle(f"DGEMM performance boxplots (OMP_NUM_THREADS={num_threads}; {num_global_repeats} trials) for different problem sizes") + fig.savefig(f"DGEMM_perf_compare_{num_threads}_threads.png", dpi=300) From 0de8ced31f2fb4bc979d55eba126f353b49237be Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Mon, 9 Jan 2023 13:51:26 -0700 Subject: [PATCH 04/21] ENH: PR 146 revisions * add fold ratios directly to plots to facilitate performance comparisons --- benchmarks/dgemm_compare.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/benchmarks/dgemm_compare.py b/benchmarks/dgemm_compare.py index 0d050c7d..f737eeb5 100644 --- a/benchmarks/dgemm_compare.py +++ b/benchmarks/dgemm_compare.py @@ -82,13 +82,30 @@ counter += 2 print("df:\n", df) + ratios = df[df["backend"] == "PyKokkos"].iloc[..., 1:].reset_index(drop=True) / df[df["backend"] == "SciPy"].iloc[..., 1:].reset_index(drop=True) + avg_ratios = ratios.mean(axis=0) + std_ratios = ratios.std(axis=0) + fig, axes = plt.subplots(nrows=1, ncols=3) fig.set_size_inches(12, 5) df.boxplot(ax=axes, by="backend", ) for ax in axes: - ax.set_xlabel("Backend") + problem_size = ax.get_title() + avg_ratio = avg_ratios[problem_size] + std_ratio = std_ratios[problem_size] + if avg_ratio == 1: + color = "gray" + prefix = "(Same Performance)" + elif avg_ratio > 1: + color = "red" + prefix = "(PyKokkos slower by)" + elif avg_ratio < 1: + color = "green" + prefix = "(PyKokkos faster by)" + final_ratio = f"{prefix} {avg_ratio:.1f} $\pm$ {std_ratio:.1f} Fold" + ax.set_xlabel(final_ratio, color=color) axes[0].set_ylabel(f"Time (s) for {num_repeats} DGEMM executions") fig.suptitle(f"DGEMM performance boxplots (OMP_NUM_THREADS={num_threads}; {num_global_repeats} trials) for different problem sizes") fig.savefig(f"DGEMM_perf_compare_{num_threads}_threads.png", dpi=300) From 632372434e863a6de04ee99b422b7dcee9ff3363 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Tue, 28 Feb 2023 17:14:42 -0700 Subject: [PATCH 05/21] WIP: draft tiled matmul--very early stages with team policy --- pykokkos/linalg/l3_blas.py | 45 +++++++++++++++++++++++---------- pykokkos/linalg/workunits.py | 48 ++++++++++++++++++++++++++++++++++++ tests/test_linalg.py | 13 ++++++++++ 3 files changed, 93 insertions(+), 13 deletions(-) diff --git a/pykokkos/linalg/l3_blas.py b/pykokkos/linalg/l3_blas.py index c1dd92c9..553c1a5d 100644 --- a/pykokkos/linalg/l3_blas.py +++ b/pykokkos/linalg/l3_blas.py @@ -7,7 +7,8 @@ def dgemm(alpha: float, view_a, view_b, beta: float = 0.0, - view_c = None): + view_c = None, + tiled=False): """ Double precision floating point genernal matrix multiplication (GEMM). @@ -21,6 +22,10 @@ def dgemm(alpha: float, Shape (k, n) beta: float, optional view_c: pykokkos view of type double, optional + tiled: bool, optional + whether to use tiled matrix multiplication + (currently only supports 2x2 tiles and 4x4 matrices with + no C view) Returns ------- @@ -46,22 +51,36 @@ def dgemm(alpha: float, C = pk.View([view_a.shape[0], view_b.shape[1]], dtype=pk.double) - if view_c is None: - pk.parallel_for(view_a.shape[0], - workunits.dgemm_impl_no_view_c, - k_a=k_a, - alpha=alpha, - view_a=view_a, - view_b=view_b, - out=C) + if not tiled: + if view_c is None: + pk.parallel_for(view_a.shape[0], + workunits.dgemm_impl_no_view_c, + k_a=k_a, + alpha=alpha, + view_a=view_a, + view_b=view_b, + out=C) + else: + pk.parallel_for(view_a.shape[0], + workunits.dgemm_impl_view_c, + k_a=k_a, + alpha=alpha, + beta=beta, + view_a=view_a, + view_b=view_b, + view_c=view_c, + out=C) else: - pk.parallel_for(view_a.shape[0], - workunits.dgemm_impl_view_c, + # 2 x 2 tiled matrix multiplication on 4x4 matrices + # TODO: generalize a bit, but assume rows and columns are + # powers of 2 + pk.parallel_for("tiled_matmul", + pk.TeamPolicy(league_size=4, # four 2 x 2 blocks hard-coded for now + team_size=4), # 2 x 2 tiles (threads) hardcoded for now + workunits.dgemm_impl_tiled_no_view_c, k_a=k_a, alpha=alpha, - beta=beta, view_a=view_a, view_b=view_b, - view_c=view_c, out=C) return C diff --git a/pykokkos/linalg/workunits.py b/pykokkos/linalg/workunits.py index b8bf58bd..2dc3b362 100644 --- a/pykokkos/linalg/workunits.py +++ b/pykokkos/linalg/workunits.py @@ -26,3 +26,51 @@ def dgemm_impl_no_view_c(tid: int, for n in range(view_b.extent(1)): for k in range(k_a): out[tid][n] += float(view_a[tid][k] * view_b[k][n] * alpha) + + +@pk.workunit +def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, + k_a: int, + alpha: float, + view_a: pk.View2D[pk.double], + view_b: pk.View2D[pk.double], + out: pk.View2D[pk.double]): + # early attempt at tiled matrix multiplication in PyKokkos + + # for now, let's assume a 2x2 tiling arrangement and + # that `view_a`, `view_b`, and `out` views are all 4 x 4 matrices + + # start off by getting a global thread id + global_tid: int = team_member.league_rank() * team_member.team_size() + team_member.team_rank() + printf("global tid: %d\n", global_tid) + # TODO: should be a simple equation for row/column indices + # in output, right?? not this conditional mess... + # assume data layout is in "C" order in memory + row: int = 0 + column: int = 0 + if team_member.league_rank() < 2 and team_member.team_rank() < 2: + row = 0 + elif team_member.league_rank() < 2 and team_member.team_rank() >= 2: + row = 1 + elif team_member.league_rank() >= 2 and team_member.team_rank() < 2: + row = 2 + else: + row = 3 + if team_member.league_rank() == 0 and team_member.team_rank() < 2: + column = team_member.team_rank() + elif team_member.league_rank() == 2 and team_member.team_rank() < 2: + column = team_member.team_rank() + elif team_member.league_rank() == 1 and team_member.team_rank() < 2: + column = 2 + team_member.team_rank() + elif team_member.league_rank() == 3 and team_member.team_rank() < 2: + column = 2 + team_member.team_rank() + elif team_member.league_rank() == 0 and team_member.team_rank() >= 2: + column = team_member.team_rank() - 2 + elif team_member.league_rank() == 2 and team_member.team_rank() >= 2: + column = team_member.team_rank() - 2 + elif team_member.league_rank() == 1 and team_member.team_rank() >= 2: + column = team_member.team_rank() + elif team_member.league_rank() == 3 and team_member.team_rank() >= 2: + column = team_member.team_rank() + # TODO: assign actual value here + out[row][column] = 5 diff --git a/tests/test_linalg.py b/tests/test_linalg.py index d3fc37aa..329bc425 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -139,3 +139,16 @@ def test_dgemm_input_handling(): dgemm(alpha=alpha, view_a=view_a, view_b=view_b) + + +def test_dgemm_tiled(): + a = np.ones((4, 4)) + b = np.ones((4, 4)) + expected = np.full((4, 4), 5) + actual_c = dgemm(alpha=1.0, + view_a=pk.from_numpy(a), + view_b=pk.from_numpy(b), + beta=0.0, + view_c=None, + tiled=True) + assert_allclose(actual_c, expected) From 20317489677b9e294853a2b43d4896f06ea71ab7 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Thu, 2 Mar 2023 10:03:08 -0700 Subject: [PATCH 06/21] ENH: use scratch for tiled DGEMM * early draft of scratch memory setup for the tiled DGEMM workunit * at the moment this doesn't work because of gh-180, so will need to deal with that first --- pykokkos/linalg/workunits.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pykokkos/linalg/workunits.py b/pykokkos/linalg/workunits.py index 2dc3b362..ec0e6c61 100644 --- a/pykokkos/linalg/workunits.py +++ b/pykokkos/linalg/workunits.py @@ -39,6 +39,7 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, # for now, let's assume a 2x2 tiling arrangement and # that `view_a`, `view_b`, and `out` views are all 4 x 4 matrices + tile_size: int = 4 # start off by getting a global thread id global_tid: int = team_member.league_rank() * team_member.team_size() + team_member.team_rank() @@ -74,3 +75,6 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, column = team_member.team_rank() # TODO: assign actual value here out[row][column] = 5 + + # start setting up the scratch (shared) memory for each team + scratch_mem: pk.ScratchView1D[double] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) From ca7bf744e6112029828281ad298f1d9b1464104b Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Tue, 7 Mar 2023 17:16:16 -0700 Subject: [PATCH 07/21] WIP, ENH: more kernel growth * created two scratch mem locations per team, and add draft code to fill them up (probably wrong) * draft code to fill the result view with the tiling operations (probably wrong) * add some tests for the tiled kernel vs. SciPy `dgemm` (new cases are failing, which makes sense for now) --- pykokkos/linalg/workunits.py | 23 +++++++++++++++--- tests/test_linalg.py | 47 ++++++++++++++++++++++++++++-------- 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/pykokkos/linalg/workunits.py b/pykokkos/linalg/workunits.py index ec0e6c61..a17a2803 100644 --- a/pykokkos/linalg/workunits.py +++ b/pykokkos/linalg/workunits.py @@ -39,7 +39,7 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, # for now, let's assume a 2x2 tiling arrangement and # that `view_a`, `view_b`, and `out` views are all 4 x 4 matrices - tile_size: int = 4 + tile_size: int = 4 # this is really just the team size... # start off by getting a global thread id global_tid: int = team_member.league_rank() * team_member.team_size() + team_member.team_rank() @@ -73,8 +73,23 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, column = team_member.team_rank() elif team_member.league_rank() == 3 and team_member.team_rank() >= 2: column = team_member.team_rank() - # TODO: assign actual value here - out[row][column] = 5 # start setting up the scratch (shared) memory for each team - scratch_mem: pk.ScratchView1D[double] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) + scratch_mem_a: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) + scratch_mem_b: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) + tmp: float = 0 + # each thread should load a single element into the local + # shared memory from A and B, which will then be shared with other members + # of the team + scratch_mem_a[team_member.team_rank()] = view_a[row][column] + scratch_mem_b[team_member.team_rank()] = view_b[row][column] + # sync threads to ensure memory is ready for shared + # usage in the team + team_member.team_barrier() + + for k in range(0, 2): + tmp += scratch_mem_a[0] * scratch_mem_b[0] + tmp += scratch_mem_a[1] * scratch_mem_b[2] + + # TODO: assign actual value here + out[row][column] = tmp diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 329bc425..4209d9ed 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -141,14 +141,41 @@ def test_dgemm_input_handling(): view_b=view_b) -def test_dgemm_tiled(): - a = np.ones((4, 4)) - b = np.ones((4, 4)) - expected = np.full((4, 4), 5) - actual_c = dgemm(alpha=1.0, - view_a=pk.from_numpy(a), - view_b=pk.from_numpy(b), - beta=0.0, - view_c=None, - tiled=True) +@pytest.mark.parametrize("alpha, a, b, expected", [ + (1.0, + np.ones((4, 4)), + np.ones((4, 4)), + np.full((4, 4), 4), + ), + (1.0, + np.eye(4, 4), + np.array([[0, 6, 5, 0], + [9, 2, 2, 1], + [3, 1, 3, 8], + [4, 9, 4, 2]], dtype=float), + np.array([[0, 6, 5, 0], + [9, 2, 2, 1], + [3, 1, 3, 8], + [4, 9, 4, 2]], dtype=float), + ), + (1.0, + np.ones((4, 4)), + np.array([[0, 6, 5, 0], + [9, 2, 2, 1], + [3, 1, 3, 8], + [4, 9, 4, 2]], dtype=float), + np.array([[16., 18., 14., 11.], + [16., 18., 14., 11.], + [16., 18., 14., 11.], + [16., 18., 14., 11.]], dtype=float) + + ), + ]) +def test_dgemm_tiled(alpha, a, b, expected): + actual_c = dgemm(alpha=alpha, + view_a=pk.from_numpy(a), + view_b=pk.from_numpy(b), + beta=0.0, + view_c=None, + tiled=True) assert_allclose(actual_c, expected) From 2605bbc8cc18f4b650e7e991cade16d8a830931a Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Wed, 8 Mar 2023 17:00:55 -0700 Subject: [PATCH 08/21] MAINT: vastly simplify indexing logic. --- pykokkos/linalg/workunits.py | 31 ++++--------------------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/pykokkos/linalg/workunits.py b/pykokkos/linalg/workunits.py index a17a2803..aa5a8a03 100644 --- a/pykokkos/linalg/workunits.py +++ b/pykokkos/linalg/workunits.py @@ -43,36 +43,13 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, # start off by getting a global thread id global_tid: int = team_member.league_rank() * team_member.team_size() + team_member.team_rank() - printf("global tid: %d\n", global_tid) # TODO: should be a simple equation for row/column indices # in output, right?? not this conditional mess... # assume data layout is in "C" order in memory - row: int = 0 - column: int = 0 - if team_member.league_rank() < 2 and team_member.team_rank() < 2: - row = 0 - elif team_member.league_rank() < 2 and team_member.team_rank() >= 2: - row = 1 - elif team_member.league_rank() >= 2 and team_member.team_rank() < 2: - row = 2 - else: - row = 3 - if team_member.league_rank() == 0 and team_member.team_rank() < 2: - column = team_member.team_rank() - elif team_member.league_rank() == 2 and team_member.team_rank() < 2: - column = team_member.team_rank() - elif team_member.league_rank() == 1 and team_member.team_rank() < 2: - column = 2 + team_member.team_rank() - elif team_member.league_rank() == 3 and team_member.team_rank() < 2: - column = 2 + team_member.team_rank() - elif team_member.league_rank() == 0 and team_member.team_rank() >= 2: - column = team_member.team_rank() - 2 - elif team_member.league_rank() == 2 and team_member.team_rank() >= 2: - column = team_member.team_rank() - 2 - elif team_member.league_rank() == 1 and team_member.team_rank() >= 2: - column = team_member.team_rank() - elif team_member.league_rank() == 3 and team_member.team_rank() >= 2: - column = team_member.team_rank() + row: int = global_tid / 4 + column: int = team_member.team_rank() + + #printf("global_tid, row, column, and element from a: %d: (%d, %d), %f\n", global_tid, row, column, view_a[row][column]) # start setting up the scratch (shared) memory for each team scratch_mem_a: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) From 73a18bbc53a17ab7fe128d8adaeeba589f3e74f3 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Thu, 9 Mar 2023 17:33:11 -0700 Subject: [PATCH 09/21] WIP: only 1 tiled test case failing. --- pykokkos/linalg/workunits.py | 119 ++++++++++++++++++++++++++++++++--- 1 file changed, 112 insertions(+), 7 deletions(-) diff --git a/pykokkos/linalg/workunits.py b/pykokkos/linalg/workunits.py index aa5a8a03..a6c5c40a 100644 --- a/pykokkos/linalg/workunits.py +++ b/pykokkos/linalg/workunits.py @@ -46,8 +46,22 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, # TODO: should be a simple equation for row/column indices # in output, right?? not this conditional mess... # assume data layout is in "C" order in memory - row: int = global_tid / 4 - column: int = team_member.team_rank() + row: int = 0 + column: int = 0 + counter: int = 0 + for league_rank in range(4): + for base_row in range(tile_size / 2): + for base_column in range(tile_size / 2): + if global_tid == counter: + if league_rank % 2 != 0: + column = base_column + 2 + else: + column = base_column + if league_rank < 2: + row = base_row + else: + row = base_row + 2 + counter += 1 #printf("global_tid, row, column, and element from a: %d: (%d, %d), %f\n", global_tid, row, column, view_a[row][column]) @@ -58,15 +72,106 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, # each thread should load a single element into the local # shared memory from A and B, which will then be shared with other members # of the team - scratch_mem_a[team_member.team_rank()] = view_a[row][column] - scratch_mem_b[team_member.team_rank()] = view_b[row][column] + if team_member.league_rank() == 0 or team_member.league_rank() == 3: + scratch_mem_a[team_member.team_rank()] = view_a[row][column] + scratch_mem_b[team_member.team_rank()] = view_b[row][column] + elif team_member.league_rank() == 1: + scratch_mem_a[team_member.team_rank()] = view_a[row][column - 2] + scratch_mem_b[team_member.team_rank()] = view_b[row][column] + elif team_member.league_rank() == 2: + scratch_mem_a[team_member.team_rank()] = view_a[row][column] + scratch_mem_b[team_member.team_rank()] = view_b[row - 2][column] # sync threads to ensure memory is ready for shared # usage in the team team_member.team_barrier() + # the first multiplication in the dot product + # is just the intersection of the row and column vectors + # in a and b: + if global_tid == 8: + printf("tmp checkpoint 1: %f\n", tmp) + printf("value to add to tmp: %f\n", scratch_mem_a[team_member.team_rank()] * scratch_mem_b[team_member.team_rank()]) + tmp += scratch_mem_a[team_member.team_rank()] * scratch_mem_b[team_member.team_rank()] + if global_tid == 8: + printf("tmp checkpoint 2: %f\n", tmp) + # the second multiplication in the dot product + # should include the adjacent tile members + new_index_a: int = 0 + new_index_b: int = 0 + if team_member.team_rank() == 0: + new_index_a = 1 + new_index_b = 2 + elif team_member.team_rank() == 1: + new_index_a = 0 + new_index_b = 3 + elif team_member.team_rank() == 2: + new_index_a = 3 + new_index_b = 0 + elif team_member.team_rank() == 3: + new_index_a = 2 + new_index_b = 1 + #if team_member.league_rank() == 3: + #for i in range(4): + #printf("global tid %d, scratch b element %d: %f\n", global_tid, i, scratch_mem_b[i]) - for k in range(0, 2): - tmp += scratch_mem_a[0] * scratch_mem_b[0] - tmp += scratch_mem_a[1] * scratch_mem_b[2] + #printf("global_tid: next A element, next B element in tile: %d: (%f, %f)\n", global_tid, scratch_mem_a[new_index_a], scratch_mem_b[new_index_b]) + tmp += scratch_mem_a[new_index_a] * scratch_mem_b[new_index_b] + if global_tid == 8: + printf("new a value to add in: %f\n", scratch_mem_a[new_index_a]) + printf("new b value to add in: %f\n", scratch_mem_b[new_index_b]) + printf("value to add to tmp: %f\n", scratch_mem_a[new_index_a] * scratch_mem_b[new_index_b]) + printf("tmp checkpoint 3: %f\n", tmp) + team_member.team_barrier() + # next, we need to load two more tiles from A and B to complete + # the row/column dot product + row_A: int = 0 + row_B: int = 0 + column_A: int = 0 + column_B: int = 0 + if team_member.league_rank() == 0: + # for the new A (row-wise) tile: + # the row number shouldn't change; + # the columns will iterate + row_A = row + column_A = column + 2 + # the reverse for the B tile + row_B = row + 2 + column_B = column + elif team_member.league_rank() == 1: + row_A = row + column_A = column - 2 + row_B = row + 2 + column_B = column + elif team_member.league_rank() == 2: + row_A = row + column_A = column + 2 + row_B = row + column_B = column + elif team_member.league_rank() == 3: + row_A = row + column_A = column - 2 + row_B = row - 2 + column_B = column + + # TODO: it should be possible to avoid this verbosity + # by looping... + scratch_mem_a[team_member.team_rank()] = view_a[row_A][column_A] + scratch_mem_b[team_member.team_rank()] = view_b[row_B][column_B] + team_member.team_barrier() + tmp += scratch_mem_a[team_member.team_rank()] * scratch_mem_b[team_member.team_rank()] + if global_tid == 8: + printf("row_A: %d\n", row_A) + printf("column_A: %d\n", column_A) + printf("row_B: %d\n", row_B) + printf("column_B: %d\n", column_B) + printf("new a value to add in: %f\n", scratch_mem_a[team_member.team_rank()]) + printf("new b value to add in: %f\n", scratch_mem_b[team_member.team_rank()]) + printf("value to add to tmp: %f\n", scratch_mem_a[team_member.team_rank()] * scratch_mem_b[team_member.team_rank()]) + printf("tmp checkpoint 4: %f\n", tmp) + tmp += scratch_mem_a[new_index_a] * scratch_mem_b[new_index_b] + if global_tid == 8: + printf("value to add to tmp: %f\n", scratch_mem_a[new_index_a] * scratch_mem_b[new_index_b]) + printf("tmp checkpoint 5: %f\n", tmp) + team_member.team_barrier() # TODO: assign actual value here out[row][column] = tmp From 521c849aec6cf5d4c25762577b47000a96c34347 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Sun, 12 Mar 2023 13:22:19 -0600 Subject: [PATCH 10/21] ENH: add tiled matmul tests passing * all tiled matmul tests passing; simplified algorithm --- pykokkos/linalg/workunits.py | 154 +++++++---------------------------- 1 file changed, 29 insertions(+), 125 deletions(-) diff --git a/pykokkos/linalg/workunits.py b/pykokkos/linalg/workunits.py index a6c5c40a..b7e5b3f4 100644 --- a/pykokkos/linalg/workunits.py +++ b/pykokkos/linalg/workunits.py @@ -40,138 +40,42 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, # for now, let's assume a 2x2 tiling arrangement and # that `view_a`, `view_b`, and `out` views are all 4 x 4 matrices tile_size: int = 4 # this is really just the team size... + width: int = 4 # start off by getting a global thread id global_tid: int = team_member.league_rank() * team_member.team_size() + team_member.team_rank() - # TODO: should be a simple equation for row/column indices - # in output, right?? not this conditional mess... - # assume data layout is in "C" order in memory - row: int = 0 - column: int = 0 - counter: int = 0 - for league_rank in range(4): - for base_row in range(tile_size / 2): - for base_column in range(tile_size / 2): - if global_tid == counter: - if league_rank % 2 != 0: - column = base_column + 2 - else: - column = base_column - if league_rank < 2: - row = base_row - else: - row = base_row + 2 - counter += 1 - #printf("global_tid, row, column, and element from a: %d: (%d, %d), %f\n", global_tid, row, column, view_a[row][column]) - - # start setting up the scratch (shared) memory for each team + # TODO: I have no idea how to get 2D scratch memory views? scratch_mem_a: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) scratch_mem_b: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) + # in a 4 x 4 matrix with 2 x 2 tiling the leagues + # and teams have matching row/col assignment approaches + bx: int = team_member.league_rank() / 2 + by: int = 0 + if team_member.league_rank() % 2 != 0: + by = 1 + tx: int = team_member.team_rank() / 2 + ty: int = 0 + if team_member.team_rank() % 2 != 0: + ty = 1 tmp: float = 0 - # each thread should load a single element into the local - # shared memory from A and B, which will then be shared with other members - # of the team - if team_member.league_rank() == 0 or team_member.league_rank() == 3: - scratch_mem_a[team_member.team_rank()] = view_a[row][column] - scratch_mem_b[team_member.team_rank()] = view_b[row][column] - elif team_member.league_rank() == 1: - scratch_mem_a[team_member.team_rank()] = view_a[row][column - 2] - scratch_mem_b[team_member.team_rank()] = view_b[row][column] - elif team_member.league_rank() == 2: - scratch_mem_a[team_member.team_rank()] = view_a[row][column] - scratch_mem_b[team_member.team_rank()] = view_b[row - 2][column] - # sync threads to ensure memory is ready for shared - # usage in the team - team_member.team_barrier() - # the first multiplication in the dot product - # is just the intersection of the row and column vectors - # in a and b: - if global_tid == 8: - printf("tmp checkpoint 1: %f\n", tmp) - printf("value to add to tmp: %f\n", scratch_mem_a[team_member.team_rank()] * scratch_mem_b[team_member.team_rank()]) - tmp += scratch_mem_a[team_member.team_rank()] * scratch_mem_b[team_member.team_rank()] - if global_tid == 8: - printf("tmp checkpoint 2: %f\n", tmp) - # the second multiplication in the dot product - # should include the adjacent tile members - new_index_a: int = 0 - new_index_b: int = 0 - if team_member.team_rank() == 0: - new_index_a = 1 - new_index_b = 2 - elif team_member.team_rank() == 1: - new_index_a = 0 - new_index_b = 3 - elif team_member.team_rank() == 2: - new_index_a = 3 - new_index_b = 0 - elif team_member.team_rank() == 3: - new_index_a = 2 - new_index_b = 1 - #if team_member.league_rank() == 3: - #for i in range(4): - #printf("global tid %d, scratch b element %d: %f\n", global_tid, i, scratch_mem_b[i]) + col: int = by * 2 + ty + row: int = bx * 2 + tx + + # these variables are a bit silly--can we not get + # 2D scratch memory indexing? + a_index: int = 0 + b_index: int = 0 - #printf("global_tid: next A element, next B element in tile: %d: (%f, %f)\n", global_tid, scratch_mem_a[new_index_a], scratch_mem_b[new_index_b]) - tmp += scratch_mem_a[new_index_a] * scratch_mem_b[new_index_b] - if global_tid == 8: - printf("new a value to add in: %f\n", scratch_mem_a[new_index_a]) - printf("new b value to add in: %f\n", scratch_mem_b[new_index_b]) - printf("value to add to tmp: %f\n", scratch_mem_a[new_index_a] * scratch_mem_b[new_index_b]) - printf("tmp checkpoint 3: %f\n", tmp) - team_member.team_barrier() - # next, we need to load two more tiles from A and B to complete - # the row/column dot product - row_A: int = 0 - row_B: int = 0 - column_A: int = 0 - column_B: int = 0 - if team_member.league_rank() == 0: - # for the new A (row-wise) tile: - # the row number shouldn't change; - # the columns will iterate - row_A = row - column_A = column + 2 - # the reverse for the B tile - row_B = row + 2 - column_B = column - elif team_member.league_rank() == 1: - row_A = row - column_A = column - 2 - row_B = row + 2 - column_B = column - elif team_member.league_rank() == 2: - row_A = row - column_A = column + 2 - row_B = row - column_B = column - elif team_member.league_rank() == 3: - row_A = row - column_A = column - 2 - row_B = row - 2 - column_B = column + for i in range(out.extent(1) / 2): + scratch_mem_a[team_member.team_rank()] = view_a[row][i * 2 + ty] + scratch_mem_b[team_member.team_rank()] = view_b[i * 2 + tx][col] + team_member.team_barrier() - # TODO: it should be possible to avoid this verbosity - # by looping... - scratch_mem_a[team_member.team_rank()] = view_a[row_A][column_A] - scratch_mem_b[team_member.team_rank()] = view_b[row_B][column_B] - team_member.team_barrier() - tmp += scratch_mem_a[team_member.team_rank()] * scratch_mem_b[team_member.team_rank()] - if global_tid == 8: - printf("row_A: %d\n", row_A) - printf("column_A: %d\n", column_A) - printf("row_B: %d\n", row_B) - printf("column_B: %d\n", column_B) - printf("new a value to add in: %f\n", scratch_mem_a[team_member.team_rank()]) - printf("new b value to add in: %f\n", scratch_mem_b[team_member.team_rank()]) - printf("value to add to tmp: %f\n", scratch_mem_a[team_member.team_rank()] * scratch_mem_b[team_member.team_rank()]) - printf("tmp checkpoint 4: %f\n", tmp) - tmp += scratch_mem_a[new_index_a] * scratch_mem_b[new_index_b] - if global_tid == 8: - printf("value to add to tmp: %f\n", scratch_mem_a[new_index_a] * scratch_mem_b[new_index_b]) - printf("tmp checkpoint 5: %f\n", tmp) - team_member.team_barrier() + for k in range(2): + a_index = tx + (k * 2) + b_index = ty + (k * 2) + tmp += scratch_mem_a[a_index] * scratch_mem_b[b_index] + team_member.team_barrier() - # TODO: assign actual value here - out[row][column] = tmp + out[row][col] = tmp From e40f5c489e099260e0f50f3b4a0b2fd6f4819d1c Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Sun, 12 Mar 2023 15:41:53 -0600 Subject: [PATCH 11/21] BUG, TST: more tests/fixes * more tiled DGEMM testing/bug fixing --- pykokkos/linalg/l3_blas.py | 23 ++++++++++++----------- pykokkos/linalg/workunits.py | 2 +- tests/test_linalg.py | 26 +++++++++++++++++++++++++- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/pykokkos/linalg/l3_blas.py b/pykokkos/linalg/l3_blas.py index 553c1a5d..173650e8 100644 --- a/pykokkos/linalg/l3_blas.py +++ b/pykokkos/linalg/l3_blas.py @@ -1,3 +1,5 @@ +from typing import Optional + import pykokkos as pk from pykokkos.linalg import workunits @@ -8,7 +10,7 @@ def dgemm(alpha: float, view_b, beta: float = 0.0, view_c = None, - tiled=False): + tile_width: Optional[int] = None): """ Double precision floating point genernal matrix multiplication (GEMM). @@ -22,10 +24,8 @@ def dgemm(alpha: float, Shape (k, n) beta: float, optional view_c: pykokkos view of type double, optional - tiled: bool, optional - whether to use tiled matrix multiplication - (currently only supports 2x2 tiles and 4x4 matrices with - no C view) + tile_width: int, optional + Number of elements along a dimension of the square tiles. Returns ------- @@ -51,7 +51,7 @@ def dgemm(alpha: float, C = pk.View([view_a.shape[0], view_b.shape[1]], dtype=pk.double) - if not tiled: + if not tile_width: if view_c is None: pk.parallel_for(view_a.shape[0], workunits.dgemm_impl_no_view_c, @@ -71,12 +71,13 @@ def dgemm(alpha: float, view_c=view_c, out=C) else: - # 2 x 2 tiled matrix multiplication on 4x4 matrices - # TODO: generalize a bit, but assume rows and columns are - # powers of 2 + # limited tiling support--only (some) convenient powers of two + # allowed for now... + # TODO: league and team size requests outside of these + # values can segfault... pk.parallel_for("tiled_matmul", - pk.TeamPolicy(league_size=4, # four 2 x 2 blocks hard-coded for now - team_size=4), # 2 x 2 tiles (threads) hardcoded for now + pk.TeamPolicy(league_size=4, + team_size=tile_width ** 2), workunits.dgemm_impl_tiled_no_view_c, k_a=k_a, alpha=alpha, diff --git a/pykokkos/linalg/workunits.py b/pykokkos/linalg/workunits.py index b7e5b3f4..69120a6a 100644 --- a/pykokkos/linalg/workunits.py +++ b/pykokkos/linalg/workunits.py @@ -73,7 +73,7 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, team_member.team_barrier() for k in range(2): - a_index = tx + (k * 2) + a_index = k + ((team_member.team_rank() // 2) * 2) b_index = ty + (k * 2) tmp += scratch_mem_a[a_index] * scratch_mem_b[b_index] team_member.team_barrier() diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 4209d9ed..e81b4a45 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -3,6 +3,7 @@ import numpy as np from numpy.testing import assert_allclose +import scipy import pytest @@ -172,10 +173,33 @@ def test_dgemm_input_handling(): ), ]) def test_dgemm_tiled(alpha, a, b, expected): + # expected values hardcoded from SciPy output actual_c = dgemm(alpha=alpha, view_a=pk.from_numpy(a), view_b=pk.from_numpy(b), beta=0.0, view_c=None, - tiled=True) + tile_width=2) + assert_allclose(actual_c, expected) + + +@pytest.mark.parametrize("input_width, tile_width", [ + (4, 2), + ]) +@pytest.mark.parametrize("seed", [ + 100787, 90, 10, + ]) +def test_dgemm_square_tiled_vs_scipy(input_width, tile_width, seed): + rng = np.random.default_rng(seed) + a = rng.integers(low=0, high=10, size=(input_width, input_width)).astype(float) + b = rng.integers(low=0, high=19, size=(input_width, input_width)).astype(float) + expected = scipy.linalg.blas.dgemm(alpha=1.0, + a=a, + b=b) + actual_c = dgemm(alpha=1.0, + view_a=pk.from_numpy(a), + view_b=pk.from_numpy(b), + beta=0.0, + view_c=None, + tile_width=tile_width) assert_allclose(actual_c, expected) From 64b8d0d8efcf8c8e1b869477780ad0d806c02655 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Sun, 12 Mar 2023 15:49:02 -0600 Subject: [PATCH 12/21] ENH: allow varied league_size * allow varied league_size, but currently segfaults when greater than `4` it seems... --- .github/workflows/main_ci.yml | 2 +- pykokkos/linalg/l3_blas.py | 3 ++- tests/test_linalg.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main_ci.yml b/.github/workflows/main_ci.yml index 9bd9a090..ec2fcb61 100644 --- a/.github/workflows/main_ci.yml +++ b/.github/workflows/main_ci.yml @@ -24,7 +24,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install --upgrade numpy mypy cmake pytest pybind11 scikit-build patchelf + python -m pip install --upgrade numpy mypy==1.0.1 cmake pytest pybind11 scikit-build patchelf - name: Install pykokkos-base run: | cd /tmp diff --git a/pykokkos/linalg/l3_blas.py b/pykokkos/linalg/l3_blas.py index 173650e8..f2316af0 100644 --- a/pykokkos/linalg/l3_blas.py +++ b/pykokkos/linalg/l3_blas.py @@ -75,8 +75,9 @@ def dgemm(alpha: float, # allowed for now... # TODO: league and team size requests outside of these # values can segfault... + league_size = int(C.size / (tile_width ** 2)) pk.parallel_for("tiled_matmul", - pk.TeamPolicy(league_size=4, + pk.TeamPolicy(league_size=league_size, team_size=tile_width ** 2), workunits.dgemm_impl_tiled_no_view_c, k_a=k_a, diff --git a/tests/test_linalg.py b/tests/test_linalg.py index e81b4a45..6c1ac9c1 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -185,6 +185,7 @@ def test_dgemm_tiled(alpha, a, b, expected): @pytest.mark.parametrize("input_width, tile_width", [ (4, 2), + #(8, 2), ]) @pytest.mark.parametrize("seed", [ 100787, 90, 10, From 8087560d87ca927612c57da369cc5074147521c9 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Sun, 12 Mar 2023 16:11:56 -0600 Subject: [PATCH 13/21] Add SciPy --- .github/workflows/main_ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main_ci.yml b/.github/workflows/main_ci.yml index ec2fcb61..5b30969b 100644 --- a/.github/workflows/main_ci.yml +++ b/.github/workflows/main_ci.yml @@ -24,7 +24,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install --upgrade numpy mypy==1.0.1 cmake pytest pybind11 scikit-build patchelf + python -m pip install --upgrade numpy mypy==1.0.1 cmake pytest pybind11 scikit-build patchelf scipy - name: Install pykokkos-base run: | cd /tmp From 92a4a258be41a99c97f5f3079e841247e1132ae1 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Sat, 18 Mar 2023 19:21:13 -0600 Subject: [PATCH 14/21] Debug prints --- pykokkos/linalg/l3_blas.py | 1 + pykokkos/linalg/workunits.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/pykokkos/linalg/l3_blas.py b/pykokkos/linalg/l3_blas.py index f2316af0..842f066e 100644 --- a/pykokkos/linalg/l3_blas.py +++ b/pykokkos/linalg/l3_blas.py @@ -76,6 +76,7 @@ def dgemm(alpha: float, # TODO: league and team size requests outside of these # values can segfault... league_size = int(C.size / (tile_width ** 2)) + print("league_size:", league_size) pk.parallel_for("tiled_matmul", pk.TeamPolicy(league_size=league_size, team_size=tile_width ** 2), diff --git a/pykokkos/linalg/workunits.py b/pykokkos/linalg/workunits.py index 69120a6a..e2cc2898 100644 --- a/pykokkos/linalg/workunits.py +++ b/pykokkos/linalg/workunits.py @@ -35,6 +35,7 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, view_a: pk.View2D[pk.double], view_b: pk.View2D[pk.double], out: pk.View2D[pk.double]): + printf("tiled workunit checkpoint 1") # early attempt at tiled matrix multiplication in PyKokkos # for now, let's assume a 2x2 tiling arrangement and @@ -44,10 +45,12 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, # start off by getting a global thread id global_tid: int = team_member.league_rank() * team_member.team_size() + team_member.team_rank() + printf("tiled workunit checkpoint 2 for thread id: %d\n", global_tid) # TODO: I have no idea how to get 2D scratch memory views? scratch_mem_a: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) scratch_mem_b: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) + printf("tiled workunit checkpoint 3 for thread id: %d\n", global_tid) # in a 4 x 4 matrix with 2 x 2 tiling the leagues # and teams have matching row/col assignment approaches bx: int = team_member.league_rank() / 2 @@ -61,6 +64,7 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, tmp: float = 0 col: int = by * 2 + ty row: int = bx * 2 + tx + printf("tiled workunit checkpoint 4 for thread id: %d\n", global_tid) # these variables are a bit silly--can we not get # 2D scratch memory indexing? @@ -70,12 +74,16 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, for i in range(out.extent(1) / 2): scratch_mem_a[team_member.team_rank()] = view_a[row][i * 2 + ty] scratch_mem_b[team_member.team_rank()] = view_b[i * 2 + tx][col] + printf("tiled workunit checkpoint 5 for thread id: %d\n", global_tid) team_member.team_barrier() + printf("tiled workunit checkpoint 6 for thread id: %d\n", global_tid) for k in range(2): a_index = k + ((team_member.team_rank() // 2) * 2) b_index = ty + (k * 2) tmp += scratch_mem_a[a_index] * scratch_mem_b[b_index] team_member.team_barrier() + printf("tiled workunit checkpoint 7 for thread id: %d\n", global_tid) + printf("tiled workunit checkpoint 8 for thread id: %d\n", global_tid) out[row][col] = tmp From e3166ebcdd6a816a1f82e4cc7cf6978dd8101ffe Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Sat, 18 Mar 2023 20:05:04 -0600 Subject: [PATCH 15/21] Try more threads --- .github/workflows/main_ci.yml | 1 + pykokkos/linalg/workunits.py | 8 -------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/.github/workflows/main_ci.yml b/.github/workflows/main_ci.yml index 5b30969b..55c9ad12 100644 --- a/.github/workflows/main_ci.yml +++ b/.github/workflows/main_ci.yml @@ -39,4 +39,5 @@ jobs: mypy pykokkos - name: run tests run: | + export OMP_NUM_THREADS=4 python runtests.py diff --git a/pykokkos/linalg/workunits.py b/pykokkos/linalg/workunits.py index e2cc2898..69120a6a 100644 --- a/pykokkos/linalg/workunits.py +++ b/pykokkos/linalg/workunits.py @@ -35,7 +35,6 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, view_a: pk.View2D[pk.double], view_b: pk.View2D[pk.double], out: pk.View2D[pk.double]): - printf("tiled workunit checkpoint 1") # early attempt at tiled matrix multiplication in PyKokkos # for now, let's assume a 2x2 tiling arrangement and @@ -45,12 +44,10 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, # start off by getting a global thread id global_tid: int = team_member.league_rank() * team_member.team_size() + team_member.team_rank() - printf("tiled workunit checkpoint 2 for thread id: %d\n", global_tid) # TODO: I have no idea how to get 2D scratch memory views? scratch_mem_a: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) scratch_mem_b: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) - printf("tiled workunit checkpoint 3 for thread id: %d\n", global_tid) # in a 4 x 4 matrix with 2 x 2 tiling the leagues # and teams have matching row/col assignment approaches bx: int = team_member.league_rank() / 2 @@ -64,7 +61,6 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, tmp: float = 0 col: int = by * 2 + ty row: int = bx * 2 + tx - printf("tiled workunit checkpoint 4 for thread id: %d\n", global_tid) # these variables are a bit silly--can we not get # 2D scratch memory indexing? @@ -74,16 +70,12 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, for i in range(out.extent(1) / 2): scratch_mem_a[team_member.team_rank()] = view_a[row][i * 2 + ty] scratch_mem_b[team_member.team_rank()] = view_b[i * 2 + tx][col] - printf("tiled workunit checkpoint 5 for thread id: %d\n", global_tid) team_member.team_barrier() - printf("tiled workunit checkpoint 6 for thread id: %d\n", global_tid) for k in range(2): a_index = k + ((team_member.team_rank() // 2) * 2) b_index = ty + (k * 2) tmp += scratch_mem_a[a_index] * scratch_mem_b[b_index] team_member.team_barrier() - printf("tiled workunit checkpoint 7 for thread id: %d\n", global_tid) - printf("tiled workunit checkpoint 8 for thread id: %d\n", global_tid) out[row][col] = tmp From cc7d976c4a2c55463c5cad3a0118fbf1ddae2fd5 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Sat, 18 Mar 2023 23:10:19 -0600 Subject: [PATCH 16/21] ENH: PR 146 revisions * `dgemm()` now accepts a `league_size` argument, in case that might be useful for GPU where more blocks of threads may be allowed? We no longer calculate `league_size` automatically because this can cause segfaults/issues... (wrt actually available resources I think...) * the tiled DGEMM kernel now passes tests with several input widths that are different powers of 2 --- pykokkos/linalg/l3_blas.py | 5 +---- pykokkos/linalg/workunits.py | 30 ++++++++++++++++-------------- tests/test_linalg.py | 6 ++++-- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/pykokkos/linalg/l3_blas.py b/pykokkos/linalg/l3_blas.py index 842f066e..fbe6af88 100644 --- a/pykokkos/linalg/l3_blas.py +++ b/pykokkos/linalg/l3_blas.py @@ -10,6 +10,7 @@ def dgemm(alpha: float, view_b, beta: float = 0.0, view_c = None, + league_size: int = 4, tile_width: Optional[int] = None): """ Double precision floating point genernal matrix multiplication (GEMM). @@ -73,10 +74,6 @@ def dgemm(alpha: float, else: # limited tiling support--only (some) convenient powers of two # allowed for now... - # TODO: league and team size requests outside of these - # values can segfault... - league_size = int(C.size / (tile_width ** 2)) - print("league_size:", league_size) pk.parallel_for("tiled_matmul", pk.TeamPolicy(league_size=league_size, team_size=tile_width ** 2), diff --git a/pykokkos/linalg/workunits.py b/pykokkos/linalg/workunits.py index 69120a6a..86d0de95 100644 --- a/pykokkos/linalg/workunits.py +++ b/pykokkos/linalg/workunits.py @@ -39,15 +39,14 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, # for now, let's assume a 2x2 tiling arrangement and # that `view_a`, `view_b`, and `out` views are all 4 x 4 matrices - tile_size: int = 4 # this is really just the team size... - width: int = 4 + width: int = out.extent(1) # start off by getting a global thread id global_tid: int = team_member.league_rank() * team_member.team_size() + team_member.team_rank() # TODO: I have no idea how to get 2D scratch memory views? - scratch_mem_a: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) - scratch_mem_b: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), tile_size) + scratch_mem_a: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), team_member.team_size()) + scratch_mem_b: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), team_member.team_size()) # in a 4 x 4 matrix with 2 x 2 tiling the leagues # and teams have matching row/col assignment approaches bx: int = team_member.league_rank() / 2 @@ -67,15 +66,18 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, a_index: int = 0 b_index: int = 0 - for i in range(out.extent(1) / 2): - scratch_mem_a[team_member.team_rank()] = view_a[row][i * 2 + ty] - scratch_mem_b[team_member.team_rank()] = view_b[i * 2 + tx][col] - team_member.team_barrier() + for row_factor in range(0, width, team_member.team_size()): + for col_factor in range(0, width, team_member.team_size()): + tmp = 0 + for i in range(width / 2): + scratch_mem_a[team_member.team_rank()] = view_a[row + row_factor][i * 2 + ty] + scratch_mem_b[team_member.team_rank()] = view_b[i * 2 + tx][col + col_factor] + team_member.team_barrier() - for k in range(2): - a_index = k + ((team_member.team_rank() // 2) * 2) - b_index = ty + (k * 2) - tmp += scratch_mem_a[a_index] * scratch_mem_b[b_index] - team_member.team_barrier() + for k in range(2): + a_index = k + ((team_member.team_rank() // 2) * 2) + b_index = ty + (k * 2) + tmp += scratch_mem_a[a_index] * scratch_mem_b[b_index] + team_member.team_barrier() - out[row][col] = tmp + out[row + row_factor][col + col_factor] = tmp diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 6c1ac9c1..0a767be4 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -184,8 +184,10 @@ def test_dgemm_tiled(alpha, a, b, expected): @pytest.mark.parametrize("input_width, tile_width", [ - (4, 2), - #(8, 2), + (2 ** 2, 2), + (2 ** 3, 2), + (2 ** 5, 2), + (2 ** 7, 2), ]) @pytest.mark.parametrize("seed", [ 100787, 90, 10, From 87b8f00ecdf11b7b0f1274767df2fdcf38ea3f8c Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Sun, 19 Mar 2023 15:01:42 -0600 Subject: [PATCH 17/21] BENCH: simplify/improve DGEMM benchmarking code. --- benchmarks/dgemm_compare.py | 125 ++++++++++++------------------------ 1 file changed, 41 insertions(+), 84 deletions(-) diff --git a/benchmarks/dgemm_compare.py b/benchmarks/dgemm_compare.py index f737eeb5..12cc173b 100644 --- a/benchmarks/dgemm_compare.py +++ b/benchmarks/dgemm_compare.py @@ -1,14 +1,15 @@ """ -Compare DGEMM performance with SciPy -(i.e., a wheel with OpenBLAS 0.3.18) +Record DGEMM performance. """ import os +import time import pykokkos as pk from pykokkos.linalg.l3_blas import dgemm as pk_dgemm import numpy as np +from numpy.testing import assert_allclose from scipy.linalg.blas import dgemm as scipy_dgemm import matplotlib matplotlib.use("Agg") @@ -18,94 +19,50 @@ if __name__ == "__main__": - import timeit - num_global_repeats = 50 - num_repeats = 5000 - results = { - "PyKokkos": {"small": [], - "medium": [], - "large": []}, - "SciPy": {"small": [], - "medium": [], - "large": []}, - } - alpha, a, b, c, beta = (3.6, - np.array([[8, 7, 1, 200, 55.3], - [99.2, 1.11, 2.02, 17.7, 900.2], - [5.01, 15.21, 22.07, 1.09, 22.22], - [1, 2, 3, 4, 5]], dtype=np.float64), - np.array([[9, 0, 2, 19], - [77, 100, 4, 19], - [1, 500, 9, 19], - [226.68, 11.61, 12.12, 19], - [17.7, 200.10, 301.17, 20]], dtype=np.float64), - np.ones((4, 4)) * 3.3, - 4.3) + scenario_name = "pk_gp160_dgemm_NO_tiling_CPU_OpenMP" + space = pk.ExecutionSpace.OpenMP + pk.set_default_space(space) + + num_global_repeats = 5 + square_matrix_width = 2 ** 9 + + rng = np.random.default_rng(18898787) + alpha = 1.0 + a = rng.random((square_matrix_width, square_matrix_width)).astype(float) + b = rng.random((square_matrix_width, square_matrix_width)).astype(float) + view_a = pk.from_numpy(a) + view_b = pk.from_numpy(b) + #cuda_a = cp.array(a) + #cuda_b = cp.array(b) + num_threads = os.environ.get("OMP_NUM_THREADS") - df = pd.DataFrame(np.full(shape=(num_global_repeats * 2, 4), fill_value=np.nan), - columns=["backend", "small", "medium", "large"]) - df["backend"] = df["backend"].astype(str) if num_threads is None: raise ValueError("must set OMP_NUM_THREADS for benchmarks!") + df = pd.DataFrame(np.full(shape=(num_global_repeats, 2), fill_value=np.nan), + columns=["scenario", "time (s)"]) + df["scenario"] = df["scenario"].astype(str) + print("df before trials:\n", df) + + expected = scipy_dgemm(alpha, a, b) counter = 0 for global_repeat in tqdm(range(1, num_global_repeats + 1)): - for col_num, system_size in tqdm(enumerate(["small", "medium", "large"]), total=3): - if system_size == "medium": - a_new = np.tile(a, (10, 1)) - b_new = np.tile(b, (1, 10)) - c_new = np.ones((40, 40)) * 3.3 - elif system_size == "large": - a_new = np.tile(a, (40, 1)) - b_new = np.tile(b, (1, 40)) - c_new = np.ones((160, 160)) * 3.3 - else: - a_new = a - b_new = b - c_new = c + start = time.perf_counter() + #actual = pk_dgemm(alpha, view_a, view_b, beta=0.0, view_c=None) + #actual = pk_dgemm(alpha, view_a, view_b, beta=0.0, view_c=None, league_size=4, tile_width=2) + actual = scipy_dgemm(alpha, a, b) + end = time.perf_counter() + assert_allclose(actual, expected) + + dgemm_time_sec = end - start + df.iloc[counter, 0] = f"{scenario_name}" + df.iloc[counter, 1] = dgemm_time_sec + counter += 1 - view_a = pk.from_numpy(a_new) - view_b = pk.from_numpy(b_new) - view_c = pk.from_numpy(c_new) - pk_dgemm_time_sec = timeit.timeit("pk_dgemm(alpha, view_a, view_b, beta, view_c)", - globals=globals(), - number=num_repeats) - results["PyKokkos"][system_size].append(pk_dgemm_time_sec) - df.iloc[counter, 0] = "PyKokkos" - df.iloc[counter, col_num + 1] = pk_dgemm_time_sec - scipy_dgemm_time_sec = timeit.timeit("scipy_dgemm(alpha, a_new, b_new, beta, c_new)", - globals=globals(), - number=num_repeats) - results["SciPy"][system_size].append(scipy_dgemm_time_sec) - df.iloc[counter + 1, 0] = "SciPy" - df.iloc[counter + 1, col_num + 1] = scipy_dgemm_time_sec - counter += 2 - print("df:\n", df) - ratios = df[df["backend"] == "PyKokkos"].iloc[..., 1:].reset_index(drop=True) / df[df["backend"] == "SciPy"].iloc[..., 1:].reset_index(drop=True) - avg_ratios = ratios.mean(axis=0) - std_ratios = ratios.std(axis=0) + print("df after trials:\n", df) - fig, axes = plt.subplots(nrows=1, ncols=3) - fig.set_size_inches(12, 5) - df.boxplot(ax=axes, - by="backend", - ) - for ax in axes: - problem_size = ax.get_title() - avg_ratio = avg_ratios[problem_size] - std_ratio = std_ratios[problem_size] - if avg_ratio == 1: - color = "gray" - prefix = "(Same Performance)" - elif avg_ratio > 1: - color = "red" - prefix = "(PyKokkos slower by)" - elif avg_ratio < 1: - color = "green" - prefix = "(PyKokkos faster by)" - final_ratio = f"{prefix} {avg_ratio:.1f} $\pm$ {std_ratio:.1f} Fold" - ax.set_xlabel(final_ratio, color=color) - axes[0].set_ylabel(f"Time (s) for {num_repeats} DGEMM executions") - fig.suptitle(f"DGEMM performance boxplots (OMP_NUM_THREADS={num_threads}; {num_global_repeats} trials) for different problem sizes") - fig.savefig(f"DGEMM_perf_compare_{num_threads}_threads.png", dpi=300) + filename = f"{scenario_name}_square_matrix_width_{square_matrix_width}_{num_global_repeats}_trials.parquet.gzip" + df.to_parquet(filename, + engine="pyarrow", + compression="gzip") From 6c71f6d0a6bf33af9f8f71c543c168c8175bc10d Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Sun, 19 Mar 2023 15:25:53 -0600 Subject: [PATCH 18/21] ENH: support different league sizes * add limited league size variation support--size of 1 and some convenient multiples of 4 may work; tests for 1 and 4 are passing locally --- pykokkos/linalg/l3_blas.py | 11 ++++++++++- pykokkos/linalg/workunits.py | 14 +++++++++++--- tests/test_linalg.py | 6 +++++- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/pykokkos/linalg/l3_blas.py b/pykokkos/linalg/l3_blas.py index fbe6af88..91a0edb9 100644 --- a/pykokkos/linalg/l3_blas.py +++ b/pykokkos/linalg/l3_blas.py @@ -10,6 +10,8 @@ def dgemm(alpha: float, view_b, beta: float = 0.0, view_c = None, + # TODO: league_size support is pretty limited/confusing + # at the moment... league_size: int = 4, tile_width: Optional[int] = None): """ @@ -74,6 +76,12 @@ def dgemm(alpha: float, else: # limited tiling support--only (some) convenient powers of two # allowed for now... + # limited league size support for now as well... + if league_size == 1: + slide_factor = 0 + else: + slide_factor = int(league_size / 4) + pk.parallel_for("tiled_matmul", pk.TeamPolicy(league_size=league_size, team_size=tile_width ** 2), @@ -82,5 +90,6 @@ def dgemm(alpha: float, alpha=alpha, view_a=view_a, view_b=view_b, - out=C) + out=C, + slide_factor=slide_factor) return C diff --git a/pykokkos/linalg/workunits.py b/pykokkos/linalg/workunits.py index 86d0de95..336d1ec6 100644 --- a/pykokkos/linalg/workunits.py +++ b/pykokkos/linalg/workunits.py @@ -34,7 +34,8 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, alpha: float, view_a: pk.View2D[pk.double], view_b: pk.View2D[pk.double], - out: pk.View2D[pk.double]): + out: pk.View2D[pk.double], + slide_factor: int): # early attempt at tiled matrix multiplication in PyKokkos # for now, let's assume a 2x2 tiling arrangement and @@ -66,8 +67,15 @@ def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember, a_index: int = 0 b_index: int = 0 - for row_factor in range(0, width, team_member.team_size()): - for col_factor in range(0, width, team_member.team_size()): + # TODO: league size support is limited for now, probably + # only some convenient factors of the total matrix size + slide_size: int = 0 + if slide_factor == 0: + slide_size = 2 + else: + slide_size = 4 * slide_factor + for row_factor in range(0, width, slide_size): + for col_factor in range(0, width, slide_size): tmp = 0 for i in range(width / 2): scratch_mem_a[team_member.team_rank()] = view_a[row + row_factor][i * 2 + ty] diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 0a767be4..4367d0c1 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -192,7 +192,10 @@ def test_dgemm_tiled(alpha, a, b, expected): @pytest.mark.parametrize("seed", [ 100787, 90, 10, ]) -def test_dgemm_square_tiled_vs_scipy(input_width, tile_width, seed): +@pytest.mark.parametrize("league_size", [ + 1, 4, + ]) +def test_dgemm_square_tiled_vs_scipy(input_width, tile_width, seed, league_size): rng = np.random.default_rng(seed) a = rng.integers(low=0, high=10, size=(input_width, input_width)).astype(float) b = rng.integers(low=0, high=19, size=(input_width, input_width)).astype(float) @@ -204,5 +207,6 @@ def test_dgemm_square_tiled_vs_scipy(input_width, tile_width, seed): view_b=pk.from_numpy(b), beta=0.0, view_c=None, + league_size=league_size, tile_width=tile_width) assert_allclose(actual_c, expected) From 40a654d4e4a5f61bdfb195975a523c9552f8ddb7 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Sun, 19 Mar 2023 15:38:06 -0600 Subject: [PATCH 19/21] BENCH: auto-rm pk_cpp folder.. --- benchmarks/dgemm_compare.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/benchmarks/dgemm_compare.py b/benchmarks/dgemm_compare.py index 12cc173b..dd25c0d7 100644 --- a/benchmarks/dgemm_compare.py +++ b/benchmarks/dgemm_compare.py @@ -3,6 +3,7 @@ """ import os +import shutil import time import pykokkos as pk @@ -39,6 +40,10 @@ if num_threads is None: raise ValueError("must set OMP_NUM_THREADS for benchmarks!") + cwd = os.getcwd() + shutil.rmtree(os.path.join(cwd, "pk_cpp"), + ignore_errors=True) + df = pd.DataFrame(np.full(shape=(num_global_repeats, 2), fill_value=np.nan), columns=["scenario", "time (s)"]) df["scenario"] = df["scenario"].astype(str) From 30fca1f42524fd8aa68f046ec3a5af4379714739 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Sat, 25 Mar 2023 15:39:24 -0600 Subject: [PATCH 20/21] DGEMM benchmark improvments --- benchmarks/dgemm_compare.py | 85 +++++++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 22 deletions(-) diff --git a/benchmarks/dgemm_compare.py b/benchmarks/dgemm_compare.py index dd25c0d7..faeb1d5d 100644 --- a/benchmarks/dgemm_compare.py +++ b/benchmarks/dgemm_compare.py @@ -5,6 +5,8 @@ import os import shutil import time +import argparse +import socket import pykokkos as pk from pykokkos.linalg.l3_blas import dgemm as pk_dgemm @@ -19,27 +21,72 @@ from tqdm import tqdm +def setup_data(mode): + rng = np.random.default_rng(18898787) + a = rng.random((square_matrix_width, square_matrix_width)).astype(float) + b = rng.random((square_matrix_width, square_matrix_width)).astype(float) + if "pykokkos" in mode: + view_a = pk.View([square_matrix_width, square_matrix_width], dtype=pk.float64) + view_b = pk.View([square_matrix_width, square_matrix_width], dtype=pk.float64) + view_a[:] = a + view_b[:] = b + return view_a, view_b + else: + return a, b + + +def time_dgemm(expected, mode, league_size=4, tile_width=2): + start = time.perf_counter() + if mode == "pykokkos_no_tiling": + actual = pk_dgemm(alpha, a, b, beta=0.0, view_c=None) + elif mode == "pykokkos_with_tiling": + actual = pk_dgemm(alpha, a, b, beta=0.0, view_c=None, league_size=4, tile_width=2) + elif mode == "scipy": + actual = scipy_dgemm(alpha, a, b) + else: + raise ValueError(f"Unknown timing mode: {mode}") + # include check for correctness inside the + # timer code block to prevent i.e., async GPU + # execution; just be careful to select matrix sizes + # large enough that the assertion isn't slower than the + # DGEMM + assert_allclose(actual, expected) + end = time.perf_counter() + dgemm_time_sec = end - start + return dgemm_time_sec + + if __name__ == "__main__": - scenario_name = "pk_gp160_dgemm_NO_tiling_CPU_OpenMP" - space = pk.ExecutionSpace.OpenMP + parser = argparse.ArgumentParser() + parser.add_argument('-n', '--num-global-repeats', default=5) + parser.add_argument('-m', '--mode', default="scipy") + parser.add_argument('-p', '--power-of-two', default=10) + parser.add_argument('-w', '--tile-width', default=2) + parser.add_argument('-l', '--league-size', default=4) + parser.add_argument('-s', '--space', default="OpenMP") + args = parser.parse_args() + hostname = socket.gethostname() + + if args.space == "OpenMP": + space = pk.ExecutionSpace.OpenMP + elif args.space == "Cuda": + space = pk.ExecutionSpace.Cuda + else: + raise ValueError(f"Invalid execution space specified: {args.space}") pk.set_default_space(space) - num_global_repeats = 5 - square_matrix_width = 2 ** 9 - rng = np.random.default_rng(18898787) - alpha = 1.0 - a = rng.random((square_matrix_width, square_matrix_width)).astype(float) - b = rng.random((square_matrix_width, square_matrix_width)).astype(float) - view_a = pk.from_numpy(a) - view_b = pk.from_numpy(b) - #cuda_a = cp.array(a) - #cuda_b = cp.array(b) + num_global_repeats = int(args.num_global_repeats) + square_matrix_width = 2 ** int(args.power_of_two) + num_threads = os.environ.get("OMP_NUM_THREADS") if num_threads is None: raise ValueError("must set OMP_NUM_THREADS for benchmarks!") + space_name = str(space).split(".")[1] + scenario_name = f"{hostname}_dgemm_{args.mode}_{num_threads}_OMP_threads_{space_name}_execution_space_{square_matrix_width}_square_matrix_width_{args.league_size}_league_size" + cwd = os.getcwd() shutil.rmtree(os.path.join(cwd, "pk_cpp"), ignore_errors=True) @@ -49,25 +96,19 @@ df["scenario"] = df["scenario"].astype(str) print("df before trials:\n", df) + alpha = 1.0 + a, b = setup_data(mode=args.mode) expected = scipy_dgemm(alpha, a, b) counter = 0 for global_repeat in tqdm(range(1, num_global_repeats + 1)): - start = time.perf_counter() - #actual = pk_dgemm(alpha, view_a, view_b, beta=0.0, view_c=None) - #actual = pk_dgemm(alpha, view_a, view_b, beta=0.0, view_c=None, league_size=4, tile_width=2) - actual = scipy_dgemm(alpha, a, b) - end = time.perf_counter() - assert_allclose(actual, expected) - - dgemm_time_sec = end - start + dgemm_time_sec = time_dgemm(expected, mode=args.mode, league_size=args.league_size, tile_width=args.tile_width) df.iloc[counter, 0] = f"{scenario_name}" df.iloc[counter, 1] = dgemm_time_sec counter += 1 - print("df after trials:\n", df) - filename = f"{scenario_name}_square_matrix_width_{square_matrix_width}_{num_global_repeats}_trials.parquet.gzip" + filename = f"{scenario_name}.parquet.gzip" df.to_parquet(filename, engine="pyarrow", compression="gzip") From 876cc999716846931256f7ea590513e70605c253 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Sun, 26 Mar 2023 15:30:13 -0600 Subject: [PATCH 21/21] matplotlib wasn't being used in DGEMM bench. --- benchmarks/dgemm_compare.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/benchmarks/dgemm_compare.py b/benchmarks/dgemm_compare.py index faeb1d5d..5190cfba 100644 --- a/benchmarks/dgemm_compare.py +++ b/benchmarks/dgemm_compare.py @@ -14,9 +14,6 @@ import numpy as np from numpy.testing import assert_allclose from scipy.linalg.blas import dgemm as scipy_dgemm -import matplotlib -matplotlib.use("Agg") -import matplotlib.pyplot as plt import pandas as pd from tqdm import tqdm