Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexible matmul test #476

Merged
merged 4 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/cubecl-cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod tests {

cubecl_core::testgen_all!(f32: [f16, bf16, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]);
cubecl_linalg::testgen_matmul_accelerated!([f16]);
cubecl_linalg::testgen_matmul_quantized!();
cubecl_linalg::testgen_matmul_simple!([f16, bf16, f32]);
cubecl_linalg::testgen_matmul_tiling2d!([f16, bf16, f32]);
cubecl_linalg::testgen_tensor_identity!([f16, bf16, f32, u32]);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,50 +1,36 @@
use std::fmt::Display;

use cubecl_core::prelude::*;
use cubecl_core::server::Handle;
use cubecl_core::tensor_line_size_parallel;
use cubecl_core::CubeElement;
use cubecl_core::Feature;

use crate::matmul::components::global::args::TensorInputsLaunch;
use crate::matmul::components::tile::accelerated::Accelerated;
use crate::matmul::components::tile::plane::PlaneMma;
use crate::matmul::components::Ident;
use crate::matmul::components::MatmulConfigFactory;
use crate::matmul::components::MatmulLaunch;
use crate::matmul::components::MatmulProblem;
use crate::matmul::components::MatrixLayout;
use crate::matmul::components::SingleMatmulSpec;
use crate::matmul::kernels::matmul;
use crate::matmul::kernels::matmul::Algorithm;
use crate::matmul::kernels::matmul::StandardSelector;
use crate::matmul::tests::test_utils::CastInto;
use crate::tensor::TensorHandle;

use crate::matmul::tests::test_utils::assert_equals_approx;
use crate::matmul::tests::test_utils::generate_random_data;
use crate::matmul::tests::test_utils::matmul_cpu_reference;
use crate::matmul::tests::test_utils::Sample;
use crate::matmul::tests::test_utils::TestPrecision;

struct TensorRawParts<F: Float + CubeElement> {
struct TensorRawParts<N: Numeric + CubeElement> {
handle: Handle,
shape: Vec<usize>,
strides: Vec<usize>,
original_data: Option<Vec<F>>,
original_data: Option<Vec<N>>,
}

type Spec<EG, ES> = SingleMatmulSpec<EG, ES, f32>;

/// Test the correctness of the specified Matmul on the given device,
/// against a naive CPU implementation over the given problem
pub fn test_matmul_algorithm<A, EG, ES, R>(
pub fn test_matmul_algorithm<A, P, R>(
client: ComputeClient<R::Server, R::Channel>,
mut problem: MatmulProblem,
input: <A::BatchMatmul as MatmulConfigFactory>::Input,
selection: A::Selection,
) where
A: Algorithm,
EG: Float + CubeElement + Display + CastInto<ES>,
ES: Float + CubeElement + Display + CastInto<EG>,
P: TestPrecision,
R: Runtime,
{
let env = std::env::var("MATMUL_TEST_MODE");
Expand All @@ -57,24 +43,24 @@ pub fn test_matmul_algorithm<A, EG, ES, R>(
},
Err(_) => false,
};
let lhs = tensor_raw_parts::<EG, R>(&client, &problem, Ident::Lhs);
let rhs = tensor_raw_parts::<EG, R>(&client, &problem, Ident::Rhs);
let out = tensor_raw_parts::<EG, R>(&client, &problem, Ident::Out);
let lhs = tensor_raw_parts::<P::EG, R>(&client, &problem, Ident::Lhs);
let rhs = tensor_raw_parts::<P::EG, R>(&client, &problem, Ident::Rhs);
let out = tensor_raw_parts::<P::EG, R>(&client, &problem, Ident::Out);

problem.lhs_line_size = tensor_line_size_parallel(
R::line_size_elem(&EG::as_elem_native_unchecked()),
R::line_size_elem(&P::EG::as_elem_native_unchecked()),
&lhs.shape,
&lhs.strides,
lhs.strides.len() - 1,
);
problem.rhs_line_size = tensor_line_size_parallel(
R::line_size_elem(&EG::as_elem_native_unchecked()),
R::line_size_elem(&P::EG::as_elem_native_unchecked()),
&rhs.shape,
&rhs.strides,
rhs.strides.len() - 1,
);
problem.out_line_size = tensor_line_size_parallel(
R::line_size_elem(&EG::as_elem_native_unchecked()),
R::line_size_elem(&P::EG::as_elem_native_unchecked()),
&out.shape,
&out.strides,
out.strides.len() - 1,
Expand All @@ -101,33 +87,33 @@ pub fn test_matmul_algorithm<A, EG, ES, R>(
}
};

if A::check_availability::<R, (EG, ES, f32)>(&client, &config).is_err() {
if A::check_availability::<R, (P::EG, P::ES, f32)>(&client, &config).is_err() {
// Can't execute the test.
println!("Skipped - not supported!");
client.flush();
return;
}

unsafe {
A::BatchMatmul::launch_unchecked::<Spec<EG, ES>, R>(
A::BatchMatmul::launch_unchecked::<SingleMatmulSpec<P::EG, P::ES, P::EA>, R>(
&client,
cube_dim,
cube_count,
TensorInputsLaunch::new(
TensorArg::<R>::from_raw_parts::<EG>(
TensorArg::<R>::from_raw_parts::<P::EG>(
&lhs.handle,
&lhs.strides,
&lhs.shape,
problem.lhs_line_size,
),
TensorArg::<R>::from_raw_parts::<EG>(
TensorArg::<R>::from_raw_parts::<P::EG>(
&rhs.handle,
&rhs.strides,
&rhs.shape,
problem.rhs_line_size,
),
),
TensorArg::<R>::from_raw_parts::<EG>(
TensorArg::<R>::from_raw_parts::<P::EG>(
&out.handle,
&out.strides,
&out.shape,
Expand All @@ -137,74 +123,23 @@ pub fn test_matmul_algorithm<A, EG, ES, R>(
);
}

assert_result::<EG, ES, R>(
&lhs.original_data.unwrap(),
&rhs.original_data.unwrap(),
&problem,
&client,
out.handle,
None,
);
}

/// Test the correctness of the high-level Matmul on the given device,
/// against a naive CPU implementation over the given problem
pub fn test_matmul_launch<EG: Float + CubeElement + Display + CastInto<EG>, R: Runtime>(
problem: MatmulProblem,
device: &R::Device,
) {
let client: ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel> = R::client(device);

if !(client.properties().feature_enabled(Feature::Plane)
&& client.properties().feature_enabled(Feature::Type(
EG::as_elem_native().expect("To be a native type"),
)))
{
// Can't execute the test.
return;
}

let lhs = tensor_raw_parts::<EG, R>(&client, &problem, Ident::Lhs);
let rhs = tensor_raw_parts::<EG, R>(&client, &problem, Ident::Rhs);
let out = tensor_raw_parts::<EG, R>(&client, &problem, Ident::Out);

let lhs_handle = TensorHandle::new(lhs.shape, lhs.strides, lhs.handle);
let rhs_handle = TensorHandle::new(rhs.shape, rhs.strides, rhs.handle);
let out_handle = TensorHandle::new(out.shape, out.strides, out.handle);

let out = matmul::launch::<R, EG, StandardSelector<Accelerated>>(
&client,
lhs_handle.clone(),
rhs_handle.clone(),
out_handle.clone(),
)
.unwrap_or_else(|_| {
matmul::launch::<R, EG, StandardSelector<PlaneMma>>(
&client, lhs_handle, rhs_handle, out_handle,
)
.unwrap()
});

assert_result::<EG, EG, R>(
P::assert_result::<R>(
&lhs.original_data.unwrap(),
&rhs.original_data.unwrap(),
&problem,
&client,
out.handle,
// We cannot assume the inner precision of the matmul, therefore we need a permissive epsilon
Some(10e-2),
);
}

fn tensor_raw_parts<EG: Float + CubeElement, R: Runtime>(
fn tensor_raw_parts<EG: Numeric + CubeElement + Sample, R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
problem: &MatmulProblem,
ident: Ident,
) -> TensorRawParts<EG> {
match ident {
Ident::Lhs => {
let original_data: Vec<EG> =
generate_random_data(tensor_size(problem, Ident::Lhs), 1234);
let original_data = EG::sample(tensor_size(problem, Ident::Lhs), 1234);
let data = match problem.lhs_layout {
MatrixLayout::RowMajor => original_data.clone(),
MatrixLayout::ColMajor => {
Expand All @@ -220,8 +155,7 @@ fn tensor_raw_parts<EG: Float + CubeElement, R: Runtime>(
}
}
Ident::Rhs => {
let original_data: Vec<EG> =
generate_random_data(tensor_size(problem, Ident::Rhs), 5678);
let original_data = EG::sample(tensor_size(problem, Ident::Rhs), 5678);
let data = match problem.rhs_layout {
MatrixLayout::RowMajor => original_data.clone(),
MatrixLayout::ColMajor => {
Expand Down Expand Up @@ -266,44 +200,6 @@ fn transpose<E: Copy>(array: &[E], batches: usize, rows: usize, cols: usize) ->
result
}

fn assert_result<
EG: Float + CubeElement + Display + CastInto<ES>,
ES: Float + CubeElement + CastInto<EG>,
R: Runtime,
>(
lhs: &[EG],
rhs: &[EG],
problem: &MatmulProblem,
client: &ComputeClient<R::Server, R::Channel>,
out: Handle,
epsilon: Option<f32>,
) {
let epsilon = match epsilon {
Some(epsilon) => epsilon,
None => {
let maybe_cmma = client.properties().feature_enabled(Feature::Cmma {
a: ES::as_elem_native().expect("To be a native type"),
b: ES::as_elem_native().expect("To be a native type"),
c: EG::as_elem_native().expect("To be a native type"),
m: 16,
k: 16,
n: 16,
});

// Need to compensate for the temporary conversion to f16/tf32
match maybe_cmma {
true => 10e-5 / EG::EPSILON.to_f32().unwrap() * half::f16::EPSILON.to_f32(),
false => 10e-5,
}
}
};

let expected = matmul_cpu_reference(lhs, rhs, problem);
if let Err(e) = assert_equals_approx::<R, EG>(client, out, &expected, epsilon) {
panic!("{}", e);
}
}

/// Returns the total number of elements for the identified tensor, inferred by the problem definition
fn tensor_size(problem: &MatmulProblem, ident: Ident) -> usize {
match ident {
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-linalg/src/matmul/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ mod test_utils;
pub mod tiling2d;

pub use test_macros::cmma::suite::*;
pub use test_utils::Quantized;
16 changes: 10 additions & 6 deletions crates/cubecl-linalg/src/matmul/tests/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use cubecl_core::{prelude::Float, CubeElement, Runtime};

use crate::{matmul::kernels::simple, tensor::TensorHandle};

use super::test_utils::{assert_equals_approx, MatmulTestCase};
use super::test_utils::{assert_equals_approx, MatmulTestCase, Sample};

pub fn test_small<R: Runtime, F: Float + CubeElement + Display>(device: &R::Device) {
pub fn test_small<R: Runtime, F: Float + CubeElement + Display + Sample>(device: &R::Device) {
let case = MatmulTestCase {
m: 64,
k: 64,
Expand All @@ -17,7 +17,7 @@ pub fn test_small<R: Runtime, F: Float + CubeElement + Display>(device: &R::Devi
test_simple::<R, F>(case, device);
}

pub fn test_large<R: Runtime, F: Float + CubeElement + Display>(device: &R::Device) {
pub fn test_large<R: Runtime, F: Float + CubeElement + Display + Sample>(device: &R::Device) {
let case = MatmulTestCase {
m: 256,
k: 256,
Expand All @@ -28,7 +28,9 @@ pub fn test_large<R: Runtime, F: Float + CubeElement + Display>(device: &R::Devi
test_simple::<R, F>(case, device);
}

pub fn test_with_check_bounds<R: Runtime, F: Float + CubeElement + Display>(device: &R::Device) {
pub fn test_with_check_bounds<R: Runtime, F: Float + CubeElement + Display + Sample>(
device: &R::Device,
) {
let case = MatmulTestCase {
m: 60,
k: 60,
Expand All @@ -39,7 +41,9 @@ pub fn test_with_check_bounds<R: Runtime, F: Float + CubeElement + Display>(devi
test_simple::<R, F>(case, device);
}

pub fn test_with_batches<R: Runtime, F: Float + CubeElement + Display>(device: &R::Device) {
pub fn test_with_batches<R: Runtime, F: Float + CubeElement + Display + Sample>(
device: &R::Device,
) {
let case = MatmulTestCase {
m: 64,
k: 64,
Expand All @@ -50,7 +54,7 @@ pub fn test_with_batches<R: Runtime, F: Float + CubeElement + Display>(device: &
test_simple::<R, F>(case, device);
}

fn test_simple<R: Runtime, F: Float + CubeElement + Display>(
fn test_simple<R: Runtime, F: Float + CubeElement + Display + Sample>(
case: MatmulTestCase,
device: &R::Device,
) {
Expand Down
22 changes: 18 additions & 4 deletions crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ pub mod suite;
#[macro_export]
macro_rules! testgen_matmul_accelerated {
($eg:ty, $es:ty) => {
type EG = $eg;
type ES = $es;
type Precision = ($eg, $es);

$crate::matmul_standard_tests!();
};
Expand All @@ -26,11 +25,26 @@ macro_rules! testgen_matmul_accelerated {
}
};
}

#[macro_export]
macro_rules! testgen_matmul_quantized {
() => {
#[allow(non_snake_case)]
mod matmul_quantized {
use super::*;

type Precision = $crate::matmul::tests::Quantized;
type TMM = $crate::matmul::components::tile::accelerated::Accelerated;

$crate::matmul_standard_tests!();
}
};
}

#[macro_export]
macro_rules! testgen_matmul_plane {
($float:ident) => {
type EG = $float;
type ES = $float;
type Precision = ($eg, $es);

$crate::matmul_standard_tests!();
};
Expand Down
Loading