diff --git a/yateto/arch.py b/yateto/arch.py index 334af1b..12a6e25 100644 --- a/yateto/arch.py +++ b/yateto/arch.py @@ -132,6 +132,7 @@ def getArchitectureIdentifiedBy(ident): 'a64fx': Architecture(name, precision, 64, True), 'neon': Architecture(name, precision, 16, False), 'apple-m1': Architecture(name, precision, 16, False), + 'apple-m2': Architecture(name, precision, 16, False), 'sve128': Architecture(name, precision, 16, False), 'sve256': Architecture(name, precision, 32, False), 'sve512': Architecture(name, precision, 64, False), diff --git a/yateto/codegen/gemm/gemmgen.py b/yateto/codegen/gemm/gemmgen.py index 04d6264..1537802 100644 --- a/yateto/codegen/gemm/gemmgen.py +++ b/yateto/codegen/gemm/gemmgen.py @@ -399,7 +399,7 @@ def _kernel(self, routine_name): #flags += ["LIBXSMM_GEMM_FLAG_ALIGN_C"] libxsmm_flag_str = " | ".join(flags) - prefetch_flag = "LIBXSMM_GEMM_PREFETCH_SIGONLY" if not self._arch.enablePrefetch else "LIBXSMM_GEMM_PREFETCH_BL2_VIA_C" + prefetch_flag = "LIBXSMM_GEMM_PREFETCH_NONE" if not self._arch.enablePrefetch else "LIBXSMM_GEMM_PREFETCH_BL2_VIA_C" kernel_var_name = f'{routine_name}_var' return """ diff --git a/yateto/gemm_configuration.py b/yateto/gemm_configuration.py index d3fbd64..54dc1ed 100644 --- a/yateto/gemm_configuration.py +++ b/yateto/gemm_configuration.py @@ -162,7 +162,7 @@ def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, ali return Preference.LOW def _archSupported(self): - supported_set = {'noarch', 'wsm', 'snb', 'hsw', 'skx', 'knc', 'knl', 'naples', 'rome', 'milan', 'bergamo', "a64fx", "thunderx2t99", 'neon', 'sve128', 'sve256', 'sve512', 'apple-m1'} + supported_set = {'noarch', 'wsm', 'snb', 'hsw', 'skx', 'knc', 'knl', 'naples', 'rome', 'milan', 'bergamo', "a64fx", "thunderx2t99", 'neon', 'sve128', 'sve256', 'sve512', 'apple-m1', "apple-m2"} if self._arch.name.lower() in supported_set: return True @@ -210,7 +210,7 @@ def __init__(self, arch, cmd: str = 'pspamm.py', threshold: int = 128): self._threshold = threshold def _archSupported(self): - supported_set = {'thunderx2t99', 'knl', 'skx', 'a64fx', 'hsw', 'naples', 'rome', 'milan', 'bergamo', 'neon', 'sve128', 'sve256', 'sve512', 'sve1024', 'sve2048', 'apple-m1'} + supported_set = {'thunderx2t99', 'knl', 'skx', 'a64fx', 'hsw', 'naples', 'rome', 'milan', 'bergamo', 'neon', 'sve128', 'sve256', 'sve512', 'sve1024', 'sve2048', 'apple-m1', 'apple-m2'} if self._arch.name.lower() in supported_set: return True else: @@ -303,6 +303,8 @@ def __init__(self, arch): 'knl' : [libxsmm_jit, libxsmm, pspamm, mkl, blis, eigen], 'skx' : [libxsmm_jit, libxsmm, pspamm, mkl, blis, eigen], 'thunderx2t99' : [libxsmm_jit, pspamm, openblas, blis, eigen], + 'apple-m1' : [libxsmm_jit, pspamm, openblas, blis, eigen], + 'apple-m2' : [libxsmm_jit, pspamm, openblas, blis, eigen], 'a64fx' : [libxsmm_jit, pspamm, openblas, blis, eigen], 'neon' : [libxsmm_jit, pspamm, openblas, blis, eigen], 'sve128' : [libxsmm_jit, pspamm, openblas, blis, eigen], diff --git a/yateto/util.py b/yateto/util.py index 9a28d20..108ecf1 100644 --- a/yateto/util.py +++ b/yateto/util.py @@ -27,7 +27,7 @@ def create_collection(matrices): def tensor_from_constant_expression(name: str, expression, target_indices: Indices = None, - dtype: dtype = np.float128, + dtype: dtype = np.longdouble, tensor_args: dict = dict()): """ Computes the result of an expression and returns @@ -55,7 +55,7 @@ def tensor_collection_from_constant_expression(base_name: str, expressions, group_indices, target_indices: Indices = None, - dtype=np.float128, + dtype=np.longdouble, tensor_args: dict = {}): """ Computes the result of an expression group and returns