diff --git a/crates/cubecl-spirv/src/variable.rs b/crates/cubecl-spirv/src/variable.rs index 3e317dc2..65d65091 100644 --- a/crates/cubecl-spirv/src/variable.rs +++ b/crates/cubecl-spirv/src/variable.rs @@ -526,24 +526,21 @@ impl SpirvCompiler { Variable::Versioned { id, item: Item::Vector(elem, vec), - } => IndexedVariable::Composite( + } if index.as_const().is_some() => IndexedVariable::Composite( self.get_versioned(*id), - index - .as_const() - .expect("Index into vector must be constant") - .as_u32(), + index.as_const().unwrap().as_u32(), Item::Vector(*elem, *vec), ), - Variable::LocalBinding { .. } | Variable::Local { .. } => { - let index = index - .as_const() - .expect("Index into vector must be constant") - .as_u32(); - if index > 0 { - panic!("Tried accessing {index}th element of scalar!"); - } else { - IndexedVariable::Scalar(variable.clone()) - } + Variable::Versioned { + id, + item: Item::Vector(elem, vec), + } => IndexedVariable::DynamicComposite( + self.get_versioned(*id), + self.read(index), + Item::Vector(*elem, *vec), + ), + Variable::Local { .. } | Variable::LocalBinding { .. } | Variable::Versioned { .. } => { + IndexedVariable::Scalar(variable.clone()) } Variable::Slice { ptr, offset, .. } => { let item = Item::Scalar(Elem::Int(32, false));