-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ROCm Sparse Marlin Kernels #1206
base: main
Are you sure you want to change the base?
Conversation
ROCm build infrastructure
[ROCm] Enable Tiled layout extension and minor changes to setup
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1206
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New FailuresAs of commit f18043d with merge base 98c4e2e ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Do you have performance numbers by any chance relative to fp16? wanna make sure the performance improvements are competitive with CUDA |
still WIP, but would you share the benchmark you guys are using? will try that on mi300x when the PR is ready. |
Ok holler at me again whenever you need a review. Really excited to see this land |
For benchmarking it is a little ad hoc the best place for this today would be to verify on: https://github.com/pytorch/ao/blob/main/torchao/_models/llama/generate.py |
Fixes builds for non-rocm.
00bc94d
to
d2c7ce4
Compare
@pytorchbot rebase |
You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra. |
Unknown label
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
662bfe7
to
a4e8c30
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Did you get a chance to try this and get benchmarking numbers? Curious to see how it compares. We should probably update the testing framework too for AMD
@@ -19,6 +19,28 @@ | |||
#include "base.h" | |||
|
|||
namespace torchao { | |||
|
|||
#ifdef USE_ROCM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we gate on a specific ROCm version like we do for CUDA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! What we need is a GPU arch check instead of ROCm version check. I have added a GPU architecture check in the setup.py
. As a result, the kernel will now only be built for the MI300X architecture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I think setup.py
was recently updated by #1490, so you may have to pull in the new changes.
#if defined(USE_ROCM) | ||
#if ROCM_VERSION >= 60200 | ||
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308})); | ||
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does B16_ONE refer to here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks let me clean up a little bit. I'd like this PR focus on sparse_marlin. tensor_core_tile_layout.cu should go to #1201 instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks It is planned. I will update the benchmark PR. |
1f3b773
to
08d1cfb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, should be good to merge once we fix the setup.py conflicts.
3185c9d
to
aea9d81
Compare
done. |
Built on top pf #1201. This pull request introduces support for ROCm (Radeon Open Compute) for sparse marling kernel in addition to CUDA, enabling the code to run on AMD GPUs.
The main changes involve conditional compilation to handle differences between CUDA and ROCm, as well as adding ROCm-specific intrinsics for MI300x.
co-author : @lcskrishna
Key changes include:
ROCm Support in
setup.py
:Conditional Compilation in CUDA Source Files:
ROCm-specific Implementations:
Next: