From f2a5aa05b510f3910e15343af1560c0f938b94a2 Mon Sep 17 00:00:00 2001 From: Andrei Ivanov <9387186+and-ivanov@users.noreply.github.com> Date: Fri, 12 May 2023 12:08:37 +0000 Subject: [PATCH] n:m:g format implementation --- pyproject.toml | 1 + src/sten/__init__.py | 3 + src/sten/grouped_nm/__init__.py | 0 src/sten/grouped_nm/dace_gnm_mult.py | 1207 ++++++++++++++++ src/sten/grouped_nm/grouped_nm_tensor.py | 1091 ++++++++++++++ src/sten/grouped_nm/matmul_generator.py | 1659 ++++++++++++++++++++++ src/sten/grouped_nm/sten_impls.py | 174 +++ tests/test_nmg.py | 91 ++ 8 files changed, 4226 insertions(+) create mode 100644 src/sten/grouped_nm/__init__.py create mode 100644 src/sten/grouped_nm/dace_gnm_mult.py create mode 100644 src/sten/grouped_nm/grouped_nm_tensor.py create mode 100644 src/sten/grouped_nm/matmul_generator.py create mode 100644 src/sten/grouped_nm/sten_impls.py create mode 100644 tests/test_nmg.py diff --git a/pyproject.toml b/pyproject.toml index c869b56..3bfe5ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "coverage", "pytest-cov", "pytest-xdist", + "py-cpuinfo", ] [project.urls] diff --git a/src/sten/__init__.py b/src/sten/__init__.py index bb4ba4d..28f526b 100644 --- a/src/sten/__init__.py +++ b/src/sten/__init__.py @@ -1,3 +1,6 @@ from .sten import * from .patches import patch + +from .grouped_nm.sten_impls import GroupedNMSparsifier +from .grouped_nm.grouped_nm_tensor import GroupedNMTensor, PerfectNMTensor diff --git a/src/sten/grouped_nm/__init__.py b/src/sten/grouped_nm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/sten/grouped_nm/dace_gnm_mult.py b/src/sten/grouped_nm/dace_gnm_mult.py new file mode 100644 index 0000000..f57ec88 --- /dev/null +++ b/src/sten/grouped_nm/dace_gnm_mult.py @@ -0,0 +1,1207 @@ +from dataclasses import dataclass +import torch +from math import factorial as fact +import math +import itertools +from native_scripting import compile +import ctypes +import numpy as np +import copy +import functools +import time +from sympy import factorint +from typing import Tuple, List, Union +import argparse +from .grouped_nm_tensor import GroupedNMTensor, is_correct_nm, make_n_m, int_ceil +import dace +from dace import nodes +import jinja2 +from textwrap import dedent +from .matmul_generator import make_load_mask + + +def generate_sparse_dense_microkernel_avx2( + ptr_a_val, + ptr_a_idx, + ptr_b, + ptr_c, + M, + K, + Nblocks, + SAMS, + SAKS, + SAMI, + SAKI, + SBK, + SCM, + SCB, + SCV, + SBB, + SBV, + n, + m, + g, + num_vectors_b, + v, +): + DK2 = fact(m) // fact(n) // fact(m - n) + DK3 = g + DM3 = m + DM3S = n + DN4 = num_vectors_b # accumulator tile size + DN5 = v # vector size + + bti = make_n_m(n, m) + + DM3S_padded = int_ceil(M, DM3) * DM3S + DK23_padded = int_ceil(K, DK2 * DK3) * DK2 * DK3 + # DN5R = N % DN5 + DN5R = 0 + + assert M <= DM3 + assert K <= DK2 * DK3 + # assert N <= DN4 * DN5 + assert SCV == 1 + assert SBV == 1 + + template = jinja2.Template( + dedent( + """\ + // SAMS, SAKS, SAMI, SAKI, SBK, SCM, SCB, SCV, SBB, SBV = {{SAMS, SAKS, SAMI, SAKI, SBK, SCM, SCB, SCV, SBB, SBV}} + // init accumulators + {% for dm3s in range(DM3S_padded) %} + {% for dn4 in range(DN4) %} + __m256 acc_{{ dm3s }}_{{ dn4 }} = _mm256_setzero_ps(); + {% endfor %} + {% endfor %} + {% if DN5R != 0 %}__m256i load_mask = _mm256_setr_epi32({{ make_load_mask(DN5R, DN5) }});{% endif %} + {% for dk2 in range(DK2) %}{ + {% set si = bti[dk2] %} + {% set in_remainder_DK2 = (DK23_padded != K) %} + // in_remainder_DK2 = {{ in_remainder_DK2 }} dk2 {{ dk2 }} DK2 {{ DK2 }} DK23_padded {{ DK23_padded }} K {{ K }} + FOR(dk3, {{DK3}}) { + int dk23 = {{ dk2 }} * {{ DK3 }} + dk3; + int16_t dk23_in_B = {{ ptr_a_idx }}[dk23 * {{ SAKI }}]; + //printf("A LOAD [offset from A = %d, dk23 = %d, dk3=%d, DK3=%d, SAKI = %d]\\n", &({{ ptr_a_idx }}[dk23 * {{ SAKI }}]) - A_idx, dk23, dk3, {{DK3}} , {{ SAKI }}); + //printf("dk23 in B = %d\\n", (int)dk23_in_B); + {% if in_remainder_DK2 %} + // in dk2 remainder + if ((dk23_in_B < 0) || (dk23_in_B >= {{K}})) break; + {% endif %} + // SBK = {{SBK}} + float* B2 = &{{ ptr_b }}[dk23_in_B * {{ SBK }}]; + //printf("B2 offset [offset from B = %d, dk23_in_b = %d, SBK = %d]\\n", B2 - B, dk23_in_B, {{SBK}}); + {% for dm3s in range(DM3S_padded) %} + __m256 va_{{ dm3s }} = _mm256_broadcast_ss(&{{ ptr_a_val }}[{{ dm3s }} * {{ SAMS }} + dk23 * {{ SAKS }}]); + {% endfor %} + {% for dn4 in range(DN4) %}{ + {% set in_remainder_DN5 = (dn4 == DN4 - 1) and (DN5R != 0) %} + {% if not in_remainder_DN5 %} + //printf("B2 LOAD [offset from B = %d, dn4 = %d, SBB = %d]\\n", &B2[{{ dn4 }} * {{ SBB }}] - B, {{ dn4}} , {{ SBB }}); + __m256 vb_{{ dn4 }} = _mm256_loadu_ps(&B2[{{dn4}} * {{SBB}}]); + {% else %} + __m256 vb_{{ dn4 }} = _mm256_maskload_ps(&B2[{{ dn4 }} * {{ SBB }}], load_mask); + {% endif %} + {% for dm3s in range(DM3S_padded) %} + acc_{{ dm3s }}_{{ dn4 }} = _mm256_fmadd_ps(va_{{ dm3s }}, vb_{{ dn4 }}, acc_{{ dm3s }}_{{ dn4 }}); + {% endfor %} + }{% endfor %} + } + {% for dm3s in range(DM3S_padded) %} + {% if (dk2 == DK2 - 1) or (bti[dk2][dm3s] != bti[dk2+1][dm3s]) %} + {% for dn4 in range(DN4) %}{ + {% set in_remainder_DN5 = (dn4 == DN4 - 1) and (DN5R != 0) %} + {% set dm3 = si[dm3s] %} + {% if dm3 < DM3 %} // cut out remainder DM3 + float* c_addr = &{{ ptr_c }}[{{ dm3 }} * {{ SCM }} + {{ dn4 }} * {{ SCB }}]; + {% if not in_remainder_DN5 %} + _mm256_storeu_ps(c_addr, _mm256_add_ps(_mm256_loadu_ps(c_addr), acc_{{ dm3s }}_{{ dn4 }})); + {% else %} + _mm256_maskstore_ps(c_addr, load_mask, _mm256_add_ps(_mm256_maskload_ps(c_addr, load_mask), acc_{{ dm3s }}_{{ dn4 }})); + {% endif %} + {% endif %} + {% if (dk2 != DK2 - 1) %} + acc_{{ dm3s }}_{{ dn4 }} = _mm256_setzero_ps(); + {% endif %} + }{% endfor %} + {% endif %} + {% endfor %} + }{% endfor %} + """ + ), + trim_blocks=True, + lstrip_blocks=True, + undefined=jinja2.StrictUndefined, + ) + code = template.render({**globals(), **locals()}) + + return code + + +def generate_sparse_dense_microkernel_avx512( + ptr_a_val, + ptr_a_idx, + ptr_b, + ptr_c, + M, + K, + Nblocks, + SAMS, + SAKS, + SAMI, + SAKI, + SBK, + SCM, + SCB, + SCV, + SBB, + SBV, + n, + m, + g, + num_vectors_b, + v, +): + DK2 = fact(m) // fact(n) // fact(m - n) + DK3 = g + DM3 = m + DM3S = n + DN4 = num_vectors_b # accumulator tile size + DN5 = v # vector size + + bti = make_n_m(n, m) + + DM3S_padded = int_ceil(M, DM3) * DM3S + DK23_padded = int_ceil(K, DK2 * DK3) * DK2 * DK3 + # DN5R = N % DN5 + DN5R = 0 + + assert M <= DM3 + assert K <= DK2 * DK3 + # assert N <= DN4 * DN5 + assert SCV == 1 + assert SBV == 1 + + template = jinja2.Template( + dedent( + """\ + // SAMS, SAKS, SAMI, SAKI, SBK, SCM, SCB, SCV, SBB, SBV = {{SAMS, SAKS, SAMI, SAKI, SBK, SCM, SCB, SCV, SBB, SBV}} + // init accumulators + {% for dm3s in range(DM3S_padded) %} + {% for dn4 in range(DN4) %} + __m512 acc_{{ dm3s }}_{{ dn4 }} = _mm512_setzero_ps(); + {% endfor %} + {% endfor %} + {% if DN5R != 0 %}__mmask16 load_mask = {{ 2 ** DN5R - 1 }};{% endif %} + {% for dk2 in range(DK2) %}{ + {% set si = bti[dk2] %} + {% set in_remainder_DK2 = (DK23_padded != K) %} + // in_remainder_DK2 = {{ in_remainder_DK2 }} dk2 {{ dk2 }} DK2 {{ DK2 }} DK23_padded {{ DK23_padded }} K {{ K }} + FOR(dk3, {{DK3}}) { + int dk23 = {{ dk2 }} * {{ DK3 }} + dk3; + int16_t dk23_in_B = {{ ptr_a_idx }}[dk23 * {{ SAKI }}]; + //printf("A LOAD [offset from A = %d, dk23 = %d, dk3=%d, DK3=%d, SAKI = %d]\\n", &({{ ptr_a_idx }}[dk23 * {{ SAKI }}]) - A_idx, dk23, dk3, {{DK3}} , {{ SAKI }}); + //printf("dk23 in B = %d\\n", (int)dk23_in_B); + {% if in_remainder_DK2 %} + // in dk2 remainder + if ((dk23_in_B < 0) || (dk23_in_B >= {{K}})) break; + {% endif %} + // SBK = {{SBK}} + float* B2 = &{{ ptr_b }}[dk23_in_B * {{ SBK }}]; + //printf("B2 offset [offset from B = %d, dk23_in_b = %d, SBK = %d]\\n", B2 - B, dk23_in_B, {{SBK}}); + {% for dm3s in range(DM3S_padded) %} + __m512 va_{{ dm3s }} = _mm512_set1_ps({{ ptr_a_val }}[{{ dm3s }} * {{ SAMS }} + dk23 * {{ SAKS }}]); + {% endfor %} + {% for dn4 in range(DN4) %}{ + {% set in_remainder_DN5 = (dn4 == DN4 - 1) and (DN5R != 0) %} + {% if not in_remainder_DN5 %} + //printf("B2 LOAD [offset from B = %d, dn4 = %d, SBB = %d]\\n", &B2[{{ dn4 }} * {{ SBB }}] - B, {{ dn4}} , {{ SBB }}); + __m512 vb_{{ dn4 }} = _mm512_loadu_ps(&B2[{{dn4}} * {{SBB}}]); + {% else %} + __m512 vb_{{ dn4 }} = _mm512_maskz_load_ps(load_mask, &B2[{{ dn4 }} * {{ SBB }}]); + {% endif %} + {% for dm3s in range(DM3S_padded) %} + acc_{{ dm3s }}_{{ dn4 }} = _mm512_fmadd_ps(va_{{ dm3s }}, vb_{{ dn4 }}, acc_{{ dm3s }}_{{ dn4 }}); + {% endfor %} + }{% endfor %} + } + {% for dm3s in range(DM3S_padded) %} + {% if (dk2 == DK2 - 1) or (bti[dk2][dm3s] != bti[dk2+1][dm3s]) %} + {% for dn4 in range(DN4) %}{ + {% set in_remainder_DN5 = (dn4 == DN4 - 1) and (DN5R != 0) %} + {% set dm3 = si[dm3s] %} + {% if dm3 < DM3 %} // cut out remainder DM3 + float* c_addr = &{{ ptr_c }}[{{ dm3 }} * {{ SCM }} + {{ dn4 }} * {{ SCB }}]; + {% if not in_remainder_DN5 %} + _mm512_storeu_ps(c_addr, _mm512_add_ps(_mm512_loadu_ps(c_addr), acc_{{ dm3s }}_{{ dn4 }})); + {% else %} + _mm512_mask_store_ps(c_addr, load_mask, _mm512_add_ps(_mm512_maskz_load_ps(load_mask, c_addr), acc_{{ dm3s }}_{{ dn4 }})); + {% endif %} + {% endif %} + {% if (dk2 != DK2 - 1) %} + acc_{{ dm3s }}_{{ dn4 }} = _mm512_setzero_ps(); + {% endif %} + }{% endfor %} + {% endif %} + {% endfor %} + }{% endfor %} + """ + ), + trim_blocks=True, + lstrip_blocks=True, + undefined=jinja2.StrictUndefined, + ) + code = template.render({**globals(), **locals()}) + + return code + + +@dataclass +class Loop: + dim: str = "m" + sequential: bool = False + size: int = 1 # Relative tile size + buffer_a: bool = False + buffer_b: bool = False + buffer_c: bool = False + + +def generate_configuration( + n: int, + m: int, + g: int, + M: int, + N: int, + K: int, + B_trans: bool, + C_trans: bool, + *, + vector_size: int = 8, + num_vectors_b: int = 4, # Microkernel parameters + loops: Union[List[Loop], List[List[Loop]]] = [ + Loop("k", sequential=True), + Loop("m"), + Loop("n"), + ], + local_c: bool = False, +) -> dace.SDFG: + if vector_size not in (8, 16): + raise ValueError("vector_size must be 8 or 16") + + num_vector_registers = vector_size * 2 # 16 for AVX2 and AVX512, 32 for AVX512_VL + + assert num_vectors_b >= 1 and num_vectors_b * n < num_vector_registers + + # Ensure integers are 32-bit by default + dace.Config.set("compiler", "default_data_types", value="C") + dace.Config.set("compiler.allow_view_arguments", value=True) + + # Create a new SDFG + sdfg = dace.SDFG("gnm_mult") + state = sdfg.add_state() + + # Tile sizes + DK2 = math.factorial(m) // math.factorial(n) // math.factorial(m - n) + DK2_g = DK2 * g + DMI_padded = math.ceil(M / m) + DKI_padded = math.ceil(K / DK2_g) + Nblocks = N // (num_vectors_b * vector_size) + + # Create C array + C_desc = dace.float32[DMI_padded, m, Nblocks, num_vectors_b, vector_size] + if C_trans: + C_desc.set_strides_from_layout(2, 3, 4, 0, 1) + sdfg.add_datadesc("C", C_desc) + + # Initialize C to zero + if not local_c: + state.add_mapped_tasklet( + "init", + dict( + i=f"0:{DMI_padded}", + j=f"0:{m}", + k=f"0:{Nblocks}", + l=f"0:{num_vectors_b}", + z=f"0:{vector_size}", + ), + {}, + "c = 0", + {"c": dace.Memlet("C[i, j, k, l, z]")}, + external_edges=True, + ) + + if isinstance(loops[0], Loop): + all_loops: List[List[Loop]] = [loops] * 8 + else: + # Specialized loops for every remainder + all_loops: List[List[Loop]] = loops + + num_tiles_k = DKI_padded + if K % DK2_g != 0: + num_tiles_k -= 1 + k_remainder = True + else: + k_remainder = False + + # Create microkernel states for each remainder configuration + for loops, (rem_m, rem_n, rem_k) in zip( + all_loops, + [ + (0, 0, 0), + (0, 0, 1), + (0, 1, 0), + (0, 1, 1), + (1, 0, 0), + (1, 0, 1), + (1, 1, 0), + (1, 1, 1), + ], + ): + if rem_n == 1: + continue + if rem_k == 1 and k_remainder is False: + continue + state = sdfg.add_state_after(state, f"microkernel_{rem_m}_{rem_n}_{rem_k}") + + # Skip remainder loops + loops = [loop for loop in loops if loop.dim != "m" or rem_m == 0] + loops = [loop for loop in loops if loop.dim != "n" or rem_n == 0] + loops = [loop for loop in loops if loop.dim != "k" or rem_k == 0] + + generate_single_kernel( + sdfg, + state, + n, + m, + g, + M, + N, + K, + B_trans, + C_trans, + vector_size, + num_vectors_b, + loops, + local_c, + rem_m, + rem_n, + rem_k, + num_tiles_k, + ) + + sdfg.append_global_code( + """ +#include +#define FOR(i, n) for (int i = 0; i < n; ++i) +""" + ) + sdfg.openmp_sections = False + + # sdfg.simplify() + + return sdfg + + +def generate_single_kernel( + sdfg: dace.SDFG, + state: dace.SDFGState, + n: int, + m: int, + g: int, + M: int, + N: int, + K: int, + B_trans: bool, + C_trans: bool, + v: int, + num_vectors_b: int, + loops: List[Loop], + local_c: bool, + rem_m: bool, + rem_n: bool, + rem_k: bool, + num_tiles_k: int, +): + kernel_generator = ( + generate_sparse_dense_microkernel_avx2 + if v == 8 + else generate_sparse_dense_microkernel_avx512 + ) + assert N % v == 0 + assert N % (num_vectors_b * v) == 0 + assert rem_n == 0 + + # Number of permutations + DK2 = math.factorial(m) // math.factorial(n) // math.factorial(m - n) + DK2_g = DK2 * g + + # Sizes + DMI_padded = math.ceil(M / m) + DKI_padded = math.ceil(K / DK2_g) + Nblocks = N // (num_vectors_b * v) + map_dims = {"m": DMI_padded, "n": Nblocks, "k": num_tiles_k} + + # Tile sizes + tile_M = m + tile_K = DK2_g + + # Data descriptors + A_val_desc = dace.float32[DKI_padded, DMI_padded, DK2, g, n] + A_idx_desc = dace.int16[DKI_padded, DMI_padded, DK2, g, 1] + B_desc = dace.float32[K, N] + C_desc = dace.float32[DMI_padded, m, Nblocks, num_vectors_b, v] + if B_trans: + B_desc.set_strides_from_layout(1, 0) + if C_trans: + C_desc.set_strides_from_layout(2, 3, 4, 0, 1) + + if "A_val" not in sdfg.arrays: + sdfg.add_datadesc("A_val", A_val_desc) + sdfg.add_datadesc("A_idx", A_idx_desc) + sdfg.add_datadesc("B", B_desc) + + # Add loops to graph + map_entries: List[nodes.MapEntry] = [] + map_exits: List[nodes.MapExit] = [] + for loop in loops: + map_entry, map_exit = state.add_map( + f"{loop.dim}_map", + {f"block_{loop.dim}": f"0:{map_dims[loop.dim]}"}, + schedule=( + dace.ScheduleType.Sequential + if loop.sequential + else dace.ScheduleType.CPU_Multicore + ), + ) + map_entries.append(map_entry) + map_exits.append(map_exit) + + # Get innermost local storage strides + SBV, SBB, SBK = 1, v, B_desc.strides[0] + SCM, SCB, SCV = C_desc.strides[-4], C_desc.strides[-2], C_desc.strides[-1] + + # Generate tasklet code with innermost local storage strides + autogenerated_code = kernel_generator( + ptr_a_val="ptr_a_val", + ptr_a_idx="ptr_a_idx", + ptr_b="ptr_b", + ptr_c="ptr_c", + M=tile_M if (rem_m == 0) else (M % tile_M), + K=tile_K if (rem_k == 0) else (K % tile_K), + SAMS=A_val_desc.strides[-1], + SAKS=A_val_desc.strides[-2], + SAMI=A_idx_desc.strides[-1], + SAKI=A_idx_desc.strides[-2], + SBV=SBV, + SBB=SBB, + SBK=SBK, + SCM=SCM, + SCB=SCB, + SCV=SCV, + n=n, + m=m, + g=g, + Nblocks=Nblocks, + num_vectors_b=num_vectors_b, + v=v, + ) + + a_index_k = "block_k" if (rem_k == 0) else f"{DKI_padded-1}" + a_index_m = "block_m" if (rem_m == 0) else f"{DMI_padded-1}" + b_index_k = ( + f"block_k*{DK2_g}:(block_k+1)*{DK2_g}" + if (rem_k == 0) + else f"{num_tiles_k*DK2*g}:{K}" + ) + b_index_n = f"block_n*{num_vectors_b*v}:(block_n+1)*{num_vectors_b*v}" + + # Add tasklet + tasklet = state.add_tasklet( + "kernel", + {"ptr_a_val", "ptr_a_idx", "ptr_b", "ptr_cin"}, + {"ptr_c"}, + autogenerated_code, + language=dace.Language.CPP, + ) + + # Add access nodes + a_val = state.add_read("A_val") + a_idx = state.add_read("A_idx") + b = state.add_read("B") + cin = state.add_read("C") + cout = state.add_write("C") + + # Add memlet paths + state.add_memlet_path( + a_val, + *map_entries, + tasklet, + dst_conn="ptr_a_val", + memlet=dace.Memlet(f"A_val[{a_index_k}, {a_index_m}, 0:{DK2}, 0:{g}, 0:{n}]"), + ) + state.add_memlet_path( + a_idx, + *map_entries, + tasklet, + dst_conn="ptr_a_idx", + memlet=dace.Memlet(f"A_idx[{a_index_k}, {a_index_m}, 0:{DK2}, 0:{g}, 0:1]"), + ) + state.add_memlet_path( + b, + *map_entries, + tasklet, + dst_conn="ptr_b", + memlet=dace.Memlet(f"B[{b_index_k}, {b_index_n}]"), + ) + state.add_memlet_path( + cin, + *map_entries, + tasklet, + dst_conn="ptr_cin", + memlet=dace.Memlet(f"C[{a_index_m}, 0:{m}, block_n, 0:{num_vectors_b}, 0:{v}]"), + ) + state.add_memlet_path( + tasklet, + *map_exits[::-1], + cout, + src_conn="ptr_c", + memlet=dace.Memlet(f"C[{a_index_m}, 0:{m}, block_n, 0:{num_vectors_b}, 0:{v}]"), + ) + + +def nmg_mult( + outshape: Tuple[int], + m: int, + n: int, + g: int, + transpose_b: bool, + transpose_c: bool, + tile: int, + tile_2: int, + local_b: bool, + local_c: bool, + name: str = None, + kernel: str = "avx2", +): + assert kernel in ("avx2", "avx512") + kernel_generator = ( + generate_sparse_dense_microkernel_avx2 + if kernel == "avx2" + else generate_sparse_dense_microkernel_avx512 + ) + + import dace + + # Ensure integers are 32-bit by default + dace.Config.set("compiler", "default_data_types", value="C") + dace.Config.set("compiler.allow_view_arguments", value=True) + + M, K, N = outshape + # Number of permutations + DK2 = math.factorial(m) // math.factorial(n) // math.factorial(m - n) + DK2_g = DK2 * g + v = 8 if kernel == "avx2" else 16 + num_vectors_b = 4 + assert N % v == 0 + assert N % (num_vectors_b * v) == 0 + + DMI_padded = math.ceil(M / m) + DM_padded = DMI_padded * m + DKI_padded = math.ceil(K / DK2_g) + DK_padded = DKI_padded * DK2_g + + Nblocks = N // (num_vectors_b * v) + + A_val_desc = dace.float32[DKI_padded, DMI_padded, DK2, g, n] + A_idx_desc = dace.int16[DKI_padded, DMI_padded, DK2, g, 1] + B_desc = dace.float32[DKI_padded, DK2, g, Nblocks, num_vectors_b, v] + C_desc = dace.float32[DMI_padded, m, Nblocks, num_vectors_b, v] + + local_B_desc = dace.float32[DK2, g, num_vectors_b, v] + local_C_desc = dace.float32[DMI_padded, m, num_vectors_b, v] + + # Change strides accordingly + if transpose_b: + B_desc.set_strides_from_layout(2, 1, 0, 5, 4, 3) + if transpose_c: + C_desc.set_strides_from_layout(1, 0, 4, 3, 2) + + # Handle strides in local storage + if local_b: + SBV, SBB, SBK = ( + local_B_desc.strides[-1], + local_B_desc.strides[-2], + local_B_desc.strides[-3], + ) + else: + SBV, SBB, SBK = B_desc.strides[-1], B_desc.strides[-2], B_desc.strides[-4] + + if local_c: + SCM, SCB, SCV = ( + local_C_desc.strides[-3], + local_C_desc.strides[-2], + local_C_desc.strides[-1], + ) + else: + SCM, SCB, SCV = C_desc.strides[-4], C_desc.strides[-2], C_desc.strides[-1] + + AUTOGENERATED_CODE = [[[";", ";"], [";", ";"]], [[";", ";"], [";", ";"]]] + + tile_M = m + tile_K = DK2_g + tile_N = v * num_vectors_b + + for mrem in range(2): + for nrem in range(2): + for krem in range(2): + if mrem and (M % tile_M) < 0: + raise ValueError("Invalid tile_M") + if mrem and (M % tile_M) == 0: + continue + if nrem and (N % tile_N) < 0: + raise ValueError("Invalid tile_N") + if nrem and (N % tile_N) == 0: + continue + if krem and (K % tile_K) < 0: + raise ValueError("Invalid tile_K") + if krem and (K % tile_K) == 0: + continue + + AUTOGENERATED_CODE[mrem][nrem][krem] = kernel_generator( + ptr_a_val="ptr_a_val", + ptr_a_idx="ptr_a_idx", + ptr_b="ptr_b", + ptr_c="ptr_c", + M=tile_M if (mrem == 0) else (M % tile_M), + K=tile_K if (krem == 0) else (K % tile_K), + SAMS=A_val_desc.strides[-1], + SAKS=A_val_desc.strides[-2], + SAMI=A_idx_desc.strides[-1], + SAKI=A_idx_desc.strides[-2], + SBV=SBV, + SBB=SBB, + SBK=SBK, + SCM=SCM, + SCB=SCB, + SCV=SCV, + n=n, + m=m, + g=g, + Nblocks=Nblocks, + num_vectors_b=num_vectors_b, + v=v, + ) + AUTOGENERATED_CODE_0_0_0 = "/* XXX 000 XXX */\n" + AUTOGENERATED_CODE[0][0][0] + AUTOGENERATED_CODE_0_0_1 = "/* XXX 00K XXX */\n" + AUTOGENERATED_CODE[0][0][1] + AUTOGENERATED_CODE_0_1_0 = "/* XXX 0N0 XXX */\n" + AUTOGENERATED_CODE[0][1][0] + AUTOGENERATED_CODE_0_1_1 = "/* XXX 0NK XXX */\n" + AUTOGENERATED_CODE[0][1][1] + AUTOGENERATED_CODE_1_0_0 = "/* XXX M00 XXX */\n" + AUTOGENERATED_CODE[1][0][0] + AUTOGENERATED_CODE_1_0_1 = "/* XXX M0K XXX */\n" + AUTOGENERATED_CODE[1][0][1] + AUTOGENERATED_CODE_1_1_0 = "/* XXX MN0 XXX */\n" + AUTOGENERATED_CODE[1][1][0] + AUTOGENERATED_CODE_1_1_1 = "/* XXX MNK XXX */\n" + AUTOGENERATED_CODE[1][1][1] + + num_tiles_k = DKI_padded + if K % tile_K != 0: + num_tiles_k -= 1 + k_remainder = True + else: + k_remainder = False + num_tiles_m = DMI_padded + # if M % tile_M != 0: + # num_tiles_m -= 1 + # m_remainder = True + # else: + m_remainder = False + + # print('K remainder:', k_remainder, 'M remainder:', (M % tile_M) != 0) + + B_desc = dace.float32[K, N] + if transpose_b: + B_desc.set_strides_from_layout(0, 1) + + block_n = dace.symbol("block_n") + + @dace.program(auto_optimize=True) + def bla(A_val: A_val_desc, A_idx: A_idx_desc, B: B_desc, C: C_desc): + Clocal = np.ndarray([DMI_padded, m, num_vectors_b, v], dtype=np.float32) + Clocal[:] = 0 + + Blocal = np.ndarray([K, num_vectors_b * v], dtype=np.float32) + if transpose_b: + # for i, j in dace.map[0:K, 0:num_vectors_b*v]: + # Blocal[i, j] = B[i, block_n*num_vectors_b*v + j] + with dace.tasklet: + ( + inB + << B[ + :, + block_n * num_vectors_b * v : (block_n + 1) * num_vectors_b * v, + ] + ) + outB >> Blocal[:, :] + f""" + copy_with_transpose(outB, inB, {B_desc.strides[1]}, {num_vectors_b*v}, {num_vectors_b*v}, {K}); + """ + else: + Blocal[:, :] = B[ + :, block_n * num_vectors_b * v : (block_n + 1) * num_vectors_b * v + ] + + for block_k in range(num_tiles_k): + for block_m in dace.map[0:num_tiles_m]: + with dace.tasklet(dace.Language.CPP): + ptr_a_val << A_val[block_k, block_m, 0:DK2, 0:g, 0:n] + ptr_a_idx << A_idx[block_k, block_m, 0:DK2, 0:g, 0:1] + ptr_b << Blocal[block_k * DK2 * g : (block_k + 1) * DK2 * g, :] + ptr_cin << Clocal[block_m, :, :, :] + AUTOGENERATED_CODE_0_0_0 + ptr_c >> Clocal[block_m, :, :, :] + + # m remainder + if m_remainder: + with dace.tasklet(dace.Language.CPP): + ptr_a_val << A_val[block_k, -1, 0:DK2, 0:g, 0:n] + ptr_a_idx << A_idx[block_k, -1, 0:DK2, 0:g, 0:1] + ptr_b << Blocal[block_k * DK2 * g : (block_k + 1) * DK2 * g, :] + ptr_cin << Clocal[-1, :, :, :] + AUTOGENERATED_CODE_1_0_0 + ptr_c >> Clocal[-1, :, :, :] + + # k remainder + if k_remainder: + for block_m in dace.map[0:num_tiles_m]: + with dace.tasklet(dace.Language.CPP): + ptr_a_val << A_val[-1, block_m, 0:DK2, 0:g, 0:n] + ptr_a_idx << A_idx[-1, block_m, 0:DK2, 0:g, 0:1] + ptr_b << Blocal[num_tiles_k * DK2 * g :, :] + ptr_cin << Clocal[block_m, :, :, :] + AUTOGENERATED_CODE_0_0_1 + ptr_c >> Clocal[block_m, :, :, :] + + # k/m remainder + if m_remainder: + with dace.tasklet(dace.Language.CPP): + ptr_a_val << A_val[-1, -1, 0:DK2, 0:g, 0:n] + ptr_a_idx << A_idx[-1, -1, 0:DK2, 0:g, 0:1] + ptr_b << Blocal[num_tiles_k * DK2 * g :, :] + ptr_cin << Clocal[-1, :, :, :] + AUTOGENERATED_CODE_1_0_1 + ptr_c >> Clocal[-1, :, :, :] + + # if transpose_c: + # # for i, j, k, l in dace.map[0:DMI_padded, 0:m, 0:num_vectors_b, 0:v]: + # # C[i, j, block_n, k, l] = Clocal[i, j, k, l] + # else: + C[:, :, block_n, :, :] = Clocal[:, :, :, :] + + @dace.program(auto_optimize=True) + def gnm_mult( + A_val: A_val_desc, + A_idx: A_idx_desc, + B: B_desc, + C: C_desc, + ): + for block_n in dace.map[0:Nblocks]: + bla(A_val, A_idx, B, C, block_n=block_n) + + sdfg = gnm_mult.to_sdfg() + sdfg.append_global_code( + """ + #include + #define FOR(i, n) for (int i = 0; i < n; ++i) + +inline void copy_with_transpose( + float* dst, + const float* src, + size_t s, + size_t d, + size_t m, + size_t n +) { + /* + src stride [s, 1] shape [m, n] + dst stride [1, d] shape [n, m] + */ + size_t i = 0; + for (; i < m / 4 * 4; i += 4) { + size_t j = 0; + for (; j < n / 4 * 4; j += 4) { + __m128 x0 = _mm_loadu_ps(&src[(i + 0) * s + j]); + __m128 x1 = _mm_loadu_ps(&src[(i + 1) * s + j]); + __m128 x2 = _mm_loadu_ps(&src[(i + 2) * s + j]); + __m128 x3 = _mm_loadu_ps(&src[(i + 3) * s + j]); + __m128 y0 = _mm_unpacklo_ps(x0, x1); + __m128 y1 = _mm_unpackhi_ps(x0, x1); + __m128 y2 = _mm_unpacklo_ps(x2, x3); + __m128 y3 = _mm_unpackhi_ps(x2, x3); + __m128 z0 = _mm_movelh_ps(y0, y2); + __m128 z1 = _mm_movehl_ps(y2, y0); + __m128 z2 = _mm_movelh_ps(y1, y3); + __m128 z3 = _mm_movehl_ps(y3, y1); + _mm_store_ps(&dst[i + (j + 0) * d], z0); + _mm_store_ps(&dst[i + (j + 1) * d], z1); + _mm_store_ps(&dst[i + (j + 2) * d], z2); + _mm_store_ps(&dst[i + (j + 3) * d], z3); + } + for (; j < n; j += 4) { + __m128i mask_load = (n - j == 3) ? _mm_setr_epi32(-1, -1, -1, 0) : + (n - j == 2) ? _mm_setr_epi32(-1, -1, 0, 0) : + /*n - j == 1*/ _mm_setr_epi32(-1, 0, 0, 0); + __m128 x0 = _mm_maskload_ps(&src[(i + 0) * s + j], mask_load); + __m128 x1 = _mm_maskload_ps(&src[(i + 1) * s + j], mask_load); + __m128 x2 = _mm_maskload_ps(&src[(i + 2) * s + j], mask_load); + __m128 x3 = _mm_maskload_ps(&src[(i + 3) * s + j], mask_load); + __m128 y0 = _mm_unpacklo_ps(x0, x1); + __m128 y1 = _mm_unpackhi_ps(x0, x1); + __m128 y2 = _mm_unpacklo_ps(x2, x3); + __m128 y3 = _mm_unpackhi_ps(x2, x3); + __m128 z0 = _mm_movelh_ps(y0, y2); + __m128 z1 = _mm_movehl_ps(y2, y0); + __m128 z2 = _mm_movelh_ps(y1, y3); + __m128 z3 = _mm_movehl_ps(y3, y1); + _mm_store_ps(&dst[i + (j + 0) * d], z0); + if (n - j == 1) continue; + _mm_store_ps(&dst[i + (j + 1) * d], z1); + if (n - j == 2) continue; + _mm_store_ps(&dst[i + (j + 2) * d], z2); + } + } + for (; i < m; i += 4) { + __m128i mask_store = (m - i == 3) ? _mm_setr_epi32(-1, -1, -1, 0) : + (m - i == 2) ? _mm_setr_epi32(-1, -1, 0, 0) : + /*m - i == 1*/ _mm_setr_epi32(-1, 0, 0, 0); + size_t j = 0; + for (; j < n / 4 * 4; j += 4) { + __m128 x0 = _mm_loadu_ps(&src[(i + 0) * s + j]); + __m128 x1 = (m - i > 1) ? _mm_loadu_ps(&src[(i + 1) * s + j]) : _mm_setzero_ps(); + __m128 x2 = (m - i > 2) ? _mm_loadu_ps(&src[(i + 2) * s + j]) : _mm_setzero_ps(); + __m128 x3 = (m - i > 3) ? _mm_loadu_ps(&src[(i + 3) * s + j]) : _mm_setzero_ps(); + __m128 y0 = _mm_unpacklo_ps(x0, x1); + __m128 y1 = _mm_unpackhi_ps(x0, x1); + __m128 y2 = _mm_unpacklo_ps(x2, x3); + __m128 y3 = _mm_unpackhi_ps(x2, x3); + __m128 z0 = _mm_movelh_ps(y0, y2); + __m128 z1 = _mm_movehl_ps(y2, y0); + __m128 z2 = _mm_movelh_ps(y1, y3); + __m128 z3 = _mm_movehl_ps(y3, y1); + _mm_maskstore_ps(&dst[i + (j + 0) * d], mask_store, z0); + _mm_maskstore_ps(&dst[i + (j + 1) * d], mask_store, z1); + _mm_maskstore_ps(&dst[i + (j + 2) * d], mask_store, z2); + _mm_maskstore_ps(&dst[i + (j + 3) * d], mask_store, z3); + } + for (; j < n; j += 4) { + __m128i mask_load = (n - j == 3) ? _mm_setr_epi32(-1, -1, -1, 0) : + (n - j == 2) ? _mm_setr_epi32(-1, -1, 0, 0) : + /*n - j == 1*/ _mm_setr_epi32(-1, 0, 0, 0); + __m128 x0 = _mm_maskload_ps(&src[(i + 0) * s + j], mask_load); + __m128 x1 = (m - i > 1) ? _mm_maskload_ps(&src[(i + 1) * s + j], mask_load) : _mm_setzero_ps(); + __m128 x2 = (m - i > 2) ? _mm_maskload_ps(&src[(i + 2) * s + j], mask_load) : _mm_setzero_ps(); + __m128 x3 = (m - i > 3) ? _mm_maskload_ps(&src[(i + 3) * s + j], mask_load) : _mm_setzero_ps(); + __m128 y0 = _mm_unpacklo_ps(x0, x1); + __m128 y1 = _mm_unpackhi_ps(x0, x1); + __m128 y2 = _mm_unpacklo_ps(x2, x3); + __m128 y3 = _mm_unpackhi_ps(x2, x3); + __m128 z0 = _mm_movelh_ps(y0, y2); + __m128 z1 = _mm_movehl_ps(y2, y0); + __m128 z2 = _mm_movelh_ps(y1, y3); + __m128 z3 = _mm_movehl_ps(y3, y1); + _mm_maskstore_ps(&dst[i + (j + 0) * d], mask_store, z0); + if (n - j == 1) continue; + _mm_maskstore_ps(&dst[i + (j + 1) * d], mask_store, z1); + if (n - j == 2) continue; + _mm_maskstore_ps(&dst[i + (j + 2) * d], mask_store, z2); + } + } +} + + + """ + ) + for sd in sdfg.all_sdfgs_recursive(): + sd.openmp_sections = False + + if name is not None: + sdfg.name = name + + from dace.transformation.auto import auto_optimize + + # sdfg = auto_optimize.auto_optimize(sdfg, dace.DeviceType.CPU) + + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, dace.nodes.MapEntry): + if len(node.map.params) > 1: + node.map.collapse = 2 + + def find_map_by_param(sdfg: dace.SDFG, pname: str) -> dace.nodes.MapEntry: + return next( + n + for n, _ in sdfg.all_nodes_recursive() + if isinstance(n, dace.nodes.MapEntry) and pname in n.params + ) + + def find_all_maps_by_param(sdfg: dace.SDFG, pname: str) -> dace.nodes.MapEntry: + yield from ( + (n, p) + for n, p in sdfg.all_nodes_recursive() + if isinstance(n, dace.nodes.MapEntry) and pname in n.params + ) + + from dace.transformation import helpers as xfutil + from dace.transformation.dataflow import MapExpansion, InLocalStorage + + map = find_map_by_param(sdfg, "block_m") + # map_m, map_n = MapExpansion.apply_to(sdfg, map_entry=map) + if tile > 0: + xfutil.tile(sdfg, map, (N % tile == 0), False, block_n=tile) + if tile_2 > 0: + xfutil.tile(sdfg, map, (N % tile_2 == 0), False, block_n=tile_2) + # xfutil.permute_map(map, [1, 0]) + # tile_map_m = find_map_by_param(sdfg, 'tile_block_n') + + # Hardcoded above + # if local_b: + # state: dace.SDFGState + # for map, state in find_all_maps_by_param(sdfg, 'block_n'): + # sd = state.parent + # # For each n map, add a new transient and copy to it + # name = sd.add_datadesc('local_B', copy.deepcopy(local_B_desc), find_new_name=True) + # sd.arrays[name].transient = True + # new_node = state.add_access(name) + # # Redirect B edge through it + # edge = next(e for e in state.out_edges(map) if e.data.data == 'B') + # state.remove_edge(edge) + # state.add_edge(edge.src, edge.src_conn, new_node, None, edge.data) + # state.add_edge(new_node, None, edge.dst, edge.dst_conn, dace.Memlet(data=name)) + + # Hardcoded above + # if local_c: + # state: dace.SDFGState + # for map, state in find_all_maps_by_param(sdfg, 'block_n'): + # exitnode = state.exit_node(map) + # sd = state.parent + # # For each n map, add a new transient and copy to it + # name = sd.add_datadesc('local_C', copy.deepcopy(local_C_desc), find_new_name=True) + # sd.arrays[name].transient = True + # # Redirect input through it + # edge = next(e for e in state.out_edges(map) if e.data.data == 'C') + # new_input_node = state.add_access(name) + # state.remove_edge(edge) + # state.add_edge(edge.src, edge.src_conn, new_input_node, None, edge.data) + # state.add_edge(new_input_node, None, edge.dst, edge.dst_conn, dace.Memlet(data=name)) + # # Redirect output edge through it + # edge = next(e for e in state.in_edges(exitnode) if e.data.data == 'C') + # new_output_node = state.add_access(name) + # state.remove_edge(edge) + # state.add_edge(edge.src, edge.src_conn, new_output_node, None, dace.Memlet(data=name)) + # state.add_edge(new_output_node, None, edge.dst, edge.dst_conn, edge.data) + + # InLocalStorage.apply_to(sdfg, dict(array='A_val'), node_a=tile_map_m, node_b=map) + # InLocalStorage.apply_to(sdfg, dict(array='A_idx'), node_a=tile_map_m, node_b=map) + dace.Config.set("compiler", "max_stack_array_size", value=999999999999) + + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, dace.nodes.MapEntry): + if len(node.map.params) > 1: + node.map.collapse = 2 + + return sdfg.compile() + + +def test_dace(kernel=None): + parser = argparse.ArgumentParser() + + parser.add_argument("-n", type=int, default=3) + parser.add_argument("-m", type=int, default=6) + parser.add_argument("-g", type=int, default=12) + parser.add_argument("--trans_b", action="store_true", default=False) + parser.add_argument("--trans_c", action="store_true", default=False) + parser.add_argument("--tile", type=int, default=0) + parser.add_argument("--tile_2", type=int, default=0) + parser.add_argument( + "--kernel", type=str, default="best", choices=["best", "avx2", "avx512"] + ) + + parser.add_argument("--local_b", action="store_true", default=False) + parser.add_argument("--local_c", action="store_true", default=False) + # TODO: loop order + + args = parser.parse_args() + + kernel: str = kernel or args.kernel + if kernel == "best": + try: + import cpuinfo + + flags = cpuinfo.get_cpu_info()["flags"] + if "avx512f" in flags: + kernel = "avx512" + elif "avx2" in flags: + kernel = "avx2" + else: + raise RuntimeError( + "This CPU does not support AVX2, which is the minimum requirement for the kernel to run" + ) + except (ImportError, ModuleNotFoundError): + print("To select best kernel automatically, use `pip install py-cpuinfo`.") + kernel = "avx2" + + # torch.manual_seed(1) + # shape = (4-1, 6-1) + # shape_b = (6-1, 32-1) + M, K, N = 768, 3072, 4096 + # M, K, N = 768, 768, 4096 + # M, K, N = 768, 20*12*10, 4096 + shape = (M, K) + shape_b = (K, N) + m, n, g = args.m, args.n, args.g + # m, n, g = 6, 3, 12 + # m, n, g = 6, 3, 7 + # m, n, g = 10, 1, 24 + # m, n, g = 4, 2, 1 + # m, n, g = 2, 1, 1 + # dense_ten = torch.randint(5, 100, shape, device="cpu", dtype=torch.float32) + + # Hardcoded for now + args.local_b = True + args.local_c = True + + print(f"Dimensions: {shape[0]} x {shape[1]} x {shape_b[1]}") + print(f"B transposed: {args.trans_b}, C transposed: {args.trans_c}") + print(f"n:m:g microkernel/dace hybrid with {n}:{m}:{g} ({kernel.upper()})") + + # Optimizations + # Transpose always means storing local tiles + local_b = args.local_b if not args.trans_b else True + local_c = args.local_c if not args.trans_c else True + print("Optimizations:") + if args.tile > 1: + print("Tile size:", args.tile) + if args.tile_2 > 1: + print("Tile 2 size:", args.tile_2) + print(f"Local storage: (B: {local_b}, C: {local_c})") + + dense_ten = torch.rand(shape, device="cpu", dtype=torch.float32) + sparse_dim = 0 + group_dim = 1 + t1 = time.time() + nm_ten = GroupedNMTensor.from_dense( + dense_ten, + n=n, + m=m, + sparse_dim=sparse_dim, + group_size=g, + group_dim=group_dim, + ) + t2 = time.time() + print(f"dense->sparse total time {t2-t1:.2f}") + + densified = nm_ten.to_dense() + + assert is_correct_nm(dense_ten, nm_ten.to_dense(), sparse_dim, n, m) + + # Dims: parts of matrix x number of groups x G x N + # print("val shape:", nm_ten.val.shape) + # print("idx shape:", nm_ten.idx.shape) + # print("original", functools.reduce(lambda x, y: x * y, dense_ten.shape, 1)) + # print("sparse", functools.reduce(lambda x, y: x * y, nm_ten.val.shape, 1)) + print( + "sparsity", + functools.reduce(lambda x, y: x * y, nm_ten.val.shape, 1) + / functools.reduce(lambda x, y: x * y, dense_ten.shape, 1), + ) + + # other = torch.randint(200, 300, shape_b, device="cpu", dtype=torch.float32) + # other = torch.zeros_like(other) + # other[1,0] = 1 + other = torch.rand(shape_b, device="cpu", dtype=torch.float32) + + # groups = np.array(nm_ten.nm_strides["order"], dtype=np.int32) + + # print('de', densified) + # print('ot', other) + + expected = densified @ other + # print('ex', expected) + + print("Compiling kernel...") + + v = 8 if kernel == "avx2" else 16 + # sdfg = generate_configuration(n, m, g, M, N, K, args.trans_b, args.trans_c, vector_size=v, + # loops=[Loop('n'), Loop('k', sequential=True), Loop('m', sequential=True)]) + # compiled = sdfg.compile() + compiled = nmg_mult( + (M, K, N), + m, + n, + g, + args.trans_b, + args.trans_c, + args.tile, + args.tile_2, + local_b, + local_c, + None, + kernel, + ) + + print("Compilation complete.") + + # tensor.val is of shape (number of groups=DMI*DKI, chunk_size, group_size, n) + # tensor.idx is of shape (loop_outer_size, chunk_size, group_size, 1) + A_val = nm_ten.val + A_idx = nm_ten.idx + + # Problem dimensions + perms = math.factorial(m) // math.factorial(n) // math.factorial(m - n) + DKI_padded = int_ceil(K, perms * g) + DK_padded = DKI_padded * perms * g + DMI_padded = math.ceil(M / m) + DM_padded = DMI_padded * m + num_vectors_b = 4 + v = 8 if kernel == "avx2" else 16 + Nblocks = N // (num_vectors_b * v) + + if args.trans_b: + # other_padded = torch.zeros([DK_padded, N], dtype=other.dtype) + # other_padded[:K, :] = other + # other_reshaped = other_padded.reshape([DKI_padded, perms, g, Nblocks, num_vectors_b, v]) + # other_permuted = other_reshaped.permute(3, 4, 5, 0, 1, 2) + B = other.permute(1, 0) + else: + # other_padded = torch.zeros([DK_padded, N], dtype=other.dtype) + # other_padded[:K, :] = other + # B = other_padded + B = other + + output = torch.empty([DM_padded, N], dtype=other.dtype) + if args.trans_c: + output = output.permute(1, 0).contiguous() + + compiled( + A_val=A_val.contiguous(), + A_idx=A_idx.contiguous(), + B=B.contiguous(), + C=output.contiguous(), + ) + + if args.trans_c: # Transpose back for correctness testing + output = output.permute(1, 0).contiguous() + result = output[:M, :N] + + avgdiff = (result - expected).abs().sum() / result.numel() + maxdiff = (result - expected).abs().max() + median_diff = (result - expected).abs().median() + diffcount = np.sum(np.where((result - expected).abs() > 1e-2, 1, 0)) + print( + f"avgdiff {avgdiff:.3f} maxdiff {maxdiff:.3f} median_diff {median_diff:.3f}. Count: {diffcount} / {result.numel()}" + ) + + assert torch.allclose(result, expected) + + +if __name__ == "__main__": + test_dace() + # test_dace('avx2') + # test_dace('avx512') + # test_dense_nm_conversion() + # print("ok") diff --git a/src/sten/grouped_nm/grouped_nm_tensor.py b/src/sten/grouped_nm/grouped_nm_tensor.py new file mode 100644 index 0000000..c078889 --- /dev/null +++ b/src/sten/grouped_nm/grouped_nm_tensor.py @@ -0,0 +1,1091 @@ +import torch +from math import factorial as fact +import math +import itertools +from native_scripting import compile +import ctypes +import numpy as np +import copy +import functools +import time +from sympy import factorint +from typing import Tuple +import warnings +import cpuinfo + +try: + cache = functools.cache +except AttributeError: + cache = functools.lru_cache(maxsize=None) + +# ++++++++++++++ n:m order generator ++++++++++++++ + + +def is_valid_n_m(l, n, m): + def is_increasing(a): + return all([e1 < e2 for e1, e2 in zip(a[:-1], a[1:])]) + + def is_adjacent(a1, a2): + return sum([i1 != i2 for i1, i2 in zip(a1, a2)]) == 1 + + all_increasing = all([is_increasing(a) for a in l]) + all_adjacent = all([is_adjacent(a, b) for a, b in zip(l[:-1], l[1:])]) + correct_size = len(l) == fact(m) // fact(n) // fact(m - n) + return all_increasing and all_adjacent and correct_size + + +@cache +def make_n_m(n, m, special=False): + # special -- True: "..xx.->...xx" False: "xx...->...xx" + assert 0 <= n and n <= m + if n == m or n == 0: + return [tuple(range(n))] + first = make_n_m(n, m - 1, True) + if special: + first = list(reversed(first)) + second = make_n_m(n - 1, m - 1, not special) + second = [tpl + (m - 1,) for tpl in second] + result = first + second + assert is_valid_n_m(result, n, m) + return result + + +def make_n_m_mask(m, nnz_indices): + res = [0] * m + for idx in nnz_indices: + res[idx] = 1 + return tuple(res) + + +def make_n_m_order_c(nnz_indices_list): + elems = [] + for tpl in nnz_indices_list: + elem = "{" + ", ".join([str(t) for t in tpl]) + "}" + elems.append(elem) + return "{" + ", ".join(elems) + "}" + + +# ============== n:m order generator ============== + + +@cache +def compute_nm_strides(dense_shape, n, m, sparse_dim, group_dim, group_size): + chunk_size = fact(m) // fact(n) // fact(m - n) + num_chunks = math.ceil(dense_shape[group_dim] / (chunk_size * group_size)) + num_blocks = math.ceil(dense_shape[sparse_dim] / m) + sparse_dim_expanded = [num_blocks, m] + group_dim_expanded = [num_chunks, chunk_size, group_size] + + padded_sparse_dim = math.prod(sparse_dim_expanded) + padded_group_dim = math.prod(group_dim_expanded) + + # pad sparse dim to fit into n:m blocks + padded_dense_shape = list(dense_shape) + padded_dense_shape[sparse_dim] = padded_sparse_dim + padded_dense_shape[group_dim] = padded_group_dim + + if sparse_dim < group_dim: + exp_sparse_dim = sparse_dim + exp_group_dim = group_dim + len(sparse_dim_expanded) - 1 + + smaller_dim_from = exp_sparse_dim + smaller_dim_to = exp_sparse_dim + len(sparse_dim_expanded) + larger_dim_from = exp_group_dim + larger_dim_to = exp_group_dim + len(group_dim_expanded) + + expanded_dense_shape = ( + padded_dense_shape[:sparse_dim] + + sparse_dim_expanded + + padded_dense_shape[sparse_dim + 1 : group_dim] + + group_dim_expanded + + padded_dense_shape[group_dim + 1 :] + ) + + else: + exp_sparse_dim = sparse_dim + len(group_dim_expanded) - 1 + exp_group_dim = group_dim + + smaller_dim_from = exp_group_dim + smaller_dim_to = exp_group_dim + len(group_dim_expanded) + larger_dim_from = exp_sparse_dim + larger_dim_to = exp_sparse_dim + len(sparse_dim_expanded) + + expanded_dense_shape = ( + padded_dense_shape[:group_dim] + + group_dim_expanded + + padded_dense_shape[group_dim + 1 : sparse_dim] + + sparse_dim_expanded + + padded_dense_shape[sparse_dim + 1 :] + ) + + sparse_val_shape = ( + expanded_dense_shape[: exp_sparse_dim + 1] + + [n] + + expanded_dense_shape[exp_sparse_dim + 2 :] + ) + sparse_idx_shape = ( + expanded_dense_shape[: exp_sparse_dim + 1] + + [1] + + expanded_dense_shape[exp_sparse_dim + 2 :] + ) + + if sparse_dim < group_dim: + dense_block_stride = math.prod(expanded_dense_shape[smaller_dim_to:]) + dense_group_stride = math.prod(expanded_dense_shape[larger_dim_to:]) + sparse_block_stride = math.prod(sparse_val_shape[smaller_dim_to:]) + sparse_group_stride = math.prod(sparse_val_shape[larger_dim_to:]) + idx_block_stride = math.prod(sparse_idx_shape[smaller_dim_to:]) + idx_group_stride = math.prod(sparse_idx_shape[larger_dim_to:]) + else: + dense_block_stride = math.prod(expanded_dense_shape[larger_dim_to:]) + dense_group_stride = math.prod(expanded_dense_shape[smaller_dim_to:]) + sparse_block_stride = math.prod(sparse_val_shape[larger_dim_to:]) + sparse_group_stride = math.prod(sparse_val_shape[smaller_dim_to:]) + idx_block_stride = math.prod(sparse_idx_shape[larger_dim_to:]) + idx_group_stride = math.prod(sparse_idx_shape[smaller_dim_to:]) + + dense_loop_outer_stride = math.prod(expanded_dense_shape[smaller_dim_from + 1 :]) + dense_loop_middle_stride = math.prod(expanded_dense_shape[larger_dim_from + 1 :]) + sparse_loop_outer_stride = math.prod(sparse_val_shape[smaller_dim_from + 1 :]) + sparse_loop_middle_stride = math.prod(sparse_val_shape[larger_dim_from + 1 :]) + idx_loop_outer_stride = math.prod(sparse_idx_shape[smaller_dim_from + 1 :]) + idx_loop_middle_stride = math.prod(sparse_idx_shape[larger_dim_from + 1 :]) + + expanded_ndim = len(expanded_dense_shape) + + last_dims = [ + exp_group_dim, + exp_sparse_dim, + exp_group_dim + 1, + exp_group_dim + 2, + exp_sparse_dim + 1, + ] + dim_permutation = [ + i for i in range(expanded_ndim) if i not in last_dims + ] + last_dims + inverse_dim_permutation = sorted( + range(expanded_ndim), key=lambda x: dim_permutation[x] + ) + + permuted_shape = [expanded_dense_shape[i] for i in dim_permutation] + merged_shape = [math.prod(permuted_shape[:-3]), *permuted_shape[-3:]] + + return { + "n": n, + "m": m, + "sparse_dim": sparse_dim, + "group_dim": group_dim, + "group_size": group_size, + "dense_shape": dense_shape, + "padded_dense_shape": padded_dense_shape, + "expanded_dense_shape": expanded_dense_shape, + "chunk_size": chunk_size, + "num_chunks": num_chunks, + "num_blocks": num_blocks, + "exp_sparse_dim": exp_sparse_dim, + "exp_group_dim": exp_group_dim, + "smaller_dim_from": smaller_dim_from, + "smaller_dim_to": smaller_dim_to, + "larger_dim_from": larger_dim_from, + "larger_dim_to": larger_dim_to, + "sparse_val_shape": sparse_val_shape, + "sparse_idx_shape": sparse_idx_shape, + "loop_outer_size": math.prod(expanded_dense_shape[: smaller_dim_from + 1]), + "loop_middle_size": math.prod( + expanded_dense_shape[smaller_dim_to : larger_dim_from + 1] + ), + "loop_inner_size": math.prod(expanded_dense_shape[larger_dim_to:]), + "order": make_n_m(n, m), + "dense_block_stride": dense_block_stride, + "dense_group_stride": dense_group_stride, + "sparse_block_stride": sparse_block_stride, + "sparse_group_stride": sparse_group_stride, + "idx_block_stride": idx_block_stride, + "idx_group_stride": idx_group_stride, + "dense_loop_middle_stride": dense_loop_middle_stride, + "dense_loop_outer_stride": dense_loop_outer_stride, + "sparse_loop_middle_stride": sparse_loop_middle_stride, + "sparse_loop_outer_stride": sparse_loop_outer_stride, + "idx_loop_middle_stride": idx_loop_middle_stride, + "idx_loop_outer_stride": idx_loop_outer_stride, + "padded_sparse_dim": padded_sparse_dim, + "padded_group_dim": padded_group_dim, + "dim_permutation": dim_permutation, + "inverse_dim_permutation": inverse_dim_permutation, + "permuted_shape": permuted_shape, + "merged_shape": merged_shape, + } + + +@cache +def get_dense_to_grouped_nm_impl_cpu(dense_dtype, dense_shape, n): + assert len(dense_shape) == 4 # (batch_dim, chunk=m!/n!/(m-n)!, group, block=m) + chunk_size = dense_shape[1] + group_size = dense_shape[2] + m = dense_shape[3] + assert chunk_size == fact(m) // fact(n) // fact(m - n) + + # check that group_size * chunk_size fits int16_t + assert chunk_size * group_size < 2**15 + + order = make_n_m(n, m) + loop_outer_size = dense_shape[0] + + dense_group_stride = m + sparse_group_stride = n + idx_group_stride = 1 + + dense_loop_outer_stride = chunk_size * group_size * m + sparse_loop_outer_stride = chunk_size * group_size * n + idx_loop_outer_stride = chunk_size * group_size * 1 + + assert dense_dtype in (torch.float32, torch.float64) + dtype = "float" if dense_dtype == torch.float32 else "double" + lib = compile( + f""" + #include + #include + #include + #include + #include + #include + #include + #include + static const int8_t nnz_order[{chunk_size}][{n}] = {make_n_m_order_c(order)}; + extern "C" void func({dtype}* dense, {dtype}* sparse, int16_t* idx) {{ + const int64_t acc_size = {chunk_size} * {chunk_size} * {group_size}; + std::vector> accs(acc_size); + for (int64_t os_idx = 0; os_idx < {loop_outer_size}; os_idx++) {{ + {dtype}* dense_base = &dense[os_idx * {dense_loop_outer_stride}]; + {dtype}* sparse_base = &sparse[os_idx * {sparse_loop_outer_stride}]; + int16_t* idx_base = &idx[os_idx * {idx_loop_outer_stride}]; + for (int64_t cs_idx = 0; cs_idx < {chunk_size}; cs_idx++) {{ + for (int64_t gs_idx = 0; gs_idx < {group_size}; gs_idx++) {{ + for (int64_t nnz_idx = 0; nnz_idx < {chunk_size}; nnz_idx++) {{ + {dtype} blk_acc = 0; + int16_t original_idx = cs_idx * {group_size} + gs_idx; + for (int64_t n_idx = 0; n_idx < {n}; n_idx++) {{ + int16_t m_idx = nnz_order[nnz_idx][n_idx]; + blk_acc += std::abs(dense_base[m_idx + original_idx * {dense_group_stride}]); + }} + int64_t acc_idx = nnz_idx + original_idx * {chunk_size}; + accs[acc_idx] = std::make_tuple(blk_acc, original_idx, nnz_idx); + }} + }} + }} + std::sort(accs.begin(), accs.end(), std::greater<>()); + // extract elements from sorted accumulators into resulting tensor + int16_t group_elems[{chunk_size}] = {{ 0 }}; + bool is_taken[{chunk_size} * {group_size}] = {{ 0 }}; + for (int64_t acc_idx = 0; acc_idx < acc_size; acc_idx++) {{ + {dtype} val = 0; + int16_t orig_idx = -1, nnz_idx = -1; + std::tie(val, orig_idx, nnz_idx) = accs[acc_idx]; + if (!is_taken[orig_idx] && group_elems[nnz_idx] < {group_size}) {{ + is_taken[orig_idx] = true; + int16_t group_idx = group_elems[nnz_idx]++; + int16_t sparse_idx = (group_idx + nnz_idx * {group_size}); + // put into the output + idx_base[sparse_idx * {idx_group_stride}] = orig_idx; + for (int64_t n_idx = 0; n_idx < {n}; n_idx++) {{ + int64_t m_idx = nnz_order[nnz_idx][n_idx]; + sparse_base[n_idx + sparse_idx * {sparse_group_stride}] = dense_base[m_idx + orig_idx * {dense_group_stride}]; + }} + }} + }} + }} + }} + """, + ) + lib.func.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ] + return lib.func + + +def int_ceil(x, y): + return (x - 1) // y + 1 + + +def find_opt_cuda_block(nx, ny, nz, max_block=1024, max_z=64): + # prioritize powers of 2 to avoid remainders everywhere + dims = [nx, ny, nz] + pows2 = [factorint(d).get(2, 0) for d in dims] + block_sizes = [2**p for p in pows2] + dims = [d // b for d, b in zip(dims, block_sizes)] + + # fit z dimension if it exceedes limits + while block_sizes[-1] > max_z: + block_sizes[-1] //= 2 + dims[-1] *= 2 + + # remove some powers if max block size is exceeded + while math.prod(block_sizes) > max_block: + for i in range(len(dims)): + if block_sizes[i] == 1: + continue + block_sizes[i] //= 2 + dims[i] *= 2 + break + + assert math.prod(block_sizes) <= max_block + # now distribute some powers of 2 while max block is not reached + while math.prod(block_sizes) < max_block: + # print(f'block sizes {block_sizes}') + exitting = False + for i, d in enumerate(dims): + if d != max(dims): + continue + if i == 2 and block_sizes[i] == 64: + exitting = True + break + block_sizes[i] *= 2 + dims[i] //= 2 + break + if exitting: + break + + # print(f"optimal block {block_sizes} for inputs of size {[nx, ny, nz]}") + assert math.prod(block_sizes) <= max_block + assert block_sizes[-1] <= max_z + return block_sizes + + +@cache +def get_dense_to_grouped_nm_impl_cuda(dense_dtype, dense_shape, n): + assert len(dense_shape) == 4 # (batch_dim, chunk=m!/n!/(m-n)!, group, block=m) + chunk_size = dense_shape[1] + group_size = dense_shape[2] + m = dense_shape[3] + assert chunk_size == fact(m) // fact(n) // fact(m - n) + + # check that group_size * chunk_size fits int16_t + assert chunk_size * group_size < 2**15 + + order = make_n_m(n, m) + loop_outer_size = dense_shape[0] + + dense_group_stride = m + sparse_group_stride = n + idx_group_stride = 1 + + dense_loop_outer_stride = chunk_size * group_size * m + sparse_loop_outer_stride = chunk_size * group_size * n + idx_loop_outer_stride = chunk_size * group_size * 1 + + acc_size = chunk_size * chunk_size * group_size + + block_x, block_y, block_z = find_opt_cuda_block( + chunk_size, loop_outer_size, group_size + ) + + grid_x = int_ceil(chunk_size, block_x) + grid_y = int_ceil(loop_outer_size, block_y) + grid_z = int_ceil(group_size, block_z) + + N = chunk_size + if N % 2 == 0: + Nx = N - 1 + Ny = N // 2 + else: + Nx = N + Ny = (N - 1) // 2 + + block_x_pp, block_y_pp, block_z_pp = find_opt_cuda_block( + Nx * group_size, loop_outer_size, Ny * group_size + ) + + grid_x_pp = int_ceil(Nx * group_size, block_x_pp) + grid_y_pp = int_ceil(loop_outer_size, block_y_pp) + grid_z_pp = int_ceil(Ny * group_size, block_z_pp) + + assert dense_dtype in (torch.float32, torch.float64) + dtype = "float" if dense_dtype == torch.float32 else "double" + lib = compile( + f""" + #include + #include + #include + #include + #include + #include + #include + #include + #include + #define CUDA_CHECK(expr) do {{\\ + cudaError_t err = (expr);\\ + if (err != cudaSuccess) {{\\ + std::cerr << "CUDA ERROR: " << __FILE__ << ":" << __LINE__ << ": " << #expr << " <" << cudaGetErrorName(err) << "> " << cudaGetErrorString(err) << "\\n"; \\ + abort(); \\ + }}\\ + }} while(0) + __device__ __managed__ int something_swapped; + __device__ static const int8_t nnz_order[{chunk_size}][{n}] = {make_n_m_order_c(order)}; + __global__ void kernel_preprocess({dtype}* dense, {dtype}* all_accs, int* is_taken_all, int* groups_all) {{ + int64_t cs_idx = threadIdx.x + blockDim.x * blockIdx.x; + int64_t os_idx = threadIdx.y + blockDim.y * blockIdx.y; + int64_t gs_idx = threadIdx.z + blockDim.z * blockIdx.z; + if (cs_idx >= {chunk_size}) return; + if (os_idx >= {loop_outer_size}) return; + if (gs_idx >= {group_size}) return; + + int* groups = groups_all + os_idx * {chunk_size}; + + {dtype}* accs = all_accs + os_idx * {acc_size}; + {dtype}* dense_base = &dense[os_idx * {dense_loop_outer_stride}]; + int max_nnz_idx = -1; + {dtype} max_val = 1e-3; + for (int16_t nnz_idx = 0; nnz_idx < {chunk_size}; nnz_idx++) {{ + {dtype} blk_acc = 0; + int16_t original_idx = cs_idx * {group_size} + gs_idx; + for (int64_t n_idx = 0; n_idx < {n}; n_idx++) {{ + int16_t m_idx = nnz_order[nnz_idx][n_idx]; + {dtype} abs_val = std::abs(dense_base[m_idx + original_idx * {dense_group_stride}]); + blk_acc += abs_val; + }} + int64_t acc_idx = nnz_idx + original_idx * {chunk_size}; + accs[acc_idx] = blk_acc; + if (blk_acc > max_val) {{ + max_nnz_idx = nnz_idx; + max_val = blk_acc; + }} + }} + int group_idx = {group_size}; + if (max_nnz_idx != -1) {{ + group_idx = atomicAdd(groups + max_nnz_idx, 1); + }} + // is_taken maps index in the original array to the index in new array + int* is_taken = is_taken_all + os_idx * {chunk_size} * {group_size}; + if (group_idx < {group_size}) {{ + is_taken[cs_idx * {group_size} + gs_idx] = max_nnz_idx * {group_size} + group_idx; + }} else {{ + something_swapped = 1; // we didn't find optimal distribution on the first try + is_taken[cs_idx * {group_size} + gs_idx] = -1; + }} + }} + __device__ void vset(volatile int* ptr, int val) {{ *ptr = val; }} + __device__ int vget(volatile int* ptr) {{ return *ptr; }} + __global__ void kernel_preprocess2(int* is_taken_all, int* groups_all) {{ + int64_t cs_idx = threadIdx.x + blockDim.x * blockIdx.x; + int64_t os_idx = threadIdx.y + blockDim.y * blockIdx.y; + int64_t gs_idx = threadIdx.z + blockDim.z * blockIdx.z; + if (cs_idx >= {chunk_size}) return; + if (os_idx >= {loop_outer_size}) return; + if (gs_idx >= {group_size}) return; + int* is_taken = is_taken_all + os_idx * {chunk_size} * {group_size}; + int* groups = groups_all + os_idx * {chunk_size}; + if (is_taken[cs_idx * {group_size} + gs_idx] == -1) {{ + for (int16_t nnz_idx = 0; nnz_idx < {chunk_size}; nnz_idx++) {{ + if (vget(groups + nnz_idx) < {group_size}) {{ + int group_idx = atomicAdd(groups + nnz_idx, 1); + if (group_idx < {group_size}) {{ + is_taken[cs_idx * {group_size} + gs_idx] = nnz_idx * {group_size} + group_idx; + break; + }} + }} + }} + }} + }} + __global__ void kernel_postprocess({dtype}* all_accs, int* is_taken_all) {{ + int64_t xgs_idx = threadIdx.x + blockDim.x * blockIdx.x; + int64_t os_idx = threadIdx.y + blockDim.y * blockIdx.y; + int64_t ygs_idx = threadIdx.z + blockDim.z * blockIdx.z; + if (xgs_idx >= {Nx} * {group_size}) return; + if (os_idx >= {loop_outer_size}) return; + if (ygs_idx >= {Ny} * {group_size}) return; + + {dtype}* accs = all_accs + os_idx * {acc_size}; + // extract elements from sorted accumulators into resulting tensor + int* is_taken = is_taken_all + os_idx * {chunk_size} * {group_size}; + + int x = xgs_idx / {group_size}; + int l_old_gs_idx = xgs_idx % {group_size}; + + int y = ygs_idx / {group_size}; + int r_old_gs_idx = ygs_idx % {group_size}; + + int l_old_cs_idx = -1; + int r_old_cs_idx = -1; + + // init l_old_cs_idx and r_old_cs_idx + if ({N} % 2 == 0) {{ + if (y > x) {{ + l_old_cs_idx = ({N} - 1) - x; + r_old_cs_idx = ({N} - 1) - y; + }} else {{ + l_old_cs_idx = x + 1; + r_old_cs_idx = y; + }} + }} else {{ + if (y >= x) {{ + l_old_cs_idx = ({N} - 1) - x; + r_old_cs_idx = ({N} - 2) - y; + }} else {{ + l_old_cs_idx = x; + r_old_cs_idx = y; + }} + }} + + for (int rep = 0; rep < 1; rep++) {{ + + int l_old_idx = l_old_cs_idx * {group_size} + l_old_gs_idx; + int r_old_idx = r_old_cs_idx * {group_size} + r_old_gs_idx; + + if (l_old_idx > r_old_idx) {{ + // guarantee locking in the uniform order + int tmp = l_old_idx; + l_old_idx = r_old_idx; + r_old_idx = tmp; + }} + + int* l_is_taken = is_taken + l_old_idx; + int* r_is_taken = is_taken + r_old_idx; + + int l_new_idx = vget(l_is_taken); + if (l_new_idx < 0) continue; + int r_new_idx = vget(r_is_taken); + if (r_new_idx < 0) continue; + + int l_new_cs_idx = l_new_idx / {group_size}; + int r_new_cs_idx = r_new_idx / {group_size}; + + {dtype} l_cur_acc = accs[l_new_cs_idx + l_old_idx * {chunk_size}]; + {dtype} r_cur_acc = accs[r_new_cs_idx + r_old_idx * {chunk_size}]; + + {dtype} l_swap_acc = accs[r_new_cs_idx + l_old_idx * {chunk_size}]; + {dtype} r_swap_acc = accs[l_new_cs_idx + r_old_idx * {chunk_size}]; + + if (l_swap_acc + r_swap_acc > l_cur_acc + r_cur_acc) {{ + if (l_new_idx != atomicCAS(l_is_taken, l_new_idx, -1)) continue; + if (r_new_idx != atomicCAS(r_is_taken, r_new_idx, l_new_idx)) {{ + vset(l_is_taken, l_new_idx); // rollback + continue; + }} + vset(l_is_taken, r_new_idx); + something_swapped = 1; + }} + + }} + }} + __global__ void kernel_postprocess2({dtype}* dense, {dtype}* sparse, int16_t* idx, int* is_taken_all) {{ + int64_t cs_idx = threadIdx.x + blockDim.x * blockIdx.x; + int64_t os_idx = threadIdx.y + blockDim.y * blockIdx.y; + int64_t gs_idx = threadIdx.z + blockDim.z * blockIdx.z; + if (cs_idx >= {chunk_size}) return; + if (os_idx >= {loop_outer_size}) return; + if (gs_idx >= {group_size}) return; + + int* is_taken = is_taken_all + os_idx * {chunk_size} * {group_size}; + int16_t* idx_base = &idx[os_idx * {idx_loop_outer_stride}]; + {dtype}* dense_base = &dense[os_idx * {dense_loop_outer_stride}]; + {dtype}* sparse_base = &sparse[os_idx * {sparse_loop_outer_stride}]; + + //printf("Applying final permutation...\\n"); + int old_idx = cs_idx * {group_size} + gs_idx; + int new_idx = is_taken[old_idx]; + + int new_cs_idx = new_idx / {group_size}; + + // put into the output + idx_base[new_idx * {idx_group_stride}] = old_idx; + for (int64_t n_idx = 0; n_idx < {n}; n_idx++) {{ + int64_t m_idx = nnz_order[new_cs_idx][n_idx]; + sparse_base[n_idx + new_idx * {sparse_group_stride}] = dense_base[m_idx + old_idx * {dense_group_stride}]; + }} + }} + extern "C" void func({dtype}* dense, {dtype}* sparse, int16_t* idx, int* is_taken, {dtype}* accs, int* groups) {{ + something_swapped = 0; + kernel_preprocess<<>>(dense, accs, is_taken, groups); + CUDA_CHECK(cudaPeekAtLastError()); + CUDA_CHECK(cudaStreamSynchronize(0)); + + if (something_swapped) {{ + kernel_preprocess2<<>>(is_taken, groups); + CUDA_CHECK(cudaPeekAtLastError()); + }} + + while (something_swapped) {{ + something_swapped = 0; + kernel_postprocess<<>>(accs, is_taken); + CUDA_CHECK(cudaPeekAtLastError()); + CUDA_CHECK(cudaStreamSynchronize(0)); + }} + + kernel_postprocess2<<>>(dense, sparse, idx, is_taken); + CUDA_CHECK(cudaPeekAtLastError()); + CUDA_CHECK(cudaStreamSynchronize(0)); + }} + """, + lang="cu", + # opts=["-G"], + # opts=['--expt-relaxed-constexpr', '--expt-extended-lambda'] + ) + lib.func.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ] + return lib.func + + +def dense_to_grouped_n_m_merged(inp, n): + impl_bulder = ( + get_dense_to_grouped_nm_impl_cpu + if inp.device.type == "cpu" + else get_dense_to_grouped_nm_impl_cuda + ) + func = impl_bulder( + inp.dtype, + inp.shape, + n, + ) + loop_outer_size = inp.shape[0] + chunk_size = inp.shape[1] + group_size = inp.shape[2] + sparse_val_shape = (*inp.shape[:-1], n) + sparse_idx_shape = (*inp.shape[:-1], 1) + out = torch.zeros(sparse_val_shape, dtype=inp.dtype, device=inp.device) + idx = torch.full( + sparse_idx_shape, fill_value=-1, dtype=torch.int16, device=inp.device + ) + cont_inp = inp.contiguous() + if inp.device.type == "cpu": + func(cont_inp.data_ptr(), out.data_ptr(), idx.data_ptr()) + else: + is_taken = torch.empty( + loop_outer_size * chunk_size * group_size, + dtype=torch.int, + device=inp.device, + ) + accs = torch.empty( + loop_outer_size * chunk_size * chunk_size * group_size, + dtype=inp.dtype, + device=inp.device, + ) + groups = torch.zeros( + loop_outer_size * chunk_size, dtype=torch.int, device=inp.device + ) + with torch.cuda.device(inp.device): + # t1 = time.time() + func( + cont_inp.data_ptr(), + out.data_ptr(), + idx.data_ptr(), + is_taken.data_ptr(), + accs.data_ptr(), + groups.data_ptr(), + ) + # t2 = time.time() + # print(f"dense->sparse kernel time {t2-t1:.2f}") + return out, idx + + +def pad_to(tensor, new_shape): + padding = [(0, p - s) for s, p in zip(tensor.shape, new_shape)] + padding = [elem for pair in reversed(padding) for elem in pair] + return torch.nn.functional.pad(tensor, padding) + + +def unpad_to(tensor, new_shape): + assert tensor.ndim == len(new_shape) + for d, l in enumerate(new_shape): + tensor = tensor.narrow(dim=d, start=0, length=l) + return tensor + + +def dense_to_grouped_n_m(tensor, n, m, sparse_dim, group_size, group_dim): + nm_strides = compute_nm_strides( + tensor.shape, n, m, sparse_dim, group_dim, group_size + ) + + padded = pad_to(tensor, nm_strides["padded_dense_shape"]) + expanded = padded.reshape(nm_strides["expanded_dense_shape"]) + permuted = expanded.permute(nm_strides["dim_permutation"]) + merged = permuted.reshape(nm_strides["merged_shape"]) + return (*dense_to_grouped_n_m_merged(merged, n), nm_strides) + + +@cache +def get_grouped_n_m_to_dense_impl_cpu(dense_dtype, val_shape, m): + assert len(val_shape) == 4 # (batch_dim, chunk=m!/n!/(m-n)!, group, block=n) + chunk_size = val_shape[1] + group_size = val_shape[2] + n = val_shape[3] + assert chunk_size == fact(m) // fact(n) // fact(m - n) + + order = make_n_m(n, m) + loop_outer_size = val_shape[0] + + dense_loop_outer_stride = chunk_size * group_size * m + sparse_loop_outer_stride = chunk_size * group_size * n + idx_loop_outer_stride = chunk_size * group_size * 1 + + dense_group_stride = m + sparse_group_stride = n + idx_group_stride = 1 + + assert dense_dtype in (torch.float32, torch.float64) + dtype = "float" if dense_dtype == torch.float32 else "double" + chunk_size = len(order) + impl = compile( + f""" + #include + #include + #include + #include + static const int8_t nnz_order[{chunk_size}][{n}] = {make_n_m_order_c(order)}; + extern "C" void func({dtype}* sparse, {dtype}* dense, int16_t* idx) {{ + for (int64_t os_idx = 0; os_idx < {loop_outer_size}; os_idx++) {{ + {dtype}* dense_base = &dense[os_idx * {dense_loop_outer_stride}]; + {dtype}* sparse_base = &sparse[os_idx * {sparse_loop_outer_stride}]; + int16_t* idx_base = &idx[os_idx * {idx_loop_outer_stride}]; + for (int64_t cs_idx = 0; cs_idx < {chunk_size}; cs_idx++) {{ + for (int64_t gs_idx = 0; gs_idx < {group_size}; gs_idx++) {{ + int16_t sparse_idx = (gs_idx + cs_idx * {group_size}); + int16_t original_idx = idx_base[sparse_idx * {idx_group_stride}]; + for (int64_t n_idx = 0; n_idx < {n}; n_idx++) {{ + int64_t m_idx = nnz_order[cs_idx][n_idx]; + dense_base[m_idx + original_idx * {dense_group_stride}] = sparse_base[n_idx + sparse_idx * {sparse_group_stride}]; + }} + }} + }} + }} + }} + """, + ) + impl.func.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ] + return impl.func + + +@cache +def get_grouped_n_m_to_dense_impl_cuda(dense_dtype, val_shape, m): + assert len(val_shape) == 4 # (batch_dim, chunk=m!/n!/(m-n)!, group, block=n) + chunk_size = val_shape[1] + group_size = val_shape[2] + n = val_shape[3] + assert chunk_size == fact(m) // fact(n) // fact(m - n) + + order = make_n_m(n, m) + loop_outer_size = val_shape[0] + + dense_loop_outer_stride = chunk_size * group_size * m + sparse_loop_outer_stride = chunk_size * group_size * n + idx_loop_outer_stride = chunk_size * group_size * 1 + + dense_group_stride = m + sparse_group_stride = n + idx_group_stride = 1 + + assert dense_dtype in (torch.float32, torch.float64) + dtype = "float" if dense_dtype == torch.float32 else "double" + chunk_size = len(order) + + block_x = 8 + block_y = 8 + block_z = 4 + + grid_x = int_ceil(loop_outer_size, block_x) + grid_y = int_ceil(chunk_size, block_y) + grid_z = int_ceil(group_size, block_z) + + impl = compile( + f""" + #include + #include + #include + #include + #include + #define CUDA_CHECK(expr) do {{\\ + cudaError_t err = (expr);\\ + if (err != cudaSuccess) {{\\ + std::cerr << "CUDA ERROR: " << __FILE__ << ":" << __LINE__ << ": " << #expr << " <" << cudaGetErrorName(err) << "> " << cudaGetErrorString(err) << "\\n"; \\ + abort(); \\ + }}\\ + }} while(0) + __device__ static const int8_t nnz_order[{chunk_size}][{n}] = {make_n_m_order_c(order)}; + __global__ void kernel({dtype}* sparse, {dtype}* dense, int16_t* idx) {{ + int64_t os_idx = threadIdx.x + blockDim.x * blockIdx.x; + int64_t cs_idx = threadIdx.y + blockDim.y * blockIdx.y; + int64_t gs_idx = threadIdx.z + blockDim.z * blockIdx.z; + if (os_idx >= {loop_outer_size}) return; + if (cs_idx >= {chunk_size}) return; + if (gs_idx >= {group_size}) return; + {dtype}* dense_base = &dense[os_idx * {dense_loop_outer_stride}]; + {dtype}* sparse_base = &sparse[os_idx * {sparse_loop_outer_stride}]; + int16_t* idx_base = &idx[os_idx * {idx_loop_outer_stride}]; + int16_t sparse_idx = (gs_idx + cs_idx * {group_size}); + int16_t original_idx = idx_base[sparse_idx * {idx_group_stride}]; + for (int64_t n_idx = 0; n_idx < {n}; n_idx++) {{ + int64_t m_idx = nnz_order[cs_idx][n_idx]; + dense_base[m_idx + original_idx * {dense_group_stride}] = sparse_base[n_idx + sparse_idx * {sparse_group_stride}]; + }} + }} + extern "C" void func({dtype}* sparse, {dtype}* dense, int16_t* idx) {{ + kernel<<>>(sparse, dense, idx); + CUDA_CHECK(cudaPeekAtLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + }} + """, + lang="cu", + ) + impl.func.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ] + return impl.func + + +def grouped_n_m_to_dense(nm_strides, val, idx): + impl_bulder = ( + get_grouped_n_m_to_dense_impl_cpu + if val.device.type == "cpu" + else get_grouped_n_m_to_dense_impl_cuda + ) + func = impl_bulder( + val.dtype, + val.shape, + nm_strides["m"], + ) + + out = torch.zeros(nm_strides["merged_shape"], dtype=val.dtype, device=val.device) + if val.device.type == "cpu": + func(val.data_ptr(), out.data_ptr(), idx.data_ptr()) + else: + with torch.cuda.device(val.device): + # t1 = time.time() + func(val.data_ptr(), out.data_ptr(), idx.data_ptr()) + # t2 = time.time() + # print(f"sparse->dense kernel time {t2-t1:.2f}") + unmerged = out.reshape(nm_strides["permuted_shape"]) + unpermuted = unmerged.permute(nm_strides["inverse_dim_permutation"]) + unexpanded = unpermuted.reshape(nm_strides["padded_dense_shape"]) + unpadded = unpad_to(unexpanded, nm_strides["dense_shape"]) + + return unpadded + + +class GroupedNMTensor: + def __init__(self, val, idx, nm_strides): + self.val = val + self.idx = idx + self.nm_strides = nm_strides + + @staticmethod + def from_dense(tensor, n, m, sparse_dim, group_size, group_dim): + val, idx, nm_strides = dense_to_grouped_n_m( + tensor, n, m, sparse_dim, group_size, group_dim + ) + ten = GroupedNMTensor(val, idx, nm_strides) + ten.set_spmm_opt() + return ten + + def set_spmm_opt(self, tile_a=None, acc_width=None, tile_b=None, kernel=None): + flags = cpuinfo.get_cpu_info()["flags"] + if "avx512f" in flags: + kernel_autodetect = "avx512" + else: + kernel_autodetect = "avx2" + + self.nm_strides["tile_a"] = tile_a or 3 + self.nm_strides["acc_width"] = acc_width or 4 + self.nm_strides["tile_b"] = tile_b or 16 + self.nm_strides["kernel"] = kernel or kernel_autodetect + + def clone_layout(self, dense_tensor): + raise NotImplementedError( + "This function doesn't support proper device handling and should be removed" + ) + # copies format and locations of nonzeros + if dense_tensor.shape != self.nm_strides["dense_shape"]: + raise ValueError( + f"Shape mismatch: expected {self.nm_strides['dense_shape']}, received {dense_tensor.shape}" + ) + + exp_shape = self.nm_strides["expanded_dense_shape"] + exp_group_dim = self.nm_strides["exp_group_dim"] + exp_sparse_dim = self.nm_strides["exp_sparse_dim"] + + # step 1: reorder values but keep them non-sparsified + padded_dense = pad_to(dense_tensor, self.nm_strides["padded_dense_shape"]) + exp_indexed_shape = ( + *exp_shape[: exp_group_dim + 1], + exp_shape[exp_group_dim + 1] * exp_shape[exp_group_dim + 2], + *exp_shape[exp_group_dim + 3 :], + ) + broadcasted_idx = self.idx.to(device=dense_tensor.device).expand(exp_shape) + padded_dense_for_reordering = padded_dense.reshape(exp_indexed_shape) + broadcasted_idx_for_reordering = broadcasted_idx.reshape(exp_indexed_shape).to( + torch.int64 + ) + reordered_dense = padded_dense_for_reordering.gather( + dim=exp_group_dim + 1, index=broadcasted_idx_for_reordering + ) + reordered_dense = reordered_dense.reshape(exp_shape) + # step 2: drop values to match the nonzero mask + order = torch.tensor(self.nm_strides["order"], device=dense_tensor.device) + if exp_sparse_dim < exp_group_dim: + order = order.t() + singular_idx_shape = [1 for _ in exp_shape] + singular_idx_shape[exp_group_dim + 1] = self.nm_strides["chunk_size"] + singular_idx_shape[exp_sparse_dim + 1] = self.nm_strides["n"] + sparse_indices = order.reshape(singular_idx_shape).expand_as( + self.val.to(device=dense_tensor.device) + ) + reordered_sparsified = reordered_dense.gather( + dim=exp_sparse_dim + 1, index=sparse_indices + ) + return GroupedNMTensor( + reordered_sparsified.to(device="cpu"), + self.idx.to(device="cpu"), + self.nm_strides, + ) + + def to_dense(self): + return grouped_n_m_to_dense( + self.nm_strides, + self.val, + self.idx, + ) + + +class FixedMaskTensor: + def __init__(self, val, mask, n, m, g): + assert torch.all( + torch.isclose(mask, torch.zeros_like(mask)) + | torch.isclose(mask, torch.ones_like(mask)) + ) + self.val = val + self.mask = mask + self.n = n + self.m = m + self.g = g + + @staticmethod + def from_dense(tensor, n, m, g): + mask = torch.where( + tensor.abs() < 1e-6, + torch.zeros_like(tensor, dtype=torch.bool), + torch.ones_like(tensor, dtype=torch.bool), + ) + return FixedMaskTensor(tensor * mask, mask, n, m, g) + + def to_dense(self): + return copy.deepcopy(self.val) + + def numel(self): + return self.val.numel() + + def to(self, device=None, dtype=None, non_blocking=False, copy=False): + return FixedMaskTensor( + self.val.to(device=device, dtype=dtype, copy=True), + self.mask.to(device=device, dtype=dtype, copy=True), + self.n, + self.m, + self.g, + ) + + @property + def shape(self): + return self.val.shape + + @property + def device(self): + return self.val.device + + @property + def dtype(self): + return self.val.dtype + + +class PerfectNMTensor: + def __init__(self, val, idx, n, m, sparse_dim, sparse_dim_size): + self.val = val + self.idx = idx + self.n = n + self.m = m + self.sparse_dim = sparse_dim + self.sparse_dim_size = sparse_dim_size + + @staticmethod + def from_dense(dense_tensor, n, m, sparse_dim): + num_blocks = math.ceil(dense_tensor.shape[sparse_dim] / m) + padded_sparse_dim = num_blocks * m + padded_shape = [ + (padded_sparse_dim if idx == sparse_dim else dim) + for idx, dim in enumerate(dense_tensor.shape) + ] + padded_tensor = pad_to(dense_tensor, padded_shape) + expanded_shape = ( + dense_tensor.shape[:sparse_dim] + + (num_blocks, m) + + dense_tensor.shape[sparse_dim + 1 :] + ) + expanded_tensor = padded_tensor.reshape(expanded_shape) + sorted_indices = expanded_tensor.abs().argsort( + dim=sparse_dim + 1, descending=True + ) + sorted_vals = expanded_tensor.gather(dim=sparse_dim + 1, index=sorted_indices) + sparse_vals = sorted_vals.narrow(dim=sparse_dim + 1, start=0, length=n) + sparse_indices = sorted_indices.narrow(dim=sparse_dim + 1, start=0, length=n) + return PerfectNMTensor( + sparse_vals, + sparse_indices, + n, + m, + sparse_dim, + dense_tensor.shape[sparse_dim], + ) + + def to_dense(self): + exp_shape = self.val.shape + exp_shape = ( + exp_shape[: self.sparse_dim + 1] + + (self.m,) + + exp_shape[self.sparse_dim + 2 :] + ) + dense = torch.zeros(exp_shape, dtype=self.val.dtype, device=self.val.device) + dense.scatter_(dim=self.sparse_dim + 1, index=self.idx, src=self.val) + padded_shape = ( + exp_shape[: self.sparse_dim] + + (exp_shape[self.sparse_dim] * self.m,) + + exp_shape[self.sparse_dim + 2 :] + ) + unpadded = dense.reshape(padded_shape).narrow( + dim=self.sparse_dim, start=0, length=self.sparse_dim_size + ) + return unpadded + + +def is_correct_nm(original_dense, sparsified_dense, sparse_dim, n, m): + is_sparsified = (original_dense == sparsified_dense) | ( + sparsified_dense == torch.zeros_like(sparsified_dense) + ) + if not is_sparsified.all(): + return False + shape = original_dense.shape + padded_shape = ( + shape[:sparse_dim] + + (math.ceil(shape[sparse_dim] / m) * m,) + + shape[sparse_dim + 1 :] + ) + padded_sparsified = pad_to(sparsified_dense, padded_shape) + expanded_shape = ( + shape[:sparse_dim] + + (math.ceil(shape[sparse_dim] / m), m) + + shape[sparse_dim + 1 :] + ) + expanded_sparsified = padded_sparsified.reshape(expanded_shape) + nnz_per_block = expanded_sparsified.bool().sum(dim=sparse_dim + 1) + if not (nnz_per_block <= n).all(): + return False + return True diff --git a/src/sten/grouped_nm/matmul_generator.py b/src/sten/grouped_nm/matmul_generator.py new file mode 100644 index 0000000..0a31d3a --- /dev/null +++ b/src/sten/grouped_nm/matmul_generator.py @@ -0,0 +1,1659 @@ +#!/usr/bin/env python + +from textwrap import dedent +import subprocess +import jinja2 +import itertools +import hashlib +import json +from math import factorial as fact +import math +import numpy as np +import sys +import pathlib +import argparse + + +def strides_from_shape(shape, custom={}): + result = [custom[len(shape) - 1] if (len(shape) - 1) in custom else "1"] + for i in reversed(range(0, len(shape) - 1)): + current_stride = ( + custom[i] if (i in custom) else (shape[i + 1] + " * " + result[0]) + ) + result.insert(0, current_stride) + return result + + +def size_from_shape(shape): + return " * ".join(shape) + + +def linear_index(index, strides): + assert len(index) == len(strides) + return " + ".join(f"{i} * {s}" for i, s in zip(reversed(index), reversed(strides))) + + +def modify_access(access, replacements): + for k, v in replacements.items(): + access = access.replace(k, str(v)) + return access + + +def is_increasing(l): + return all([a < b for a, b in zip(l[:-1], l[1:])]) + + +def is_adjacent(n1, n2): + return sum([i1 != i2 for i1, i2 in zip(n1, n2)]) == 1 + + +def find_path(sources, remaining, nodes): + # nodes - list of available nodes + # sources - indices of nodes which can be taken first + # remaining - indices of nodes that can be used as targets + targets_per_src = {} + for src in sources: + next_rem = remaining - {src} + targets = set(tgt for tgt in next_rem if is_adjacent(nodes[src], nodes[tgt])) + targets_per_src[src] = targets + # heuristic: first try sources with the most targets + targets_per_src = { + k: v for k, v in sorted(targets_per_src.items(), key=lambda x: len(x[1])) + } + for src, targets in targets_per_src.items(): + next_rem = remaining - {src} + for next_path in find_path(targets, next_rem, nodes): + yield [nodes[src]] + next_path + if not remaining: + yield [] + + +def make_blk_to_idx_list(m, n): + assert m < n + range_1 = list(range(n)) + range_n = [range_1[:] for _ in range(m)] + nodes = [idx for idx in itertools.product(*range_n) if is_increasing(idx)] + path = next(find_path(set(range(len(nodes))), set(range(len(nodes))), nodes)) + return path + + +def infer_strides(shape, stride): + assert len(shape) == len(stride) + # replaces None in stride assuming that current stride is equal to the current dimension + stride = list(stride) + for i in range(len(shape) - 1, -1, -1): + if stride[i] is None: + if i == len(shape) - 1: + stride[i] = "1" + else: + stride[i] = f"{shape[i + 1]} * {stride[i + 1]}" + return stride + + +def infer_reordered_shape(dims, pos_to_idx): + n = len(dims) + strides = [None for _ in range(n)] + idx_to_pos = [None for i in range(n)] + for pos, idx in enumerate(pos_to_idx): + idx_to_pos[idx] = pos + for idx in range(n - 1, -1, -1): + pos = idx_to_pos[idx] + if idx == n - 1: + strides[pos] = "1" + else: + pos_next = idx_to_pos[idx + 1] + strides[pos] = f"{dims[pos_next]} * {strides[pos_next]}" + return list(zip(dims, strides)) + + +class Array: + def __init__(self, dtype, name, shape, align=None, base_offset=0): + self.dtype = dtype + self.name = name + shape = [(x if isinstance(x, tuple) else (x, None)) for x in shape] + self.dims, self.strides = zip(*shape) + self.strides = infer_strides(self.dims, self.strides) + self.align = align + self.base_offset = base_offset + + def __getitem__(self, key): + if isinstance(key, (int, str)): + key = (key,) + if isinstance(key, tuple): + if all(isinstance(k, (str, int)) for k in key): + if "" in key: + return self.subarray([(k if k != "" else None) for k in key]) + else: + return self.access(*key) + raise IndexError(f"Can't get the item from array by key {key}") + + def __call__(self, *key): + assert all(isinstance(k, tuple) for k in key) + return self.tiled([(k if k != () else None) for k in key]) + + def size(self): + return math.prod(self.dims) + + def decl(self): + if self.align is None: + self.align = 1 + align_str = f" __attribute__ ((aligned ({self.align})))" if self.align else "" + return f"{self.dtype} {self.name}[{self.size()}]{align_str};" + return result + + def lin_idx(self, index): + assert all([d != 0 for d in self.dims]) + if len(index) != len(self.strides): + raise ValueError("Shape mismatch") + res = " + ".join([f"{s} * {i}" for s, i in zip(self.strides, index)]) + res = f"{res} + {self.base_offset}" + return res + + def access(self, *index): + assert all([d != 0 for d in self.dims]) + return f"{self.name}[{self.lin_idx(index)}]" + + def subarray(self, index): + assert all([d != 0 for d in self.dims]) + # index = ('dn2', 'dn3', 0, 'dk23_in_B', None, 0) + assert len(index) == len(self.strides) + partial_idx_list = [ + (i, s) for i, s in zip(index, self.strides) if i is not None + ] + partial_idx = " + ".join([f"{s} * {i}" for i, s in partial_idx_list]) + new_shape = [ + (d, s) for i, d, s in zip(index, self.dims, self.strides) if i is None + ] + return Array( + self.dtype, + self.name, + new_shape, + align=self.align, + base_offset=f"{self.base_offset} + {partial_idx}", + ) + + def materialize(self, name): + subarray_decl = f"{self.dtype}* {name} = &{self.name}[{self.base_offset}];" + new_array = Array( + self.dtype, + name, + zip(self.dims, self.strides), + align=self.align, + base_offset=0, + ) + return subarray_decl, new_array + + def tiled(self, tiles): + # original dim [M, N, K] + # tiles [(M1, M2), (N1, N2, N3), None] + assert len(tiles) == len(self.strides) + new_shape = [] + for tile_shapes, last_stride, original_dim in zip( + tiles, self.strides, self.dims + ): + if tile_shapes is None: + tile_shapes = (original_dim,) + tile_strides = [None for _ in tile_shapes] + tile_strides[-1] = last_stride + tile_strides = infer_strides(tile_shapes, tile_strides) + new_shape += [(d, s) for d, s in zip(tile_shapes, tile_strides)] + return Array( + self.dtype, + self.name, + new_shape, + align=self.align, + base_offset=self.base_offset, + ) + + def offset(self, *offsets): + assert len(offsets) == len(self.dims) + new_base_offset = self.lin_idx(offsets) + new_shape = [ + (f"({d} - {o})", s) for d, o, s in zip(self.dims, offsets, self.strides) + ] + return Array( + self.dtype, + self.name, + new_shape, + align=self.align, + base_offset=new_base_offset, + ) + + def indexed_ofset(self, dim_idx, offset): + offsets = [0 for _ in self.dims] + offsets[dim_idx] = offset + return self.offset(*offsets) + + def gtile(self, dim_idx, tile_size): + assert all([d != 0 for d in self.dims]) + N = self.dims[dim_idx] + N2 = tile_size + N1 = N // N2 + N2H = N % N2 + + A_shape = list(zip(self.dims, self.strides)) + A_stride = self.strides[dim_idx] + + AF_shape = ( + A_shape[:dim_idx] + + [(N1, f"{A_stride} * {N2}"), (N2, A_stride)] + + A_shape[dim_idx + 1 :] + ) + AR_shape = ( + A_shape[:dim_idx] + + [(1, f"{A_stride} * {N2}"), (N2H, A_stride)] + + A_shape[dim_idx + 1 :] + ) + + AF = Array( + self.dtype, + self.name, + AF_shape, + align=self.align, + base_offset=self.base_offset, + ) + + AR = self.indexed_ofset(dim_idx, N1 * N2) + AR = Array( + self.dtype, + self.name, + AR_shape, + align=self.align, + base_offset=AR.base_offset, + ) + + # assert all([d != 0 for d in AF.dims]) + # assert all([d != 0 for d in AR.dims]) + + # return full array, remainder array, full size, remainder size + return AF, AR, N1, N2H + + def shorten(self, *new_dims): + # new_dims: [None, 10, None, 20, None] + assert len(new_dims) == len(self.dims) + + new_dims = list(new_dims) + for i, (old_dim, new_dim) in enumerate(zip(self.dims, new_dims)): + if new_dim is not None: + assert new_dim <= old_dim + else: + new_dims[i] = old_dim + return Array( + self.dtype, + self.name, + list(zip(new_dims, self.strides)), + align=self.align, + base_offset=self.base_offset, + ) + + +class IdGenerator: + def __init__(self): + self.counters = {} + + def make(self, name): + id = self.counters.get(name, 0) + self.counters[name] = id + 1 + return f"{name}_{id}" + + +def loop_nest(dims, reorder=None): + # dims = [('dn2', 'DN2'), ('dm2', 'DM2') ('dn3', 'DN3')] + # reorder = [0, 2, 1] + if reorder is None: + reorder = list(range(len(dims))) + reordered_dims = [dims[i] for i in reorder] + return " ".join([f"FOR({idx}, {num})" for idx, num in reordered_dims]) + + +def stringify(params): + if isinstance(params, dict): + return "_".join([stringify(k) + "_" + stringify(v) for k, v in params.items()]) + elif isinstance(params, str): + return params + elif isinstance(params, list): + return "_".join([stringify(p) for p in params]) + else: + res = str(params) + if " " in res: + raise ValueError(f"Stringified <{res}> contains not allowed symbols") + return res + + +def make_template(x): + return jinja2.Template( + dedent(x), + trim_blocks=True, + lstrip_blocks=True, + undefined=jinja2.StrictUndefined, + ) + + +def transpose_4x4(A, B): + return make_template( + """\ + __m128 x0 = _mm_loadu_ps(&{{ A.access(0, 0) }}); + __m128 x1 = _mm_loadu_ps(&{{ A.access(1, 0) }}); + __m128 x2 = _mm_loadu_ps(&{{ A.access(2, 0) }}); + __m128 x3 = _mm_loadu_ps(&{{ A.access(3, 0) }}); + __m128 y0 = _mm_unpacklo_ps(x0, x1); + __m128 y1 = _mm_unpackhi_ps(x0, x1); + __m128 y2 = _mm_unpacklo_ps(x2, x3); + __m128 y3 = _mm_unpackhi_ps(x2, x3); + __m128 z0 = _mm_movelh_ps(y0, y2); + __m128 z1 = _mm_movehl_ps(y2, y0); + __m128 z2 = _mm_movelh_ps(y1, y3); + __m128 z3 = _mm_movehl_ps(y3, y1); + _mm_store_ps(&{{ B.access(0, 0) }}, z0); + _mm_store_ps(&{{ B.access(1, 0) }}, z1); + _mm_store_ps(&{{ B.access(2, 0) }}, z2); + _mm_store_ps(&{{ B.access(3, 0) }}, z3); + """ + ).render({**globals(), **locals()}) + + +def dense_gemm_kernel(A1, B1, C, params): + # requires DK2, DK3, DN4, DN5, DM3 + # A1 [DK2, DK3, DM3] + # B1 [DK2, DK3, DN4, DN5] + # C [DM3, DN4, DN5] + DK2 = params[A1.dims[0]] + DK3 = params[A1.dims[1]] + DN4 = params[B1.dims[2]] + DN5 = params[B1.dims[3]] + DM3 = params[A1.dims[2]] + assert DM3 % 2 == 0 + return make_template( + """\ + {% for dm3 in range(DM3) %} + {% for dn4 in range(DN4) %} + __m256 acc_{{ dm3 }}_{{ dn4 }} = _mm256_setzero_ps(); + {% endfor %} + {% endfor %} + FOR (dk2, DK2) { + {% for dk3 in range(DK3) %} + { + {% for dn4 in range(DN4) %} + __m256 vb_{{ dn4 }} = _mm256_loadu_ps(&{{ B1.access('dk2', dk3, dn4, 0) }}); + __m256 vb_0_{{ dn4 }} = _mm256_moveldup_ps(vb_{{ dn4 }}); + __m256 vb_1_{{ dn4 }} = _mm256_movehdup_ps(vb_{{ dn4 }}); + {% endfor %} + {% for dm3 in range(DM3 // 2) %} + { + __m256 va = (__m256)_mm256_broadcast_sd((double*)&{{ A1.access('dk2', dk3, dm3 * 2) }}); + {% for dn4 in range(DN4) %} + acc_{{ dm3 * 2 + 0 }}_{{ dn4 }} = _mm256_fmadd_ps(va, vb_0_{{ dn4 }}, acc_{{ dm3 * 2 + 0 }}_{{ dn4 }}); + acc_{{ dm3 * 2 + 1 }}_{{ dn4 }} = _mm256_fmadd_ps(va, vb_1_{{ dn4 }}, acc_{{ dm3 * 2 + 1 }}_{{ dn4 }}); + {% endfor %} + } + {% endfor %} + } + {% endfor %} + } + {% for dm3 in range(DM3 // 2) %} + {% for dn4 in range(DN4) %} + { + float* addr0 = &{{ C.access(dm3 * 2 + 0, dn4, 0) }}; + float* addr1 = &{{ C.access(dm3 * 2 + 1, dn4, 0) }}; + + __m256 x0 = acc_{{ dm3 * 2 + 0 }}_{{ dn4 }}; + __m256 x1 = acc_{{ dm3 * 2 + 1 }}_{{ dn4 }}; + __m256 y0 = _mm256_unpacklo_ps(x0, x1); + __m256 y1 = _mm256_unpackhi_ps(x0, x1); + __m256 z0 = (__m256)_mm256_unpacklo_pd((__m256d) y0, (__m256d) y1); + __m256 z1 = (__m256)_mm256_unpackhi_pd((__m256d) y0, (__m256d) y1); + + _mm256_storeu_ps(addr0, _mm256_add_ps(_mm256_loadu_ps(addr0), z0)); + _mm256_storeu_ps(addr1, _mm256_add_ps(_mm256_loadu_ps(addr1), z1)); + } + {% endfor %} + {% endfor %} + """ + ).render({**globals(), **locals()}) + + +def make_load_mask(m, n): + assert m < n + return ", ".join([("-1" if i < m else "0") for i in range(n)]) + + +def params_hash(params): + name = stringify(params) + return hashlib.sha256(name.encode()).hexdigest()[:7] + + +def benchmark_parametrized_dense_dense(params): + name = stringify(params) + + encoded_name = params_hash(params) + + print(f"Configuration hash: {encoded_name} Params: {json.dumps(params)}") + + signature = f"void dense_dense_{encoded_name}(int64_t M, int64_t N, int64_t K, float* __restrict__ A, int64_t lda, float* __restrict__ B, int64_t ldb, float* __restrict__ C, int64_t ldc)" + + target_M = params["target_M"] + target_N = params["target_N"] + + DK = params["DK"] + DK23 = params["DK2"] * params["DK3"] + DK3 = params["DK3"] + + DM3 = params["DM3"] + + DM4C = 4 + assert DM3 % DM4C == 0 + DM3C = DM3 // DM4C + + DK4C = 4 + DK3C = 1 + # assert (DK23) % (DK3C * DK4C) == 0 + DK2C = (DK23) // (DK3C * DK4C) + + DM2 = params["DM2"] + + DN5 = params["DN5"] + DN4 = params["DN4"] + DN3 = params["DN3"] + DN2 = params["DN2"] + + DM1_ = DM2 * DM3 + DN1_ = DN2 * DN3 * DN4 * DN5 + + DM1 = max(round(target_M * 1.0 / DM1_), 1) + DN1 = max(round(target_N * 1.0 / DN1_), 1) + + M = DM1 * DM2 * DM3 + N = DN1 * DN2 * DN3 * DN4 * DN5 + K = params["DK"] + + loop_order_copy_b = params["loop_order_copy_b"] + loop_order_copy_a = params["loop_order_copy_a"] + loop_order_inner = params["loop_order_inner"] + + B1_layout = params["B1_layout"] + A1_layout = params["A1_layout"] + + id_gen = IdGenerator() + + template = jinja2.Template( + dedent( + """\ + // WARNING! THIS IS GENERATED FILE! DO NOT EDIT! + + #include "dense_dense_{{ encoded_name }}.h" + + #include + #include + #include + #include + + #define FOR(i, n) for (int64_t i = 0; i < (n); i++) + + {% macro transpose_4x4_(A, B) %} + __m128 x0 = _mm_loadu_ps(&{{ A[0, 0] }}); + __m128 x1 = _mm_loadu_ps(&{{ A[1, 0] }}); + __m128 x2 = _mm_loadu_ps(&{{ A[2, 0] }}); + __m128 x3 = _mm_loadu_ps(&{{ A[3, 0] }}); + __m128 y0 = _mm_unpacklo_ps(x0, x1); + __m128 y1 = _mm_unpackhi_ps(x0, x1); + __m128 y2 = _mm_unpacklo_ps(x2, x3); + __m128 y3 = _mm_unpackhi_ps(x2, x3); + __m128 z0 = _mm_movelh_ps(y0, y2); + __m128 z1 = _mm_movehl_ps(y2, y0); + __m128 z2 = _mm_movelh_ps(y1, y3); + __m128 z3 = _mm_movehl_ps(y3, y1); + _mm_store_ps(&{{ B[0, 0] }}, z0); + _mm_store_ps(&{{ B[1, 0] }}, z1); + _mm_store_ps(&{{ B[2, 0] }}, z2); + _mm_store_ps(&{{ B[3, 0] }}, z3); + {% endmacro %} + + {% macro kernel_acc_mn_vec_(A, B) %}{ + {% set DM, DN, VEC = A.dims[0], B.dims[0], 8 %} + {% set DN1, DN2R = math.ceil(DN / VEC), DN % VEC %} + {% for dn1 in range(DN1) %} + {% if (dn1 < DN1 - 1) or (DN2R == 0) %} + __m256 vb_{{ dn1 }} = _mm256_loadu_ps(&{{ B[dn1 * VEC] }}); + {% else %} + __m256 vb_{{ dn1 }} = _mm256_maskload_ps(&{{ B[dn1 * VEC] }}, load_mask); + {% endif %} + {% endfor %} + {% for dm in range(DM) %}{ + __m256 va = _mm256_broadcast_ss(&{{ A[dm] }}); + {% for dn1 in range(DN1) %}{ + acc_{{ dm }}_{{ dn1 }} = _mm256_fmadd_ps(va, vb_{{ dn1 }}, acc_{{ dm }}_{{ dn1 }}); + }{% endfor %} + }{% endfor %} + }{% endmacro %} + + {% macro dense_gemm_kernel_(A1, B1, C) %}{ + {% set DK23, DN45, DM3, VEC = A1.dims[0], B1.dims[1], A1.dims[1], 8 %} + {% set DN4, DN5R = math.ceil(DN45 / VEC), (DN45 % VEC) %} + // init accumulators + {% for dm3 in range(DM3) %} + {% for dn4 in range(DN4) %} + __m256 acc_{{ dm3 }}_{{ dn4 }} = _mm256_setzero_ps(); + {% endfor %} + {% endfor %} + {% if DN5R != 0 %}__m256i load_mask = _mm256_setr_epi32({{ make_load_mask(DN5R, VEC) }});{% endif %} + {% set A1F, A1R, DK2F, DK3R = A1.gtile(0, params["DK3"]) %} + {% set B1F, B1R, DK2F, DK3R = B1.gtile(0, params["DK3"]) %} + {% set dk2 = id_gen.make("dk2") %} + FOR ({{ dk2 }}, {{ DK2F }}) { + {% for dk3 in range(DK3) %}{ + {{ kernel_acc_mn_vec_(A1F[dk2, dk3, ''], B1F[dk2, dk3, '']) }} + }{% endfor %} + } + {% if DK3R != 0 %}{ + {% for dk3 in range(DK3R) %}{ + {{ kernel_acc_mn_vec_(A1R[0, dk3, ''], B1R[0, dk3, '']) }} + }{% endfor %} + }{% endif %} + {% for dm3 in range(DM3) %} + {% for dn4 in range(DN4) %}{ + float* addr = &{{ C[dm3, dn4 * VEC] }}; + {% if (dn4 < DN4 - 1) or (DN5R == 0) %} + _mm256_storeu_ps(addr, _mm256_add_ps(_mm256_loadu_ps(addr), acc_{{ dm3 }}_{{ dn4 }})); + {% else %} + _mm256_maskstore_ps(addr, load_mask, _mm256_add_ps(_mm256_maskload_ps(addr, load_mask), acc_{{ dm3 }}_{{ dn4 }})); + {% endif %} + }{% endfor %} + {% endfor %} + }{% endmacro %} + + {% macro dense_gemm_kernel_simple_(A1, B1, C) %}{ + {% set DK2, DK3, DN45, DM3 = A1.dims[0], A1.dims[1], B1.dims[2], A1.dims[2] %} + FOR(dk2, {{ DK2 }}) FOR(dk3, {{ DK3 }}) FOR(dm3, {{ DM3 }}) FOR(dn45, {{ DN45 }}) { + {{ C['dm3', 0, 'dn45'] }} += {{ A1['dk2', 'dk3', 'dm3'] }} * {{ B1['dk2', 'dk3', 'dn45'] }}; + } + }{% endmacro %} + + {{ signature }} { + + const int64_t DN5 = {{ DN5 }}; + const int64_t DN4 = {{ DN4 }}; + const int64_t DN3 = {{ DN3 }}; + const int64_t DN2 = {{ DN2 }}; + + const int64_t DM4C = {{ DM4C }}; + const int64_t DM3C = {{ DM3C }}; + + const int64_t DM3 = {{ DM3 }}; + + const int64_t DM2 = {{ DM2 }}; + + const int64_t DK4C = {{ DK4C }}; + const int64_t DK3C = {{ DK3C }}; + const int64_t DK2C = {{ DK2C }}; + + {% set A = Array('float', 'A', [(M, 'lda'), K]) %} + + {% set B = Array('float', 'B', [(K, 'ldb'), N]) %} + {% set C = Array('float', 'C', [DM1 * DM2, (DM3, 'ldc'), DN1, DN2 * DN3, DN4 * DN5]) %} + + {% set B1 = Array('float', 'B1', [DN2 * DN3, DK23, DN4 * DN5], align=32) %} + {% set A1 = Array('float', 'A1', [DM1 * DM2, DK23, DM3], align=32) %} + + {{ A1.decl() }} + {{ B1.decl() }} + + {% macro process_dk1_(A, B, C, A1, B1) %} + {% set DK23 = A.dims[1] %} + {% set dn1 = id_gen.make("dn1") %} + FOR ({{ dn1 }}, {{ DN1 }}) { + // copy tile of B into higher level of cache + {% macro copy_B(B, B1) %} + {% set N23, K23, N45 = B1.dims %} + {% set dk23, dn23, dn45 = id_gen.make("dk23"), id_gen.make("dn23"), id_gen.make("dn45") %} + #pragma omp for + FOR({{ dk23 }}, {{ K23 }}) FOR({{ dn23 }}, {{ N23 }}) FOR({{ dn45 }}, {{ N45 }}) { + {{ B1[dn23, dk23, dn45] }} = {{ B((), (N23, N45))[dk23, dn23, dn45] }}; + } + {% endmacro %} + #pragma omp parallel + { + {{ copy_B(B((), (DN1, DN1_))['', dn1, ''], B1) }} + if ({{ dn1 }} == 0) { + {% set dm1 = id_gen.make("dm1") %} + #pragma omp for + FOR ({{ dm1 }}, {{ DM1 }}) { + // copy tile of A into higher level of cache + {% set dm2, dm3, dk23 = id_gen.make("dm2"), id_gen.make("dm3"), id_gen.make("dk23") %} + FOR({{ dm2 }}, {{ DM2 }}) FOR({{ dm3 }}, {{ DM3 }}) FOR({{ dk23 }}, {{ DK23 }}) { + {{ A1((DM1, DM2), (), ())[dm1, dm2, dk23, dm3] }} = {{ A((DM1, DM2, DM3), ())[dm1, dm2, dm3, dk23] }}; + } + // compute + {% set dm2, dn23 = id_gen.make("dm2"), id_gen.make("dn23") %} + FOR({{ dm2 }}, {{ DM2 }}) FOR({{ dn23 }}, {{ DN2 * DN3 }}) { + {{ dense_gemm_kernel_(A1((DM1, DM2), (), ())[dm1, dm2, '', ''], + B1[dn23, '', ''], + C((DM1, DM2), (), (), (), ())[dm1, dm2, '', dn1, dn23, '']) }} + } + } + } else { + // compute + {% set dm12, dn23 = id_gen.make("dm12"), id_gen.make("dn23") %} + #pragma omp for + FOR({{ dm12 }}, {{ DM1 * DM2 }}) FOR({{ dn23 }}, {{ DN2 * DN3 }}) { + {{ dense_gemm_kernel_(A1[dm12, '', ''], + B1[dn23, '', ''], + C[dm12, '', dn1, dn23, '']) }} + } + } + } // pragma omp parallel + } + {% endmacro %} + {% set AF, AR, DK1, DK2R = A.gtile(1, DK23) %} + {% set BF, BR, DK1, DK2R = B.gtile(0, DK23) %} + {% set dk1 = id_gen.make("dk1") %} + FOR({{ dk1 }}, {{ DK1 }}) { + {{ process_dk1_(AF['', dk1, ''], BF[dk1, '', ''], C, A1, B1) }} + } + {% if DK2R != 0 %}{ + {{ process_dk1_(AR['', 0, ''], BR[0, '', ''], C, A1.shorten(None, DK2R, None), B1.shorten(None, DK2R, None)) }} + }{% endif %} + } + """ + ), + trim_blocks=True, + lstrip_blocks=True, + undefined=jinja2.StrictUndefined, + ) + code = template.render({**globals(), **locals()}) + + template = jinja2.Template( + dedent( + """\ + // WARNING! THIS IS GENERATED FILE! DO NOT EDIT! + + #pragma once + + #include + + #ifdef __cplusplus + extern "C" { + #endif + + /* + Configuration: + {{ json.dumps(params) }} + */ + + {{ signature }}; + + #ifdef __cplusplus + } // extern "C" + #endif + """ + ) + ) + header = template.render({**globals(), **locals()}) + + with open(f"generated/dense_dense_{encoded_name}.c", "w") as f: + f.write(code) + with open(f"generated/dense_dense_{encoded_name}.h", "w") as f: + f.write(header) + + main_template = jinja2.Template( + dedent( + """\ + #include "openblas/cblas.h" + + #include "common.hpp" + + #include "generated/dense_dense_{{ encoded_name }}.h" + + int main() { + int64_t M = {{ DM1 * DM2 * DM3 }}; + int64_t K = {{ DK }}; + int64_t N = {{ DN1 * DN2 * DN3 * DN4 * DN5 }}; + + std::vector dA = rand_gen(M * K); + std::vector dB = rand_gen(K * N); + + std::vector dCd(M * N); + + std::vector dCd_ref(M * N); + + moment t1, t2; + + { + t1 = timer::now(); + cblas_sgemm( + CblasRowMajor, // CBLAS_LAYOUT layout, + CblasNoTrans, // CBLAS_TRANSPOSE TransA, + CblasNoTrans, // CBLAS_TRANSPOSE TransB, + M, // const CBLAS_INDEX M, + N, // const CBLAS_INDEX N, + K, // const CBLAS_INDEX K, + 1.0, // const float alpha, + dA.data(), // const float *A, + K, // const CBLAS_INDEX lda, + dB.data(), // const float *B, + N, // const CBLAS_INDEX ldb, + 0.0, // const float beta, + dCd_ref.data(), // float *C, + N // const CBLAS_INDEX ldc + ); + t2 = timer::now(); + printf("openblas_dense_dense seconds %.3g ns_per_fma %.3g\\n", seconds(t2 - t1), seconds(t2 - t1) / (M * N * K) * 1e9); + } + + { + std::fill(dCd.begin(), dCd.end(), 0.0f); + t1 = timer::now(); + dense_dense_{{ encoded_name }}( + M, // const CBLAS_INDEX M, + N, // const CBLAS_INDEX N, + K, // const CBLAS_INDEX K, + dA.data(), // const float *A, + K, // const CBLAS_INDEX lda, + dB.data(), // const float *B, + N, // const CBLAS_INDEX ldb, + dCd.data(), // float *C, + N // const CBLAS_INDEX ldc + ); + t2 = timer::now(); + printf("my_dense_dense seconds %.3g ns_per_fma %.3g\\n", seconds(t2 - t1), seconds(t2 - t1) / (M * N * K) * 1e9); + CHECK(vector_allclose(dCd, dCd_ref)); + } + } + """ + ), + trim_blocks=True, + lstrip_blocks=True, + undefined=jinja2.StrictUndefined, + ) + main = main_template.render({**globals(), **locals()}) + + with open(f"generated/dense_dense_{encoded_name}_main.cpp", "w") as f: + f.write(main) + + subprocess.run( + f"gcc -march=native -g -O3 -fopenmp -lpthread generated/dense_dense_{encoded_name}.c -c -o generated/dense_dense_{encoded_name}.o".split(), + check=True, + ) + subprocess.run( + f"g++ -march=native -g -O3 -fopenmp -lpthread -I. -I./OpenBLAS/install/include generated/dense_dense_{encoded_name}_main.cpp ./OpenBLAS/install/lib/libopenblas.a -lpthread generated/dense_dense_{encoded_name}.o -o generated/dense_dense_{encoded_name}_main".split(), + check=True, + ) + subprocess.run(f"generated/dense_dense_{encoded_name}_main", check=True) + + +def generate_parametrized_sparse_dense(params): + params = derive_params(params) + encoded_name = params_hash(params) + + print( + f"Configuration hash: {encoded_name} Params: {json.dumps(params)}", + file=sys.stderr, + ) + + signature = f"void sparse_dense_{encoded_name}(float* __restrict__ A_val, int16_t* __restrict__ A_idx, float* __restrict__ B, float* __restrict__ C)" + + DK3 = params["DK3"] # + DK2 = ( + fact(params["DM3"]) + // fact(params["DM3S"]) + // fact(params["DM3"] - params["DM3S"]) + ) # number of sparse groups + params["DK2"] = DK2 + + DM3 = params["DM3"] # accumulator tile size (should be 4 for 4:2 pattern) + DM3S = params["DM3S"] # sparse tile size + DM2 = params["DM2"] # cache tile size + + DN5 = params["DN5"] # vector size + DN4 = params["DN4"] # accumulator tile size + DN3 = params["DN3"] # unused for now + DN2 = params["DN2"] # cache tile size + params["DN45"] = DN4 * DN5 + params["DM23S"] = DM2 * DM3S + params["DM23"] = DM2 * DM3 + params["DN2345"] = DN2 * DN3 * DN4 * DN5 + params["DK23"] = DK2 * DK3 + + DN = params["DN"] + DK = params["DK"] + DM = params["DM"] + bti = make_blk_to_idx_list(DM3S, DM3) + + params["DMI_padded"] = math.ceil(DM / DM3) + params["DM_padded"] = params["DMI_padded"] * DM3 + params["DMS_padded"] = params["DMI_padded"] * DM3S + + params["DK_padded"] = math.ceil(DK / (DK2 * DK3)) * (DK2 * DK3) + + bti_str = ( + "{" + + ", ".join( + "{" + ", ".join(str(idx) for idx in idx_tuple) + "}" for idx_tuple in bti + ) + + "}" + ) + + # strides + SAMS = params["SAMS"] + SAKS = params["SAKS"] + SAMI = params["SAMI"] + SAKI = params["SAKI"] + SBK = params["SBK"] + SBN = params["SBN"] + SCM = params["SCM"] + SCN = params["SCN"] + + id_gen = IdGenerator() + + template = jinja2.Template( + dedent( + """\ + // WARNING! THIS IS GENERATED FILE! DO NOT EDIT! + + #include "sparse_dense_{{ encoded_name }}.h" + + #include + #include + #include + #include + + #define FOR(i, n) for (int64_t i = 0; i < (n); i++) + + {{ signature }} { + + {% set A_val = Array('float', 'A_val', [(params["DMS_padded"], SAMS), (params["DK_padded"], SAKS)]) %} + {% set A_idx = Array('int16_t', 'A_idx', [(params["DMI_padded"], SAMI), (params["DK_padded"], SAKI)]) %} + // A_val.dims {{ A_val.dims }} A_idx.dims {{ A_idx.dims }} + + {% set B = Array('float', 'B', [(DK, SBK), (DN, SBN)]) %} + {% set C = Array('float', 'C', [(DM, SCM), (DN, SCN)]) %} + // C.dims {{ C.dims }} C.strides {{ C.strides }} + + {% set B1 = Array('float', 'B1', [DN2 * DN3, DK2 * DK3, DN4 * DN5], align=32) %} + {% set A1_val = Array('float', 'A1_val', [params["DMS_padded"], DK2 * DK3], align=32) %} + {% set A1_idx = Array('int16_t', 'A1_idx', [params["DMI_padded"], DK2 * DK3], align=32) %} + + {% macro sparse_dense_gemm_kernel_(A1_val, A1_idx, B1, C, in_remainder_DK2) %} + {% set DM3S_padded, DK23_padded = A1_val.dims %} + {% set (DK23_padded,) = A1_idx.dims %} + {% set DK23, DN45 = B1.dims %} + {% set DM3, DN45 = C.dims %} + {% set DK2 = DK23_padded // params['DK3'] %} + {% set DN4, DN5R = math.ceil(DN45 / params['DN5']), DN45 % params['DN5'] %} + // A1_val.dims {{ A1_val.dims }} A1_idx.dims {{ A1_idx.dims }} B1.dims {{ B1.dims }} + // init accumulators + {% for dm3s in range(DM3S_padded) %} + {% for dn4 in range(DN4) %} + __m256 acc_{{ dm3s }}_{{ dn4 }} = _mm256_setzero_ps(); + {% endfor %} + {% endfor %} + {% if DN5R != 0 %}__m256i load_mask = _mm256_setr_epi32({{ make_load_mask(DN5R, DN5) }});{% endif %} + {% for dk2 in range(DK2) %}{ + {% set si = bti[dk2] %} + {% set dk3 = id_gen.make("dk3") %} + {% set in_remainder_DK2 = (DK23_padded != DK23) %} + // in_remainder_DK2 = {{ in_remainder_DK2 }} dk2 {{ dk2 }} DK2 {{ DK2 }} DK23_padded {{ DK23_padded }} DK23 {{ DK23 }} + FOR({{dk3}}, {{DK3}}) { + {% set dk23 = id_gen.make("dk23") %} + int {{dk23}} = {{ dk2 }} * {{ DK3 }} + {{ dk3 }}; + {% set dk23_in_B = id_gen.make("dk23_in_B") %} + int16_t {{dk23_in_B}} = {{ A1_idx[dk23] }}; + {% if in_remainder_DK2 %} + // in dk2 remainder + if ({{dk23_in_B}} < 0) break; + {% endif %} + {%+ set DECL, B2 = B1[dk23_in_B, ''].materialize('B2') %}{{ DECL }} + {% for dm3s in range(DM3S_padded) %} + __m256 va_{{ dm3s }} = _mm256_broadcast_ss(&{{ A1_val[dm3s, dk23] }}); + {% endfor %} + {% for dn4 in range(DN4) %}{ + {% set in_remainder_DN5 = (dn4 == DN4 - 1) and (DN5R != 0) %} + {% if not in_remainder_DN5 %} + __m256 vb_{{ dn4 }} = _mm256_loadu_ps(&{{ B2[dn4 * DN5] }}); + {% else %} + __m256 vb_{{ dn4 }} = _mm256_maskload_ps(&{{ B2[dn4 * DN5] }}, load_mask); + {% endif %} + {% for dm3s in range(DM3S_padded) %} + acc_{{ dm3s }}_{{ dn4 }} = _mm256_fmadd_ps(va_{{ dm3s }}, vb_{{ dn4 }}, acc_{{ dm3s }}_{{ dn4 }}); + {% endfor %} + }{% endfor %} + } + {% for dm3s in range(DM3S_padded) %} + {% if (dk2 == DK2 - 1) or (bti[dk2][dm3s] != bti[dk2+1][dm3s]) %} + {% for dn4 in range(DN4) %}{ + {% set in_remainder_DN5 = (dn4 == DN4 - 1) and (DN5R != 0) %} + {% set dm3 = si[dm3s] %} + {% if dm3 < DM3 %} // cut out remainder DM3 + {% if SCN == 1 %} + float* c_addr = &{{ C[dm3, dn4 * DN5] }}; + {% if not in_remainder_DN5 %} + _mm256_storeu_ps(c_addr, _mm256_add_ps(_mm256_loadu_ps(c_addr), acc_{{ dm3s }}_{{ dn4 }})); + {% else %} + _mm256_maskstore_ps(c_addr, load_mask, _mm256_add_ps(_mm256_maskload_ps(c_addr, load_mask), acc_{{ dm3s }}_{{ dn4 }})); + {% endif %} + {% else %} + {% set dn5 = id_gen.make("dn5") %} + FOR ({{dn5}}, {{DN5R if in_remainder_DN5 else DN5}}) { + float* acc_ptr = (float*)&acc_{{ dm3s }}_{{ dn4 }}; + {{ C((), (DN4, DN5))[dm3, dn4, dn5] }} += acc_ptr[{{dn5}}]; + } + {% endif %} + {% endif %} + {% if (dk2 != DK2 - 1) %} + acc_{{ dm3s }}_{{ dn4 }} = _mm256_setzero_ps(); + {% endif %} + }{% endfor %} + {% endif %} + {% endfor %} + }{% endfor %} + {% endmacro %} + + {% macro copy_A(A1, A) %} + {% set DM2, DK23 = A.dims %} + // copy tile of A into higher level of cache + {% set dm2 = id_gen.make("dm2") %} + {% set dk23 = id_gen.make("dk23") %} + FOR({{dm2}}, {{DM2}}) FOR({{dk23}}, {{DK23}}) { + {{ A1[dm2, dk23] }} = {{ A[dm2, dk23] }}; + } + {% endmacro %} + + {% macro copy_B(B1, B) %} + {% set DN23A, DK23_padded, DN45 = B1.dims %} + {% set DK23, DN2345 = B.dims %} + {% set BF, BR, DN23, DN45R = B.gtile(1, DN45) %} + // B1.dims {{ B1.dims }} B.dims {{ B.dims }} + {% set dk23 = id_gen.make("dk23") %} + {% set dn23 = id_gen.make("dn23") %} + {% set dn45 = id_gen.make("dn45") %} + #pragma omp for + FOR({{dk23}}, {{DK23}}) { + {% if DN23 != 0 %} + FOR({{dn23}}, {{DN23}}) { + FOR({{dn45}}, {{DN45}}) { + {{ B1[dn23, dk23, dn45] }} = {{ BF[dk23, dn23, dn45] }}; + } + } + {% endif %} + {% if DN45R != 0 %} + FOR({{dn45}}, {{DN45R}}) { + {{ B1[DN23, dk23, dn45] }} = {{ BR[dk23, 0, dn45] }}; + } + {% endif %} + } + {% endmacro %} + + {% macro compute_loop_N_(A1_val, A1_idx, B1, C) %} + {% set DM3S, DK23 = A1_val.dims %} + {% set (DK23,) = A1_idx.dims %} + {% set DN23, DK23, DN45 = B1.dims %} + {% set DM3, DN2345 = C.dims %} + {% set CF, CR, DN23, DN45R = C.gtile(1, params['DN45']) %} + {% set dn23 = id_gen.make("dn23") %} + {% if DN23 != 0 %} + FOR({{dn23}}, {{DN23}}) { + {{ sparse_dense_gemm_kernel_(A1_val, A1_idx, B1[dn23, '', ''], CF['', dn23, '']) }} + } + {% endif %} + {% if DN45R != 0 %}{ + {{ sparse_dense_gemm_kernel_(A1_val, A1_idx, B1[DN23, '', ''], CR['', 0, '']) }} + }{% endif %} + {% endmacro %} + + {% macro compute_loop_(A1_val, A1_idx, B1, C, parallel) %} + {% set DM2A, DK23 = A1_idx.dims %} + {% set DM23SA, DK23 = A1_val.dims %} + {% set DM3S = DM23SA // DM2A %} // never has remainder + {% set DN23, DK23, DN45 = B1.dims %} + {% set DM23, DN2345 = C.dims %} + + {% set CF, CR, DM2, DM3R = C.gtile(0, params['DM3']) %} + + {% set dm2 = id_gen.make("dm2") %} + {% if DM2 != 0 %}{ + {% if parallel %} + #pragma omp for nowait + {% endif %} + FOR({{dm2}}, {{DM2}}) { + {{ compute_loop_N_(A1_val((DM2A, DM3S), ())[dm2, '', ''], A1_idx[dm2, ''], B1, CF[dm2, '', '']) }} + } + }{% endif %} + {% if DM3R != 0 %}{ + {% if parallel %} + #pragma omp single nowait + {% endif %} + { + {{ compute_loop_N_(A1_val((DM2A, DM3S), ())[DM2, '', ''], A1_idx[DM2, ''], B1, CR[0, '', '']) }} + } + }{% endif %} + {% endmacro %} + + {% macro compute_head_(A1_val, A1_idx, A_val, A_idx, B1, C) %} + {% set DM123S, DK23 = A1_val.dims %} + {% set DM12, DK23 = A1_idx.dims %} + // A1_val.dims {{ A1_val.dims }} A1_idx.dims {{ A1_idx.dims }} + {% set CF, CR, DM1, DM23R = C.gtile(0, params["DM2"] * params["DM3"]) %} + {% set DM1A = DM1 + (DM23R != 0) %} + {% set A1_val_exp = A1_val((DM1A, params["DM2"] * params["DM3S"]), ()) %} + {% set A1_idx_exp = A1_idx((DM1A, params["DM2"]), ()) %} + {% set A_val_exp = A_val((DM1A, params["DM2"] * params["DM3S"]), ()) %} + {% set A_idx_exp = A_idx((DM1A, params["DM2"]), ()) %} + // CF.dims {{ CF.dims }} CR.dims {{ CR.dims }} DM1 {{ DM1 }} DM23R {{ DM23R }} + {% set dm1 = id_gen.make("dm1") %} + {% if DM1 != 0 %}{ + #pragma omp for nowait + FOR ({{dm1}}, {{DM1}}) { + {{ copy_A(A1_val_exp[dm1, '', ''], A_val_exp[dm1, '', '']) }} + {{ copy_A(A1_idx_exp[dm1, '', ''], A_idx_exp[dm1, '', '']) }} + {{ compute_loop_( + A1_val_exp[dm1, '', ''], + A1_idx_exp[dm1, '', ''], + B1, + CF[dm1, '', ''], + parallel=False) }} + } + }{% endif %} + {% if DM23R != 0 %}{ + {% set A1_idx_rem = A1_idx_exp[DM1, '', ''].shorten(((DM12 % params["DM2"]) if (DM12 % params["DM2"]) else (params["DM2"])), None) %} + {% set A1_val_rem = A1_val_exp[DM1, '', ''].shorten(A1_idx_rem.dims[0] * params["DM3S"], None) %} + {% set A_idx_rem = A_idx_exp[DM1, '', ''].shorten(((DM12 % params["DM2"]) if (DM12 % params["DM2"]) else (params["DM2"])), None) %} + {% set A_val_rem = A_val_exp[DM1, '', ''].shorten(A_idx_rem.dims[0] * params["DM3S"], None) %} + #pragma omp single nowait + { + {{ copy_A(A1_val_rem, A_val_rem) }} + {{ copy_A(A1_idx_rem, A_idx_rem) }} + {{ compute_loop_( + A1_val_rem, + A1_idx_rem, + B1, + CR[0, '', ''], + parallel=False) }} + } + }{% endif %} + {% endmacro %} + + {% macro compute_head_or_tail_(B1, B, A1_val, A1_idx, A_val, A_idx, C, is_head) %} + {% set DN23A, DK23F, DN45 = B1.dims %} + {% set DK23, DN2345 = B.dims %} + {% set B1 = B1.shorten(None, DK23, None) %} + #pragma omp parallel + { + {{ copy_B(B1, B) }} + {% if is_head %} + {{ compute_head_(A1_val, A1_idx, A_val, A_idx, B1, C) }} + {% else %} + {{ compute_loop_(A1_val, A1_idx, B1, C, parallel=True) }} + {% endif %} + } + {% endmacro %} + + {% macro compute_head_or_tail_2_(B1, B, A1_val, A1_idx, A_val, A_idx, C) %} + {% set BF, BR, DN1, DN2345R = B.gtile(1, params["DN2345"]) %} + {% set CF, CR, DN1, DN2345R = C.gtile(1, params["DN2345"]) %} + {% set dn1 = id_gen.make("dn1") %} + {% if DN1 == 0 and DN2345R != 0 %} + { + int64_t {{dn1}} = 0; + {{ B1.decl() }} + {{ compute_head_or_tail_(B1, BR['', 0, ''], A1_val, A1_idx, A_val, A_idx, CR['', 0, ''], is_head=True) }} + } + {% else %} + {% if DN1 != 0 %} + { + int64_t {{dn1}} = 0; + {{ B1.decl() }} + {{ compute_head_or_tail_(B1, BF['', dn1, ''], A1_val, A1_idx, A_val, A_idx, CF['', dn1, ''], is_head=True) }} + } + for (int64_t {{dn1}} = 1; {{dn1}} < {{DN1}}; {{dn1}}++) { + {{ B1.decl() }} + {{ compute_head_or_tail_(B1, BF['', dn1, ''], A1_val, A1_idx, A_val, A_idx, CF['', dn1, ''], is_head=False) }} + } + {% endif %} + {% if DN2345R != 0 %}{ + {{ B1.decl() }} + {{ compute_head_or_tail_(B1, BR['', 0, ''], A1_val, A1_idx, A_val, A_idx, CR['', 0, ''], is_head=False) }} + }{% endif %} + {% endif %} + {% endmacro %} + + {% set BF, BR, DK1, DK23R = B.gtile(0, params["DK23"]) %} + {% set DK1A = DK1 + (DK23R != 0) %} + + // BF.dims {{ BF.dims }} BR.dims {{ BR.dims }} + + {% set dk1 = id_gen.make("dk1") %} + {% if DK1 != 0 %} + FOR({{dk1}}, {{DK1}}) { + // dk1 main BF.dims {{ BF.dims }} + {{ A1_val.decl() }} + {{ A1_idx.decl() }} + // A1_val.dims {{ A1_val.dims }} A1_idx.dims {{ A1_idx.dims }} + {{ compute_head_or_tail_2_(B1, BF[dk1, '', ''], A1_val, A1_idx, A_val((), (DK1A, params["DK23"]))['', dk1, ''], A_idx((), (DK1A, params["DK23"]))['', dk1, ''], C) }} + } + {% endif %} + {% if DK23R != 0 %} + { + // dk1 remainder BR.dims {{ BR.dims }} + {{ A1_val.decl() }} + {{ A1_idx.decl() }} + {{ compute_head_or_tail_2_(B1, BR[0, '', ''], A1_val, A1_idx, A_val((), (DK1A, params["DK23"]))['', DK1, ''], A_idx((), (DK1A, params["DK23"]))['', DK1, ''], C) }} + } + {% endif %} + } + """ + ), + trim_blocks=True, + lstrip_blocks=True, + undefined=jinja2.StrictUndefined, + ) + code = template.render({**globals(), **locals()}) + + template = jinja2.Template( + dedent( + """\ + // WARNING! THIS IS GENERATED FILE! DO NOT EDIT! + + #pragma once + + #include + + #ifdef __cplusplus + extern "C" { + #endif + + /* + Configuration: + {{ json.dumps(params) }} + */ + + {{ signature }}; + + #ifdef __cplusplus + } // extern "C" + #endif + """ + ) + ) + header = template.render({**globals(), **locals()}) + + pathlib.Path("generated").mkdir(exist_ok=True) + + with open(f"generated/sparse_dense_{encoded_name}.c", "w") as f: + f.write(code) + with open(f"generated/sparse_dense_{encoded_name}.h", "w") as f: + f.write(header) + + main_template = jinja2.Template( + dedent( + """\ + #include "openblas/cblas.h" + + #include "common.hpp" + + #include "generated/sparse_dense_{{ encoded_name }}.h" + + int main() { + int64_t M = {{ DM }}; + int64_t K = {{ DK }}; + int64_t N = {{ DN }}; + int64_t DK3 = {{ DK3 }}; + + std::vector dA = rand_gen(M * K); + std::vector dB = rand_gen(K * N); + + std::vector dAsv; + std::vector dAsi; + std::vector dAsr; + std::tie(dA, dAsv, dAsi, dAsr) = drop_m_n_non_leading(dA, K, DK3, {{ DM3S }}, {{ DM3 }}, {{ bti_str }}); + + int K_padded = {{ params["DK_padded"] }}; + int M_padded = {{ params["DM_padded"] }}; + FAIL_CHECK(dA.size() == K_padded * M_padded); + + std::vector dA_ref = dA; + std::vector dB_ref = dB; + + int trans_a = {{ 1 * params["trans_a"] }}; + int trans_b = {{ 1 * params["trans_b"] }}; + int trans_c = {{ 1 * params["trans_c"] }}; + if (trans_a) { + dA = transpose(dA, {{ params["DK_padded"] }}); + dAsv = transpose(dAsv, {{ params["DK_padded"] }}); + dAsi = transpose(dAsi, {{ params["DK_padded"] }}); + dAsr = transpose(dAsr, {{ params["DK_padded"] }}); + } + if (trans_b) { + dB = transpose(dB, {{ params["DN"] }}); + } + + std::vector dCd(M * N); + + std::vector dCd_ref(M * N); + + moment t1, t2; + + { + t1 = timer::now(); + cblas_sgemm( + CblasRowMajor, // CBLAS_LAYOUT layout, + CblasNoTrans, // CBLAS_TRANSPOSE TransA, + CblasNoTrans, // CBLAS_TRANSPOSE TransB, + M, // const CBLAS_INDEX M, + N, // const CBLAS_INDEX N, + K, // const CBLAS_INDEX K, + 1.0, // const float alpha, + dA_ref.data(), // const float *A, + K_padded, // const CBLAS_INDEX lda, + dB_ref.data(), // const float *B, + N, // const CBLAS_INDEX ldb, + 0.0, // const float beta, + dCd_ref.data(), // float *C, + N // const CBLAS_INDEX ldc + ); + t2 = timer::now(); + printf("openblas_dense_dense seconds %.3g ns_per_fma %.3g\\n", seconds(t2 - t1), seconds(t2 - t1) / (M * N * K) * 1e9); + } + + if (trans_c) { + dCd_ref = transpose(dCd_ref, {{ params["DN"] }}); + } + + { + std::fill(dCd.begin(), dCd.end(), 0.0f); + t1 = timer::now(); + sparse_dense_{{ encoded_name }}( + dAsv.data(), // const float *A, + dAsr.data(), + dB.data(), // const float *B, + dCd.data() // float *C, + ); + t2 = timer::now(); + printf("my_sparse_dense seconds %.3g ns_per_fma %.3g sparsity %.2f pattern %d:%d\\n", seconds(t2 - t1), seconds(t2 - t1) / (M * N * K * {{ DM3S / DM3 }}) * 1e9, {{ 1.0 - DM3S / DM3 }}, {{ DM3S }}, {{ DM3 }}); + CHECK(vector_allclose(dCd, dCd_ref)); + } + } + """ + ), + trim_blocks=True, + lstrip_blocks=True, + undefined=jinja2.StrictUndefined, + ) + main = main_template.render({**globals(), **locals()}) + + with open(f"generated/sparse_dense_{encoded_name}_main.cpp", "w") as f: + f.write(main) + + return encoded_name + + +def generate_grouped_m_n_converter(DK3, DM3, DM3S): + bti = make_blk_to_idx_list(DM3S, DM3) + bti_str = ( + "{" + + ", ".join( + "{" + ", ".join(str(idx) for idx in idx_tuple) + "}" for idx_tuple in bti + ) + + "}" + ) + + converter_template = jinja2.Template( + dedent( + """\ + #include "common.hpp" + + extern "C" { + void dense_to_sparse_GRP{{DK3}}_M{{DM3S}}_N{{DM3}}(float* sparse_vals, int16_t* sparse_indices, float* dense_vals, int64_t DM, int64_t DK, float* sparsified_dense_vals) { + std::vector dA(dense_vals, dense_vals + DM * DK); + std::vector sdA; + std::vector dAsv; + std::vector dAsi; + std::vector dAsr; + std::tie(sdA, dAsv, dAsi, dAsr) = drop_m_n_non_leading(dA, DK, {{ DK3 }}, {{ DM3S }}, {{ DM3 }}, {{ bti_str }}); + std::copy(dAsv.begin(), dAsv.end(), sparse_vals); + std::copy(dAsr.begin(), dAsr.end(), sparse_indices); + if (sparsified_dense_vals) { + std::copy(sdA.begin(), sdA.end(), sparsified_dense_vals); + } + } + } + + """ + ), + trim_blocks=True, + lstrip_blocks=True, + undefined=jinja2.StrictUndefined, + ) + + with open( + f"generated/sparse_dense_grouped{DK3}_m{DM3S}_n{DM3}_converter.cpp", "w" + ) as f: + f.write(converter_template.render({**globals(), **locals()})) + + +def benchmark_parametrized_sparse_dense(params): + encoded_name = generate_parametrized_sparse_dense(params) + + print("Compilation started...") + subprocess.run( + f"gcc -march=native -g -O3 -fopenmp -lpthread generated/sparse_dense_{encoded_name}.c -c -o generated/sparse_dense_{encoded_name}.o".split(), + check=True, + ) + subprocess.run( + f"g++ -march=native -g -O3 -fopenmp -lpthread -I. -I./OpenBLAS/install/include generated/sparse_dense_{encoded_name}_main.cpp ./OpenBLAS/install/lib/libopenblas.a generated/sparse_dense_{encoded_name}.o -o generated/sparse_dense_{encoded_name}_main".split(), + check=True, + ) + print("Compilation completed") + subprocess.run(f"generated/sparse_dense_{encoded_name}_main", check=True) + + +def derive_params(params_base): + params = dict(params_base) + + # A is always sparse, B is always dense + params["DMI_padded"] = math.ceil(params["DM"] / params["DM3"]) + params["DMS_padded"] = params["DMI_padded"] * params["DM3S"] + params["DM_padded"] = params["DMI_padded"] * params["DM3"] + params["DK2"] = ( + math.factorial(params["DM3"]) + // math.factorial(params["DM3S"]) + // math.factorial(params["DM3"] - params["DM3S"]) + ) + params["DK_padded"] = math.ceil(params["DK"] / (params["DK2"] * params["DK3"])) * ( + params["DK2"] * params["DK3"] + ) + + if params["trans_a"]: + params["SAM"] = 1 + params["SAK"] = params["DM_padded"] + params["SAMS"] = 1 + params["SAKS"] = params["DMS_padded"] + params["SAMI"] = 1 + params["SAKI"] = params["DMI_padded"] + else: + params["SAM"] = params["DK_padded"] + params["SAK"] = 1 + params["SAMS"] = params["DK_padded"] + params["SAKS"] = 1 + params["SAMI"] = params["DK_padded"] + params["SAKI"] = 1 + + if params["trans_b"]: + params["SBK"] = 1 + params["SBN"] = params["DK"] + else: + params["SBK"] = params["DN"] + params["SBN"] = 1 + + if params["trans_c"]: + params["SCM"] = 1 + params["SCN"] = params["DM"] + else: + params["SCM"] = params["DN"] + params["SCN"] = 1 + + return params + + +def test_baseline_sparse_dense(): + for trans_c in (True, False): + for trans_a in (False, True): + for trans_b in (False, True): + params = derive_params( + { + "trans_a": trans_a, + "trans_b": trans_b, + "trans_c": trans_c, + "DM": 4000, + "DK": 3000, + "DN": 2000, + "DK3": 16, + "DM3": 6, + "DM3S": 3, + "DM2": 3, + "DN5": 8, + "DN4": 4, + "DN3": 1, + "DN2": 16, + } + ) + + benchmark_parametrized_sparse_dense(params) + + +def baseline_sparse_dense(): + params = derive_params( + { + "trans_a": False, + "trans_b": False, + "trans_c": False, + "DM": 4000, + "DK": 3000, + "DN": 2000, + "DK3": 16, + "DM3": 6, + "DM3S": 3, + "DM2": 3, + "DN5": 8, + "DN4": 4, + "DN3": 1, + "DN2": 16, + } + ) + + benchmark_parametrized_sparse_dense(params) + + +def baseline_dense_dense(): + params = { + "target_M": 2048, + "target_N": 2048, + "target_K": 2048, + "DK3": 8, + "DK2": 40, + "DM3": 12, + "DM2": 2, + "DN5": 8, + "DN4": 1, + "DN3": 20, + "DN2": 1, + "loop_order_copy_b": list(range(6)), + "loop_order_copy_a": list(range(4)), + "loop_order_inner": list(range(3)), + "B1_layout": list(range(6)), + "A1_layout": list(range(4)), + } + + benchmark_parametrized_dense_dense(params) + + +def test_sparse_dense(): + DK3 = 16 + DM3 = 6 + DM3S = 3 + DM2 = 3 + DM1 = 5 + DN5 = 8 + DN4 = 4 + DN3 = 1 + DN2 = 16 + DN1 = 13 + DK2 = math.factorial(DM3) // math.factorial(DM3S) // math.factorial(DM3 - DM3S) + DK1 = 7 + + DMs = [ + 1, + DM3 - 1, + DM3, + DM2 * DM3 - 1, + DM2 * DM3, + DM1 * DM2 * DM3 - 1, + DM1 * DM2 * DM3, + ] + DNs = [ + 1, + DN5 - 1, + DN5, + DN4 * DN5 - 1, + DN4 * DN5, + DN2 * DN3 * DN4 * DN5 - 1, + DN2 * DN3 * DN4 * DN5, + DN1 * DN2 * DN3 * DN4 * DN5 - 1, + DN1 * DN2 * DN3 * DN4 * DN5, + ] + DKs = [ + 1, + DK3 - 1, + DK3, + DK2 * DK3 - 1, + DK2 * DK3, + DK1 * DK2 * DK3 - 1, + DK1 * DK2 * DK3, + ] + + MNK = [] + for DM in DMs: + MNK.append((DM, DNs[0], DKs[0])) + # MNK.append((DM, DNs[-1], DKs[-1])) + for DN in DNs: + MNK.append((DMs[0], DN, DKs[0])) + # MNK.append((DMs[-1], DN, DKs[-1])) + for DK in DKs: + MNK.append((DMs[0], DNs[0], DK)) + # MNK.append((DMs[-1], DNs[-1], DK)) + + for DM, DN, DK in MNK: + params = { + "DM": DM, + "DK": DK, + "DN": DN, + "DK3": DK3, + "DM3": DM3, + "DM3S": DM3S, + "DM2": DM2, + "DN5": DN5, + "DN4": DN4, + "DN3": DN3, + "DN2": DN2, + "trans_a": False, + "trans_b": False, + "trans_c": False, + } + benchmark_parametrized_sparse_dense(params) + + +def get_tile_used_memory(M, K, N): + return (M * K + K * N + M * N) * 4 + + +def test_bert_sizes1(): + M = 768 + K = 3072 + N = 4096 + + params = { + "target_M": M, + "target_N": N, + "DK": K, + "DK3": 8, + "DK2": 40, + "DM3": 12, + "DM2": 2, + "DN5": 8, + "DN4": 1, + "DN3": 20, + "DN2": 1, + "loop_order_copy_b": list(range(6)), + "loop_order_copy_a": list(range(4)), + "loop_order_inner": list(range(3)), + "B1_layout": list(range(6)), + "A1_layout": list(range(4)), + } + benchmark_parametrized_dense_dense(params) + + return + + acc_size = 4 + + n = 3 + m = 6 + + chunks = fact(m) // fact(n) // fact(m - n) # DK2 + + g = 16 + + tile_b_size = 16 + tile_a_size = 4 + + params = { + "DM": 768, + "DK": 3072, + "DN": 4096, + "DK3": g, # g in n:m:g + "DM3": m, # m in n:m:g + "DM3S": n, # n in n:m:g + "DM2": tile_a_size, # tile size A + "DN5": 8, # vector size + "DN4": acc_size, # accumulator DN4 x DM3S + "DN3": 1, # unused tile dim + "DN2": tile_b_size, # tile size B + "trans_a": False, + "trans_b": False, + "trans_c": False, + } + + changes_per_tile = [ + ("N", 8), # DN5 + ("M", n), # DM3S + ("N", acc_size), # DN4 + ("K", g), # DK3 + ("K", chunks), # DK2 + ("N", tile_b_size), # DN23 + ("M", tile_a_size), # DM2 + ("M", M // (tile_a_size * n)), # DM1 + ("N", N // (tile_b_size * acc_size * 8)), # DN1 + ("K", K // (chunks * g)), # DK1 + ] + sizes_per_tile = [] + for i, e in enumerate(changes_per_tile): + sizes = {"N": 1, "K": 1, "M": 1} + for j in range(i): + dim, size = changes_per_tile[j] + sizes[dim] *= size + sizes_per_tile.append(get_tile_used_memory(**sizes)) + + print(sizes_per_tile) + + cores = 8 + # sudo dmidecode -t cache / lscpu + total_l1 = 256 * 1024 # 32 * 1024 per core + total_l2 = 1024 * 1024 # 256 * 1024 per pair of cores + total_l3 = 8182 * 1024 + l1 = total_l1 // cores + l2 = total_l2 // cores + l3 = total_l3 // cores + print(f"acc_size {acc_size} * n {n} = {acc_size * n} < 16") + assert acc_size * n <= 16 # try to reduce acc_size + print(f"sizes_per_tile[5] {sizes_per_tile[5]} < {l1} (g)") + # assert sizes_per_tile[5] <= l1 # L1d cache: 32K, try to reduce group size g + print(f"sizes_per_tile[6] {sizes_per_tile[6]} < {l2} (tile_b_size)") + # assert sizes_per_tile[6] <= l2 # L2 cache: 256K, try to reduce tile_b_size + print(f"sizes_per_tile[7] {sizes_per_tile[7]} < {l3} (tile_a_size)") + # assert sizes_per_tile[7] <= l3 # L3 cache: 8182K, try to reduce tile_a_size + + benchmark_parametrized_sparse_dense(params) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--run-sparse-baseline", action=argparse.BooleanOptionalAction) + parser.add_argument("--run-dense-baseline", action=argparse.BooleanOptionalAction) + parser.add_argument("--run-tests", action=argparse.BooleanOptionalAction) + args = parser.parse_args() + + if args.run_sparse_baseline: + baseline_sparse_dense() + if args.run_dense_baseline: + baseline_dense_dense() + if args.run_tests: + test_sparse_dense() + test_baseline_sparse_dense() + test_bert_sizes1() diff --git a/src/sten/grouped_nm/sten_impls.py b/src/sten/grouped_nm/sten_impls.py new file mode 100644 index 0000000..2733692 --- /dev/null +++ b/src/sten/grouped_nm/sten_impls.py @@ -0,0 +1,174 @@ +import sten +import torch +import math +from pathlib import Path +import sys +import ctypes + +from .grouped_nm_tensor import GroupedNMTensor +from . import matmul_generator +from .dace_gnm_mult import nmg_mult + + +class GroupedNMSparsifier: + def __init__(self, n, m, g): + self.n = n + self.m = m + self.g = g + + +@sten.register_sparsifier_implementation( + sparsifier=GroupedNMSparsifier, inp=torch.Tensor, out=GroupedNMTensor +) +def dense_to_grouped_nm(sparsifier, tensor, grad_fmt=None): + gnm = GroupedNMTensor.from_dense( + tensor, + sparsifier.n, + sparsifier.m, + sparse_dim=tensor.ndim - 2, + group_size=sparsifier.g, + group_dim=tensor.ndim - 1, + ) + res = sten.SparseTensorWrapper.wrapped_from_dense( + gnm, + tensor, + grad_fmt, + ) + return res + + +LOADED_LIBS = {} # Library keepalive +LOADED_FUNCS = {} # Preloaded functions + + +def sparse_dense_mul_dispatch( + nm_strides, DM, DK, sparse_values, sparse_indices, dense, trans_a, trans_b, trans_c +): + assert len(dense.shape) == 2 + if trans_a: + DM, DK = DK, DM + DN = dense.shape[0 if trans_b else 1] + + # Set transposed B value based on contiguity (if b.T is contiguous, trans_b=True) + trans_b = not dense.is_contiguous() + + kernel = nm_strides["kernel"] + + DK3 = nm_strides["group_size"] + DM3S = nm_strides["n"] # this is also first dim of accumulator size + DM3 = nm_strides["m"] + + DM2 = nm_strides["tile_a"] # cache tile size of A + DN4 = nm_strides["acc_width"] # second dim of accumulator size + DN2 = nm_strides["tile_b"] # cache tile size of B + DN3 = 1 # redundant tile dimension, unused at the momemnt (always 1) + DN5 = 8 if kernel == "avx2" else 16 # vector size (in floats) + + DK2 = math.factorial(DM3) // math.factorial(DM3S) // math.factorial(DM3 - DM3S) + + params = matmul_generator.derive_params( + { + "DK3": DK3, + "DM3": DM3, + "DM3S": DM3S, + "DM2": DM2, + "DN5": DN5, + "DN4": DN4, + "DN3": DN3, + "DN2": DN2, + "DM": DM, + "DK": DK, + "DN": DN, + "trans_a": trans_a, + "trans_b": trans_b, + "trans_c": trans_c, + } + ) + + encoded_name = "nmg_" + matmul_generator.params_hash(params) + if encoded_name in LOADED_FUNCS: + sparse_dense_impl = LOADED_FUNCS[encoded_name] + else: + path = f".dacecache/{encoded_name}/build/lib{encoded_name}.so" + if not Path(path).is_file(): + print("Compilation started...", file=sys.stderr) + nmg_mult( + (DM, DK, DN), + m=DM3, + n=DM3S, + g=DK3, + transpose_b=trans_b, + transpose_c=trans_c, + kernel=kernel, + tile=0, + tile_2=0, + local_b=True, + local_c=True, + name=encoded_name, + ) + print("Compilation completed", file=sys.stderr) + + lib = ctypes.CDLL(path) + sparse_dense_impl = getattr(lib, f"__program_{encoded_name}") + sparse_dense_impl.argtypes = [ + ctypes.c_void_p, # void* state = NULL, + ctypes.c_void_p, # int16_t* __restrict__ A_idx, + ctypes.c_void_p, # float* __restrict__ A_val, + ctypes.c_void_p, # float* __restrict__ B, + ctypes.c_void_p, # float* __restrict__ C, + ctypes.c_void_p, # int* __restrict__ groups = NULL + ] + + LOADED_LIBS[encoded_name] = lib + LOADED_FUNCS[encoded_name] = sparse_dense_impl + + DM_padded = math.ceil(DM / DM3) * DM3 + + output = torch.empty(DN, DM_padded) if trans_c else torch.empty(DM_padded, DN) + svc = sparse_values.contiguous() + sic = sparse_indices.contiguous() + + sparse_dense_impl( + None, + sic.data_ptr(), + svc.data_ptr(), + dense.data_ptr(), + output.data_ptr(), + None, + ) + + return output + + +@sten.register_fwd_op_impl( + operator=torch.nn.functional.linear, + inp=(torch.Tensor, GroupedNMTensor, torch.Tensor), + out=[(sten.KeepAll, torch.Tensor)], +) +def sparse_torch_nn_functional_linear(ctx, inputs, output_sparsifiers): + input, weight, bias = inputs + ctx.save_for_backward(input, weight, bias) + + flattened_input = torch.flatten(input, start_dim=0, end_dim=-2) + + sparse_values = weight.wrapped_tensor.val + sparse_indices = weight.wrapped_tensor.idx + DM, DK = weight.wrapped_tensor.nm_strides["dense_shape"] + + output = sparse_dense_mul_dispatch( + weight.wrapped_tensor.nm_strides, + DM, + DK, + sparse_values, + sparse_indices, + flattened_input, + trans_a=False, + trans_b=True, + trans_c=True, + ) # this is supposed to be more efficient, but unfortunately it is slower + + output = output.reshape((*input.shape[0:-1], -1))[..., :DM] + + if bias is not None: + output += bias.unsqueeze(0).expand_as(output) + return output diff --git a/tests/test_nmg.py b/tests/test_nmg.py new file mode 100644 index 0000000..a51c5ad --- /dev/null +++ b/tests/test_nmg.py @@ -0,0 +1,91 @@ +import torch +import sten +import itertools + + +def test_bert_inference(): + model = torch.hub.load( + "huggingface/pytorch-transformers", "model", "bert-base-uncased" + ) + input = torch.randint(low=0, high=100, size=(8, 512)) + + weights_to_sparsify = [ + module_name + ".weight" + for module_name, module in model.named_modules() + if ( + isinstance(module, torch.nn.modules.linear.Linear) + and "encoder.layer" in module_name + ) + ] + assert weights_to_sparsify + sb = sten.SparsityBuilder() + for weight in weights_to_sparsify: + sb.set_weight( + name=weight, + initial_sparsifier=sten.GroupedNMSparsifier(n=3, m=6, g=4), + out_format=sten.GroupedNMTensor, + ) + sparse_model = sb.get_sparse_model(model) + + output = sparse_model(input) + + +def test_dense_nm_conversion(): + if torch.cuda.is_available(): + device = torch.device("cuda:" + str(torch.cuda.device_count() - 1)) + else: + device = torch.device("cpu") + + torch.manual_seed(123) + dims = list(reversed([1, 5, 17])) + for shape in itertools.product(dims, dims): + shape = (3, *shape) + base_layout = "abc" + for layout in itertools.permutations(base_layout): + layout = "".join(layout) + dense_ten = torch.einsum( + f"{base_layout}->{layout}", torch.rand(shape, device=device) + ) + sparse_dim = layout.index("c") + group_dim = layout.index("b") + + n = 5 + m = 12 + g = 2 + + nm_ten = sten.GroupedNMTensor.from_dense( + dense_ten, + n=n, + m=m, + sparse_dim=sparse_dim, + group_size=g, + group_dim=group_dim, + ) + + sparsified_dense = nm_ten.to_dense() + + assert sten.grouped_nm.grouped_nm_tensor.is_correct_nm( + dense_ten, sparsified_dense, sparse_dim=sparse_dim, n=n, m=m + ) + preserved_magnitude = sparsified_dense.abs().sum() / dense_ten.abs().sum() + assert preserved_magnitude > n / m # it should be better than random + + perfect_nm_ten = sten.PerfectNMTensor.from_dense( + dense_ten, n=n, m=m, sparse_dim=sparse_dim + ) + perfect_sparsified_dense = perfect_nm_ten.to_dense() + assert sten.grouped_nm.grouped_nm_tensor.is_correct_nm( + dense_ten, perfect_sparsified_dense, sparse_dim=sparse_dim, n=n, m=m + ) + perfect_preserved_magnitude = ( + perfect_sparsified_dense.abs().sum() / dense_ten.abs().sum() + ) + assert perfect_preserved_magnitude >= preserved_magnitude - 1e-5 + print( + f"shape {list(dense_ten.shape)} magnitude {n / m:.3f} < {preserved_magnitude:.3f} < {perfect_preserved_magnitude:.3f} ok" + ) + + +if __name__ == "__main__": + test_bert_inference() + test_dense_nm_conversion()