Skip to content

Commit

Permalink
Support FP8 FA from Quark format (#388)
Browse files Browse the repository at this point in the history
* Support FP8 FA from Quark format

* Support FP8 FA from Quark format

* nit: update comment
  • Loading branch information
BowenBao authored Jan 28, 2025
1 parent 28b1ad9 commit 6b2147f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 24 deletions.
47 changes: 27 additions & 20 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import fnmatch
import re
from typing import Any, Dict, List, Optional, cast

import torch
Expand Down Expand Up @@ -122,6 +121,12 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
for q_config in q_configs:
q_config["output_tensors"] = None

# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(Dict[str, Any],
layer_quant_config.get("*q_proj"))
q_proj_q_config["output_tensors"] = None

return cls(quant_config=config,
kv_cache_group=kv_cache_group,
kv_cache_config=kv_cache_config,
Expand All @@ -148,6 +153,19 @@ def _check_scheme_supported(self,
else:
return False

def is_fp8_w8a8(self) -> bool:
# Returns True if all quantized layers in model are fp8 w8a8
global_quant_config = cast(
Dict[str, Any], self.quant_config.get("global_quant_config"))
layer_quant_configs = cast(Dict[str, Any],
self.quant_config.get("layer_quant_config"))
for config in (global_quant_config, *layer_quant_configs.values()):
weight_config = cast(Dict[str, Any], config.get("weight"))
input_config = cast(Dict[str, Any], config.get("input_tensors"))
if not self._is_fp8_w8a8(weight_config, input_config):
return False
return True

def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
Expand Down Expand Up @@ -286,25 +304,14 @@ def get_cache_scale(self, name: str) -> Optional[str]:
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
return None

kv_proj_names = [
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
]
if name.endswith(".output_scale"):
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
return name.replace(kv_output_scale_name, ".attn.k_scale")

elif len(kv_proj_names) == 2:
for kv_proj_name in kv_proj_names:
if kv_proj_name in name and kv_proj_name == "k_proj":
return name.replace(".k_proj.output_scale",
".attn.k_scale")
elif kv_proj_name in name and kv_proj_name == "v_proj":
return name.replace(".v_proj.output_scale",
".attn.v_scale")
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")

# If no matches, return None
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.out_dtype = torch.get_default_dtype()

@classmethod
def get_min_capability(cls) -> int:
Expand Down Expand Up @@ -134,6 +135,7 @@ def apply_weights(self,
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -197,7 +198,9 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.use_fp8 = isinstance(quant_config, Fp8Config)
self.use_fp8 = isinstance(
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
and quant_config.is_fp8_w8a8())
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.attn = Grok1Attention(hidden_size=self.hidden_size,
Expand Down
14 changes: 11 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -84,7 +85,9 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.use_fp8 = (isinstance(quant_config, Fp8Config)
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
(isinstance(quant_config, QuarkConfig)
and quant_config.is_fp8_w8a8())
if current_platform.is_rocm() and not is_navi() else
False)
if hidden_act != "silu":
Expand Down Expand Up @@ -196,10 +199,13 @@ def __init__(self,
sliding_window = None

# For CUDA devices and Navi4x, attn_fp8 will be set to false.
use_fp8 = isinstance(
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
and quant_config.is_fp8_w8a8())
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
and current_platform.is_rocm() \
and not is_navi() \
and isinstance(quant_config, Fp8Config)
and use_fp8

self.attn = Attention(
self.num_heads,
Expand Down Expand Up @@ -240,7 +246,9 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.use_fp8 = (isinstance(quant_config, Fp8Config)
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
(isinstance(quant_config, QuarkConfig)
and quant_config.is_fp8_w8a8())
if current_platform.is_rocm() and not is_navi() else
False)
rope_theta = getattr(config, "rope_theta", 10000)
Expand Down

0 comments on commit 6b2147f

Please sign in to comment.