-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ROCm OCP FP8 Support #1677
base: main
Are you sure you want to change the base?
ROCm OCP FP8 Support #1677
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1677
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 475d2f5 with merge base ea7910e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
for arch in mxArchName: | ||
if arch in archName: | ||
return True | ||
return False | ||
|
||
|
||
def is_MI350(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional: if there is any public info (spec sheet, etc) we can link to about this hardware from the docblock, I think that would be awesome! same for below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks . I've added Supported AMD GPU Models and their LLVM gfx Codes
will add detail documentation when the time is right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good if CI passes!
TLDR: Follow up/ Build on top of #144476. add OCP FP8 support for gfx950 refer to pytorch/ao#1677 This pull request includes several changes to improve compatibility and support for new GPU architectures and data types, particularly for ROCm. The key updates involve adding support for new ROCm versions and GPU architectures, updating data type handling, and removing outdated checks. ### Improvements to GPU Architecture and ROCm Version Support: * [`aten/src/ATen/Context.cpp`](diffhunk://#diff-33de472d304acbe57d693c8567370c638068bedc1aa0ce8e9dc115dad05a7810L323-R326): Added support for new GPU architectures `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. * [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199): Updated architecture support in multiple functions to include `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL865-R876) ### Updates to Data Type Handling: * [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L81-L98): Enhanced data type conversion to include new float8 types for both CUDA and ROCm environments. * [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fL29-R80): Updated `HipDataTypeFor` template to handle new float8 types and added hard-coded enum values for ROCm versions prior to 6.3. ### Removal of Outdated Checks: * [`cmake/public/LoadHIP.cmake`](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197): Removed the check for `HIP_NEW_TYPE_ENUMS` as it is no longer necessary with the updated ROCm versions. [[1]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197) [[2]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L211-R182) These changes ensure better compatibility and performance on newer hardware and software environments, particularly for users leveraging ROCm and CUDA for deep learning and scientific computing tasks. Pull Request resolved: #146632 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
TLDR: Follow up/ Build on top of pytorch#144476. add OCP FP8 support for gfx950 refer to pytorch/ao#1677 This pull request includes several changes to improve compatibility and support for new GPU architectures and data types, particularly for ROCm. The key updates involve adding support for new ROCm versions and GPU architectures, updating data type handling, and removing outdated checks. ### Improvements to GPU Architecture and ROCm Version Support: * [`aten/src/ATen/Context.cpp`](diffhunk://#diff-33de472d304acbe57d693c8567370c638068bedc1aa0ce8e9dc115dad05a7810L323-R326): Added support for new GPU architectures `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. * [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199): Updated architecture support in multiple functions to include `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL865-R876) ### Updates to Data Type Handling: * [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L81-L98): Enhanced data type conversion to include new float8 types for both CUDA and ROCm environments. * [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fL29-R80): Updated `HipDataTypeFor` template to handle new float8 types and added hard-coded enum values for ROCm versions prior to 6.3. ### Removal of Outdated Checks: * [`cmake/public/LoadHIP.cmake`](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197): Removed the check for `HIP_NEW_TYPE_ENUMS` as it is no longer necessary with the updated ROCm versions. [[1]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197) [[2]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L211-R182) These changes ensure better compatibility and performance on newer hardware and software environments, particularly for users leveraging ROCm and CUDA for deep learning and scientific computing tasks. Pull Request resolved: pytorch#146632 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
Add a comment documenting supported AMD GPU models and their corresponding LLVM gfx codes, including Navi4, MI300X, and MI350.
TLDR: Quick fix for ROCm device check. OCP FP8 support status update.
This pull request includes changes to improve the handling of imports, update configurations, and add new utility functions in the
torchao
library. The most important changes include removing comments to avoid circular imports, updating the configuration for supported float8 types, and adding utility functions to check for specific GPU architectures.refer to : pytorch/pytorch#146632
Configuration updates:
torchao/float8/config.py
: Updated the configuration for selecting the preferred float8 type pair to include support for OCP F8 variants in MI350/Navi4.New utility functions:
torchao/utils.py
: Added new utility functionsis_MI350
andis_Navi4
to check for specific GPU architectures.Improvements to import handling:
torchao/dtypes/uintx/marlin_qqq_tensor.py
: Removed comments to avoid circular imports in the__tensor_unflatten__
andfrom_plain
methods. [1] [2]torchao/dtypes/uintx/marlin_sparse_layout.py
: Removed comments to avoid circular imports in the__tensor_unflatten__
andfrom_plain
methods. [1] [2]