Skip to content

Commit

Permalink
Fix reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Feb 3, 2025
1 parent 276b8db commit 2437c74
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 23 deletions.
29 changes: 8 additions & 21 deletions crates/cubecl-reduce/src/shared_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::ReduceError;
///
/// // Create input and output handles.
/// let input_handle = client.create(f32::as_bytes(&[0, 1, 2, 3]));
/// let output_handle = client.empty(size_of::<F>());
/// let input = unsafe {
/// TensorHandleRef::<R>::from_raw_parts(
/// &input_handle,
Expand All @@ -29,9 +30,12 @@ use crate::ReduceError;
/// size_f32,
/// )
/// };
/// let output = unsafe {
/// TensorHandleRef::<R>::from_raw_parts(&output_handle, &[1], &[1], size_of::<F>())
/// };
///
/// // Here `R` is a `cubecl::Runtime`.
/// let result = shared_sum::<R, f32>(&client, input, cube_count);
/// let result = shared_sum::<R, f32>(&client, input, output, cube_count);
///
/// if result.is_ok() {
/// let binding = output_handle.binding();
Expand All @@ -43,8 +47,9 @@ use crate::ReduceError;
pub fn shared_sum<R: Runtime, N: Numeric + CubeElement>(
client: &ComputeClient<R::Server, R::Channel>,
input: TensorHandleRef<R>,
output: TensorHandleRef<R>,
cube_count: u32,
) -> Result<N, ReduceError> {
) -> Result<(), ReduceError> {
// Check that the client supports atomic addition.
let atomic_elem = Atomic::<N>::as_elem_native_unchecked();
if !client
Expand Down Expand Up @@ -74,19 +79,6 @@ pub fn shared_sum<R: Runtime, N: Numeric + CubeElement>(
let num_lines_per_unit = input_len.div_ceil(num_units * line_size);
let cube_count = CubeCount::new_1d(cube_count);

// Generate output tensor.
let output_handle = client.create(N::as_bytes(&[N::from_int(0)]));
let output_shape = vec![1];
let output_stride = vec![1];
let output = unsafe {
TensorHandleRef::<R>::from_raw_parts(
&output_handle,
&output_stride,
&output_shape,
size_of::<N>(),
)
};

// Launch kernel
unsafe {
shared_sum_kernel::launch_unchecked::<N, R>(
Expand All @@ -101,12 +93,7 @@ pub fn shared_sum<R: Runtime, N: Numeric + CubeElement>(
);
}

// Extract sum from output.
let binding = output_handle.binding();
let bytes = client.read_one(binding);
let output_values = N::from_bytes(&bytes);

Ok(output_values[0])
Ok(())
}

#[cube(launch_unchecked)]
Expand Down
10 changes: 8 additions & 2 deletions crates/cubecl-reduce/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ impl TestCase {
let client = R::client(device);

let input_handle = client.create(F::as_bytes(&input_values));
let output_handle = client.empty(size_of::<F>());

let input = unsafe {
TensorHandleRef::<R>::from_raw_parts(
Expand All @@ -513,14 +514,19 @@ impl TestCase {
size_of::<F>(),
)
};
let output = unsafe {
TensorHandleRef::<R>::from_raw_parts(&output_handle, &[1], &[1], size_of::<F>())
};

let cube_count = 3;
let result = shared_sum::<R, F>(&client, input, cube_count);
let result = shared_sum::<R, F>(&client, input, output, cube_count);

if result.is_err() {
return; // don't execute the test in that case since atomic adds are not supported.
}
assert_approx_equal(&[result.unwrap()], &[expected]);
let bytes = client.read_one(output_handle.binding());
let actual = F::from_bytes(&bytes);
assert_approx_equal(actual, &[expected]);
}

fn num_output_values(&self) -> usize {
Expand Down

0 comments on commit 2437c74

Please sign in to comment.