Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm: enable trillion-parameter MoE models with INT4-FP8 single node #4152

Merged
merged 2 commits into from
Mar 6, 2025

Conversation

HaiShaw
Copy link
Collaborator

@HaiShaw HaiShaw commented Mar 6, 2025

INT4 MoE weights, FP8 compute

credits: @shengnxu, @coderfeli, @carlushuang, @kkHuang-amd , @valarLip, @HaiShaw

Motivation

Enable models with more than 1.2 trillion parameters on single node of 8xMI300/MI308.
Speedup decoding performance from INT4 weight, lowered memory bandwidth.
Use the latest FP8 Tensor Core for computation (available to MI300, MI308).

Model used can be accessed at https://huggingface.co/amd/grok-1-W4A8KV8 (please apply access to https://huggingface.co/amd). you can also contact us in SGLang slack for temporary token.

grok-1-W4A8KV8/config.json:

{
  "_name_or_path": "/group/amdneuralopt/huggingface/pretrained_models/grok-1-sglang-tp1",
  "architectures": [
    "Grok1ModelForCausalLM"
  ],
  "attn_output_multiplier": 0.08838834764831845,
  "auto_map": {
    "AutoConfig": "configuration_grok1.Grok1Config",
    "AutoModel": "modeling_grok1.Grok1Model",
    "AutoModelForCausalLM": "modeling_grok1.Grok1ModelForCausalLM"
  },
  "bos_token_id": 1,
  "embedding_multiplier_scale": 78.38367176906169,
  "eos_token_id": 2,
  "hidden_size": 6144,
  "intermediate_size": 32768,
  "max_attn_value": 30.0,
  "max_position_embeddings": 8192,
  "model_type": "grok-1",
  "num_attention_heads": 48,
  "num_local_experts": 8,
  "num_experts_per_tok": 2,
  "num_hidden_layers": 64,
  "num_key_value_heads": 8,
  "output_multiplier_scale": 0.5773502691896257,
  "output_router_logits": false,
  "pad_token_id": 0,
  "quantization_config": {
    "activation_scheme": "static",
    "export": {
      "kv_cache_group": [
        "*k_proj",
        "*v_proj"
      ],
      "min_kv_scale": 1.0,
      "pack_method": "reorder",
      "weight_format": "real_quantized",
      "weight_merge_groups": null
    },
    "ignored_layers": [
      "model.layers.0.block_sparse_moe.gate",
      ... ... ... ...
      "model.layers.63.block_sparse_moe.gate",
      "lm_head"
    ],
    "kv_cache_scheme": "static",
    "quant_method": "fp8",
    "int4_experts": {
      "bits": 4,
      "sym": true,
      "group": "column"
    }
  },
  "rms_norm_eps": 1e-05,
  "router_aux_loss_coef": 0.001,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.47.1",
  "use_cache": true,
  "vocab_size": 131072
}

Modifications

with less than 1% margin on gsm8k scores

  • Grok-1 FP8 performance (one measured)
/sgl-workspace/sglang# python -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 512 --model /data/lmzheng-grok-1/ --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --quantization fp8
Benchmark ...
Prefill. latency: 1.70331 s, throughput:  19237.80 token/s
Decode.  latency: 0.01748 s, throughput:   1830.72 token/s
Decode.  latency: 0.01791 s, throughput:   1786.83 token/s
Decode.  latency: 0.01777 s, throughput:   1800.57 token/s
Decode.  latency: 0.01796 s, throughput:   1781.26 token/s
Decode.  latency: 0.01792 s, throughput:   1785.74 token/s
Decode.  median latency: 0.02416 s, median throughput:   1324.33 token/s
Total. latency: 13.594 s, throughput:   3615.73 token/s
  • Grok-1 INT4-FP8 quantized model performance (one measured)
# CK_MOE=1 USE_INT4_WEIGHT=1 python -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 512 --model /data/grok-1-W4A8KV8 --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --quantization fp8 --trust-remote-code
Benchmark ...
Prefill. latency: 2.21035 s, throughput:  14824.80 token/s
Decode.  latency: 0.02072 s, throughput:   1544.74 token/s
Decode.  latency: 0.02016 s, throughput:   1587.36 token/s
Decode.  latency: 0.02007 s, throughput:   1594.26 token/s
Decode.  latency: 0.02013 s, throughput:   1589.62 token/s
Decode.  latency: 0.02016 s, throughput:   1587.66 token/s
Decode.  median latency: 0.02068 s, median throughput:   1547.29 token/s
Total. latency: 12.734 s, throughput:   3859.76 token/s

INT4-FP8 model architecture

image

Conclusion:

  • INT4-FP8 enabled serving much bigger model on one server.
  • INT4-FP8 model yields better median decode throughput and latency, serves the purpose.

Checklist

@HaiShaw HaiShaw changed the title ROCm/AITER: enable trillion-parameter MoE models with INT4-FP8 (INT4 … ROCm/AITER: enable trillion-parameter MoE models with INT4-FP8 Mar 6, 2025
@HaiShaw HaiShaw changed the title ROCm/AITER: enable trillion-parameter MoE models with INT4-FP8 ROCm: enable trillion-parameter MoE models with INT4-FP8 single node Mar 6, 2025
@zhyncs zhyncs merged commit 13bc39c into sgl-project:main Mar 6, 2025
6 of 20 checks passed
@@ -513,6 +513,10 @@ def weight_loader(

# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

USE_INT4_WEIGHT -> SGLANG_ROCM_USE_INT4_WEIGHTS

@@ -590,6 +622,53 @@ def create_weights(
layer.w2_input_scale = None

def process_weights_after_loading(self, layer: Module) -> None:
if get_bool_env_var("USE_INT4_WEIGHT"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this part out into a separate function.

@HaiShaw
Copy link
Collaborator Author

HaiShaw commented Mar 6, 2025

@merrymercy let me handle your request soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants