Skip to content

Commit

Permalink
[Bugfix][Quantization] Fix FP8 + EP (#13784)
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
tlrmchlsmth authored Feb 25, 2025
1 parent 51010a1 commit 1e15aae
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 22 deletions.
30 changes: 15 additions & 15 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ class FusedMoE(torch.nn.Module):

def __init__(
self,
num_experts: int,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
intermediate_size: int,
Expand Down Expand Up @@ -291,7 +291,8 @@ def __init__(
else:
self.ep_size = 1
self.top_k = top_k
self.num_experts = num_experts # Global number of experts
self.global_num_experts = num_experts
self.local_num_experts = self.global_num_experts // self.ep_size
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
Expand All @@ -308,39 +309,38 @@ def __init__(

if self.ep_size > 1:
# Create a tensor of size num_experts filled with -1
self.expert_map = torch.full((self.num_experts, ),
self.expert_map = torch.full((self.global_num_experts, ),
-1,
dtype=torch.int32)
# Create a expert map for the local experts
local_num_experts = num_experts // self.ep_size
ep_rank = get_tensor_model_parallel_rank()
if ep_rank < (self.ep_size - 1):
# Each non-last rank gets local_num_experts experts.
self.expert_map[ep_rank * local_num_experts:
(ep_rank + 1) * local_num_experts] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
self.expert_map[ep_rank * self.local_num_experts:
(ep_rank + 1) * self.local_num_experts] = \
torch.arange(0, self.local_num_experts, dtype=torch.int32)
else:
# All remaining experts are assigned to the last rank.
local_num_experts = num_experts - ep_rank * local_num_experts
self.expert_map[-local_num_experts:] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
self.local_num_experts = (self.global_num_experts -
ep_rank * self.local_num_experts)
self.expert_map[-self.local_num_experts:] = \
torch.arange(0, self.local_num_experts, dtype=torch.int32)

if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")

# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None

local_num_experts = torch.sum(self.expert_map != -1) \
if self.expert_map is not None else num_experts

moe_quant_params = {
"num_experts": local_num_experts,
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
Expand Down Expand Up @@ -647,7 +647,7 @@ def forward(self, hidden_states: torch.Tensor,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.num_experts,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_quant_method(self, layer: torch.nn.Module,
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
if layer.num_experts > 32:
if layer.local_num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,11 +573,11 @@ def process_weights_after_loading(self, layer: Module) -> None:
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
layer.num_experts,
layer.local_num_experts,
dtype=torch.float32,
device=w13_weight.device),
requires_grad=False)
for expert in range(layer.num_experts):
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[
expert] = ops.scaled_fp8_quant(
layer.w13_weight.data[expert, :, :])
Expand Down Expand Up @@ -644,7 +644,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def override_quantization_method(cls, hf_quant_cfg,
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, FusedMoE):
if layer.num_experts > 32:
if layer.local_num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/quark/quark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
Expand Down

0 comments on commit 1e15aae

Please sign in to comment.