From 97142d9a84ea625b13ab74a76373bc9bbfe4c588 Mon Sep 17 00:00:00 2001 From: David Schneller Date: Tue, 5 Sep 2023 12:06:25 +0200 Subject: [PATCH] Fix isnan assertions for scalar families --- yateto/codegen/visitor.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/yateto/codegen/visitor.py b/yateto/codegen/visitor.py index 933acd6..ee200ce 100644 --- a/yateto/codegen/visitor.py +++ b/yateto/codegen/visitor.py @@ -381,9 +381,13 @@ def generate_extra_offset_args(base_name_with_namespace, groups): continue with cpp.Function('{}::{}::{}'.format(self.NAMESPACE, name, executeName(index))): - sclrs = sorted(list(kernelOutline.scalars), key=str) - for scalar in sclrs: - cpp('assert(!std::isnan({}));'.format(scalar)) + for base_name_with_namespace, groups in kernelOutline.scalars: + base_name = Tensor.splitBasename(base_name_with_namespace)[-1] + if len(next(iter(groups))) > 0: + for gis in groups: + cpp('assert(!std::isnan({}({})));'.format(base_name, ','.join(str(gi) for gi in gis))) + else: + cpp(f'assert(!std::isnan({base_name}));') for base_name_with_namespace, groups in kernelOutline.tensors.items(): base_name = Tensor.splitBasename(base_name_with_namespace)[-1] if len(next(iter(groups))) > 0: