Skip to content

Commit

Permalink
Fix isnan assertions for scalar families
Browse files Browse the repository at this point in the history
  • Loading branch information
davschneller committed Sep 5, 2023
1 parent 69b6ac8 commit 97142d9
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions yateto/codegen/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,13 @@ def generate_extra_offset_args(base_name_with_namespace, groups):
continue

with cpp.Function('{}::{}::{}'.format(self.NAMESPACE, name, executeName(index))):
sclrs = sorted(list(kernelOutline.scalars), key=str)
for scalar in sclrs:
cpp('assert(!std::isnan({}));'.format(scalar))
for base_name_with_namespace, groups in kernelOutline.scalars:
base_name = Tensor.splitBasename(base_name_with_namespace)[-1]
if len(next(iter(groups))) > 0:
for gis in groups:
cpp('assert(!std::isnan({}({})));'.format(base_name, ','.join(str(gi) for gi in gis)))
else:
cpp(f'assert(!std::isnan({base_name}));')
for base_name_with_namespace, groups in kernelOutline.tensors.items():
base_name = Tensor.splitBasename(base_name_with_namespace)[-1]
if len(next(iter(groups))) > 0:
Expand Down

0 comments on commit 97142d9

Please sign in to comment.