Skip to content

Commit

Permalink
Refactor fused GEMMs for tinytc
Browse files Browse the repository at this point in the history
Signed-off-by: Carsten Uphoff <[email protected]>
  • Loading branch information
uphoffc committed Jun 26, 2024
1 parent 7870c71 commit a782c9c
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 220 deletions.
29 changes: 28 additions & 1 deletion yateto/codegen/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .. import aspp
from ..ast.indices import BoundingBox
from ..ast.log import splitByDistance
from .tiny_tensor_language import Dump, Function
from .tiny_tensor_language import Dump, Function, IntegerType, MemrefType, GroupType, IntImmValue, DYNAMIC, SubviewInst, LoadInst
import hashlib


Expand Down Expand Up @@ -232,3 +232,30 @@ def call(self):
def prototype(self):
return f'void {self.name}({", ".join(self.wrapper_args)});'

def makeMemrefType(scalarTy, memoryLayout, needsBatchMode: bool):
shape = tuple(r.size() for r in memoryLayout.bbox())
stride = memoryLayout.stride()
if needsBatchMode:
shape = shape + (DYNAMIC, )
stride = stride + (memoryLayout.requiredReals(), )
return MemrefType(scalarTy, shape, stride)

def makeBatchType(scalarTy, memoryLayout, isComputeConstant: bool, isTemporary: bool):
if isComputeConstant:
return makeMemrefType(scalarTy, memoryLayout, False)
elif isTemporary:
return makeMemrefType(scalarTy, memoryLayout, True)
else:
return GroupType(makeMemrefType(scalarTy, memoryLayout, False), DYNAMIC)

def makeLoad(bb, operand, gid, isComputeConstant: bool, isTemporary: bool):
if isComputeConstant:
return operand
elif isTemporary:
offsetList = [IntImmValue(IntegerType.index, 0)] * (operand.type().order() - 1)
sizeList = [IntImmValue(IntegerType.index, DYNAMIC)] * (operand.type().order() - 1)
offsetList.append(gid)
sizeList.append(None)
return bb.add(SubviewInst(operand, offsetList, sizeList))
else:
return bb.add(LoadInst(operand, [gid]))
49 changes: 13 additions & 36 deletions yateto/codegen/copyscaleadd/tinytc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause

from ..common import TensorDescription, IndexedTensorDescription, BatchedOperationsAux, TinytcKernelArgument, TinytcScalarKernelArgument, TinytcWrapper
from ..common import TensorDescription, IndexedTensorDescription, BatchedOperationsAux, TinytcKernelArgument, TinytcScalarKernelArgument, TinytcWrapper, makeMemrefType, makeBatchType, makeLoad
from ..cache import TinytcWriter
from ..tiny_tensor_language import *

Expand All @@ -19,37 +19,6 @@ def __init__(self, arch, descr):
def generate(self, cpp, routineCache):
d = self._descr

def MakeMemrefType(ml, needsBatchMode):
shape = tuple(r.size() for r in ml.bbox())
stride = ml.stride()
if needsBatchMode:
shape = shape + (DYNAMIC, )
stride = stride + (ml.requiredReals(), )
return MemrefType(self._ty, shape, stride)

def MakeBatchType(var):
ml = var.memoryLayout
if var.is_compute_constant:
return MakeMemrefType(ml, False)
elif var.is_temporary:
return MakeMemrefType(ml, True)
else:
return GroupType(MakeMemrefType(ml, False), DYNAMIC)

def MakeLoad(bb, var, operand, gid):
if var.is_compute_constant:
return operand
elif var.is_temporary:
offset_list = [IntImmValue(IntegerType.index, 0)
] * (operand.type().order() - 1)
size_list = [IntImmValue(IntegerType.index, DYNAMIC)
] * (operand.type().order() - 1)
offset_list.append(gid)
size_list.append(None)
return bb.add(SubviewInst(operand, offset_list, size_list))
else:
return bb.add(LoadInst(operand, [gid]))

# Order can be 1 or 2
def MakeLoopOverAxpby(d, order, transpose, A, B):
beta = FloatImmValue(self._ty, d.beta)
Expand Down Expand Up @@ -93,14 +62,22 @@ def MakeLoopOverAxpby(d, order, transpose, A, B):
return csa_region

alpha = LocalValue(self._ty, 'alpha')
Abatch = LocalValue(MakeBatchType(d.term), 'A')
Bbatch = LocalValue(MakeBatchType(d.result), 'B')
Abatch = LocalValue(
makeBatchType(self._ty, d.term.memoryLayout,
d.term.is_compute_constant, d.term.is_temporary),
'A')
Bbatch = LocalValue(
makeBatchType(self._ty, d.result.memoryLayout,
d.result.is_compute_constant, d.result.is_temporary),
'B')
kernel = Function('copyscaleadd', [alpha, Abatch, Bbatch], None)

bb = RegionBuilder()
gid = bb.add(GroupIdInst())
A = MakeLoad(bb, d.term, Abatch, gid)
B = MakeLoad(bb, d.result, Bbatch, gid)
A = makeLoad(bb, Abatch, gid, d.term.is_compute_constant,
d.term.is_temporary)
B = makeLoad(bb, Bbatch, gid, d.result.is_compute_constant,
d.result.is_temporary)

trans = Transpose.n
if len(d.result.indices) == 0:
Expand Down
Loading

0 comments on commit a782c9c

Please sign in to comment.