From 20c8119915ca5a8b933ae8194e55e8e8650396f2 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 8 Mar 2025 04:11:18 +0800 Subject: [PATCH 01/27] Fix eagle hang issue for max_new_tokens=1 (#4185) --- python/sglang/srt/managers/scheduler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 05bc8d730ea..cb3a4b5de39 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -957,7 +957,11 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: self.req_to_token_pool.free(self.chunked_req.req_pool_idx) self.batch_is_full = False + last_bs = self.last_batch.batch_size() self.last_batch.filter_batch() + if self.last_batch.batch_size() < last_bs: + self.batch_is_full = False + if not self.last_batch.is_empty(): if self.running_batch is None: self.running_batch = self.last_batch From e1aaa79ac9954c705f839e8304d29eac452ce04b Mon Sep 17 00:00:00 2001 From: saienduri Date: Fri, 7 Mar 2025 13:02:02 -0800 Subject: [PATCH 02/27] Update amd ci docker image to v0.4.3.post4-rocm630. (#4189) --- .github/workflows/pr-test-amd.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index 6e505957685..507590025e2 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -35,12 +35,12 @@ jobs: else DEVICE_FLAG="--device /dev/dri" fi - docker pull lmsysorg/sglang:v0.4.3.post2-rocm630 + docker pull lmsysorg/sglang:v0.4.3.post4-rocm630 docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \ -v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \ --cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \ -w /sglang-checkout --name ci_sglang \ - lmsysorg/sglang:v0.4.3.post2-rocm630 + lmsysorg/sglang:v0.4.3.post4-rocm630 - name: Install dependencies run: | @@ -71,12 +71,12 @@ jobs: else DEVICE_FLAG="--device /dev/dri" fi - docker pull lmsysorg/sglang:v0.4.3.post2-rocm630 + docker pull lmsysorg/sglang:v0.4.3.post4-rocm630 docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \ -v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \ --cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \ -w /sglang-checkout --name ci_sglang \ - lmsysorg/sglang:v0.4.3.post2-rocm630 + lmsysorg/sglang:v0.4.3.post4-rocm630 - name: Install dependencies run: | From d052f4c8a9fb7e135ca0f0b09f6feead93db9e01 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 7 Mar 2025 20:21:08 -0800 Subject: [PATCH 03/27] New clang format for sgl kernel (#4194) --- python/upload_pypi.sh | 6 - sgl-kernel/.clang-format | 7 + .../activation/fused_add_rms_norm_kernel.cu | 13 +- .../csrc/allreduce/custom_all_reduce_hip.cuh | 116 ++-- .../csrc/allreduce/trt_reduce_internal.cu | 43 +- .../csrc/allreduce/trt_reduce_kernel.cu | 30 +- .../lightning_attention_decode_kernel.cu | 46 +- .../epilogue/epilogue_per_row_per_col_scale.h | 63 +- .../gemm/dispatch_policy.hpp | 20 +- .../gemm/gemm_universal_base_compat.h | 22 +- .../gemm/gemm_with_epilogue_visitor.h | 72 ++- .../csrc/gemm/cublas_grouped_gemm.cu | 64 +- .../csrc/gemm/fp8_blockwise_gemm_kernel.cu | 75 ++- .../sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu | 598 ++++++++++++------ .../sgl-kernel/csrc/gemm/int8_gemm_kernel.cu | 404 ++++++++---- .../csrc/gemm/per_tensor_quant_fp8.cu | 17 +- .../csrc/gemm/per_token_group_quant_fp8.cu | 33 +- .../csrc/gemm/per_token_quant_fp8.cu | 16 +- .../sgl-kernel/csrc/moe/moe_align_kernel.cu | 53 +- .../csrc/speculative/eagle_utils.cu | 86 ++- .../csrc/speculative/speculative_sampling.cu | 46 +- .../csrc/speculative/speculative_sampling.cuh | 95 ++- .../src/sgl-kernel/include/sgl_kernels_ops.h | 246 +++++-- .../include/trt_reduce_internal.cuh | 4 +- sgl-kernel/src/sgl-kernel/include/utils.h | 1 - 25 files changed, 1490 insertions(+), 686 deletions(-) delete mode 100644 python/upload_pypi.sh diff --git a/python/upload_pypi.sh b/python/upload_pypi.sh deleted file mode 100644 index 35616e1dad8..00000000000 --- a/python/upload_pypi.sh +++ /dev/null @@ -1,6 +0,0 @@ -cp ../README.md ../LICENSE . -rm -rf dist -python3 -m build -python3 -m twine upload dist/* - -rm -rf README.md LICENSE diff --git a/sgl-kernel/.clang-format b/sgl-kernel/.clang-format index 5e690c02885..afbd654a790 100644 --- a/sgl-kernel/.clang-format +++ b/sgl-kernel/.clang-format @@ -6,3 +6,10 @@ DerivePointerAlignment: false PointerAlignment: Left NamespaceIndentation: None SortIncludes: true +AllowShortLoopsOnASingleLine: false +BinPackParameters: false # Prevents packing parameters in declarations +BinPackArguments: false # Prevents packing arguments in function calls +AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis +AlignOperands: Align # Aligns arguments vertically +PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument +PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name diff --git a/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu index a4ae14ae59d..41f4d2e7099 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu @@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T // support float16, bfloat16 and float32 DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { cudaError_t status = norm::FusedAddRMSNorm( - static_cast(input.data_ptr()), static_cast(residual.data_ptr()), - static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); + static_cast(input.data_ptr()), + static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), + batch_size, + hidden_size, + eps, + torch_current_stream); + TORCH_CHECK( + status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); return true; }); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh b/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh index 06173bc4225..7baf5f01ef4 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh @@ -153,19 +153,20 @@ DINLINE O downcast(array_t val) { // prior memory accesses. Note: volatile writes will not be reordered against // other volatile writes. template -DINLINE void start_sync(const RankSignals& sg, +DINLINE void start_sync( + const RankSignals& sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal* self_sg, - int rank) { + Signal* self_sg, + int rank) { #ifdef USE_ROCM uint32_t flag = self_sg->_flag[blockIdx.x] + 1; if (threadIdx.x < ngpus) { // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write - __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, - __MEMORY_SCOPE_SYSTEM); + __scoped_atomic_store_n( + &sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM); // wait until we got true from all ranks while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) < flag) @@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg, // barrier in the all reduce kernel. If it's the final synchronization barrier, // we don't need to make any visibility guarantees for prior memory accesses. template -DINLINE void end_sync(const RankSignals& sg, +DINLINE void end_sync( + const RankSignals& sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal* self_sg, - int rank) { + Signal* self_sg, + int rank) { #ifdef USE_ROCM __syncthreads(); // eliminate the case that prior writes are not visible after signals become @@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg, if (threadIdx.x < ngpus) { // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write - __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, - final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, __MEMORY_SCOPE_SYSTEM); + __scoped_atomic_store_n( + &sg.signals[threadIdx.x]->end[blockIdx.x][rank], + flag, + final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, + __MEMORY_SCOPE_SYSTEM); // wait until we got true from all ranks - while (__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], - final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, __MEMORY_SCOPE_DEVICE) < flag) + while (__scoped_atomic_load_n( + &self_sg->end[blockIdx.x][threadIdx.x], + final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, + __MEMORY_SCOPE_DEVICE) < flag) ; } __syncthreads(); @@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { } template -__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _dp, RankSignals sg, +__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage( + RankData* _dp, + RankSignals sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal* self_sg, - T* __restrict__ result, int rank, int size) { + Signal* self_sg, + T* __restrict__ result, + int rank, + int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same @@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) { } template -__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _dp, RankSignals sg, +__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage( + RankData* _dp, + RankSignals sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal* self_sg, - T* __restrict__ result, int rank, int size) { + Signal* self_sg, + T* __restrict__ result, + int rank, + int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -357,8 +372,14 @@ class CustomAllreduce { * note: this class does not own any device memory. Any required buffers * are passed in from the constructor */ - CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, const hipIpcMemHandle_t* handles, - const std::vector& offsets, int rank, bool full_nvlink = true) + CustomAllreduce( + Signal* meta, + void* rank_data, + size_t rank_data_sz, + const hipIpcMemHandle_t* handles, + const std::vector& offsets, + int rank, + bool full_nvlink = true) : rank_(rank), world_size_(offsets.size()), full_nvlink_(full_nvlink), @@ -382,8 +403,8 @@ class CustomAllreduce { auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); if (new_handle) { char* ipc_ptr; - CUDACHECK(hipIpcOpenMemHandle((void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), - hipIpcMemLazyEnablePeerAccess)); + CUDACHECK(hipIpcOpenMemHandle( + (void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess)); it->second = ipc_ptr; } return it->second; @@ -399,13 +420,14 @@ class CustomAllreduce { void* base_ptr; // note: must share the base address of each allocation, or we get wrong // address - if (hipPointerGetAttribute(&base_ptr, + if (hipPointerGetAttribute( + &base_ptr, #ifdef USE_ROCM - HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR, + HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR, #else - CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, #endif - (hipDeviceptr_t)ptr) != hipSuccess) + (hipDeviceptr_t)ptr) != hipSuccess) throw std::runtime_error("failed to get pointer attr"); CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); offsets[i] = ((char*)ptr) - ((char*)base_ptr); @@ -415,8 +437,8 @@ class CustomAllreduce { void check_rank_data_capacity(size_t num = 1) { if (d_rank_data_base_ + num > d_rank_data_end_) - throw std::runtime_error("Rank data buffer is overflowed by " + - std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); + throw std::runtime_error( + "Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); } void register_buffer(const std::vector& handles, const std::vector& offsets, void* self) { @@ -443,8 +465,8 @@ class CustomAllreduce { // rank 1 may get the same input address for the second allreduce, but rank 2 // got a different address. IPC handles have internal reference counting // mechanism so overhead should be small. - void register_graph_buffers(const std::vector& handles, - const std::vector>& offsets) { + void + register_graph_buffers(const std::vector& handles, const std::vector>& offsets) { auto num_buffers = graph_unreg_buffers_.size(); check_rank_data_capacity(num_buffers); std::vector rank_data(num_buffers); @@ -474,11 +496,17 @@ class CustomAllreduce { * will cause contention on NVLink bus. */ template - void allreduce(hipStream_t stream, T* input, T* output, int size, + void allreduce( + hipStream_t stream, + T* input, + T* output, + int size, #ifndef USE_ROCM - int threads = 512, int block_limit = 36){ + int threads = 512, + int block_limit = 36){ #else - int threads = 512, int block_limit = 16) { + int threads = 512, + int block_limit = 16) { #endif auto d = packed_t::P::size; if (size % d != 0) @@ -487,8 +515,8 @@ class CustomAllreduce { "of " + std::to_string(d)); if (block_limit > kMaxBlocks) - throw std::runtime_error("max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + - std::to_string(block_limit)); + throw std::runtime_error( + "max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit)); RankData* ptrs; hipStreamCaptureStatus status; @@ -499,17 +527,17 @@ class CustomAllreduce { } else { auto it = buffers_.find(input); if (it == buffers_.end()) - throw std::runtime_error("buffer address " + std::to_string(reinterpret_cast(input)) + - " is not registered!"); + throw std::runtime_error( + "buffer address " + std::to_string(reinterpret_cast(input)) + " is not registered!"); ptrs = it->second; } size /= d; auto bytes = size * sizeof(typename packed_t::P); int blocks = ::min(block_limit, (size + threads - 1) / threads); -#define KL(ngpus, name) \ - hipLaunchKernelGGL((name), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, \ - size); +#define KL(ngpus, name) \ + hipLaunchKernelGGL( \ + (name), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size); #define REDUCE_CASE(ngpus) \ case ngpus: { \ if (world_size_ == 2) { \ diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu index fa9e3a2c5d2..f1ee5d40efd 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu @@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) { return c.packed; } -__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, - size_t const world_size, int const tidx, int const bidx) { +__inline__ __device__ void multi_gpu_barrier( + uint32_t** signals, + uint32_t const flag, + size_t const local_rank, + size_t const world_size, + int const tidx, + int const bidx) { // After this function, at least one block in each GPU has reached the barrier if (tidx < world_size) { // we can think of signals having the shape [world_size, world_size] @@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const } template -__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, - size_t const world_size, int const tidx, int const bidx, int const grid_size) { +__inline__ __device__ void block_barrier( + uint32_t** signals, + uint32_t const flag, + size_t const local_rank, + size_t const world_size, + int const tidx, + int const bidx, + int const grid_size) { if constexpr (!start) { __syncthreads(); } @@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc } } // wait for equivalent blocks of other GPUs to have copied data to their shareable buffer - block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, - grid_size); + block_barrier( + params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size); // Each block accumulates the values from the different GPUs on the same node. for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { @@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc } } } - block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, - grid_size); + block_barrier( + params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size); // Each block accumulates the values from the different GPUs on the same node. for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { @@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc } } - block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, - bidx, grid_size); + block_barrier( + params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size); // Gather all needed elts from other intra-node ranks for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { @@ -459,8 +470,12 @@ std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar //////////////////////////////////////////////////////////////////////////////////////////////////// template -void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block, - cudaStream_t stream) { +void dispatchARKernels( + AllReduceStrategyType algo, + AllReduceParams& param, + int blocks_per_grid, + int threads_per_block, + cudaStream_t stream) { switch (algo) { case AllReduceStrategyType::ONESHOT: { oneShotAllReduceKernel<<>>(param); @@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy CHECK_CUDA_SUCCESS(cudaGetLastError()); } -void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, - cudaStream_t stream) { +void trtCustomAllReduce( + AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) { if (params.elts_total == 0) { return; } diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu index af129de52ef..5c879255621 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu @@ -29,9 +29,14 @@ using IPC_KEY = std::array; class AllReduceMeta { public: - AllReduceMeta(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, - const std::vector& tmp_result_buffers, const std::vector& barrier_in, - const std::vector& barrier_out) { + AllReduceMeta( + int64_t rank_id, + int64_t world_size, + torch::Tensor& rank_data, + const std::vector& buffers, + const std::vector& tmp_result_buffers, + const std::vector& barrier_in, + const std::vector& barrier_out) { this->rank_id = (int)rank_id; this->world_size = (int)world_size; this->barrier_in = barrier_in; @@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0; } -fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, - const std::vector& tmp_result_buffers, const std::vector& barrier_in, - const std::vector& barrier_out) { +fptr_t init_custom_ar( + int64_t rank_id, + int64_t world_size, + torch::Tensor& rank_data, + const std::vector& buffers, + const std::vector& tmp_result_buffers, + const std::vector& barrier_in, + const std::vector& barrier_out) { auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out); return (fptr_t)m; } @@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) { auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); if (new_handle) { char* ipc_ptr; - CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle((void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), - cudaIpcMemLazyEnablePeerAccess)); + CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle( + (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess)); it->second = ipc_ptr; } return it->second; @@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) { // rank 1 may get the same input address for the second allreduce, but rank 2 // got a different address. IPC handles have internal reference counting // mechanism so overhead should be small. -void register_graph_buffers(fptr_t _fa, const std::vector>& handles, - const std::vector>& offsets) { +void register_graph_buffers( + fptr_t _fa, const std::vector>& handles, const std::vector>& offsets) { AllReduceMeta* m = reinterpret_cast(_fa); std::vector handle_bytes; handle_bytes.reserve(handles.size()); diff --git a/sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu index 02c50498eb9..f9d524f6001 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu @@ -23,15 +23,18 @@ limitations under the License. #define THREADS_PER_BLOCK 128 template -__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d] - const T* __restrict__ k, // [b, h, 1, d] - const T* __restrict__ v, // [b, h, 1, e] - const float* __restrict__ past_kv, // [b, h, d, e] - const float* __restrict__ slope, // [h, 1, 1] - T* __restrict__ output, // [b, h, 1, e] - float* __restrict__ new_kv, // [b, h, d, e] - const int batch_size, const int num_heads, const int qk_dim, - const int v_dim) { +__global__ void lightning_attention_decode_kernel( + const T* __restrict__ q, // [b, h, 1, d] + const T* __restrict__ k, // [b, h, 1, d] + const T* __restrict__ v, // [b, h, 1, e] + const float* __restrict__ past_kv, // [b, h, d, e] + const float* __restrict__ slope, // [h, 1, 1] + T* __restrict__ output, // [b, h, 1, e] + float* __restrict__ new_kv, // [b, h, d, e] + const int batch_size, + const int num_heads, + const int qk_dim, + const int v_dim) { extern __shared__ char smem[]; T* __restrict__ q_shared = reinterpret_cast(smem); T* __restrict__ k_shared = reinterpret_cast(smem + qk_dim * sizeof(T)); @@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q, } } -void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, - const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, - torch::Tensor new_kv) { +void lightning_attention_decode( + const torch::Tensor& q, + const torch::Tensor& k, + const torch::Tensor& v, + const torch::Tensor& past_kv, + const torch::Tensor& slope, + torch::Tensor output, + torch::Tensor new_kv) { TORCH_CHECK(q.is_contiguous(), "q must be contiguous"); TORCH_CHECK(k.is_contiguous(), "k must be contiguous"); TORCH_CHECK(v.is_contiguous(), "v must be contiguous"); @@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] { size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float); lightning_attention_decode_kernel<<>>( - q.data_ptr(), k.data_ptr(), v.data_ptr(), past_kv.data_ptr(), - slope.data_ptr(), output.data_ptr(), new_kv.data_ptr(), batch_size, num_heads, - qk_dim, v_dim); + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + past_kv.data_ptr(), + slope.data_ptr(), + output.data_ptr(), + new_kv.data_ptr(), + batch_size, + num_heads, + qk_dim, + v_dim); })); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h index f5cd4381563..9f85bee28b1 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h @@ -25,9 +25,15 @@ namespace cutlass { namespace epilogue { namespace threadblock { -template +template < + typename ThreadblockShape_, + int ThreadCount, + typename ScaleTileIterator_, + typename OutputTileIterator_, + typename ElementAccumulator_, + typename ElementCompute_, + typename ElementwiseFunctor_, + bool UseMasking_ = false> class EpilogueVisitorPerRowPerCol { public: using ThreadblockShape = ThreadblockShape_; @@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol { Arguments(typename ElementwiseFunctor::Params elementwise_) : elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} - Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_, - int64_t batch_stride_D_) + Arguments( + typename ElementwiseFunctor::Params elementwise_, + int64_t batch_stride_alpha_, + int64_t batch_stride_C_, + int64_t batch_stride_D_) : elementwise(elementwise_), batch_stride_alpha(batch_stride_alpha_), batch_stride_C(batch_stride_C_), @@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol { public: CUTLASS_DEVICE - EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, - cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, - typename ScaleTileIterator::Params params_alpha_col, - typename OutputTileIterator::Params params_C, - typename OutputTileIterator::Params params_D, bool with_bias, bool per_token_quant, - bool per_channel_quant, AlphaScaleElementType* ptr_alpha_row, - AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, - typename OutputTileIterator::Element* ptr_D, - cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), - int column_offset = 0, - cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) + EpilogueVisitorPerRowPerCol( + Params const& params, + SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, + int thread_idx, + int warp_idx, + int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, + typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, + bool with_bias, + bool per_token_quant, + bool per_channel_quant, + AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, + typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), + int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) : params_(params), shared_storage_(shared_storage), extent_(problem_size), @@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol { /// Helper to indicate split-K behavior CUTLASS_DEVICE - void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme - int split_k_slices) { ///< Total number of split-K slices + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices } /// Called to set the batch index @@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol { private: CUTLASS_DEVICE - ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col, - AlphaScaleElementType const& scale_row) { + ComputeFragment per_token_channel_scale_accumulator_( + ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) { ComputeFragment result; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ComputeFragment::kElements; ++i) { @@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol { } CUTLASS_DEVICE - ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col, - AlphaScaleElementType const& scale_row) { + ComputeFragment per_token_scale_accumulator_( + ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) { ComputeFragment result; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ComputeFragment::kElements; ++i) { diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp index 48b0ad9490e..f62b51ee7ed 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp @@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp // specialized dynamic schedule For FP8 kernels with Block Scaling -template , class KernelSchedule = KernelTmaWarpSpecialized, - int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, - // while zero-value `ScaleGranularityM` indicates that scaling - // granularity is `size<0>(TileShape_MNK{})` along M. - > +template < + int Stages_, + class ClusterShape_ = Shape<_1, _1, _1>, + class KernelSchedule = KernelTmaWarpSpecialized, + int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, + // while zero-value `ScaleGranularityM` indicates that scaling + // granularity is `size<0>(TileShape_MNK{})` along M. + > struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 : MainloopSm90TmaGmmaWarpSpecialized { - static_assert(cute::is_same_v>, - "KernelSchedule must be one of the warp specialized policies"); + static_assert( + cute:: + is_same_v>, + "KernelSchedule must be one of the warp specialized policies"); }; ////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h index 3de9ff078b6..b58d84318ba 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h @@ -159,8 +159,9 @@ class GemmUniversalBaseCompat { get_grid_shape_(grid_tiled_shape, gemm_k_size, args); dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" - << " result = {" << result << "}"); + CUTLASS_TRACE_HOST( + " grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); return result; } @@ -175,8 +176,8 @@ class GemmUniversalBaseCompat { CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); if (smem_size <= (48 << 10)) { - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel, - GemmKernel::kThreadCount, smem_size); + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); if (result == cudaSuccess) { CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); @@ -184,12 +185,12 @@ class GemmUniversalBaseCompat { } } else { // Query assuming zero shared memory then compute occupancy limit based on SMEM - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel, - GemmKernel::kThreadCount, 0); + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); if (result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " - << cudaGetErrorString(result)); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); return -1; } @@ -226,8 +227,9 @@ class GemmUniversalBaseCompat { /// Initializes GEMM state from arguments. Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); + CUTLASS_TRACE_HOST( + "GemmUniversalBaseCompat::initialize() - workspace " << workspace + << ", stream: " << (stream ? "non-null" : "null")); size_t workspace_bytes = get_workspace_size(args); diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h index 11fc872505f..905d11ba2c6 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h @@ -32,10 +32,11 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function + > struct GemmWithEpilogueVisitor { public: using Mma = Mma_; @@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor { Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {} /// constructs an arguments structure - Arguments(GemmCoord problem_size_, TensorRefA ref_A_, TensorRefB ref_B_, TensorRefAlphaCol ref_alpha_col_, - TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, - typename EpilogueVisitor::Arguments epilogue_visitor_) + Arguments( + GemmCoord problem_size_, + TensorRefA ref_A_, + TensorRefB ref_B_, + TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, + TensorRefC ref_C_, + TensorRefC ref_D_, + typename EpilogueVisitor::Arguments epilogue_visitor_) : mode(GemmUniversalMode::kGemm), problem_size(problem_size_), batch_count(1), @@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor { isAMisaligned = problem_size.k() % kAlignmentA; } else if (platform::is_same::value) { isAMisaligned = problem_size.m() % kAlignmentA; - } else if (platform::is_same>::value || - platform::is_same>::value) { + } else if ( + platform::is_same>::value || + platform::is_same>::value) { isAMisaligned = problem_size.k() % kAlignmentA; } @@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor { isBMisaligned = problem_size.n() % kAlignmentB; } else if (platform::is_same::value) { isBMisaligned = problem_size.k() % kAlignmentB; - } else if (platform::is_same>::value || - platform::is_same>::value) { + } else if ( + platform::is_same>::value || + platform::is_same>::value) { isBMisaligned = problem_size.k() % kAlignmentB; } @@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor { isCMisaligned = problem_size.n() % kAlignmentC; } else if (platform::is_same::value) { isCMisaligned = problem_size.m() % kAlignmentC; - } else if (platform::is_same>::value || - platform::is_same>::value) { + } else if ( + platform::is_same>::value || + platform::is_same>::value) { isCMisaligned = problem_size.n() % kAlignmentC; } @@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor { int thread_idx = threadIdx.x; // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, - tb_offset_A); + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); - typename Mma::IteratorB iterator_B(params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, - tb_offset_B); + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. @@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor { threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // assume identity swizzle - MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN); + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); @@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor { with_bias = false; } - EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, params.problem_size.mn(), - thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, - params.params_D, with_bias, true, true, params.ptr_alpha_row, params.ptr_alpha_col, - params.ptr_C, params.ptr_D, threadblock_offset, - blockIdx.y * params.problem_size.m()); + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, + shared_storage.epilogue.visitor, + params.problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + params.params_alpha_col, + params.params_C, + params.params_D, + with_bias, + true, + true, + params.ptr_alpha_row, + params.ptr_alpha_col, + params.ptr_C, + params.ptr_D, + threadblock_offset, + blockIdx.y * params.problem_size.m()); if (params.mode == GemmUniversalMode::kGemm) { // Indicate which position in a serial reduction the output operator is currently updating diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu index ec899d33024..d0a80c7bff5 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu @@ -21,10 +21,13 @@ #include "utils.h" -static void check_group_count(const std::vector& inputs, const std::vector& weights, - const std::vector& outputs) { - TORCH_CHECK(((inputs.size() == weights.size()) && (inputs.size() == outputs.size())), - "The group count of inputs, weights and outputs should be the same."); +static void check_group_count( + const std::vector& inputs, + const std::vector& weights, + const std::vector& outputs) { + TORCH_CHECK( + ((inputs.size() == weights.size()) && (inputs.size() == outputs.size())), + "The group count of inputs, weights and outputs should be the same."); } static void check_device_dtype(const torch::Dtype& dtype, const std::vector& tensors) { @@ -68,21 +71,26 @@ static std::vector get_tensor_ptrs(const std::vector& tens static torch::Tensor create_ptr_pointer(const std::vector& ptrs, cudaStream_t stream) { auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA); torch::Tensor gpu_ptrs = torch::empty({static_cast(ptrs.size())}, options); - TORCH_CHECK(cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, - stream) == CUBLAS_STATUS_SUCCESS); + TORCH_CHECK( + cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) == + CUBLAS_STATUS_SUCCESS); return gpu_ptrs; } // We want compute input @ weight^T in row major // This is equivalent to computing weight @ input^T in col major // Cublas only accepts matrix in column major, so this arrangement is needed -void cublas_grouped_gemm(const std::vector& inputs, // b: (m, k) row major = (k, m) col major - const std::vector& weights, // a: (n, k) row major = (n, k)^T col major - const std::vector& outputs, // c: (m, n) row major = (n, m) col major - const torch::Dtype& out_dtype, int64_t cublas_handle, int64_t cuda_stream) { - TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, - "cublas grouped_gemm can" - "only be applied to float16 and bfloat16 dtype"); +void cublas_grouped_gemm( + const std::vector& inputs, // b: (m, k) row major = (k, m) col major + const std::vector& weights, // a: (n, k) row major = (n, k)^T col major + const std::vector& outputs, // c: (m, n) row major = (n, m) col major + const torch::Dtype& out_dtype, + int64_t cublas_handle, + int64_t cuda_stream) { + TORCH_CHECK( + out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, + "cublas grouped_gemm can" + "only be applied to float16 and bfloat16 dtype"); int group_count = inputs.size(); check_group_count(inputs, weights, outputs); @@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector& inputs, // b: (m, k torch::Tensor d_c = create_ptr_pointer(c_array, stream); #if defined CUDA_VERSION && CUDA_VERSION >= 12050 - auto status = cublasGemmGroupedBatchedEx(handle, transa_array.data(), transb_array.data(), m_array.data(), - n_array.data(), k_array.data(), alpha_array.data(), (void**)d_a.data_ptr(), - cuda_data_type, lda_array.data(), (void**)d_b.data_ptr(), cuda_data_type, - ldb_array.data(), beta_array.data(), (void**)d_c.data_ptr(), cuda_data_type, - ldc_array.data(), group_count, group_size.data(), CUBLAS_COMPUTE_32F); + auto status = cublasGemmGroupedBatchedEx( + handle, + transa_array.data(), + transb_array.data(), + m_array.data(), + n_array.data(), + k_array.data(), + alpha_array.data(), + (void**)d_a.data_ptr(), + cuda_data_type, + lda_array.data(), + (void**)d_b.data_ptr(), + cuda_data_type, + ldb_array.data(), + beta_array.data(), + (void**)d_c.data_ptr(), + cuda_data_type, + ldc_array.data(), + group_count, + group_size.data(), + CUBLAS_COMPUTE_32F); TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status)); TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization"); return; #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, - "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion()); + TORCH_CHECK_NOT_IMPLEMENTED( + false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion()); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu index 337a5ad69ac..a62a5c0ce6d 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu @@ -35,8 +35,12 @@ using namespace cute; template -void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b) { +void launch_sm90_fp8_blockwise_scaled_mm( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b) { using ElementAccumulator = float; using ElementCompute = float; using ElementBlockScale = float; @@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC, - LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp; + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueSchedule, + StoreEpilogueCompute>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, - TileShape, ClusterShape, + ArchTag, + OperatorClass, + ElementA, + LayoutA, + AlignmentA, + ElementB, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; - using GemmKernel = - cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape - CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; Gemm gemm_op; @@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor } template -void sm90_fp8_blockwise_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b) { +void sm90_fp8_blockwise_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b) { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _1, _1>; launch_sm90_fp8_blockwise_scaled_mm(out, a, b, scales_a, scales_b); } -torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const torch::Dtype& out_dtype) { +torch::Tensor fp8_blockwise_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Dtype& out_dtype) { TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); @@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); - TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, - "mat_a must be multiple of 16 bytes for memory alignment"); - TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, - "mat_b must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK( + (mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK( + (mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment"); TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); @@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T #endif #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, - "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); + TORCH_CHECK_NOT_IMPLEMENTED( + false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu index 36b9585f349..64731ebe4d2 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu @@ -53,10 +53,17 @@ limitations under the License. using namespace cute; #if defined CUDA_VERSION && CUDA_VERSION >= 12040 -template typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT, - typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>> +template < + typename ElementType, + typename OutElementType, + typename AccumElementType, + typename CtaShape, + typename WarpShape, + int Stages, + bool WithBias, + typename FP8MathOperator = cutlass::arch::OpMultiplyAdd, + template typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT, + typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>> struct DeviceGemmFp8RowwiseSm89 { static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); @@ -85,56 +92,86 @@ struct DeviceGemmFp8RowwiseSm89 { // Number of epilogue stages in EVT static constexpr int EVTEpilogueStages = 1; - using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout; + using OutputTileThreadMap = cutlass::epilogue::threadblock:: + OutputTileThreadLayout; // Definition of EVT using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch; using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>; - using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; + cutlass::multiplies, + ElementComputeEpilogue, + ElementComputeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; + using bScaleSrc = cutlass::epilogue::threadblock:: + VisitorRowBroadcast>; using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT; - using ComputeAScale = - cutlass::epilogue::threadblock::VisitorCompute; - using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast>; + using ComputeAScale = cutlass::epilogue::threadblock:: + VisitorCompute; + using aScaleSrc = cutlass::epilogue::threadblock:: + VisitorColBroadcast>; using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT; // With bias using biasSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; - using ComputeAScaleWithBias = - cutlass::epilogue::threadblock::VisitorCompute; + using ComputeAScaleWithBias = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, + ElementC, + ElementComputeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; using EpilogueAScaleWithBias = cutlass::epilogue::threadblock::Sm80EVT; using dTar = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, ElementC, cutlass::FloatRoundStyle::round_to_nearest, Stride>; - using EpilogueStore = - typename cutlass::platform::conditional, - cutlass::epilogue::threadblock::Sm80EVT>::type; + OutputTileThreadMap, + ElementC, + cutlass::FloatRoundStyle::round_to_nearest, + Stride>; + using EpilogueStore = typename cutlass::platform::conditional< + WithBias, + cutlass::epilogue::threadblock::Sm80EVT, + cutlass::epilogue::threadblock::Sm80EVT>::type; using EpilogueOp = EpilogueStore; using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< - ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB, - cutlass::ComplexTransform::kNone, AlignmentB, ElementC, LayoutC, AlignmentC, ElementAccumulator, - ElementComputeEpilogue, OperatorClass, ArchTag, CtaShape, WarpShape, InstructionShape, EpilogueOp, - ThreadblockSwizzle, Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel; + ElementA, + LayoutA, + cutlass::ComplexTransform::kNone, + AlignmentA, + ElementB, + LayoutB, + cutlass::ComplexTransform::kNone, + AlignmentB, + ElementC, + LayoutC, + AlignmentC, + ElementAccumulator, + ElementComputeEpilogue, + OperatorClass, + ArchTag, + CtaShape, + WarpShape, + InstructionShape, + EpilogueOp, + ThreadblockSwizzle, + Stages, + FP8MathOperator, + EVTEpilogueStages>::GemmKernel; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; }; template -typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +typename Gemm::Arguments prepare_sm89_fp8_args( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { using ElementT = typename Gemm::ElementA; using ElementOutput = typename Gemm::ElementD; using ElementComputeEpilogue = float; @@ -158,54 +195,61 @@ typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch:: ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); - typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode - {m, n, k}, // Problem size - 1, // Split-k factor - {}, // Epilogue args - ptr_a, // a pointer - ptr_b, // b pointer - nullptr, // c pointer (unused) - nullptr, // d pointer (unused) - m * k, // batch stride a (unused) - n * k, // batch stride b (unused) - m * n, // batch stride c (unused) - m * n, // batch stride d (unused) - lda, // stride a - ldb, // stride b - ldc, // stride c (unused) - ldc); // stride d (unused) + typename Gemm::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, // Mode + {m, n, k}, // Problem size + 1, // Split-k factor + {}, // Epilogue args + ptr_a, // a pointer + ptr_b, // b pointer + nullptr, // c pointer (unused) + nullptr, // d pointer (unused) + m * k, // batch stride a (unused) + n * k, // batch stride b (unused) + m * n, // batch stride c (unused) + m * n, // batch stride d (unused) + lda, // stride a + ldb, // stride b + ldc, // stride c (unused) + ldc); // stride d (unused) if constexpr (WithBias) { - args.epilogue = {{ - { - {}, // Accumulator - {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, - {} // Multiplies - }, - {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, - {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}}, - {} // Multiplies - }, - {ptr_d, {n, _1{}, _0{}}}}; + args.epilogue = { + { + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; } else { - args.epilogue = {{ - { - {}, // Accumulator - {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, - {} // Multiplies - }, - {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, - {} // Multiplies - }, - {ptr_d, {n, _1{}, _0{}}}}; + args.epilogue = { + { + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; } return args; } template -void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void launch_sm89_fp8_scaled_mm( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { auto args = prepare_sm89_fp8_args(out, a, b, scales_a, scales_b, bias); Gemm gemm_op; @@ -222,109 +266,187 @@ void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const } template -void sm89_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm89_fp8_dispatch_bias( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { using ElementInput = cutlass::float_e4m3_t; using ElementOutput = OutType; using AccumElementType = float; if (bias) { - using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + using Gemm = typename DeviceGemmFp8RowwiseSm89< + ElementInput, + ElementOutput, + AccumElementType, + CtaShape, + WarpShape, + Stages, + true>::Gemm; return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); } else { - using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + using Gemm = typename DeviceGemmFp8RowwiseSm89< + ElementInput, + ElementOutput, + AccumElementType, + CtaShape, + WarpShape, + Stages, + false>::Gemm; return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); } } template -void sm89_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm89_fp8_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { uint32_t const m = a.size(0); uint32_t const n = out.size(1); if (m == 1) { if (n <= 8192) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<16, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 7>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 5>(out, a, b, scales_a, scales_b, bias); } } else if (m <= 16) { // M in (1, 16] if (n <= 8192) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<16, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 4>(out, a, b, scales_a, scales_b, bias); } else if (n <= 16384) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 5>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<16, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 7>(out, a, b, scales_a, scales_b, bias); } } else if (m <= 64) { // M in (16, 64] if (n <= 16384) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 7>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<16, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 7>(out, a, b, scales_a, scales_b, bias); } } else if (m <= 128) { // M in (64, 128] if (n <= 8192) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 64, 64>, + 4>(out, a, b, scales_a, scales_b, bias); } else if (n <= 16384) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 64, 64>, + 5>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 5>(out, a, b, scales_a, scales_b, bias); } } else if (m <= 256) { // M in (128, 256] if (n <= 8192) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + 5>(out, a, b, scales_a, scales_b, bias); } else if (n <= 16384) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + 7>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, + 4>(out, a, b, scales_a, scales_b, bias); } } else if (m <= 512) { // M in (256, 512) if (n <= 16384) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + 2>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + 4>(out, a, b, scales_a, scales_b, bias); } } else { // M in (512, inf) if (n <= 8192) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + 3>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + 2>(out, a, b, scales_a, scales_b, bias); } } } #endif #if defined CUDA_VERSION && CUDA_VERSION >= 12000 -template +template < + typename ElementType, + typename OutElementType, + typename AccumElementType, + typename CTAShape, + typename ClusterShape, + typename MainloopScheduleType, + typename EpilogueScheduleType, + typename TileSchedulerType = void, + bool WithBias = false> struct DeviceGemmFp8RowwiseSm90 { static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); @@ -374,44 +496,70 @@ struct DeviceGemmFp8RowwiseSm90 { using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default // setting in the Collective Builder // Implement rowwise scaling epilogue. - using XScale = - cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, - cute::Stride, cute::Int<0>, cute::Int<0>>>; - - using WScale = - cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, - cute::Stride, cute::Int<1>, cute::Int<0>>>; - - using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, - cute::Stride, cute::Int<1>, cute::Int<0>>>; + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0, + TileShape, + ElementOutput, + ElementOutput, + cute::Stride, cute::Int<1>, cute::Int<0>>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + ElementComputeEpilogue, // First stage output type. + ElementComputeEpilogue, // First stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; - using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + ElementOutput, + ElementComputeEpilogue, // Second stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; // With bias - using ComputeWithBias = - cutlass::epilogue::fusion::Sm90Compute; + using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, + ElementOutput, + ElementComputeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; using EpilogueEVT = typename cutlass::platform::conditional::type; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementComputeEpilogue, ElementC, LayoutC, - AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput, cutlass::epilogue::TmaWarpSpecialized, + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementC, + LayoutC, + AlignmentC, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::TmaWarpSpecialized, EpilogueEVT>::CollectiveOp; using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; @@ -423,22 +571,38 @@ struct DeviceGemmFp8RowwiseSm90 { using FastAccum = FastPongSchedule; // Default apply Pingpong using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, - TileShape, ClusterShape, + ArchTag, + OperatorClass, + ElementA, + LayoutA, + AlignmentA, + ElementB, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, MainloopScheduleType>::CollectiveOp; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape - CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + TileSchedulerType>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; }; template -typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +typename Gemm::Arguments prepare_sm90_fp8_args( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { using ElementT = typename Gemm::ElementA; using ElementOutput = typename Gemm::ElementD; using ElementComputeEpilogue = float; @@ -465,14 +629,15 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch:: StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); StrideC stride_c; StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); - typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, - {m, n, k, 1}, - {ptr_a, stride_a, ptr_b, stride_b}, - {{}, // epilogue.thread - nullptr, - stride_c, - ptr_d, - stride_d}}; + typename Gemm::Arguments args = { + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {ptr_a, stride_a, ptr_b, stride_b}, + {{}, // epilogue.thread + nullptr, + stride_c, + ptr_d, + stride_d}}; if constexpr (WithBias) { args.epilogue.thread = { {ptr_scales_a}, @@ -500,9 +665,13 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch:: } template -void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void launch_sm90_fp8_scaled_mm( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { auto args = prepare_sm90_fp8_args(out, a, b, scales_a, scales_b, bias); Gemm gemm_op; @@ -519,66 +688,117 @@ void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const TORCH_CHECK(status == cutlass::Status::kSuccess) } -template -void sm90_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias, bool fast_accum = true, - bool use_persistent = false) { +template < + typename OutType, + typename CTAShape, + typename ClusterShape, + typename MainloopScheduleType, + typename TileSchedulerType> +void sm90_fp8_dispatch_bias( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias, + bool fast_accum = true, + bool use_persistent = false) { using ElementInput = cutlass::float_e4m3_t; using ElementOutput = OutType; using AccumElementType = float; using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; if (bias) { - using Gemm = - typename DeviceGemmFp8RowwiseSm90::Gemm; + using Gemm = typename DeviceGemmFp8RowwiseSm90< + ElementInput, + ElementOutput, + AccumElementType, + CTAShape, + ClusterShape, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + true>::Gemm; return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); } else { - using Gemm = - typename DeviceGemmFp8RowwiseSm90::Gemm; + using Gemm = typename DeviceGemmFp8RowwiseSm90< + ElementInput, + ElementOutput, + AccumElementType, + CTAShape, + ClusterShape, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + false>::Gemm; return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); } } template -void sm90_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm90_fp8_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { uint32_t const m = a.size(0); using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; using PersistentTileScheduler = cutlass::gemm::PersistentScheduler; using BasicTileScheduler = void; if (m <= 1) { - return sm90_fp8_dispatch_bias, Shape<_1, _8, _1>, FastBasicScheduler, - BasicTileScheduler>(out, a, b, scales_a, scales_b, bias); + return sm90_fp8_dispatch_bias< + OutType, + Shape<_64, _64, _128>, + Shape<_1, _8, _1>, + FastBasicScheduler, + BasicTileScheduler>(out, a, b, scales_a, scales_b, bias); } if (m <= 64) { // m in [1, 64] - return sm90_fp8_dispatch_bias, Shape<_1, _4, _1>, FastPingpongScheduler, - PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + return sm90_fp8_dispatch_bias< + OutType, + Shape<_64, _64, _128>, + Shape<_1, _4, _1>, + FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); } else if (m <= 256) { // m in (64, 256] - return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, - PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + return sm90_fp8_dispatch_bias< + OutType, + Shape<_64, _64, _128>, + Shape<_1, _1, _1>, + FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); } else if (m <= 1024) { // m in (256, 1024] - return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, - PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + return sm90_fp8_dispatch_bias< + OutType, + Shape<_128, _128, _128>, + Shape<_1, _1, _1>, + FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); } else { // m in (1024, inf) - return sm90_fp8_dispatch_bias, Shape<_2, _1, _1>, FastPingpongScheduler, - PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + return sm90_fp8_dispatch_bias< + OutType, + Shape<_128, _128, _128>, + Shape<_2, _1, _1>, + FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); } } #endif -torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias) { +torch::Tensor fp8_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Dtype& out_dtype, + const c10::optional& bias) { TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); @@ -587,10 +807,10 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); - TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, - "mat_a must be multiple of 16 bytes for memory alignment"); - TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, - "mat_b must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK( + (mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK( + (mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment"); TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu index 4a8130d667e..86aa3b8f2f4 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu @@ -35,11 +35,20 @@ limitations under the License. using namespace cute; -template -void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +template < + typename ElementOutput, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + int NumStages> +void cutlass_int8_scaled_mm( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { using ElementAccumulator = int32_t; using ElementCompute = float; using ElementInputA = int8_t; @@ -48,30 +57,51 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons using OperatorClass = cutlass::arch::OpClassTensorOp; using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; - using DefaultGemmConf = cutlass::gemm::device::DefaultGemmConfiguration; + using DefaultGemmConf = cutlass::gemm::device:: + DefaultGemmConfiguration; using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp; using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< - ElementInputA, cutlass::layout::RowMajor, DefaultGemmConf::kAlignmentA, ElementInputB, - cutlass::layout::ColumnMajor, DefaultGemmConf::kAlignmentB, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - ThreadblockSwizzle, NumStages, true, typename DefaultGemmConf::Operator>::GemmKernel; + ElementInputA, + cutlass::layout::RowMajor, + DefaultGemmConf::kAlignmentA, + ElementInputB, + cutlass::layout::ColumnMajor, + DefaultGemmConf::kAlignmentB, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + NumStages, + true, + typename DefaultGemmConf::Operator>::GemmKernel; using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< cutlass::epilogue::threadblock::OutputTileOptimalThreadMap< typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape, typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count, GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads, - GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, cutlass::sizeof_bits::value>, + GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, + cutlass::sizeof_bits::value>, ElementCompute>; using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol< - ThreadblockShape, GemmKernel_::kThreadCount, AlphaColTileIterator, - typename GemmKernel_::Epilogue::OutputTileIterator, ElementAccumulator, ElementCompute, EpilogueOutputOp>; + ThreadblockShape, + GemmKernel_::kThreadCount, + AlphaColTileIterator, + typename GemmKernel_::Epilogue::OutputTileIterator, + ElementAccumulator, + ElementCompute, + EpilogueOutputOp>; - using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< - EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpilogueWithVisitorFromExistingEpilogue::Epilogue; using GemmKernel = cutlass::gemm::kernel::GemmWithEpilogueVisitor; @@ -104,98 +134,164 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons typename EpilogueOutputOp::Params linearScalingParams; typename EpilogueVisitor::Arguments visitor_args{linearScalingParams}; - typename Gemm::Arguments args{{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, - {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args}; + typename Gemm::Arguments args{ + {m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args}; - auto workspace = torch::empty(gemm_op.get_workspace_size(args), - torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); + auto workspace = torch::empty( + gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); auto can_implement = gemm_op.can_implement(args); - TORCH_CHECK(can_implement == cutlass::Status::kSuccess, - "gemm cannot implement, error: ", cutlassGetStatusString(can_implement)); + TORCH_CHECK( + can_implement == cutlass::Status::kSuccess, + "gemm cannot implement, error: ", + cutlassGetStatusString(can_implement)); auto status = gemm_op(args, workspace.data_ptr(), stream); TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status)); } template -void sm75_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm75_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { int m = mat_a.size(0); if (m <= 32) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<32, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, + 2>(out, mat_a, mat_b, scales_a, scales_b, bias); } else if (m <= 64) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, + 2>(out, mat_a, mat_b, scales_a, scales_b, bias); } else if (m <= 256) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, + 2>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, + 2>(out, mat_a, mat_b, scales_a, scales_b, bias); } } template -void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm80_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { int m = mat_a.size(0); int n = mat_b.size(1); if (m <= 16) { if (n <= 4096) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<16, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + InstructionShape, + 6>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<16, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + InstructionShape, + 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } } else if (m <= 32) { if (n <= 4096) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, + 6>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, + 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } } else if (m <= 64) { if (n <= 4096) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, + 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, + 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } } else if (m <= 128 && n < 8192) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, + 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, + 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } } -template -void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +template < + typename ElementOutput, + typename TileShape, + typename ClusterShape, + typename MainloopScheduleType, + bool WithBias> +void cutlass_int8_scaled_mm_sm90( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { using ArchTag = cutlass::arch::Sm90; using ElementAccumulator = int32_t; @@ -213,50 +309,75 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; using TileSchedulerType = cutlass::gemm::PersistentScheduler; - using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, - Stride, Int<0>, Int<0>>>; + using XScale = cutlass::epilogue::fusion:: + Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride, Int<0>, Int<0>>>; - using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, - Stride, Int<1>, Int<0>>>; + using WScale = cutlass::epilogue::fusion:: + Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride, Int<1>, Int<0>>>; - using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, - Stride, Int<1>, Int<0>>>; + using Bias = cutlass::epilogue::fusion:: + Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, Stride, Int<1>, Int<0>>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; // Scale - using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + using Compute0 = cutlass::epilogue::fusion:: + Sm90Compute; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; - using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + using Compute1 = cutlass::epilogue::fusion:: + Sm90Compute; using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; // With bias - using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute; + using ComputeWithBias = cutlass::epilogue::fusion:: + Sm90Compute; using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; using EpilogueEVT = typename cutlass::platform::conditional::type; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, ElementOutput, cutlass::layout::RowMajor, AlignmentC, ElementOutput, - cutlass::layout::RowMajor, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementOutput, + cutlass::layout::RowMajor, + AlignmentC, + ElementOutput, + cutlass::layout::RowMajor, + AlignmentOutput, + EpilogueScheduleType, + EpilogueEVT>::CollectiveOp; using Stages = cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementInputA, cutlass::layout::RowMajor, AlignmentA, ElementInputB, - cutlass::layout::ColumnMajor, AlignmentB, ElementAccumulator, TileShape, ClusterShape, Stages, + ArchTag, + OperatorClass, + ElementInputA, + cutlass::layout::RowMajor, + AlignmentA, + ElementInputB, + cutlass::layout::ColumnMajor, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + Stages, MainloopScheduleType>::CollectiveOp; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape - CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + TileSchedulerType>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -283,14 +404,15 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, StrideC stride_c; StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); - typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, - {m, n, k, 1}, - {a_ptr, stride_a, b_ptr, stride_b}, - {{}, // epilogue.thread - nullptr, - stride_c, - o_ptr, - stride_d}}; + typename Gemm::Arguments args = { + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {a_ptr, stride_a, b_ptr, stride_b}, + {{}, // epilogue.thread + nullptr, + stride_c, + o_ptr, + stride_d}}; if constexpr (WithBias) { ElementOutput* bias_ptr = static_cast(bias->data_ptr()); @@ -308,23 +430,29 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, }; } - auto workspace = torch::empty(gemm_op.get_workspace_size(args), - torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); + auto workspace = torch::empty( + gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); auto can_implement = gemm_op.can_implement(args); - TORCH_CHECK(can_implement == cutlass::Status::kSuccess, - "gemm cannot implement, error: ", cutlassGetStatusString(can_implement)); + TORCH_CHECK( + can_implement == cutlass::Status::kSuccess, + "gemm cannot implement, error: ", + cutlassGetStatusString(can_implement)); auto status = gemm_op(args, workspace.data_ptr(), stream); TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status)); } template -void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm90_dispatch_bias( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { if (bias) { cutlass_int8_scaled_mm_sm90( out, mat_a, mat_b, scales_a, scales_b, bias); @@ -335,45 +463,73 @@ void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const to } template -void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm90_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { int m = mat_a.size(0); int n = mat_b.size(1); if (m <= 32) { if (n < 8192) { - return sm90_dispatch_bias, Shape<_1, _8, _1>, - cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_64, _64, _128>, + Shape<_1, _8, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - return sm90_dispatch_bias, Shape<_1, _8, _1>, - cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_64, _128, _128>, + Shape<_1, _8, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); } } else if (m <= 64) { if (n < 8192) { - return sm90_dispatch_bias, Shape<_1, _4, _1>, - cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_64, _64, _128>, + Shape<_1, _4, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - return sm90_dispatch_bias, Shape<_1, _1, _1>, - cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_64, _64, _256>, + Shape<_1, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); } } else if (m <= 128) { if (n <= 4096) { - return sm90_dispatch_bias, Shape<_2, _1, _1>, - cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_64, _64, _128>, + Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - return sm90_dispatch_bias, Shape<_2, _1, _1>, - cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_64, _128, _128>, + Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); } } else { - return sm90_dispatch_bias, Shape<_2, _1, _1>, - cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, - bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_128, _128, _128>, + Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, bias); } } -torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias) { +torch::Tensor int8_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Dtype& out_dtype, + const c10::optional& bias) { TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index d9290fe012a..ea222c00150 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -8,8 +8,8 @@ #include "utils.h" template -__global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, - const int64_t num_elements) { +__global__ void +per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) { float max_value = 0.0f; unsigned int tid = threadIdx.x; unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x; @@ -56,8 +56,11 @@ __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __r } template -__global__ void per_tensor_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output, - const float* __restrict__ scale, const int64_t num_elements) { +__global__ void per_tensor_quant_fp8_kernel( + const T* __restrict__ input, + FP8_TYPE* __restrict__ output, + const float* __restrict__ scale, + const int64_t num_elements) { const int gid = blockIdx.x * blockDim.x + threadIdx.x; const int grid_size = blockDim.x * gridDim.x; const float scale_val = 1.0f / (*scale); @@ -124,8 +127,10 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch } per_tensor_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), static_cast(output_q.data_ptr()), - static_cast(output_s.data_ptr()), num_elements); + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + num_elements); return true; }); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu index e5a14602a92..bb3135dad23 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu @@ -17,10 +17,15 @@ __device__ __forceinline__ float GroupReduce(float val, const int tid) { } template -__global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, void* __restrict__ output_q, - float* __restrict__ output_s, const int group_size, - const int num_groups, const float eps, const float fp8_min, - const float fp8_max) { +__global__ void per_token_group_quant_fp8_kernel( + const T* __restrict__ input, + void* __restrict__ output_q, + float* __restrict__ output_s, + const int group_size, + const int num_groups, + const float eps, + const float fp8_min, + const float fp8_max) { const int groups_per_block = 16; const int local_group_id = threadIdx.x / 16; const int lane_id = threadIdx.x % 16; @@ -80,8 +85,14 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo } } -void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, - int64_t group_size, double eps, double fp8_min, double fp8_max) { +void sgl_per_token_group_quant_fp8( + torch::Tensor input, + torch::Tensor output_q, + torch::Tensor output_s, + int64_t group_size, + double eps, + double fp8_min, + double fp8_max) { CHECK_INPUT(input); CHECK_INPUT(output_q); CHECK_INPUT(output_s); @@ -97,8 +108,14 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { per_token_group_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), output_q.data_ptr(), static_cast(output_s.data_ptr()), - group_size, num_groups, (float)eps, (float)fp8_min, (float)fp8_max); + static_cast(input.data_ptr()), + output_q.data_ptr(), + static_cast(output_s.data_ptr()), + group_size, + num_groups, + (float)eps, + (float)fp8_min, + (float)fp8_max); return true; }); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index 5528ad8c568..1491af126ef 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -7,9 +7,12 @@ #include "utils.h" template -__global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q, - float* __restrict__ output_s, const int64_t hidden_dim, - const int64_t num_tokens) { +__global__ void per_token_quant_fp8_kernel( + const T* __restrict__ input, + FP8_TYPE* __restrict__ output_q, + float* __restrict__ output_s, + const int64_t hidden_dim, + const int64_t num_tokens) { const int token_idx = blockIdx.x; if (token_idx >= num_tokens) return; @@ -110,8 +113,11 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { per_token_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), static_cast(output_q.data_ptr()), - static_cast(output_s.data_ptr()), hidden_dim, num_tokens); + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + hidden_dim, + num_tokens); return true; }); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu index 473aae6f5ec..c5f37e55609 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -25,9 +25,11 @@ limitations under the License. #define WARP_SIZE 32 template -__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, size_t numel) { +__global__ void count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + size_t numel) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; @@ -39,10 +41,15 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ } template -__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) { +__global__ void moe_align_block_size_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t* __restrict__ cumsum) { __shared__ int32_t shared_counts[WARP_SIZE][8]; const int warp_id = threadIdx.x / WARP_SIZE; @@ -91,17 +98,29 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id } } -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, - torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, - torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { +void moe_align_block_size( + torch::Tensor topk_ids, + int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, + torch::Tensor cumsum_buffer) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now."); DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { auto align_kernel = moe_align_block_size_kernel; - align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), - num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); + align_kernel<<<1, 1024, 0, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel(), + cumsum_buffer.data_ptr()); const int block_threads = 256; const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; @@ -109,8 +128,10 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b const int actual_blocks = std::min(num_blocks, max_blocks); auto sort_kernel = count_and_sort_expert_tokens_kernel; - sort_kernel<<>>(topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), topk_ids.numel()); + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), + topk_ids.numel()); }); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu b/sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu index af44261cc18..1bfd6fd8481 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu @@ -23,10 +23,18 @@ // tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = // [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, // draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token] -__global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, - bool* tree_mask, int64_t* positions, int64_t* retrive_index, - int64_t* retrive_next_token, int64_t* retrive_next_sibling, int topk, int depth, - int draft_token_num) { +__global__ void build_tree_efficient( + int64_t* parent_list, + int64_t* selected_index, + int32_t* verified_seq_len, + bool* tree_mask, + int64_t* positions, + int64_t* retrive_index, + int64_t* retrive_next_token, + int64_t* retrive_next_sibling, + int topk, + int depth, + int draft_token_num) { int bid = blockIdx.x; int tid = threadIdx.x; @@ -99,10 +107,18 @@ __global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_ind } } -void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, - at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, - at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk, - int64_t depth, int64_t draft_token_num) { +void build_tree_kernel_efficient( + at::Tensor parent_list, + at::Tensor selected_index, + at::Tensor verified_seq_len, + at::Tensor tree_mask, + at::Tensor positions, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + int64_t topk, + int64_t depth, + int64_t draft_token_num) { // TODO (ying) check shape // TODO (ying) check type int bs = parent_list.size(0); @@ -111,11 +127,17 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); build_tree_efficient<<>>( - static_cast(parent_list.data_ptr()), static_cast(selected_index.data_ptr()), - static_cast(verified_seq_len.data_ptr()), static_cast(tree_mask.data_ptr()), - static_cast(positions.data_ptr()), static_cast(retrive_index.data_ptr()), - static_cast(retrive_next_token.data_ptr()), static_cast(retrive_next_sibling.data_ptr()), - int32_t(topk), int32_t(depth), int32_t(draft_token_num)); + static_cast(parent_list.data_ptr()), + static_cast(selected_index.data_ptr()), + static_cast(verified_seq_len.data_ptr()), + static_cast(tree_mask.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + int32_t(topk), + int32_t(depth), + int32_t(draft_token_num)); } // parent_list [bs, topk * (depth - 1) + 1)] @@ -124,8 +146,16 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind // tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = // [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, // draft_token, depth + 2] -__global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, bool* tree_mask, - int64_t* positions, int64_t* retrive_index, int topk, int depth, int draft_token_num) { +__global__ void build_tree( + int64_t* parent_list, + int64_t* selected_index, + int32_t* verified_seq_len, + bool* tree_mask, + int64_t* positions, + int64_t* retrive_index, + int topk, + int depth, + int draft_token_num) { int bid = blockIdx.x; int tid = threadIdx.x; @@ -191,9 +221,16 @@ __global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_ } } -void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, - at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk, - int64_t depth, int64_t draft_token_num) { +void build_tree_kernel( + at::Tensor parent_list, + at::Tensor selected_index, + at::Tensor verified_seq_len, + at::Tensor tree_mask, + at::Tensor positions, + at::Tensor retrive_index, + int64_t topk, + int64_t depth, + int64_t draft_token_num) { // TODO (ying) check shape // TODO (ying) check type int bs = parent_list.size(0); @@ -202,8 +239,13 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); build_tree<<>>( - static_cast(parent_list.data_ptr()), static_cast(selected_index.data_ptr()), - static_cast(verified_seq_len.data_ptr()), static_cast(tree_mask.data_ptr()), - static_cast(positions.data_ptr()), static_cast(retrive_index.data_ptr()), int32_t(topk), - int32_t(depth), int32_t(draft_token_num)); + static_cast(parent_list.data_ptr()), + static_cast(selected_index.data_ptr()), + static_cast(verified_seq_len.data_ptr()), + static_cast(tree_mask.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(retrive_index.data_ptr()), + int32_t(topk), + int32_t(depth), + int32_t(draft_token_num)); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu b/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu index 379a2a22c0d..6eaafdb5be4 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu @@ -29,12 +29,19 @@ using namespace flashinfer; // retrive_next_sibling: [bs, num_draft_tokens] // uniform_samples: [bs, num_draft_tokens] // target_probs: [bs, num_draft_tokens, vocab_size] -void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index, - at::Tensor accept_token_num, // mutable - at::Tensor candidates, at::Tensor retrive_index, - at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, - at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs, - bool deterministic, int64_t cuda_stream = 0) { +void tree_speculative_sampling_target_only( + at::Tensor predicts, + at::Tensor accept_index, + at::Tensor accept_token_num, // mutable + at::Tensor candidates, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + at::Tensor uniform_samples, + at::Tensor target_probs, + at::Tensor draft_probs, + bool deterministic, + int64_t cuda_stream = 0) { CHECK_INPUT(candidates); CHECK_INPUT(retrive_index); CHECK_INPUT(retrive_next_token); @@ -108,13 +115,24 @@ void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accep cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly( - static_cast(predicts.data_ptr()), static_cast(accept_index.data_ptr()), - static_cast(accept_token_num.data_ptr()), static_cast(candidates.data_ptr()), - static_cast(retrive_index.data_ptr()), static_cast(retrive_next_token.data_ptr()), - static_cast(retrive_next_sibling.data_ptr()), static_cast(uniform_samples.data_ptr()), - static_cast(target_probs.data_ptr()), static_cast(draft_probs.data_ptr()), batch_size, - num_spec_step, num_draft_tokens, vocab_size, deterministic, stream); + static_cast(predicts.data_ptr()), + static_cast(accept_index.data_ptr()), + static_cast(accept_token_num.data_ptr()), + static_cast(candidates.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + static_cast(uniform_samples.data_ptr()), + static_cast(target_probs.data_ptr()), + static_cast(draft_probs.data_ptr()), + batch_size, + num_spec_step, + num_draft_tokens, + vocab_size, + deterministic, + stream); - TORCH_CHECK(status == cudaSuccess, - "TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status))); + TORCH_CHECK( + status == cudaSuccess, + "TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status))); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh b/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh index b9a32d2a90e..bf7099231c1 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh @@ -27,15 +27,29 @@ namespace sampling { using namespace cub; -template -__global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* accept_index, - IdType* accept_token_num, // mutable - IdType* candidates, IdType* retrive_index, IdType* retrive_next_token, - IdType* retrive_next_sibling, DType* uniform_samples, - DType* target_probs, DType* draft_probs, uint32_t batch_size, - uint32_t num_speculative_tokens, uint32_t num_draft_tokens, - uint32_t d) { +template < + uint32_t BLOCK_THREADS, + BlockScanAlgorithm SCAN_ALGORITHM, + BlockReduceAlgorithm REDUCE_ALGORITHM, + uint32_t VEC_SIZE, + bool DETERMINISTIC, + typename DType, + typename IdType> +__global__ void TreeSpeculativeSamplingTargetOnly( + IdType* predicts, + IdType* accept_index, + IdType* accept_token_num, // mutable + IdType* candidates, + IdType* retrive_index, + IdType* retrive_next_token, + IdType* retrive_next_sibling, + DType* uniform_samples, + DType* target_probs, + DType* draft_probs, + uint32_t batch_size, + uint32_t num_speculative_tokens, + uint32_t num_draft_tokens, + uint32_t d) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; extern __shared__ __align__(alignof(SamplingTempStorage)) @@ -140,37 +154,54 @@ __global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* acce } template -cudaError_t TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* output_token_ids, - IdType* output_accepted_token_num, // mutable - IdType* candidates, IdType* retrive_index, IdType* retrive_next_token, - IdType* retrive_next_sibling, DType* uniform_samples, DType* target_probs, - DType* draft_probs, uint32_t batch_size, uint32_t num_speculative_tokens, - uint32_t num_draft_tokens, uint32_t d, bool deterministic, - cudaStream_t stream = 0) { +cudaError_t TreeSpeculativeSamplingTargetOnly( + IdType* predicts, + IdType* output_token_ids, + IdType* output_accepted_token_num, // mutable + IdType* candidates, + IdType* retrive_index, + IdType* retrive_next_token, + IdType* retrive_next_sibling, + DType* uniform_samples, + DType* target_probs, + DType* draft_probs, + uint32_t batch_size, + uint32_t num_speculative_tokens, + uint32_t num_draft_tokens, + uint32_t d, + bool deterministic, + cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&predicts, - &output_token_ids, - &output_accepted_token_num, - &candidates, - &retrive_index, - &retrive_next_token, - &retrive_next_sibling, - &uniform_samples, - &target_probs, - &draft_probs, - &batch_size, - &num_speculative_tokens, - &num_draft_tokens, - &d}; + void* args[] = { + &predicts, + &output_token_ids, + &output_accepted_token_num, + &candidates, + &retrive_index, + &retrive_next_token, + &retrive_next_sibling, + &uniform_samples, + &target_probs, + &draft_probs, + &batch_size, + &num_speculative_tokens, + &num_draft_tokens, + &d}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TreeSpeculativeSamplingTargetOnly; + auto kernel = TreeSpeculativeSamplingTargetOnly< + BLOCK_THREADS, + SCAN_ALGO, + REDUCE_ALGO, + VEC_SIZE, + DETERMINISTIC, + DType, + IdType>; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); })}); diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index f5ebffb1295..5bc5c7083b8 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -42,8 +42,8 @@ using fptr_t = int64_t; void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); -void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, - int64_t cuda_stream); +void gemma_fused_add_rmsnorm( + at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); @@ -53,113 +53,219 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); */ #ifdef USE_ROCM // ROCM custom allreduce -fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, - const std::vector& offsets, int64_t rank, bool full_nvlink); +fptr_t init_custom_ar( + torch::Tensor& meta, + torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, + int64_t rank, + bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); void dispose(fptr_t _fa); int64_t meta_size(); -void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, - const std::vector& offsets); +void register_buffer( + fptr_t _fa, torch::Tensor& t, const std::vector& handles, const std::vector& offsets); std::tuple> get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector& handles, - const std::vector>& offsets); +void register_graph_buffers( + fptr_t _fa, const std::vector& handles, const std::vector>& offsets); torch::Tensor allocate_meta_buffer(int64_t size); torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); #else // TRTLLM custom allreduce -fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, - const std::vector& tmp_result_buffers, const std::vector& barrier_in, - const std::vector& barrier_out); +fptr_t init_custom_ar( + int64_t rank_id, + int64_t world_size, + torch::Tensor& rank_data, + const std::vector& buffers, + const std::vector& tmp_result_buffers, + const std::vector& barrier_in, + const std::vector& barrier_out); void dispose(fptr_t _fa); void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector>& handles, - const std::vector>& offsets); +void register_graph_buffers( + fptr_t _fa, const std::vector>& handles, const std::vector>& offsets); #endif /* * From csrc/gemm */ -torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias); -torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias); -torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const torch::Dtype& out_dtype); -void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, - double eps, double fp8_min, double fp8_max); +torch::Tensor int8_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Dtype& out_dtype, + const c10::optional& bias); +torch::Tensor fp8_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Dtype& out_dtype, + const c10::optional& bias); +torch::Tensor fp8_blockwise_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Dtype& out_dtype); +void sgl_per_token_group_quant_fp8( + at::Tensor input, + at::Tensor output_q, + at::Tensor output_s, + int64_t group_size, + double eps, + double fp8_min, + double fp8_max); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); -void cublas_grouped_gemm(const std::vector& inputs, const std::vector& weights, - const std::vector& outputs, const torch::Dtype& out_dtype, - int64_t cublas_handle, int64_t cuda_stream); +void cublas_grouped_gemm( + const std::vector& inputs, + const std::vector& weights, + const std::vector& outputs, + const torch::Dtype& out_dtype, + int64_t cublas_handle, + int64_t cuda_stream); /* * From csrc/moe */ -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, - torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, - torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); +void moe_align_block_size( + torch::Tensor topk_ids, + int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, + torch::Tensor cumsum_buffer); /* * From csrc/speculative */ -void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index, - at::Tensor accept_token_num, // mutable - at::Tensor candidates, at::Tensor retrive_index, - at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, - at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs, - bool deterministic = true, int64_t cuda_stream = 0); - -void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, - at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, - at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk, - int64_t depth, int64_t draft_token_num); - -void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, - at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk, - int64_t depth, int64_t draft_token_num); +void tree_speculative_sampling_target_only( + at::Tensor predicts, + at::Tensor accept_index, + at::Tensor accept_token_num, // mutable + at::Tensor candidates, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + at::Tensor uniform_samples, + at::Tensor target_probs, + at::Tensor draft_probs, + bool deterministic = true, + int64_t cuda_stream = 0); + +void build_tree_kernel_efficient( + at::Tensor parent_list, + at::Tensor selected_index, + at::Tensor verified_seq_len, + at::Tensor tree_mask, + at::Tensor positions, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + int64_t topk, + int64_t depth, + int64_t draft_token_num); + +void build_tree_kernel( + at::Tensor parent_list, + at::Tensor selected_index, + at::Tensor verified_seq_len, + at::Tensor tree_mask, + at::Tensor positions, + at::Tensor retrive_index, + int64_t topk, + int64_t depth, + int64_t draft_token_num); /* * From FlashInfer */ -void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, - at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); -void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, - std::optional maybe_min_p_arr, double min_p_val, bool deterministic, - int64_t cuda_stream); +void bmm_fp8( + at::Tensor A, + at::Tensor B, + at::Tensor D, + at::Tensor A_scale, + at::Tensor B_scale, + at::Tensor workspace_buffer, + int64_t cublas_handle, + int64_t cuda_stream); +void min_p_sampling_from_probs( + at::Tensor probs, + at::Tensor uniform_samples, + at::Tensor samples, + std::optional maybe_min_p_arr, + double min_p_val, + bool deterministic, + int64_t cuda_stream); // top k renorm probs // patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. -void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, - unsigned int top_k_val, int64_t cuda_stream); +void top_k_renorm_probs( + at::Tensor probs, + at::Tensor renorm_probs, + std::optional maybe_top_k_arr, + unsigned int top_k_val, + int64_t cuda_stream); // patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. -inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs, - std::optional maybe_top_k_arr, int64_t top_k_val, - int64_t cuda_stream) { +inline void top_k_renorm_probs_wrapper( + at::Tensor probs, + at::Tensor renorm_probs, + std::optional maybe_top_k_arr, + int64_t top_k_val, + int64_t cuda_stream) { top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast(top_k_val), cuda_stream); } -void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, - double top_p_val, int64_t cuda_stream); -void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, - at::Tensor success, std::optional maybe_top_k_arr, double top_k_val, - std::optional maybe_top_p_arr, double top_p_val, bool deterministic, - int64_t cuda_stream); -void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, - std::optional maybe_top_p_arr, double top_p_val, bool deterministic, - int64_t cuda_stream); -void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, - int64_t cuda_stream); +void top_p_renorm_probs( + at::Tensor probs, + at::Tensor renorm_probs, + std::optional maybe_top_p_arr, + double top_p_val, + int64_t cuda_stream); +void top_k_top_p_sampling_from_probs( + at::Tensor probs, + at::Tensor uniform_samples, + at::Tensor samples, + at::Tensor success, + std::optional maybe_top_k_arr, + double top_k_val, + std::optional maybe_top_p_arr, + double top_p_val, + bool deterministic, + int64_t cuda_stream); +void top_p_sampling_from_probs( + at::Tensor probs, + at::Tensor uniform_samples, + at::Tensor samples, + at::Tensor success, + std::optional maybe_top_p_arr, + double top_p_val, + bool deterministic, + int64_t cuda_stream); +void apply_rope_pos_ids_cos_sin_cache( + at::Tensor q, + at::Tensor k, + at::Tensor q_rope, + at::Tensor k_rope, + at::Tensor cos_sin_cache, + at::Tensor pos_ids, + bool interleave, + int64_t cuda_stream); /* * Other */ -void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, - const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, - torch::Tensor new_kv); +void lightning_attention_decode( + const torch::Tensor& q, + const torch::Tensor& k, + const torch::Tensor& v, + const torch::Tensor& past_kv, + const torch::Tensor& slope, + torch::Tensor output, + torch::Tensor new_kv); // sgl_per_token_quant_fp8 void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); diff --git a/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh index f4b01230cf3..c670c994db1 100644 --- a/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh @@ -103,7 +103,7 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world return AllReduceStrategyType::TWOSHOT; } -void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, - cudaStream_t stream); +void trtCustomAllReduce( + AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream); } // namespace trt_llm diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h index 94bcefa7fb6..b2960954bcb 100644 --- a/sgl-kernel/src/sgl-kernel/include/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -95,7 +95,6 @@ inline int getSMVersion() { AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) #define CEILDIV(x, y) (((x) + (y)-1) / (y)) - #define WARP_SIZE 32 #ifndef USE_ROCM From d4017a6b6339257888484f86d9d20a20546111fe Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 7 Mar 2025 22:12:13 -0800 Subject: [PATCH 04/27] [EAGLE] many fixes for eagle (#4195) Co-authored-by: SangBin Cho Co-authored-by: Sehoon Kim --- python/pyproject.toml | 8 +- python/sglang/srt/configs/model_config.py | 2 +- python/sglang/srt/entrypoints/engine.py | 2 + python/sglang/srt/layers/sampler.py | 62 +++++------ python/sglang/srt/managers/scheduler.py | 2 + .../srt/model_executor/cuda_graph_runner.py | 5 +- .../sglang/srt/model_executor/model_runner.py | 39 ------- .../sampling/penaltylib/frequency_penalty.py | 1 - .../sampling/penaltylib/presence_penalty.py | 1 - python/sglang/srt/speculative/eagle_worker.py | 104 +++++++++++------- test/lang/test_srt_backend.py | 2 +- test/srt/test_eagle_infer.py | 77 ++++++++++++- test/srt/test_eval_accuracy_large.py | 9 +- test/srt/test_mla.py | 8 ++ test/srt/test_penalty.py | 15 ++- 15 files changed, 202 insertions(+), 135 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 0dc0ef63dbb..6eaa6263bef 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -18,12 +18,15 @@ dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"] [project.optional-dependencies] runtime_common = [ "aiohttp", + "datasets", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", + "llguidance>=0.6.15", "modelscope", + "ninja", "orjson", "packaging", "pillow", @@ -33,13 +36,10 @@ runtime_common = [ "python-multipart", "pyzmq>=25.1.2", "torchao>=0.7.0", + "transformers==4.48.3", "uvicorn", "uvloop", "xgrammar==0.1.14", - "ninja", - "transformers==4.48.3", - "llguidance>=0.6.15", - "datasets" ] srt = [ diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 64ef15cf7fb..6f103bcc603 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -81,7 +81,7 @@ def __init__( if context_length is not None: if context_length > derived_context_len: if get_bool_env_var( - "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False" + "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True" ): logger.warning( f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 7c0f287b7d0..f8a6b4e431f 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -106,6 +106,8 @@ def __init__(self, **kwargs): tokenizer_manager, scheduler_info = _launch_subprocesses( server_args=server_args ) + + self.server_args = server_args self.tokenizer_manager = tokenizer_manager self.scheduler_info = scheduler_info diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index f471626e1f7..ec041305c7b 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -42,7 +42,6 @@ def forward( return_logprob: bool, top_logprobs_nums: List[int], token_ids_logprobs: List[List[int]], - batch_next_token_ids: Optional[torch.Tensor] = None, ): """Run a sampler & compute logprobs and update logits_output accordingly. @@ -72,8 +71,7 @@ def forward( if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling - if batch_next_token_ids is None: - batch_next_token_ids = torch.argmax(logits, -1) + batch_next_token_ids = torch.argmax(logits, -1) if return_logprob: logprobs = torch.nn.functional.log_softmax(logits, dim=-1) else: @@ -94,43 +92,39 @@ def forward( top_p_normalize_probs_torch(probs, sampling_info.top_ps) ).clamp(min=torch.finfo(probs.dtype).min) - if batch_next_token_ids is None: - max_top_k_round, batch_size = 32, probs.shape[0] - uniform_samples = torch.rand( - (max_top_k_round, batch_size), device=probs.device + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand( + (max_top_k_round, batch_size), device=probs.device + ) + if sampling_info.need_min_p_sampling: + probs = top_k_renorm_prob(probs, sampling_info.top_ks) + probs = top_p_renorm_prob(probs, sampling_info.top_ps) + batch_next_token_ids = min_p_sampling_from_probs( + probs, uniform_samples, sampling_info.min_ps ) - if sampling_info.need_min_p_sampling: - probs = top_k_renorm_prob(probs, sampling_info.top_ks) - probs = top_p_renorm_prob(probs, sampling_info.top_ps) - batch_next_token_ids = min_p_sampling_from_probs( - probs, uniform_samples, sampling_info.min_ps - ) - else: - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( - probs, - uniform_samples, - sampling_info.top_ks, - sampling_info.top_ps, - filter_apply_order="joint", - ) - - if self.use_nan_detection and not torch.all(success): - logger.warning("Detected errors during sampling!") - batch_next_token_ids = torch.zeros_like( - batch_next_token_ids - ) - - elif global_server_args_dict["sampling_backend"] == "pytorch": - if batch_next_token_ids is None: - # A slower fallback implementation with torch native operations. - batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( + else: + batch_next_token_ids, success = top_k_top_p_sampling_from_probs( probs, + uniform_samples, sampling_info.top_ks, sampling_info.top_ps, - sampling_info.min_ps, - sampling_info.need_min_p_sampling, + filter_apply_order="joint", ) + if self.use_nan_detection and not torch.all(success): + logger.warning("Detected errors during sampling!") + batch_next_token_ids = torch.zeros_like(batch_next_token_ids) + + elif global_server_args_dict["sampling_backend"] == "pytorch": + # A slower fallback implementation with torch native operations. + batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( + probs, + sampling_info.top_ks, + sampling_info.top_ps, + sampling_info.min_ps, + sampling_info.need_min_p_sampling, + ) + if return_logprob: # clamp to avoid -inf logprobs = torch.log( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index cb3a4b5de39..a5c6a1dbdcd 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -957,11 +957,13 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: self.req_to_token_pool.free(self.chunked_req.req_pool_idx) self.batch_is_full = False + # Filter batch last_bs = self.last_batch.batch_size() self.last_batch.filter_batch() if self.last_batch.batch_size() < last_bs: self.batch_is_full = False + # Merge the new batch into the running batch if not self.last_batch.is_empty(): if self.running_batch is None: self.running_batch = self.last_batch diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 813fbf6fc1c..6a2bab22a46 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -300,10 +300,11 @@ def can_run(self, forward_batch: ForwardBatch): def capture(self): with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream + # Reverse the order to enable better memory sharing across cuda graphs. capture_range = ( - tqdm.tqdm(self.capture_bs) + tqdm.tqdm(reversed(self.capture_bs)) if get_tensor_model_parallel_rank() == 0 - else self.capture_bs + else reversed(self.capture_bs) ) for bs in capture_range: with patch_model( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 666b97e2b8e..6489ea6eddf 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -928,45 +928,6 @@ def _preprocess_logits( sampling_info.update_regex_vocab_mask() sampling_info.apply_logits_bias(logits_output.next_token_logits) - def update_output_logprobs( - self, - logits_output: LogitsProcessorOutput, - sampling_info: SamplingBatchInfo, - top_logprobs_nums: List[int], - token_ids_logprobs: List[int], - next_token_ids: torch.Tensor, - *, - num_tokens_per_req: List[int], - ): - """Update the logits_output's output logprob based on next_token_ids - - Args: - logits_output: The logits output from the model forward - sampling_info: Sampling info for logprob calculation - top_logprobs_nums: Number of logprobs per request. - next_token_ids: Next token ids. - num_tokens_per_req: The number of tokens per request. - - Returns: - A list of next_token_ids - """ - self._preprocess_logits(logits_output, sampling_info) - # We should repeat top_logprobs_nums to match num_tokens_per_req. - top_logprobs_nums_repeat_interleaved = [] - token_ids_logprobs_repeat_interleaved = [] - for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req): - top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens) - for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req): - token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens) - self.sampler( - logits_output, - sampling_info, - True, - top_logprobs_nums_repeat_interleaved, - token_ids_logprobs_repeat_interleaved, - batch_next_token_ids=next_token_ids, - ) - def sample( self, logits_output: LogitsProcessorOutput, diff --git a/python/sglang/srt/sampling/penaltylib/frequency_penalty.py b/python/sglang/srt/sampling/penaltylib/frequency_penalty.py index 69153462731..893a1c3775a 100644 --- a/python/sglang/srt/sampling/penaltylib/frequency_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/frequency_penalty.py @@ -56,7 +56,6 @@ def _filter(self, keep_indices: torch.Tensor): ] def _merge(self, their: "BatchedFrequencyPenalizer"): - print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}") self.frequency_penalties = torch.cat( [self.frequency_penalties, their.frequency_penalties], dim=0 ) diff --git a/python/sglang/srt/sampling/penaltylib/presence_penalty.py b/python/sglang/srt/sampling/penaltylib/presence_penalty.py index 91266b352fb..4f3a6ace3a0 100644 --- a/python/sglang/srt/sampling/penaltylib/presence_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/presence_penalty.py @@ -56,7 +56,6 @@ def _filter(self, keep_indices: torch.Tensor): ] def _merge(self, their: "BatchedPresencePenalizer"): - print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}") self.presence_penalties = torch.cat( [self.presence_penalties, their.presence_penalties], dim=0 ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 12da787eb31..bd2fa600915 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -7,6 +7,7 @@ from huggingface_hub import snapshot_download from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.model_executor.forward_batch_info import ( @@ -302,13 +303,10 @@ def draft_forward(self, forward_batch: ForwardBatch): # Set inputs forward_batch.input_ids = input_ids + out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1) forward_batch.out_cache_loc = out_cache_loc[ - forward_batch.batch_size - * self.topk - * i : forward_batch.batch_size - * self.topk - * (i + 1) - ] + :, self.topk * i : self.topk * (i + 1) + ].flatten() forward_batch.positions.add_(1) forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] spec_info.hidden_states = hidden_states @@ -353,42 +351,70 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): batch.spec_info = res.draft_input if batch.return_logprob: - # Compute output logprobs using the sampler. - num_tokens_per_req = [ - accept + 1 for accept in res.accept_length_per_req_cpu - ] - self.target_worker.model_runner.update_output_logprobs( - logits_output, - batch.sampling_info, - batch.top_logprobs_nums, - batch.token_ids_logprobs, - res.verified_id, - # +1 for bonus token. - num_tokens_per_req=num_tokens_per_req, - ) - - # Add output logprobs to the request. - pt = 0 - # NOTE: tolist() of these values are skipped when output is processed - next_token_logprobs = res.logits_output.next_token_logprobs.tolist() - verified_ids = res.verified_id.tolist() - for req, num_tokens in zip(batch.reqs, num_tokens_per_req): - for _ in range(num_tokens): - if req.return_logprob: - token_id = verified_ids[pt] - req.output_token_logprobs_val.append(next_token_logprobs[pt]) - req.output_token_logprobs_idx.append(token_id) - if req.top_logprobs_num > 0: - req.output_top_logprobs_val.append( - res.logits_output.next_token_top_logprobs_val[pt] - ) - req.output_top_logprobs_idx.append( - res.logits_output.next_token_top_logprobs_idx[pt] - ) - pt += 1 + self.add_logprob_values(batch, res, logits_output) return logits_output, res, model_worker_batch + def add_logprob_values( + self, + batch: ScheduleBatch, + res: EagleVerifyOutput, + logits_output: LogitsProcessorOutput, + ): + # Extract args + logits_output = res.logits_output + top_logprobs_nums = batch.top_logprobs_nums + token_ids_logprobs = batch.token_ids_logprobs + logprobs = torch.nn.functional.log_softmax( + logits_output.next_token_logits, dim=-1 + ) + batch_next_token_ids = res.verified_id + num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu] + + # We should repeat top_logprobs_nums to match num_tokens_per_req. + top_logprobs_nums_repeat_interleaved = [] + token_ids_logprobs_repeat_interleaved = [] + for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req): + top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens) + for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req): + token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens) + + # Extract logprobs + if any(x > 0 for x in top_logprobs_nums): + ( + logits_output.next_token_top_logprobs_val, + logits_output.next_token_top_logprobs_idx, + ) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved) + + if any(x is not None for x in token_ids_logprobs): + ( + logits_output.next_token_token_ids_logprobs_val, + logits_output.next_token_token_ids_logprobs_idx, + ) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved) + + logits_output.next_token_logprobs = logprobs[ + torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device), + batch_next_token_ids, + ] + + # Add output logprobs to the request. + pt = 0 + next_token_logprobs = logits_output.next_token_logprobs.tolist() + verified_ids = batch_next_token_ids.tolist() + for req, num_tokens in zip(batch.reqs, num_tokens_per_req): + for _ in range(num_tokens): + if req.return_logprob: + req.output_token_logprobs_val.append(next_token_logprobs[pt]) + req.output_token_logprobs_idx.append(verified_ids[pt]) + if req.top_logprobs_num > 0: + req.output_top_logprobs_val.append( + res.logits_output.next_token_top_logprobs_val[pt] + ) + req.output_top_logprobs_idx.append( + res.logits_output.next_token_top_logprobs_idx[pt] + ) + pt += 1 + def forward_draft_extend( self, batch: ScheduleBatch, diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index 92e828c26f4..29f7a12a2b4 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -76,7 +76,7 @@ def test_hellaswag_select(self): # Run twice to capture more bugs for _ in range(2): accuracy, latency = test_hellaswag_select() - self.assertGreater(accuracy, 0.65) + self.assertGreater(accuracy, 0.60) def test_gen_min_new_tokens(self): test_gen_min_new_tokens() diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index cadca667b37..5b89071b65f 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -123,7 +123,7 @@ def _test_eos_token(self, engine): def _test_acc_length(self, engine): prompt = [ "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:" - ] + ] * 5 sampling_params = {"temperature": 0, "max_new_tokens": 512} output = engine.generate(prompt, sampling_params) output = output[0] @@ -141,10 +141,14 @@ def _test_acc_length(self, engine): / output["meta_info"]["e2e_latency"] ) print(f"{acc_length=}") - self.assertGreater(acc_length, 3.6) + + if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST: + self.assertGreater(acc_length, 3.6) + else: + self.assertGreater(acc_length, 2.6) -class TestEAGLEEngineTokenMap(unittest.TestCase): +class TestEAGLEEngineTokenMap(TestEAGLEEngine): BASE_CONFIG = { "model_path": "meta-llama/Meta-Llama-3-8B-Instruct", "speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B", @@ -155,6 +159,7 @@ class TestEAGLEEngineTokenMap(unittest.TestCase): "speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt", "mem_fraction_static": 0.7, "cuda_graph_max_bs": 5, + "dtype": "float16", } NUM_CONFIGS = 1 @@ -245,8 +250,25 @@ def test_request_abort(self): for p in threads: p.join() + def test_max_token_one(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=1, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + + # Just run and check it does not hang + metrics = run_eval(args) + self.assertGreater(metrics["output_throughput"], 50) + def test_gsm8k(self): - server_info = requests.get(self.base_url + "/flush_cache") + requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( num_shots=5, @@ -391,6 +413,53 @@ def test_logprob_mixed(self): with ThreadPoolExecutor(8) as executor: list(executor.map(func, args)) + def run_decode(self, sampling_params): + return_logprob = True + top_logprobs_num = 5 + return_text = True + n = 1 + + response = requests.post( + self.base_url + "/generate", + json={ + "text": "Human: Write a travel blog post to Hawaii.\n\nAssistant:", + "sampling_params": { + "max_new_tokens": 48, + "n": n, + "temperature": 0.7, + **sampling_params, + }, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "return_text_in_logprobs": return_text, + "logprob_start_len": 0, + }, + ) + self.assertEqual(response.status_code, 200) + print(json.dumps(response.json())) + print("=" * 100) + + def test_penalty_mixed(self): + args = [ + {}, + {}, + {}, + {"frequency_penalty": 2}, + {"presence_penalty": 1}, + {"min_new_tokens": 16}, + {"frequency_penalty": 0.2}, + {"presence_penalty": 0.4}, + {"min_new_tokens": 8}, + {"frequency_penalty": 0.4, "presence_penalty": 0.8}, + {"frequency_penalty": 0.4, "min_new_tokens": 12}, + {"presence_penalty": 0.8, "min_new_tokens": 12}, + {"presence_penalty": -0.3, "frequency_penalty": 1.3, "min_new_tokens": 32}, + {"presence_penalty": 0.3, "frequency_penalty": -1.3, "min_new_tokens": 32}, + ] + random.shuffle(args * 5) + with ThreadPoolExecutor(8) as executor: + list(executor.map(self.run_decode, args)) + class TestEAGLERetract(TestEAGLEServer): @classmethod diff --git a/test/srt/test_eval_accuracy_large.py b/test/srt/test_eval_accuracy_large.py index dd923777f4b..f5e0e3cdbde 100644 --- a/test/srt/test_eval_accuracy_large.py +++ b/test/srt/test_eval_accuracy_large.py @@ -44,11 +44,12 @@ def test_mmlu(self): ) metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.71) if is_in_ci(): write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n') + self.assertGreater(metrics["score"], 0.71) + def test_human_eval(self): args = SimpleNamespace( base_url=self.base_url, @@ -59,13 +60,14 @@ def test_human_eval(self): ) metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.64) if is_in_ci(): write_github_step_summary( f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n' ) + self.assertGreater(metrics["score"], 0.64) + def test_mgsm_en(self): args = SimpleNamespace( base_url=self.base_url, @@ -76,13 +78,14 @@ def test_mgsm_en(self): ) metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.835) if is_in_ci(): write_github_step_summary( f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n' ) + self.assertGreater(metrics["score"], 0.835) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index a019988ab98..b2a831f99ff 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -1,6 +1,7 @@ import unittest from types import SimpleNamespace +import requests import torch from sglang.srt.utils import kill_process_tree @@ -129,6 +130,8 @@ def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + args = SimpleNamespace( num_shots=5, data_path=None, @@ -143,6 +146,11 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.60) + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 2.5) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_penalty.py b/test/srt/test_penalty.py index cb9b6b3dc06..e1d11a9ac54 100644 --- a/test/srt/test_penalty.py +++ b/test/srt/test_penalty.py @@ -42,7 +42,7 @@ def run_decode(self, sampling_params): # prompt that is supposed to generate < 32 tokens "text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", "sampling_params": { - "max_new_tokens": 32, + "max_new_tokens": 48, "n": n, **sampling_params, }, @@ -68,19 +68,22 @@ def test_min_new_tokens(self): def test_presence_penalty(self): self.run_decode({"presence_penalty": 2}) - def test_mixed(self): + def test_penalty_mixed(self): args = [ {}, {}, {}, {"frequency_penalty": 2}, - {"min_new_tokens": 16}, {"presence_penalty": 1}, + {"min_new_tokens": 16}, {"frequency_penalty": 0.2}, - {"min_new_tokens": 8}, {"presence_penalty": 0.4}, - {"presence_penalty": 0.4, "frequency_penalty": 2}, - {"min_new_tokens": 12, "frequency_penalty": 2}, + {"min_new_tokens": 8}, + {"frequency_penalty": 0.4, "presence_penalty": 0.8}, + {"frequency_penalty": 0.4, "min_new_tokens": 12}, + {"presence_penalty": 0.8, "min_new_tokens": 12}, + {"presence_penalty": -0.3, "frequency_penalty": 1.3, "min_new_tokens": 32}, + {"presence_penalty": 0.3, "frequency_penalty": -1.3, "min_new_tokens": 32}, ] random.shuffle(args * 5) with ThreadPoolExecutor(8) as executor: From b93ef5e56d5ea0a4ecf6f79eba422b70c33384f9 Mon Sep 17 00:00:00 2001 From: lukec <118525388+sleepcoo@users.noreply.github.com> Date: Sat, 8 Mar 2025 14:42:16 +0800 Subject: [PATCH 05/27] Remove the vllm dependency from the moe_align function (#4164) Co-authored-by: Hongbosherlock --- .../sgl-kernel/csrc/moe/moe_align_kernel.cu | 18 ++++++++++-------- sgl-kernel/tests/test_moe_align.py | 8 +++++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu index c5f37e55609..83609a3294b 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -47,18 +47,18 @@ __global__ void moe_align_block_size_kernel( int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, + int32_t experts_per_warp, int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) { - __shared__ int32_t shared_counts[WARP_SIZE][8]; + extern __shared__ int32_t shared_counts[]; const int warp_id = threadIdx.x / WARP_SIZE; - const int experts_per_warp = 8; const int my_expert_start = warp_id * experts_per_warp; for (int i = 0; i < experts_per_warp; ++i) { if (my_expert_start + i < num_experts) { - shared_counts[warp_id][i] = 0; + shared_counts[warp_id * experts_per_warp + i] = 0; } } @@ -71,7 +71,7 @@ __global__ void moe_align_block_size_kernel( int expert_id = topk_ids[i]; int warp_idx = expert_id / experts_per_warp; int expert_offset = expert_id % experts_per_warp; - atomicAdd(&shared_counts[warp_idx][expert_offset], 1); + atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1); } __syncthreads(); @@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel( int expert_count = 0; int warp_idx = (i - 1) / experts_per_warp; int expert_offset = (i - 1) % experts_per_warp; - expert_count = shared_counts[warp_idx][expert_offset]; + expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; } @@ -108,16 +108,18 @@ void moe_align_block_size( torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now."); - + TORCH_CHECK(num_experts % WARP_SIZE == 0); + int experts_per_warp = num_experts / WARP_SIZE; DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { auto align_kernel = moe_align_block_size_kernel; - align_kernel<<<1, 1024, 0, stream>>>( + size_t shared_mem_size = 32 * experts_per_warp * sizeof(int32_t); + align_kernel<<<1, 1024, shared_mem_size, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, + experts_per_warp, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 81d05ffa1fe..3d89c3406cc 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -138,18 +138,20 @@ def moe_align_block_size_triton( @pytest.mark.parametrize( - "block_size,num_tokens,topk", + "block_size,num_tokens,topk,num_experts", list( itertools.product( [32, 64, 128, 256], # block_size [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens [1, 2, 4, 8, 16, 32, 64], # topk + [64, 160, 256], # num_experts ) ), ) -def test_moe_align_block_size_compare_implementations(block_size, num_tokens, topk): +def test_moe_align_block_size_compare_implementations( + block_size, num_tokens, topk, num_experts +): # For DeepSeek V3, we have 256 experts - num_experts = 256 topk_ids = torch.stack( [ From 90bb2be27e498be472af40f5ace8b2d9cd817d1d Mon Sep 17 00:00:00 2001 From: Rex Date: Fri, 7 Mar 2025 22:52:12 -0800 Subject: [PATCH 06/27] Minor improvement to per_tensor_quant_fp8 (#4197) --- .../sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index ea222c00150..f1f7d14a92b 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -57,13 +57,9 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output template __global__ void per_tensor_quant_fp8_kernel( - const T* __restrict__ input, - FP8_TYPE* __restrict__ output, - const float* __restrict__ scale, - const int64_t num_elements) { + const T* __restrict__ input, FP8_TYPE* __restrict__ output, const float scale_val, const int64_t num_elements) { const int gid = blockIdx.x * blockDim.x + threadIdx.x; const int grid_size = blockDim.x * gridDim.x; - const float scale_val = 1.0f / (*scale); constexpr uint32_t vec_size = 16 / sizeof(T); using vec_t = flashinfer::vec_t; @@ -125,12 +121,9 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch per_tensor_absmax_kernel<<>>( static_cast(input.data_ptr()), static_cast(output_s.data_ptr()), num_elements); } - + float scale_val = 1.0f / (*static_cast(output_s.data_ptr())); per_tensor_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), - static_cast(output_q.data_ptr()), - static_cast(output_s.data_ptr()), - num_elements); + static_cast(input.data_ptr()), static_cast(output_q.data_ptr()), scale_val, num_elements); return true; }); } From 96d0e37fa7621c37a130ec12f867c8f99c9ef878 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 7 Mar 2025 22:57:09 -0800 Subject: [PATCH 07/27] Revert "Minor improvement to per_tensor_quant_fp8 (#4197)" (#4198) --- .../sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index f1f7d14a92b..ea222c00150 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -57,9 +57,13 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output template __global__ void per_tensor_quant_fp8_kernel( - const T* __restrict__ input, FP8_TYPE* __restrict__ output, const float scale_val, const int64_t num_elements) { + const T* __restrict__ input, + FP8_TYPE* __restrict__ output, + const float* __restrict__ scale, + const int64_t num_elements) { const int gid = blockIdx.x * blockDim.x + threadIdx.x; const int grid_size = blockDim.x * gridDim.x; + const float scale_val = 1.0f / (*scale); constexpr uint32_t vec_size = 16 / sizeof(T); using vec_t = flashinfer::vec_t; @@ -121,9 +125,12 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch per_tensor_absmax_kernel<<>>( static_cast(input.data_ptr()), static_cast(output_s.data_ptr()), num_elements); } - float scale_val = 1.0f / (*static_cast(output_s.data_ptr())); + per_tensor_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), static_cast(output_q.data_ptr()), scale_val, num_elements); + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + num_elements); return true; }); } From 08c4d764a51e795515d49a2d8aaabdee1ba66ab7 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 8 Mar 2025 00:41:35 -0800 Subject: [PATCH 08/27] lazy import attn backends (#4200) --- .../srt/layers/attention/triton_backend.py | 4 +--- .../srt/model_executor/cuda_graph_runner.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 24 ++++++++++++++----- test/srt/test_eagle_infer.py | 2 +- 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index a9d72618085..b942dee5cf5 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -6,9 +6,7 @@ import triton from sglang.srt.layers.attention.base_attn_backend import AttentionBackend -from sglang.srt.layers.attention.flashinfer_backend import ( - create_flashinfer_kv_indices_triton, -) +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 6a2bab22a46..83c2d88f01f 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -302,7 +302,7 @@ def capture(self): self.stream = graph_capture_context.stream # Reverse the order to enable better memory sharing across cuda graphs. capture_range = ( - tqdm.tqdm(reversed(self.capture_bs)) + tqdm.tqdm(list(reversed(self.capture_bs))) if get_tensor_model_parallel_rank() == 0 else reversed(self.capture_bs) ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6489ea6eddf..8040709a721 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -35,11 +35,6 @@ set_custom_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state -from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend -from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend -from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend -from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend -from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.dp_attention import ( get_attention_tp_group, get_attention_tp_size, @@ -77,7 +72,6 @@ set_cpu_offload_max_bytes, set_cuda_arch, ) -from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -779,6 +773,10 @@ def init_cublas(self): def init_attention_backend(self): """Init attention kernel backend.""" if self.server_args.attention_backend == "flashinfer": + from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferAttnBackend, + ) + # Init streams if self.server_args.speculative_algorithm == "EAGLE": self.plan_stream_for_flashinfer = torch.cuda.Stream() @@ -794,12 +792,26 @@ def init_attention_backend(self): "Please use `--attention-backend flashinfer`." ) if self.server_args.enable_double_sparsity: + from sglang.srt.layers.attention.double_sparsity_backend import ( + DoubleSparseAttnBackend, + ) + self.attn_backend = DoubleSparseAttnBackend(self) else: + from sglang.srt.layers.attention.triton_backend import TritonAttnBackend + self.attn_backend = TritonAttnBackend(self) elif self.server_args.attention_backend == "torch_native": + from sglang.srt.layers.attention.torch_native_backend import ( + TorchNativeAttnBackend, + ) + self.attn_backend = TorchNativeAttnBackend(self) elif self.server_args.attention_backend == "flashinfer_mla": + from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAAttnBackend, + ) + self.attn_backend = FlashInferMLAAttnBackend(self) else: raise ValueError( diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 5b89071b65f..a87b6e37b89 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -108,7 +108,7 @@ def _test_batch_generation(self, engine): def _test_eos_token(self, engine): prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]" params = { - "temperature": 0, + "temperature": 0.1, "max_new_tokens": 1024, "skip_special_tokens": False, } From 0fe7c13be18d1edd8682747ce558b430a1aa1c9e Mon Sep 17 00:00:00 2001 From: Mingshan Date: Sat, 8 Mar 2025 17:03:38 +0800 Subject: [PATCH 09/27] Fix bench_serving flush cache not recognizing OPENAI_API_KEY (#4181) Signed-off-by: Mingshan --- python/sglang/bench_serving.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index db22f42c46c..036d1e86499 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -1006,7 +1006,7 @@ async def limited_request_func(request_func_input, pbar): # Flush cache if "sglang" in backend: - requests.post(base_url + "/flush_cache") + requests.post(base_url + "/flush_cache", headers=get_auth_headers()) time.sleep(1.0) From 8d323e95e4406d5663725b177571757c1d402e1e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 8 Mar 2025 01:28:10 -0800 Subject: [PATCH 10/27] Use clang format 18 in pr-test-sgl-kernel.yml (#4203) --- .github/workflows/pr-test-sgl-kernel.yml | 2 +- sgl-kernel/src/sgl-kernel/include/utils.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index df059c1f402..8c8d5ce97c7 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -27,7 +27,7 @@ jobs: with: source: sgl-kernel extensions: h,c,cpp,hpp,cu,cuh,cc - clangFormatVersion: 16 + clangFormatVersion: 18 style: file build-wheels: diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h index b2960954bcb..79bf84671d0 100644 --- a/sgl-kernel/src/sgl-kernel/include/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -94,7 +94,7 @@ inline int getSMVersion() { #define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) -#define CEILDIV(x, y) (((x) + (y)-1) / (y)) +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) #define WARP_SIZE 32 #ifndef USE_ROCM From 4a893d142ded28c20178faeaaee9f2bf0caa24a7 Mon Sep 17 00:00:00 2001 From: Kebe Date: Sat, 8 Mar 2025 19:01:13 +0800 Subject: [PATCH 11/27] Refactor Dockerfile: unify CUDA logic and reduce image size by ~2.6 GB (#3749) Signed-off-by: Kebe --- docker/Dockerfile | 39 +++++++-------------------------------- 1 file changed, 7 insertions(+), 32 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 3ae74a8cccb..075b1e8d92c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -30,44 +30,19 @@ ARG CUDA_VERSION RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ && git clone --depth=1 https://github.com/sgl-project/sglang.git \ && if [ "$CUDA_VERSION" = "12.1.1" ]; then \ - python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu121; \ + export CUINDEX=121; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ - python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ + export CUINDEX=124; \ elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ - python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ + export CUINDEX=124; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ - python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118; \ - python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ + export CUINDEX=118; \ + python3 -m pip install --no-cache-dir sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ else \ echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ fi \ + && python3 -m pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cu${CUINDEX} \ && cd sglang \ - && if [ "$BUILD_TYPE" = "srt" ]; then \ - if [ "$CUDA_VERSION" = "12.1.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu121/torch2.5/flashinfer-python; \ - elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \ - elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \ - elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ - python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu118/torch2.5/flashinfer-python; \ - python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ - else \ - echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ - fi; \ - else \ - if [ "$CUDA_VERSION" = "12.1.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.5/flashinfer-python; \ - elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \ - elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \ - elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ - python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu118/torch2.5/flashinfer-python; \ - python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ - else \ - echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ - fi; \ - fi + && python3 -m pip --no-cache-dir install -e "python[${BUILD_TYPE}]" --find-links https://flashinfer.ai/whl/cu${CUINDEX}/torch2.5/flashinfer-python ENV DEBIAN_FRONTEND=interactive From 2cadd51d11a7fddf7c15833f6fca617428af7ef2 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 8 Mar 2025 05:23:06 -0800 Subject: [PATCH 12/27] Test no vllm custom allreduce (#4210) --- .github/workflows/pr-test.yml | 2 ++ test/srt/test_bench_one_batch.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 225c215c8c9..5ac06597327 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -269,6 +269,8 @@ jobs: cd test/srt python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1 + USE_VLLM_CUSTOM_ALLREDUCE=0 python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1 + - name: Benchmark single latency + torch.compile (TP=2) timeout-minutes: 10 run: | diff --git a/test/srt/test_bench_one_batch.py b/test/srt/test_bench_one_batch.py index 1d50b574744..f4140b89fce 100644 --- a/test/srt/test_bench_one_batch.py +++ b/test/srt/test_bench_one_batch.py @@ -11,7 +11,9 @@ class TestBenchOneBatch(unittest.TestCase): def test_bs1(self): - output_throughput = run_bench_one_batch(DEFAULT_MODEL_NAME_FOR_TEST, []) + output_throughput = run_bench_one_batch( + DEFAULT_MODEL_NAME_FOR_TEST, ["--cuda-graph-max-bs", "2"] + ) if is_in_ci(): write_github_step_summary( @@ -22,7 +24,7 @@ def test_bs1(self): def test_moe_tp2_bs1(self): output_throughput = run_bench_one_batch( - DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2"] + DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2", "--cuda-graph-max-bs", "2"] ) if is_in_ci(): From b3251e9f40b85159d52563b9ca8276fa0fa03703 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Sat, 8 Mar 2025 21:47:35 +0800 Subject: [PATCH 13/27] refine quant kernel code style (#4211) --- .../sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu | 13 +------------ .../sgl-kernel/csrc/gemm/per_token_quant_fp8.cu | 14 +------------- sgl-kernel/src/sgl-kernel/include/utils.h | 16 ++++++++++++++++ 3 files changed, 18 insertions(+), 25 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index ea222c00150..a95d5ea720a 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -37,18 +37,7 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output max_value = fmaxf(max_value, fabsf(val)); } - static __shared__ float warpLevelMaxs[WARP_SIZE]; - const int laneId = threadIdx.x % WARP_SIZE; - const int warpId = threadIdx.x / WARP_SIZE; - - max_value = warpReduceMax(max_value); - - if (laneId == 0) warpLevelMaxs[warpId] = max_value; - __syncthreads(); - - max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; - - if (warpId == 0) max_value = warpReduceMax(max_value); + max_value = blockReduceMax(max_value); if (tid == 0) { atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX); diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index 1491af126ef..12616ff441f 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -30,19 +30,7 @@ __global__ void per_token_quant_fp8_kernel( max_value = fmaxf(max_value, fabsf(val)); } - max_value = warpReduceMax(max_value); - - static __shared__ float warpLevelMaxs[WARP_SIZE]; - const int laneId = threadIdx.x % WARP_SIZE; - const int warpId = threadIdx.x / WARP_SIZE; - - if (laneId == 0) warpLevelMaxs[warpId] = max_value; - __syncthreads(); - - if (warpId == 0) { - max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; - max_value = warpReduceMax(max_value); - } + max_value = blockReduceMax(max_value); __shared__ float block_max; if (tid == 0) { diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h index 79bf84671d0..c099bf5aa22 100644 --- a/sgl-kernel/src/sgl-kernel/include/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -124,4 +124,20 @@ __device__ __forceinline__ float warpReduceMax(float max_value) { max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); return max_value; } + +__device__ __forceinline__ float blockReduceMax(float max_value) { + static __shared__ float warpLevelMaxs[WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + + max_value = warpReduceMax(max_value); + + if (laneId == 0) warpLevelMaxs[warpId] = max_value; + __syncthreads(); + + max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; + if (warpId == 0) max_value = warpReduceMax(max_value); + + return max_value; +} #endif From 48473684cc3e3d080fca85b089375700788f2d7a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 8 Mar 2025 15:40:49 -0800 Subject: [PATCH 14/27] Split test_mla.py into two files (#4216) --- .github/workflows/pr-test-amd.yml | 4 +- test/srt/run_suite.py | 1 + test/srt/test_mla.py | 100 -------------------------- test/srt/test_mla_deepseek_v3.py | 113 ++++++++++++++++++++++++++++++ 4 files changed, 116 insertions(+), 102 deletions(-) create mode 100644 test/srt/test_mla_deepseek_v3.py diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index 507590025e2..03406ef86ff 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -90,11 +90,11 @@ jobs: - name: MLA TEST timeout-minutes: 20 run: | - docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py TestMLA + docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py finish: needs: [ - accuracy-test-1-gpu-amd + accuracy-test-1-gpu-amd, mla-test-1-gpu-amd ] runs-on: ubuntu-latest steps: diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 754fe9a79fa..ebab2bf68b7 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -24,6 +24,7 @@ "test_gguf.py", "test_input_embeddings.py", "test_mla.py", + "test_mla_deepseek_v3.py", "test_mla_flashinfer.py", "test_mla_fp8.py", "test_json_constrained.py", diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index b2a831f99ff..b1f9d090d47 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -1,11 +1,7 @@ import unittest from types import SimpleNamespace -import requests -import torch - from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -56,101 +52,5 @@ def test_mgsm_en(self): self.assertGreater(metrics["score"], 0.8) -class TestDeepseekV3(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = "lmsys/sglang-ci-dsv3-test" - cls.base_url = DEFAULT_URL_FOR_TEST - other_args = ["--trust-remote-code"] - if torch.cuda.is_available() and torch.version.cuda: - other_args.extend(["--enable-torch-compile", "--cuda-graph-max-bs", "2"]) - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=other_args, - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.62) - - -class TestDeepseekV3MTP(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = "lmsys/sglang-ci-dsv3-test" - cls.base_url = DEFAULT_URL_FOR_TEST - other_args = ["--trust-remote-code"] - if torch.cuda.is_available() and torch.version.cuda: - other_args.extend( - [ - "--cuda-graph-max-bs", - "2", - "--disable-radix", - "--enable-torch-compile", - "--torch-compile-max-bs", - "1", - "--speculative-algorithm", - "EAGLE", - "--speculative-draft", - "lmsys/sglang-ci-dsv3-test-NextN", - "--speculative-num-steps", - "2", - "--speculative-eagle-topk", - "4", - "--speculative-num-draft-tokens", - "4", - ] - ) - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=other_args, - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - requests.get(self.base_url + "/flush_cache") - - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) - - server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] - print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 2.5) - - if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_mla_deepseek_v3.py b/test/srt/test_mla_deepseek_v3.py new file mode 100644 index 00000000000..ba43c2ba14b --- /dev/null +++ b/test/srt/test_mla_deepseek_v3.py @@ -0,0 +1,113 @@ +import unittest +from types import SimpleNamespace + +import requests +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestDeepseekV3(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "lmsys/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend(["--enable-torch-compile", "--cuda-graph-max-bs", "2"]) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + +class TestDeepseekV3MTP(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "lmsys/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend( + [ + "--cuda-graph-max-bs", + "2", + "--disable-radix", + "--enable-torch-compile", + "--torch-compile-max-bs", + "1", + "--speculative-algorithm", + "EAGLE", + "--speculative-draft", + "lmsys/sglang-ci-dsv3-test-NextN", + "--speculative-num-steps", + "2", + "--speculative-eagle-topk", + "4", + "--speculative-num-draft-tokens", + "4", + ] + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 2.5) + + +if __name__ == "__main__": + unittest.main() From 6eec3cdce6a48d09c58872d5bd5569a8434252d8 Mon Sep 17 00:00:00 2001 From: Xihuai Wang Date: Sun, 9 Mar 2025 12:14:50 +0800 Subject: [PATCH 15/27] docs(reasoning content): :memo: deepseek-r1 parser support qwq (#4124) --- docs/backend/separate_reasoning.ipynb | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/backend/separate_reasoning.ipynb b/docs/backend/separate_reasoning.ipynb index d9a927c19de..756ecbaa995 100644 --- a/docs/backend/separate_reasoning.ipynb +++ b/docs/backend/separate_reasoning.ipynb @@ -11,7 +11,8 @@ "## Supported Models\n", "\n", "Currently, SGLang supports the following reasoning models:\n", - "- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d): The reasoning content is wrapped with `` and `` tags." + "- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d): The reasoning content is wrapped with `` and `` tags.\n", + "- [QwQ](https://huggingface.co/Qwen/QwQ-32B): The reasoning content is wrapped with `` and `` tags." ] }, { @@ -55,6 +56,15 @@ "wait_for_server(f\"http://localhost:{port}\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that `--reasoning-parser` defines the parser used to interpret responses. Currently supported parsers include:\n", + "\n", + "- deepseek-r1: DeepSeek R1 series and QwQ (e.g. deepseek-ai/DeepSeek-R1, Qwen/QwQ-32B)." + ] + }, { "cell_type": "markdown", "metadata": {}, From 79a321af55aafc14d660f9cf1ace43b8f5edc8d3 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Sun, 9 Mar 2025 13:15:14 +0800 Subject: [PATCH 16/27] revert pr 3628 to pass test_mla ci (#4219) --- .../csrc/gemm/per_token_group_quant_fp8.cu | 76 ++++++++----------- 1 file changed, 30 insertions(+), 46 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu index bb3135dad23..3ad43e7c601 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu @@ -2,18 +2,17 @@ #include #include -#include #include "utils.h" using FP8_TYPE = c10::Float8_e4m3fn; -__device__ __forceinline__ float GroupReduce(float val, const int tid) { - val = fmaxf(val, __shfl_xor_sync(0xffff, val, 8)); - val = fmaxf(val, __shfl_xor_sync(0xffff, val, 4)); - val = fmaxf(val, __shfl_xor_sync(0xffff, val, 2)); - val = fmaxf(val, __shfl_xor_sync(0xffff, val, 1)); - return val; +__device__ __forceinline__ float GroupReduceMax(volatile float* smem, const int tid) { + smem[tid] = fmaxf(smem[tid], smem[tid + 8]); + if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]); + if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]); + if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]); + return smem[0]; } template @@ -27,60 +26,45 @@ __global__ void per_token_group_quant_fp8_kernel( const float fp8_min, const float fp8_max) { const int groups_per_block = 16; - const int local_group_id = threadIdx.x / 16; - const int lane_id = threadIdx.x % 16; - const int block_group_id = blockIdx.x * groups_per_block; - const int block_group_offset = (block_group_id + local_group_id) * group_size; + const int tid = threadIdx.x; + const int local_group_id = tid / 16; + const int local_tid = tid % 16; - __shared__ float s_absmax[16]; + __shared__ float s_absmax[16][17]; float local_absmax = eps; - const T* group_input = input + block_group_offset; - FP8_TYPE* group_output = static_cast(output_q) + block_group_offset; - float* scale_output = output_s + block_group_id + local_group_id; - - constexpr uint32_t vec_size = 16 / sizeof(T); - using vec_t = flashinfer::vec_t; - - const int32_t num_vec_elems = group_size / vec_size; + if (block_group_id + local_group_id < num_groups) { + const T* group_input = input + (block_group_id + local_group_id) * group_size; + FP8_TYPE* group_output = static_cast(output_q) + (block_group_id + local_group_id) * group_size; + float* scale_output = output_s + block_group_id + local_group_id; - for (int32_t i = lane_id; i < num_vec_elems; i += 16) { - vec_t input_vec; - input_vec.cast_load(group_input + i * vec_size); - -#pragma unroll - for (uint32_t j = 0; j < vec_size; ++j) { - float val = static_cast(input_vec[j]); + for (int i = local_tid; i < group_size; i += 16) { + float val = static_cast(group_input[i]); float abs_val = fabsf(val); local_absmax = fmaxf(local_absmax, abs_val); } - } - local_absmax = GroupReduce(local_absmax, lane_id); + s_absmax[local_group_id][local_tid] = local_absmax; + __syncthreads(); - if (lane_id == 0) { - s_absmax[local_group_id] = local_absmax; - } - __syncthreads(); + if (local_tid < 8) { + GroupReduceMax(&s_absmax[local_group_id][0], local_tid); + } + __syncthreads(); - const float group_absmax = s_absmax[local_group_id]; - const float y_s = group_absmax / fp8_max; + const float group_absmax = s_absmax[local_group_id][0]; + const float y_s = group_absmax / fp8_max; - if (lane_id == 0) { - *scale_output = y_s; - } - - for (int32_t i = lane_id; i < num_vec_elems; i += 16) { - vec_t input_vec; - input_vec.cast_load(group_input + i * vec_size); + if (local_tid == 0) { + *scale_output = y_s; + } -#pragma unroll - for (uint32_t j = 0; j < vec_size; ++j) { - float val = static_cast(input_vec[j]); + for (int i = local_tid; i < group_size; i += 16) { + float val = static_cast(group_input[i]); float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max); - group_output[i * vec_size + j] = FP8_TYPE(q_val); + group_output[i] = FP8_TYPE(q_val); } } } From ee132a45155149269aad823a974ff0cc51b54faf Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 8 Mar 2025 22:27:47 -0800 Subject: [PATCH 17/27] use latest sgl-kernel for mla test (#4222) --- .github/workflows/pr-test-sgl-kernel.yml | 33 +++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index 8c8d5ce97c7..0c38901f05a 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -95,8 +95,39 @@ jobs: run: | pip3 uninstall sgl-kernel -y + mla-test: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + needs: build-wheels + runs-on: 1-gpu-runner + steps: + - uses: actions/checkout@v4 + + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-* + + - name: Install + run: | + bash scripts/ci_install_dependency.sh + pip3 uninstall sgl-kernel -y || true + pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps + pip3 list | grep sgl-kernel + + - name: Run test + timeout-minutes: 30 + run: | + cd test/srt + python3 test_mla_deepseek_v3.py + + - name: Uninstall dependencies + run: | + pip3 uninstall sgl-kernel -y + finish: - needs: [unit-test, lint] + needs: [unit-test, mla-test, lint] runs-on: ubuntu-latest steps: - name: Finish From 8abf74e3c9353c2c33c83d156d5a69acf6274b72 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 8 Mar 2025 22:54:51 -0800 Subject: [PATCH 18/27] Rename files in sgl kernel to avoid nested folder structure (#4213) Co-authored-by: zhyncs --- .github/workflows/release-pypi-kernel.yml | 2 +- .github/workflows/release-whl-kernel.yml | 4 +- python/sglang/srt/_custom_ops.py | 30 +++++----- sgl-kernel/Makefile | 10 ++-- sgl-kernel/README.md | 13 ++-- .../csrc/allreduce/custom_all_reduce.hip | 0 .../csrc/allreduce/custom_all_reduce_hip.cuh | 0 .../csrc/allreduce/trt_reduce_internal.cu | 0 .../csrc/allreduce/trt_reduce_kernel.cu | 0 .../lightning_attention_decode_kernel.cu | 0 .../epilogue/epilogue_per_row_per_col_scale.h | 0 .../gemm/collective/collective_builder.hpp | 0 ..._warpspecialized_fp8_blockwise_scaling.hpp | 0 .../gemm/dispatch_policy.hpp | 0 .../gemm/gemm_universal_base_compat.h | 0 .../gemm/gemm_with_epilogue_visitor.h | 0 .../elementwise}/fused_add_rms_norm_kernel.cu | 0 .../csrc/gemm/cublas_grouped_gemm.cu | 0 .../csrc/gemm/fp8_blockwise_gemm_kernel.cu | 0 .../csrc/gemm/fp8_gemm_kernel.cu | 0 .../csrc/gemm/int8_gemm_kernel.cu | 0 .../csrc/gemm/per_tensor_quant_fp8.cu | 0 .../csrc/gemm/per_token_group_quant_fp8.cu | 0 .../csrc/gemm/per_token_quant_fp8.cu | 0 .../csrc/moe/moe_align_kernel.cu | 0 .../csrc/speculative/eagle_utils.cu | 0 .../csrc/speculative/speculative_sampling.cu | 0 .../csrc/speculative/speculative_sampling.cuh | 0 .../sgl-kernel => csrc}/torch_extension.cc | 60 +++++++++---------- .../torch_extension_rocm.cc | 4 +- .../sgl_kernel_ops.h} | 52 ++++++++-------- .../include/trt_reduce_internal.cuh | 0 .../{src/sgl-kernel => }/include/utils.h | 0 sgl-kernel/pyproject.toml | 4 -- .../sgl_kernel}/__init__.py | 15 ++--- .../ops => python/sgl_kernel}/allreduce.py | 31 +++++----- .../ops => python/sgl_kernel}/attention.py | 3 +- .../sgl_kernel/elementwise.py} | 19 +++--- .../ops => python/sgl_kernel}/gemm.py | 19 +++--- .../ops => python/sgl_kernel}/moe.py | 3 +- .../ops => python/sgl_kernel}/sampling.py | 13 ++-- .../ops => python/sgl_kernel}/speculative.py | 9 ++- .../ops => python/sgl_kernel}/utils.py | 0 .../sgl_kernel}/version.py | 0 sgl-kernel/setup.py | 42 ++++++------- sgl-kernel/setup_rocm.py | 48 +++++++-------- sgl-kernel/tests/test_trt_allreduce.py | 2 +- 47 files changed, 184 insertions(+), 199 deletions(-) rename sgl-kernel/{src/sgl-kernel => }/csrc/allreduce/custom_all_reduce.hip (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/allreduce/custom_all_reduce_hip.cuh (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/allreduce/trt_reduce_internal.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/allreduce/trt_reduce_kernel.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/attention/lightning_attention_decode_kernel.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/cutlass_extensions/gemm/dispatch_policy.hpp (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h (100%) rename sgl-kernel/{src/sgl-kernel/csrc/activation => csrc/elementwise}/fused_add_rms_norm_kernel.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/gemm/cublas_grouped_gemm.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/gemm/fp8_blockwise_gemm_kernel.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/gemm/fp8_gemm_kernel.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/gemm/int8_gemm_kernel.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/gemm/per_tensor_quant_fp8.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/gemm/per_token_group_quant_fp8.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/gemm/per_token_quant_fp8.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/moe/moe_align_kernel.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/speculative/eagle_utils.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/speculative/speculative_sampling.cu (100%) rename sgl-kernel/{src/sgl-kernel => }/csrc/speculative/speculative_sampling.cuh (100%) rename sgl-kernel/{src/sgl-kernel => csrc}/torch_extension.cc (98%) rename sgl-kernel/{src/sgl-kernel => csrc}/torch_extension_rocm.cc (97%) rename sgl-kernel/{src/sgl-kernel/include/sgl_kernels_ops.h => include/sgl_kernel_ops.h} (99%) rename sgl-kernel/{src/sgl-kernel => }/include/trt_reduce_internal.cuh (100%) rename sgl-kernel/{src/sgl-kernel => }/include/utils.h (100%) rename sgl-kernel/{src/sgl-kernel => python/sgl_kernel}/__init__.py (74%) rename sgl-kernel/{src/sgl-kernel/ops => python/sgl_kernel}/allreduce.py (62%) rename sgl-kernel/{src/sgl-kernel/ops => python/sgl_kernel}/attention.py (62%) rename sgl-kernel/{src/sgl-kernel/ops/activation.py => python/sgl_kernel/elementwise.py} (87%) rename sgl-kernel/{src/sgl-kernel/ops => python/sgl_kernel}/gemm.py (81%) rename sgl-kernel/{src/sgl-kernel/ops => python/sgl_kernel}/moe.py (83%) rename sgl-kernel/{src/sgl-kernel/ops => python/sgl_kernel}/sampling.py (94%) rename sgl-kernel/{src/sgl-kernel/ops => python/sgl_kernel}/speculative.py (88%) rename sgl-kernel/{src/sgl-kernel/ops => python/sgl_kernel}/utils.py (100%) rename sgl-kernel/{src/sgl-kernel => python/sgl_kernel}/version.py (100%) diff --git a/.github/workflows/release-pypi-kernel.yml b/.github/workflows/release-pypi-kernel.yml index 495bf68c8b2..f589119e61a 100644 --- a/.github/workflows/release-pypi-kernel.yml +++ b/.github/workflows/release-pypi-kernel.yml @@ -5,7 +5,7 @@ on: branches: - main paths: - - sgl-kernel/src/sgl-kernel/version.py + - sgl-kernel/python/sgl_kernel/version.py workflow_dispatch: concurrency: diff --git a/.github/workflows/release-whl-kernel.yml b/.github/workflows/release-whl-kernel.yml index 5eaa0127fa7..631551475fe 100644 --- a/.github/workflows/release-whl-kernel.yml +++ b/.github/workflows/release-whl-kernel.yml @@ -9,7 +9,7 @@ on: branches: - main paths: - - sgl-kernel/src/sgl-kernel/version.py + - sgl-kernel/python/sgl_kernel/version.py jobs: build-wheels: @@ -59,7 +59,7 @@ jobs: id: set_tag_name run: | if [ -z "${{ inputs.tag_name }}" ]; then - TAG_NAME="v$(cat sgl-kernel/src/sgl-kernel/version.py | cut -d'"' -f2)" + TAG_NAME="v$(cat sgl-kernel/python/sgl_kernel/version.py | cut -d'"' -f2)" echo "tag_name=$TAG_NAME" >> $GITHUB_OUTPUT else echo "tag_name=${{ inputs.tag_name }}" >> $GITHUB_OUTPUT diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index c5056ffc272..d06765c3a8c 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -75,42 +75,42 @@ def init_custom_ar( rank: int, full_nvlink: bool, ) -> int: - return sgl_kernel.ops.allreduce.init_custom_ar( + return sgl_kernel.allreduce.init_custom_ar( meta, rank_data, handles, offsets, rank, full_nvlink ) def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - sgl_kernel.ops.allreduce.all_reduce_reg(fa, inp, out) + sgl_kernel.allreduce.all_reduce_reg(fa, inp, out) def all_reduce_unreg( fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor ) -> None: - sgl_kernel.ops.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out) + sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out) def dispose(fa: int) -> None: - sgl_kernel.ops.allreduce.dispose(fa) + sgl_kernel.allreduce.dispose(fa) def meta_size() -> int: - return sgl_kernel.ops.allreduce.meta_size() + return sgl_kernel.allreduce.meta_size() def register_buffer( fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] ) -> None: - return sgl_kernel.ops.allreduce.register_buffer(fa, t, handles, offsets) + return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: - return sgl_kernel.ops.allreduce.get_graph_buffer_ipc_meta(fa) + return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa) def register_graph_buffers( fa: int, handles: List[str], offsets: List[List[int]] ) -> None: - sgl_kernel.ops.allreduce.register_graph_buffers(fa, handles, offsets) + sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets) def allocate_meta_buffer(size: int) -> torch.Tensor: - return sgl_kernel.ops.allreduce.allocate_meta_buffer(size) + return sgl_kernel.allreduce.allocate_meta_buffer(size) def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: - return sgl_kernel.ops.allreduce.get_meta_buffer_ipc_handle(inp) + return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp) else: # TRTLLM custom allreduce @@ -123,7 +123,7 @@ def init_custom_ar( barrier_in: List[int], barrier_out: List[int], ) -> int: - return sgl_kernel.ops.init_custom_reduce( + return sgl_kernel.init_custom_reduce( rank_id, world_size, rank_data_base, @@ -134,15 +134,15 @@ def init_custom_ar( ) def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - sgl_kernel.ops.custom_reduce(fa, inp, out) + sgl_kernel.custom_reduce(fa, inp, out) def dispose(fa: int) -> None: - sgl_kernel.ops.custom_dispose(fa) + sgl_kernel.custom_dispose(fa) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: - return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) + return sgl_kernel.get_graph_buffer_ipc_meta(fa) def register_graph_buffers( fa: int, handles: List[List[int]], offsets: List[List[int]] ) -> None: - sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) + sgl_kernel.register_graph_buffers(fa, handles, offsets) diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index 986e424f403..53375fa0fa3 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -38,12 +38,12 @@ test: ## Run all tests format: check-deps ## Format all source files @echo "Formatting source files..." - @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i - @find src tests -name '*.py' | xargs isort - @find src tests -name '*.py' | xargs black + @find csrc tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i + @find python tests -name '*.py' | xargs isort + @find python tests -name '*.py' | xargs black @pre-commit run --all-files -FILES_TO_UPDATE = src/sgl-kernel/version.py \ +FILES_TO_UPDATE = python/sgl_kernel/version.py \ pyproject.toml update: ## Update version numbers across project files. Usage: make update @@ -51,7 +51,7 @@ update: ## Update version numbers across project files. Usage: make update "; \ exit 1; \ fi - @OLD_VERSION=$$(grep "version" src/sgl-kernel/version.py | cut -d '"' -f2); \ + @OLD_VERSION=$$(grep "version" python/sgl_kernel/version.py | cut -d '"' -f2); \ NEW_VERSION=$(filter-out $@,$(MAKECMDGOALS)); \ echo "Updating version from $$OLD_VERSION to $$NEW_VERSION"; \ for file in $(FILES_TO_UPDATE); do \ diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 1f805cbd000..689f34be0b5 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -45,12 +45,11 @@ Third-party libraries: Steps to add a new kernel: -1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc) -2. Expose interface in [src/sgl-kernel/include/sgl_kernels_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h) -3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc) -4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py) -5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py) -6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source +1. Implement the kernel in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc) +2. Expose the interface in [include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_ops.h) +3. Create torch extension in [csrc/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/torch_extension.cc) +4. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source +5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel) ### Build & Install @@ -72,4 +71,4 @@ The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, t ### Release new version -Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/version.py) +Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/version.py) diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip b/sgl-kernel/csrc/allreduce/custom_all_reduce.hip similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip rename to sgl-kernel/csrc/allreduce/custom_all_reduce.hip diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh b/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh rename to sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu b/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu rename to sgl-kernel/csrc/allreduce/trt_reduce_internal.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu b/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu rename to sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu b/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu rename to sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h rename to sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp rename to sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp rename to sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp rename to sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h rename to sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h rename to sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h diff --git a/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu b/sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu rename to sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu b/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu rename to sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu rename to sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu rename to sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu b/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu rename to sgl-kernel/csrc/gemm/int8_gemm_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu rename to sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu rename to sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu rename to sgl-kernel/csrc/gemm/per_token_quant_fp8.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu rename to sgl-kernel/csrc/moe/moe_align_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu b/sgl-kernel/csrc/speculative/eagle_utils.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu rename to sgl-kernel/csrc/speculative/eagle_utils.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu b/sgl-kernel/csrc/speculative/speculative_sampling.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu rename to sgl-kernel/csrc/speculative/speculative_sampling.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh b/sgl-kernel/csrc/speculative/speculative_sampling.cuh similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh rename to sgl-kernel/csrc/speculative/speculative_sampling.cuh diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc similarity index 98% rename from sgl-kernel/src/sgl-kernel/torch_extension.cc rename to sgl-kernel/csrc/torch_extension.cc index a8ee8770787..9fd32bf99fe 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -16,33 +16,9 @@ limitations under the License. #include #include -#include "sgl_kernels_ops.h" - -TORCH_LIBRARY_EXPAND(sgl_kernels, m) { - /* - * From csrc/activation - */ - m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("rmsnorm", torch::kCUDA, &rmsnorm); - - m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); - m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); - - m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); - - m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); - - m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); - m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); - - m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); - m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); - - m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); - m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); +#include "sgl_kernel_ops.h" +TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From csrc/allreduce */ @@ -67,6 +43,30 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { */ m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); + /* + * From csrc/elementwise + */ + m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("rmsnorm", torch::kCUDA, &rmsnorm); + + m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); + m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); + + m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); + + m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); + + m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + /* * From csrc/gemm */ @@ -93,6 +93,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); + m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()"); + m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8); + m.def( "cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs," " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"); @@ -171,9 +174,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); - - m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()"); - m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8); } -REGISTER_EXTENSION(_kernels) +REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc similarity index 97% rename from sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc rename to sgl-kernel/csrc/torch_extension_rocm.cc index 95adea90bb7..014e311cf2b 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "sgl_kernels_ops.h" +#include "sgl_kernel_ops.h" -TORCH_LIBRARY_EXPAND(sgl_kernels, m) { +TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From csrc/allreduce */ diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/include/sgl_kernel_ops.h similarity index 99% rename from sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h rename to sgl-kernel/include/sgl_kernel_ops.h index 5bc5c7083b8..82412b6e0a2 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -36,18 +36,6 @@ limitations under the License. using fptr_t = int64_t; -/* - * From csrc/activation - */ -void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); -void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); -void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); -void gemma_fused_add_rmsnorm( - at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); -void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); -void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); -void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); - /* * From csrc/allreduce */ @@ -88,6 +76,30 @@ void register_graph_buffers( fptr_t _fa, const std::vector>& handles, const std::vector>& offsets); #endif +/* + * From csrc/attention + */ +void lightning_attention_decode( + const torch::Tensor& q, + const torch::Tensor& k, + const torch::Tensor& v, + const torch::Tensor& past_kv, + const torch::Tensor& slope, + torch::Tensor output, + torch::Tensor new_kv); + +/* + * From csrc/elementwise + */ +void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); +void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); +void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); +void gemma_fused_add_rmsnorm( + at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); +void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); +void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + /* * From csrc/gemm */ @@ -120,6 +132,7 @@ void sgl_per_token_group_quant_fp8( double fp8_min, double fp8_max); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); +void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); void cublas_grouped_gemm( const std::vector& inputs, const std::vector& weights, @@ -254,18 +267,3 @@ void apply_rope_pos_ids_cos_sin_cache( at::Tensor pos_ids, bool interleave, int64_t cuda_stream); - -/* - * Other - */ -void lightning_attention_decode( - const torch::Tensor& q, - const torch::Tensor& k, - const torch::Tensor& v, - const torch::Tensor& past_kv, - const torch::Tensor& slope, - torch::Tensor output, - torch::Tensor new_kv); - -// sgl_per_token_quant_fp8 -void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); diff --git a/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh b/sgl-kernel/include/trt_reduce_internal.cuh similarity index 100% rename from sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh rename to sgl-kernel/include/trt_reduce_internal.cuh diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/include/utils.h similarity index 100% rename from sgl-kernel/src/sgl-kernel/include/utils.h rename to sgl-kernel/include/utils.h diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 24325aeca3e..6c7eb3e60a0 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -20,10 +20,6 @@ dependencies = [] "Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues" -[tool.setuptools] -package-dir = {"sgl_kernel" = "src/sgl-kernel"} -packages = ["sgl_kernel", "sgl_kernel.ops", "sgl_kernel.csrc"] - [tool.wheel] exclude = [ "dist*", diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py similarity index 74% rename from sgl-kernel/src/sgl-kernel/__init__.py rename to sgl-kernel/python/sgl_kernel/__init__.py index ab7f673b027..c8cb0443d68 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -9,7 +9,10 @@ mode=ctypes.RTLD_GLOBAL, ) -from sgl_kernel.ops.activation import ( +from sgl_kernel import common_ops +from sgl_kernel.allreduce import * +from sgl_kernel.attention import lightning_attention_decode +from sgl_kernel.elementwise import ( apply_rope_with_cos_sin_cache_inplace, fused_add_rmsnorm, gelu_and_mul, @@ -19,9 +22,7 @@ rmsnorm, silu_and_mul, ) -from sgl_kernel.ops.allreduce import * -from sgl_kernel.ops.attention import lightning_attention_decode -from sgl_kernel.ops.gemm import ( +from sgl_kernel.gemm import ( bmm_fp8, cublas_grouped_gemm, fp8_blockwise_scaled_mm, @@ -31,15 +32,15 @@ sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8, ) -from sgl_kernel.ops.moe import moe_align_block_size -from sgl_kernel.ops.sampling import ( +from sgl_kernel.moe import moe_align_block_size +from sgl_kernel.sampling import ( min_p_sampling_from_probs, top_k_renorm_prob, top_k_top_p_sampling_from_probs, top_p_renorm_prob, top_p_sampling_from_probs, ) -from sgl_kernel.ops.speculative import ( +from sgl_kernel.speculative import ( build_tree_kernel, build_tree_kernel_efficient, tree_speculative_sampling_target_only, diff --git a/sgl-kernel/src/sgl-kernel/ops/allreduce.py b/sgl-kernel/python/sgl_kernel/allreduce.py similarity index 62% rename from sgl-kernel/src/sgl-kernel/ops/allreduce.py rename to sgl-kernel/python/sgl_kernel/allreduce.py index 05079e3f4e3..0924e7f3587 100644 --- a/sgl-kernel/src/sgl-kernel/ops/allreduce.py +++ b/sgl-kernel/python/sgl_kernel/allreduce.py @@ -1,6 +1,5 @@ from typing import List, Tuple -import sgl_kernel.ops._kernels import torch if torch.version.hip is not None: @@ -13,49 +12,49 @@ def init_custom_ar( rank: int, full_nvlink: bool, ) -> int: - return torch.ops.sgl_kernels.init_custom_ar( + return torch.ops.sgl_kernel.init_custom_ar( meta, rank_data, handles, offsets, rank, full_nvlink ) def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out) + torch.ops.sgl_kernel.all_reduce_reg(fa, inp, out) def all_reduce_unreg( fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor ) -> None: - torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out) + torch.ops.sgl_kernel.all_reduce_unreg(fa, inp, reg_buffer, out) def dispose(fa: int) -> None: - torch.ops.sgl_kernels.dispose(fa) + torch.ops.sgl_kernel.dispose(fa) def meta_size() -> int: - return torch.ops.sgl_kernels.meta_size() + return torch.ops.sgl_kernel.meta_size() def register_buffer( fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] ) -> None: - return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets) + return torch.ops.sgl_kernel.register_buffer(fa, t, handles, offsets) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: - return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) + return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa) def register_graph_buffers( fa: int, handles: List[str], offsets: List[List[int]] ) -> None: - torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) + torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets) def allocate_meta_buffer(size: int) -> torch.Tensor: - return torch.ops.sgl_kernels.allocate_meta_buffer(size) + return torch.ops.sgl_kernel.allocate_meta_buffer(size) def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: - return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp) + return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle(inp) else: # TRTLLM custom allreduce def init_custom_reduce( rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out ): - return torch.ops.sgl_kernels.init_custom_ar( + return torch.ops.sgl_kernel.init_custom_ar( rank_id, num_devices, rank_data, @@ -66,13 +65,13 @@ def init_custom_reduce( ) def custom_dispose(fa): - torch.ops.sgl_kernels.dispose(fa) + torch.ops.sgl_kernel.dispose(fa) def custom_reduce(fa, inp, out): - torch.ops.sgl_kernels.all_reduce(fa, inp, out) + torch.ops.sgl_kernel.all_reduce(fa, inp, out) def get_graph_buffer_ipc_meta(fa): - return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) + return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa) def register_graph_buffers(fa, handles, offsets): - torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) + torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets) diff --git a/sgl-kernel/src/sgl-kernel/ops/attention.py b/sgl-kernel/python/sgl_kernel/attention.py similarity index 62% rename from sgl-kernel/src/sgl-kernel/ops/attention.py rename to sgl-kernel/python/sgl_kernel/attention.py index a4cb5fc0b4b..53fec4dd167 100644 --- a/sgl-kernel/src/sgl-kernel/ops/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -1,8 +1,7 @@ -import sgl_kernel.ops._kernels import torch def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): - torch.ops.sgl_kernels.lightning_attention_decode( + torch.ops.sgl_kernel.lightning_attention_decode( q, k, v, past_kv, slope, output, new_kv ) diff --git a/sgl-kernel/src/sgl-kernel/ops/activation.py b/sgl-kernel/python/sgl_kernel/elementwise.py similarity index 87% rename from sgl-kernel/src/sgl-kernel/ops/activation.py rename to sgl-kernel/python/sgl_kernel/elementwise.py index 08a65ec01e0..fc6d8ea00ae 100644 --- a/sgl-kernel/src/sgl-kernel/ops/activation.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -1,8 +1,7 @@ from typing import Optional -import sgl_kernel.ops._kernels import torch -from sgl_kernel.ops.utils import get_cuda_stream +from sgl_kernel.utils import get_cuda_stream # These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer @@ -15,14 +14,14 @@ def rmsnorm( ) -> torch.Tensor: if out is None: out = torch.empty_like(input) - torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, get_cuda_stream()) + torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream()) return out def fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: - torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps) + torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps) def gemma_rmsnorm( @@ -33,14 +32,14 @@ def gemma_rmsnorm( ) -> torch.Tensor: if out is None: out = torch.empty_like(input) - torch.ops.sgl_kernels.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream()) + torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream()) return out def gemma_fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: - torch.ops.sgl_kernels.gemma_fused_add_rmsnorm( + torch.ops.sgl_kernel.gemma_fused_add_rmsnorm( input, residual, weight, eps, get_cuda_stream() ) @@ -66,7 +65,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernels.silu_and_mul(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream()) return out @@ -81,7 +80,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream()) return out @@ -96,7 +95,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernels.gelu_and_mul(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream()) return out @@ -141,7 +140,7 @@ def apply_rope_with_cos_sin_cache_inplace( raise ValueError("cos_sin_cache should be float32") positions = positions.int() - torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache( + torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache( q=query.view(query.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size), q_rope=query.view(query.shape[0], -1, head_size), diff --git a/sgl-kernel/src/sgl-kernel/ops/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py similarity index 81% rename from sgl-kernel/src/sgl-kernel/ops/gemm.py rename to sgl-kernel/python/sgl_kernel/gemm.py index 883894e966a..e5936da5677 100644 --- a/sgl-kernel/src/sgl-kernel/ops/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -1,12 +1,11 @@ from typing import List, Optional -import sgl_kernel.ops._kernels import torch -from sgl_kernel.ops.utils import _get_cache_buf, get_cuda_stream +from sgl_kernel.utils import _get_cache_buf, get_cuda_stream def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): - return torch.ops.sgl_kernels.int8_scaled_mm( + return torch.ops.sgl_kernel.int8_scaled_mm( mat_a, mat_b, scales_a, @@ -17,7 +16,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): - return torch.ops.sgl_kernels.fp8_blockwise_scaled_mm( + return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm( mat_a, mat_b, scales_a, @@ -27,7 +26,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): - return torch.ops.sgl_kernels.fp8_scaled_mm( + return torch.ops.sgl_kernel.fp8_scaled_mm( mat_a, mat_b, scales_a, @@ -46,7 +45,7 @@ def _bmm_fp8_internal( B_scale: torch.Tensor, ) -> None: cublas_handle = torch.cuda.current_blas_handle() - torch.ops.sgl_kernels.bmm_fp8( + torch.ops.sgl_kernel.bmm_fp8( A, B, D, @@ -86,7 +85,7 @@ def sgl_per_token_group_quant_fp8( fp8_min: float, fp8_max: float, ) -> None: - torch.ops.sgl_kernels.sgl_per_token_group_quant_fp8( + torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8( input, output_q, output_s, group_size, eps, fp8_min, fp8_max ) @@ -97,7 +96,7 @@ def sgl_per_tensor_quant_fp8( output_s: torch.Tensor, is_static: bool, ) -> None: - torch.ops.sgl_kernels.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static) + torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static) def cublas_grouped_gemm( @@ -110,7 +109,7 @@ def cublas_grouped_gemm( len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0 ), "Inputs/weights/outputs should not be empty!" cublas_handle = torch.cuda.current_blas_handle() - torch.ops.sgl_kernels.cublas_grouped_gemm( + torch.ops.sgl_kernel.cublas_grouped_gemm( inputs, weights, outputs, @@ -125,4 +124,4 @@ def sgl_per_token_quant_fp8( output_q: torch.Tensor, output_s: torch.Tensor, ) -> None: - torch.ops.sgl_kernels.sgl_per_token_quant_fp8(input, output_q, output_s) + torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s) diff --git a/sgl-kernel/src/sgl-kernel/ops/moe.py b/sgl-kernel/python/sgl_kernel/moe.py similarity index 83% rename from sgl-kernel/src/sgl-kernel/ops/moe.py rename to sgl-kernel/python/sgl_kernel/moe.py index 208198272f3..ad20da03611 100644 --- a/sgl-kernel/src/sgl-kernel/ops/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -1,4 +1,3 @@ -import sgl_kernel.ops._kernels import torch @@ -12,7 +11,7 @@ def moe_align_block_size( token_cnts_buffer, cumsum_buffer, ): - torch.ops.sgl_kernels.moe_align_block_size( + torch.ops.sgl_kernel.moe_align_block_size( topk_ids, num_experts, block_size, diff --git a/sgl-kernel/src/sgl-kernel/ops/sampling.py b/sgl-kernel/python/sgl_kernel/sampling.py similarity index 94% rename from sgl-kernel/src/sgl-kernel/ops/sampling.py rename to sgl-kernel/python/sgl_kernel/sampling.py index 1be42f8fd5f..7bf10bd4a48 100644 --- a/sgl-kernel/src/sgl-kernel/ops/sampling.py +++ b/sgl-kernel/python/sgl_kernel/sampling.py @@ -1,8 +1,7 @@ from typing import Optional, Tuple, Union -import sgl_kernel.ops._kernels import torch -from sgl_kernel.ops.utils import _to_tensor_scalar_tuple, get_cuda_stream +from sgl_kernel.utils import _to_tensor_scalar_tuple, get_cuda_stream def _top_k_renorm_probs_internal( @@ -13,7 +12,7 @@ def _top_k_renorm_probs_internal( probs = probs.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None renorm_probs = torch.empty_like(probs) - torch.ops.sgl_kernels.top_k_renorm_probs_wrapper( + torch.ops.sgl_kernel.top_k_renorm_probs_wrapper( probs, renorm_probs, maybe_top_k_arr, @@ -41,7 +40,7 @@ def _top_p_renorm_probs_internal( probs = probs.float() maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None renorm_probs = torch.empty_like(probs) - torch.ops.sgl_kernels.top_p_renorm_probs( + torch.ops.sgl_kernel.top_p_renorm_probs( probs, renorm_probs, maybe_top_p_arr, @@ -76,7 +75,7 @@ def _top_p_sampling_from_probs_internal( ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device) - torch.ops.sgl_kernels.top_p_sampling_from_probs( + torch.ops.sgl_kernel.top_p_sampling_from_probs( probs, uniform_samples, samples, @@ -122,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal( ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device) - torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs( + torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs( probs, uniform_samples, samples, @@ -180,7 +179,7 @@ def _min_p_sampling_from_probs_internal( maybe_min_p_arr.float() if maybe_min_p_arr is not None else None ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) - torch.ops.sgl_kernels.min_p_sampling_from_probs( + torch.ops.sgl_kernel.min_p_sampling_from_probs( probs, uniform_samples, samples, diff --git a/sgl-kernel/src/sgl-kernel/ops/speculative.py b/sgl-kernel/python/sgl_kernel/speculative.py similarity index 88% rename from sgl-kernel/src/sgl-kernel/ops/speculative.py rename to sgl-kernel/python/sgl_kernel/speculative.py index f209f16a93d..53acb1d95e9 100644 --- a/sgl-kernel/src/sgl-kernel/ops/speculative.py +++ b/sgl-kernel/python/sgl_kernel/speculative.py @@ -1,6 +1,5 @@ -import sgl_kernel.ops._kernels import torch -from sgl_kernel.ops.utils import get_cuda_stream +from sgl_kernel.utils import get_cuda_stream def tree_speculative_sampling_target_only( @@ -16,7 +15,7 @@ def tree_speculative_sampling_target_only( draft_probs: torch.Tensor, deterministic: bool = True, ) -> None: - torch.ops.sgl_kernels.tree_speculative_sampling_target_only( + torch.ops.sgl_kernel.tree_speculative_sampling_target_only( predicts, accept_index, accept_token_num, @@ -45,7 +44,7 @@ def build_tree_kernel_efficient( depth: int, draft_token_num: int, ) -> None: - torch.ops.sgl_kernels.build_tree_kernel_efficient( + torch.ops.sgl_kernel.build_tree_kernel_efficient( parent_list, selected_index, verified_seq_len, @@ -71,7 +70,7 @@ def build_tree_kernel( depth: int, draft_token_num: int, ) -> None: - torch.ops.sgl_kernels.build_tree_kernel( + torch.ops.sgl_kernel.build_tree_kernel( parent_list, selected_index, verified_seq_len, diff --git a/sgl-kernel/src/sgl-kernel/ops/utils.py b/sgl-kernel/python/sgl_kernel/utils.py similarity index 100% rename from sgl-kernel/src/sgl-kernel/ops/utils.py rename to sgl-kernel/python/sgl_kernel/utils.py diff --git a/sgl-kernel/src/sgl-kernel/version.py b/sgl-kernel/python/sgl_kernel/version.py similarity index 100% rename from sgl-kernel/src/sgl-kernel/version.py rename to sgl-kernel/python/sgl_kernel/version.py diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 545ff1bfc55..72d710b3dc3 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -48,16 +48,16 @@ def _get_version(): return line.split("=")[1].strip().strip('"') -operator_namespace = "sgl_kernels" +operator_namespace = "sgl_kernel" cutlass_default = root / "3rdparty" / "cutlass" cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" turbomind = root / "3rdparty" / "turbomind" include_dirs = [ + root / "include", + root / "csrc", cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", - root / "src" / "sgl-kernel" / "include", - root / "src" / "sgl-kernel" / "csrc", flashinfer.resolve() / "include", flashinfer.resolve() / "include" / "gemm", flashinfer.resolve() / "csrc", @@ -96,21 +96,21 @@ def _get_version(): ] sources = [ - "src/sgl-kernel/torch_extension.cc", - "src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu", - "src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu", - "src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu", - "src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu", - "src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu", - "src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu", - "src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu", - "src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu", - "src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu", - "src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu", - "src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu", - "src/sgl-kernel/csrc/moe/moe_align_kernel.cu", - "src/sgl-kernel/csrc/speculative/eagle_utils.cu", - "src/sgl-kernel/csrc/speculative/speculative_sampling.cu", + "csrc/allreduce/trt_reduce_internal.cu", + "csrc/allreduce/trt_reduce_kernel.cu", + "csrc/attention/lightning_attention_decode_kernel.cu", + "csrc/elementwise/fused_add_rms_norm_kernel.cu", + "csrc/gemm/cublas_grouped_gemm.cu", + "csrc/gemm/fp8_gemm_kernel.cu", + "csrc/gemm/fp8_blockwise_gemm_kernel.cu", + "csrc/gemm/int8_gemm_kernel.cu", + "csrc/gemm/per_token_group_quant_fp8.cu", + "csrc/gemm/per_token_quant_fp8.cu", + "csrc/gemm/per_tensor_quant_fp8.cu", + "csrc/moe/moe_align_kernel.cu", + "csrc/speculative/eagle_utils.cu", + "csrc/speculative/speculative_sampling.cu", + "csrc/torch_extension.cc", "3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/norm.cu", @@ -158,7 +158,7 @@ def _get_version(): ext_modules = [ CUDAExtension( - name="sgl_kernel.ops._kernels", + name="sgl_kernel.common_ops", sources=sources, include_dirs=include_dirs, extra_compile_args={ @@ -174,8 +174,8 @@ def _get_version(): setup( name="sgl-kernel", version=_get_version(), - packages=find_packages(), - package_dir={"": "src"}, + packages=find_packages(where="python"), + package_dir={"": "python"}, ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)}, options={"bdist_wheel": {"py_limited_api": "cp39"}}, diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index 9185e4ae15a..25484ae7aab 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -13,12 +13,9 @@ # limitations under the License. # ============================================================================== -import multiprocessing -import os import sys from pathlib import Path -import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension @@ -35,16 +32,16 @@ def _get_version(): return line.split("=")[1].strip().strip('"') -operator_namespace = "sgl_kernels" +operator_namespace = "sgl_kernel" include_dirs = [ - root / "src" / "sgl-kernel" / "include", - root / "src" / "sgl-kernel" / "csrc", + root / "include", + root / "csrc", ] sources = [ - "src/sgl-kernel/torch_extension_rocm.cc", - "src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip", - "src/sgl-kernel/csrc/moe/moe_align_kernel.cu", + "csrc/allreduce/custom_all_reduce.hip", + "csrc/moe/moe_align_kernel.cu", + "csrc/torch_extension_rocm.cc", ] cxx_flags = ["-O3"] @@ -64,26 +61,27 @@ def _get_version(): "-DENABLE_FP8", ] +ext_modules = [ + CUDAExtension( + name="sgl_kernel.common_ops", + sources=sources, + include_dirs=include_dirs, + extra_compile_args={ + "nvcc": hipcc_flags, + "cxx": cxx_flags, + }, + libraries=libraries, + extra_link_args=extra_link_args, + py_limited_api=True, + ), +] + setup( name="sgl-kernel", version=_get_version(), packages=find_packages(), - package_dir={"": "src"}, - ext_modules=[ - CUDAExtension( - name="sgl_kernel.ops._kernels", - sources=sources, - include_dirs=include_dirs, - extra_compile_args={ - "nvcc": hipcc_flags, - "cxx": cxx_flags, - }, - libraries=libraries, - extra_link_args=extra_link_args, - py_limited_api=True, - ), - ], + package_dir={"": "python"}, + ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)}, options={"bdist_wheel": {"py_limited_api": "cp39"}}, - install_requires=["torch"], ) diff --git a/sgl-kernel/tests/test_trt_allreduce.py b/sgl-kernel/tests/test_trt_allreduce.py index 0387637ab67..9bbc4e76fa8 100644 --- a/sgl-kernel/tests/test_trt_allreduce.py +++ b/sgl-kernel/tests/test_trt_allreduce.py @@ -7,7 +7,7 @@ from typing import Any, List, Optional import ray -import sgl_kernel.ops.allreduce as custom_ops +import sgl_kernel.allreduce as custom_ops import torch import torch.distributed as dist from torch.distributed import ProcessGroup From 5c7dd14ba1547a2f8a4f6793d9f08daf4114a344 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 8 Mar 2025 23:01:59 -0800 Subject: [PATCH 19/27] chore: bump v0.0.4 for sgl-kernel (#4223) --- sgl-kernel/pyproject.toml | 2 +- sgl-kernel/python/sgl_kernel/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 6c7eb3e60a0..cdc7f936c1b 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.3.post7" +version = "0.0.4" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/python/sgl_kernel/version.py b/sgl-kernel/python/sgl_kernel/version.py index 4c51b84c6e6..81f0fdeccf6 100644 --- a/sgl-kernel/python/sgl_kernel/version.py +++ b/sgl-kernel/python/sgl_kernel/version.py @@ -1 +1 @@ -__version__ = "0.0.3.post7" +__version__ = "0.0.4" From 1361ab9e0363b52160b7d363b49675d5e91f21ab Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 8 Mar 2025 23:39:26 -0800 Subject: [PATCH 20/27] Lazily import lora backends (#4225) --- python/sglang/srt/lora/backend/__init__.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/lora/backend/__init__.py b/python/sglang/srt/lora/backend/__init__.py index 07fe11d23bf..7b76f90e52e 100644 --- a/python/sglang/srt/lora/backend/__init__.py +++ b/python/sglang/srt/lora/backend/__init__.py @@ -1,23 +1,20 @@ -from .base_backend import BaseLoRABackend -from .flashinfer_backend import FlashInferLoRABackend -from .triton_backend import TritonLoRABackend +from sglang.srt.lora.backend.base_backend import BaseLoRABackend def get_backend_from_name(name: str) -> BaseLoRABackend: """ Get corresponding backend class from backend's name """ - backend_mapping = { - "triton": TritonLoRABackend, - "flashinfer": FlashInferLoRABackend, - } + if name == "triton": + from sglang.srt.lora.backend.triton_backend import TritonLoRABackend - if name in backend_mapping: - return backend_mapping[name] + return TritonLoRABackend + elif name == "flashinfer": + from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend - raise Exception( - f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}" - ) + return FlashInferLoRABackend + else: + raise ValueError(f"Invalid backend: {name}") __all__ = [ From 0e90ae628a07499936295d19793cee102ddfea8e Mon Sep 17 00:00:00 2001 From: Peter Pan Date: Sun, 9 Mar 2025 15:41:20 +0800 Subject: [PATCH 21/27] [docker] Distributed Serving with k8s Statefulset ( good example for DeepSeek-R1) (#3631) Signed-off-by: Peter Pan Co-authored-by: Kebe --- docker/k8s-sglang-distributed-sts.yaml | 104 +++++++++++++++++++++++++ docs/start/install.md | 16 +++- 2 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 docker/k8s-sglang-distributed-sts.yaml diff --git a/docker/k8s-sglang-distributed-sts.yaml b/docker/k8s-sglang-distributed-sts.yaml new file mode 100644 index 00000000000..6b81d9b14df --- /dev/null +++ b/docker/k8s-sglang-distributed-sts.yaml @@ -0,0 +1,104 @@ +# Two Nodes Sglang example + +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: distributed-sglang +spec: + replicas: 2 # number of nodes/pods to run distributed sglang + selector: + matchLabels: + app: distributed-sglang + serviceName: "" + template: + metadata: + labels: + app: distributed-sglang + spec: + containers: + - name: sglang-container + image: docker.io/lmsysorg/sglang:latest + imagePullPolicy: Always # image may be replaced by official CI versioned image + command: + - /bin/bash + - -c + # please modify the sglang serving arguments below, as necessary. + # NOTE: the --expert-parallel-size and --enable-ep-moe are for MoE model like DeepSeek-R1 + args: + - | + python3 -m sglang.launch_server \ + --model /llm-folder \ + --dist-init-addr sglang-master-pod:5000 \ + --tensor-parallel-size 16 \ + --nnodes 2 \ + --node-rank $POD_INDEX \ + --trust-remote-code \ + --host 0.0.0.0 \ + --port 8000 \ + --enable-metrics \ + --enable-ep-moe \ + --expert-parallel-size 16 + env: + - name: POD_INDEX # reflects the node-rank + valueFrom: + fieldRef: + apiVersion: v1 + fieldPath: metadata.labels['apps.kubernetes.io/pod-index'] + - name: NCCL_DEBUG + value: INFO + resources: + limits: + nvidia.com/gpu: "8" + requests: + volumeMounts: + - mountPath: /dev/shm + name: dshm + - mountPath: /llm-folder + name: llm + securityContext: + privileged: true # to leverage RDMA/InfiniBand device, co-work with HostNetwork=true + hostNetwork: true + volumes: + - emptyDir: + medium: Memory + sizeLimit: 10Gi + name: dshm + - hostPath: + path: /llm-folder # replace with PVC or hostPath with your model weights + type: DirectoryOrCreate + name: llm + #- persistentVolumeClaim: + # claimName: llm-pvc + # name: llm +--- +apiVersion: v1 +kind: Service +metadata: + name: sglang-master-pod +spec: + type: ClusterIP + selector: + app: distributed-sglang + apps.kubernetes.io/pod-index: "0" + ports: + - name: dist-port + port: 5000 + targetPort: 5000 +--- +# the serving service +apiVersion: v1 +kind: Service +metadata: + name: sglang-serving-on-master +spec: + type: NodePort + selector: + app: distributed-sglang + apps.kubernetes.io/pod-index: "0" + ports: + - name: serving + port: 8000 + targetPort: 8000 + - name: metrics + port: 8080 + targetPort: 8080 diff --git a/docs/start/install.md b/docs/start/install.md index fe460e044b3..f7234c0a660 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -98,7 +98,21 @@ drun v0.4.3.post4-rocm630 python3 -m sglang.bench_one_batch --batch-size 32 --in 2. Execute the command `docker compose up -d` in your terminal. -## Method 5: Run on Kubernetes or Clouds with SkyPilot +## Method 5: Using Kubernetes + +
+More + +1. Option 1: For single node serving (typically when the model size fits into GPUs on one node) + Execute command `kubectl apply -f docker/k8s-sglang-service.yaml`, to create k8s deployment and service, with llama-31-8b as example. + +2. Option 2: For multi-node serving (usually when a large model requires more than one GPU node, such as `DeepSeek-R1`) + Modify the LLM model path and arguments as necessary, then execute command `kubectl apply -f docker/k8s-sglang-distributed-sts.yaml`, to create two nodes k8s statefulset and serving service. +
+ + + +## Method 6: Run on Kubernetes or Clouds with SkyPilot
More From dceb256f1ba137b365052b412f5b3c4e3a148de8 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Sat, 8 Mar 2025 23:41:40 -0800 Subject: [PATCH 22/27] [docs] Unhide production metrics page (#4193) --- docs/references/general.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/references/general.rst b/docs/references/general.rst index fedb2be764d..8ea335d84b1 100644 --- a/docs/references/general.rst +++ b/docs/references/general.rst @@ -11,3 +11,4 @@ General Guidance faq.md learn_more.md modelscope.md + production_metrics.md From 89ccb533ad390d9cf9c75dacf6eec4130901aede Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 8 Mar 2025 23:43:09 -0800 Subject: [PATCH 23/27] use sgl-kernel 0.0.4 (#4224) --- python/pyproject.toml | 2 +- scripts/ci_install_dependency.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 6eaa6263bef..6078abdd278 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -44,7 +44,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.0.3.post6", + "sgl-kernel==0.0.4", "flashinfer_python==0.2.2.post1", "torch==2.5.1", "vllm>=0.6.4.post1,<=0.7.2", diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index c187bf81b2e..408adbaf8f0 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -26,4 +26,4 @@ pip install transformers==4.45.2 sentence_transformers accelerate peft pandas da pip install cuda-python nvidia-cuda-nvrtc-cu12 # reinstall sgl-kernel -pip install sgl-kernel==0.0.3.post6 --force-reinstall --no-deps +pip install sgl-kernel==0.0.4 --force-reinstall --no-deps From 9fb48f951f8620d4ad255a2abb2b2aa779ab9d9a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 9 Mar 2025 00:01:54 -0800 Subject: [PATCH 24/27] Support nextn for flashinfer mla attention backend (#4218) --- docs/references/deepseek.md | 2 +- .../attention/flashinfer_mla_backend.py | 374 +++++++++++++++--- python/sglang/srt/models/deepseek_v2.py | 2 + python/sglang/srt/speculative/eagle_worker.py | 10 + test/srt/test_mla_flashinfer.py | 63 +++ 5 files changed, 393 insertions(+), 58 deletions(-) diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 2b6836d5c2d..6289fa35791 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -84,7 +84,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. -- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. +- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. Currently when using flashinfer mla wrapper and speculative decoding together, the `speculative_eagle_topk` parameter should be set to 1. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 9e81acc6f90..9af027bd1fd 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -11,9 +11,10 @@ from dataclasses import dataclass from functools import partial -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import torch +import triton from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -23,6 +24,7 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.utils import is_flashinfer_available if TYPE_CHECKING: @@ -58,12 +60,16 @@ class FlashInferMLAAttnBackend(AttentionBackend): def __init__( self, model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + q_indptr_decode_buf: Optional[torch.Tensor] = None, ): super().__init__() # Parse constants self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device + self.skip_prefill = skip_prefill global_config.enable_flashinfer_mla = True @@ -78,35 +84,51 @@ def __init__( self.workspace_buffer = global_workspace_buffer max_bs = model_runner.req_to_token_pool.size - self.kv_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) + if kv_indptr_buf is None: + self.kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + else: + self.kv_indptr = kv_indptr_buf - self.qo_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) + if not self.skip_prefill: + self.qo_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) - self.q_indptr_decode = torch.arange( - 0, max_bs + 1, dtype=torch.int32, device=model_runner.device - ) + if q_indptr_decode_buf is None: + self.q_indptr_decode = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=model_runner.device + ) + else: + self.q_indptr_decode = q_indptr_decode_buf self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.workspace_buffer, "NHD" ) - self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper( - self.workspace_buffer, - backend="auto", - ) + if not self.skip_prefill: + self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + backend="auto", + ) + + # FlashinferMLA backend uses mla wrapper for target verify + self.prefill_wrapper_verify = BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + backend="auto", + ) self.decode_wrapper = BatchMLAPagedAttentionWrapper( self.workspace_buffer, backend="auto" ) # Create indices updater - self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill( - model_runner, self - ) + if not skip_prefill: + self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill( + model_runner, self + ) + self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode( model_runner, self ) @@ -114,7 +136,7 @@ def __init__( # Other metadata self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None self.decode_cuda_graph_metadata = {} - self.prefill_cuda_graph_metadata = {} + self.prefill_cuda_graph_metadata = {} # For verify def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode_or_idle(): @@ -126,6 +148,28 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): init_metadata_replay=False, ) self.forward_metadata = DecodeMetadata(self.decode_wrapper) + elif forward_batch.forward_mode.is_draft_extend(): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + prefill_wrapper_paged=self.prefill_wrapper_paged, + use_ragged=False, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = PrefillMetadata(self.prefill_wrapper_paged, False) + elif forward_batch.forward_mode.is_target_verify(): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + prefill_wrapper_paged=self.prefill_wrapper_verify, + use_ragged=False, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = PrefillMetadata(self.prefill_wrapper_verify, False) else: prefix_lens = forward_batch.extend_prefix_lens extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) @@ -202,10 +246,33 @@ def init_forward_metadata_capture_cuda_graph( seq_lens_sum, decode_wrapper=decode_wrapper, init_metadata_replay=False, + spec_info=spec_info, ) self.decode_cuda_graph_metadata[bs] = decode_wrapper self.forward_metadata = DecodeMetadata(decode_wrapper) decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper) + elif forward_mode.is_target_verify(): + verify_wrapper = BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + use_cuda_graph=True, + qo_indptr=self.cuda_graph_qo_indptr[: bs + 1], + kv_indptr=self.cuda_graph_kv_indptr[: bs + 1], + kv_indices=self.cuda_graph_kv_indices, + kv_len_arr=self.cuda_graph_kv_lens[:bs], + backend="auto", + ) + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_prefill.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + prefix_lens=None, + prefill_wrapper_paged=verify_wrapper, + use_ragged=False, + spec_info=spec_info, + ) + self.prefill_cuda_graph_metadata[bs] = verify_wrapper + self.forward_metadata = PrefillMetadata(verify_wrapper, False) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -221,6 +288,7 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_cpu: Optional[torch.Tensor], ): if forward_mode.is_decode_or_idle(): + assert seq_lens_cpu is not None kv_len_arr_cpu = seq_lens_cpu[:bs] self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum( kv_len_arr_cpu, dim=0 @@ -239,8 +307,19 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_sum, decode_wrapper=self.decode_cuda_graph_metadata[bs], init_metadata_replay=True, + spec_info=spec_info, **self.fast_decode_kwargs, ) + elif forward_mode.is_target_verify(): + self.indices_updater_prefill.update( + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_sum, + prefix_lens=None, + prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs], + use_ragged=False, + spec_info=spec_info, + ) else: raise ValueError(f"Invalid forward mode: {forward_mode=}") @@ -254,7 +333,7 @@ def forward_extend( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, - save_kv_cache=True, + save_kv_cache: bool = True, ): cache_loc = forward_batch.out_cache_loc @@ -297,7 +376,7 @@ def forward_decode( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, - save_kv_cache=True, + save_kv_cache: bool = True, ): decode_wrapper = self.forward_metadata.decode_wrapper cache_loc = forward_batch.out_cache_loc @@ -349,6 +428,7 @@ def update( seq_lens_sum: int, decode_wrapper: BatchMLAPagedAttentionWrapper, init_metadata_replay: bool = False, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, **fast_decode_kwargs, ): decode_wrapper = decode_wrapper or self.decode_wrapper @@ -360,6 +440,7 @@ def update( self.q_indptr, self.kv_indptr, init_metadata_replay, + spec_info, **fast_decode_kwargs, ) @@ -372,30 +453,33 @@ def call_begin_forward( q_indptr: torch.Tensor, kv_indptr: torch.Tensor, init_metadata_replay: bool = False, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, **fast_decode_kwargs, ): bs = len(req_pool_indices) q_indptr = q_indptr[: bs + 1] - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = ( - torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda") - if not init_metadata_replay - else fast_decode_kwargs["kv_indices"] - ) - kv_lens = paged_kernel_lens.to(torch.int32) sm_scale = self.scaling + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = ( + torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda") + if not init_metadata_replay + else fast_decode_kwargs["kv_indices"] + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.shape[1], + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.shape[1], - ) if not init_metadata_replay: wrapper.plan( q_indptr, @@ -457,6 +541,7 @@ def update( prefix_lens: torch.Tensor, prefill_wrapper_paged: BatchMLAPagedAttentionWrapper, use_ragged: bool, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, ): if use_ragged: paged_kernel_lens = prefix_lens @@ -476,6 +561,7 @@ def update( self.kv_indptr, self.qo_indptr, use_ragged, + spec_info, ) def call_begin_forward( @@ -490,29 +576,46 @@ def call_begin_forward( kv_indptr: torch.Tensor, qo_indptr: torch.Tensor, use_ragged: bool, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, ): - bs = len(req_pool_indices) - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum, - dtype=torch.int32, - device=req_pool_indices.device, - ) - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.shape[1], - ) - - qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) - qo_indptr = qo_indptr[: bs + 1] + bs = len(seq_lens) sm_scale = self.scaling + if spec_info is None: + assert len(seq_lens) == len(req_pool_indices) + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + paged_kernel_lens_sum, + dtype=torch.int32, + device=req_pool_indices.device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.shape[1], + ) + qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + custom_mask = None + else: + assert isinstance(spec_info, EagleDraftInput) or isinstance( + spec_info, EagleVerifyInput + ) + # TODO: Support topk > 1 with custom mask + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + req_pool_indices, + paged_kernel_lens, + paged_kernel_lens_sum, + self.req_to_token, + ) + ) + if use_ragged: # ragged prefill wrapper_ragged.begin_forward( @@ -543,6 +646,163 @@ def call_begin_forward( ) +class FlashInferMLAMultiStepDraftBackend: + """ + Wrap multiple flashinfer mla attention backends as one for multiple consecutive + draft decoding steps. + """ + + def __init__( + self, + model_runner: ModelRunner, + topk: int, + speculative_num_steps: int, + ): + from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices + + if topk > 1: + raise ValueError( + f"Currently Flashinfer MLA only supports topk=1 for speculative decoding" + ) + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices + + max_bs = model_runner.req_to_token_pool.size * self.topk + self.kv_indptr = torch.zeros( + ( + self.speculative_num_steps, + max_bs + 1, + ), + dtype=torch.int32, + device=model_runner.device, + ) + self.q_indptr_decode = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=model_runner.device + ) + + self.attn_backends = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + FlashInferMLAAttnBackend( + model_runner, + skip_prefill=True, + kv_indptr_buf=self.kv_indptr[i], + q_indptr_decode_buf=self.q_indptr_decode, + ) + ) + + self.max_context_len = self.attn_backends[0].max_context_len + + # Cached variables for generate_draft_decode_kv_indices + self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] + + def common_template( + self, + forward_batch: ForwardBatch, + kv_indices_buffer: torch.Tensor, + call_fn: Callable, + ): + num_seqs = forward_batch.batch_size + bs = self.topk * num_seqs + seq_lens_sum = forward_batch.seq_lens_sum + + self.generate_draft_decode_kv_indices[ + (self.speculative_num_steps, num_seqs, self.topk) + ]( + forward_batch.req_pool_indices, + forward_batch.req_to_token_pool.req_to_token, + forward_batch.seq_lens, + kv_indices_buffer, + self.kv_indptr, + forward_batch.positions, + num_seqs, + self.topk, + self.pool_len, + kv_indices_buffer.shape[1], + self.kv_indptr.shape[1], + triton.next_power_of_2(num_seqs), + triton.next_power_of_2(self.speculative_num_steps), + triton.next_power_of_2(bs), + ) + + assert forward_batch.spec_info is not None + assert isinstance(forward_batch.spec_info, EagleDraftInput) + + for i in range(self.speculative_num_steps - 1): + forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] + forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ + : seq_lens_sum * self.topk + bs * (i + 1) + ] + call_fn(i, forward_batch) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + kv_indices = torch.zeros( + ( + self.speculative_num_steps, + forward_batch.batch_size * self.topk * self.max_context_len, + ), + dtype=torch.int32, + device="cuda", + ) + + def call_fn(i, forward_batch): + assert forward_batch.spec_info is not None + assert isinstance(forward_batch.spec_info, EagleDraftInput) + forward_batch.spec_info.kv_indptr = ( + forward_batch.spec_info.kv_indptr.clone() + ) + forward_batch.spec_info.kv_indices = ( + forward_batch.spec_info.kv_indices.clone() + ) + self.attn_backends[i].init_forward_metadata(forward_batch) + + self.common_template(forward_batch, kv_indices, call_fn) + + def init_cuda_graph_state(self, max_bs: int): + self.cuda_graph_kv_indices = torch.zeros( + (self.speculative_num_steps, max_bs * self.max_context_len), + dtype=torch.int32, + device="cuda", + ) + + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state( + max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] + ) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) + + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + bs, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + seq_lens_sum=-1, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=forward_batch.decode_seq_lens_cpu, + ) + + self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) + + def fast_mla_decode_plan( self, qo_indptr_cpu: torch.Tensor, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 13544007e77..82c73ec94db 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -555,6 +555,8 @@ def no_absorb() -> bool: return ( not global_server_args_dict["flashinfer_mla_disable_ragged"] and forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() and forward_batch.extend_prefix_lens.sum() == 0 ) else: diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index bd2fa600915..90d47cc0fd3 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -123,6 +123,16 @@ def init_attention_backend(self): self.topk, self.speculative_num_steps, ) + elif self.server_args.attention_backend == "flashinfer_mla": + from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAMultiStepDraftBackend, + ) + + self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend( + self.model_runner, + self.topk, + self.speculative_num_steps, + ) else: raise ValueError( f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}" diff --git a/test/srt/test_mla_flashinfer.py b/test/srt/test_mla_flashinfer.py index 04586acc5a9..e7113d03d8b 100644 --- a/test/srt/test_mla_flashinfer.py +++ b/test/srt/test_mla_flashinfer.py @@ -1,6 +1,7 @@ import unittest from types import SimpleNamespace +import requests import torch from sglang.srt.utils import kill_process_tree @@ -100,5 +101,67 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.62) +class TestFlashinferMLAMTP(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "lmsys/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend( + [ + "--cuda-graph-max-bs", + "2", + "--disable-radix", + "--enable-torch-compile", + "--torch-compile-max-bs", + "1", + "--speculative-algorithm", + "EAGLE", + "--speculative-draft", + "lmsys/sglang-ci-dsv3-test-NextN", + "--speculative-num-steps", + "4", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--enable-flashinfer-mla", + ] + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 2.5) + + if __name__ == "__main__": unittest.main() From 0dd6cda2886c70761cd1d05161cbd76ccc21564a Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Sun, 9 Mar 2025 16:03:32 +0800 Subject: [PATCH 25/27] Apply sgl w8a8 fp8 kernel (#3148) --- python/sglang/srt/configs/model_config.py | 4 +- python/sglang/srt/layers/linear.py | 2 +- python/sglang/srt/layers/parameter.py | 10 ++ .../srt/layers/quantization/__init__.py | 2 + .../srt/layers/quantization/blockwise_int8.py | 3 +- python/sglang/srt/layers/quantization/fp8.py | 41 ++--- .../srt/layers/quantization/fp8_kernel.py | 133 ++++++++++++++- .../srt/layers/quantization/fp8_utils.py | 156 +++++++++++++++++- .../srt/layers/quantization/modelopt_quant.py | 6 +- .../srt/layers/quantization/w8a8_fp8.py | 126 ++++++++++++++ python/sglang/srt/server_args.py | 1 + python/sglang/srt/utils.py | 8 + python/sglang/test/test_block_fp8.py | 68 +++++++- 13 files changed, 523 insertions(+), 37 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/w8a8_fp8.py diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 6f103bcc603..489cc6d4b05 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -250,9 +250,11 @@ def _verify_quantization(self) -> None: "compressed-tensors", "experts_int8", "w8a8_int8", + "w8a8_fp8", ] compatible_quantization_methods = { - "w8a8_int8": ["compressed-tensors", "compressed_tensors"] + "w8a8_int8": ["compressed-tensors", "compressed_tensors"], + "w8a8_fp8": ["compressed-tensors", "compressed_tensors"], } if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 919bcced3a8..85748fa7434 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -18,6 +18,7 @@ ) from sglang.srt.layers.parameter import ( BasevLLMParameter, + BlockQuantScaleParameter, PackedColumnParameter, PackedvLLMParameter, PerTensorScaleParameter, @@ -27,7 +28,6 @@ QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter from sglang.srt.utils import set_weight_attrs logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index 78be6798254..b3fc6b440c4 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -16,6 +16,7 @@ "ModelWeightParameter", "ChannelQuantScaleParameter", "GroupQuantScaleParameter", + "BlockQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", ] @@ -221,6 +222,15 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter): pass +class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + block-wise quantization. Uses both column and row parallelism. + """ + + pass + + class PerTensorScaleParameter(BasevLLMParameter): """ Parameter class for scales where the number of scales is diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 1ef8f43816f..c09fb5a1a00 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config +from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { @@ -50,6 +51,7 @@ "qqq": QQQConfig, "experts_int8": ExpertsInt8Config, "w8a8_int8": W8A8Int8Config, + "w8a8_fp8": W8A8Fp8Config, } diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index 1470ca427b5..ce526cd6a9b 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -13,12 +13,11 @@ LinearMethodBase, UnquantizedLinearMethod, ) -from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter +from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear from sglang.srt.utils import set_weight_attrs diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index e296756b54b..44a3cba8ad0 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -16,9 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, - apply_fp8_linear, convert_to_channelwise, - cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale, ) @@ -29,14 +27,21 @@ LinearMethodBase, UnquantizedLinearMethod, ) -from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter +from sglang.srt.layers.parameter import ( + BlockQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_utils import ( - BlockQuantScaleParameter, + apply_fp8_linear, apply_w8a8_block_fp8_linear, + cutlass_fp8_supported, + input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.utils import ( @@ -305,15 +310,15 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) - - # If using marlin (w8a16), kernel uses channelwise weights, - # so extend the weight scales to be channelwise. - if self.use_marlin: - assert weight_scale.numel() == 1 - weight_scale = convert_to_channelwise( - weight_scale.expand(len(layer.logical_widths)), layer.logical_widths + if self.cutlass_fp8_supported or self.use_marlin: + # apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale + qweight, weight_scale = per_token_group_quant_fp8( + layer.weight, layer.weight.shape[-1] ) + weight_scale = weight_scale.t().contiguous() + else: + # per-tensor quantization + qweight, weight_scale = input_to_float8(layer.weight) # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) @@ -330,23 +335,19 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.input_scale = torch.nn.Parameter( layer.input_scale.data, requires_grad=False ) - # If using marlin (w8a16), kernel uses channelwise weights, - # so extend the weight scales to be channelwise. - if self.use_marlin: + + # cutlass sgl-kernel and marlin only support per-channel scale + if self.cutlass_fp8_supported or self.use_marlin: weight = layer.weight weight_scale = convert_to_channelwise( layer.weight_scale, layer.logical_widths ) - - # If using w8a8, torch._scaled_mm needs per tensor, so - # requantize the logical shards as a single weight. else: # Dequant -> Quant with max scale so we can run per tensor. weight = layer.weight weight_scale = layer.weight_scale - # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip_: + if is_hip(): weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale, diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 47f310a24de..54c07f90940 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -29,7 +29,7 @@ _is_cuda = torch.cuda.is_available() and torch.version.cuda if _is_cuda: - from sgl_kernel import sgl_per_token_group_quant_fp8 + from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 logger = logging.getLogger(__name__) @@ -70,7 +70,8 @@ def _per_token_group_quant_fp8( # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) y_s = _absmax / fp8_max - y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + y_s_inv = 1.0 / y_s + y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) @@ -140,7 +141,7 @@ def per_token_group_quant_fp8( x: The input tenosr with ndim >= 2. group_size: The group size used for quantization. eps: The minimum to avoid dividing zero. - dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. + dtype: The dype of output tensor. Returns: Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. @@ -241,6 +242,132 @@ def sglang_per_token_group_quant_fp8( return x_q, x_s +def sglang_per_token_quant_fp8( + x: torch.Tensor, + dtype: torch.dtype = fp8_type_, +): + assert x.is_contiguous(), "`x` is not contiguous" + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + x_s = torch.empty( + x.shape[0], + 1, + device=x.device, + dtype=torch.float32, + ) + + sgl_per_token_quant_fp8(x, x_q, x_s) + + return x_q, x_s + + +@triton.jit +def _static_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + y_s_repeat_ptr, + # Stride of input + y_stride, + # Collums of input + N, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, + REPEAT_SCALE: tl.constexpr, +): + """A Triton-accelerated function to perform quantization using the given scale on a + tensor + + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + if REPEAT_SCALE: + y_s_repeat_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + y_s = tl.load(y_s_ptr).to(tl.float32) + y_s_inv = 1.0 / y_s + y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + if REPEAT_SCALE: + tl.store(y_s_repeat_ptr, y_s) + + +def static_quant_fp8( + x: torch.Tensor, + x_s: torch.Tensor, + repeat_scale: bool = False, + dtype: torch.dtype = fp8_type_, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform static quantization using the given scale on an input tensor `x`. + + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + + Args: + x: The input tenosr with ndim >= 2. + x_s: The quantization scale. + repeat_scale: Whether to broadcast per-tensor scale to per-channel scale. + dtype: The dype of output tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. + """ + assert x.is_contiguous(), "`x` is not contiguous" + assert x_s.numel() == 1, "only supports per-tensor scale" + finfo = torch.finfo(dtype) + fp8_max = finfo.max + + if is_hip_: + fp8_max = 224.0 + + fp8_min = -fp8_max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // x.shape[-1] + N = x.shape[-1] + if repeat_scale: + x_s_repeat = torch.empty( + (M, 1), + device=x.device, + dtype=torch.float32, + ) + else: + x_s_repeat = None + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + _static_quant_fp8[(M,)]( + x, + x_q, + x_s, + x_s_repeat, + N, + N, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + REPEAT_SCALE=repeat_scale, + num_warps=num_warps, + num_stages=num_stages, + ) + x_s = x_s_repeat if repeat_scale else x_s + return x_q, x_s + + @triton.jit def _w8a8_block_fp8_matmul( # Pointers to inputs and output diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index ff10f0a5632..feaae26f6c7 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -2,13 +2,23 @@ from typing import List, Optional, Tuple import torch +from packaging.version import Version -from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, + static_quant_fp8, w8a8_block_fp8_matmul, ) -from sglang.srt.utils import get_bool_env_var, is_hip +from sglang.srt.utils import ( + get_bool_env_var, + get_cuda_version, + get_device_capability, + is_hip, +) + +use_vllm_cutlass_w8a8_fp8_kernel = os.environ.get( + "USE_VLLM_CUTLASS_W8A8_FP8_KERNEL", default=False +) is_hip_ = is_hip() if is_hip_ and get_bool_env_var("CK_MOE"): @@ -18,6 +28,25 @@ if _is_cuda: from sgl_kernel import fp8_blockwise_scaled_mm + from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8 + + if use_vllm_cutlass_w8a8_fp8_kernel: + from vllm import _custom_ops as ops + else: + from sgl_kernel import fp8_scaled_mm + + +def cutlass_fp8_supported(): + if not _is_cuda: + return False + major, minor = get_device_capability() + cuda_version = get_cuda_version() + if major >= 9: + return cuda_version >= (12, 0) + elif major == 8 and minor == 9: + return cuda_version >= (12, 4) + return False + def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, @@ -158,10 +187,121 @@ def block_quant_to_tensor_quant( return x_q_tensor, scale -class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): - """ - Parameter class for weight scales loaded for weights with - block-wise quantization. Uses both column and row parallelism. - """ +def apply_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + input_scale_ub: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_fp8_supported: bool = True, + use_per_token_if_dynamic: bool = False, +) -> torch.Tensor: + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[1]] + + # cutlass w8a8 fp8 sgl-kernel only supports per-token scale + if input_scale is not None: + assert input_scale.numel() == 1 + # broadcast per-tensor scale to per-token scale when supporting cutlass + qinput, x_scale = static_quant_fp8( + input_2d, input_scale, repeat_scale=cutlass_fp8_supported + ) + else: + # default use per-token quantization if dynamic + if _is_cuda: + qinput, x_scale = sglang_per_token_quant_fp8(input_2d) + else: + qinput, x_scale = per_token_group_quant_fp8( + input_2d, group_size=input_2d.shape[1] + ) + + if cutlass_fp8_supported: + if use_vllm_cutlass_w8a8_fp8_kernel: + # Fall back to vllm cutlass w8a8 fp8 kernel + output = ops.cutlass_scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) + else: + assert ( + weight_scale.numel() == weight.shape[1] + ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale" + output = fp8_scaled_mm( + qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias + ) + return output.view(*output_shape) + + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token + else: + per_tensor_weights = weight_scale.numel() == 1 + per_tensor_activations = x_scale.numel() == 1 + + if per_tensor_weights and per_tensor_activations: + # Fused GEMM_DQ + output = torch._scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + else: + # Fallback for channelwise case, where we use unfused DQ + # due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # Making sure the dummy tensor is on the same device as the weight + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY.device != weight.device: + TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm( + qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) - pass + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * weight_scale.t() + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index a28e0aeea04..c26012da21e 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -7,7 +7,7 @@ from torch.nn.parameter import Parameter from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, + convert_to_channelwise, cutlass_fp8_supported, requantize_with_max_scale, ) @@ -19,6 +19,7 @@ QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear # Initialize logger for the module logger = logging.getLogger(__name__) @@ -161,6 +162,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight, layer.weight_scale, layer.logical_widths ) layer.weight = Parameter(quantized_weight.t(), requires_grad=False) + # cutlass sgl-kernel only supports per-channel scale + if self.cutlass_fp8_supported: + max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py new file mode 100644 index 00000000000..0adedc68fcd --- /dev/null +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -0,0 +1,126 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from sglang.srt.layers.linear import LinearMethodBase +from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.fp8_utils import ( + apply_fp8_linear, + cutlass_fp8_supported, + normalize_e4m3fn_to_e4m3fnuz, +) +from sglang.srt.utils import is_hip + + +class W8A8Fp8Config(QuantizationConfig): + """Config class for W8A8 FP8 Quantization. + + - Weight: static, per-channel, symmetric + - Activation: dynamic, per-token, symmetric + """ + + def __init__(self): + pass + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 89 + + @classmethod + def get_name(self) -> str: + return "w8a8_fp8" + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config": + return cls() + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + from sglang.srt.layers.linear import LinearBase + + if isinstance(layer, LinearBase): + return W8A8Fp8LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class W8A8Fp8LinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: W8A8Fp8Config): + self.cutlass_fp8_supported = cutlass_fp8_supported() + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight = layer.weight + weight_scale = layer.weight_scale.detach() + if is_hip(): + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale + ) + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs + ): + + weight_loader = extra_weight_attrs.get("weight_loader") + self.logical_widths = output_partition_sizes + + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + return apply_fp8_linear( + x, + layer.weight, + layer.weight_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported, + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c5b8b920e7f..4e6fbdd49ed 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -405,6 +405,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "gguf", "modelopt", "w8a8_int8", + "w8a8_fp8", ], help="The quantization method.", ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 1ce2862f963..8bfdbc0ed26 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -52,11 +52,13 @@ import zmq from fastapi.responses import ORJSONResponse from packaging import version as pkg_version +from packaging.version import Version, parse from starlette.routing import Mount from torch import nn from torch.func import functional_call from torch.library import Library from torch.profiler import ProfilerActivity, profile, record_function +from torch.utils.cpp_extension import CUDA_HOME from triton.runtime.cache import ( FileCacheManager, default_cache_dir, @@ -1431,6 +1433,12 @@ def rank0_print(msg: str): print(msg, flush=True) +def get_cuda_version(): + if torch.version.cuda: + return tuple(map(int, torch.version.cuda.split("."))) + return (0, 0) + + def launch_dummy_health_check_server(host, port): import uvicorn from fastapi import FastAPI, Response diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index 3a02531e695..b3da7690ce7 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -7,6 +7,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, + static_quant_fp8, w8a8_block_fp8_matmul, ) @@ -63,7 +64,7 @@ def _per_token_group_quant_fp8(self, num_tokens, d, dtype, group_size, seed): out, scale = per_token_group_quant_fp8(x, group_size) self.assertTrue( - torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15) + torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20) ) self.assertTrue(torch.allclose(scale, ref_scale)) @@ -85,6 +86,71 @@ def test_per_token_group_quant_fp8(self): self._per_token_group_quant_fp8(*params) +# For test +def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn): + """Function to perform static quantization on an input tensor `x` using native torch. + + It converts the tensor values into float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + """ + assert x.is_contiguous(), "`x` is not contiguous" + assert x_s.numel() == 1, "only supports per-tensor scale" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_ = x.reshape(x.numel() // x.shape[-1], x.shape[-1]) + x_s_inv = 1.0 / x_s + x_q = (x_ * x_s_inv).clamp(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(x.shape) + + return x_q, x_s + + +class TestStaticQuantFP8(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16, torch.float32] + NUM_TOKENS = [7, 83, 2048] + D = [512, 4096, 5120, 13824] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _static_quant_fp8(self, num_tokens, d, dtype, seed): + torch.manual_seed(seed) + + x = torch.rand(num_tokens, d, dtype=dtype) + fp8_max = torch.finfo(torch.float8_e4m3fn).max + x_s = x.max() / fp8_max + + with torch.inference_mode(): + ref_out, _ = native_static_quant_fp8(x, x_s) + out, _ = static_quant_fp8(x, x_s, repeat_scale=True) + + self.assertTrue( + torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50) + ) + + def test_static_quant_fp8(self): + for params in itertools.product( + self.NUM_TOKENS, + self.D, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + num_tokens=params[0], + d=params[1], + dtype=params[2], + seed=params[3], + ): + self._static_quant_fp8(*params) + + # For test def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): """This function performs matrix multiplication with block-wise quantization using native torch. From 34c8898755be67de5d379cdceb0173f6a2b5265c Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Sun, 9 Mar 2025 01:10:43 -0800 Subject: [PATCH 26/27] Check eagle server args (#4217) --- python/sglang/srt/server_args.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4e6fbdd49ed..480a415e8cd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -284,9 +284,13 @@ def __post_init__(self): "Overlap scheduler are disabled because of using " "eagle speculative decoding." ) - # The token generated from the verify step is counted. + # The token generated from the verify step is counted in speculative_num_draft_tokens. # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded. - # assert self.speculative_num_steps < self.speculative_num_draft_tokens + assert self.speculative_num_steps < self.speculative_num_draft_tokens + assert ( + self.speculative_num_draft_tokens - 1 + <= self.speculative_num_steps * self.speculative_eagle_topk + ) # GGUF if ( From df84ab2a5b87f4e8490049beb74fab6e67bbe3df Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 9 Mar 2025 01:16:05 -0800 Subject: [PATCH 27/27] update sgl-kernel 3rdparty (#4228) --- .gitmodules | 3 --- sgl-kernel/3rdparty/cutlass | 2 +- sgl-kernel/3rdparty/turbomind | 1 - sgl-kernel/README.md | 1 - sgl-kernel/setup.py | 3 --- 5 files changed, 1 insertion(+), 9 deletions(-) delete mode 160000 sgl-kernel/3rdparty/turbomind diff --git a/.gitmodules b/.gitmodules index 97f3421449d..ed7603bfd3c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,6 +7,3 @@ [submodule "sgl-kernel/3rdparty/flashinfer"] path = sgl-kernel/3rdparty/flashinfer url = https://github.com/flashinfer-ai/flashinfer.git -[submodule "sgl-kernel/3rdparty/turbomind"] - path = sgl-kernel/3rdparty/turbomind - url = https://github.com/InternLM/turbomind diff --git a/sgl-kernel/3rdparty/cutlass b/sgl-kernel/3rdparty/cutlass index ca4fdbea708..df18f5e4f5d 160000 --- a/sgl-kernel/3rdparty/cutlass +++ b/sgl-kernel/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit ca4fdbea708ad940c905359788372b8add9f85e0 +Subproject commit df18f5e4f5de76bed8be1de8e4c245f2f5ec3020 diff --git a/sgl-kernel/3rdparty/turbomind b/sgl-kernel/3rdparty/turbomind deleted file mode 160000 index 0c9d0c724a9..00000000000 --- a/sgl-kernel/3rdparty/turbomind +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 0c9d0c724a99974ca3af0c12b24ef8a0444c4fd9 diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 689f34be0b5..e86c2625963 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -39,7 +39,6 @@ Third-party libraries: - [CCCL](https://github.com/NVIDIA/cccl) - [CUTLASS](https://github.com/NVIDIA/cutlass) - [FlashInfer](https://github.com/flashinfer-ai/flashinfer) -- [TurboMind](https://github.com/InternLM/turbomind) ### Kernel Development diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 72d710b3dc3..d76a2668a88 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -52,7 +52,6 @@ def _get_version(): cutlass_default = root / "3rdparty" / "cutlass" cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" -turbomind = root / "3rdparty" / "turbomind" include_dirs = [ root / "include", root / "csrc", @@ -62,8 +61,6 @@ def _get_version(): flashinfer.resolve() / "include" / "gemm", flashinfer.resolve() / "csrc", "cublas", - turbomind.resolve(), - turbomind.resolve() / "src", ] nvcc_flags = [