diff --git a/vllm/config.py b/vllm/config.py index a4a6ef05c900d..ee14822841548 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,7 +17,7 @@ get_hf_image_processor_config, get_hf_text_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - is_hip, is_neuron, is_openvino, is_xpu, + is_hip, is_mi250, is_neuron, is_openvino, is_xpu, print_warning_once) if TYPE_CHECKING: @@ -949,6 +949,12 @@ def __init__( self._verify_args() self.rank: int = 0 + if is_mi250() and self.tensor_parallel_size > 1: + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "working correctly on multi AMD MI250.") + @property def use_ray(self) -> bool: return self.distributed_executor_backend == "ray" or ( diff --git a/vllm/utils.py b/vllm/utils.py index af857ca315b38..35b84e4336657 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -487,6 +487,15 @@ def is_xpu() -> bool: return hasattr(torch, "xpu") and torch.xpu.is_available() +@lru_cache(maxsize=None) +def is_mi250() -> bool: + if not is_hip() or not torch.cuda.is_available(): + return False + archName = torch.cuda.get_device_properties('cuda').gcnArchName + return (archName is not None) and \ + ("gfx90a" in archName) + + @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes."""