Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUTLASS-based row-wise scaled sparse FP8 kernel #1671

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

alexsamardzic
Copy link
Collaborator

No description provided.

Copy link

pytorch-bot bot commented Feb 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1671

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit d1d96f7 with merge base d00ee41 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 5, 2025
@alexsamardzic
Copy link
Collaborator Author

alexsamardzic commented Feb 5, 2025

The kernel is ready and passes smoke test.

Remaining tasks:

  • Write a converter to SM90 sparse semi-structured format
  • Validate the kernel on proper test inputs
  • Write the benchmark
  • Write Python-side code: sparsify/quantize method, Llama generator extension, etc.
  • Provide that kernel is built with SM90a flags when torchao detects H100 card as SM90
  • Further unify CUDA code with rowwise_scaled_linear_cutlass code
  • Implement a meaningful config selection.

@cpuhrsch @drisspg

@cpuhrsch cpuhrsch requested a review from jcaip February 5, 2025 21:00
@alexsamardzic alexsamardzic added float8 sparsity topic: new feature Use this tag if this PR adds a new feature labels Feb 6, 2025
@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch 2 times, most recently from 5bbcc49 to 6d34b7e Compare February 6, 2025 23:41
@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch 8 times, most recently from bd7288a to f11fae4 Compare February 13, 2025 22:38
@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch 10 times, most recently from bf65c83 to c0368e3 Compare February 19, 2025 23:33
@alexsamardzic
Copy link
Collaborator Author

Testing this PR revealed that the sparse compressor in CUTLASS is not treating -0.0 values as zeros. The upstream fix is proposed here.

@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch from c0368e3 to 4c63c65 Compare February 20, 2025 19:10
@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch 2 times, most recently from 05ed2d4 to 6fb6165 Compare February 23, 2025 19:24
@alexsamardzic
Copy link
Collaborator Author

alexsamardzic commented Feb 24, 2025

This PR is ready for review. It contains:

  1. An implementation of two new CUTLASS-based operators:
    • Converter to sparse format for FP8 data and SM9x arch, in torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x.
    • Row-wise scaled linear operator implementation for sparse FP8 weight and FP8 activation in torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass. For parallel compilation, each operator template instantiation is in a separate .cu file.
  2. The test for later operator in test/test_rowwise_scaled_linear_sparse_cutlass.py (not all tests will pass at the moment because of [QST] About NaNs generated during FP16->FP8 quantization #1766), and the micro-benchmark in benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py.
  3. The corresponding layout and TensorImpl class implementations in torchao/dtypes/floatx/cutlass_semi_sparse_layout.py. Because of a CUTLASS issue with handling minus zero values when compressing dense to sparse tensor, from_plain() method here contains a temporary workaround (the fix for this CUTLASS issue is in the works: Treat negative zero as equivalent to positive zero in sm90_sparse_gemm_compressor.hpp NVIDIA/cutlass#2110).
  4. The remaining glue code on the Python side in torchao/ops.py, torchao/dtypes/affine_quantized_tensor.py and torchao/quantization/quant_api.py, including definition of new config Float8DynamicActivationFloat8SemiSparseWeightConfig for the quantize_() method.
  5. An update to torchao/_models/llama/generate.py script, to make it possible to test the new quantization and linear operator within the context of Llama - run with python generate.py --compile --sparsity semi -q float8dq.
  6. Some minor updates for CUTLASS-based integer W4A4/W4A8 stuff.

I'll address the performance tuning (through CUTLASS run-time config selection), that is mentioned as a remaining task above, in a separate PR.

@drisspg The setup.py changes are about activating gencode flags for SM90a when the build is for SM90 - it's clumsy, but it works, so hopefully we could use this approach until eventually switching to CMake builds for the extensions. I'm adding you as a reviewer because of this; also, please add reviewer(s), whoever may be the most appropriate, for the Python side of the code.

@jcaip If you think there is a need, we may discuss eventually exposing mentioned new operators through SparseSemiStructuredTensor.

@gau-nernst With this PR, it's possible to try CUTLASS-based W4A4 operator from the Llama generator - run with python generate.py --compile --sparsity semi -q int4dq-4 (be sure to fetch the model beforehand - instructions are here). The output is not meaningful, maybe it's because the quantization is too tight, but we may want to investigate it further.

@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch 3 times, most recently from ad04e4b to a7197f7 Compare February 25, 2025 21:33
@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch from a7197f7 to d1d96f7 Compare February 26, 2025 18:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. float8 sparsity topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants