From 18ef6f8ef62308fdcc0072b5e7b898283f9c8c4e Mon Sep 17 00:00:00 2001 From: David Schneller Date: Sun, 7 Apr 2024 04:12:10 +0200 Subject: [PATCH] Refine array lengths one more time --- yateto/codegen/gpukernel.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/yateto/codegen/gpukernel.py b/yateto/codegen/gpukernel.py index 6c109d9..bed4535 100644 --- a/yateto/codegen/gpukernel.py +++ b/yateto/codegen/gpukernel.py @@ -3,7 +3,7 @@ from common import * from .common import TensorDescription, IndexedTensorDescription, BatchedOperationsAux -from ..ast.indices import BoundingBox +from ..ast.indices import BoundingBox, Range from ..type import Scalar from .cache import RoutineGenerator, GpuRoutineGenerator from kernelforge.interface import YatetoInterface as yi @@ -47,6 +47,7 @@ def generate(self, cpp, routineCache): routineCache.addRoutine(routine_name, KernelForgeWriter(kernelforge_generator, context.get_vm().get_headers())) def _can_be_aligned(self, dest, ops, target, permute): + # TODO: useful? aligned = dest.memoryLayout.alignedStride() for i, op in enumerate(ops): if 0 in target[i]: @@ -62,16 +63,15 @@ def make_tensor(op, dims): entry = self._add_scalar(op) entry_name = op.name() else: - currentRangePre = BoundingBox.fromSpp(op.eqspp) - currentRange = list(currentRangePre) - currentShape = list(op.memoryLayout.shape()) + # TODO: refine + currentPreShape = list(BoundingBox.fromSpp(op.eqspp)) if can_be_aligned: for i, dim in enumerate(dims): - if dim == 0: - currentRange[i] = currentRange[i].aligned(self._arch) - - # unstable/incorrect? TODO: check (for now, it should work) - currentShape[i] = max(self._arch.alignedUpper(currentShape[i]), currentRange[i].stop) + if i == 0 and op.memoryLayout.alignedStride(): # previously: dim == 0 + currentPreShape[i] = currentPreShape[i].aligned(self._arch) + currentShape = [b.stop for b in currentPreShape] + currentRange = list(BoundingBox(Range(0, b) for b in currentShape)) + entry = self._get_kernelforge_matrix(tensor=op, tensor_variable=op, shape=currentShape,