Skip to content

Commit

Permalink
Custom all reduce fix mi250 (#247)
Browse files Browse the repository at this point in the history
* disable custom all reduce when running on multiple MI250

* formatting
  • Loading branch information
omirosh authored Oct 29, 2024
1 parent 4bba092 commit 5974cc3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
8 changes: 7 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
get_hf_image_processor_config,
get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_neuron, is_openvino, is_xpu,
is_hip, is_mi250, is_neuron, is_openvino, is_xpu,
print_warning_once)

if TYPE_CHECKING:
Expand Down Expand Up @@ -949,6 +949,12 @@ def __init__(
self._verify_args()
self.rank: int = 0

if is_mi250() and self.tensor_parallel_size > 1:
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"working correctly on multi AMD MI250.")

@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
Expand Down
9 changes: 9 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,15 @@ def is_xpu() -> bool:
return hasattr(torch, "xpu") and torch.xpu.is_available()


@lru_cache(maxsize=None)
def is_mi250() -> bool:
if not is_hip() or not torch.cuda.is_available():
return False
archName = torch.cuda.get_device_properties('cuda').gcnArchName
return (archName is not None) and \
("gfx90a" in archName)


@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
Expand Down

0 comments on commit 5974cc3

Please sign in to comment.