Skip to content

Commit

Permalink
Merge pull request #61 from SeisSol/davschneller/more-cpu-archs
Browse files Browse the repository at this point in the history
Add Some More CPU Architectures
  • Loading branch information
davschneller authored Sep 21, 2023
2 parents 274b27b + ba01a07 commit 0770647
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
14 changes: 12 additions & 2 deletions yateto/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _get_name_and_precision(ident):
def getArchitectureIdentifiedBy(ident):
name, precision = _get_name_and_precision(ident)

# NOTE: ibxsmm currently supports prefetch only for KNL kernels
# NOTE: libxsmm currently supports prefetch only for KNL kernels
arch = {
'noarch': Architecture(name, precision, 16, False),
'wsm': Architecture(name, precision, 16, False),
Expand All @@ -124,10 +124,20 @@ def getArchitectureIdentifiedBy(ident):
'skx': Architecture(name, precision, 64, True),
'knc': Architecture(name, precision, 64, False),
'knl': Architecture(name, precision, 64, True),
'naples': Architecture(name, precision, 32, False),
'rome': Architecture(name, precision, 32, False),
'milan': Architecture(name, precision, 32, False),
'bergamo': Architecture(name, precision, 64, True),
'thunderx2t99': Architecture(name, precision, 16, False),
'a64fx': Architecture(name, precision, 64, True),
'power9': Architecture(name, precision, 16, False)
'neon': Architecture(name, precision, 16, False),
'apple-m1': Architecture(name, precision, 16, False),
'sve128': Architecture(name, precision, 16, False),
'sve256': Architecture(name, precision, 32, False),
'sve512': Architecture(name, precision, 64, False),
'sve1024': Architecture(name, precision, 128, False),
'sve2048': Architecture(name, precision, 256, False),
'power9': Architecture(name, precision, 16, False),
}
return arch[name]

Expand Down
17 changes: 15 additions & 2 deletions yateto/codegen/gemm/gemmgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,17 @@ def __call__(self, routineName, fileName):
if self._mode == 'pspamm':
pspamm_arch = cpu_arch
if cpu_arch == 'a64fx':
pspamm_arch = 'arm_sve'
pspamm_arch = 'arm_sve512'
elif cpu_arch in ['apple-m1', 'thunderx2t99', 'neon']:
pspamm_arch = 'arm'
elif cpu_arch.startswith('sve'):
pspamm_arch = f'arm_{cpu_arch}' # TODO(David): rename to sveLEN only
elif cpu_arch in ['naples', 'rome', 'milan']:
# names are Zen1, Zen2, Zen3, respectively
# no explicit support for these archs yet, but they have the same instruction sets (AVX2+FMA3) that HSW also needs
pspamm_arch = 'hsw'
elif cpu_arch in ['bergamo']:
pspamm_arch = 'skx'
argList = [
self._cmd,
self._gemmDescr['M'],
Expand All @@ -268,6 +274,13 @@ def __call__(self, routineName, fileName):
for key, val in self._blockSize.items():
argList.extend(['--' + key, val])
else:
libxsmm_arch = cpu_arch
if cpu_arch in ['naples', 'rome', 'milan']:
# names are Zen1, Zen2, Zen3, respectively
# no explicit support for these archs yet, but they have the same instruction sets (AVX2+FMA3) that HSW also needs
libxsmm_arch = 'hsw'
elif cpu_arch in ['bergamo']:
libxsmm_arch = 'skx'
argList = [
self._cmd,
'dense',
Expand All @@ -283,7 +296,7 @@ def __call__(self, routineName, fileName):
self._gemmDescr['beta'],
self._gemmDescr['alignedA'],
self._gemmDescr['alignedC'],
'hsw' if cpu_arch == 'rome' else cpu_arch, # libxsmm has no support for rome, hsw works well in practice
libxsmm_arch, # libxsmm has no support for rome, hsw works well in practice
self._gemmDescr['prefetch'],
self._arch.precision + 'P'
]
Expand Down
19 changes: 14 additions & 5 deletions yateto/gemm_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, ali
return Preference.LOW

def _archSupported(self):
supported_set = {'noarch', 'wsm', 'snb', 'hsw', 'skx', 'knc', 'knl', 'rome', "a64fx", "thunderx2t99"}
supported_set = {'noarch', 'wsm', 'snb', 'hsw', 'skx', 'knc', 'knl', 'naples', 'rome', 'milan', 'bergamo', "a64fx", "thunderx2t99", 'neon', 'sve128', 'sve256', 'sve512', 'apple-m1'}

if self._arch.name.lower() in supported_set:
return True
Expand All @@ -184,7 +184,7 @@ def __init__(self, arch, cmd: str = 'libxsmm_gemm_generator', threshold: int = 1
self._threshold = threshold

def _archSupported(self):
supported_set = {'noarch', 'wsm', 'snb', 'hsw', 'skx', 'knc', 'knl', 'rome'}
supported_set = {'noarch', 'wsm', 'snb', 'hsw', 'skx', 'knc', 'knl', 'naples', 'rome', 'milan', 'bergamo'}

if self._arch.name.lower() in supported_set:
return True
Expand All @@ -210,7 +210,7 @@ def __init__(self, arch, cmd: str = 'pspamm.py', threshold: int = 128):
self._threshold = threshold

def _archSupported(self):
supported_set = {'thunderx2t99', 'knl', 'skx', 'a64fx'}
supported_set = {'thunderx2t99', 'knl', 'skx', 'a64fx', 'hsw', 'naples', 'rome', 'milan', 'bergamo', 'neon', 'sve128', 'sve256', 'sve512', 'sve1024', 'sve2048', 'apple-m1'}
if self._arch.name.lower() in supported_set:
return True
else:
Expand Down Expand Up @@ -295,12 +295,21 @@ def __init__(self, arch):
forge = GemmForge(arch)
defaults = {
'snb' : [libxsmm_jit, libxsmm, mkl, blis, eigen],
'hsw' : [libxsmm_jit, libxsmm, mkl, blis, eigen],
'rome' : [libxsmm_jit, libxsmm, blis, eigen],
'hsw' : [libxsmm_jit, libxsmm, pspamm, mkl, blis, eigen],
'naples' : [libxsmm_jit, libxsmm, pspamm, blis, eigen],
'rome' : [libxsmm_jit, libxsmm, pspamm, blis, eigen],
'milan' : [libxsmm_jit, libxsmm, pspamm, blis, eigen],
'bergamo' : [libxsmm_jit, libxsmm, pspamm, blis, eigen],
'knl' : [libxsmm_jit, libxsmm, pspamm, mkl, blis, eigen],
'skx' : [libxsmm_jit, libxsmm, pspamm, mkl, blis, eigen],
'thunderx2t99' : [libxsmm_jit, pspamm, openblas, blis, eigen],
'a64fx' : [libxsmm_jit, pspamm, openblas, blis, eigen],
'neon' : [libxsmm_jit, pspamm, openblas, blis, eigen],
'sve128' : [libxsmm_jit, pspamm, openblas, blis, eigen],
'sve256' : [libxsmm_jit, pspamm, openblas, blis, eigen],
'sve512' : [libxsmm_jit, pspamm, openblas, blis, eigen],
'sve1024' : [pspamm, openblas, blis, eigen],
'sve2048' : [pspamm, openblas, blis, eigen],
'power9' : [openblas, blis, eigen]
}

Expand Down

0 comments on commit 0770647

Please sign in to comment.