Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm Sparse Marlin Kernels #1206

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open

Conversation

petrex
Copy link
Collaborator

@petrex petrex commented Oct 31, 2024

Built on top pf #1201. This pull request introduces support for ROCm (Radeon Open Compute) for sparse marling kernel in addition to CUDA, enabling the code to run on AMD GPUs.

The main changes involve conditional compilation to handle differences between CUDA and ROCm, as well as adding ROCm-specific intrinsics for MI300x.

co-author : @lcskrishna


Key changes include:

ROCm Support in setup.py:

  • hip kernels generation

Conditional Compilation in CUDA Source Files:

  • Added conditional compilation directives to exclude certain code for ROCm and include ROCm-specific implementations.

ROCm-specific Implementations:

  • Implemented ROCm-specific versions of functions and macros that are different from their CUDA counterparts, ensuring compatibility and performance on AMD GPUs.

Next:

  • validation and benchmark across workloads on MIxxx GPUs

Copy link

pytorch-bot bot commented Oct 31, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1206

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures

As of commit f18043d with merge base 98c4e2e (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 31, 2024
@msaroufim msaroufim requested review from msaroufim and removed request for msaroufim November 2, 2024 22:51
@msaroufim
Copy link
Member

Do you have performance numbers by any chance relative to fp16? wanna make sure the performance improvements are competitive with CUDA

@petrex
Copy link
Collaborator Author

petrex commented Nov 5, 2024

Do you have performance numbers by any chance relative to fp16? wanna make sure the performance improvements are competitive with CUDA

still WIP, but would you share the benchmark you guys are using? will try that on mi300x when the PR is ready.

@msaroufim
Copy link
Member

Ok holler at me again whenever you need a review. Really excited to see this land

@drisspg
Copy link
Contributor

drisspg commented Nov 5, 2024

For benchmarking it is a little ad hoc the best place for this today would be to verify on: https://github.com/pytorch/ao/blob/main/torchao/_models/llama/generate.py

@jcaip jcaip mentioned this pull request Nov 11, 2024
1 task
@petrex petrex force-pushed the rocm_sparse_marlin branch from 00bc94d to d2c7ce4 Compare January 6, 2025 22:06
@petrex petrex added the topic: new feature Use this tag if this PR adds a new feature label Jan 6, 2025
@petrex
Copy link
Collaborator Author

petrex commented Jan 6, 2025

@pytorchbot rebase

Copy link

pytorch-bot bot commented Jan 6, 2025

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

@petrex petrex self-assigned this Jan 6, 2025
@petrex petrex marked this pull request as ready for review January 7, 2025 17:22
Copy link

pytorch-bot bot commented Jan 7, 2025

Unknown label ciflow/rocm.
Currently recognized labels are

  • ciflow/benchmark

@petrex petrex requested a review from msaroufim January 8, 2025 15:59
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems good to me, I'll lean on @atalman and @jcaip for the final merge since the error you're seeing in CI does seem like an underlying infra issue. It's not a flake though, I tried rerunning it and it still fails

@petrex petrex requested review from jcaip and atalman January 8, 2025 20:18
@petrex petrex force-pushed the rocm_sparse_marlin branch from 662bfe7 to a4e8c30 Compare January 8, 2025 23:06
Copy link
Contributor

@jcaip jcaip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Did you get a chance to try this and get benchmarking numbers? Curious to see how it compares. We should probably update the testing framework too for AMD

@@ -19,6 +19,28 @@
#include "base.h"

namespace torchao {

#ifdef USE_ROCM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we gate on a specific ROCm version like we do for CUDA?

Copy link
Collaborator Author

@petrex petrex Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! What we need is a GPU arch check instead of ROCm version check. I have added a GPU architecture check in the setup.py . As a result, the kernel will now only be built for the MI300X architecture.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I think setup.py was recently updated by #1490, so you may have to pull in the new changes.

#if defined(USE_ROCM)
#if ROCM_VERSION >= 60200
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does B16_ONE refer to here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks let me clean up a little bit. I'd like this PR focus on sparse_marlin. tensor_core_tile_layout.cu should go to #1201 instead.

Copy link
Collaborator Author

@petrex petrex Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0x3F80 in BF16 : Sign bit (0) + Exponent (111111) + Mantissa (00000000) = 1.0
Just renamed it in #1201 to reflect this.
see : 26fa19c

@petrex
Copy link
Collaborator Author

petrex commented Jan 9, 2025

Looks good! Did you get a chance to try this and get benchmarking numbers? Curious to see how it compares. We should probably update the testing framework too for AMD

thanks It is planned. I will update the benchmark PR.

@petrex petrex force-pushed the rocm_sparse_marlin branch from 1f3b773 to 08d1cfb Compare January 9, 2025 22:34
@petrex petrex requested a review from jcaip January 10, 2025 00:25
Copy link
Contributor

@jcaip jcaip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, should be good to merge once we fix the setup.py conflicts.

petrex and others added 2 commits January 15, 2025 15:03
@petrex petrex force-pushed the rocm_sparse_marlin branch from 3185c9d to aea9d81 Compare January 15, 2025 23:52
@petrex
Copy link
Collaborator Author

petrex commented Jan 16, 2025

LGTM, should be good to merge once we fix the setup.py conflicts.

done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/rocm CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants