diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs index 72a6a838..7a2a8360 100644 --- a/crates/cubecl-core/src/frontend/element/base.rs +++ b/crates/cubecl-core/src/frontend/element/base.rs @@ -382,5 +382,7 @@ pub(crate) fn __expand_new( val: C, ) -> ExpandElementTyped { let input: ExpandElementTyped = val.into(); - ::__expand_cast_from(scope, input) + let const_val = input.expand.as_const().unwrap(); + let var = Variable::constant(const_val.cast_to(Out::as_elem(scope))); + ExpandElement::Plain(var).into() } diff --git a/crates/cubecl-core/src/frontend/plane.rs b/crates/cubecl-core/src/frontend/plane.rs index 7334b581..7eea4911 100644 --- a/crates/cubecl-core/src/frontend/plane.rs +++ b/crates/cubecl-core/src/frontend/plane.rs @@ -87,6 +87,77 @@ pub mod plane_sum { } } +/// Perform an inclusive sum operation across all units in a plane. +/// This sums all values to the "left" of the unit, including this unit's value. +/// Also known as "prefix sum" or "inclusive scan". +/// +/// # Example +/// `inclusive_sum([1, 2, 3, 4, 5]) == [1, 3, 6, 10, 15]` +#[allow(unused_variables)] +pub fn plane_inclusive_sum(value: E) -> E { + unexpanded!() +} + +/// Module containing the expand function for [plane_inclusive_sum()]. +pub mod plane_inclusive_sum { + use super::*; + + /// Expand method of [plane_inclusive_sum()]. + pub fn expand( + scope: &mut Scope, + elem: ExpandElementTyped, + ) -> ExpandElementTyped { + let elem: ExpandElement = elem.into(); + let output = scope.create_local(elem.item); + + let out = *output; + let input = *elem; + + scope.register(Instruction::new( + Plane::InclusiveSum(UnaryOperator { input }), + out, + )); + + output.into() + } +} + +/// Perform an exclusive sum operation across all units in a plane. +/// This sums all values to the "left" of the unit, excluding this unit's value. The 0th unit will +/// be set to `E::zero()`. +/// Also known as "exclusive prefix sum" or "exclusive scan". +/// +/// # Example +/// `exclusive_sum([1, 2, 3, 4, 5]) == [0, 1, 3, 6, 10]` +#[allow(unused_variables)] +pub fn plane_exclusive_sum(value: E) -> E { + unexpanded!() +} + +/// Module containing the expand function for [plane_exclusive_sum()]. +pub mod plane_exclusive_sum { + use super::*; + + /// Expand method of [plane_exclusive_sum()]. + pub fn expand( + scope: &mut Scope, + elem: ExpandElementTyped, + ) -> ExpandElementTyped { + let elem: ExpandElement = elem.into(); + let output = scope.create_local(elem.item); + + let out = *output; + let input = *elem; + + scope.register(Instruction::new( + Plane::ExclusiveSum(UnaryOperator { input }), + out, + )); + + output.into() + } +} + /// Perform a reduce prod operation across all units in a plane. pub fn plane_prod(_elem: E) -> E { unexpanded!() @@ -113,6 +184,77 @@ pub mod plane_prod { } } +/// Perform an inclusive product operation across all units in a plane. +/// This multiplies all values to the "left" of the unit, including this unit's value. +/// Also known as "prefix product" or "inclusive scan". +/// +/// # Example +/// `exclusive_prod([1, 2, 3, 4, 5]) == [1, 2, 6, 24, 120]` +#[allow(unused_variables)] +pub fn plane_inclusive_prod(value: E) -> E { + unexpanded!() +} + +/// Module containing the expand function for [plane_inclusive_prod()]. +pub mod plane_inclusive_prod { + use super::*; + + /// Expand method of [plane_inclusive_prod()]. + pub fn expand( + scope: &mut Scope, + elem: ExpandElementTyped, + ) -> ExpandElementTyped { + let elem: ExpandElement = elem.into(); + let output = scope.create_local(elem.item); + + let out = *output; + let input = *elem; + + scope.register(Instruction::new( + Plane::InclusiveProd(UnaryOperator { input }), + out, + )); + + output.into() + } +} + +/// Perform an exclusive product operation across all units in a plane. +/// This multiplies all values to the "left" of the unit, excluding this unit's value. The 0th unit +/// will be set to `E::one()`. +/// Also known as "exclusive prefix product" or "exclusive scan". +/// +/// # Example +/// `exclusive_prod([1, 2, 3, 4, 5]) == [1, 1, 2, 6, 24]` +#[allow(unused_variables)] +pub fn plane_exclusive_prod(value: E) -> E { + unexpanded!() +} + +/// Module containing the expand function for [plane_exclusive_prod()]. +pub mod plane_exclusive_prod { + use super::*; + + /// Expand method of [plane_exclusive_prod()]. + pub fn expand( + scope: &mut Scope, + elem: ExpandElementTyped, + ) -> ExpandElementTyped { + let elem: ExpandElement = elem.into(); + let output = scope.create_local(elem.item); + + let out = *output; + let input = *elem; + + scope.register(Instruction::new( + Plane::ExclusiveProd(UnaryOperator { input }), + out, + )); + + output.into() + } +} + /// Perform a reduce max operation across all units in a plane. pub fn plane_max(_elem: E) -> E { unexpanded!() diff --git a/crates/cubecl-core/src/runtime_tests/plane.rs b/crates/cubecl-core/src/runtime_tests/plane.rs index b523e387..fb56099c 100644 --- a/crates/cubecl-core/src/runtime_tests/plane.rs +++ b/crates/cubecl-core/src/runtime_tests/plane.rs @@ -14,6 +14,22 @@ pub fn kernel_sum(output: &mut Tensor) { } } +#[cube(launch)] +pub fn kernel_inclusive_sum(output: &mut Tensor) { + let val = output[UNIT_POS]; + let val2 = plane_inclusive_sum(val); + + output[UNIT_POS] = val2; +} + +#[cube(launch)] +pub fn kernel_exclusive_sum(output: &mut Tensor) { + let val = output[UNIT_POS]; + let val2 = plane_exclusive_sum(val); + + output[UNIT_POS] = val2; +} + #[cube(launch)] pub fn kernel_prod(output: &mut Tensor) { let val = output[UNIT_POS]; @@ -24,6 +40,22 @@ pub fn kernel_prod(output: &mut Tensor) { } } +#[cube(launch)] +pub fn kernel_inclusive_prod(output: &mut Tensor) { + let val = output[UNIT_POS]; + let val2 = plane_inclusive_prod(val); + + output[UNIT_POS] = val2; +} + +#[cube(launch)] +pub fn kernel_exclusive_prod(output: &mut Tensor) { + let val = output[UNIT_POS]; + let val2 = plane_exclusive_prod(val); + + output[UNIT_POS] = val2; +} + #[cube(launch)] pub fn kernel_max(output: &mut Tensor) { let val = output[UNIT_POS]; @@ -122,6 +154,90 @@ pub fn test_plane_sum< ); } +pub fn test_plane_inclusive_sum< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( + client: ComputeClient, + vectorization: u8, +) { + let plane_size = 32; + let input: Vec = (0..plane_size * vectorization as u32) + .map(|x| x as f32) + .collect(); + let mut expected = input.clone(); + + for k in 1..plane_size as usize { + let offs_out = k * vectorization as usize; + for k_1 in 0..k { + let offs_in = k_1 * vectorization as usize; + for v in 0..vectorization as usize { + expected[v + offs_out] += input[v + offs_in]; + } + } + } + + let input: Vec = input.into_iter().map(|x| F::new(x)).collect(); + let expected: Vec = expected.into_iter().map(|x| F::new(x)).collect(); + + test_plane_operation::( + &input, + &expected, + vectorization, + client.clone(), + |cube_count, handle| { + kernel_inclusive_sum::launch::( + &client, + cube_count, + CubeDim::new(plane_size, 1, 1), + handle, + ) + }, + ); +} + +pub fn test_plane_exclusive_sum< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( + client: ComputeClient, + vectorization: u8, +) { + let plane_size = 32; + let input: Vec = (0..plane_size * vectorization as u32) + .map(|x| x as f32) + .collect(); + let mut expected = vec![0.0; input.len()]; + + for k in 1..plane_size as usize { + let offs_out = k * vectorization as usize; + for k_1 in 0..k { + let offs_in = k_1 * vectorization as usize; + for v in 0..vectorization as usize { + expected[v + offs_out] += input[v + offs_in]; + } + } + } + + let input: Vec = input.into_iter().map(|x| F::new(x)).collect(); + let expected: Vec = expected.into_iter().map(|x| F::new(x)).collect(); + + test_plane_operation::( + &input, + &expected, + vectorization, + client.clone(), + |cube_count, handle| { + kernel_exclusive_sum::launch::( + &client, + cube_count, + CubeDim::new(plane_size, 1, 1), + handle, + ) + }, + ); +} + pub fn test_plane_prod< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, @@ -164,6 +280,98 @@ pub fn test_plane_prod< ); } +pub fn test_plane_inclusive_prod< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( + client: ComputeClient, + vectorization: u8, +) { + let plane_size = 32; + let input: Vec = (0..plane_size * vectorization as u32) + .map(|x| match x % 3 { + 0 => 0.5, + 1 => 1.25, + 2 => 1.75, + _ => unreachable!(), + }) // keep the values relatively close to 1 to avoid overflow. + .collect(); + let mut expected = input.clone(); + + for k in 1..plane_size as usize { + let offs_out = k * vectorization as usize; + for k_1 in 0..k { + let offs_in = k_1 * vectorization as usize; + for v in 0..vectorization as usize { + expected[v + offs_out] *= input[v + offs_in]; + } + } + } + let input: Vec = input.into_iter().map(|x| F::new(x)).collect(); + let expected: Vec = expected.into_iter().map(|x| F::new(x)).collect(); + + test_plane_operation::( + &input, + &expected, + vectorization, + client.clone(), + |cube_count, handle| { + kernel_inclusive_prod::launch::( + &client, + cube_count, + CubeDim::new(plane_size, 1, 1), + handle, + ) + }, + ); +} + +pub fn test_plane_exclusive_prod< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( + client: ComputeClient, + vectorization: u8, +) { + let plane_size = 32; + let input: Vec = (0..plane_size * vectorization as u32) + .map(|x| match x % 3 { + 0 => 0.5, + 1 => 1.25, + 2 => 1.75, + _ => unreachable!(), + }) // keep the values relatively close to 1 to avoid overflow. + .collect(); + let mut expected = vec![1.0; input.len()]; + + for k in 1..plane_size as usize { + let offs_out = k * vectorization as usize; + for k_1 in 0..k { + let offs_in = k_1 * vectorization as usize; + for v in 0..vectorization as usize { + expected[v + offs_out] *= input[v + offs_in]; + } + } + } + let input: Vec = input.into_iter().map(|x| F::new(x)).collect(); + let expected: Vec = expected.into_iter().map(|x| F::new(x)).collect(); + + test_plane_operation::( + &input, + &expected, + vectorization, + client.clone(), + |cube_count, handle| { + kernel_exclusive_prod::launch::( + &client, + cube_count, + CubeDim::new(plane_size, 1, 1), + handle, + ) + }, + ); +} + pub fn test_plane_max< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, @@ -474,6 +682,46 @@ macro_rules! testgen_plane { impl_test_plane_sum(4); } + fn impl_test_plane_inclusive_sum(vectorization: u8) { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::plane::test_plane_inclusive_sum::( + client.clone(), + vectorization, + ); + } + #[test] + fn test_plane_inclusive_sum_vec1() { + impl_test_plane_inclusive_sum(1); + } + #[test] + fn test_plane_inclusive_sum_vec2() { + impl_test_plane_inclusive_sum(2); + } + #[test] + fn test_plane_inclusive_sum_vec4() { + impl_test_plane_inclusive_sum(4); + } + + fn impl_test_plane_exclusive_sum(vectorization: u8) { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::plane::test_plane_exclusive_sum::( + client.clone(), + vectorization, + ); + } + #[test] + fn test_plane_exclusive_sum_vec1() { + impl_test_plane_exclusive_sum(1); + } + #[test] + fn test_plane_exclusive_sum_vec2() { + impl_test_plane_exclusive_sum(2); + } + #[test] + fn test_plane_exclusive_sum_vec4() { + impl_test_plane_exclusive_sum(4); + } + fn impl_test_plane_prod(vectorization: u8) { let client = TestRuntime::client(&Default::default()); cubecl_core::runtime_tests::plane::test_plane_prod::( @@ -494,6 +742,46 @@ macro_rules! testgen_plane { impl_test_plane_prod(4); } + fn impl_test_plane_inclusive_prod(vectorization: u8) { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::plane::test_plane_inclusive_prod::( + client.clone(), + vectorization, + ); + } + #[test] + fn test_plane_inclusive_prod_vec1() { + impl_test_plane_inclusive_prod(1); + } + #[test] + fn test_plane_inclusive_prod_vec2() { + impl_test_plane_inclusive_prod(2); + } + #[test] + fn test_plane_inclusive_prod_vec4() { + impl_test_plane_inclusive_prod(4); + } + + fn impl_test_plane_exclusive_prod(vectorization: u8) { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::plane::test_plane_exclusive_prod::( + client.clone(), + vectorization, + ); + } + #[test] + fn test_plane_exclusive_prod_vec1() { + impl_test_plane_exclusive_prod(1); + } + #[test] + fn test_plane_exclusive_prod_vec2() { + impl_test_plane_exclusive_prod(2); + } + #[test] + fn test_plane_exclusive_prod_vec4() { + impl_test_plane_exclusive_prod(4); + } + fn impl_test_plane_max(vectorization: u8) { let client = TestRuntime::client(&Default::default()); cubecl_core::runtime_tests::plane::test_plane_max::( diff --git a/crates/cubecl-cpp/src/cuda/dialect.rs b/crates/cubecl-cpp/src/cuda/dialect.rs index 31c152b4..3e32a69b 100644 --- a/crates/cubecl-cpp/src/cuda/dialect.rs +++ b/crates/cubecl-cpp/src/cuda/dialect.rs @@ -80,6 +80,9 @@ impl> Dialect for CudaDialect { fn warp_shuffle_xor(var: &str, offset: &str) -> String { format!("__shfl_xor_sync(-1, {var}, {offset})") } + fn warp_shuffle_up(var: &str, offset: &str) -> String { + format!("__shfl_up_sync(-1, {var}, {offset})") + } fn warp_shuffle_down(var: &str, offset: &str) -> String { format!("__shfl_down_sync(-1, {var}, {offset})") } diff --git a/crates/cubecl-cpp/src/hip/dialect.rs b/crates/cubecl-cpp/src/hip/dialect.rs index 8511dcdb..bf9dc0a8 100644 --- a/crates/cubecl-cpp/src/hip/dialect.rs +++ b/crates/cubecl-cpp/src/hip/dialect.rs @@ -81,6 +81,9 @@ impl> Dialect for HipDialect { fn warp_shuffle_xor(var: &str, offset: &str) -> String { format!("__shfl_xor({var}, {offset})") } + fn warp_shuffle_up(var: &str, offset: &str) -> String { + format!("__shfl_up({var}, {offset})") + } fn warp_shuffle_down(var: &str, offset: &str) -> String { format!("__shfl_down({var}, {offset})") } diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index f55e68f5..ad3bf35e 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -32,6 +32,7 @@ pub trait Dialect: // warp instructions (all threads participating) fn warp_shuffle(var: &str, source: &str) -> String; fn warp_shuffle_xor(var: &str, offset: &str) -> String; + fn warp_shuffle_up(var: &str, offset: &str) -> String; fn warp_shuffle_down(var: &str, offset: &str) -> String; fn warp_all(var: &str) -> String; fn warp_any(var: &str) -> String; @@ -250,12 +251,40 @@ impl CppCompiler { out, })) } + gpu::Plane::InclusiveSum(op) => { + self.settings.idx_global = true; + instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum { + input: self.compile_variable(op.input), + out, + })) + } + gpu::Plane::ExclusiveSum(op) => { + self.settings.idx_global = true; + instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum { + input: self.compile_variable(op.input), + out, + })) + } gpu::Plane::Prod(op) => { instructions.push(Instruction::Warp(WarpInstruction::ReduceProd { input: self.compile_variable(op.input), out, })) } + gpu::Plane::InclusiveProd(op) => { + self.settings.idx_global = true; + instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd { + input: self.compile_variable(op.input), + out, + })) + } + gpu::Plane::ExclusiveProd(op) => { + self.settings.idx_global = true; + instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd { + input: self.compile_variable(op.input), + out, + })) + } gpu::Plane::Max(op) => { instructions.push(Instruction::Warp(WarpInstruction::ReduceMax { input: self.compile_variable(op.input), diff --git a/crates/cubecl-cpp/src/shared/warp.rs b/crates/cubecl-cpp/src/shared/warp.rs index 1a731268..8fc37de0 100644 --- a/crates/cubecl-cpp/src/shared/warp.rs +++ b/crates/cubecl-cpp/src/shared/warp.rs @@ -10,10 +10,26 @@ pub enum WarpInstruction { input: Variable, out: Variable, }, + InclusiveSum { + input: Variable, + out: Variable, + }, + ExclusiveSum { + input: Variable, + out: Variable, + }, ReduceProd { input: Variable, out: Variable, }, + InclusiveProd { + input: Variable, + out: Variable, + }, + ExclusiveProd { + input: Variable, + out: Variable, + }, ReduceMax { input: Variable, out: Variable, @@ -77,6 +93,14 @@ unsigned int leader = __ffs(mask) - 1; {out} = threadIdx.x % warpSize == leader; " ), + WarpInstruction::InclusiveSum { input, out } => reduce_inclusive(f, input, out, "+="), + WarpInstruction::InclusiveProd { input, out } => reduce_inclusive(f, input, out, "*="), + WarpInstruction::ExclusiveSum { input, out } => { + reduce_exclusive(f, input, out, "+=", "0") + } + WarpInstruction::ExclusiveProd { input, out } => { + reduce_exclusive(f, input, out, "*=", "1") + } } } } @@ -97,6 +121,69 @@ fn reduce_operator( }) } +fn reduce_inclusive( + f: &mut core::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + op: &str, +) -> core::fmt::Result { + let in_optimized = input.optimized(); + let acc_item = in_optimized.item(); + + reduce_with_loop(f, input, out, acc_item, |acc, index| { + let acc_indexed = maybe_index(acc, index); + let shfl_up = D::warp_shuffle_up(&acc_indexed, "offset"); + let tmp = Variable::tmp(Item::scalar(acc_item.elem)); + let lane_id = Variable::::ThreadIdxWarp; + format!( + " +{} = {shfl_up}; +if({lane_id} >= offset) {{ + {acc_indexed} {op} {tmp}; +}} +", + tmp.fmt_left() + ) + }) +} + +fn reduce_exclusive( + f: &mut core::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + op: &str, + default: &str, +) -> core::fmt::Result { + let in_optimized = input.optimized(); + let acc_item = in_optimized.item(); + + let inclusive = Variable::tmp(acc_item); + reduce_inclusive(f, input, &inclusive, op)?; + let shfl = Variable::tmp(acc_item); + writeln!(f, "{} = {{", shfl.fmt_left())?; + for k in 0..acc_item.vectorization { + let inclusive_indexed = maybe_index(&inclusive, k); + writeln!( + f, + "{},", + D::warp_shuffle_up(&inclusive_indexed.to_string(), "1") + )?; + } + writeln!(f, "}};")?; + let lane_id = Variable::::ThreadIdxWarp; + + write!( + f, + "{} = ({lane_id} == 0) ? {}{{", + out.fmt_left(), + out.item(), + )?; + for _ in 0..out.item().vectorization { + write!(f, "{default},")?; + } + writeln!(f, "}} : {};", cast(&shfl, out.item())) +} + fn reduce_comparison( f: &mut core::fmt::Formatter<'_>, input: &Variable, diff --git a/crates/cubecl-ir/src/plane.rs b/crates/cubecl-ir/src/plane.rs index ea1048b5..557b1c89 100644 --- a/crates/cubecl-ir/src/plane.rs +++ b/crates/cubecl-ir/src/plane.rs @@ -19,7 +19,11 @@ pub enum Plane { Ballot(UnaryOperator), Broadcast(BinaryOperator), Sum(UnaryOperator), + InclusiveSum(UnaryOperator), + ExclusiveSum(UnaryOperator), Prod(UnaryOperator), + InclusiveProd(UnaryOperator), + ExclusiveProd(UnaryOperator), Min(UnaryOperator), Max(UnaryOperator), } @@ -35,7 +39,11 @@ impl Display for Plane { writeln!(f, "plane_broadcast({}, {})", op.lhs, op.rhs) } Plane::Sum(op) => writeln!(f, "plane_sum({})", op.input), + Plane::InclusiveSum(op) => writeln!(f, "plane_inclusive_sum({})", op.input), + Plane::ExclusiveSum(op) => writeln!(f, "plane_exclusive_sum({})", op.input), Plane::Prod(op) => writeln!(f, "plane_product({})", op.input), + Plane::InclusiveProd(op) => writeln!(f, "plane_inclusive_product({})", op.input), + Plane::ExclusiveProd(op) => writeln!(f, "plane_exclusive_product({})", op.input), Plane::Min(op) => writeln!(f, "plane_min({})", op.input), Plane::Max(op) => writeln!(f, "plane_max({})", op.input), } diff --git a/crates/cubecl-opt/src/analyses/uniformity.rs b/crates/cubecl-opt/src/analyses/uniformity.rs index dee1d53d..fb702f2a 100644 --- a/crates/cubecl-opt/src/analyses/uniformity.rs +++ b/crates/cubecl-opt/src/analyses/uniformity.rs @@ -51,7 +51,12 @@ impl Uniformity { match &inst.operation { Operation::Plane(plane) => match plane { // Elect returns true on only one unit, so it's always non-uniform - Plane::Elect => self.mark_uniformity(out, false)?, + // Inclusive/exclusive scans are non-uniform by definition + Plane::Elect + | Plane::ExclusiveSum(_) + | Plane::InclusiveSum(_) + | Plane::ExclusiveProd(_) + | Plane::InclusiveProd(_) => self.mark_uniformity(out, false)?, // Reductions are always uniform if executed in uniform control flow Plane::Sum(_) | Plane::Prod(_) diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index 8828eefe..a9d12d83 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -254,7 +254,11 @@ impl Optimizer { Plane::All(unary_operator) | Plane::Any(unary_operator) | Plane::Sum(unary_operator) + | Plane::InclusiveSum(unary_operator) + | Plane::ExclusiveSum(unary_operator) | Plane::Prod(unary_operator) + | Plane::InclusiveProd(unary_operator) + | Plane::ExclusiveProd(unary_operator) | Plane::Min(unary_operator) | Plane::Max(unary_operator) | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read), diff --git a/crates/cubecl-spirv/src/subgroup.rs b/crates/cubecl-spirv/src/subgroup.rs index 19f88e61..6d38dcb5 100644 --- a/crates/cubecl-spirv/src/subgroup.rs +++ b/crates/cubecl-spirv/src/subgroup.rs @@ -1,4 +1,4 @@ -use cubecl_core::ir::{Plane, Variable}; +use cubecl_core::ir::{Plane, UnaryOperator, Variable}; use rspirv::spirv::{Capability, GroupOperation, Scope, Word}; use crate::{item::Elem, SpirvCompiler, SpirvTarget}; @@ -104,52 +104,22 @@ impl SpirvCompiler { }); } Plane::Sum(op) => { - self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| { - match out_ty.elem() { - Elem::Int(_, _) => b.group_non_uniform_i_add( - ty, - Some(out), - subgroup, - GroupOperation::Reduce, - input, - None, - ), - Elem::Float(_) | Elem::Relaxed => b.group_non_uniform_f_add( - ty, - Some(out), - subgroup, - GroupOperation::Reduce, - input, - None, - ), - elem => unreachable!("{elem}"), - } - .unwrap(); - }); + self.plane_sum(op, out, GroupOperation::Reduce, uniform); + } + Plane::ExclusiveSum(op) => { + self.plane_sum(op, out, GroupOperation::ExclusiveScan, uniform); + } + Plane::InclusiveSum(op) => { + self.plane_sum(op, out, GroupOperation::InclusiveScan, uniform); } Plane::Prod(op) => { - self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| { - match out_ty.elem() { - Elem::Int(_, _) => b.group_non_uniform_i_mul( - ty, - Some(out), - subgroup, - GroupOperation::Reduce, - input, - None, - ), - Elem::Float(_) | Elem::Relaxed => b.group_non_uniform_f_mul( - ty, - Some(out), - subgroup, - GroupOperation::Reduce, - input, - None, - ), - _ => unreachable!(), - } - .unwrap(); - }); + self.plane_prod(op, out, GroupOperation::Reduce, uniform); + } + Plane::ExclusiveProd(op) => { + self.plane_prod(op, out, GroupOperation::ExclusiveScan, uniform); + } + Plane::InclusiveProd(op) => { + self.plane_prod(op, out, GroupOperation::InclusiveScan, uniform); } Plane::Min(op) => { self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| { @@ -218,6 +188,50 @@ impl SpirvCompiler { } } + fn plane_sum( + &mut self, + op: UnaryOperator, + out: Variable, + action: GroupOperation, + uniform: bool, + ) { + let subgroup = self.subgroup(); + self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| { + match out_ty.elem() { + Elem::Int(_, _) => { + b.group_non_uniform_i_add(ty, Some(out), subgroup, action, input, None) + } + Elem::Float(_) | Elem::Relaxed => { + b.group_non_uniform_f_add(ty, Some(out), subgroup, action, input, None) + } + elem => unreachable!("{elem}"), + } + .unwrap(); + }); + } + + fn plane_prod( + &mut self, + op: UnaryOperator, + out: Variable, + action: GroupOperation, + uniform: bool, + ) { + let subgroup = self.subgroup(); + self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| { + match out_ty.elem() { + Elem::Int(_, _) => { + b.group_non_uniform_i_mul(ty, Some(out), subgroup, action, input, None) + } + Elem::Float(_) | Elem::Relaxed => { + b.group_non_uniform_f_mul(ty, Some(out), subgroup, action, input, None) + } + _ => unreachable!(), + } + .unwrap(); + }); + } + fn subgroup(&mut self) -> Word { self.const_u32(Scope::Subgroup as u32) } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 24fea4d4..03b476e7 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -609,10 +609,26 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }, + cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }, + cube::Plane::InclusiveSum(op) => Subgroup::InclusiveSum { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }, cube::Plane::Prod(op) => Subgroup::Prod { input: self.compile_variable(op.input), out: self.compile_variable(out), }, + cube::Plane::ExclusiveProd(op) => Subgroup::ExclusiveProd { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }, + cube::Plane::InclusiveProd(op) => Subgroup::InclusiveProd { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }, cube::Plane::Min(op) => Subgroup::Min { input: self.compile_variable(op.input), out: self.compile_variable(out), diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/subgroup.rs b/crates/cubecl-wgpu/src/compiler/wgsl/subgroup.rs index 52985b73..d934b9d1 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/subgroup.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/subgroup.rs @@ -28,10 +28,26 @@ pub enum Subgroup { input: Variable, out: Variable, }, + ExclusiveSum { + input: Variable, + out: Variable, + }, + InclusiveSum { + input: Variable, + out: Variable, + }, Prod { input: Variable, out: Variable, }, + ExclusiveProd { + input: Variable, + out: Variable, + }, + InclusiveProd { + input: Variable, + out: Variable, + }, Min { input: Variable, out: Variable, @@ -115,10 +131,26 @@ impl Display for Subgroup { let out = out.fmt_left(); writeln!(f, "{out} = subgroupAdd({input});") } + Subgroup::ExclusiveSum { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = subgroupExclusiveAdd({input});") + } + Subgroup::InclusiveSum { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = subgroupInclusiveAdd({input});") + } Subgroup::Prod { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = subgroupMul({input});") } + Subgroup::ExclusiveProd { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = subgroupExclusiveMul({input});") + } + Subgroup::InclusiveProd { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = subgroupInclusiveMul({input});") + } Subgroup::Min { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = subgroupMin({input});")