Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm: enable trillion-parameter MoE models with INT4-FP8 single node #4152

Merged
merged 2 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@ def weight_loader(

# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

USE_INT4_WEIGHT -> SGLANG_ROCM_USE_INT4_WEIGHTS

loaded_weight = loaded_weight * 2.0

# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)

Expand Down Expand Up @@ -551,6 +555,10 @@ def weight_loader(
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5

self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
Expand All @@ -570,6 +578,10 @@ def weight_loader(
tp_rank=tp_rank,
)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0

self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
Expand Down
132 changes: 110 additions & 22 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,11 @@ def create_weights(
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported

if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
params_dtype = (
torch.int32
if get_bool_env_var("USE_INT4_WEIGHT")
else torch.float8_e4m3fn
)
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
Expand All @@ -485,21 +489,40 @@ def create_weights(
)

# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
),
requires_grad=False,
)
if get_bool_env_var("USE_INT4_WEIGHT"):
# INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size,
hidden_size // 8,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype
),
requires_grad=False,
)
else:
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype
),
requires_grad=False,
)

layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

Expand Down Expand Up @@ -538,7 +561,9 @@ def create_weights(
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)

if is_hip_ and get_bool_env_var("CK_MOE"):
if (
is_hip_
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
Expand All @@ -565,6 +590,13 @@ def create_weights(
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

if get_bool_env_var("USE_INT4_WEIGHT"):
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale1, extra_weight_attrs)
set_weight_attrs(w2_weight_scale1, extra_weight_attrs)

# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
Expand All @@ -590,6 +622,53 @@ def create_weights(
layer.w2_input_scale = None

def process_weights_after_loading(self, layer: Module) -> None:
if get_bool_env_var("USE_INT4_WEIGHT"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this part out into a separate function.

# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()

# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
max_w13_scale_fp8 = max_w13_scales[expert_id]
for shard_id in range(2):
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
int4_rescale = (
layer.w13_weight_scale[expert_id][shard_id]
/ max_w13_scale_fp8
)
layer.w13_weight_scale1[expert_id][
start : start + shard_size
] *= int4_rescale
start += shard_size

layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)

# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for expert_id in range(layer.num_experts):
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
return

from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
padding_size, # Avoid circular import
)
Expand Down Expand Up @@ -823,8 +902,24 @@ def apply(
correction_bias=correction_bias,
)

if is_hip_ and get_bool_env_var("CK_MOE") and activation == "silu":
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return asm_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=activation,
)
if is_hip_ and get_bool_env_var("CK_MOE"):
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
return asm_moe(
Expand All @@ -835,10 +930,6 @@ def apply(
topk_ids,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
None,
None,
False,
None,
block_shape=tuple(self.quant_config.weight_block_size),
expert_mask=None,
)
Expand All @@ -851,9 +942,6 @@ def apply(
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
None,
None,
False,
)
else:
# Expert fusion with FP8 quantization
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,8 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor:
elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
else:
return x_
# return x_
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 8), 2, 4)

x_ = x_.permute(0, 1, 3, 4, 2, 5)
x_ = x_.contiguous()
Expand Down
Loading