Skip to content

Commit

Permalink
Add fp8 support for llama model family on Navi4x (#245)
Browse files Browse the repository at this point in the history
* Add fp8 support for Llama model family on Navi4x
  • Loading branch information
qli88 authored Oct 25, 2024
1 parent c9fc160 commit 4bba092
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 17 deletions.
16 changes: 15 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")

# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1200")

#
# Supported/expected torch versions for CUDA/ROCm.
Expand Down Expand Up @@ -172,6 +172,20 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result")
#
get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG})

#
# Get supported FP8 format based on GPU arches
#
get_supported_fp8_format(FP8_FORMAT ${VLLM_GPU_LANG} "${VLLM_GPU_ARCHES}")
if(${FP8_FORMAT} STREQUAL "E4M3FN")
message(STATUS "FP8 format: E4M3FN")
list(APPEND VLLM_GPU_FLAGS "-DUSE_CUDA_FP8_FORMAT")
elseif(${FP8_FORMAT} STREQUAL "E4M3FNUZ")
message(STATUS "FP8 format: E4M3FNUZ")
list(APPEND VLLM_GPU_FLAGS "-DUSE_HIP_FP8_FORMAT")
elseif(${FP8_FORMAT} STREQUAL "CONFLICT")
message(FATAL_ERROR "Target architectures support different types of FP8 formats!")
endif()

#
# Set nvcc parallelism.
#
Expand Down
30 changes: 30 additions & 0 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,33 @@ function (define_gpu_extension_target GPU_MOD_NAME)

install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
endfunction()


# gfx12xx should not be compiled together with gfx94x (MI300) because they support different types of FP8 format.
# FP8_FORMAT will be returned (E4M3FN / E4M3FNUZ / NONE / CONFLICT)
macro (get_supported_fp8_format FP8_FORMAT GPU_LANG GPU_ARCHES)
set(_USING_CUDA_FP8_FORMAT "FALSE")
set(_USING_HIP_FP8_FORMAT "FALSE")

if (NOT (${GPU_LANG} STREQUAL "HIP"))
set(_USING_CUDA_FP8_FORMAT "TRUE")
else()
foreach (_ARCH ${GPU_ARCHES})
if (_ARCH MATCHES "gfx94.")
set(_USING_HIP_FP8_FORMAT "TRUE")
elseif(_ARCH MATCHES "gfx12..")
set(_USING_CUDA_FP8_FORMAT "TRUE")
endif()
endforeach()
endif()

if ((${_USING_CUDA_FP8_FORMAT} STREQUAL "FALSE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "FALSE"))
set(FP8_FORMAT "NONE")
elseif((${_USING_CUDA_FP8_FORMAT} STREQUAL "FALSE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "TRUE"))
set(FP8_FORMAT "E4M3FNUZ")
elseif((${_USING_CUDA_FP8_FORMAT} STREQUAL "TRUE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "FALSE"))
set(FP8_FORMAT "E4M3FN")
else()
set(FP8_FORMAT "CONFLICT")
endif()
endmacro()
6 changes: 3 additions & 3 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"

#ifndef USE_ROCM
#if defined(USE_CUDA_FP8_FORMAT)
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif

#ifndef USE_ROCM
#if defined(USE_CUDA_FP8_FORMAT)
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
Expand Down Expand Up @@ -50,7 +50,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
}

float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#ifndef USE_ROCM
#if defined(USE_CUDA_FP8_FORMAT)
return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
Expand Down
4 changes: 2 additions & 2 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType
from vllm.utils import is_hip
from vllm.utils import is_hip, is_navi4x

logger = init_logger(__name__)

Expand Down Expand Up @@ -711,7 +711,7 @@ def scaled_fp8_quant(
assert (input.ndim == 2)
shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() \
out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() and not is_navi4x() \
else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
Expand Down
15 changes: 8 additions & 7 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once
from vllm.utils import is_hip, is_navi4x, print_warning_once

ACTIVATION_SCHEMES = ["static", "dynamic"]

Expand Down Expand Up @@ -227,8 +227,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight
weight_scale = layer.weight_scale

# If rocm, use float8_e4m3fnuz.
if is_hip():
# If rocm (except Navi4x), use float8_e4m3fnuz.
if is_hip() and not is_navi4x():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
Expand Down Expand Up @@ -378,9 +378,9 @@ def process_weights_after_loading(self, layer: Module) -> None:

# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
# If rocm (except Navi4x), use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz \
if is_hip() else torch.float8_e4m3fn
if is_hip() and not is_navi4x() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
Expand Down Expand Up @@ -427,8 +427,9 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if is_hip():
# If rocm (except Navi4x, which uses e4m3fn),
# normalize the weights and scales to e4m3fnuz
if is_hip() and not is_navi4x():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_hip
from vllm.utils import is_hip, is_navi4x

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
Expand Down Expand Up @@ -87,7 +87,8 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.use_fp8 = isinstance(quant_config, Fp8Config)
self.use_fp8 = isinstance(quant_config, Fp8Config) \
if is_hip() and not is_navi4x() else False
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
Expand Down Expand Up @@ -189,8 +190,10 @@ def __init__(
cache_config=cache_config,
quant_config=quant_config,
)
# For CUDA devices and Navi4x, attn_fp8_out will be set to false.
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
and is_hip() \
and not is_navi4x() \
and isinstance(quant_config, Fp8Config)

def forward(
Expand Down Expand Up @@ -225,7 +228,8 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.use_fp8 = isinstance(quant_config, Fp8Config)
self.use_fp8 = isinstance(quant_config, Fp8Config) \
if is_hip() and not is_navi4x() else False
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
Expand Down Expand Up @@ -456,7 +460,8 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn

if is_hip():
# Navi4x quantization should be treated as CUDA devices.
if is_hip() and not is_navi4x():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
Expand Down
12 changes: 12 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ipaddress
import os
import random
import re
import socket
import subprocess
import sys
Expand Down Expand Up @@ -425,6 +426,17 @@ def is_hip() -> bool:
return torch.version.hip is not None


@lru_cache(maxsize=None)
def is_navi4x() -> bool:
if not is_hip() or not torch.cuda.is_available():
return False
# All (visible) GPUs must be of the same type,
# otherwise FP8 results can't be guaranteed.
archName = torch.cuda.get_device_properties('cuda').gcnArchName
return (archName is not None) and \
(re.match("gfx12[0-9]{2}", archName) is not None)


@lru_cache(maxsize=None)
def is_cpu() -> bool:
from importlib.metadata import PackageNotFoundError, version
Expand Down

0 comments on commit 4bba092

Please sign in to comment.