Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for saving sharded checkpoints? #3209

Open
Kipok opened this issue Jan 30, 2025 · 3 comments
Open

Support for saving sharded checkpoints? #3209

Kipok opened this issue Jan 30, 2025 · 3 comments
Assignees
Labels
help wanted Extra attention is needed RLHF Using SGLang for post training

Comments

@Kipok
Copy link

Kipok commented Jan 30, 2025

Does sglang support sharded checkpoints? I see in here https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_loader/loader.py#L492 that there is a loader and it recommends using examples/save_sharded_state.py to save the sharded state, but this file doesn't exist.

Does it refer to this one from vllm https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/save_sharded_state.py?

Also the load-format doesn't have a choice for sharded_state https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py#L315, is that a typo or it's not supposed to be used?

My real problem is that I'm trying to load DeepSeek-R1 and it takes a very long time. I have a sharded checkpoint that vllm can load instantly, but sglang raises the following error (after I add "sharded_state" to choices in launcher to avoid error right away)

Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1773, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 239, in __init__
    self.tp_worker = TpWorkerClass(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in __init__
    self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 68, in __init__
    self.model_runner = ModelRunner(
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 185, in __init__
    self.load_model()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 306, in load_model
    self.model = get_model(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/__init__.py", line 22, in get_model
    return loader.load_model(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 605, in load_model
    param_data = state_dict[key].data
KeyError: 'model.layers.10.mlp.experts.e_score_correction_bias'

Does it mean that sglang shards model differently and I need to redo the sharding in some way? Or it's not supported at all? Or is there any other recommended way to load R1/V3 model fast?

@Kipok
Copy link
Author

Kipok commented Jan 30, 2025

I think I was able to work around that issue by renaming experts.e_score_correction_bias to experts.correction_bias (would be great if someone can double check if that's the right way to handle it). But now running into the following

Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1787, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 240, in __init__
    self.tp_worker = TpWorkerClass(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in __init__
    self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 68, in __init__
    self.model_runner = ModelRunner(
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 214, in __init__
    self.init_cuda_graphs()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 730, in init_cuda_graphs
    self.cuda_graph_runner = CudaGraphRunner(self)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 226, in __init__
    self.capture()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 292, in capture
    ) = self.capture_one_batch_size(bs, forward)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 370, in capture_one_batch_size
    run_once()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 363, in run_once
    logits_output = forward(input_ids, forward_batch.positions, forward_batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 858, in forward
    hidden_states = self.model(input_ids, positions, forward_batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 819, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 757, in forward
    hidden_states = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 516, in forward
    return self.forward_absorb(positions, hidden_states, forward_batch)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 580, in forward_absorb
    if self.w_kc.dtype == torch.float8_e4m3fnuz:
AttributeError: 'NoneType' object has no attribute 'dtype'

@Kipok
Copy link
Author

Kipok commented Jan 30, 2025

Ok, I was able to make it work with the below patch. Sharing in case anyone else runs into the same problem. Hope this is properly integrated in the future

diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py
index 9e6b094..2bc44ee 100644
--- a/python/sglang/srt/model_loader/loader.py
+++ b/python/sglang/srt/model_loader/loader.py
@@ -53,6 +53,18 @@ from sglang.srt.utils import (
     set_weight_attrs,
 )
 
+from sglang.srt.layers.quantization.fp8_utils import (
+    block_quant_to_tensor_quant,
+    input_to_float8,
+    normalize_e4m3fn_to_e4m3fnuz,
+)
+
+from vllm import _custom_ops as ops
+
+from sglang.srt.utils import is_cuda_available, is_hip
+
+is_hip_ = is_hip()
+
 
 @contextmanager
 def device_loading_context(module: torch.nn.Module, target_device: torch.device):
@@ -602,6 +614,8 @@ class ShardedStateLoader(BaseModelLoader):
                         # If loading with LoRA enabled, additional padding may
                         # be added to certain parameters. We only load into a
                         # narrowed view of the parameter data.
+                        if 'experts.e_score_correction_bias' in key:
+                            key = key.replace('experts.e_score_correction_bias', 'experts.correction_bias')
                         param_data = state_dict[key].data
                         param_shape = state_dict[key].shape
                         for dim, size in enumerate(tensor.shape):
@@ -617,6 +631,58 @@ class ShardedStateLoader(BaseModelLoader):
                             )
                         param_data.copy_(tensor)
                         state_dict.pop(key)
+
+            # patching for DS-v3/r1
+            for layer_id in range(model.config.num_hidden_layers):
+                self_attn = model.model.layers[layer_id].self_attn
+                if hasattr(self_attn.kv_b_proj, "qweight"):
+                    # AWQ compatible
+                    w = ops.awq_dequantize(
+                        self_attn.kv_b_proj.qweight,
+                        self_attn.kv_b_proj.scales,
+                        self_attn.kv_b_proj.qzeros,
+                        0,
+                        0,
+                        0,
+                    ).T
+                else:
+                    w = self_attn.kv_b_proj.weight
+                # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
+                # This may affect the accuracy of fp8 model.
+                if hasattr(model.quant_config, "weight_block_size") and w.dtype in (
+                    torch.float8_e4m3fn,
+                    torch.float8_e4m3fnuz,
+                ):
+                    weight_block_size = model.quant_config.weight_block_size
+                    if weight_block_size is not None:
+                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
+                        if is_hip_:
+                            weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
+                                weight=w,
+                                weight_scale=self_attn.kv_b_proj.weight_scale_inv,
+                                input_scale=None,
+                            )
+                        else:
+                            weight = w
+                            weight_scale = self_attn.kv_b_proj.weight_scale_inv
+
+                        w, scale = block_quant_to_tensor_quant(
+                            weight, weight_scale, weight_block_size
+                        )
+                        self_attn.w_scale = scale
+                w_kc, w_vc = w.unflatten(
+                    0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
+                ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
+                self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
+                self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
+                if (
+                    hasattr(self_attn.kv_b_proj, "weight_scale")
+                    and self_attn.w_scale is None
+                ):
+                    self_attn.w_scale = self_attn.kv_b_proj.weight_scale
+                    if is_hip_:
+                        self_attn.w_scale *= 2.0
+
             if state_dict:
                 raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
         return model.eval()
@@ -1225,3 +1291,4 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
         return LayeredModelLoader(load_config)
 
     return DefaultModelLoader(load_config)
+
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 7bee346..4fab9d6 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -320,6 +320,7 @@ class ServerArgs:
                 "gguf",
                 "bitsandbytes",
                 "layered",
+                "sharded_state",
             ],
             help="The format of the model weights to load. "
             '"auto" will try to load the weights in the safetensors format '

@zhaochenyang20 zhaochenyang20 self-assigned this Feb 1, 2025
@zhaochenyang20 zhaochenyang20 added help wanted Extra attention is needed RLHF Using SGLang for post training labels Feb 1, 2025
@zhaochenyang20
Copy link
Collaborator

Yeah. We do have plan to support different shared memory in SGLang. Stay tuned at:

#2569

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed RLHF Using SGLang for post training
Projects
None yet
Development

No branches or pull requests

2 participants