Skip to content

Commit

Permalink
Add inclusive/exclusive sum and prod (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Feb 3, 2025
1 parent a172f67 commit ff94be8
Show file tree
Hide file tree
Showing 13 changed files with 680 additions and 47 deletions.
4 changes: 3 additions & 1 deletion crates/cubecl-core/src/frontend/element/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,5 +382,7 @@ pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
val: C,
) -> ExpandElementTyped<Out> {
let input: ExpandElementTyped<C> = val.into();
<Out as super::Cast>::__expand_cast_from(scope, input)
let const_val = input.expand.as_const().unwrap();
let var = Variable::constant(const_val.cast_to(Out::as_elem(scope)));
ExpandElement::Plain(var).into()
}
142 changes: 142 additions & 0 deletions crates/cubecl-core/src/frontend/plane.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,77 @@ pub mod plane_sum {
}
}

/// Perform an inclusive sum operation across all units in a plane.
/// This sums all values to the "left" of the unit, including this unit's value.
/// Also known as "prefix sum" or "inclusive scan".
///
/// # Example
/// `inclusive_sum([1, 2, 3, 4, 5]) == [1, 3, 6, 10, 15]`
#[allow(unused_variables)]
pub fn plane_inclusive_sum<E: CubePrimitive>(value: E) -> E {
unexpanded!()
}

/// Module containing the expand function for [plane_inclusive_sum()].
pub mod plane_inclusive_sum {
use super::*;

/// Expand method of [plane_inclusive_sum()].
pub fn expand<E: CubePrimitive>(
scope: &mut Scope,
elem: ExpandElementTyped<E>,
) -> ExpandElementTyped<E> {
let elem: ExpandElement = elem.into();
let output = scope.create_local(elem.item);

let out = *output;
let input = *elem;

scope.register(Instruction::new(
Plane::InclusiveSum(UnaryOperator { input }),
out,
));

output.into()
}
}

/// Perform an exclusive sum operation across all units in a plane.
/// This sums all values to the "left" of the unit, excluding this unit's value. The 0th unit will
/// be set to `E::zero()`.
/// Also known as "exclusive prefix sum" or "exclusive scan".
///
/// # Example
/// `exclusive_sum([1, 2, 3, 4, 5]) == [0, 1, 3, 6, 10]`
#[allow(unused_variables)]
pub fn plane_exclusive_sum<E: CubePrimitive>(value: E) -> E {
unexpanded!()
}

/// Module containing the expand function for [plane_exclusive_sum()].
pub mod plane_exclusive_sum {
use super::*;

/// Expand method of [plane_exclusive_sum()].
pub fn expand<E: CubePrimitive>(
scope: &mut Scope,
elem: ExpandElementTyped<E>,
) -> ExpandElementTyped<E> {
let elem: ExpandElement = elem.into();
let output = scope.create_local(elem.item);

let out = *output;
let input = *elem;

scope.register(Instruction::new(
Plane::ExclusiveSum(UnaryOperator { input }),
out,
));

output.into()
}
}

/// Perform a reduce prod operation across all units in a plane.
pub fn plane_prod<E: CubePrimitive>(_elem: E) -> E {
unexpanded!()
Expand All @@ -113,6 +184,77 @@ pub mod plane_prod {
}
}

/// Perform an inclusive product operation across all units in a plane.
/// This multiplies all values to the "left" of the unit, including this unit's value.
/// Also known as "prefix product" or "inclusive scan".
///
/// # Example
/// `exclusive_prod([1, 2, 3, 4, 5]) == [1, 2, 6, 24, 120]`
#[allow(unused_variables)]
pub fn plane_inclusive_prod<E: CubePrimitive>(value: E) -> E {
unexpanded!()
}

/// Module containing the expand function for [plane_inclusive_prod()].
pub mod plane_inclusive_prod {
use super::*;

/// Expand method of [plane_inclusive_prod()].
pub fn expand<E: CubePrimitive>(
scope: &mut Scope,
elem: ExpandElementTyped<E>,
) -> ExpandElementTyped<E> {
let elem: ExpandElement = elem.into();
let output = scope.create_local(elem.item);

let out = *output;
let input = *elem;

scope.register(Instruction::new(
Plane::InclusiveProd(UnaryOperator { input }),
out,
));

output.into()
}
}

/// Perform an exclusive product operation across all units in a plane.
/// This multiplies all values to the "left" of the unit, excluding this unit's value. The 0th unit
/// will be set to `E::one()`.
/// Also known as "exclusive prefix product" or "exclusive scan".
///
/// # Example
/// `exclusive_prod([1, 2, 3, 4, 5]) == [1, 1, 2, 6, 24]`
#[allow(unused_variables)]
pub fn plane_exclusive_prod<E: CubePrimitive>(value: E) -> E {
unexpanded!()
}

/// Module containing the expand function for [plane_exclusive_prod()].
pub mod plane_exclusive_prod {
use super::*;

/// Expand method of [plane_exclusive_prod()].
pub fn expand<E: CubePrimitive>(
scope: &mut Scope,
elem: ExpandElementTyped<E>,
) -> ExpandElementTyped<E> {
let elem: ExpandElement = elem.into();
let output = scope.create_local(elem.item);

let out = *output;
let input = *elem;

scope.register(Instruction::new(
Plane::ExclusiveProd(UnaryOperator { input }),
out,
));

output.into()
}
}

/// Perform a reduce max operation across all units in a plane.
pub fn plane_max<E: CubePrimitive>(_elem: E) -> E {
unexpanded!()
Expand Down
Loading

0 comments on commit ff94be8

Please sign in to comment.