From 49e8f9c965118c5b4b0c4727643351e3ca2b7691 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Feb 2025 05:11:32 +0000 Subject: [PATCH] Revert "Add torch._scaled_mm for CPU (#139975)" This reverts commit 22fae4c5f94eb43f71a2eebc1904880740cb1d60. Reverted https://github.com/pytorch/pytorch/pull/139975 on behalf of https://github.com/huydhn due to third time is the charm ([comment](https://github.com/pytorch/pytorch/pull/139975#issuecomment-2664622598)) --- aten/src/ATen/native/Blas.cpp | 83 --- aten/src/ATen/native/mkldnn/Linear.cpp | 127 +--- aten/src/ATen/native/mkldnn/Linear.h | 12 - aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp | 22 +- aten/src/ATen/native/native_functions.yaml | 2 - test/inductor/test_fp8.py | 113 ++-- test/test_matmul_cuda.py | 526 ++++++++++++++++ test/test_matmul_fp8.py | 561 ------------------ torch/_inductor/codegen/cpp_prefix.h | 4 - .../aoti_torch/generated/c_shim_cpu.h | 2 - torch/testing/_internal/common_device_type.py | 2 - .../_internal/common_methods_invocations.py | 47 +- 12 files changed, 586 insertions(+), 915 deletions(-) delete mode 100644 test/test_matmul_fp8.py diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp index 3117f8f1136ab4..f62c3177782268 100644 --- a/aten/src/ATen/native/Blas.cpp +++ b/aten/src/ATen/native/Blas.cpp @@ -7,11 +7,6 @@ #include #include -#include -#include -#if !defined(__s390x__) && !defined(__powerpc__) -#include -#endif #ifndef AT_PER_OPERATOR_HEADERS #include @@ -29,9 +24,6 @@ #include #include #include -#include -#include -#include #endif namespace at::meta { @@ -230,79 +222,4 @@ Tensor vdot(const Tensor &self, const Tensor &other){ } -static Tensor& -_scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - bool use_fast_accum, - Tensor& out) { - TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); - TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); - TORCH_CHECK( - mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", - mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - - TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend."); - TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], - " but got ", bias->numel()); - - // Check types - TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); - TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); - TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); - - auto mat1_c = mat1.contiguous(); - auto mat2_c = mat2.contiguous(); - IntArrayRef mat1_sizes = mat1_c.sizes(); - IntArrayRef mat2_sizes = mat2_c.sizes(); - at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); - - float input_scale = scale_a.item(); - float weight_scale = scale_b.item(); - auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale); - auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale); - auto out_tmp = at::matmul(fp32_mat1, fp32_mat2); - if (bias) { - out_tmp.add_(bias.value()); - } - out_tmp = out_tmp.to(out.scalar_type()); - out.copy_(out_tmp); - return out; -} - -Tensor& -_scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - bool use_fast_accum, - Tensor& out) { -#if AT_MKLDNN_ENABLED() && (IDEEP_VERSION_MAJOR >= 3 && IDEEP_VERSION_MINOR >= 5) - if (at::globalContext().userEnabledMkldnn() && cpuinfo_has_x86_amx_int8()) { - return mkldnn_scaled_mm(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); - } else -#endif - { - return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); - } -} - -Tensor -_scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - bool use_fast_accum) { - const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); - Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_)); - return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); -} - } // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index 1da3ce29cc9372..86304ccbb2a850 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -47,20 +46,9 @@ std::tuple mkldnn_linear_backward( TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support"); } -Tensor& -mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - bool use_fast_accum, - Tensor& out) { - TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support"); -} - } // namespace at::native + #else // AT_MKLDNN_ENABLED #include @@ -459,119 +447,6 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) { TORCH_FN(mkldnn_linear_pointwise_binary)); } -Tensor& -mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - bool use_fast_accum, - Tensor& out) { - TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); - TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); - TORCH_CHECK( - mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", - mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - - TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend."); - TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], - " but got ", bias->numel()); - - // Check types - TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); - TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); - TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); - // TODO: This check of mat1 and mat2 must have the same data type will be removed after oneDNN v3.6. - TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "Expected mat1 and mat2 must have the same data type"); - - // Validation checks have passed lets resize the output to actual size - auto mat1_c = mat1.contiguous(); - auto mat2_c = mat2.contiguous(); - IntArrayRef mat1_sizes = mat1_c.sizes(); - IntArrayRef mat2_sizes = mat2_c.sizes(); - at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); - - float input_scale = scale_a.item(); - float weight_scale = scale_b.item(); - auto src = at::native::itensor_view_from_dense(mat1_c); - auto weight_t = at::native::itensor_view_from_dense(mat2_c); - bool with_bias = bias.has_value(); - int64_t K = mat1_sizes[1], M = mat1_sizes[0], - N = mat2_sizes[1]; - - std::vector src_dims = {M, K}; - std::vector weight_dims = {K, N}; - std::vector dst_dims = {M, N}; - - ideep::tensor dst = at::native::itensor_view_from_dense(out); - auto src_desc = ideep::tensor::desc( - src_dims, - get_mkldnn_dtype(mat1.scalar_type()), - ideep::format_tag::any); - auto weights_desc = ideep::tensor::desc( - weight_dims, - get_mkldnn_dtype(mat2.scalar_type()), - ideep::format_tag::any); - auto dst_desc = ideep::tensor::desc( - dst_dims, - get_mkldnn_dtype(out.scalar_type()), - ideep::format_tag::any); - ideep::tensor onednn_bias; - if (with_bias) { - auto bias_value = bias.value(); - if (bias_value.dim() == 1) { - auto b_reshape = bias_value.reshape({1, bias_value.size(0)}); - onednn_bias = at::native::itensor_view_from_dense(b_reshape); - } else { - onednn_bias = at::native::itensor_view_from_dense(bias_value); - } - } - auto bias_desc = ideep::tensor::desc(); - if (with_bias) { - bias_desc = ideep::tensor::desc(onednn_bias.get_dims(), - get_mkldnn_dtype(bias.value().scalar_type()), - ideep::format_tag::any); - } - auto op_attr = ideep::attr_t(); - if (input_scale != 1.0f) { - op_attr.set_scales_mask(DNNL_ARG_SRC, 0); - } - if (weight_scale != 1.0f) { - op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); - } - - op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); - auto engine = ideep::engine::cpu_engine(); - dnnl::matmul::primitive_desc primitive_desc = with_bias - ? dnnl::matmul::primitive_desc( - engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr) - : dnnl::matmul::primitive_desc( - engine, src_desc, weights_desc, dst_desc, op_attr); - auto primitive = dnnl::matmul(primitive_desc); - - // Prepare args and execute primitive - ideep::tensor scratchpad(primitive_desc.scratchpad_desc()); - ideep::exec_args args; - args.insert({DNNL_ARG_SRC, src}); - args.insert({DNNL_ARG_WEIGHTS, weight_t}); - args.insert({DNNL_ARG_DST, dst}); - args.insert({DNNL_ARG_SCRATCHPAD, scratchpad}); - if (with_bias) { - args.insert({DNNL_ARG_BIAS, onednn_bias}); - } - ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale)); - ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale)); - - if (input_scale != 1.0f) { - args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t}); - } - args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t}); - - primitive.execute(ideep::stream::default_stream(), args); - return out; -} - } // namespace at #endif // AT_MKLDNN_ENABLED diff --git a/aten/src/ATen/native/mkldnn/Linear.h b/aten/src/ATen/native/mkldnn/Linear.h index 1dc50c7c541673..6a7fcd60b0e6d4 100644 --- a/aten/src/ATen/native/mkldnn/Linear.h +++ b/aten/src/ATen/native/mkldnn/Linear.h @@ -35,15 +35,3 @@ C10_API Tensor mkl_linear( } // namespace at #endif // AT_MKLDNN_ENABLED() - -namespace at::native { -Tensor& -mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - bool use_fast_accum, - Tensor& out); -} // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp index f26427a981f729..32daef37a5637f 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp @@ -57,10 +57,6 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) { return ideep::tensor::data_type::bf16; case ScalarType::Half: return ideep::tensor::data_type::f16; - case ScalarType::Float8_e4m3fn: - return ideep::tensor::data_type::f8_e4m3; - case ScalarType::Float8_e5m2: - return ideep::tensor::data_type::f8_e5m2; default: TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type"); } @@ -165,24 +161,8 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data const_cast(tensor.const_data_ptr()) : tensor.data_ptr()}; } - else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) { - return {{tensor.sizes().vec(), - ideep::tensor::data_type::f8_e4m3, - tensor.strides().vec()}, - from_const_data_ptr ? - const_cast(tensor.const_data_ptr()) : - tensor.data_ptr()}; - } - else if (tensor.scalar_type() == ScalarType::Float8_e5m2) { - return {{tensor.sizes().vec(), - ideep::tensor::data_type::f8_e5m2, - tensor.strides().vec()}, - from_const_data_ptr ? - const_cast(tensor.const_data_ptr()) : - tensor.data_ptr()}; - } else { - TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8/fp8 tensor input"); + TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input"); } } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index fb86365fae79e6..52f9547d470d13 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7066,13 +7066,11 @@ - func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor variants: function dispatch: - CPU: _scaled_mm_cpu CUDA: _scaled_mm_cuda - func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: - CPU: _scaled_mm_out_cpu CUDA: _scaled_mm_out_cuda # NOTE [ Sparse: autograd and API ] diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index dac57df9930e21..9d71bb6a8f74a4 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -14,7 +14,7 @@ parametrize, TEST_WITH_ROCM, ) -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA from torch.utils._triton import has_triton_tma_device @@ -117,10 +117,10 @@ def _fix_fp8_dtype_for_rocm( @instantiate_parametrized_tests class TestFP8Types(TestCase): + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @unittest.skipIf(TEST_WITH_ROCM, "Not supported yet") @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) - @parametrize("device", ("cuda", "cpu")) - def test_xblock_for_small_numel(self, float8_dtype: torch.dtype, device: str): + def test_xblock_for_small_numel(self, float8_dtype: torch.dtype): """ TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4 depends on the variant of fp8 type. @@ -129,34 +129,30 @@ def test_xblock_for_small_numel(self, float8_dtype: torch.dtype, device: str): We should not pick a XBLOCK larger than xnumel """ - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) def f(x): return x.to(dtype=float8_dtype) - x = torch.randn(1, device=device) + x = torch.randn(1, device="cuda") expected = f(x) actual = torch.compile(f)(x) torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @unittest.skipIf(TEST_WITH_ROCM, "Not supported yet") @parametrize("dtype", (torch.float16, torch.bfloat16)) - @parametrize("device", ("cuda", "cpu")) - def test_eager_fallback(self, dtype: torch.dtype, device: torch.device): - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) + def test_eager_fallback(self, dtype: torch.dtype): weight_shape = (32, 16) e4m3_type = torch.float8_e4m3fn e4m3_type = _fix_fp8_dtype_for_rocm(e4m3_type, device="cuda") def fp8_matmul_unwrapped(x): - a_scale = torch.Tensor([1.0]).to(device=device) - b_scale = torch.Tensor([1.0]).to(device=device) + a_scale = torch.Tensor([1.0]).to(device="cuda") + b_scale = torch.Tensor([1.0]).to(device="cuda") output_scale = None - input_bias = torch.rand(32, device=device, dtype=dtype) - weight = torch.rand(*weight_shape, device=device, dtype=dtype).T.to( + input_bias = torch.rand(32, device="cuda", dtype=dtype) + weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T.to( e4m3_type ) a_inverse_scale = 1 / a_scale @@ -177,24 +173,19 @@ def fp8_matmul_unwrapped(x): ) x_shape = (16, 16) - x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type) + x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type) y_fp8 = compiled_fp8_matmul(x) # noqa: F841 x_shape = (15, 16) - x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type) + x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type) y_fp8 = compiled_fp8_matmul(x) # noqa: F841 + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float)) @parametrize("shape", ("15,3,13", "4,2048,4096")) @parametrize("dst_types", [(torch.float8_e4m3fn, torch.float8_e5m2)]) - @parametrize("device", ("cuda", "cpu")) - def test_valid_cast( - self, dtype: torch.dtype, shape: str, dst_types: tuple, device: torch.device - ): - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) - if device == "cuda": - dst_types = _fix_fp8_dtype_for_rocm(dst_types, device="cuda") + def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple): + dst_types = _fix_fp8_dtype_for_rocm(dst_types, device="cuda") e4m3, e5m2 = dst_types def fp8_cast(x): @@ -205,7 +196,7 @@ def fp8_cast(x): compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True) shape = [int(dim) for dim in shape.split(",")] - x = torch.rand(*shape, device=device, dtype=dtype) + x = torch.rand(*shape, device="cuda", dtype=dtype) y0_fp8, y1_fp8 = compiled_fp8_cast(x) torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1) @@ -234,21 +225,14 @@ def fp8_cast(x, dtype): x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2) compiled_fp8_cast(x, torch.float8_e4m3fn) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float)) @parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("16,16,16", "4,2048,4096")) - @parametrize("device", ("cuda", "cpu")) def test_to_fp8_saturated( - self, - src_dtype: torch.dtype, - dst_dtype: torch.dtype, - shape: str, - device: torch.device, + self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str ): - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) - if device == "cuda": - dst_dtype = _fix_fp8_dtype_for_rocm(dst_dtype, device="cuda") + dst_dtype = _fix_fp8_dtype_for_rocm(dst_dtype, device="cuda") def fp8_saturated(x, dtype): return _to_fp8_saturated(x, dtype) @@ -257,23 +241,18 @@ def fp8_saturated(x, dtype): fp8_saturated, backend="inductor", dynamic=True ) shape = [int(dim) for dim in shape.split(",")] - x = torch.rand(*shape, device=device, dtype=src_dtype) + x = torch.rand(*shape, device="cuda", dtype=src_dtype) y_compiled = compiled_fp8_cast(x, dst_dtype) y = fp8_saturated(x, dst_dtype) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1) @unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue") + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) - @parametrize("device", ("cuda", "cpu")) - def test_amax_fp8_quant( - self, float8_dtype: torch.dtype, shape: str, device: torch.device - ): - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest( - "FP8 is only supported on H100+ and sm_89 and MI300+ devices" - ) + def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str): + float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") shape = [int(dim) for dim in shape.split(",")] batch_size, sequence_length, hidden_size = shape @@ -286,24 +265,19 @@ def amax_fp8(x: Tensor, scale: Tensor): compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor") x_shape = (batch_size, sequence_length, hidden_size) - x = torch.rand(*x_shape, device=device, dtype=torch.half) - scale = torch.tensor(0.2, device=device, dtype=torch.float) + x = torch.rand(*x_shape, device="cuda", dtype=torch.half) + scale = torch.tensor(0.2, device="cuda", dtype=torch.float) y_compiled = compiled_amax_fp8_quant(x, scale) y = amax_fp8(x, scale) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) - @parametrize("device", ("cuda", "cpu")) - def test_amax_along_with_fp8_quant( - self, float8_dtype: torch.dtype, shape: str, device: torch.device - ): - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) - if device == "cuda": - float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") + def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str): + float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") shape = [int(dim) for dim in shape.split(",")] batch_size, sequence_length, hidden_size = shape @@ -316,12 +290,12 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor") x_shape = (batch_size, sequence_length, hidden_size) - x = torch.rand(*x_shape, device=device, dtype=torch.half) - scale = torch.tensor(1.0, device=device, dtype=torch.float) + x = torch.rand(*x_shape, device="cuda", dtype=torch.half) + scale = torch.tensor(1.0, device="cuda", dtype=torch.float) - amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half) + amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half) y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled) - amax_buffer = torch.zeros((1), device=device, dtype=torch.half) + amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half) y = amax_fp8(x, scale, amax_buffer) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1) @@ -330,21 +304,14 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): ) @unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue") + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("amax_keep_dim", (True, False)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) - @parametrize("device", ("cuda", "cpu")) def test_layernorm_fp8_quant( - self, - float8_dtype: torch.dtype, - amax_keep_dim: bool, - shape: str, - device: torch.device, + self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: str ): - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest( - "FP8 is only supported on H100+ and sm_89 and MI300+ devices" - ) + float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") shape = [int(dim) for dim in shape.split(",")] batch_size, sequence_length, hidden_size = shape @@ -366,12 +333,12 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor") x_shape = (batch_size, sequence_length, hidden_size) - x = torch.rand(*x_shape, device=device, dtype=torch.half) - scale = torch.tensor(0.2, device=device, dtype=torch.float) + x = torch.rand(*x_shape, device="cuda", dtype=torch.half) + scale = torch.tensor(0.2, device="cuda", dtype=torch.float) - amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half) + amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half) y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled) - amax_buffer = torch.zeros((1), device=device, dtype=torch.half) + amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half) y = ln_fp8(x, scale, amax_buffer) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1) @@ -785,5 +752,5 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): if __name__ == "__main__": - if HAS_CUDA or HAS_CPU: + if HAS_CUDA: run_tests() diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 6ba2230d973118..940b9a983578ef 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -1,8 +1,11 @@ # Owner(s): ["module: linear algebra"] +import contextlib import unittest from itertools import product from functools import partial +from typing import Optional +import re import torch @@ -14,7 +17,9 @@ from torch.testing import make_tensor from torch.testing._internal.common_cuda import ( SM53OrLater, + SM89OrLater, _get_torch_cuda_version, + PLATFORM_SUPPORTS_FP8 ) from torch.testing._internal.common_device_type import ( dtypes, @@ -244,6 +249,526 @@ def _expand_to_batch(t: torch.Tensor): self.assertEqual(out1_gpu, out2_gpu[0]) +f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" + +if torch.version.hip: + e4m3_type = torch.float8_e4m3fnuz + e5m2_type = torch.float8_e5m2fnuz + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max +else: + e4m3_type = torch.float8_e4m3fn + e5m2_type = torch.float8_e5m2 + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max + +# avoid division by zero when calculating scale +EPS = 1e-12 + +def amax_to_scale( + amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype +): + """ Converts the amax value of a tensor to the fp8 scale. + Args: + amax: The amax value of the tensor. + float8_dtype: the float8 dtype. + orig_dtype: The original dtype of the tensor. + """ + scale = torch.empty_like(amax, dtype=torch.float32) + if float8_dtype == e4m3_type: + res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + elif float8_dtype == e5m2_type: + res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + else: + raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") + + # Ensure the scale is representable in float16, + # this helps when amax is small. We are assuming that we don't need + # to care about this for float32/bfloat16 + if orig_dtype is torch.float16: + res = torch.clamp(res, max=torch.finfo(torch.float16).max) + + scale.copy_(res) + return scale + +def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): + if dim is None: + amax = torch.max(torch.abs(x)) + else: + amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values + + return amax_to_scale(amax, float8_dtype, x.dtype) + +def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: + # naive implementation: dq -> op -> q + x_fp32 = x.to(torch.float) / x_scale + y_fp32 = y.to(torch.float) / y_scale + out_fp32 = torch.mm(x_fp32, y_fp32) + + return out_fp32.to(out_dtype) + +def addmm_float8_unwrapped( + a_data: torch.Tensor, + a_scale: torch.Tensor, + b_data: torch.Tensor, + b_scale: torch.tensor, + output_dtype: torch.dtype, + output_scale: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + a_inverse_scale = a_scale.reciprocal() + b_inverse_scale = b_scale.reciprocal() + if output_dtype == torch.float32 and bias is not None: + # Bias is not supported by _scaled_mm when output is fp32 + output = torch._scaled_mm( + a_data, + b_data, + scale_a=a_inverse_scale, + scale_b=b_inverse_scale, + scale_result=output_scale, + out_dtype=output_dtype, + ) + output += bias + return output + output = torch._scaled_mm( + a_data, + b_data, + bias=bias, + scale_a=a_inverse_scale, + scale_b=b_inverse_scale, + scale_result=output_scale, + out_dtype=output_dtype, + ) + return output + +def mm_float8( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + output_dtype: torch.dtype, # output dtype + output_scale: Optional[torch.Tensor] = None, # output scale, precomputed +) -> torch.Tensor: + return addmm_float8_unwrapped( + a, a_scale, b, b_scale, output_dtype, output_scale + ) + +def to_fp8_saturated( + x: torch.Tensor, + fp8_dtype: torch.dtype +): + if fp8_dtype == e4m3_type: + x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) + elif fp8_dtype == e5m2_type: + x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) + else: + raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}") + + return x.to(fp8_dtype) + +@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") +class TestFP8MatmulCuda(TestCase): + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + def _test_tautological_mm(self, device: str = "cuda", + x_dtype: torch.dtype = e4m3_type, + y_dtype: torch.dtype = e4m3_type, + out_dtype: Optional[torch.dtype] = None, + size: int = 16) -> None: + x_fp8 = torch.rand(size, size, device=device).to(x_dtype) + y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t() + out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) + out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype) + if out_dtype is not None: + self.assertEqual(out_dtype, out_fp8.dtype) + self.assertEqual(out_fp32, out_fp8.to(torch.float)) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + def test_float8_basics(self, device) -> None: + self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) + # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported + # supported on ROCm but fails on CUDA + ctx = self.assertRaises(RuntimeError) if torch.version.hip is None else contextlib.nullcontext() + with ctx: + self._test_tautological_mm(device, e5m2_type, e5m2_type) + + self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32) + self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48) + + self._test_tautological_mm(device, size=64, out_dtype=torch.float16) + self._test_tautological_mm(device, size=96, out_dtype=torch.float32) + self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16) + + with self.assertRaises(AssertionError if torch.version.hip else RuntimeError): + self._test_tautological_mm(device, out_dtype=e5m2_type) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + def test_float8_scale(self, device) -> None: + size = (16, 16) + x = torch.full(size, .5, device=device, dtype=e4m3_type) + # hipblaslt does not yet support mixed e4m3_type input + y_type = e4m3_type if torch.version.hip else e5m2_type + y = torch.full(size, .5, device=device, dtype=y_type).t() + scale_one = torch.tensor(1.0, device=device) + scale_a = torch.tensor(1.5, device=device) + scale_b = torch.tensor(0.66, device=device) + out_fp8 = torch._scaled_mm(x, y, scale_a=scale_one, scale_b=scale_one) + self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) + out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) + self.assertEqual(out_fp8, out_fp8_s) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_scaled_mm_vs_emulated(self, base_dtype): + torch.manual_seed(42) + input_dtype = e4m3_type + output_dtype = base_dtype + compare_type = torch.float32 + + x = torch.randn(16, 16, device="cuda", dtype=base_dtype) + y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() + + x_scale = tensor_to_scale(x, input_dtype).float() + y_scale = tensor_to_scale(y, input_dtype).float() + + x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) + y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) + + # Calculate actual F8 mm + out_scaled_mm = mm_float8( + x_fp8, + y_fp8, + a_scale=x_scale, + b_scale=y_scale, + output_dtype=output_dtype + ) + + # Calculate emulated F8 mm + out_emulated = mm_float8_emulated( + x_fp8, + x_scale, + y_fp8, + y_scale, + output_dtype + ) + + if output_dtype != base_dtype: + out_scaled_mm = out_scaled_mm.to(compare_type) + out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) + + out_emulated = out_emulated.to(compare_type) + out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) + + if base_dtype in {torch.bfloat16, torch.float16}: + atol, rtol = 7e-2, 7e-2 + else: + atol, rtol = 3e-3, 3e-3 + + torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_scaled_mm_change_stride(self, base_dtype): + torch.manual_seed(42) + input_dtype = e4m3_type + output_dtype = base_dtype + compare_type = torch.float32 + + x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype) + y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype) + + x.normal_() + y.normal_() + + x_scale = tensor_to_scale(x, input_dtype).float() + y_scale = tensor_to_scale(y, input_dtype).float() + + x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) + y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) + + # Calculate actual F8 mm + out_scaled_mm = mm_float8( + x_fp8, + y_fp8, + a_scale=x_scale, + b_scale=y_scale, + output_dtype=output_dtype + ) + + # Calculate emulated F8 mm + out_emulated = mm_float8_emulated( + x_fp8, + x_scale, + y_fp8, + y_scale, + output_dtype + ) + + if output_dtype != base_dtype: + out_scaled_mm = out_scaled_mm.to(compare_type) + out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) + + out_emulated = out_emulated.to(compare_type) + out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) + + if base_dtype in {torch.bfloat16, torch.float16}: + atol, rtol = 7e-2, 7e-2 + else: + atol, rtol = 3e-3, 3e-3 + + torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + def test_float8_bias(self, device) -> None: + (k, l, m) = (16, 48, 32) + x = torch.ones((k, l), device=device).to(e4m3_type) + y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() + bias = torch.full((m,), 4.0, device=device, dtype=torch.half) + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) + out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) + outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias) + # this fails on ROCm currently because hipblaslt doesn't have amax op + out_fp32 = out_fp8.to(torch.float32) + outb_fp32 = outb_fp8.to(torch.float32) + difference = torch.abs(out_fp32 - outb_fp32) + self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32)) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @parametrize("bias", [True, False]) + def test_non_divisible_leading_dim(self, device, bias: bool) -> None: + x = torch.rand((17, 16), device=device).to(e4m3_type) + y = torch.rand((16, 16), device=device).to(e4m3_type).t() + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) + input_bias = None + if bias: + input_bias = torch.rand((16,), device=device).to(torch.half) + _ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + def test_float8_bias_relu_edgecase(self, device) -> None: + (k, l, m) = (16, 48, 32) + x = torch.full((k, l), 0.0, device=device).to(e4m3_type) + y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t() + bias = torch.full((m,), -3.0, device=device, dtype=torch.half) + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) + outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias) + outb_fp32 = outb_fp8.to(torch.float32) + self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32)) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + def test_float32_output_errors_with_bias(self, device) -> None: + (k, l, m) = (16, 48, 32) + x = torch.rand((k, l), device=device).to(e4m3_type) + y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) + bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16) + self.assertRaisesRegex( + RuntimeError, + "Bias is not supported when out_dtype is set to Float32", + lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), + ) + + @unittest.skipIf(PLATFORM_SUPPORTS_FP8, f8_msg) + def test_error_message_fp8_pre_sm89(self, device) -> None: + (k, l, m) = (16, 48, 32) + x = torch.rand((k, l), device=device).to(e4m3_type) + y = torch.rand((m, l), device=device).to(e4m3_type).t() + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) + self.assertRaisesRegex( + RuntimeError, + r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+", + lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32), + ) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + def test_float8_scale_fast_accum(self, device) -> None: + size = (16, 16) + x = torch.full(size, .5, device=device, dtype=e4m3_type) + # hipblaslt does not yet support mixed e4m3_type input + y_type = e4m3_type if torch.version.hip else e5m2_type + y = torch.full(size, .5, device=device, dtype=y_type).t() + scale_a = torch.tensor(1.5, device=device) + scale_b = torch.tensor(0.66, device=device) + out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True) + self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) + out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) + self.assertEqual(out_fp8, out_fp8_s) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) + @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific") + @parametrize("use_fast_accum", [True, False]) + def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: + M, K, N = (1024, 512, 2048) + fill_value = 0.5 + x = torch.full((M, K), fill_value, device=device) + y = torch.full((N, K), fill_value, device=device) + + x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32) + y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32) + + x_fp8 = x.to(e4m3_type) + y_fp8 = y.to(e4m3_type).t() + + out_fp8 = torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=x_scales, + scale_b=y_scales, + out_dtype=torch.bfloat16, + use_fast_accum=use_fast_accum, + ) + self.assertEqual( + out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) + ) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) + @skipIfRocm() + def test_float8_error_messages(self, device) -> None: + M, K, N = (1024, 512, 2048) + fill_value = 0.5 + x = torch.full((M, K), fill_value, device=device) + y = torch.full((N, K), fill_value, device=device) + + x_fp8 = x.to(e4m3_type) + y_fp8 = y.to(e4m3_type).t() + + with self.assertRaisesRegex( + RuntimeError, + re.escape( + "For RowWise scaling, scale_a should be (1024, 1) and scale_b " + "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)" + ), + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((1, 1), device="cuda"), + scale_b=torch.ones((1, 2), device="cuda"), + out_dtype=torch.bfloat16, + ) + + with self.assertRaisesRegex( + RuntimeError, + re.escape( + " For RowWise scaling, scale_a should be (1024, 1) and scale_b " + "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)" + ), + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((M, 1), device="cuda"), + scale_b=torch.ones((1, N + 1), device="cuda"), + out_dtype=torch.bfloat16, + ) + with self.assertRaisesRegex( + RuntimeError, + re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"), + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N, N), device="cuda"), + out_dtype=torch.bfloat16, + ) + + with self.assertRaisesRegex( + RuntimeError, + re.escape( + "Both scale_a and scale_b must be contiguous for RowWise scaling." + ), + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((M, 1), device="cuda"), + scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2], + out_dtype=torch.bfloat16, + ) + + # Note re.compile is used, not re.escape. This is to accomodate fn vs fnuz type message. + with self.assertRaisesRegex( + RuntimeError, + r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.", + ): + torch._scaled_mm( + x_fp8, + y_fp8.to(e5m2_type), + scale_a=torch.ones((M, 1), device="cuda"), + scale_b=torch.ones((1, N), device="cuda"), + out_dtype=torch.bfloat16, + ) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) + @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific") + @parametrize("base_dtype", [torch.bfloat16]) + def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): + torch.manual_seed(42) + input_dtype = e4m3_type + output_dtype = base_dtype + + x = torch.randn(16, 16, device="cuda", dtype=base_dtype) + y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() + + x_scales = tensor_to_scale(x, input_dtype, dim=1).float() + y_scales = tensor_to_scale(y, input_dtype, dim=0).float() + + x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type) + y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type) + + # Calculate actual F8 mm + out_scaled_mm = mm_float8( + x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype + ) + + # Calculate emulated F8 mm + out_emulated = mm_float8_emulated( + x_fp8, x_scales, y_fp8, y_scales, output_dtype + ) + + if base_dtype in {torch.bfloat16, torch.float16}: + atol, rtol = 7e-2, 7e-2 + else: + atol, rtol = 2e-3, 2e-3 + + torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @parametrize("which_dim_zero", [0, 1, 2]) + @parametrize("use_torch_compile", [False, True]) + def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None: + device = "cuda" + x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn + out_dtype = torch.bfloat16 + M, K, N = 32, 32, 32 + if which_dim_zero == 0: + M = 0 + elif which_dim_zero == 1: + K = 0 + elif which_dim_zero == 2: + N = 0 + + x_fp8 = torch.zeros(M, K, device=device).to(x_dtype) + y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t() + out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) + scale_a = torch.tensor(float('-inf'), device=device) + scale_b = torch.tensor(float('-inf'), device=device) + f = torch._scaled_mm + if use_torch_compile: + f = torch.compile(torch._scaled_mm) + out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype) + self.assertEqual(out_dtype, out_fp8.dtype) + self.assertEqual(out_fp32, out_fp8.to(torch.float)) + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") @unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") @@ -365,6 +890,7 @@ def run_test( ) instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu") +instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu") instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu") if __name__ == '__main__': diff --git a/test/test_matmul_fp8.py b/test/test_matmul_fp8.py deleted file mode 100644 index 852b026bbc7491..00000000000000 --- a/test/test_matmul_fp8.py +++ /dev/null @@ -1,561 +0,0 @@ -# Owner(s): ["module: linear algebra"] - -import contextlib -import re -import unittest -from typing import Optional - -import torch -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM89OrLater -from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_utils import ( - IS_WINDOWS, - parametrize, - run_tests, - skipIfRocm, - TestCase, -) - - -# Protects against includes accidentally setting the default dtype -assert torch.get_default_dtype() is torch.float32 - -f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" - -if torch.version.hip: - e4m3_type = torch.float8_e4m3fnuz - e5m2_type = torch.float8_e5m2fnuz - E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max - E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max -else: - e4m3_type = torch.float8_e4m3fn - e5m2_type = torch.float8_e5m2 - E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max - E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max - -# avoid division by zero when calculating scale -EPS = 1e-12 - - -def amax_to_scale( - amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype -): - """Converts the amax value of a tensor to the fp8 scale. - Args: - amax: The amax value of the tensor. - float8_dtype: the float8 dtype. - orig_dtype: The original dtype of the tensor. - """ - scale = torch.empty_like(amax, dtype=torch.float32) - if float8_dtype == e4m3_type: - res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) - elif float8_dtype == e5m2_type: - res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) - else: - raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") - - # Ensure the scale is representable in float16, - # this helps when amax is small. We are assuming that we don't need - # to care about this for float32/bfloat16 - if orig_dtype is torch.float16: - res = torch.clamp(res, max=torch.finfo(torch.float16).max) - - scale.copy_(res) - return scale - - -def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): - if dim is None: - amax = torch.max(torch.abs(x)) - else: - amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values - - return amax_to_scale(amax, float8_dtype, x.dtype) - - -def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: - # naive implementation: dq -> op -> q - x_fp32 = x.to(torch.float) / x_scale - y_fp32 = y.to(torch.float) / y_scale - out_fp32 = torch.mm(x_fp32, y_fp32) - - return out_fp32.to(out_dtype) - - -def addmm_float8_unwrapped( - a_data: torch.Tensor, - a_scale: torch.Tensor, - b_data: torch.Tensor, - b_scale: torch.tensor, - output_dtype: torch.dtype, - output_scale: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - a_inverse_scale = a_scale.reciprocal() - b_inverse_scale = b_scale.reciprocal() - if output_dtype == torch.float32 and bias is not None: - # Bias is not supported by _scaled_mm when output is fp32 - output = torch._scaled_mm( - a_data, - b_data, - scale_a=a_inverse_scale, - scale_b=b_inverse_scale, - scale_result=output_scale, - out_dtype=output_dtype, - ) - output += bias - return output - output = torch._scaled_mm( - a_data, - b_data, - bias=bias, - scale_a=a_inverse_scale, - scale_b=b_inverse_scale, - scale_result=output_scale, - out_dtype=output_dtype, - ) - return output - - -def mm_float8( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - output_dtype: torch.dtype, # output dtype - output_scale: Optional[torch.Tensor] = None, # output scale, precomputed -) -> torch.Tensor: - return addmm_float8_unwrapped(a, a_scale, b, b_scale, output_dtype, output_scale) - - -def to_fp8_saturated(x: torch.Tensor, fp8_dtype: torch.dtype): - if fp8_dtype == e4m3_type: - x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) - elif fp8_dtype == e5m2_type: - x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) - else: - raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}") - - return x.to(fp8_dtype) - - -class TestFP8Matmul(TestCase): - def _test_tautological_mm( - self, - device: str = "cuda", - x_dtype: torch.dtype = e4m3_type, - y_dtype: torch.dtype = e4m3_type, - out_dtype: Optional[torch.dtype] = None, - size: int = 16, - ) -> None: - if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) - x_fp8 = torch.rand(size, size, device=device).to(x_dtype) - y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t() - out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) - scale_a = torch.tensor(1.0, device=device) - scale_b = torch.tensor(1.0, device=device) - out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype) - if out_dtype is not None: - self.assertEqual(out_dtype, out_fp8.dtype) - self.assertEqual(out_fp32, out_fp8.to(torch.float)) - - def test_float8_basics(self, device) -> None: - if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) - self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) - # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported - # supported on ROCm but fails on CUDA - ctx = ( - self.assertRaises(RuntimeError) - if torch.version.hip is None and device != "cpu" - else contextlib.nullcontext() - ) - with ctx: - self._test_tautological_mm(device, e5m2_type, e5m2_type) - - if device != "cpu": - # TODO: The following 2 tests are mixed dtypes between src and weight, - # which will be enabled in oneDNN v3.6 in CPU. - self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32) - self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48) - - self._test_tautological_mm(device, size=64, out_dtype=torch.float16) - self._test_tautological_mm(device, size=96, out_dtype=torch.float32) - self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16) - - with self.assertRaises( - AssertionError if torch.version.hip or device == "cpu" else RuntimeError - ): - self._test_tautological_mm(device, out_dtype=e5m2_type) - - def test_float8_scale(self, device) -> None: - if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) - size = (16, 16) - x = torch.full(size, 0.5, device=device, dtype=e4m3_type) - # hipblaslt does not yet support mixed e4m3_type input - # TODO: will use e5m2_type after upgrading oneDNN to v3.6. - y_type = e4m3_type if torch.version.hip or device == "cpu" else e5m2_type - y = torch.full(size, 0.5, device=device, dtype=y_type).t() - scale_one = torch.tensor(1.0, device=device) - scale_a = torch.tensor(1.5, device=device) - scale_b = torch.tensor(0.66, device=device) - out_fp8 = torch._scaled_mm(x, y, scale_a=scale_one, scale_b=scale_one) - self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4.0, device=device)) - out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) - self.assertEqual(out_fp8, out_fp8_s) - - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) - def test_scaled_mm_vs_emulated(self, base_dtype): - torch.manual_seed(42) - input_dtype = e4m3_type - output_dtype = base_dtype - compare_type = torch.float32 - - x = torch.randn(16, 16, device="cuda", dtype=base_dtype) - y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() - - x_scale = tensor_to_scale(x, input_dtype).float() - y_scale = tensor_to_scale(y, input_dtype).float() - - x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) - y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) - - # Calculate actual F8 mm - out_scaled_mm = mm_float8( - x_fp8, y_fp8, a_scale=x_scale, b_scale=y_scale, output_dtype=output_dtype - ) - - # Calculate emulated F8 mm - out_emulated = mm_float8_emulated(x_fp8, x_scale, y_fp8, y_scale, output_dtype) - - if output_dtype != base_dtype: - out_scaled_mm = out_scaled_mm.to(compare_type) - out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) - - out_emulated = out_emulated.to(compare_type) - out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) - - if base_dtype in {torch.bfloat16, torch.float16}: - atol, rtol = 7e-2, 7e-2 - else: - atol, rtol = 3e-3, 3e-3 - - torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) - def test_scaled_mm_change_stride(self, base_dtype): - torch.manual_seed(42) - input_dtype = e4m3_type - output_dtype = base_dtype - compare_type = torch.float32 - - x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype) - y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype) - - x.normal_() - y.normal_() - - x_scale = tensor_to_scale(x, input_dtype).float() - y_scale = tensor_to_scale(y, input_dtype).float() - - x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) - y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) - - # Calculate actual F8 mm - out_scaled_mm = mm_float8( - x_fp8, y_fp8, a_scale=x_scale, b_scale=y_scale, output_dtype=output_dtype - ) - - # Calculate emulated F8 mm - out_emulated = mm_float8_emulated(x_fp8, x_scale, y_fp8, y_scale, output_dtype) - - if output_dtype != base_dtype: - out_scaled_mm = out_scaled_mm.to(compare_type) - out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) - - out_emulated = out_emulated.to(compare_type) - out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) - - if base_dtype in {torch.bfloat16, torch.float16}: - atol, rtol = 7e-2, 7e-2 - else: - atol, rtol = 3e-3, 3e-3 - - torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - - def test_float8_bias(self, device) -> None: - if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) - (k, l, m) = (16, 48, 32) - x = torch.ones((k, l), device=device).to(e4m3_type) - y = torch.full((m, l), 0.25, device=device, dtype=e4m3_type).t() - bias = torch.full((m,), 4.0, device=device, dtype=torch.half) - scale_a = torch.tensor(1.0, device=device) - scale_b = torch.tensor(1.0, device=device) - out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) - outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias) - # this fails on ROCm currently because hipblaslt doesn't have amax op - out_fp32 = out_fp8.to(torch.float32) - outb_fp32 = outb_fp8.to(torch.float32) - difference = torch.abs(out_fp32 - outb_fp32) - self.assertEqual( - difference, torch.tensor(4.0, device=device).expand_as(out_fp32) - ) - - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @parametrize("bias", [True, False]) - def test_non_divisible_leading_dim(self, device, bias: bool) -> None: - x = torch.rand((17, 16), device=device).to(e4m3_type) - y = torch.rand((16, 16), device=device).to(e4m3_type).t() - scale_a = torch.tensor(1.0, device=device) - scale_b = torch.tensor(1.0, device=device) - input_bias = None - if bias: - input_bias = torch.rand((16,), device=device).to(torch.half) - _ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias) - - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - def test_float8_bias_relu_edgecase(self, device) -> None: - (k, l, m) = (16, 48, 32) - x = torch.full((k, l), 0.0, device=device).to(e4m3_type) - y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t() - bias = torch.full((m,), -3.0, device=device, dtype=torch.half) - scale_a = torch.tensor(1.0, device=device) - scale_b = torch.tensor(1.0, device=device) - outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias) - outb_fp32 = outb_fp8.to(torch.float32) - self.assertEqual( - outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32) - ) - - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - def test_float32_output_errors_with_bias(self, device) -> None: - (k, l, m) = (16, 48, 32) - x = torch.rand((k, l), device=device).to(e4m3_type) - y = torch.full((m, l), 0.25, device=device, dtype=e4m3_type).t() - scale_a = torch.tensor(1.0, device=device) - scale_b = torch.tensor(1.0, device=device) - bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16) - self.assertRaisesRegex( - RuntimeError, - "Bias is not supported when out_dtype is set to Float32", - lambda: torch._scaled_mm( - x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32 - ), - ) - - @unittest.skipIf(PLATFORM_SUPPORTS_FP8 or not torch.cuda.is_available(), f8_msg) - def test_error_message_fp8_pre_sm89(self, device) -> None: - (k, l, m) = (16, 48, 32) - x = torch.rand((k, l), device=device).to(e4m3_type) - y = torch.rand((m, l), device=device).to(e4m3_type).t() - scale_a = torch.tensor(1.0, device=device) - scale_b = torch.tensor(1.0, device=device) - self.assertRaisesRegex( - RuntimeError, - r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+", - lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32), - ) - - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - def test_float8_scale_fast_accum(self, device) -> None: - size = (16, 16) - x = torch.full(size, 0.5, device=device, dtype=e4m3_type) - # hipblaslt does not yet support mixed e4m3_type input - y_type = e4m3_type if torch.version.hip else e5m2_type - y = torch.full(size, 0.5, device=device, dtype=y_type).t() - scale_a = torch.tensor(1.5, device=device) - scale_b = torch.tensor(0.66, device=device) - out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True) - self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4.0, device=device)) - out_fp8_s = torch._scaled_mm( - x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True - ) - self.assertEqual(out_fp8, out_fp8_s) - - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) - @unittest.skipIf( - not SM89OrLater, "rowwise implementation is currently sm89+ specific" - ) - @parametrize("use_fast_accum", [True, False]) - def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: - M, K, N = (1024, 512, 2048) - fill_value = 0.5 - x = torch.full((M, K), fill_value, device=device) - y = torch.full((N, K), fill_value, device=device) - - x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32) - y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32) - - x_fp8 = x.to(e4m3_type) - y_fp8 = y.to(e4m3_type).t() - - out_fp8 = torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=x_scales, - scale_b=y_scales, - out_dtype=torch.bfloat16, - use_fast_accum=use_fast_accum, - ) - self.assertEqual( - out_fp8.to(torch.float32), - torch.full((M, N), K * (fill_value**2), device=device), - ) - - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) - @skipIfRocm() - def test_float8_error_messages(self, device) -> None: - M, K, N = (1024, 512, 2048) - fill_value = 0.5 - x = torch.full((M, K), fill_value, device=device) - y = torch.full((N, K), fill_value, device=device) - - x_fp8 = x.to(e4m3_type) - y_fp8 = y.to(e4m3_type).t() - - with self.assertRaisesRegex( - RuntimeError, - re.escape( - "For RowWise scaling, scale_a should be (1024, 1) and scale_b " - "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)" - ), - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((1, 1), device="cuda"), - scale_b=torch.ones((1, 2), device="cuda"), - out_dtype=torch.bfloat16, - ) - - with self.assertRaisesRegex( - RuntimeError, - re.escape( - " For RowWise scaling, scale_a should be (1024, 1) and scale_b " - "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)" - ), - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((M, 1), device="cuda"), - scale_b=torch.ones((1, N + 1), device="cuda"), - out_dtype=torch.bfloat16, - ) - with self.assertRaisesRegex( - RuntimeError, - re.escape( - "For non-TensorWise scaling, scale tensors must be 2-dimensional" - ), - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N, N), device="cuda"), - out_dtype=torch.bfloat16, - ) - - with self.assertRaisesRegex( - RuntimeError, - re.escape( - "Both scale_a and scale_b must be contiguous for RowWise scaling." - ), - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((M, 1), device="cuda"), - scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2], - out_dtype=torch.bfloat16, - ) - - # Note re.compile is used, not re.escape. This is to accomodate fn vs fnuz type message. - with self.assertRaisesRegex( - RuntimeError, - re.escape( - "Expected b.dtype() == at::kFloat8_e4m3fn to be true, but got false." - ), - ): - torch._scaled_mm( - x_fp8, - y_fp8.to(e5m2_type), - scale_a=torch.ones((M, 1), device="cuda"), - scale_b=torch.ones((1, N), device="cuda"), - out_dtype=torch.bfloat16, - ) - - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) - @unittest.skipIf( - not SM89OrLater, "rowwise implementation is currently sm89+ specific" - ) - @parametrize("base_dtype", [torch.bfloat16]) - def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): - torch.manual_seed(42) - input_dtype = e4m3_type - output_dtype = base_dtype - - x = torch.randn(16, 16, device="cuda", dtype=base_dtype) - y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() - - x_scales = tensor_to_scale(x, input_dtype, dim=1).float() - y_scales = tensor_to_scale(y, input_dtype, dim=0).float() - - x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type) - y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type) - - # Calculate actual F8 mm - out_scaled_mm = mm_float8( - x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype - ) - - # Calculate emulated F8 mm - out_emulated = mm_float8_emulated( - x_fp8, x_scales, y_fp8, y_scales, output_dtype - ) - - if base_dtype in {torch.bfloat16, torch.float16}: - atol, rtol = 7e-2, 7e-2 - else: - atol, rtol = 2e-3, 2e-3 - - torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @parametrize("which_dim_zero", [0, 1, 2]) - @parametrize("use_torch_compile", [False, True]) - def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None: - device = "cuda" - x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn - out_dtype = torch.bfloat16 - M, K, N = 32, 32, 32 - if which_dim_zero == 0: - M = 0 - elif which_dim_zero == 1: - K = 0 - elif which_dim_zero == 2: - N = 0 - - x_fp8 = torch.zeros(M, K, device=device).to(x_dtype) - y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t() - out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) - scale_a = torch.tensor(float("-inf"), device=device) - scale_b = torch.tensor(float("-inf"), device=device) - f = torch._scaled_mm - if use_torch_compile: - f = torch.compile(torch._scaled_mm) - out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype) - self.assertEqual(out_dtype, out_fp8.dtype) - self.assertEqual(out_fp32, out_fp8.to(torch.float)) - - -instantiate_device_type_tests(TestFP8Matmul, globals()) - -if __name__ == "__main__": - TestCase._default_dtype_check_enabled = True - run_tests() diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index ce48746dde629c..850391191b1341 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -21,8 +21,6 @@ #include #include -#include -#include #include #include #include @@ -50,8 +48,6 @@ typedef at::BFloat16 bfloat16; typedef at::Float8_e4m3fn float8_e4m3fn; typedef at::Float8_e5m2 float8_e5m2; -typedef at::Float8_e4m3fnuz float8_e4m3fnuz; -typedef at::Float8_e5m2fnuz float8_e5m2fnuz; template struct Welford { diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 590b65237ce030..924e77b28c2980 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -35,8 +35,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index e4de424118fb6b..3e712799d80917 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -1002,8 +1002,6 @@ class OpDTypes(Enum): torch.int8, torch.uint8, torch.bool, - torch.float8_e4m3fn, - torch.float8_e5m2, ) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index bd59cdc153dc37..087c874c60cf7b 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -22,7 +22,7 @@ from torch.testing._internal.common_dtype import ( _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, - empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types, + empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, ) from torch.testing._internal.common_device_type import \ (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, @@ -8743,21 +8743,18 @@ def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs): scale1 = make_scale((1,)) scale2 = make_scale((1,)) samples.append(SampleInput(mat1, mat2, scale1, scale2)) - # TODO: Will remove this after oneDNN v3.6 - # now oneDNN v3.5.3 only supports mat1 * mat2 with the same data types. - if device != 'cpu': - # mat1 e4m3 mat2 e5m2 - mat1 = make_mat_e4m3((M, K)) - mat2 = make_mat_e5m2((K, N)).t().contiguous().t() - scale1 = make_scale((1,)) - scale2 = make_scale((1,)) - samples.append(SampleInput(mat1, mat2, scale1, scale2)) - # mat1 e5m2 mat2 e4m3 - mat1 = make_mat_e5m2((M, K)) - mat2 = make_mat_e4m3((K, N)).t().contiguous().t() - scale1 = make_scale((1,)) - scale2 = make_scale((1,)) - samples.append(SampleInput(mat1, mat2, scale1, scale2)) + # mat1 e4m3 mat2 e5m2 + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e5m2((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append(SampleInput(mat1, mat2, scale1, scale2)) + # mat1 e5m2 mat2 e4m3 + mat1 = make_mat_e5m2((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append(SampleInput(mat1, mat2, scale1, scale2)) yield from samples @@ -16217,7 +16214,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): OpInfo( 'torch._scaled_mm', sample_inputs_func=sample_inputs_scaled_mm, - dtypes=float8_types(), + dtypes=empty_types(), dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,), supports_out=True, supports_forward_ad=False, @@ -16225,20 +16222,12 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], skips=( # Sample inputs isn't really parametrized on dtype - DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), - # "add_stub" not implemented for 'Float8_e4m3fn' - # "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn' - # https://github.com/pytorch/pytorch/issues/107256 - DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', + device_type='cuda'), # "mul_cuda" not implemented for float8_e4m3fn - # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn' # https://github.com/pytorch/pytorch/issues/107256 - DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'), - # aten::_scaled_mm hit the vmap fallback which is currently disabled - DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), - DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), - DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', - dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)), + DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + dtypes=(torch.float8_e4m3fn,)), ) ), OpInfo(