-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ROCm: enable trillion-parameter MoE models with INT4-FP8 single node #4152
Merged
+124
−23
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -460,7 +460,11 @@ def create_weights( | |
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported | ||
|
||
if self.quant_config.is_checkpoint_fp8_serialized: | ||
params_dtype = torch.float8_e4m3fn | ||
params_dtype = ( | ||
torch.int32 | ||
if get_bool_env_var("USE_INT4_WEIGHT") | ||
else torch.float8_e4m3fn | ||
) | ||
tp_size = get_tensor_model_parallel_world_size() | ||
if self.block_quant: | ||
block_n, block_k = ( | ||
|
@@ -485,21 +489,40 @@ def create_weights( | |
) | ||
|
||
# WEIGHTS | ||
w13_weight = torch.nn.Parameter( | ||
torch.empty( | ||
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype | ||
), | ||
requires_grad=False, | ||
) | ||
if get_bool_env_var("USE_INT4_WEIGHT"): | ||
# INT4 MoE weight - INT32 packed | ||
w13_weight = torch.nn.Parameter( | ||
torch.empty( | ||
num_experts, | ||
2 * intermediate_size, | ||
hidden_size // 8, | ||
dtype=params_dtype, | ||
), | ||
requires_grad=False, | ||
) | ||
w2_weight = torch.nn.Parameter( | ||
torch.empty( | ||
num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype | ||
), | ||
requires_grad=False, | ||
) | ||
else: | ||
w13_weight = torch.nn.Parameter( | ||
torch.empty( | ||
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype | ||
), | ||
requires_grad=False, | ||
) | ||
w2_weight = torch.nn.Parameter( | ||
torch.empty( | ||
num_experts, hidden_size, intermediate_size, dtype=params_dtype | ||
), | ||
requires_grad=False, | ||
) | ||
|
||
layer.register_parameter("w13_weight", w13_weight) | ||
set_weight_attrs(w13_weight, extra_weight_attrs) | ||
|
||
w2_weight = torch.nn.Parameter( | ||
torch.empty( | ||
num_experts, hidden_size, intermediate_size, dtype=params_dtype | ||
), | ||
requires_grad=False, | ||
) | ||
layer.register_parameter("w2_weight", w2_weight) | ||
set_weight_attrs(w2_weight, extra_weight_attrs) | ||
|
||
|
@@ -538,7 +561,9 @@ def create_weights( | |
layer.register_parameter("w13_weight_scale", w13_weight_scale) | ||
layer.register_parameter("w2_weight_scale", w2_weight_scale) | ||
|
||
if is_hip_ and get_bool_env_var("CK_MOE"): | ||
if ( | ||
is_hip_ | ||
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel | ||
HaiShaw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling | ||
w13_weight_scale1 = torch.nn.Parameter( | ||
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), | ||
|
@@ -565,6 +590,13 @@ def create_weights( | |
set_weight_attrs(w13_weight_scale, extra_weight_attrs) | ||
set_weight_attrs(w2_weight_scale, extra_weight_attrs) | ||
|
||
if get_bool_env_var("USE_INT4_WEIGHT"): | ||
extra_weight_attrs.update( | ||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} | ||
) | ||
set_weight_attrs(w13_weight_scale1, extra_weight_attrs) | ||
set_weight_attrs(w2_weight_scale1, extra_weight_attrs) | ||
|
||
# INPUT_SCALES | ||
if self.quant_config.activation_scheme == "static": | ||
if not self.quant_config.is_checkpoint_fp8_serialized: | ||
|
@@ -590,6 +622,53 @@ def create_weights( | |
layer.w2_input_scale = None | ||
|
||
def process_weights_after_loading(self, layer: Module) -> None: | ||
if get_bool_env_var("USE_INT4_WEIGHT"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move this part out into a separate function. |
||
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added | ||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute) | ||
# Weight Permutation | ||
layer.w13_weight = torch.nn.Parameter( | ||
permute_weight(layer.w13_weight.data), | ||
requires_grad=False, | ||
) | ||
torch.cuda.empty_cache() | ||
layer.w2_weight = torch.nn.Parameter( | ||
permute_weight(layer.w2_weight.data), | ||
requires_grad=False, | ||
) | ||
torch.cuda.empty_cache() | ||
|
||
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale | ||
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert. | ||
# We won't do requant each expert's fp8 weight (not direct available), | ||
# instead we adjust half of INT4 w13_weight_scale1 numbers | ||
assert layer.w13_weight_scale is not None | ||
shard_size = layer.intermediate_size_per_partition | ||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values | ||
for expert_id in range(layer.num_experts): | ||
start = 0 | ||
max_w13_scale_fp8 = max_w13_scales[expert_id] | ||
for shard_id in range(2): | ||
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8: | ||
int4_rescale = ( | ||
layer.w13_weight_scale[expert_id][shard_id] | ||
/ max_w13_scale_fp8 | ||
) | ||
layer.w13_weight_scale1[expert_id][ | ||
start : start + shard_size | ||
] *= int4_rescale | ||
start += shard_size | ||
|
||
layer.w13_weight_scale = torch.nn.Parameter( | ||
max_w13_scales, requires_grad=False | ||
) | ||
|
||
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling | ||
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post | ||
for expert_id in range(layer.num_experts): | ||
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id] | ||
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id] | ||
return | ||
|
||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( | ||
padding_size, # Avoid circular import | ||
) | ||
|
@@ -823,8 +902,24 @@ def apply( | |
correction_bias=correction_bias, | ||
) | ||
|
||
if is_hip_ and get_bool_env_var("CK_MOE") and activation == "silu": | ||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): | ||
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE") | ||
assert not no_combine, f"{no_combine=} is not supported." | ||
return asm_moe( | ||
x, | ||
layer.w13_weight, | ||
layer.w2_weight, | ||
topk_weights, | ||
topk_ids, | ||
layer.w13_weight_scale1, | ||
layer.w2_weight_scale1, | ||
activation=activation, | ||
) | ||
if is_hip_ and get_bool_env_var("CK_MOE"): | ||
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being. | ||
assert ( | ||
activation == "silu" | ||
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE" | ||
assert not no_combine, f"{no_combine=} is not supported." | ||
if self.block_quant: | ||
return asm_moe( | ||
|
@@ -835,10 +930,6 @@ def apply( | |
topk_ids, | ||
layer.w13_weight_scale_inv, | ||
layer.w2_weight_scale_inv, | ||
None, | ||
None, | ||
False, | ||
None, | ||
block_shape=tuple(self.quant_config.weight_block_size), | ||
expert_mask=None, | ||
) | ||
|
@@ -851,9 +942,6 @@ def apply( | |
topk_ids, | ||
layer.w13_weight_scale1, | ||
layer.w2_weight_scale1, | ||
None, | ||
None, | ||
False, | ||
) | ||
else: | ||
# Expert fusion with FP8 quantization | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
USE_INT4_WEIGHT
->SGLANG_ROCM_USE_INT4_WEIGHTS