From 96ddc2fd1470a091ac83615f17603247f13b20eb Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 14 Feb 2023 13:00:18 -0800 Subject: [PATCH 1/8] Hack in libsmm --- yateto/arch.py | 2 ++ yateto/codegen/gemm/gemmgen.py | 32 +++++++++++++++++++++++++++++++- yateto/gemm_configuration.py | 20 ++++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/yateto/arch.py b/yateto/arch.py index 5c6f05c..0e3a3b9 100644 --- a/yateto/arch.py +++ b/yateto/arch.py @@ -146,6 +146,8 @@ def getHeterogeneousArchitectureIdentifiedBy(host_arch, device_arch, device_back alignment = 128 elif device_arch in ['dg1', 'bdw', 'skl', 'Gen8', 'Gen9', 'Gen11', 'Gen12LP']: alignment = 32 + elif device_arch in ['pvc']: + alignment = 64 else: raise ValueError(f'Unknown device arch: {device_arch}') diff --git a/yateto/codegen/gemm/gemmgen.py b/yateto/codegen/gemm/gemmgen.py index 58182e3..9621e2f 100644 --- a/yateto/codegen/gemm/gemmgen.py +++ b/yateto/codegen/gemm/gemmgen.py @@ -4,7 +4,7 @@ from abc import ABC from ..cache import RoutineGenerator, GpuRoutineGenerator -from ...gemm_configuration import BLASlike, CodeGenerator, GemmForge +from ...gemm_configuration import BLASlike, CodeGenerator, GemmForge, libsmm from ..common import BatchedOperationsAux import importlib.util @@ -88,6 +88,7 @@ def generate(self, cpp, routineCache): else: flops = 2 * m.size() * n.size() * k.size() + print(self._gemm_cfg) if isinstance(self._gemm_cfg, BLASlike): ptr_a = self._pointer(term=d.leftTerm, offset2=(m.start, k.start), transpose=d.transA) ptr_b = self._pointer(term=d.rightTerm, offset2=(k.start, n.start), transpose=d.transB) @@ -150,6 +151,35 @@ def generate(self, cpp, routineCache): else: raise RuntimeError('gemmforge module is not found. You can install it with pip3. ' 'e.g., pip3 install gemmforge') + elif isinstance(self._gemm_cfg, libsmm): + def address_mode(term): + if term.is_compute_constant: + return 'strided{0}' + if term.is_temporary: + return f'strided{{{term.memoryLayout.requiredReals()}}}' + else: + return 'pointers{}' + + pA = self._pointer(d.leftTerm, (m.start, k.start), d.transA) + pB = self._pointer(d.rightTerm, (k.start, n.start), d.transB) + pC = self._pointer(d.result, (m.start, n.start), False) + + aA = address_mode(d.leftTerm) + aB = address_mode(d.rightTerm) + aC = address_mode(d.result) + + cpp('[&](){') + cpp(f' static auto kernel = smm::gemm(gemm_configuration{{{m.size()}, {n.size()}, {k.size()}, {ldA}, {aA}, {ldB}, {aB}, {ldC}, {aC}, {self._alpha(d.alpha)}, {self._beta(d.beta)}}}, static_cast<::sycl::queue*>(streamPtr));') + cpp(f' kernel.execute({pA}, {pB}, {pC}).wait();') + cpp('}();') + + # cpp('{}({}, {}, {}, nullptr, {}, nullptr);'.format( + # routineName, + # self._pointer(d.leftTerm, (m.start, k.start), d.transA), + # self._pointer(d.rightTerm, (k.start, n.start), d.transB), + # self._pointer(d.result, (m.start, n.start), False), + # d.prefetchName if d.prefetchName is not None else 'nullptr' + # )) else: gemm = { 'M': m.size(), diff --git a/yateto/gemm_configuration.py b/yateto/gemm_configuration.py index 31512a4..f6483ca 100644 --- a/yateto/gemm_configuration.py +++ b/yateto/gemm_configuration.py @@ -258,6 +258,24 @@ def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, ali return Preference.LOWEST return Preference.HIGH +class libsmm(CodeGenerator): + def __init__(self, arch): + super().__init__('', + ['smm/configuration.hpp', 'smm/gemm.hpp'], + '', + arch) + self._arch = arch + + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + return Preference.HIGHEST + + def _archSupported(self): + return self._arch.backend.lower() == 'oneapi' + + def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, + beta, alignedA, alignedC, target): + return self._archSupported() and not (sparseA or sparseB) and (not transA) and (not transB) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'gpu' + class GeneratorCollection(object): def __init__(self, gemmTools: List[GemmTool]): @@ -309,6 +327,8 @@ def __init__(self, arch): elif arch.host_name in defaults: self.gemmTools = defaults[arch.host_name] if arch.is_accelerator: + if arch.backend == 'oneapi': + self.gemmTools.extend([libsmm(arch)]) self.gemmTools.extend([forge]) else: raise Exception("Default generator collection for architecture {} is missing.".format(arch)) From e47788d90de752561c06e15e3cf0ad67bc2bc664 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 13 Apr 2023 13:51:18 -0700 Subject: [PATCH 2/8] Update interface to new smm lib --- yateto/codegen/common.py | 17 +-- yateto/codegen/copyscaleadd/csa_gen.py | 6 +- yateto/codegen/fused_gemms/factory.py | 4 + yateto/codegen/fused_gemms/libsmm.py | 172 +++++++++++++++++++++++++ yateto/codegen/gemm/gemmgen.py | 128 +++++++++++++----- yateto/gemm_configuration.py | 5 +- yateto/generator.py | 2 +- 7 files changed, 284 insertions(+), 50 deletions(-) create mode 100644 yateto/codegen/fused_gemms/libsmm.py diff --git a/yateto/codegen/common.py b/yateto/codegen/common.py index 3ac15be..17164ba 100644 --- a/yateto/codegen/common.py +++ b/yateto/codegen/common.py @@ -108,16 +108,17 @@ def deduce_addresing(self, term): else: return 'pointer_based' - def deduce_arg(self, term, as_const=False): - if term.is_compute_constant or term.is_temporary: - extra_offset = '0' - else: - extra_offset = f'{self.EXTRA_OFFSET_NAME}_{term.name}' - + def deduce_ptr_arg(self, term, as_const=False): if as_const: addressing = self.deduce_addresing(term) ptr = self._get_ptr_type(addressing) const_ptr_type = f'const {self.underlying_data_type} {ptr}' - return f'const_cast<{const_ptr_type}>({term.name}), {extra_offset}' + return f'const_cast<{const_ptr_type}>({term.name})' + else: + return f'{term.name}' + + def deduce_offset_arg(self, term): + if term.is_compute_constant or term.is_temporary: + return '0' else: - return f'{term.name}, {extra_offset}' + return f'{self.EXTRA_OFFSET_NAME}_{term.name}' diff --git a/yateto/codegen/copyscaleadd/csa_gen.py b/yateto/codegen/copyscaleadd/csa_gen.py index 20b85e7..f7553bb 100644 --- a/yateto/codegen/copyscaleadd/csa_gen.py +++ b/yateto/codegen/copyscaleadd/csa_gen.py @@ -90,8 +90,10 @@ def generate(self, cpp, routineCache): routine_name = forge_generator.get_base_name() args = [str(alpha), - aux.deduce_arg(d.term), - aux.deduce_arg(d.result), + aux.deduce_ptr_arg(d.term), + aux.deduce_offset_arg(d.term), + aux.deduce_ptr_arg(d.result), + aux.deduce_offset_arg(d.result), BatchedOperationsAux.NUM_ELEMENTS_NAME, BatchedOperationsAux.FLAGS_NAME, BatchedOperationsAux.STREAM_PTR_NAME] diff --git a/yateto/codegen/fused_gemms/factory.py b/yateto/codegen/fused_gemms/factory.py index 2515563..05636e4 100644 --- a/yateto/codegen/fused_gemms/factory.py +++ b/yateto/codegen/fused_gemms/factory.py @@ -7,6 +7,8 @@ except: raise ('Found chainforge spec but cannot load. Please, check installation of chainforge') +from .libsmm import FusedGemmsLibsmm + class Description(object): def __init__(self, node, result, arguments, add, scalar): @@ -37,5 +39,7 @@ def __next__(self): def generator(arch, descr, target): if target == 'gpu' and gb_spec: return FusedGemms(arch, descr) + elif target == 'gpu': + return FusedGemmsLibsmm(arch, descr) else: raise NotImplementedError(f'no implementation found for {target} target') diff --git a/yateto/codegen/fused_gemms/libsmm.py b/yateto/codegen/fused_gemms/libsmm.py new file mode 100644 index 0000000..5a06e31 --- /dev/null +++ b/yateto/codegen/fused_gemms/libsmm.py @@ -0,0 +1,172 @@ +from ..common import TensorDescription, IndexedTensorDescription, BatchedOperationsAux +from ...ast.indices import BoundingBox +from ..cache import RoutineGenerator, GpuRoutineGenerator +from ...ast.node import IndexedTensor +from ...type import Tensor + +import hashlib + + +class FusedGemmsLibsmm: + W_PREFIX = '_w_' + ARG_PREFIX = '_arg_' + OFF_PREFIX = '_offset_' + + def __init__(self, arch, descr): + self._arch = arch + self._descr = descr + self._batch_aux = BatchedOperationsAux(self._arch.typename) + self._cache = {} + self._tmp_matrices = {} + + def generate(self, cpp, routineCache, cfg): + input_matrices = dict() + is_constant = dict() + is_modified = dict() + var_name = dict() + + def store_matrix(var, node, is_result): + if not (var.is_temporary and is_result) and var not in var_name: + input_matrices[var] = node.memoryLayout() + is_constant[var] = node.tensor.is_compute_constant() if isinstance(node, IndexedTensor) else False + base_name = str(var) + if not base_name.startswith('_'): + base_name = Tensor.getBaseName(base_name) + name = base_name + counter = 1 + while name in var_name.values(): + name = f'{base_name}{counter}' + counter = counter + 1 + var_name[var] = name + if is_result: + is_modified[var] = True + + w_name = lambda x: f'{self.W_PREFIX}{var_name[x]}' + arg_name = lambda x: f'{self.ARG_PREFIX}{var_name[x]}' + off_name = lambda x: f'{self.OFF_PREFIX}{var_name[x]}' + memref_type = lambda ml: f'smm::ir::memref_type(real_t, {ml.bboxi(0).size()}, {ml.bboxi(1).size()}, {ml.stridei(1)})' + has_offset = lambda var: not (var.is_temporary or is_constant[var]) + + body = '' + flops = 0 + for item in self._descr: + node, args, add, scalar = item + res, op1, op2 = args + store_matrix(res, node, True) + store_matrix(op1, node.leftTerm(), False) + store_matrix(op2, node.rightTerm(), False) + + if res.is_temporary and res not in var_name: + var_name[res] = str(res) + body += f'auto {res} = bb.create_alloca({memref_type(node.memoryLayout())}, "{res}");\n' + + bbA = BoundingBox.fromSpp(node.leftTerm().eqspp()) + bbB = BoundingBox.fromSpp(node.rightTerm().eqspp()) + bbC = BoundingBox.fromSpp(node.eqspp()) + + if node.transA() or node.transB(): + raise NotImplementedError('Transposition not supported') + + k = bbA[1] & bbB[0] + m = bbA[0] + n = bbB[1] + + if node.leftTerm().memoryLayout().alignedStride() and node.memoryLayout().alignedStride(): + m = m.aligned(self._arch) + + slic = lambda r, i: f'smm::ir::slice{{{i.start-r.start}, {i.stop-r.start}}}' + name = lambda x: f'{w_name(x)}' if x in input_matrices else var_name[x] + sub = lambda x, ml, i, j: f' bb.create_submatrix({name(x)}, {slic(ml.bboxi(0), i)}, {slic(ml.bboxi(1), j)}),\n' + + body += f'bb.create_matmul(\n'; + body += sub(op1, node.leftTerm().memoryLayout(), m, k) + body += sub(op2, node.rightTerm().memoryLayout(), k, n) + body += sub(res, node.memoryLayout(), m, n) + body += f'{scalar}, {1.0 if add else 0.0});\n' + + flops += 2 * m.size() * n.size() * k.size() + + def batch_type(var): + ml = input_matrices[var] + if is_constant[var]: + return f'{memref_type(ml)}' + stride = f'smm::strided{{{ml.requiredReals()}}}' if var.is_temporary else 'smm::pointers{}' + return f'smm::ir::batch_type({memref_type(ml)}, {stride})' + + pre_body = 'fb.body([&](smm::ir::block_builder& bb) {\n' + for key in input_matrices.keys(): + if is_constant[key]: + pre_body += f'auto {w_name(key)} = {arg_name(key)};\n' + elif has_offset(key): + pre_body += f'auto {w_name(key)} = bb.create_get_work_item({arg_name(key)}, {off_name(key)});\n' + else: + pre_body += f'auto {w_name(key)} = bb.create_get_work_item({arg_name(key)});\n' + post_body = '});\n' + + args = f'constexpr auto real_t = smm::ir::to_scalar_type_v<{self._arch.typename}>;\n' + for key in input_matrices.keys(): + args += f'auto {arg_name(key)} = fb.argument({batch_type(key)}, "{var_name[key]}");\n' + if has_offset(key): + args += f'auto {off_name(key)} = fb.argument(smm::ir::data_type(smm::ir::scalar_type::i32), "offset");\n' + + pre_header = 'static auto kernel = smm::custom_kernel([](smm::ir::function_builder &fb) {\n' + post_header = f'}}, *static_cast<::sycl::queue*>({BatchedOperationsAux.STREAM_PTR_NAME}));\n' + make_kernel = f'{pre_header}{args}{pre_body}{body}{post_body}{post_header}' + + def wrapper_type(key): + ptr2ptr = '*' if not is_constant[key] and not key.is_temporary else '' + const = ' const' if key not in is_modified and not key.is_temporary else '' + return f'{self._arch.typename}{const}*{ptr2ptr}' + + hasher = hashlib.sha512() + hasher.update(make_kernel.encode('utf-8')) + wrapper_name = f'libsmm_wrapper_{hasher.hexdigest()}' + wrapper_args = [f'unsigned {BatchedOperationsAux.NUM_ELEMENTS_NAME}', f'void* {BatchedOperationsAux.STREAM_PTR_NAME}'] + wrapper_call_args = [] + call_args = [] + for key in input_matrices.keys(): + ptr2ptr = '*' if not is_constant[key] and not key.is_temporary else '' + const = ' const' if key not in is_modified and not key.is_temporary else '' + wrapper_args += [f'{wrapper_type(key)} {var_name[key]}'] + wrapper_call_args += [var_name[key]] + call_args += [f'const_cast<{wrapper_type(key)}>({str(key)})'] + if has_offset(key): + offset_name = f'{BatchedOperationsAux.EXTRA_OFFSET_NAME}_{var_name[key]}' + wrapper_args += [f'int {offset_name}'] + wrapper_call_args += [offset_name] + call_args += [f'{BatchedOperationsAux.EXTRA_OFFSET_NAME}_{key}'] + wrapper_call_args = ', '.join(wrapper_call_args) + call_args = ', '.join(call_args) + wrapper_signature = f'void {wrapper_name}({", ".join(wrapper_args)});' + wrapper = f'{wrapper_signature[:-1]} {{\n' + wrapper += make_kernel + wrapper += f'kernel({BatchedOperationsAux.NUM_ELEMENTS_NAME}, {wrapper_call_args}).wait();\n' + wrapper += '}\n\n' + + cpp(f'{wrapper_name}({BatchedOperationsAux.NUM_ELEMENTS_NAME}, {BatchedOperationsAux.STREAM_PTR_NAME}, {call_args});') + + routineCache.addRoutine(wrapper_signature, LibsmmWriter(wrapper_signature, wrapper)) + + return flops + +class LibsmmWriter(GpuRoutineGenerator): + def __init__(self, signature, source): + self._source = source + self._signature = signature + + def __eq__(self, other): + return self._signature == other._signature + + def header(self, cpp): + cpp.include('smm/custom_kernel.hpp') + cpp.include('smm/ir/builder.hpp') + cpp.include('smm/ir/data_type.hpp') + cpp.include('smm/ir/scalar_type.hpp') + cpp.include('smm/ir/slice.hpp') + cpp.includeSys('CL/sycl.hpp') + + def __call__(self, routineName, fileName): + with open(fileName, 'a') as f: + f.write(self._source) + + return self._signature diff --git a/yateto/codegen/gemm/gemmgen.py b/yateto/codegen/gemm/gemmgen.py index 9621e2f..6d76e2a 100644 --- a/yateto/codegen/gemm/gemmgen.py +++ b/yateto/codegen/gemm/gemmgen.py @@ -53,12 +53,15 @@ def generateRoutineName(self, gemm, spp): betaSubs=self._beta(gemm['beta']), **gemm ) - - def _pointer(self, term, offset2, transpose): + + def _offset(self, term, offset2, transpose): if transpose: # swaps elements of tuple if transpose offset2 = offset2[::-1] - o = term.memoryLayout.subtensorOffset(topLeftEntry=offset2) + return term.memoryLayout.subtensorOffset(topLeftEntry=offset2) + + def _pointer(self, term, offset2, transpose): + o = self._offset(term, offset2, transpose) if o > 0: return '{} + {}'.format(term.name, o) return term.name @@ -88,7 +91,6 @@ def generate(self, cpp, routineCache): else: flops = 2 * m.size() * n.size() * k.size() - print(self._gemm_cfg) if isinstance(self._gemm_cfg, BLASlike): ptr_a = self._pointer(term=d.leftTerm, offset2=(m.start, k.start), transpose=d.transA) ptr_b = self._pointer(term=d.rightTerm, offset2=(k.start, n.start), transpose=d.transB) @@ -130,9 +132,12 @@ def generate(self, cpp, routineCache): forge_generator.set(d.transA, d.transB, matrix_a, matrix_b, matrix_c, d.alpha, d.beta) routine_name = forge_generator.get_base_name() - args = [aux.deduce_arg(d.leftTerm, as_const=True), - aux.deduce_arg(d.rightTerm, as_const=True), - aux.deduce_arg(d.result, as_const=False), + args = [aux.deduce_ptr_arg(d.leftTerm, as_const=True), + aux.deduce_offset_arg(d.leftTerm), + aux.deduce_ptr_arg(d.rightTerm, as_const=True), + aux.deduce_offset_arg(d.rightTerm), + aux.deduce_ptr_arg(d.result, as_const=False), + aux.deduce_offset_arg(d.result), BatchedOperationsAux.NUM_ELEMENTS_NAME, BatchedOperationsAux.FLAGS_NAME, BatchedOperationsAux.STREAM_PTR_NAME] @@ -152,34 +157,44 @@ def generate(self, cpp, routineCache): raise RuntimeError('gemmforge module is not found. You can install it with pip3. ' 'e.g., pip3 install gemmforge') elif isinstance(self._gemm_cfg, libsmm): - def address_mode(term): - if term.is_compute_constant: - return 'strided{0}' - if term.is_temporary: - return f'strided{{{term.memoryLayout.requiredReals()}}}' - else: - return 'pointers{}' - - pA = self._pointer(d.leftTerm, (m.start, k.start), d.transA) - pB = self._pointer(d.rightTerm, (k.start, n.start), d.transB) - pC = self._pointer(d.result, (m.start, n.start), False) - - aA = address_mode(d.leftTerm) - aB = address_mode(d.rightTerm) - aC = address_mode(d.result) - - cpp('[&](){') - cpp(f' static auto kernel = smm::gemm(gemm_configuration{{{m.size()}, {n.size()}, {k.size()}, {ldA}, {aA}, {ldB}, {aB}, {ldC}, {aC}, {self._alpha(d.alpha)}, {self._beta(d.beta)}}}, static_cast<::sycl::queue*>(streamPtr));') - cpp(f' kernel.execute({pA}, {pB}, {pC}).wait();') - cpp('}();') - - # cpp('{}({}, {}, {}, nullptr, {}, nullptr);'.format( - # routineName, - # self._pointer(d.leftTerm, (m.start, k.start), d.transA), - # self._pointer(d.rightTerm, (k.start, n.start), d.transB), - # self._pointer(d.result, (m.start, n.start), False), - # d.prefetchName if d.prefetchName is not None else 'nullptr' - # )) + aux = BatchedOperationsAux(self._arch.typename) + gemm = { + 'M': m.size(), + 'N': n.size(), + 'K': k.size(), + 'LDA': ldA, + 'addrA': aux.deduce_addresing(d.leftTerm), + 'distA': d.leftTerm.memoryLayout.requiredReals(), + 'LDB': ldB, + 'addrB': aux.deduce_addresing(d.rightTerm), + 'distB': d.rightTerm.memoryLayout.requiredReals(), + 'LDC': ldC, + 'addrC': aux.deduce_addresing(d.result), + 'distC': d.result.memoryLayout.requiredReals(), + 'alpha': self._alpha(d.alpha), + 'beta': self._beta(d.beta), + 'transA': d.transA, + 'transB': d.transB, + } + routine_name = 'libsmm_wrapper_m{M}_n{N}_k{K}_ldA{LDA}_{addrA}_{distA}_ldB{LDB}_{addrB}_{distB}_ldC{LDC}_{addrC}_{distC}_alpha{alpha}_beta{beta}_transA{transA}_transB{transB}'.format(**gemm) + + + offset_a = self._offset(term=d.leftTerm, offset2=(m.start, k.start), transpose=d.transA) + offset_b = self._offset(term=d.rightTerm, offset2=(k.start, n.start), transpose=d.transB) + offset_c = self._offset(term=d.result, offset2=(m.start, n.start), transpose=False) + args = [aux.deduce_ptr_arg(d.leftTerm, as_const=True), + f'{aux.deduce_offset_arg(d.leftTerm)} + {offset_a}', + aux.deduce_ptr_arg(d.rightTerm, as_const=True), + f'{aux.deduce_offset_arg(d.rightTerm)} + {offset_b}', + aux.deduce_ptr_arg(d.result, as_const=False), + f'{aux.deduce_offset_arg(d.result)} + {offset_c}', + BatchedOperationsAux.NUM_ELEMENTS_NAME, + BatchedOperationsAux.STREAM_PTR_NAME] + args = ', '.join(args) + + cpp(f'{routine_name}({args});') + + routineCache.addRoutine(routine_name, LibsmmGemmGen(self._arch, gemm)) else: gemm = { 'M': m.size(), @@ -456,3 +471,46 @@ def __call__(self, routineName, fileName): file.write(f"{func_signature}") file.write(self._call(routineName)) return func_signature + ";" + +class LibsmmGemmGen(GpuRoutineGenerator): + def __init__(self, arch, gemm_descr): + self.arch = arch + self.gemm_descr = gemm_descr + + def __eq__(self, other): + return self.arch == other.arch and self.gemm_descr == other.gemm_descr + + def header(self, cpp): + cpp.include('smm/configuration.hpp') + cpp.include('smm/gemm.hpp') + cpp.includeSys('CL/sycl.hpp') + + def _functionSignature(self, routineName): + typ = self.arch.typename + stars = lambda x: '**' if x == 'pointer_based' else '*' + starsA = stars(self.gemm_descr['addrA']) + starsB = stars(self.gemm_descr['addrB']) + starsC = stars(self.gemm_descr['addrC']) + return f'void {routineName}({typ} const{starsA} A, int offsetA, {typ} const{starsB} B, int offsetB, {typ}{starsC} C, int offsetC, unsigned {BatchedOperationsAux.NUM_ELEMENTS_NAME}, void* {BatchedOperationsAux.STREAM_PTR_NAME})' + + def address_mode(self, addr, dist): + if addr == 'pointer_based': + return 'smm::pointers{}' + elif addr == 'none': + return 'smm::strided{0}' + elif addr == 'strided': + return f'smm::strided{{{dist}}}' + raise NameError(addr) + + def __call__(self, routineName, fileName): + func_signature = self._functionSignature(routineName) + with open(fileName, "a") as f: + aA = self.address_mode(self.gemm_descr['addrA'], self.gemm_descr['distA']) + aB = self.address_mode(self.gemm_descr['addrB'], self.gemm_descr['distB']) + aC = self.address_mode(self.gemm_descr['addrC'], self.gemm_descr['distC']) + + f.write(f'{func_signature} {{\n') + f.write(' static auto kernel = smm::make_gemm<{typ}>({M}, {N}, {K}, {LDA}, {aA}, {LDB}, {aB}, {LDC}, {aC}, {alpha}, {beta}, *static_cast<::sycl::queue*>({stream}));\n'.format(typ=self.arch.typename, aA=aA, aB=aB, aC=aC, stream=BatchedOperationsAux.STREAM_PTR_NAME, **self.gemm_descr)) + f.write(f' kernel(A, offsetA, B, offsetB, C, offsetC, {BatchedOperationsAux.NUM_ELEMENTS_NAME}).wait();\n') + f.write('}\n') + return func_signature + ";" diff --git a/yateto/gemm_configuration.py b/yateto/gemm_configuration.py index f6483ca..8ec9e1e 100644 --- a/yateto/gemm_configuration.py +++ b/yateto/gemm_configuration.py @@ -260,10 +260,7 @@ def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, ali class libsmm(CodeGenerator): def __init__(self, arch): - super().__init__('', - ['smm/configuration.hpp', 'smm/gemm.hpp'], - '', - arch) + super().__init__('', [], '', arch) self._arch = arch def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): diff --git a/yateto/generator.py b/yateto/generator.py index ea59b3e..dd6d51f 100644 --- a/yateto/generator.py +++ b/yateto/generator.py @@ -102,7 +102,7 @@ def prepareUntilCodeGen(self, cost_estimator): self.cfg = SubstituteBackward().visit(self.cfg) self.cfg = RemoveEmptyStatements().visit(self.cfg) self.cfg = MergeActions().visit(self.cfg) - if self.target == 'gpu' and chainforge_spec: + if self.target == 'gpu':# and chainforge_spec: self.cfg = FindFusedGemms().visit(self.cfg) self.cfg = LivenessAnalysis().visit(self.cfg) From cc799f67e96447ca8c413d232325a53a0b31c146 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Thu, 13 Apr 2023 15:04:35 -0700 Subject: [PATCH 3/8] Fix transpose later Signed-off-by: Carsten Uphoff --- yateto/codegen/fused_gemms/libsmm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/yateto/codegen/fused_gemms/libsmm.py b/yateto/codegen/fused_gemms/libsmm.py index 5a06e31..7994faf 100644 --- a/yateto/codegen/fused_gemms/libsmm.py +++ b/yateto/codegen/fused_gemms/libsmm.py @@ -65,7 +65,8 @@ def store_matrix(var, node, is_result): bbC = BoundingBox.fromSpp(node.eqspp()) if node.transA() or node.transB(): - raise NotImplementedError('Transposition not supported') + #raise NotImplementedError('Transposition not supported') + print(f'WARNING: Transposition not supported yet in {res} = {op1} * {op2}') k = bbA[1] & bbB[0] m = bbA[0] From 52bbed0751dae47410046c4756e58b362bca8957 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 24 Apr 2023 08:19:16 -0700 Subject: [PATCH 4/8] Add transpose for libsmm Signed-off-by: Carsten Uphoff --- yateto/codegen/gemm/gemmgen.py | 6 +++++- yateto/gemm_configuration.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/yateto/codegen/gemm/gemmgen.py b/yateto/codegen/gemm/gemmgen.py index 6d76e2a..5a5b33b 100644 --- a/yateto/codegen/gemm/gemmgen.py +++ b/yateto/codegen/gemm/gemmgen.py @@ -509,8 +509,12 @@ def __call__(self, routineName, fileName): aB = self.address_mode(self.gemm_descr['addrB'], self.gemm_descr['distB']) aC = self.address_mode(self.gemm_descr['addrC'], self.gemm_descr['distC']) + T = lambda x: 'smm::transpose::T' if x else 'smm::transpose::N' + tA = T(self.gemm_descr['transA']) + tB = T(self.gemm_descr['transB']) + f.write(f'{func_signature} {{\n') - f.write(' static auto kernel = smm::make_gemm<{typ}>({M}, {N}, {K}, {LDA}, {aA}, {LDB}, {aB}, {LDC}, {aC}, {alpha}, {beta}, *static_cast<::sycl::queue*>({stream}));\n'.format(typ=self.arch.typename, aA=aA, aB=aB, aC=aC, stream=BatchedOperationsAux.STREAM_PTR_NAME, **self.gemm_descr)) + f.write(' static auto kernel = smm::make_gemm<{typ}>({tA}, {tB}, {M}, {N}, {K}, {LDA}, {aA}, {LDB}, {aB}, {LDC}, {aC}, {alpha}, {beta}, *static_cast<::sycl::queue*>({stream}));\n'.format(typ=self.arch.typename, aA=aA, aB=aB, aC=aC, tA=tA, tB=tB, stream=BatchedOperationsAux.STREAM_PTR_NAME, **self.gemm_descr)) f.write(f' kernel(A, offsetA, B, offsetB, C, offsetC, {BatchedOperationsAux.NUM_ELEMENTS_NAME}).wait();\n') f.write('}\n') return func_signature + ";" diff --git a/yateto/gemm_configuration.py b/yateto/gemm_configuration.py index 8ec9e1e..3dfbc5a 100644 --- a/yateto/gemm_configuration.py +++ b/yateto/gemm_configuration.py @@ -271,7 +271,7 @@ def _archSupported(self): def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, target): - return self._archSupported() and not (sparseA or sparseB) and (not transA) and (not transB) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'gpu' + return self._archSupported() and not (sparseA or sparseB) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'gpu' class GeneratorCollection(object): From 143247d76036ea83975e4573b892fa2ae54dd5a9 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Mon, 22 May 2023 07:57:19 -0700 Subject: [PATCH 5/8] Adapt to new interface Signed-off-by: Carsten Uphoff --- yateto/codegen/gemm/gemmgen.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/yateto/codegen/gemm/gemmgen.py b/yateto/codegen/gemm/gemmgen.py index 5a5b33b..d9462a4 100644 --- a/yateto/codegen/gemm/gemmgen.py +++ b/yateto/codegen/gemm/gemmgen.py @@ -512,9 +512,11 @@ def __call__(self, routineName, fileName): T = lambda x: 'smm::transpose::T' if x else 'smm::transpose::N' tA = T(self.gemm_descr['transA']) tB = T(self.gemm_descr['transB']) + alpha = self.gemm_descr['alpha'] + beta = self.gemm_descr['beta'] f.write(f'{func_signature} {{\n') - f.write(' static auto kernel = smm::make_gemm<{typ}>({tA}, {tB}, {M}, {N}, {K}, {LDA}, {aA}, {LDB}, {aB}, {LDC}, {aC}, {alpha}, {beta}, *static_cast<::sycl::queue*>({stream}));\n'.format(typ=self.arch.typename, aA=aA, aB=aB, aC=aC, tA=tA, tB=tB, stream=BatchedOperationsAux.STREAM_PTR_NAME, **self.gemm_descr)) - f.write(f' kernel(A, offsetA, B, offsetB, C, offsetC, {BatchedOperationsAux.NUM_ELEMENTS_NAME}).wait();\n') + f.write(' static auto kernel = smm::make_gemm<{typ}>({tA}, {tB}, {M}, {N}, {K}, {LDA}, {aA}, {LDB}, {aB}, {LDC}, {aC}, *static_cast<::sycl::queue*>({stream}));\n'.format(typ=self.arch.typename, aA=aA, aB=aB, aC=aC, tA=tA, tB=tB, stream=BatchedOperationsAux.STREAM_PTR_NAME, **self.gemm_descr)) + f.write(f' kernel({alpha}, A, B, {beta}, C, {BatchedOperationsAux.NUM_ELEMENTS_NAME}).wait();\n') f.write('}\n') return func_signature + ";" From 5608d7f3859ebf11e8ee0c91d142a9c99777b868 Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 30 Apr 2024 00:51:35 -0700 Subject: [PATCH 6/8] Update Tiny Tensor Compiler plumbing Signed-off-by: Carsten Uphoff --- yateto/codegen/fused_gemms/factory.py | 4 +- yateto/codegen/fused_gemms/libsmm.py | 173 ---------------------- yateto/codegen/fused_gemms/tinytc.py | 205 ++++++++++++++++++++++++++ yateto/codegen/gemm/gemmgen.py | 120 ++++++++++++--- yateto/gemm_configuration.py | 7 +- 5 files changed, 307 insertions(+), 202 deletions(-) delete mode 100644 yateto/codegen/fused_gemms/libsmm.py create mode 100644 yateto/codegen/fused_gemms/tinytc.py diff --git a/yateto/codegen/fused_gemms/factory.py b/yateto/codegen/fused_gemms/factory.py index 984896e..0beb0eb 100644 --- a/yateto/codegen/fused_gemms/factory.py +++ b/yateto/codegen/fused_gemms/factory.py @@ -14,7 +14,7 @@ except: raise ('Found chainforge spec but cannot load. Please, check installation of chainforge') -from .libsmm import FusedGemmsLibsmm +from .tinytc import FusedGemmsTinytc class Description(object): @@ -47,6 +47,6 @@ def generator(arch, descr, target): if target == 'gpu' and gb_spec: return FusedGemms(arch, descr) elif target == 'gpu': - return FusedGemmsLibsmm(arch, descr) + return FusedGemmsTinytc(arch, descr) else: raise NotImplementedError(f'no implementation found for {target} target') diff --git a/yateto/codegen/fused_gemms/libsmm.py b/yateto/codegen/fused_gemms/libsmm.py deleted file mode 100644 index 7994faf..0000000 --- a/yateto/codegen/fused_gemms/libsmm.py +++ /dev/null @@ -1,173 +0,0 @@ -from ..common import TensorDescription, IndexedTensorDescription, BatchedOperationsAux -from ...ast.indices import BoundingBox -from ..cache import RoutineGenerator, GpuRoutineGenerator -from ...ast.node import IndexedTensor -from ...type import Tensor - -import hashlib - - -class FusedGemmsLibsmm: - W_PREFIX = '_w_' - ARG_PREFIX = '_arg_' - OFF_PREFIX = '_offset_' - - def __init__(self, arch, descr): - self._arch = arch - self._descr = descr - self._batch_aux = BatchedOperationsAux(self._arch.typename) - self._cache = {} - self._tmp_matrices = {} - - def generate(self, cpp, routineCache, cfg): - input_matrices = dict() - is_constant = dict() - is_modified = dict() - var_name = dict() - - def store_matrix(var, node, is_result): - if not (var.is_temporary and is_result) and var not in var_name: - input_matrices[var] = node.memoryLayout() - is_constant[var] = node.tensor.is_compute_constant() if isinstance(node, IndexedTensor) else False - base_name = str(var) - if not base_name.startswith('_'): - base_name = Tensor.getBaseName(base_name) - name = base_name - counter = 1 - while name in var_name.values(): - name = f'{base_name}{counter}' - counter = counter + 1 - var_name[var] = name - if is_result: - is_modified[var] = True - - w_name = lambda x: f'{self.W_PREFIX}{var_name[x]}' - arg_name = lambda x: f'{self.ARG_PREFIX}{var_name[x]}' - off_name = lambda x: f'{self.OFF_PREFIX}{var_name[x]}' - memref_type = lambda ml: f'smm::ir::memref_type(real_t, {ml.bboxi(0).size()}, {ml.bboxi(1).size()}, {ml.stridei(1)})' - has_offset = lambda var: not (var.is_temporary or is_constant[var]) - - body = '' - flops = 0 - for item in self._descr: - node, args, add, scalar = item - res, op1, op2 = args - store_matrix(res, node, True) - store_matrix(op1, node.leftTerm(), False) - store_matrix(op2, node.rightTerm(), False) - - if res.is_temporary and res not in var_name: - var_name[res] = str(res) - body += f'auto {res} = bb.create_alloca({memref_type(node.memoryLayout())}, "{res}");\n' - - bbA = BoundingBox.fromSpp(node.leftTerm().eqspp()) - bbB = BoundingBox.fromSpp(node.rightTerm().eqspp()) - bbC = BoundingBox.fromSpp(node.eqspp()) - - if node.transA() or node.transB(): - #raise NotImplementedError('Transposition not supported') - print(f'WARNING: Transposition not supported yet in {res} = {op1} * {op2}') - - k = bbA[1] & bbB[0] - m = bbA[0] - n = bbB[1] - - if node.leftTerm().memoryLayout().alignedStride() and node.memoryLayout().alignedStride(): - m = m.aligned(self._arch) - - slic = lambda r, i: f'smm::ir::slice{{{i.start-r.start}, {i.stop-r.start}}}' - name = lambda x: f'{w_name(x)}' if x in input_matrices else var_name[x] - sub = lambda x, ml, i, j: f' bb.create_submatrix({name(x)}, {slic(ml.bboxi(0), i)}, {slic(ml.bboxi(1), j)}),\n' - - body += f'bb.create_matmul(\n'; - body += sub(op1, node.leftTerm().memoryLayout(), m, k) - body += sub(op2, node.rightTerm().memoryLayout(), k, n) - body += sub(res, node.memoryLayout(), m, n) - body += f'{scalar}, {1.0 if add else 0.0});\n' - - flops += 2 * m.size() * n.size() * k.size() - - def batch_type(var): - ml = input_matrices[var] - if is_constant[var]: - return f'{memref_type(ml)}' - stride = f'smm::strided{{{ml.requiredReals()}}}' if var.is_temporary else 'smm::pointers{}' - return f'smm::ir::batch_type({memref_type(ml)}, {stride})' - - pre_body = 'fb.body([&](smm::ir::block_builder& bb) {\n' - for key in input_matrices.keys(): - if is_constant[key]: - pre_body += f'auto {w_name(key)} = {arg_name(key)};\n' - elif has_offset(key): - pre_body += f'auto {w_name(key)} = bb.create_get_work_item({arg_name(key)}, {off_name(key)});\n' - else: - pre_body += f'auto {w_name(key)} = bb.create_get_work_item({arg_name(key)});\n' - post_body = '});\n' - - args = f'constexpr auto real_t = smm::ir::to_scalar_type_v<{self._arch.typename}>;\n' - for key in input_matrices.keys(): - args += f'auto {arg_name(key)} = fb.argument({batch_type(key)}, "{var_name[key]}");\n' - if has_offset(key): - args += f'auto {off_name(key)} = fb.argument(smm::ir::data_type(smm::ir::scalar_type::i32), "offset");\n' - - pre_header = 'static auto kernel = smm::custom_kernel([](smm::ir::function_builder &fb) {\n' - post_header = f'}}, *static_cast<::sycl::queue*>({BatchedOperationsAux.STREAM_PTR_NAME}));\n' - make_kernel = f'{pre_header}{args}{pre_body}{body}{post_body}{post_header}' - - def wrapper_type(key): - ptr2ptr = '*' if not is_constant[key] and not key.is_temporary else '' - const = ' const' if key not in is_modified and not key.is_temporary else '' - return f'{self._arch.typename}{const}*{ptr2ptr}' - - hasher = hashlib.sha512() - hasher.update(make_kernel.encode('utf-8')) - wrapper_name = f'libsmm_wrapper_{hasher.hexdigest()}' - wrapper_args = [f'unsigned {BatchedOperationsAux.NUM_ELEMENTS_NAME}', f'void* {BatchedOperationsAux.STREAM_PTR_NAME}'] - wrapper_call_args = [] - call_args = [] - for key in input_matrices.keys(): - ptr2ptr = '*' if not is_constant[key] and not key.is_temporary else '' - const = ' const' if key not in is_modified and not key.is_temporary else '' - wrapper_args += [f'{wrapper_type(key)} {var_name[key]}'] - wrapper_call_args += [var_name[key]] - call_args += [f'const_cast<{wrapper_type(key)}>({str(key)})'] - if has_offset(key): - offset_name = f'{BatchedOperationsAux.EXTRA_OFFSET_NAME}_{var_name[key]}' - wrapper_args += [f'int {offset_name}'] - wrapper_call_args += [offset_name] - call_args += [f'{BatchedOperationsAux.EXTRA_OFFSET_NAME}_{key}'] - wrapper_call_args = ', '.join(wrapper_call_args) - call_args = ', '.join(call_args) - wrapper_signature = f'void {wrapper_name}({", ".join(wrapper_args)});' - wrapper = f'{wrapper_signature[:-1]} {{\n' - wrapper += make_kernel - wrapper += f'kernel({BatchedOperationsAux.NUM_ELEMENTS_NAME}, {wrapper_call_args}).wait();\n' - wrapper += '}\n\n' - - cpp(f'{wrapper_name}({BatchedOperationsAux.NUM_ELEMENTS_NAME}, {BatchedOperationsAux.STREAM_PTR_NAME}, {call_args});') - - routineCache.addRoutine(wrapper_signature, LibsmmWriter(wrapper_signature, wrapper)) - - return flops - -class LibsmmWriter(GpuRoutineGenerator): - def __init__(self, signature, source): - self._source = source - self._signature = signature - - def __eq__(self, other): - return self._signature == other._signature - - def header(self, cpp): - cpp.include('smm/custom_kernel.hpp') - cpp.include('smm/ir/builder.hpp') - cpp.include('smm/ir/data_type.hpp') - cpp.include('smm/ir/scalar_type.hpp') - cpp.include('smm/ir/slice.hpp') - cpp.includeSys('CL/sycl.hpp') - - def __call__(self, routineName, fileName): - with open(fileName, 'a') as f: - f.write(self._source) - - return self._signature diff --git a/yateto/codegen/fused_gemms/tinytc.py b/yateto/codegen/fused_gemms/tinytc.py new file mode 100644 index 0000000..a4b1065 --- /dev/null +++ b/yateto/codegen/fused_gemms/tinytc.py @@ -0,0 +1,205 @@ +from ..common import TensorDescription, IndexedTensorDescription, BatchedOperationsAux +from ...ast.indices import BoundingBox +from ..cache import RoutineGenerator, GpuRoutineGenerator +from ...ast.node import IndexedTensor +from ...type import Tensor + +import hashlib + + +class FusedGemmsTinytc: + def __init__(self, arch, descr): + self._arch = arch + self._descr = descr + self._batch_aux = BatchedOperationsAux(self._arch.typename) + self._cache = {} + self._tmp_matrices = {} + self._scalar_type = 'f64' if self._arch.bytesPerReal == 8 else 'f32' + self._var_counter = 0 + + def next_var(self): + count = self._var_counter + self._var_counter += 1 + return count + + def generate(self, cpp, routineCache, cfg): + input_matrices = dict() + is_constant = dict() + is_modified = dict() + var_name = dict() + work_item_name = dict() + self._var_counter = 0 + + def store_matrix(var, node, is_result): + if var not in var_name: + if var.is_temporary and is_result: + var_name[res] = 'tmp' + else: + input_matrices[var] = node.memoryLayout() + is_constant[var] = node.tensor.is_compute_constant() if isinstance(node, IndexedTensor) else False + base_name = str(var) + if not base_name.startswith('_'): + base_name = Tensor.getBaseName(base_name) + name = base_name + counter = 1 + while name in var_name.values(): + name = f'{base_name}{counter}' + counter = counter + 1 + var_name[var] = name + if is_result: + is_modified[var] = True + + def batch_type(var): + ml = input_matrices[var] + if is_constant[var]: + return f'{memref_type(ml)}' + elif var.is_temporary: + return f'{batch_memref_type(ml)}' + else: + return f'group<{memref_type(ml)}, offset: ?>' + + + memref_type = lambda ml: f'memref<{self._scalar_type}x{ml.bboxi(0).size()}x{ml.bboxi(1).size()},strided<1,{ml.stridei(1)}>>' + batch_memref_type = lambda ml: f'memref<{self._scalar_type}x{ml.bboxi(0).size()}x{ml.bboxi(1).size()}x?,strided<1,{ml.stridei(1)},{ml.requiredReals()}>>' + + for item in self._descr: + node, args, _, _ = item + res, op1, op2 = args + store_matrix(res, node, True) + store_matrix(op1, node.leftTerm(), False) + store_matrix(op2, node.rightTerm(), False) + + args = [f'%{var_name[key]}: {batch_type(key)}' for key in input_matrices.keys()] + args_str = ',\n '.join(args) + source = f'func @fused_gemm({args_str}) {{\n' + + source += f' %{self.next_var()} = group_id\n' + gid = self._var_counter-1 + for key in input_matrices.keys(): + if not is_constant[key]: + new_var = self.next_var() + if key.is_temporary: + source += f' %{new_var} = load %{var_name[key]}[:,:,%{gid}] : {batch_type(key)}\n' + else: + source += f' %{new_var} = load %{var_name[key]}[%{gid}] : {batch_type(key)}\n' + work_item_name[key] = str(new_var) + + flops = 0 + for item in self._descr: + node, args, add, scalar = item + res, op1, op2 = args + + if res.is_temporary: + var_name[res] = f'tmp{self.next_var()}' + source += f' %{var_name[res]} = alloca -> {memref_type(node.memoryLayout())}\n' + + bbA = BoundingBox.fromSpp(node.leftTerm().eqspp()) + bbB = BoundingBox.fromSpp(node.rightTerm().eqspp()) + bbC = BoundingBox.fromSpp(node.eqspp()) + + k_op1 = 0 if node.transA() else 1 + k_op2 = 1 if node.transB() else 0 + k = bbA[k_op1] & bbB[k_op2] + m = bbA[1 - k_op1] + n = bbB[1 - k_op2] + + if not node.transA() and node.leftTerm().memoryLayout().alignedStride() and node.memoryLayout().alignedStride(): + m = m.aligned(self._arch) + + slic = lambda r, i: f'{i.start-r.start}:{i.stop-i.start}' + name = lambda var: work_item_name[var] if var in work_item_name else var_name[var] + subview = lambda var, ml, range1, range2: (f' %{self.next_var()} = subview %{name(var)}[{slic(ml.bboxi(0), range1)},{slic(ml.bboxi(1), range2)}] : {memref_type(ml)}\n', f'memref<{self._scalar_type}x{range1.stop-range1.start}x{range2.stop-range2.start},strided<1,{ml.stridei(1)}>>') + trans = lambda t: 't' if t else 'n' + + op1_sub, op1_sub_ty = subview(op1, node.leftTerm().memoryLayout(), m, k) + op2_sub, op2_sub_ty = subview(op2, node.rightTerm().memoryLayout(), k, n) + res_sub, res_sub_ty = subview(res, node.memoryLayout(), m, n) + source += op1_sub + op2_sub + res_sub + source += f' gemm.{trans(node.transA())}.{trans(node.transB())} {scalar}, %{self._var_counter-3}, %{self._var_counter-2}, {1.0 if add else 0.0}, %{self._var_counter-1} : {self._scalar_type}, {op1_sub_ty}, {op2_sub_ty}, {self._scalar_type}, {res_sub_ty}\n'; + + flops += 2 * m.size() * n.size() * k.size() + + source += '}\n' + + make_kernel = """ struct custom_kernel { ::sycl::kernel kernel; ::sycl::range<3u> group_size; }; + static auto k = [&] (::sycl::queue const& queue) -> custom_kernel { + static const std::string source = R\"tinytc( +""" + make_kernel += source + make_kernel += """)tinytc\"; + auto source_ctx = tinytc::make_source_context(); + try { + auto program = tinytc::parse_string(source, source_ctx); + auto info = tinytc::make_core_info(queue.get_device()); + auto binary = tinytc::compile_to_binary(program, info, tinytc::bundle_format::native, source_ctx); + auto bundle = tinytc::make_kernel_bundle(queue.get_context(), queue.get_device(), binary); + auto kernel = tinytc::make_kernel(bundle, "fused_gemm"); + auto group_size = tinytc::get_group_size(kernel); + return {std::move(kernel), std::move(group_size)}; + } catch (tinytc::status const& st) { + throw std::runtime_error(source_ctx.get_error_log()); + } + }"""; + make_kernel += f'(*static_cast<::sycl::queue*>({BatchedOperationsAux.STREAM_PTR_NAME}));\n' + + def wrapper_type(key): + ptr2ptr = '*' if not is_constant[key] and not key.is_temporary else '' + const = ' const' if key not in is_modified and not key.is_temporary else '' + return f'{self._arch.typename}{const}*{ptr2ptr}' + + hasher = hashlib.sha512() + hasher.update(make_kernel.encode('utf-8')) + wrapper_name = f'tinytc_wrapper_{hasher.hexdigest()}' + wrapper_args = [f'unsigned {BatchedOperationsAux.NUM_ELEMENTS_NAME}', f'void* {BatchedOperationsAux.STREAM_PTR_NAME}'] + wrapper_call_args = [] + call_args = [] + for key in input_matrices.keys(): + ptr2ptr = '*' if not is_constant[key] and not key.is_temporary else '' + const = ' const' if key not in is_modified and not key.is_temporary else '' + wrapper_args += [f'{wrapper_type(key)} {var_name[key]}'] + wrapper_call_args += [var_name[key]] + call_args += [f'const_cast<{wrapper_type(key)}>({str(key)})'] + if key.is_temporary: + wrapper_call_args.append(BatchedOperationsAux.NUM_ELEMENTS_NAME) + elif not is_constant[key]: + offset_name = f'{BatchedOperationsAux.EXTRA_OFFSET_NAME}_{var_name[key]}' + wrapper_args.append(f'int {offset_name}') + wrapper_call_args.append(offset_name) + call_args.append(f'{BatchedOperationsAux.EXTRA_OFFSET_NAME}_{key}') + wrapper_call_args = ', '.join(wrapper_call_args) + call_args = ', '.join(call_args) + wrapper_signature = f'void {wrapper_name}({", ".join(wrapper_args)});' + wrapper = f'{wrapper_signature[:-1]} {{\n' + wrapper += make_kernel + wrapper += f' static_cast<::sycl::queue*>({BatchedOperationsAux.STREAM_PTR_NAME})->submit([&](::sycl::handler &h) {{\n'; + wrapper += f' h.set_args({wrapper_call_args});\n' + wrapper += f' h.parallel_for(::sycl::nd_range{{tinytc::get_global_size({BatchedOperationsAux.NUM_ELEMENTS_NAME}, k.group_size), k.group_size}}, k.kernel);\n' + wrapper += ' }).wait();\n' + wrapper += '}\n\n' + + cpp(f'{wrapper_name}({BatchedOperationsAux.NUM_ELEMENTS_NAME}, {BatchedOperationsAux.STREAM_PTR_NAME}, {call_args});') + + routineCache.addRoutine(wrapper_signature, LibsmmWriter(wrapper_signature, wrapper)) + + return flops + +class LibsmmWriter(GpuRoutineGenerator): + def __init__(self, signature, source): + self._source = source + self._signature = signature + + def __eq__(self, other): + return self._signature == other._signature + + def header(self, cpp): + cpp.include('tinytc/tinytc.hpp') + cpp.include('tinytc/tinytc_sycl.hpp') + cpp.includeSys('sycl/sycl.hpp') + cpp.includeSys('stdexcept') + cpp.includeSys('utility') + + def __call__(self, routineName, fileName): + with open(fileName, 'a') as f: + f.write(self._source) + + return self._signature diff --git a/yateto/codegen/gemm/gemmgen.py b/yateto/codegen/gemm/gemmgen.py index 7d3ed5b..14180cf 100644 --- a/yateto/codegen/gemm/gemmgen.py +++ b/yateto/codegen/gemm/gemmgen.py @@ -2,9 +2,10 @@ import subprocess import tempfile from abc import ABC +from collections import namedtuple from ..cache import RoutineGenerator, GpuRoutineGenerator -from ...gemm_configuration import BLASlike, CodeGenerator, GemmForge, libsmm +from ...gemm_configuration import BLASlike, CodeGenerator, GemmForge, tinytc from ..common import BatchedOperationsAux import importlib.util @@ -156,7 +157,7 @@ def generate(self, cpp, routineCache): else: raise RuntimeError('gemmforge module is not found. You can install it with pip3. ' 'e.g., pip3 install gemmforge') - elif isinstance(self._gemm_cfg, libsmm): + elif isinstance(self._gemm_cfg, tinytc): aux = BatchedOperationsAux(self._arch.typename) gemm = { 'M': m.size(), @@ -176,7 +177,7 @@ def generate(self, cpp, routineCache): 'transA': d.transA, 'transB': d.transB, } - routine_name = 'libsmm_wrapper_m{M}_n{N}_k{K}_ldA{LDA}_{addrA}_{distA}_ldB{LDB}_{addrB}_{distB}_ldC{LDC}_{addrC}_{distC}_alpha{alpha}_beta{beta}_transA{transA}_transB{transB}'.format(**gemm) + routine_name = 'tinytc_wrapper_m{M}_n{N}_k{K}_ldA{LDA}_{addrA}_{distA}_ldB{LDB}_{addrB}_{distB}_ldC{LDC}_{addrC}_{distC}_alpha{alpha}_beta{beta}_transA{transA}_transB{transB}'.format(**gemm) offset_a = self._offset(term=d.leftTerm, offset2=(m.start, k.start), transpose=d.transA) @@ -194,7 +195,7 @@ def generate(self, cpp, routineCache): cpp(f'{routine_name}({args});') - routineCache.addRoutine(routine_name, LibsmmGemmGen(self._arch, gemm)) + routineCache.addRoutine(routine_name, TinytcGemmGen(self._arch, gemm)) else: gemm = { 'M': m.size(), @@ -499,7 +500,7 @@ def __call__(self, routineName, fileName): file.write(self._call(routineName)) return func_signature + ";" -class LibsmmGemmGen(GpuRoutineGenerator): +class TinytcGemmGen(GpuRoutineGenerator): def __init__(self, arch, gemm_descr): self.arch = arch self.gemm_descr = gemm_descr @@ -508,9 +509,11 @@ def __eq__(self, other): return self.arch == other.arch and self.gemm_descr == other.gemm_descr def header(self, cpp): - cpp.include('smm/configuration.hpp') - cpp.include('smm/gemm.hpp') - cpp.includeSys('CL/sycl.hpp') + cpp.include('tinytc/tinytc.hpp') + cpp.include('tinytc/tinytc_sycl.hpp') + cpp.includeSys('sycl/sycl.hpp') + cpp.includeSys('stdexcept') + cpp.includeSys('utility') def _functionSignature(self, routineName): typ = self.arch.typename @@ -520,30 +523,99 @@ def _functionSignature(self, routineName): starsC = stars(self.gemm_descr['addrC']) return f'void {routineName}({typ} const{starsA} A, int offsetA, {typ} const{starsB} B, int offsetB, {typ}{starsC} C, int offsetC, unsigned {BatchedOperationsAux.NUM_ELEMENTS_NAME}, void* {BatchedOperationsAux.STREAM_PTR_NAME})' - def address_mode(self, addr, dist): + def memref_type(self, addr, M, N, stride, dist): if addr == 'pointer_based': - return 'smm::pointers{}' + return f'group>, offset: ?>' elif addr == 'none': - return 'smm::strided{0}' + return f'memref<{M}x{N},strided<1,{stride}>>' elif addr == 'strided': - return f'smm::strided{{{dist}}}' + return f'memref<{M}x{N}x?,strided<1,{stride},{dist}>>' raise NameError(addr) def __call__(self, routineName, fileName): func_signature = self._functionSignature(routineName) with open(fileName, "a") as f: - aA = self.address_mode(self.gemm_descr['addrA'], self.gemm_descr['distA']) - aB = self.address_mode(self.gemm_descr['addrB'], self.gemm_descr['distB']) - aC = self.address_mode(self.gemm_descr['addrC'], self.gemm_descr['distC']) - - T = lambda x: 'smm::transpose::T' if x else 'smm::transpose::N' - tA = T(self.gemm_descr['transA']) - tB = T(self.gemm_descr['transB']) - alpha = self.gemm_descr['alpha'] - beta = self.gemm_descr['beta'] + scalar_ty = 'f64' if self.arch.bytesPerReal == 8 else 'f32' + gd = self.gemm_descr + + Operand = namedtuple('Operand', ['name', 'addr', 'rows', 'cols', 'ld', 'dist']) + def data_type(op): + if op.addr == 'pointer_based': + return f'group, offset: ?>' + elif op.addr == 'strided': + return f'memref<{scalar_ty}x{op.rows}x{op.cols}x?,strided<1,{op.ld},{op.dist}>>' + elif op.addr == 'none': + return f'memref<{scalar_ty}x{op.rows}x{op.cols},strided<1,{op.ld}>>' + else: + raise NameError(op.addr) + def load_inst(op): + if op.addr == 'pointer_based': + return f'load %{op.name}[%gid] : {data_type(op)}' + elif op.addr == 'strided': + return f'load %{op.name}[:,:,%gid] : {data_type(op)}' + elif op.addr == 'none': + return f'load %{op.name}[:,:] : {data_type(op)}' + else: + raise NameError(op.addr) + def mat_type(op): + return f'memref<{scalar_ty}x{op.rows}x{op.cols},strided<1,{op.ld}>>' + def call_args(op): + if op.addr == 'pointer_based': + return [op.name, f'offset{op.name}'] + elif op.addr == 'strided': + return [op.name, f'{BatchedOperationsAux.NUM_ELEMENTS_NAME}'] + elif op.addr == 'none': + return [op.name] + else: + raise NameError(op.addr) + + A = Operand('A', gd['addrA'], gd['M'], gd['K'], gd['LDA'], gd['distA']) + B = Operand('B', gd['addrB'], gd['K'], gd['N'], gd['LDB'], gd['distB']) + C = Operand('C', gd['addrC'], gd['M'], gd['N'], gd['LDC'], gd['distC']) + ops = [A, B, C] + + T = lambda x: 't' if x else 'n' + tA = T(gd['transA']) + tB = T(gd['transB']) + alpha = gd['alpha'] + beta = gd['beta'] f.write(f'{func_signature} {{\n') - f.write(' static auto kernel = smm::make_gemm<{typ}>({tA}, {tB}, {M}, {N}, {K}, {LDA}, {aA}, {LDB}, {aB}, {LDC}, {aC}, *static_cast<::sycl::queue*>({stream}));\n'.format(typ=self.arch.typename, aA=aA, aB=aB, aC=aC, tA=tA, tB=tB, stream=BatchedOperationsAux.STREAM_PTR_NAME, **self.gemm_descr)) - f.write(f' kernel({alpha}, A, B, {beta}, C, {BatchedOperationsAux.NUM_ELEMENTS_NAME}).wait();\n') - f.write('}\n') + f.write(""" struct custom_kernel { ::sycl::kernel kernel; ::sycl::range<3u> group_size; }; + static auto k = [&](::sycl::queue const& queue) -> custom_kernel { + static const std::string source = R\"tinytc( +func @gemm(""") + f.write(', '.join([f'%{op.name}: {data_type(op)}' for op in ops])) + f.write(""") { +%gid = group_id +""") + for op in ops: + f.write(f'%{op.name.lower()} = {load_inst(op)}\n') + f.write(f'gemm.{tA}.{tB} {alpha}, %a, %b, {beta}, %c : {scalar_ty}, {mat_type(A)}, {mat_type(B)}, {scalar_ty}, {mat_type(C)}\n') + f.write("""})tinytc\"; + auto source_ctx = tinytc::make_source_context(); + try { + auto program = tinytc::parse_string(source, source_ctx); + auto info = tinytc::make_core_info(queue.get_device()); + auto binary = tinytc::compile_to_binary(program, info, tinytc::bundle_format::native, source_ctx); + auto bundle = tinytc::make_kernel_bundle(queue.get_context(), queue.get_device(), binary); + auto kernel = tinytc::make_kernel(bundle, "gemm"); + auto group_size = tinytc::get_group_size(kernel); + return {std::move(kernel), std::move(group_size)}; + } catch (tinytc::status const& st) { + throw std::runtime_error(source_ctx.get_error_log()); + } + }""") + f.write(f'(*static_cast<::sycl::queue*>({BatchedOperationsAux.STREAM_PTR_NAME}));\n') + args = [] + for op in ops: + args += call_args(op) + args_str = ', '.join(args) + f.write(f""" static_cast<::sycl::queue*>({BatchedOperationsAux.STREAM_PTR_NAME})->submit([&](::sycl::handler &h) {{ + h.set_args({args_str}); + h.parallel_for(::sycl::nd_range{{tinytc::get_global_size({BatchedOperationsAux.NUM_ELEMENTS_NAME}, k.group_size), k.group_size}}, k.kernel); + }}).wait(); +}} +""") + return func_signature + ";" diff --git a/yateto/gemm_configuration.py b/yateto/gemm_configuration.py index cafc0f7..aeb7c79 100644 --- a/yateto/gemm_configuration.py +++ b/yateto/gemm_configuration.py @@ -258,7 +258,7 @@ def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, ali return Preference.LOWEST return Preference.HIGH -class libsmm(CodeGenerator): +class tinytc(CodeGenerator): def __init__(self, arch): super().__init__('', [], '', arch) self._arch = arch @@ -336,7 +336,8 @@ def __init__(self, arch): self.gemmTools = defaults[arch.host_name] if arch.is_accelerator: if arch.backend == 'oneapi': - self.gemmTools.extend([libsmm(arch)]) - self.gemmTools.extend([forge]) + self.gemmTools.extend([tinytc(arch)]) + else: + self.gemmTools.extend([forge]) else: raise Exception("Default generator collection for architecture {} is missing.".format(arch)) From e5dd64bde973bece9b2cb0be1135956a431c4e0d Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 30 Apr 2024 01:52:45 -0700 Subject: [PATCH 7/8] Update fused gemm generator selection Signed-off-by: Carsten Uphoff --- yateto/codegen/factory.py | 2 +- yateto/codegen/fused_gemms/factory.py | 16 +++++++++------- yateto/generator.py | 20 +++++++++++++------- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/yateto/codegen/factory.py b/yateto/codegen/factory.py index 7065dfe..85c4b8a 100644 --- a/yateto/codegen/factory.py +++ b/yateto/codegen/factory.py @@ -110,7 +110,7 @@ def create_LoopOverGEMM(self, node, result, arguments, add, scalar, prefetchName def create_FusedGEMMs(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): description = fused_gemms.Description(node, result, arguments, add, scalar) - generator = fused_gemms.generator(self._arch, description, self._target) + generator = fused_gemms.generator(self._arch, description, gemm_cfg, self._target) return generator.generate(self._cpp, routineCache, gemm_cfg) def create_IndexSum(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): diff --git a/yateto/codegen/fused_gemms/factory.py b/yateto/codegen/fused_gemms/factory.py index 0beb0eb..c2f2c8c 100644 --- a/yateto/codegen/fused_gemms/factory.py +++ b/yateto/codegen/fused_gemms/factory.py @@ -15,6 +15,7 @@ raise ('Found chainforge spec but cannot load. Please, check installation of chainforge') from .tinytc import FusedGemmsTinytc +from ...gemm_configuration import tinytc class Description(object): @@ -43,10 +44,11 @@ def __next__(self): raise StopIteration -def generator(arch, descr, target): - if target == 'gpu' and gb_spec: - return FusedGemms(arch, descr) - elif target == 'gpu': - return FusedGemmsTinytc(arch, descr) - else: - raise NotImplementedError(f'no implementation found for {target} target') +def generator(arch, descr, gemm_cfg, target): + if target == 'gpu': + hasTinytc = any([isinstance(tool, tinytc) for tool in gemm_cfg.gemmTools]) + if hasTinytc: + return FusedGemmsTinytc(arch, descr) + elif gb_spec: + return FusedGemms(arch, descr) + raise NotImplementedError(f'no implementation found for {target} target') diff --git a/yateto/generator.py b/yateto/generator.py index ac1d655..9cffea9 100644 --- a/yateto/generator.py +++ b/yateto/generator.py @@ -14,7 +14,7 @@ from .codegen.visitor import * from .controlflow.visitor import AST2ControlFlow from .controlflow.transformer import * -from .gemm_configuration import GeneratorCollection, DefaultGeneratorCollection, BLASlike +from .gemm_configuration import GeneratorCollection, DefaultGeneratorCollection, BLASlike, tinytc from typing import List from io import StringIO import importlib.util @@ -65,7 +65,7 @@ def prepareUntilUnitTest(self): self.cfg = ast2cf.cfg() self.cfg = LivenessAnalysis().visit(self.cfg) - def prepareUntilCodeGen(self, cost_estimator): + def prepareUntilCodeGen(self, cost_estimator, enableFusedGemm: bool): self.nonZeroFlops = 0 for a in self.ast: ast = copy.deepcopy(a) @@ -102,7 +102,7 @@ def prepareUntilCodeGen(self, cost_estimator): self.cfg = SubstituteBackward().visit(self.cfg) self.cfg = RemoveEmptyStatements().visit(self.cfg) self.cfg = MergeActions().visit(self.cfg) - if self.target == 'gpu':# and chainforge_spec: + if self.target == 'gpu' and enableFusedGemm: self.cfg = FindFusedGemms().visit(self.cfg) self.cfg = LivenessAnalysis().visit(self.cfg) @@ -176,9 +176,9 @@ def prepareUntilUnitTest(self): for kernel in self._kernels.values(): kernel.prepareUntilUnitTest() - def prepareUntilCodeGen(self, costEstimator): + def prepareUntilCodeGen(self, costEstimator, enableFusedGemm: bool): for kernel in self._kernels.values(): - kernel.prepareUntilCodeGen(costEstimator) + kernel.prepareUntilCodeGen(costEstimator, enableFusedGemm) def simpleParameterSpace(*args): return list(itertools.product(*[list(range(i)) for i in args])) @@ -268,6 +268,12 @@ def generate(self, if not gemm_cfg: gemm_cfg = DefaultGeneratorCollection(self._arch) + for tool in gemm_cfg.gemmTools: + print(tool, isinstance(tool, tinytc)) + + hasTinytc = any([isinstance(tool, tinytc) for tool in gemm_cfg.gemmTools]) + enableFusedGemm = bool(chainforge_spec) or hasTinytc + print('Deducing indices...') for kernel in self._kernels: kernel.prepareUntilUnitTest() @@ -299,10 +305,10 @@ def unit_test_body(cpp, testFramework): print('Optimizing ASTs...') for kernel in self._kernels: print(f'{kernel.name} ({len(kernel.ast)} AST(s))') - kernel.prepareUntilCodeGen(cost_estimator) + kernel.prepareUntilCodeGen(cost_estimator, enableFusedGemm) for family in self._kernelFamilies.values(): print(f'{family.name} ({sum(len(kernel.ast) for kernel in family.kernels())} AST(s))') - family.prepareUntilCodeGen(cost_estimator) + family.prepareUntilCodeGen(cost_estimator, enableFusedGemm) # Create mapping from namespace to kernel/family kernel_dict = {} From 4f248e7690f5847913c580757210b0715dacbdbd Mon Sep 17 00:00:00 2001 From: Carsten Uphoff Date: Tue, 30 Apr 2024 01:57:06 -0700 Subject: [PATCH 8/8] pvc missing in gemmforge Signed-off-by: Carsten Uphoff --- yateto/arch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/yateto/arch.py b/yateto/arch.py index f1d4c1b..12a6e25 100644 --- a/yateto/arch.py +++ b/yateto/arch.py @@ -157,8 +157,6 @@ def getHeterogeneousArchitectureIdentifiedBy(host_arch, device_arch, device_back alignment = 128 elif device_arch in ['dg1', 'bdw', 'skl', 'Gen8', 'Gen9', 'Gen11', 'Gen12LP']: alignment = 32 - elif device_arch in ['pvc']: - alignment = 64 else: raise ValueError(f'Unknown device arch: {device_arch}')