From 18368ab860b90e2551d44b3f27657c4f3eb6d4ad Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Tue, 11 Feb 2025 19:57:46 +0100 Subject: [PATCH 1/2] Fix indexing code for locals --- crates/cubecl-spirv/src/variable.rs | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/crates/cubecl-spirv/src/variable.rs b/crates/cubecl-spirv/src/variable.rs index 3e317dc2c..ecea52697 100644 --- a/crates/cubecl-spirv/src/variable.rs +++ b/crates/cubecl-spirv/src/variable.rs @@ -526,25 +526,19 @@ impl SpirvCompiler { Variable::Versioned { id, item: Item::Vector(elem, vec), - } => IndexedVariable::Composite( + } if index.as_const().is_some() => IndexedVariable::Composite( self.get_versioned(*id), - index - .as_const() - .expect("Index into vector must be constant") - .as_u32(), + index.as_const().unwrap().as_u32(), + Item::Vector(*elem, *vec), + ), + Variable::Versioned { + id, + item: Item::Vector(elem, vec), + } => IndexedVariable::DynamicComposite( + self.get_versioned(*id), + self.read(index), Item::Vector(*elem, *vec), ), - Variable::LocalBinding { .. } | Variable::Local { .. } => { - let index = index - .as_const() - .expect("Index into vector must be constant") - .as_u32(); - if index > 0 { - panic!("Tried accessing {index}th element of scalar!"); - } else { - IndexedVariable::Scalar(variable.clone()) - } - } Variable::Slice { ptr, offset, .. } => { let item = Item::Scalar(Elem::Int(32, false)); let int = item.id(self); From ea02c250d8b336649a830d1c91f362324d4cc8f6 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Tue, 11 Feb 2025 20:09:29 +0100 Subject: [PATCH 2/2] Fix scalar indexing --- crates/cubecl-spirv/src/variable.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/cubecl-spirv/src/variable.rs b/crates/cubecl-spirv/src/variable.rs index ecea52697..65d650918 100644 --- a/crates/cubecl-spirv/src/variable.rs +++ b/crates/cubecl-spirv/src/variable.rs @@ -539,6 +539,9 @@ impl SpirvCompiler { self.read(index), Item::Vector(*elem, *vec), ), + Variable::Local { .. } | Variable::LocalBinding { .. } | Variable::Versioned { .. } => { + IndexedVariable::Scalar(variable.clone()) + } Variable::Slice { ptr, offset, .. } => { let item = Item::Scalar(Elem::Int(32, false)); let int = item.id(self);