Skip to content

Commit

Permalink
ROCm: enable trillion-parameter MoE models with INT4-FP8 single node (s…
Browse files Browse the repository at this point in the history
  • Loading branch information
HaiShaw authored and ShenAo1111 committed Mar 10, 2025
1 parent 2b0dd3a commit 0cf43c7
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 23 deletions.
12 changes: 12 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@ def weight_loader(

# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0

# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)

Expand Down Expand Up @@ -551,6 +555,10 @@ def weight_loader(
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5

self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
Expand All @@ -570,6 +578,10 @@ def weight_loader(
tp_rank=tp_rank,
)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0

self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
Expand Down
132 changes: 110 additions & 22 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,11 @@ def create_weights(
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported

if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
params_dtype = (
torch.int32
if get_bool_env_var("USE_INT4_WEIGHT")
else torch.float8_e4m3fn
)
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
Expand All @@ -485,21 +489,40 @@ def create_weights(
)

# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
),
requires_grad=False,
)
if get_bool_env_var("USE_INT4_WEIGHT"):
# INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size,
hidden_size // 8,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype
),
requires_grad=False,
)
else:
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype
),
requires_grad=False,
)

layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

Expand Down Expand Up @@ -538,7 +561,9 @@ def create_weights(
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)

if is_hip_ and get_bool_env_var("CK_MOE"):
if (
is_hip_
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
Expand All @@ -565,6 +590,13 @@ def create_weights(
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

if get_bool_env_var("USE_INT4_WEIGHT"):
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale1, extra_weight_attrs)
set_weight_attrs(w2_weight_scale1, extra_weight_attrs)

# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
Expand All @@ -590,6 +622,53 @@ def create_weights(
layer.w2_input_scale = None

def process_weights_after_loading(self, layer: Module) -> None:
if get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()

# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
max_w13_scale_fp8 = max_w13_scales[expert_id]
for shard_id in range(2):
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
int4_rescale = (
layer.w13_weight_scale[expert_id][shard_id]
/ max_w13_scale_fp8
)
layer.w13_weight_scale1[expert_id][
start : start + shard_size
] *= int4_rescale
start += shard_size

layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)

# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for expert_id in range(layer.num_experts):
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
return

from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
padding_size, # Avoid circular import
)
Expand Down Expand Up @@ -823,8 +902,24 @@ def apply(
correction_bias=correction_bias,
)

if is_hip_ and get_bool_env_var("CK_MOE") and activation == "silu":
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return asm_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=activation,
)
if is_hip_ and get_bool_env_var("CK_MOE"):
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
return asm_moe(
Expand All @@ -835,10 +930,6 @@ def apply(
topk_ids,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
None,
None,
False,
None,
block_shape=tuple(self.quant_config.weight_block_size),
expert_mask=None,
)
Expand All @@ -851,9 +942,6 @@ def apply(
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
None,
None,
False,
)
else:
# Expert fusion with FP8 quantization
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,8 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor:
elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
else:
return x_
# return x_
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 8), 2, 4)

x_ = x_.permute(0, 1, 3, 4, 2, 5)
x_ = x_.contiguous()
Expand Down

0 comments on commit 0cf43c7

Please sign in to comment.