Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deepseek_v3 fused gate #3191

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

NovTi
Copy link

@NovTi NovTi commented Jan 28, 2025

Add deepseek v3 fused gate module

# Your module under test
output, indices_my = deepseekv3_fused_gate(tensor, bias, seq_length)

###### Reference Implementation ######
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refactor this code into a standalone function, which can be directly used from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/topk.py#L111-L147.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean I separate the reference implementation into a standalone function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I will do that

output_ref = weights.type_as(scores)

# Assertions
output_check = torch.allclose(output_ref.sort()[0], output.sort()[0], rtol=1e-04, atol=1e-05)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not directly compare output and output_ref instead of sorting them?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is weird, kernel sometimes will output exact same output but in a different order. I checked the following steps and the output order does not matter so I used this way to do the unit test, is this ok?

Copy link
Collaborator

@BBuf BBuf Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to determine at which specific step of the fused kernel this inconsistency in order occurs. Additionally, we need to clarify whether running the PyTorch implementation twice with the same input would result in inconsistent output orders. Finally, if you believe that the current order inconsistency does not affect the fused MoE accuracy, you need to provide an end-to-end result, such as running the GSM8K test with the DeepSeek V3 model.

图片

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I will check the inconsistency inside the kernel. I cannot run e2e test on my server, Yineng will help me do the test

from sgl_kernel import deepseekv3_fused_gate


@pytest.mark.parametrize("seq_length", range(1, 20000))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a benchmark script? Maybe refer to https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

@@ -3,6 +3,7 @@
bmm_fp8,
custom_dispose,
custom_reduce,
deepseekv3_fused_gate,
Copy link
Collaborator

@BBuf BBuf Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems more appropriate to name it deepseekv3_fused_gate here, as models from the deepseek series can all go through this gate function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a generalized kernel, it only works for deepseek v3 671b model

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it also works for DeepSeek V2 VL

input.data_ptr(), bias.data_ptr(), output.data_ptr(), indices.data_ptr<int64_t>(), num_rows, k, route_scale
);

CHECK_CUDA_SUCCESS(cudaDeviceSynchronize());
Copy link
Collaborator

@BBuf BBuf Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synchronization is not allowed in CUDA kernel's host code, as it will cause CUDA graphs to crash. Can you remove it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will update these

@BBuf BBuf changed the title Add deepseek fused gate Add deepseek_v3 fused gate Jan 28, 2025
@@ -0,0 +1,219 @@
#include <cfloat>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu#L231

@BBuf
Copy link
Collaborator

BBuf commented Jan 29, 2025

In TensorRT-LLM, the fused MoE module, in addition to the fused_gate here, also includes the trt_moe_expand_and_permute, the CUTLASS grouped GEMM, and the trt_moe_unpermute_and_reduce processes. Compared to the MoE implemented in Triton, the advantage of TensorRT-LLM's approach is that it does not require padding, which saves some computational overhead, and the CUTLASS implementation may have greater performance potential, especially on Hopper architecture. I conducted an experiment at https://github.com/sgl-project/sglang/tree/bbuf_tmp, where I successfully connected trt_moe_expand_and_permute, trt_moe_unpermute_and_reduce, and FlashInfer's grouped GEMM in sgl-kernel to run correctness comparisons with the Triton fused MoE operator in the case of bfloat16 dtype. However, it seems that the current performance is still significantly worse than Triton's. This could be due to performance issues with FlashInfer's grouped GEMM on specific shapes. Additionally, FlashInfer's GEMM does not currently support scaled FP8 or INT8 GEMM. If anyone is interested, we can discuss whether to directly integrate TensorRT's fused MoE as a backend into sglang or to use FlashInfer's approach, which would require a customization of FlashInfer for grouped GEMM. cc @zhyncs

@zhyncs
Copy link
Member

zhyncs commented Jan 29, 2025

directly integrate TensorRT's fused MoE as a backend into sglang

sounds good @BBuf

@BBuf
Copy link
Collaborator

BBuf commented Jan 29, 2025

directly integrate TensorRT's fused MoE as a backend into sglang

sounds good @BBuf

Yeah, I can have a try.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants