diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml
index 6e505957685..03406ef86ff 100644
--- a/.github/workflows/pr-test-amd.yml
+++ b/.github/workflows/pr-test-amd.yml
@@ -35,12 +35,12 @@ jobs:
else
DEVICE_FLAG="--device /dev/dri"
fi
- docker pull lmsysorg/sglang:v0.4.3.post2-rocm630
+ docker pull lmsysorg/sglang:v0.4.3.post4-rocm630
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \
- lmsysorg/sglang:v0.4.3.post2-rocm630
+ lmsysorg/sglang:v0.4.3.post4-rocm630
- name: Install dependencies
run: |
@@ -71,12 +71,12 @@ jobs:
else
DEVICE_FLAG="--device /dev/dri"
fi
- docker pull lmsysorg/sglang:v0.4.3.post2-rocm630
+ docker pull lmsysorg/sglang:v0.4.3.post4-rocm630
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \
- lmsysorg/sglang:v0.4.3.post2-rocm630
+ lmsysorg/sglang:v0.4.3.post4-rocm630
- name: Install dependencies
run: |
@@ -90,11 +90,11 @@ jobs:
- name: MLA TEST
timeout-minutes: 20
run: |
- docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py TestMLA
+ docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py
finish:
needs: [
- accuracy-test-1-gpu-amd
+ accuracy-test-1-gpu-amd, mla-test-1-gpu-amd
]
runs-on: ubuntu-latest
steps:
diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml
index df059c1f402..0c38901f05a 100644
--- a/.github/workflows/pr-test-sgl-kernel.yml
+++ b/.github/workflows/pr-test-sgl-kernel.yml
@@ -27,7 +27,7 @@ jobs:
with:
source: sgl-kernel
extensions: h,c,cpp,hpp,cu,cuh,cc
- clangFormatVersion: 16
+ clangFormatVersion: 18
style: file
build-wheels:
@@ -95,8 +95,39 @@ jobs:
run: |
pip3 uninstall sgl-kernel -y
+ mla-test:
+ if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
+ needs: build-wheels
+ runs-on: 1-gpu-runner
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Download artifacts
+ uses: actions/download-artifact@v4
+ with:
+ path: sgl-kernel/dist/
+ merge-multiple: true
+ pattern: wheel-*
+
+ - name: Install
+ run: |
+ bash scripts/ci_install_dependency.sh
+ pip3 uninstall sgl-kernel -y || true
+ pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
+ pip3 list | grep sgl-kernel
+
+ - name: Run test
+ timeout-minutes: 30
+ run: |
+ cd test/srt
+ python3 test_mla_deepseek_v3.py
+
+ - name: Uninstall dependencies
+ run: |
+ pip3 uninstall sgl-kernel -y
+
finish:
- needs: [unit-test, lint]
+ needs: [unit-test, mla-test, lint]
runs-on: ubuntu-latest
steps:
- name: Finish
diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml
index 225c215c8c9..5ac06597327 100644
--- a/.github/workflows/pr-test.yml
+++ b/.github/workflows/pr-test.yml
@@ -269,6 +269,8 @@ jobs:
cd test/srt
python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1
+ USE_VLLM_CUSTOM_ALLREDUCE=0 python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1
+
- name: Benchmark single latency + torch.compile (TP=2)
timeout-minutes: 10
run: |
diff --git a/.github/workflows/release-pypi-kernel.yml b/.github/workflows/release-pypi-kernel.yml
index 495bf68c8b2..f589119e61a 100644
--- a/.github/workflows/release-pypi-kernel.yml
+++ b/.github/workflows/release-pypi-kernel.yml
@@ -5,7 +5,7 @@ on:
branches:
- main
paths:
- - sgl-kernel/src/sgl-kernel/version.py
+ - sgl-kernel/python/sgl_kernel/version.py
workflow_dispatch:
concurrency:
diff --git a/.github/workflows/release-whl-kernel.yml b/.github/workflows/release-whl-kernel.yml
index 5eaa0127fa7..631551475fe 100644
--- a/.github/workflows/release-whl-kernel.yml
+++ b/.github/workflows/release-whl-kernel.yml
@@ -9,7 +9,7 @@ on:
branches:
- main
paths:
- - sgl-kernel/src/sgl-kernel/version.py
+ - sgl-kernel/python/sgl_kernel/version.py
jobs:
build-wheels:
@@ -59,7 +59,7 @@ jobs:
id: set_tag_name
run: |
if [ -z "${{ inputs.tag_name }}" ]; then
- TAG_NAME="v$(cat sgl-kernel/src/sgl-kernel/version.py | cut -d'"' -f2)"
+ TAG_NAME="v$(cat sgl-kernel/python/sgl_kernel/version.py | cut -d'"' -f2)"
echo "tag_name=$TAG_NAME" >> $GITHUB_OUTPUT
else
echo "tag_name=${{ inputs.tag_name }}" >> $GITHUB_OUTPUT
diff --git a/.gitmodules b/.gitmodules
index 97f3421449d..ed7603bfd3c 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -7,6 +7,3 @@
[submodule "sgl-kernel/3rdparty/flashinfer"]
path = sgl-kernel/3rdparty/flashinfer
url = https://github.com/flashinfer-ai/flashinfer.git
-[submodule "sgl-kernel/3rdparty/turbomind"]
- path = sgl-kernel/3rdparty/turbomind
- url = https://github.com/InternLM/turbomind
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 3ae74a8cccb..075b1e8d92c 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -30,44 +30,19 @@ ARG CUDA_VERSION
RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \
&& git clone --depth=1 https://github.com/sgl-project/sglang.git \
&& if [ "$CUDA_VERSION" = "12.1.1" ]; then \
- python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu121; \
+ export CUINDEX=121; \
elif [ "$CUDA_VERSION" = "12.4.1" ]; then \
- python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \
+ export CUINDEX=124; \
elif [ "$CUDA_VERSION" = "12.5.1" ]; then \
- python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \
+ export CUINDEX=124; \
elif [ "$CUDA_VERSION" = "11.8.0" ]; then \
- python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118; \
- python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \
+ export CUINDEX=118; \
+ python3 -m pip install --no-cache-dir sgl-kernel -i https://docs.sglang.ai/whl/cu118; \
else \
echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \
fi \
+ && python3 -m pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cu${CUINDEX} \
&& cd sglang \
- && if [ "$BUILD_TYPE" = "srt" ]; then \
- if [ "$CUDA_VERSION" = "12.1.1" ]; then \
- python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu121/torch2.5/flashinfer-python; \
- elif [ "$CUDA_VERSION" = "12.4.1" ]; then \
- python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \
- elif [ "$CUDA_VERSION" = "12.5.1" ]; then \
- python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \
- elif [ "$CUDA_VERSION" = "11.8.0" ]; then \
- python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu118/torch2.5/flashinfer-python; \
- python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \
- else \
- echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \
- fi; \
- else \
- if [ "$CUDA_VERSION" = "12.1.1" ]; then \
- python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.5/flashinfer-python; \
- elif [ "$CUDA_VERSION" = "12.4.1" ]; then \
- python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \
- elif [ "$CUDA_VERSION" = "12.5.1" ]; then \
- python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \
- elif [ "$CUDA_VERSION" = "11.8.0" ]; then \
- python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu118/torch2.5/flashinfer-python; \
- python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \
- else \
- echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \
- fi; \
- fi
+ && python3 -m pip --no-cache-dir install -e "python[${BUILD_TYPE}]" --find-links https://flashinfer.ai/whl/cu${CUINDEX}/torch2.5/flashinfer-python
ENV DEBIAN_FRONTEND=interactive
diff --git a/docker/k8s-sglang-distributed-sts.yaml b/docker/k8s-sglang-distributed-sts.yaml
new file mode 100644
index 00000000000..6b81d9b14df
--- /dev/null
+++ b/docker/k8s-sglang-distributed-sts.yaml
@@ -0,0 +1,104 @@
+# Two Nodes Sglang example
+
+apiVersion: apps/v1
+kind: StatefulSet
+metadata:
+ name: distributed-sglang
+spec:
+ replicas: 2 # number of nodes/pods to run distributed sglang
+ selector:
+ matchLabels:
+ app: distributed-sglang
+ serviceName: ""
+ template:
+ metadata:
+ labels:
+ app: distributed-sglang
+ spec:
+ containers:
+ - name: sglang-container
+ image: docker.io/lmsysorg/sglang:latest
+ imagePullPolicy: Always # image may be replaced by official CI versioned image
+ command:
+ - /bin/bash
+ - -c
+ # please modify the sglang serving arguments below, as necessary.
+ # NOTE: the --expert-parallel-size and --enable-ep-moe are for MoE model like DeepSeek-R1
+ args:
+ - |
+ python3 -m sglang.launch_server \
+ --model /llm-folder \
+ --dist-init-addr sglang-master-pod:5000 \
+ --tensor-parallel-size 16 \
+ --nnodes 2 \
+ --node-rank $POD_INDEX \
+ --trust-remote-code \
+ --host 0.0.0.0 \
+ --port 8000 \
+ --enable-metrics \
+ --enable-ep-moe \
+ --expert-parallel-size 16
+ env:
+ - name: POD_INDEX # reflects the node-rank
+ valueFrom:
+ fieldRef:
+ apiVersion: v1
+ fieldPath: metadata.labels['apps.kubernetes.io/pod-index']
+ - name: NCCL_DEBUG
+ value: INFO
+ resources:
+ limits:
+ nvidia.com/gpu: "8"
+ requests:
+ volumeMounts:
+ - mountPath: /dev/shm
+ name: dshm
+ - mountPath: /llm-folder
+ name: llm
+ securityContext:
+ privileged: true # to leverage RDMA/InfiniBand device, co-work with HostNetwork=true
+ hostNetwork: true
+ volumes:
+ - emptyDir:
+ medium: Memory
+ sizeLimit: 10Gi
+ name: dshm
+ - hostPath:
+ path: /llm-folder # replace with PVC or hostPath with your model weights
+ type: DirectoryOrCreate
+ name: llm
+ #- persistentVolumeClaim:
+ # claimName: llm-pvc
+ # name: llm
+---
+apiVersion: v1
+kind: Service
+metadata:
+ name: sglang-master-pod
+spec:
+ type: ClusterIP
+ selector:
+ app: distributed-sglang
+ apps.kubernetes.io/pod-index: "0"
+ ports:
+ - name: dist-port
+ port: 5000
+ targetPort: 5000
+---
+# the serving service
+apiVersion: v1
+kind: Service
+metadata:
+ name: sglang-serving-on-master
+spec:
+ type: NodePort
+ selector:
+ app: distributed-sglang
+ apps.kubernetes.io/pod-index: "0"
+ ports:
+ - name: serving
+ port: 8000
+ targetPort: 8000
+ - name: metrics
+ port: 8080
+ targetPort: 8080
diff --git a/docs/backend/separate_reasoning.ipynb b/docs/backend/separate_reasoning.ipynb
index d9a927c19de..756ecbaa995 100644
--- a/docs/backend/separate_reasoning.ipynb
+++ b/docs/backend/separate_reasoning.ipynb
@@ -11,7 +11,8 @@
"## Supported Models\n",
"\n",
"Currently, SGLang supports the following reasoning models:\n",
- "- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d): The reasoning content is wrapped with `` and `` tags."
+ "- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d): The reasoning content is wrapped with `` and `` tags.\n",
+ "- [QwQ](https://huggingface.co/Qwen/QwQ-32B): The reasoning content is wrapped with `` and `` tags."
]
},
{
@@ -55,6 +56,15 @@
"wait_for_server(f\"http://localhost:{port}\")"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note that `--reasoning-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
+ "\n",
+ "- deepseek-r1: DeepSeek R1 series and QwQ (e.g. deepseek-ai/DeepSeek-R1, Qwen/QwQ-32B)."
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md
index 2b6836d5c2d..6289fa35791 100644
--- a/docs/references/deepseek.md
+++ b/docs/references/deepseek.md
@@ -84,7 +84,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
-- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off.
+- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. Currently when using flashinfer mla wrapper and speculative decoding together, the `speculative_eagle_topk` parameter should be set to 1.
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
diff --git a/docs/references/general.rst b/docs/references/general.rst
index fedb2be764d..8ea335d84b1 100644
--- a/docs/references/general.rst
+++ b/docs/references/general.rst
@@ -11,3 +11,4 @@ General Guidance
faq.md
learn_more.md
modelscope.md
+ production_metrics.md
diff --git a/docs/start/install.md b/docs/start/install.md
index fe460e044b3..f7234c0a660 100644
--- a/docs/start/install.md
+++ b/docs/start/install.md
@@ -98,7 +98,21 @@ drun v0.4.3.post4-rocm630 python3 -m sglang.bench_one_batch --batch-size 32 --in
2. Execute the command `docker compose up -d` in your terminal.
-## Method 5: Run on Kubernetes or Clouds with SkyPilot
+## Method 5: Using Kubernetes
+
+
+More
+
+1. Option 1: For single node serving (typically when the model size fits into GPUs on one node)
+ Execute command `kubectl apply -f docker/k8s-sglang-service.yaml`, to create k8s deployment and service, with llama-31-8b as example.
+
+2. Option 2: For multi-node serving (usually when a large model requires more than one GPU node, such as `DeepSeek-R1`)
+ Modify the LLM model path and arguments as necessary, then execute command `kubectl apply -f docker/k8s-sglang-distributed-sts.yaml`, to create two nodes k8s statefulset and serving service.
+
+
+
+
+## Method 6: Run on Kubernetes or Clouds with SkyPilot
More
diff --git a/python/pyproject.toml b/python/pyproject.toml
index 31a9c9893bf..b084c45d07f 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -18,12 +18,15 @@ dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"]
[project.optional-dependencies]
runtime_common = [
"aiohttp",
+ "datasets",
"decord",
"fastapi",
"hf_transfer",
"huggingface_hub",
"interegular",
+ "llguidance>=0.6.15",
"modelscope",
+ "ninja",
"orjson",
"packaging",
"pillow",
@@ -33,18 +36,15 @@ runtime_common = [
"python-multipart",
"pyzmq>=25.1.2",
"torchao>=0.7.0",
+ "transformers @ git+https://github.com/huggingface/transformers.git@v4.49.0-AyaVision",
"uvicorn",
"uvloop",
"xgrammar==0.1.14",
- "ninja",
- "transformers @ git+https://github.com/huggingface/transformers.git@84f0186",
- "llguidance>=0.6.15",
- "datasets"
]
srt = [
"sglang[runtime_common]",
- "sgl-kernel==0.0.3.post6",
+ "sgl-kernel==0.0.4",
"flashinfer_python==0.2.2.post1",
"torch==2.5.1",
"vllm>=0.6.4.post1,<=0.7.2",
diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py
index db22f42c46c..036d1e86499 100644
--- a/python/sglang/bench_serving.py
+++ b/python/sglang/bench_serving.py
@@ -1006,7 +1006,7 @@ async def limited_request_func(request_func_input, pbar):
# Flush cache
if "sglang" in backend:
- requests.post(base_url + "/flush_cache")
+ requests.post(base_url + "/flush_cache", headers=get_auth_headers())
time.sleep(1.0)
diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py
index c5056ffc272..d06765c3a8c 100644
--- a/python/sglang/srt/_custom_ops.py
+++ b/python/sglang/srt/_custom_ops.py
@@ -75,42 +75,42 @@ def init_custom_ar(
rank: int,
full_nvlink: bool,
) -> int:
- return sgl_kernel.ops.allreduce.init_custom_ar(
+ return sgl_kernel.allreduce.init_custom_ar(
meta, rank_data, handles, offsets, rank, full_nvlink
)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
- sgl_kernel.ops.allreduce.all_reduce_reg(fa, inp, out)
+ sgl_kernel.allreduce.all_reduce_reg(fa, inp, out)
def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
- sgl_kernel.ops.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out)
+ sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None:
- sgl_kernel.ops.allreduce.dispose(fa)
+ sgl_kernel.allreduce.dispose(fa)
def meta_size() -> int:
- return sgl_kernel.ops.allreduce.meta_size()
+ return sgl_kernel.allreduce.meta_size()
def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
- return sgl_kernel.ops.allreduce.register_buffer(fa, t, handles, offsets)
+ return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
- return sgl_kernel.ops.allreduce.get_graph_buffer_ipc_meta(fa)
+ return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
- sgl_kernel.ops.allreduce.register_graph_buffers(fa, handles, offsets)
+ sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor:
- return sgl_kernel.ops.allreduce.allocate_meta_buffer(size)
+ return sgl_kernel.allreduce.allocate_meta_buffer(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
- return sgl_kernel.ops.allreduce.get_meta_buffer_ipc_handle(inp)
+ return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
else:
# TRTLLM custom allreduce
@@ -123,7 +123,7 @@ def init_custom_ar(
barrier_in: List[int],
barrier_out: List[int],
) -> int:
- return sgl_kernel.ops.init_custom_reduce(
+ return sgl_kernel.init_custom_reduce(
rank_id,
world_size,
rank_data_base,
@@ -134,15 +134,15 @@ def init_custom_ar(
)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
- sgl_kernel.ops.custom_reduce(fa, inp, out)
+ sgl_kernel.custom_reduce(fa, inp, out)
def dispose(fa: int) -> None:
- sgl_kernel.ops.custom_dispose(fa)
+ sgl_kernel.custom_dispose(fa)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
- return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
+ return sgl_kernel.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
- sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
+ sgl_kernel.register_graph_buffers(fa, handles, offsets)
diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py
index 64ef15cf7fb..489cc6d4b05 100644
--- a/python/sglang/srt/configs/model_config.py
+++ b/python/sglang/srt/configs/model_config.py
@@ -81,7 +81,7 @@ def __init__(
if context_length is not None:
if context_length > derived_context_len:
if get_bool_env_var(
- "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False"
+ "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True"
):
logger.warning(
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
@@ -250,9 +250,11 @@ def _verify_quantization(self) -> None:
"compressed-tensors",
"experts_int8",
"w8a8_int8",
+ "w8a8_fp8",
]
compatible_quantization_methods = {
- "w8a8_int8": ["compressed-tensors", "compressed_tensors"]
+ "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
+ "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
}
if self.quantization is not None:
self.quantization = self.quantization.lower()
diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py
index 7c0f287b7d0..f8a6b4e431f 100644
--- a/python/sglang/srt/entrypoints/engine.py
+++ b/python/sglang/srt/entrypoints/engine.py
@@ -106,6 +106,8 @@ def __init__(self, **kwargs):
tokenizer_manager, scheduler_info = _launch_subprocesses(
server_args=server_args
)
+
+ self.server_args = server_args
self.tokenizer_manager = tokenizer_manager
self.scheduler_info = scheduler_info
diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py
index 9e81acc6f90..9af027bd1fd 100644
--- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py
@@ -11,9 +11,10 @@
from dataclasses import dataclass
from functools import partial
-from typing import TYPE_CHECKING, Optional, Union
+from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
+import triton
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -23,6 +24,7 @@
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
+from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
@@ -58,12 +60,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def __init__(
self,
model_runner: ModelRunner,
+ skip_prefill: bool = False,
+ kv_indptr_buf: Optional[torch.Tensor] = None,
+ q_indptr_decode_buf: Optional[torch.Tensor] = None,
):
super().__init__()
# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
+ self.skip_prefill = skip_prefill
global_config.enable_flashinfer_mla = True
@@ -78,35 +84,51 @@ def __init__(
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
- self.kv_indptr = torch.zeros(
- (max_bs + 1,), dtype=torch.int32, device=model_runner.device
- )
+ if kv_indptr_buf is None:
+ self.kv_indptr = torch.zeros(
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
+ )
+ else:
+ self.kv_indptr = kv_indptr_buf
- self.qo_indptr = torch.zeros(
- (max_bs + 1,), dtype=torch.int32, device=model_runner.device
- )
+ if not self.skip_prefill:
+ self.qo_indptr = torch.zeros(
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
+ )
- self.q_indptr_decode = torch.arange(
- 0, max_bs + 1, dtype=torch.int32, device=model_runner.device
- )
+ if q_indptr_decode_buf is None:
+ self.q_indptr_decode = torch.arange(
+ 0, max_bs + 1, dtype=torch.int32, device=model_runner.device
+ )
+ else:
+ self.q_indptr_decode = q_indptr_decode_buf
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
- self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
- self.workspace_buffer,
- backend="auto",
- )
+ if not self.skip_prefill:
+ self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
+ self.workspace_buffer,
+ backend="auto",
+ )
+
+ # FlashinferMLA backend uses mla wrapper for target verify
+ self.prefill_wrapper_verify = BatchMLAPagedAttentionWrapper(
+ self.workspace_buffer,
+ backend="auto",
+ )
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, backend="auto"
)
# Create indices updater
- self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
- model_runner, self
- )
+ if not skip_prefill:
+ self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
+ model_runner, self
+ )
+
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
model_runner, self
)
@@ -114,7 +136,7 @@ def __init__(
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
- self.prefill_cuda_graph_metadata = {}
+ self.prefill_cuda_graph_metadata = {} # For verify
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
@@ -126,6 +148,28 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
init_metadata_replay=False,
)
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
+ elif forward_batch.forward_mode.is_draft_extend():
+ self.indices_updater_prefill.update(
+ forward_batch.req_pool_indices,
+ forward_batch.seq_lens,
+ forward_batch.seq_lens_sum,
+ prefix_lens=None,
+ prefill_wrapper_paged=self.prefill_wrapper_paged,
+ use_ragged=False,
+ spec_info=forward_batch.spec_info,
+ )
+ self.forward_metadata = PrefillMetadata(self.prefill_wrapper_paged, False)
+ elif forward_batch.forward_mode.is_target_verify():
+ self.indices_updater_prefill.update(
+ forward_batch.req_pool_indices,
+ forward_batch.seq_lens,
+ forward_batch.seq_lens_sum,
+ prefix_lens=None,
+ prefill_wrapper_paged=self.prefill_wrapper_verify,
+ use_ragged=False,
+ spec_info=forward_batch.spec_info,
+ )
+ self.forward_metadata = PrefillMetadata(self.prefill_wrapper_verify, False)
else:
prefix_lens = forward_batch.extend_prefix_lens
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
@@ -202,10 +246,33 @@ def init_forward_metadata_capture_cuda_graph(
seq_lens_sum,
decode_wrapper=decode_wrapper,
init_metadata_replay=False,
+ spec_info=spec_info,
)
self.decode_cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = DecodeMetadata(decode_wrapper)
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
+ elif forward_mode.is_target_verify():
+ verify_wrapper = BatchMLAPagedAttentionWrapper(
+ self.workspace_buffer,
+ use_cuda_graph=True,
+ qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
+ kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
+ kv_indices=self.cuda_graph_kv_indices,
+ kv_len_arr=self.cuda_graph_kv_lens[:bs],
+ backend="auto",
+ )
+ seq_lens_sum = seq_lens.sum().item()
+ self.indices_updater_prefill.update(
+ req_pool_indices,
+ seq_lens,
+ seq_lens_sum,
+ prefix_lens=None,
+ prefill_wrapper_paged=verify_wrapper,
+ use_ragged=False,
+ spec_info=spec_info,
+ )
+ self.prefill_cuda_graph_metadata[bs] = verify_wrapper
+ self.forward_metadata = PrefillMetadata(verify_wrapper, False)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
@@ -221,6 +288,7 @@ def init_forward_metadata_replay_cuda_graph(
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
+ assert seq_lens_cpu is not None
kv_len_arr_cpu = seq_lens_cpu[:bs]
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
kv_len_arr_cpu, dim=0
@@ -239,8 +307,19 @@ def init_forward_metadata_replay_cuda_graph(
seq_lens_sum,
decode_wrapper=self.decode_cuda_graph_metadata[bs],
init_metadata_replay=True,
+ spec_info=spec_info,
**self.fast_decode_kwargs,
)
+ elif forward_mode.is_target_verify():
+ self.indices_updater_prefill.update(
+ req_pool_indices[:bs],
+ seq_lens[:bs],
+ seq_lens_sum,
+ prefix_lens=None,
+ prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
+ use_ragged=False,
+ spec_info=spec_info,
+ )
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
@@ -254,7 +333,7 @@ def forward_extend(
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
- save_kv_cache=True,
+ save_kv_cache: bool = True,
):
cache_loc = forward_batch.out_cache_loc
@@ -297,7 +376,7 @@ def forward_decode(
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
- save_kv_cache=True,
+ save_kv_cache: bool = True,
):
decode_wrapper = self.forward_metadata.decode_wrapper
cache_loc = forward_batch.out_cache_loc
@@ -349,6 +428,7 @@ def update(
seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False,
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
**fast_decode_kwargs,
):
decode_wrapper = decode_wrapper or self.decode_wrapper
@@ -360,6 +440,7 @@ def update(
self.q_indptr,
self.kv_indptr,
init_metadata_replay,
+ spec_info,
**fast_decode_kwargs,
)
@@ -372,30 +453,33 @@ def call_begin_forward(
q_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
init_metadata_replay: bool = False,
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
**fast_decode_kwargs,
):
bs = len(req_pool_indices)
q_indptr = q_indptr[: bs + 1]
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
- kv_indptr = kv_indptr[: bs + 1]
- kv_indices = (
- torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
- if not init_metadata_replay
- else fast_decode_kwargs["kv_indices"]
- )
-
kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling
+ if spec_info is None:
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
+ kv_indptr = kv_indptr[: bs + 1]
+ kv_indices = (
+ torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
+ if not init_metadata_replay
+ else fast_decode_kwargs["kv_indices"]
+ )
+ create_flashinfer_kv_indices_triton[(bs,)](
+ self.req_to_token,
+ req_pool_indices,
+ paged_kernel_lens,
+ kv_indptr,
+ None,
+ kv_indices,
+ self.req_to_token.shape[1],
+ )
+ else:
+ kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
- create_flashinfer_kv_indices_triton[(bs,)](
- self.req_to_token,
- req_pool_indices,
- paged_kernel_lens,
- kv_indptr,
- None,
- kv_indices,
- self.req_to_token.shape[1],
- )
if not init_metadata_replay:
wrapper.plan(
q_indptr,
@@ -457,6 +541,7 @@ def update(
prefix_lens: torch.Tensor,
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
use_ragged: bool,
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
):
if use_ragged:
paged_kernel_lens = prefix_lens
@@ -476,6 +561,7 @@ def update(
self.kv_indptr,
self.qo_indptr,
use_ragged,
+ spec_info,
)
def call_begin_forward(
@@ -490,29 +576,46 @@ def call_begin_forward(
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
):
- bs = len(req_pool_indices)
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
- kv_indptr = kv_indptr[: bs + 1]
- kv_indices = torch.empty(
- paged_kernel_lens_sum,
- dtype=torch.int32,
- device=req_pool_indices.device,
- )
- create_flashinfer_kv_indices_triton[(bs,)](
- self.req_to_token,
- req_pool_indices,
- paged_kernel_lens,
- kv_indptr,
- None,
- kv_indices,
- self.req_to_token.shape[1],
- )
-
- qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
- qo_indptr = qo_indptr[: bs + 1]
+ bs = len(seq_lens)
sm_scale = self.scaling
+ if spec_info is None:
+ assert len(seq_lens) == len(req_pool_indices)
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
+ kv_indptr = kv_indptr[: bs + 1]
+ kv_indices = torch.empty(
+ paged_kernel_lens_sum,
+ dtype=torch.int32,
+ device=req_pool_indices.device,
+ )
+ create_flashinfer_kv_indices_triton[(bs,)](
+ self.req_to_token,
+ req_pool_indices,
+ paged_kernel_lens,
+ kv_indptr,
+ None,
+ kv_indices,
+ self.req_to_token.shape[1],
+ )
+ qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
+ qo_indptr = qo_indptr[: bs + 1]
+ custom_mask = None
+ else:
+ assert isinstance(spec_info, EagleDraftInput) or isinstance(
+ spec_info, EagleVerifyInput
+ )
+ # TODO: Support topk > 1 with custom mask
+ kv_indices, kv_indptr, qo_indptr, custom_mask = (
+ spec_info.generate_attn_arg_prefill(
+ req_pool_indices,
+ paged_kernel_lens,
+ paged_kernel_lens_sum,
+ self.req_to_token,
+ )
+ )
+
if use_ragged:
# ragged prefill
wrapper_ragged.begin_forward(
@@ -543,6 +646,163 @@ def call_begin_forward(
)
+class FlashInferMLAMultiStepDraftBackend:
+ """
+ Wrap multiple flashinfer mla attention backends as one for multiple consecutive
+ draft decoding steps.
+ """
+
+ def __init__(
+ self,
+ model_runner: ModelRunner,
+ topk: int,
+ speculative_num_steps: int,
+ ):
+ from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
+
+ if topk > 1:
+ raise ValueError(
+ f"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
+ )
+ self.topk = topk
+ self.speculative_num_steps = speculative_num_steps
+ self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
+
+ max_bs = model_runner.req_to_token_pool.size * self.topk
+ self.kv_indptr = torch.zeros(
+ (
+ self.speculative_num_steps,
+ max_bs + 1,
+ ),
+ dtype=torch.int32,
+ device=model_runner.device,
+ )
+ self.q_indptr_decode = torch.arange(
+ 0, max_bs + 1, dtype=torch.int32, device=model_runner.device
+ )
+
+ self.attn_backends = []
+ for i in range(self.speculative_num_steps):
+ self.attn_backends.append(
+ FlashInferMLAAttnBackend(
+ model_runner,
+ skip_prefill=True,
+ kv_indptr_buf=self.kv_indptr[i],
+ q_indptr_decode_buf=self.q_indptr_decode,
+ )
+ )
+
+ self.max_context_len = self.attn_backends[0].max_context_len
+
+ # Cached variables for generate_draft_decode_kv_indices
+ self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
+
+ def common_template(
+ self,
+ forward_batch: ForwardBatch,
+ kv_indices_buffer: torch.Tensor,
+ call_fn: Callable,
+ ):
+ num_seqs = forward_batch.batch_size
+ bs = self.topk * num_seqs
+ seq_lens_sum = forward_batch.seq_lens_sum
+
+ self.generate_draft_decode_kv_indices[
+ (self.speculative_num_steps, num_seqs, self.topk)
+ ](
+ forward_batch.req_pool_indices,
+ forward_batch.req_to_token_pool.req_to_token,
+ forward_batch.seq_lens,
+ kv_indices_buffer,
+ self.kv_indptr,
+ forward_batch.positions,
+ num_seqs,
+ self.topk,
+ self.pool_len,
+ kv_indices_buffer.shape[1],
+ self.kv_indptr.shape[1],
+ triton.next_power_of_2(num_seqs),
+ triton.next_power_of_2(self.speculative_num_steps),
+ triton.next_power_of_2(bs),
+ )
+
+ assert forward_batch.spec_info is not None
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
+
+ for i in range(self.speculative_num_steps - 1):
+ forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
+ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
+ : seq_lens_sum * self.topk + bs * (i + 1)
+ ]
+ call_fn(i, forward_batch)
+
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
+ kv_indices = torch.zeros(
+ (
+ self.speculative_num_steps,
+ forward_batch.batch_size * self.topk * self.max_context_len,
+ ),
+ dtype=torch.int32,
+ device="cuda",
+ )
+
+ def call_fn(i, forward_batch):
+ assert forward_batch.spec_info is not None
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
+ forward_batch.spec_info.kv_indptr = (
+ forward_batch.spec_info.kv_indptr.clone()
+ )
+ forward_batch.spec_info.kv_indices = (
+ forward_batch.spec_info.kv_indices.clone()
+ )
+ self.attn_backends[i].init_forward_metadata(forward_batch)
+
+ self.common_template(forward_batch, kv_indices, call_fn)
+
+ def init_cuda_graph_state(self, max_bs: int):
+ self.cuda_graph_kv_indices = torch.zeros(
+ (self.speculative_num_steps, max_bs * self.max_context_len),
+ dtype=torch.int32,
+ device="cuda",
+ )
+
+ for i in range(self.speculative_num_steps):
+ self.attn_backends[i].init_cuda_graph_state(
+ max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
+ )
+
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
+ def call_fn(i, forward_batch):
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
+ forward_batch.batch_size,
+ forward_batch.batch_size * self.topk,
+ forward_batch.req_pool_indices,
+ forward_batch.seq_lens,
+ encoder_lens=None,
+ forward_mode=ForwardMode.DECODE,
+ spec_info=forward_batch.spec_info,
+ )
+
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
+
+ def init_forward_metadata_replay_cuda_graph(
+ self, forward_batch: ForwardBatch, bs: int
+ ):
+ def call_fn(i, forward_batch):
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
+ bs,
+ forward_batch.req_pool_indices,
+ forward_batch.seq_lens,
+ seq_lens_sum=-1,
+ encoder_lens=None,
+ forward_mode=ForwardMode.DECODE,
+ spec_info=forward_batch.spec_info,
+ seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
+ )
+
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
+
+
def fast_mla_decode_plan(
self,
qo_indptr_cpu: torch.Tensor,
diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py
index a9d72618085..b942dee5cf5 100644
--- a/python/sglang/srt/layers/attention/triton_backend.py
+++ b/python/sglang/srt/layers/attention/triton_backend.py
@@ -6,9 +6,7 @@
import triton
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
-from sglang.srt.layers.attention.flashinfer_backend import (
- create_flashinfer_kv_indices_triton,
-)
+from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py
index 919bcced3a8..85748fa7434 100644
--- a/python/sglang/srt/layers/linear.py
+++ b/python/sglang/srt/layers/linear.py
@@ -18,6 +18,7 @@
)
from sglang.srt.layers.parameter import (
BasevLLMParameter,
+ BlockQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
@@ -27,7 +28,6 @@
QuantizationConfig,
QuantizeMethodBase,
)
-from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
from sglang.srt.utils import set_weight_attrs
logger = logging.getLogger(__name__)
diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py
index 78be6798254..b3fc6b440c4 100644
--- a/python/sglang/srt/layers/parameter.py
+++ b/python/sglang/srt/layers/parameter.py
@@ -16,6 +16,7 @@
"ModelWeightParameter",
"ChannelQuantScaleParameter",
"GroupQuantScaleParameter",
+ "BlockQuantScaleParameter",
"PackedColumnParameter",
"RowvLLMParameter",
]
@@ -221,6 +222,15 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter):
pass
+class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
+ """
+ Parameter class for weight scales loaded for weights with
+ block-wise quantization. Uses both column and row parallelism.
+ """
+
+ pass
+
+
class PerTensorScaleParameter(BasevLLMParameter):
"""
Parameter class for scales where the number of scales is
diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py
index 1ef8f43816f..c09fb5a1a00 100644
--- a/python/sglang/srt/layers/quantization/__init__.py
+++ b/python/sglang/srt/layers/quantization/__init__.py
@@ -28,6 +28,7 @@
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
+from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
@@ -50,6 +51,7 @@
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
"w8a8_int8": W8A8Int8Config,
+ "w8a8_fp8": W8A8Fp8Config,
}
diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py
index 1470ca427b5..ce526cd6a9b 100644
--- a/python/sglang/srt/layers/quantization/blockwise_int8.py
+++ b/python/sglang/srt/layers/quantization/blockwise_int8.py
@@ -13,12 +13,11 @@
LinearMethodBase,
UnquantizedLinearMethod,
)
-from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
+from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
-from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
from sglang.srt.utils import set_weight_attrs
diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py
index e296756b54b..44a3cba8ad0 100644
--- a/python/sglang/srt/layers/quantization/fp8.py
+++ b/python/sglang/srt/layers/quantization/fp8.py
@@ -16,9 +16,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
- apply_fp8_linear,
convert_to_channelwise,
- cutlass_fp8_supported,
per_tensor_dequantize,
requantize_with_max_scale,
)
@@ -29,14 +27,21 @@
LinearMethodBase,
UnquantizedLinearMethod,
)
-from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
+from sglang.srt.layers.parameter import (
+ BlockQuantScaleParameter,
+ ModelWeightParameter,
+ PerTensorScaleParameter,
+)
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
+from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.fp8_utils import (
- BlockQuantScaleParameter,
+ apply_fp8_linear,
apply_w8a8_block_fp8_linear,
+ cutlass_fp8_supported,
+ input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.utils import (
@@ -305,15 +310,15 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
- qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
-
- # If using marlin (w8a16), kernel uses channelwise weights,
- # so extend the weight scales to be channelwise.
- if self.use_marlin:
- assert weight_scale.numel() == 1
- weight_scale = convert_to_channelwise(
- weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
+ if self.cutlass_fp8_supported or self.use_marlin:
+ # apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
+ qweight, weight_scale = per_token_group_quant_fp8(
+ layer.weight, layer.weight.shape[-1]
)
+ weight_scale = weight_scale.t().contiguous()
+ else:
+ # per-tensor quantization
+ qweight, weight_scale = input_to_float8(layer.weight)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
@@ -330,23 +335,19 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
)
- # If using marlin (w8a16), kernel uses channelwise weights,
- # so extend the weight scales to be channelwise.
- if self.use_marlin:
+
+ # cutlass sgl-kernel and marlin only support per-channel scale
+ if self.cutlass_fp8_supported or self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)
-
- # If using w8a8, torch._scaled_mm needs per tensor, so
- # requantize the logical shards as a single weight.
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
-
# If ROCm, normalize the weights and scales to e4m3fnuz
- if is_hip_:
+ if is_hip():
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py
index 47f310a24de..54c07f90940 100644
--- a/python/sglang/srt/layers/quantization/fp8_kernel.py
+++ b/python/sglang/srt/layers/quantization/fp8_kernel.py
@@ -29,7 +29,7 @@
_is_cuda = torch.cuda.is_available() and torch.version.cuda
if _is_cuda:
- from sgl_kernel import sgl_per_token_group_quant_fp8
+ from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
logger = logging.getLogger(__name__)
@@ -70,7 +70,8 @@ def _per_token_group_quant_fp8(
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
- y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
+ y_s_inv = 1.0 / y_s
+ y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
@@ -140,7 +141,7 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
- dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
+ dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
@@ -241,6 +242,132 @@ def sglang_per_token_group_quant_fp8(
return x_q, x_s
+def sglang_per_token_quant_fp8(
+ x: torch.Tensor,
+ dtype: torch.dtype = fp8_type_,
+):
+ assert x.is_contiguous(), "`x` is not contiguous"
+
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
+ x_s = torch.empty(
+ x.shape[0],
+ 1,
+ device=x.device,
+ dtype=torch.float32,
+ )
+
+ sgl_per_token_quant_fp8(x, x_q, x_s)
+
+ return x_q, x_s
+
+
+@triton.jit
+def _static_quant_fp8(
+ # Pointers to inputs and output
+ y_ptr,
+ y_q_ptr,
+ y_s_ptr,
+ y_s_repeat_ptr,
+ # Stride of input
+ y_stride,
+ # Collums of input
+ N,
+ # Information for float8
+ fp8_min,
+ fp8_max,
+ # Meta-parameters
+ BLOCK: tl.constexpr,
+ REPEAT_SCALE: tl.constexpr,
+):
+ """A Triton-accelerated function to perform quantization using the given scale on a
+ tensor
+
+ This function converts the tensor values into float8 values.
+ """
+ # Map the program id to the row of X and Y it should compute.
+ g_id = tl.program_id(0)
+ y_ptr += g_id * y_stride
+ y_q_ptr += g_id * y_stride
+ if REPEAT_SCALE:
+ y_s_repeat_ptr += g_id
+
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
+ mask = cols < N
+
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
+ y_s = tl.load(y_s_ptr).to(tl.float32)
+ y_s_inv = 1.0 / y_s
+ y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
+
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
+ if REPEAT_SCALE:
+ tl.store(y_s_repeat_ptr, y_s)
+
+
+def static_quant_fp8(
+ x: torch.Tensor,
+ x_s: torch.Tensor,
+ repeat_scale: bool = False,
+ dtype: torch.dtype = fp8_type_,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Function to perform static quantization using the given scale on an input tensor `x`.
+
+ It converts the tensor values into signed float8 values and returns the
+ quantized tensor along with the scaling factor used for quantization.
+
+ Args:
+ x: The input tenosr with ndim >= 2.
+ x_s: The quantization scale.
+ repeat_scale: Whether to broadcast per-tensor scale to per-channel scale.
+ dtype: The dype of output tensor.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
+ """
+ assert x.is_contiguous(), "`x` is not contiguous"
+ assert x_s.numel() == 1, "only supports per-tensor scale"
+ finfo = torch.finfo(dtype)
+ fp8_max = finfo.max
+
+ if is_hip_:
+ fp8_max = 224.0
+
+ fp8_min = -fp8_max
+
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
+ M = x.numel() // x.shape[-1]
+ N = x.shape[-1]
+ if repeat_scale:
+ x_s_repeat = torch.empty(
+ (M, 1),
+ device=x.device,
+ dtype=torch.float32,
+ )
+ else:
+ x_s_repeat = None
+
+ BLOCK = triton.next_power_of_2(N)
+ # heuristics for number of warps
+ num_warps = min(max(BLOCK // 256, 1), 8)
+ num_stages = 1
+ _static_quant_fp8[(M,)](
+ x,
+ x_q,
+ x_s,
+ x_s_repeat,
+ N,
+ N,
+ fp8_min=fp8_min,
+ fp8_max=fp8_max,
+ BLOCK=BLOCK,
+ REPEAT_SCALE=repeat_scale,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ x_s = x_s_repeat if repeat_scale else x_s
+ return x_q, x_s
+
+
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py
index ff10f0a5632..feaae26f6c7 100644
--- a/python/sglang/srt/layers/quantization/fp8_utils.py
+++ b/python/sglang/srt/layers/quantization/fp8_utils.py
@@ -2,13 +2,23 @@
from typing import List, Optional, Tuple
import torch
+from packaging.version import Version
-from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
+ static_quant_fp8,
w8a8_block_fp8_matmul,
)
-from sglang.srt.utils import get_bool_env_var, is_hip
+from sglang.srt.utils import (
+ get_bool_env_var,
+ get_cuda_version,
+ get_device_capability,
+ is_hip,
+)
+
+use_vllm_cutlass_w8a8_fp8_kernel = os.environ.get(
+ "USE_VLLM_CUTLASS_W8A8_FP8_KERNEL", default=False
+)
is_hip_ = is_hip()
if is_hip_ and get_bool_env_var("CK_MOE"):
@@ -18,6 +28,25 @@
if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm
+ from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
+
+ if use_vllm_cutlass_w8a8_fp8_kernel:
+ from vllm import _custom_ops as ops
+ else:
+ from sgl_kernel import fp8_scaled_mm
+
+
+def cutlass_fp8_supported():
+ if not _is_cuda:
+ return False
+ major, minor = get_device_capability()
+ cuda_version = get_cuda_version()
+ if major >= 9:
+ return cuda_version >= (12, 0)
+ elif major == 8 and minor == 9:
+ return cuda_version >= (12, 4)
+ return False
+
def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
@@ -158,10 +187,121 @@ def block_quant_to_tensor_quant(
return x_q_tensor, scale
-class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
- """
- Parameter class for weight scales loaded for weights with
- block-wise quantization. Uses both column and row parallelism.
- """
+def apply_fp8_linear(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_scale: torch.Tensor,
+ input_scale: Optional[torch.Tensor] = None,
+ input_scale_ub: Optional[torch.Tensor] = None,
+ bias: Optional[torch.Tensor] = None,
+ cutlass_fp8_supported: bool = True,
+ use_per_token_if_dynamic: bool = False,
+) -> torch.Tensor:
+ # View input as 2D matrix for fp8 methods
+ input_2d = input.view(-1, input.shape[-1])
+ output_shape = [*input.shape[:-1], weight.shape[1]]
+
+ # cutlass w8a8 fp8 sgl-kernel only supports per-token scale
+ if input_scale is not None:
+ assert input_scale.numel() == 1
+ # broadcast per-tensor scale to per-token scale when supporting cutlass
+ qinput, x_scale = static_quant_fp8(
+ input_2d, input_scale, repeat_scale=cutlass_fp8_supported
+ )
+ else:
+ # default use per-token quantization if dynamic
+ if _is_cuda:
+ qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
+ else:
+ qinput, x_scale = per_token_group_quant_fp8(
+ input_2d, group_size=input_2d.shape[1]
+ )
+
+ if cutlass_fp8_supported:
+ if use_vllm_cutlass_w8a8_fp8_kernel:
+ # Fall back to vllm cutlass w8a8 fp8 kernel
+ output = ops.cutlass_scaled_mm(
+ qinput,
+ weight,
+ out_dtype=input.dtype,
+ scale_a=x_scale,
+ scale_b=weight_scale,
+ bias=bias,
+ )
+ else:
+ assert (
+ weight_scale.numel() == weight.shape[1]
+ ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
+ output = fp8_scaled_mm(
+ qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias
+ )
+ return output.view(*output_shape)
+
+ # torch.scaled_mm supports per tensor weights + activations only
+ # so fallback to naive if per channel or per token
+ else:
+ per_tensor_weights = weight_scale.numel() == 1
+ per_tensor_activations = x_scale.numel() == 1
+
+ if per_tensor_weights and per_tensor_activations:
+ # Fused GEMM_DQ
+ output = torch._scaled_mm(
+ qinput,
+ weight,
+ out_dtype=input.dtype,
+ scale_a=x_scale,
+ scale_b=weight_scale,
+ bias=bias,
+ )
+ # A fix for discrepancy in scaled_mm which returns tuple
+ # for torch < 2.5 and a single value in torch >= 2.5
+ if type(output) is tuple and len(output) == 2:
+ output = output[0]
+
+ return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
+
+ else:
+ # Fallback for channelwise case, where we use unfused DQ
+ # due to limitations with scaled_mm
+
+ # Symmetric quantized GEMM by definition computes the following:
+ # C = (s_x * X) (s_w * W) + bias
+ # This is equivalent to dequantizing the weights and activations
+ # before applying a GEMM.
+ #
+ # In order to compute quantized operands, a quantized kernel
+ # will rewrite the above like so:
+ # C = s_w * s_x * (X * W) + bias
+ #
+ # For the scaled_mm fallback case, we break this down, since it
+ # does not support s_w being a vector.
+
+ # Making sure the dummy tensor is on the same device as the weight
+ global TORCH_DEVICE_IDENTITY
+ if TORCH_DEVICE_IDENTITY.device != weight.device:
+ TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
+
+ # GEMM
+ # This computes C = (X * W).
+ # Output in fp32 to allow subsequent ops to happen in-place
+ output = torch._scaled_mm(
+ qinput,
+ weight,
+ scale_a=TORCH_DEVICE_IDENTITY,
+ scale_b=TORCH_DEVICE_IDENTITY,
+ out_dtype=torch.float32,
+ )
+ # A fix for discrepancy in scaled_mm which returns tuple
+ # for torch < 2.5 and a single value in torch >= 2.5
+ if type(output) is tuple and len(output) == 2:
+ output = output[0]
+ # Unpad (undo num_token_padding)
+ output = torch.narrow(output, 0, 0, input_2d.shape[0])
+ x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
- pass
+ # DQ
+ # C = sw * sx * (X * W) + bias
+ output = output * x_scale * weight_scale.t()
+ if bias is not None:
+ output = output + bias
+ return output.to(dtype=input.dtype).view(*output_shape)
diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py
index a28e0aeea04..c26012da21e 100644
--- a/python/sglang/srt/layers/quantization/modelopt_quant.py
+++ b/python/sglang/srt/layers/quantization/modelopt_quant.py
@@ -7,7 +7,7 @@
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
- apply_fp8_linear,
+ convert_to_channelwise,
cutlass_fp8_supported,
requantize_with_max_scale,
)
@@ -19,6 +19,7 @@
QuantizationConfig,
QuantizeMethodBase,
)
+from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
# Initialize logger for the module
logger = logging.getLogger(__name__)
@@ -161,6 +162,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight, layer.weight_scale, layer.logical_widths
)
layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
+ # cutlass sgl-kernel only supports per-channel scale
+ if self.cutlass_fp8_supported:
+ max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py
new file mode 100644
index 00000000000..0adedc68fcd
--- /dev/null
+++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py
@@ -0,0 +1,126 @@
+from typing import Any, Dict, List, Optional
+
+import torch
+from torch.nn.parameter import Parameter
+
+from sglang.srt.layers.linear import LinearMethodBase
+from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
+from sglang.srt.layers.quantization.base_config import (
+ QuantizationConfig,
+ QuantizeMethodBase,
+)
+from sglang.srt.layers.quantization.fp8_utils import (
+ apply_fp8_linear,
+ cutlass_fp8_supported,
+ normalize_e4m3fn_to_e4m3fnuz,
+)
+from sglang.srt.utils import is_hip
+
+
+class W8A8Fp8Config(QuantizationConfig):
+ """Config class for W8A8 FP8 Quantization.
+
+ - Weight: static, per-channel, symmetric
+ - Activation: dynamic, per-token, symmetric
+ """
+
+ def __init__(self):
+ pass
+
+ @classmethod
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+ return [torch.float16, torch.bfloat16]
+
+ @classmethod
+ def get_min_capability(cls) -> int:
+ return 89
+
+ @classmethod
+ def get_name(self) -> str:
+ return "w8a8_fp8"
+
+ @classmethod
+ def get_config_filenames(cls) -> List[str]:
+ return []
+
+ @classmethod
+ def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
+ return cls()
+
+ def get_quant_method(
+ self,
+ layer: torch.nn.Module,
+ prefix: str,
+ ) -> Optional["QuantizeMethodBase"]:
+ from sglang.srt.layers.linear import LinearBase
+
+ if isinstance(layer, LinearBase):
+ return W8A8Fp8LinearMethod(self)
+ return None
+
+ def get_scaled_act_names(self) -> List[str]:
+ return []
+
+
+class W8A8Fp8LinearMethod(LinearMethodBase):
+
+ def __init__(self, quantization_config: W8A8Fp8Config):
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
+ self.quantization_config = quantization_config
+
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ weight = layer.weight
+ weight_scale = layer.weight_scale.detach()
+ if is_hip():
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
+ weight=weight, weight_scale=weight_scale
+ )
+ layer.weight = Parameter(weight.t(), requires_grad=False)
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
+
+ def create_weights(
+ self,
+ layer: torch.nn.Module,
+ input_size_per_partition: int,
+ output_partition_sizes: List[int],
+ input_size: int,
+ output_size: int,
+ params_dtype: torch.dtype,
+ **extra_weight_attrs
+ ):
+
+ weight_loader = extra_weight_attrs.get("weight_loader")
+ self.logical_widths = output_partition_sizes
+
+ weight = ModelWeightParameter(
+ data=torch.empty(
+ sum(output_partition_sizes),
+ input_size_per_partition,
+ dtype=torch.float8_e4m3fn,
+ ),
+ input_dim=1,
+ output_dim=0,
+ weight_loader=weight_loader,
+ )
+ layer.register_parameter("weight", weight)
+
+ weight_scale = ChannelQuantScaleParameter(
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
+ output_dim=0,
+ weight_loader=weight_loader,
+ )
+ layer.register_parameter("weight_scale", weight_scale)
+
+ def apply(
+ self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ ):
+ return apply_fp8_linear(
+ x,
+ layer.weight,
+ layer.weight_scale,
+ bias=bias,
+ cutlass_fp8_supported=self.cutlass_fp8_supported,
+ )
diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py
index f471626e1f7..ec041305c7b 100644
--- a/python/sglang/srt/layers/sampler.py
+++ b/python/sglang/srt/layers/sampler.py
@@ -42,7 +42,6 @@ def forward(
return_logprob: bool,
top_logprobs_nums: List[int],
token_ids_logprobs: List[List[int]],
- batch_next_token_ids: Optional[torch.Tensor] = None,
):
"""Run a sampler & compute logprobs and update logits_output accordingly.
@@ -72,8 +71,7 @@ def forward(
if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
- if batch_next_token_ids is None:
- batch_next_token_ids = torch.argmax(logits, -1)
+ batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else:
@@ -94,43 +92,39 @@ def forward(
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
).clamp(min=torch.finfo(probs.dtype).min)
- if batch_next_token_ids is None:
- max_top_k_round, batch_size = 32, probs.shape[0]
- uniform_samples = torch.rand(
- (max_top_k_round, batch_size), device=probs.device
+ max_top_k_round, batch_size = 32, probs.shape[0]
+ uniform_samples = torch.rand(
+ (max_top_k_round, batch_size), device=probs.device
+ )
+ if sampling_info.need_min_p_sampling:
+ probs = top_k_renorm_prob(probs, sampling_info.top_ks)
+ probs = top_p_renorm_prob(probs, sampling_info.top_ps)
+ batch_next_token_ids = min_p_sampling_from_probs(
+ probs, uniform_samples, sampling_info.min_ps
)
- if sampling_info.need_min_p_sampling:
- probs = top_k_renorm_prob(probs, sampling_info.top_ks)
- probs = top_p_renorm_prob(probs, sampling_info.top_ps)
- batch_next_token_ids = min_p_sampling_from_probs(
- probs, uniform_samples, sampling_info.min_ps
- )
- else:
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
- probs,
- uniform_samples,
- sampling_info.top_ks,
- sampling_info.top_ps,
- filter_apply_order="joint",
- )
-
- if self.use_nan_detection and not torch.all(success):
- logger.warning("Detected errors during sampling!")
- batch_next_token_ids = torch.zeros_like(
- batch_next_token_ids
- )
-
- elif global_server_args_dict["sampling_backend"] == "pytorch":
- if batch_next_token_ids is None:
- # A slower fallback implementation with torch native operations.
- batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
+ else:
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs,
+ uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
- sampling_info.min_ps,
- sampling_info.need_min_p_sampling,
+ filter_apply_order="joint",
)
+ if self.use_nan_detection and not torch.all(success):
+ logger.warning("Detected errors during sampling!")
+ batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
+
+ elif global_server_args_dict["sampling_backend"] == "pytorch":
+ # A slower fallback implementation with torch native operations.
+ batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
+ probs,
+ sampling_info.top_ks,
+ sampling_info.top_ps,
+ sampling_info.min_ps,
+ sampling_info.need_min_p_sampling,
+ )
+
if return_logprob:
# clamp to avoid -inf
logprobs = torch.log(
diff --git a/python/sglang/srt/lora/backend/__init__.py b/python/sglang/srt/lora/backend/__init__.py
index 07fe11d23bf..7b76f90e52e 100644
--- a/python/sglang/srt/lora/backend/__init__.py
+++ b/python/sglang/srt/lora/backend/__init__.py
@@ -1,23 +1,20 @@
-from .base_backend import BaseLoRABackend
-from .flashinfer_backend import FlashInferLoRABackend
-from .triton_backend import TritonLoRABackend
+from sglang.srt.lora.backend.base_backend import BaseLoRABackend
def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
- backend_mapping = {
- "triton": TritonLoRABackend,
- "flashinfer": FlashInferLoRABackend,
- }
+ if name == "triton":
+ from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
- if name in backend_mapping:
- return backend_mapping[name]
+ return TritonLoRABackend
+ elif name == "flashinfer":
+ from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
- raise Exception(
- f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
- )
+ return FlashInferLoRABackend
+ else:
+ raise ValueError(f"Invalid backend: {name}")
__all__ = [
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 05bc8d730ea..a5c6a1dbdcd 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -957,7 +957,13 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.batch_is_full = False
+ # Filter batch
+ last_bs = self.last_batch.batch_size()
self.last_batch.filter_batch()
+ if self.last_batch.batch_size() < last_bs:
+ self.batch_is_full = False
+
+ # Merge the new batch into the running batch
if not self.last_batch.is_empty():
if self.running_batch is None:
self.running_batch = self.last_batch
diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py
index 813fbf6fc1c..83c2d88f01f 100644
--- a/python/sglang/srt/model_executor/cuda_graph_runner.py
+++ b/python/sglang/srt/model_executor/cuda_graph_runner.py
@@ -300,10 +300,11 @@ def can_run(self, forward_batch: ForwardBatch):
def capture(self):
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
+ # Reverse the order to enable better memory sharing across cuda graphs.
capture_range = (
- tqdm.tqdm(self.capture_bs)
+ tqdm.tqdm(list(reversed(self.capture_bs)))
if get_tensor_model_parallel_rank() == 0
- else self.capture_bs
+ else reversed(self.capture_bs)
)
for bs in capture_range:
with patch_model(
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 666b97e2b8e..8040709a721 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -35,11 +35,6 @@
set_custom_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
-from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
-from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
-from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
-from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
-from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
get_attention_tp_size,
@@ -77,7 +72,6 @@
set_cpu_offload_max_bytes,
set_cuda_arch,
)
-from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@@ -779,6 +773,10 @@ def init_cublas(self):
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer":
+ from sglang.srt.layers.attention.flashinfer_backend import (
+ FlashInferAttnBackend,
+ )
+
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
@@ -794,12 +792,26 @@ def init_attention_backend(self):
"Please use `--attention-backend flashinfer`."
)
if self.server_args.enable_double_sparsity:
+ from sglang.srt.layers.attention.double_sparsity_backend import (
+ DoubleSparseAttnBackend,
+ )
+
self.attn_backend = DoubleSparseAttnBackend(self)
else:
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
+
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
+ from sglang.srt.layers.attention.torch_native_backend import (
+ TorchNativeAttnBackend,
+ )
+
self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
+ FlashInferMLAAttnBackend,
+ )
+
self.attn_backend = FlashInferMLAAttnBackend(self)
else:
raise ValueError(
@@ -928,45 +940,6 @@ def _preprocess_logits(
sampling_info.update_regex_vocab_mask()
sampling_info.apply_logits_bias(logits_output.next_token_logits)
- def update_output_logprobs(
- self,
- logits_output: LogitsProcessorOutput,
- sampling_info: SamplingBatchInfo,
- top_logprobs_nums: List[int],
- token_ids_logprobs: List[int],
- next_token_ids: torch.Tensor,
- *,
- num_tokens_per_req: List[int],
- ):
- """Update the logits_output's output logprob based on next_token_ids
-
- Args:
- logits_output: The logits output from the model forward
- sampling_info: Sampling info for logprob calculation
- top_logprobs_nums: Number of logprobs per request.
- next_token_ids: Next token ids.
- num_tokens_per_req: The number of tokens per request.
-
- Returns:
- A list of next_token_ids
- """
- self._preprocess_logits(logits_output, sampling_info)
- # We should repeat top_logprobs_nums to match num_tokens_per_req.
- top_logprobs_nums_repeat_interleaved = []
- token_ids_logprobs_repeat_interleaved = []
- for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
- top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
- for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
- token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
- self.sampler(
- logits_output,
- sampling_info,
- True,
- top_logprobs_nums_repeat_interleaved,
- token_ids_logprobs_repeat_interleaved,
- batch_next_token_ids=next_token_ids,
- )
-
def sample(
self,
logits_output: LogitsProcessorOutput,
diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py
index 13544007e77..82c73ec94db 100755
--- a/python/sglang/srt/models/deepseek_v2.py
+++ b/python/sglang/srt/models/deepseek_v2.py
@@ -555,6 +555,8 @@ def no_absorb() -> bool:
return (
not global_server_args_dict["flashinfer_mla_disable_ragged"]
and forward_batch.forward_mode.is_extend()
+ and not forward_batch.forward_mode.is_target_verify()
+ and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
else:
diff --git a/python/sglang/srt/sampling/penaltylib/frequency_penalty.py b/python/sglang/srt/sampling/penaltylib/frequency_penalty.py
index 69153462731..893a1c3775a 100644
--- a/python/sglang/srt/sampling/penaltylib/frequency_penalty.py
+++ b/python/sglang/srt/sampling/penaltylib/frequency_penalty.py
@@ -56,7 +56,6 @@ def _filter(self, keep_indices: torch.Tensor):
]
def _merge(self, their: "BatchedFrequencyPenalizer"):
- print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
self.frequency_penalties = torch.cat(
[self.frequency_penalties, their.frequency_penalties], dim=0
)
diff --git a/python/sglang/srt/sampling/penaltylib/presence_penalty.py b/python/sglang/srt/sampling/penaltylib/presence_penalty.py
index 91266b352fb..4f3a6ace3a0 100644
--- a/python/sglang/srt/sampling/penaltylib/presence_penalty.py
+++ b/python/sglang/srt/sampling/penaltylib/presence_penalty.py
@@ -56,7 +56,6 @@ def _filter(self, keep_indices: torch.Tensor):
]
def _merge(self, their: "BatchedPresencePenalizer"):
- print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
self.presence_penalties = torch.cat(
[self.presence_penalties, their.presence_penalties], dim=0
)
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index c5b8b920e7f..480a415e8cd 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -284,9 +284,13 @@ def __post_init__(self):
"Overlap scheduler are disabled because of using "
"eagle speculative decoding."
)
- # The token generated from the verify step is counted.
+ # The token generated from the verify step is counted in speculative_num_draft_tokens.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
- # assert self.speculative_num_steps < self.speculative_num_draft_tokens
+ assert self.speculative_num_steps < self.speculative_num_draft_tokens
+ assert (
+ self.speculative_num_draft_tokens - 1
+ <= self.speculative_num_steps * self.speculative_eagle_topk
+ )
# GGUF
if (
@@ -405,6 +409,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"gguf",
"modelopt",
"w8a8_int8",
+ "w8a8_fp8",
],
help="The quantization method.",
)
diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py
index 12da787eb31..90d47cc0fd3 100644
--- a/python/sglang/srt/speculative/eagle_worker.py
+++ b/python/sglang/srt/speculative/eagle_worker.py
@@ -7,6 +7,7 @@
from huggingface_hub import snapshot_download
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
+from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
@@ -122,6 +123,16 @@ def init_attention_backend(self):
self.topk,
self.speculative_num_steps,
)
+ elif self.server_args.attention_backend == "flashinfer_mla":
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
+ FlashInferMLAMultiStepDraftBackend,
+ )
+
+ self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
+ self.model_runner,
+ self.topk,
+ self.speculative_num_steps,
+ )
else:
raise ValueError(
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
@@ -302,13 +313,10 @@ def draft_forward(self, forward_batch: ForwardBatch):
# Set inputs
forward_batch.input_ids = input_ids
+ out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
forward_batch.out_cache_loc = out_cache_loc[
- forward_batch.batch_size
- * self.topk
- * i : forward_batch.batch_size
- * self.topk
- * (i + 1)
- ]
+ :, self.topk * i : self.topk * (i + 1)
+ ].flatten()
forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
spec_info.hidden_states = hidden_states
@@ -353,42 +361,70 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
batch.spec_info = res.draft_input
if batch.return_logprob:
- # Compute output logprobs using the sampler.
- num_tokens_per_req = [
- accept + 1 for accept in res.accept_length_per_req_cpu
- ]
- self.target_worker.model_runner.update_output_logprobs(
- logits_output,
- batch.sampling_info,
- batch.top_logprobs_nums,
- batch.token_ids_logprobs,
- res.verified_id,
- # +1 for bonus token.
- num_tokens_per_req=num_tokens_per_req,
- )
-
- # Add output logprobs to the request.
- pt = 0
- # NOTE: tolist() of these values are skipped when output is processed
- next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
- verified_ids = res.verified_id.tolist()
- for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
- for _ in range(num_tokens):
- if req.return_logprob:
- token_id = verified_ids[pt]
- req.output_token_logprobs_val.append(next_token_logprobs[pt])
- req.output_token_logprobs_idx.append(token_id)
- if req.top_logprobs_num > 0:
- req.output_top_logprobs_val.append(
- res.logits_output.next_token_top_logprobs_val[pt]
- )
- req.output_top_logprobs_idx.append(
- res.logits_output.next_token_top_logprobs_idx[pt]
- )
- pt += 1
+ self.add_logprob_values(batch, res, logits_output)
return logits_output, res, model_worker_batch
+ def add_logprob_values(
+ self,
+ batch: ScheduleBatch,
+ res: EagleVerifyOutput,
+ logits_output: LogitsProcessorOutput,
+ ):
+ # Extract args
+ logits_output = res.logits_output
+ top_logprobs_nums = batch.top_logprobs_nums
+ token_ids_logprobs = batch.token_ids_logprobs
+ logprobs = torch.nn.functional.log_softmax(
+ logits_output.next_token_logits, dim=-1
+ )
+ batch_next_token_ids = res.verified_id
+ num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
+
+ # We should repeat top_logprobs_nums to match num_tokens_per_req.
+ top_logprobs_nums_repeat_interleaved = []
+ token_ids_logprobs_repeat_interleaved = []
+ for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
+ top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
+ for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
+ token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
+
+ # Extract logprobs
+ if any(x > 0 for x in top_logprobs_nums):
+ (
+ logits_output.next_token_top_logprobs_val,
+ logits_output.next_token_top_logprobs_idx,
+ ) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
+
+ if any(x is not None for x in token_ids_logprobs):
+ (
+ logits_output.next_token_token_ids_logprobs_val,
+ logits_output.next_token_token_ids_logprobs_idx,
+ ) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
+
+ logits_output.next_token_logprobs = logprobs[
+ torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
+ batch_next_token_ids,
+ ]
+
+ # Add output logprobs to the request.
+ pt = 0
+ next_token_logprobs = logits_output.next_token_logprobs.tolist()
+ verified_ids = batch_next_token_ids.tolist()
+ for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
+ for _ in range(num_tokens):
+ if req.return_logprob:
+ req.output_token_logprobs_val.append(next_token_logprobs[pt])
+ req.output_token_logprobs_idx.append(verified_ids[pt])
+ if req.top_logprobs_num > 0:
+ req.output_top_logprobs_val.append(
+ res.logits_output.next_token_top_logprobs_val[pt]
+ )
+ req.output_top_logprobs_idx.append(
+ res.logits_output.next_token_top_logprobs_idx[pt]
+ )
+ pt += 1
+
def forward_draft_extend(
self,
batch: ScheduleBatch,
diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py
index 1ce2862f963..8bfdbc0ed26 100644
--- a/python/sglang/srt/utils.py
+++ b/python/sglang/srt/utils.py
@@ -52,11 +52,13 @@
import zmq
from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version
+from packaging.version import Version, parse
from starlette.routing import Mount
from torch import nn
from torch.func import functional_call
from torch.library import Library
from torch.profiler import ProfilerActivity, profile, record_function
+from torch.utils.cpp_extension import CUDA_HOME
from triton.runtime.cache import (
FileCacheManager,
default_cache_dir,
@@ -1431,6 +1433,12 @@ def rank0_print(msg: str):
print(msg, flush=True)
+def get_cuda_version():
+ if torch.version.cuda:
+ return tuple(map(int, torch.version.cuda.split(".")))
+ return (0, 0)
+
+
def launch_dummy_health_check_server(host, port):
import uvicorn
from fastapi import FastAPI, Response
diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py
index 3a02531e695..b3da7690ce7 100644
--- a/python/sglang/test/test_block_fp8.py
+++ b/python/sglang/test/test_block_fp8.py
@@ -7,6 +7,7 @@
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
+ static_quant_fp8,
w8a8_block_fp8_matmul,
)
@@ -63,7 +64,7 @@ def _per_token_group_quant_fp8(self, num_tokens, d, dtype, group_size, seed):
out, scale = per_token_group_quant_fp8(x, group_size)
self.assertTrue(
- torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
+ torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20)
)
self.assertTrue(torch.allclose(scale, ref_scale))
@@ -85,6 +86,71 @@ def test_per_token_group_quant_fp8(self):
self._per_token_group_quant_fp8(*params)
+# For test
+def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
+ """Function to perform static quantization on an input tensor `x` using native torch.
+
+ It converts the tensor values into float8 values and returns the
+ quantized tensor along with the scaling factor used for quantization.
+ """
+ assert x.is_contiguous(), "`x` is not contiguous"
+ assert x_s.numel() == 1, "only supports per-tensor scale"
+
+ finfo = torch.finfo(dtype)
+ fp8_min = finfo.min
+ fp8_max = finfo.max
+
+ x_ = x.reshape(x.numel() // x.shape[-1], x.shape[-1])
+ x_s_inv = 1.0 / x_s
+ x_q = (x_ * x_s_inv).clamp(min=fp8_min, max=fp8_max).to(dtype)
+ x_q = x_q.reshape(x.shape)
+
+ return x_q, x_s
+
+
+class TestStaticQuantFP8(unittest.TestCase):
+ DTYPES = [torch.half, torch.bfloat16, torch.float32]
+ NUM_TOKENS = [7, 83, 2048]
+ D = [512, 4096, 5120, 13824]
+ SEEDS = [0]
+
+ @classmethod
+ def setUpClass(cls):
+ if not torch.cuda.is_available():
+ raise unittest.SkipTest("CUDA is not available")
+ torch.set_default_device("cuda")
+
+ def _static_quant_fp8(self, num_tokens, d, dtype, seed):
+ torch.manual_seed(seed)
+
+ x = torch.rand(num_tokens, d, dtype=dtype)
+ fp8_max = torch.finfo(torch.float8_e4m3fn).max
+ x_s = x.max() / fp8_max
+
+ with torch.inference_mode():
+ ref_out, _ = native_static_quant_fp8(x, x_s)
+ out, _ = static_quant_fp8(x, x_s, repeat_scale=True)
+
+ self.assertTrue(
+ torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
+ )
+
+ def test_static_quant_fp8(self):
+ for params in itertools.product(
+ self.NUM_TOKENS,
+ self.D,
+ self.DTYPES,
+ self.SEEDS,
+ ):
+ with self.subTest(
+ num_tokens=params[0],
+ d=params[1],
+ dtype=params[2],
+ seed=params[3],
+ ):
+ self._static_quant_fp8(*params)
+
+
# For test
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
diff --git a/python/upload_pypi.sh b/python/upload_pypi.sh
deleted file mode 100644
index 35616e1dad8..00000000000
--- a/python/upload_pypi.sh
+++ /dev/null
@@ -1,6 +0,0 @@
-cp ../README.md ../LICENSE .
-rm -rf dist
-python3 -m build
-python3 -m twine upload dist/*
-
-rm -rf README.md LICENSE
diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh
index bccbd01df47..e4badb08d21 100755
--- a/scripts/ci_install_dependency.sh
+++ b/scripts/ci_install_dependency.sh
@@ -26,4 +26,4 @@ pip install "transformers @ git+https://github.com/huggingface/transformers.git@
pip install cuda-python nvidia-cuda-nvrtc-cu12
# reinstall sgl-kernel
-pip install sgl-kernel==0.0.3.post6 --force-reinstall --no-deps
+pip install sgl-kernel==0.0.4 --force-reinstall --no-deps
diff --git a/sgl-kernel/.clang-format b/sgl-kernel/.clang-format
index 5e690c02885..afbd654a790 100644
--- a/sgl-kernel/.clang-format
+++ b/sgl-kernel/.clang-format
@@ -6,3 +6,10 @@ DerivePointerAlignment: false
PointerAlignment: Left
NamespaceIndentation: None
SortIncludes: true
+AllowShortLoopsOnASingleLine: false
+BinPackParameters: false # Prevents packing parameters in declarations
+BinPackArguments: false # Prevents packing arguments in function calls
+AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis
+AlignOperands: Align # Aligns arguments vertically
+PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument
+PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name
diff --git a/sgl-kernel/3rdparty/cutlass b/sgl-kernel/3rdparty/cutlass
index ca4fdbea708..df18f5e4f5d 160000
--- a/sgl-kernel/3rdparty/cutlass
+++ b/sgl-kernel/3rdparty/cutlass
@@ -1 +1 @@
-Subproject commit ca4fdbea708ad940c905359788372b8add9f85e0
+Subproject commit df18f5e4f5de76bed8be1de8e4c245f2f5ec3020
diff --git a/sgl-kernel/3rdparty/turbomind b/sgl-kernel/3rdparty/turbomind
deleted file mode 160000
index 0c9d0c724a9..00000000000
--- a/sgl-kernel/3rdparty/turbomind
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 0c9d0c724a99974ca3af0c12b24ef8a0444c4fd9
diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile
index 986e424f403..53375fa0fa3 100644
--- a/sgl-kernel/Makefile
+++ b/sgl-kernel/Makefile
@@ -38,12 +38,12 @@ test: ## Run all tests
format: check-deps ## Format all source files
@echo "Formatting source files..."
- @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i
- @find src tests -name '*.py' | xargs isort
- @find src tests -name '*.py' | xargs black
+ @find csrc tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i
+ @find python tests -name '*.py' | xargs isort
+ @find python tests -name '*.py' | xargs black
@pre-commit run --all-files
-FILES_TO_UPDATE = src/sgl-kernel/version.py \
+FILES_TO_UPDATE = python/sgl_kernel/version.py \
pyproject.toml
update: ## Update version numbers across project files. Usage: make update
@@ -51,7 +51,7 @@ update: ## Update version numbers across project files. Usage: make update "; \
exit 1; \
fi
- @OLD_VERSION=$$(grep "version" src/sgl-kernel/version.py | cut -d '"' -f2); \
+ @OLD_VERSION=$$(grep "version" python/sgl_kernel/version.py | cut -d '"' -f2); \
NEW_VERSION=$(filter-out $@,$(MAKECMDGOALS)); \
echo "Updating version from $$OLD_VERSION to $$NEW_VERSION"; \
for file in $(FILES_TO_UPDATE); do \
diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md
index 1f805cbd000..e86c2625963 100644
--- a/sgl-kernel/README.md
+++ b/sgl-kernel/README.md
@@ -39,18 +39,16 @@ Third-party libraries:
- [CCCL](https://github.com/NVIDIA/cccl)
- [CUTLASS](https://github.com/NVIDIA/cutlass)
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
-- [TurboMind](https://github.com/InternLM/turbomind)
### Kernel Development
Steps to add a new kernel:
-1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
-2. Expose interface in [src/sgl-kernel/include/sgl_kernels_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h)
-3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc)
-4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
-5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
-6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
+1. Implement the kernel in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc)
+2. Expose the interface in [include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_ops.h)
+3. Create torch extension in [csrc/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/torch_extension.cc)
+4. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
+5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel)
### Build & Install
@@ -72,4 +70,4 @@ The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, t
### Release new version
-Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/version.py)
+Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/version.py)
diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip b/sgl-kernel/csrc/allreduce/custom_all_reduce.hip
similarity index 100%
rename from sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip
rename to sgl-kernel/csrc/allreduce/custom_all_reduce.hip
diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh b/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
similarity index 84%
rename from sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
rename to sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
index 06173bc4225..7baf5f01ef4 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
+++ b/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
@@ -153,19 +153,20 @@ DINLINE O downcast(array_t val) {
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template
-DINLINE void start_sync(const RankSignals& sg,
+DINLINE void start_sync(
+ const RankSignals& sg,
#ifndef USE_ROCM
- volatile
+ volatile
#endif
- Signal* self_sg,
- int rank) {
+ Signal* self_sg,
+ int rank) {
#ifdef USE_ROCM
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
- __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED,
- __MEMORY_SCOPE_SYSTEM);
+ __scoped_atomic_store_n(
+ &sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
flag)
@@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template
-DINLINE void end_sync(const RankSignals& sg,
+DINLINE void end_sync(
+ const RankSignals& sg,
#ifndef USE_ROCM
- volatile
+ volatile
#endif
- Signal* self_sg,
- int rank) {
+ Signal* self_sg,
+ int rank) {
#ifdef USE_ROCM
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
@@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg,
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
- __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
- final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, __MEMORY_SCOPE_SYSTEM);
+ __scoped_atomic_store_n(
+ &sg.signals[threadIdx.x]->end[blockIdx.x][rank],
+ flag,
+ final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
+ __MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
- while (__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
- final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, __MEMORY_SCOPE_DEVICE) < flag)
+ while (__scoped_atomic_load_n(
+ &self_sg->end[blockIdx.x][threadIdx.x],
+ final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
+ __MEMORY_SCOPE_DEVICE) < flag)
;
}
__syncthreads();
@@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
}
template
-__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
+__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
+ RankData* _dp,
+ RankSignals sg,
#ifndef USE_ROCM
- volatile
+ volatile
#endif
- Signal* self_sg,
- T* __restrict__ result, int rank, int size) {
+ Signal* self_sg,
+ T* __restrict__ result,
+ int rank,
+ int size) {
using P = typename packed_t::P;
using A = typename packed_t::A;
// note: we don't reorder the address so the accumulation order is the same
@@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) {
}
template
-__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
+__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
+ RankData* _dp,
+ RankSignals sg,
#ifndef USE_ROCM
- volatile
+ volatile
#endif
- Signal* self_sg,
- T* __restrict__ result, int rank, int size) {
+ Signal* self_sg,
+ T* __restrict__ result,
+ int rank,
+ int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t::P;
@@ -357,8 +372,14 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
- CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, const hipIpcMemHandle_t* handles,
- const std::vector& offsets, int rank, bool full_nvlink = true)
+ CustomAllreduce(
+ Signal* meta,
+ void* rank_data,
+ size_t rank_data_sz,
+ const hipIpcMemHandle_t* handles,
+ const std::vector& offsets,
+ int rank,
+ bool full_nvlink = true)
: rank_(rank),
world_size_(offsets.size()),
full_nvlink_(full_nvlink),
@@ -382,8 +403,8 @@ class CustomAllreduce {
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
- CUDACHECK(hipIpcOpenMemHandle((void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle),
- hipIpcMemLazyEnablePeerAccess));
+ CUDACHECK(hipIpcOpenMemHandle(
+ (void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
@@ -399,13 +420,14 @@ class CustomAllreduce {
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
- if (hipPointerGetAttribute(&base_ptr,
+ if (hipPointerGetAttribute(
+ &base_ptr,
#ifdef USE_ROCM
- HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
+ HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#else
- CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
+ CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#endif
- (hipDeviceptr_t)ptr) != hipSuccess)
+ (hipDeviceptr_t)ptr) != hipSuccess)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
@@ -415,8 +437,8 @@ class CustomAllreduce {
void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_)
- throw std::runtime_error("Rank data buffer is overflowed by " +
- std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
+ throw std::runtime_error(
+ "Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
void register_buffer(const std::vector& handles, const std::vector& offsets, void* self) {
@@ -443,8 +465,8 @@ class CustomAllreduce {
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
- void register_graph_buffers(const std::vector& handles,
- const std::vector>& offsets) {
+ void
+ register_graph_buffers(const std::vector& handles, const std::vector>& offsets) {
auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers);
std::vector rank_data(num_buffers);
@@ -474,11 +496,17 @@ class CustomAllreduce {
* will cause contention on NVLink bus.
*/
template
- void allreduce(hipStream_t stream, T* input, T* output, int size,
+ void allreduce(
+ hipStream_t stream,
+ T* input,
+ T* output,
+ int size,
#ifndef USE_ROCM
- int threads = 512, int block_limit = 36){
+ int threads = 512,
+ int block_limit = 36){
#else
- int threads = 512, int block_limit = 16) {
+ int threads = 512,
+ int block_limit = 16) {
#endif
auto d = packed_t::P::size;
if (size % d != 0)
@@ -487,8 +515,8 @@ class CustomAllreduce {
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
- throw std::runtime_error("max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " +
- std::to_string(block_limit));
+ throw std::runtime_error(
+ "max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
RankData* ptrs;
hipStreamCaptureStatus status;
@@ -499,17 +527,17 @@ class CustomAllreduce {
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
- throw std::runtime_error("buffer address " + std::to_string(reinterpret_cast(input)) +
- " is not registered!");
+ throw std::runtime_error(
+ "buffer address " + std::to_string(reinterpret_cast(input)) + " is not registered!");
ptrs = it->second;
}
size /= d;
auto bytes = size * sizeof(typename packed_t::P);
int blocks = ::min(block_limit, (size + threads - 1) / threads);
-#define KL(ngpus, name) \
- hipLaunchKernelGGL((name), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, \
- size);
+#define KL(ngpus, name) \
+ hipLaunchKernelGGL( \
+ (name), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu b/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
similarity index 94%
rename from sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
rename to sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
index fa9e3a2c5d2..f1ee5d40efd 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
+++ b/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
@@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) {
return c.packed;
}
-__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
- size_t const world_size, int const tidx, int const bidx) {
+__inline__ __device__ void multi_gpu_barrier(
+ uint32_t** signals,
+ uint32_t const flag,
+ size_t const local_rank,
+ size_t const world_size,
+ int const tidx,
+ int const bidx) {
// After this function, at least one block in each GPU has reached the barrier
if (tidx < world_size) {
// we can think of signals having the shape [world_size, world_size]
@@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
}
template
-__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
- size_t const world_size, int const tidx, int const bidx, int const grid_size) {
+__inline__ __device__ void block_barrier(
+ uint32_t** signals,
+ uint32_t const flag,
+ size_t const local_rank,
+ size_t const world_size,
+ int const tidx,
+ int const bidx,
+ int const grid_size) {
if constexpr (!start) {
__syncthreads();
}
@@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
}
}
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
- block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
- grid_size);
+ block_barrier(
+ params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
@@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
}
}
}
- block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
- grid_size);
+ block_barrier(
+ params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
@@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
}
}
- block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx,
- bidx, grid_size);
+ block_barrier(
+ params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Gather all needed elts from other intra-node ranks
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
@@ -459,8 +470,12 @@ std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
////////////////////////////////////////////////////////////////////////////////////////////////////
template
-void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block,
- cudaStream_t stream) {
+void dispatchARKernels(
+ AllReduceStrategyType algo,
+ AllReduceParams& param,
+ int blocks_per_grid,
+ int threads_per_block,
+ cudaStream_t stream) {
switch (algo) {
case AllReduceStrategyType::ONESHOT: {
oneShotAllReduceKernel<<>>(param);
@@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
CHECK_CUDA_SUCCESS(cudaGetLastError());
}
-void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat,
- cudaStream_t stream) {
+void trtCustomAllReduce(
+ AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) {
if (params.elts_total == 0) {
return;
}
diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu b/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
similarity index 89%
rename from sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
rename to sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
index af129de52ef..5c879255621 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
+++ b/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
@@ -29,9 +29,14 @@ using IPC_KEY = std::array;
class AllReduceMeta {
public:
- AllReduceMeta(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers,
- const std::vector& tmp_result_buffers, const std::vector& barrier_in,
- const std::vector& barrier_out) {
+ AllReduceMeta(
+ int64_t rank_id,
+ int64_t world_size,
+ torch::Tensor& rank_data,
+ const std::vector& buffers,
+ const std::vector& tmp_result_buffers,
+ const std::vector& barrier_in,
+ const std::vector& barrier_out) {
this->rank_id = (int)rank_id;
this->world_size = (int)world_size;
this->barrier_in = barrier_in;
@@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype)
return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0;
}
-fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers,
- const std::vector& tmp_result_buffers, const std::vector& barrier_in,
- const std::vector& barrier_out) {
+fptr_t init_custom_ar(
+ int64_t rank_id,
+ int64_t world_size,
+ torch::Tensor& rank_data,
+ const std::vector& buffers,
+ const std::vector& tmp_result_buffers,
+ const std::vector& barrier_in,
+ const std::vector& barrier_out) {
auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out);
return (fptr_t)m;
}
@@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
- CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle((void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle),
- cudaIpcMemLazyEnablePeerAccess));
+ CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
+ (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
@@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
-void register_graph_buffers(fptr_t _fa, const std::vector>& handles,
- const std::vector>& offsets) {
+void register_graph_buffers(
+ fptr_t _fa, const std::vector>& handles, const std::vector>& offsets) {
AllReduceMeta* m = reinterpret_cast(_fa);
std::vector handle_bytes;
handle_bytes.reserve(handles.size());
diff --git a/sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu b/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
similarity index 74%
rename from sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
rename to sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
index 02c50498eb9..f9d524f6001 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
+++ b/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
@@ -23,15 +23,18 @@ limitations under the License.
#define THREADS_PER_BLOCK 128
template
-__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d]
- const T* __restrict__ k, // [b, h, 1, d]
- const T* __restrict__ v, // [b, h, 1, e]
- const float* __restrict__ past_kv, // [b, h, d, e]
- const float* __restrict__ slope, // [h, 1, 1]
- T* __restrict__ output, // [b, h, 1, e]
- float* __restrict__ new_kv, // [b, h, d, e]
- const int batch_size, const int num_heads, const int qk_dim,
- const int v_dim) {
+__global__ void lightning_attention_decode_kernel(
+ const T* __restrict__ q, // [b, h, 1, d]
+ const T* __restrict__ k, // [b, h, 1, d]
+ const T* __restrict__ v, // [b, h, 1, e]
+ const float* __restrict__ past_kv, // [b, h, d, e]
+ const float* __restrict__ slope, // [h, 1, 1]
+ T* __restrict__ output, // [b, h, 1, e]
+ float* __restrict__ new_kv, // [b, h, d, e]
+ const int batch_size,
+ const int num_heads,
+ const int qk_dim,
+ const int v_dim) {
extern __shared__ char smem[];
T* __restrict__ q_shared = reinterpret_cast(smem);
T* __restrict__ k_shared = reinterpret_cast(smem + qk_dim * sizeof(T));
@@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
}
}
-void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
- const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
- torch::Tensor new_kv) {
+void lightning_attention_decode(
+ const torch::Tensor& q,
+ const torch::Tensor& k,
+ const torch::Tensor& v,
+ const torch::Tensor& past_kv,
+ const torch::Tensor& slope,
+ torch::Tensor output,
+ torch::Tensor new_kv) {
TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
@@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k,
at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
lightning_attention_decode_kernel<<>>(
- q.data_ptr(), k.data_ptr(), v.data_ptr(), past_kv.data_ptr(),
- slope.data_ptr(), output.data_ptr(), new_kv.data_ptr(), batch_size, num_heads,
- qk_dim, v_dim);
+ q.data_ptr(),
+ k.data_ptr(),
+ v.data_ptr(),
+ past_kv.data_ptr(),
+ slope.data_ptr(),
+ output.data_ptr(),
+ new_kv.data_ptr(),
+ batch_size,
+ num_heads,
+ qk_dim,
+ v_dim);
}));
}
diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
similarity index 81%
rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
rename to sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
index f5cd4381563..9f85bee28b1 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
+++ b/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
@@ -25,9 +25,15 @@ namespace cutlass {
namespace epilogue {
namespace threadblock {
-template
+template <
+ typename ThreadblockShape_,
+ int ThreadCount,
+ typename ScaleTileIterator_,
+ typename OutputTileIterator_,
+ typename ElementAccumulator_,
+ typename ElementCompute_,
+ typename ElementwiseFunctor_,
+ bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol {
public:
using ThreadblockShape = ThreadblockShape_;
@@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol {
Arguments(typename ElementwiseFunctor::Params elementwise_)
: elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
- Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_,
- int64_t batch_stride_D_)
+ Arguments(
+ typename ElementwiseFunctor::Params elementwise_,
+ int64_t batch_stride_alpha_,
+ int64_t batch_stride_C_,
+ int64_t batch_stride_D_)
: elementwise(elementwise_),
batch_stride_alpha(batch_stride_alpha_),
batch_stride_C(batch_stride_C_),
@@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol {
public:
CUTLASS_DEVICE
- EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
- cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
- typename ScaleTileIterator::Params params_alpha_col,
- typename OutputTileIterator::Params params_C,
- typename OutputTileIterator::Params params_D, bool with_bias, bool per_token_quant,
- bool per_channel_quant, AlphaScaleElementType* ptr_alpha_row,
- AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C,
- typename OutputTileIterator::Element* ptr_D,
- cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
- int column_offset = 0,
- cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
+ EpilogueVisitorPerRowPerCol(
+ Params const& params,
+ SharedStorage& shared_storage,
+ cutlass::MatrixCoord const& problem_size,
+ int thread_idx,
+ int warp_idx,
+ int lane_idx,
+ typename ScaleTileIterator::Params params_alpha_col,
+ typename OutputTileIterator::Params params_C,
+ typename OutputTileIterator::Params params_D,
+ bool with_bias,
+ bool per_token_quant,
+ bool per_channel_quant,
+ AlphaScaleElementType* ptr_alpha_row,
+ AlphaScaleElementType* ptr_alpha_col,
+ typename OutputTileIterator::Element* ptr_C,
+ typename OutputTileIterator::Element* ptr_D,
+ cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
+ int column_offset = 0,
+ cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
: params_(params),
shared_storage_(shared_storage),
extent_(problem_size),
@@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol {
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
- void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
- int split_k_slices) { ///< Total number of split-K slices
+ void set_k_partition(
+ int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
+ int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
@@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol {
private:
CUTLASS_DEVICE
- ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col,
- AlphaScaleElementType const& scale_row) {
+ ComputeFragment per_token_channel_scale_accumulator_(
+ ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
@@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol {
}
CUTLASS_DEVICE
- ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col,
- AlphaScaleElementType const& scale_row) {
+ ComputeFragment per_token_scale_accumulator_(
+ ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
similarity index 100%
rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
rename to sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
similarity index 100%
rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
rename to sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
similarity index 62%
rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
rename to sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
index 48b0ad9490e..f62b51ee7ed 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
+++ b/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
@@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
-template , class KernelSchedule = KernelTmaWarpSpecialized,
- int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
- // while zero-value `ScaleGranularityM` indicates that scaling
- // granularity is `size<0>(TileShape_MNK{})` along M.
- >
+template <
+ int Stages_,
+ class ClusterShape_ = Shape<_1, _1, _1>,
+ class KernelSchedule = KernelTmaWarpSpecialized,
+ int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
+ // while zero-value `ScaleGranularityM` indicates that scaling
+ // granularity is `size<0>(TileShape_MNK{})` along M.
+ >
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
: MainloopSm90TmaGmmaWarpSpecialized {
- static_assert(cute::is_same_v>,
- "KernelSchedule must be one of the warp specialized policies");
+ static_assert(
+ cute::
+ is_same_v>,
+ "KernelSchedule must be one of the warp specialized policies");
};
//////////////////////////////////////////////////////////////////////////////
diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
similarity index 93%
rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
rename to sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
index 3de9ff078b6..b58d84318ba 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
+++ b/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
@@ -159,8 +159,9 @@ class GemmUniversalBaseCompat {
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
- CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n"
- << " result = {" << result << "}");
+ CUTLASS_TRACE_HOST(
+ " grid_tiled_shape: " << grid_tiled_shape << "\n"
+ << " result = {" << result << "}");
return result;
}
@@ -175,8 +176,8 @@ class GemmUniversalBaseCompat {
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
if (smem_size <= (48 << 10)) {
- cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel,
- GemmKernel::kThreadCount, smem_size);
+ cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+ &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size);
if (result == cudaSuccess) {
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
@@ -184,12 +185,12 @@ class GemmUniversalBaseCompat {
}
} else {
// Query assuming zero shared memory then compute occupancy limit based on SMEM
- cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel,
- GemmKernel::kThreadCount, 0);
+ cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+ &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0);
if (result != cudaSuccess) {
- CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
- << cudaGetErrorString(result));
+ CUTLASS_TRACE_HOST(
+ " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
@@ -226,8 +227,9 @@ class GemmUniversalBaseCompat {
/// Initializes GEMM state from arguments.
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
- CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace "
- << workspace << ", stream: " << (stream ? "non-null" : "null"));
+ CUTLASS_TRACE_HOST(
+ "GemmUniversalBaseCompat::initialize() - workspace " << workspace
+ << ", stream: " << (stream ? "non-null" : "null"));
size_t workspace_bytes = get_workspace_size(args);
diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
similarity index 88%
rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
rename to sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
index 11fc872505f..905d11ba2c6 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
+++ b/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
@@ -32,10 +32,11 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
-template
+template <
+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
+ typename Epilogue_, ///! Epilogue
+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function
+ >
struct GemmWithEpilogueVisitor {
public:
using Mma = Mma_;
@@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor {
Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {}
/// constructs an arguments structure
- Arguments(GemmCoord problem_size_, TensorRefA ref_A_, TensorRefB ref_B_, TensorRefAlphaCol ref_alpha_col_,
- TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_,
- typename EpilogueVisitor::Arguments epilogue_visitor_)
+ Arguments(
+ GemmCoord problem_size_,
+ TensorRefA ref_A_,
+ TensorRefB ref_B_,
+ TensorRefAlphaCol ref_alpha_col_,
+ TensorRefAlphaRow ref_alpha_row_,
+ TensorRefC ref_C_,
+ TensorRefC ref_D_,
+ typename EpilogueVisitor::Arguments epilogue_visitor_)
: mode(GemmUniversalMode::kGemm),
problem_size(problem_size_),
batch_count(1),
@@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
- } else if (platform::is_same>::value ||
- platform::is_same>::value) {
+ } else if (
+ platform::is_same>::value ||
+ platform::is_same>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
@@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
- } else if (platform::is_same>::value ||
- platform::is_same>::value) {
+ } else if (
+ platform::is_same>::value ||
+ platform::is_same>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
@@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
- } else if (platform::is_same>::value ||
- platform::is_same>::value) {
+ } else if (
+ platform::is_same>::value ||
+ platform::is_same>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
@@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor {
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
- typename Mma::IteratorA iterator_A(params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx,
- tb_offset_A);
+ typename Mma::IteratorA iterator_A(
+ params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
- typename Mma::IteratorB iterator_B(params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx,
- tb_offset_B);
+ typename Mma::IteratorB iterator_B(
+ params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
@@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor {
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
- MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM,
- threadblock_tile_offset.n() * Mma::Shape::kN);
+ MatrixCoord threadblock_offset(
+ threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
@@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor {
with_bias = false;
}
- EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, params.problem_size.mn(),
- thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C,
- params.params_D, with_bias, true, true, params.ptr_alpha_row, params.ptr_alpha_col,
- params.ptr_C, params.ptr_D, threadblock_offset,
- blockIdx.y * params.problem_size.m());
+ EpilogueVisitor epilogue_visitor(
+ params.epilogue_visitor,
+ shared_storage.epilogue.visitor,
+ params.problem_size.mn(),
+ thread_idx,
+ warp_idx,
+ lane_idx,
+ params.params_alpha_col,
+ params.params_C,
+ params.params_D,
+ with_bias,
+ true,
+ true,
+ params.ptr_alpha_row,
+ params.ptr_alpha_col,
+ params.ptr_C,
+ params.ptr_D,
+ threadblock_offset,
+ blockIdx.y * params.problem_size.m());
if (params.mode == GemmUniversalMode::kGemm) {
// Indicate which position in a serial reduction the output operator is currently updating
diff --git a/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu b/sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu
similarity index 82%
rename from sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu
rename to sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu
index a4ae14ae59d..41f4d2e7099 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu
+++ b/sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu
@@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
// support float16, bfloat16 and float32
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::FusedAddRMSNorm(
- static_cast(input.data_ptr()), static_cast(residual.data_ptr()),
- static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream);
- TORCH_CHECK(status == cudaSuccess,
- "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
+ static_cast(input.data_ptr()),
+ static_cast(residual.data_ptr()),
+ static_cast(weight.data_ptr()),
+ batch_size,
+ hidden_size,
+ eps,
+ torch_current_stream);
+ TORCH_CHECK(
+ status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true;
});
}
diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu b/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
similarity index 71%
rename from sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
rename to sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
index ec899d33024..d0a80c7bff5 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
+++ b/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
@@ -21,10 +21,13 @@
#include "utils.h"
-static void check_group_count(const std::vector& inputs, const std::vector& weights,
- const std::vector& outputs) {
- TORCH_CHECK(((inputs.size() == weights.size()) && (inputs.size() == outputs.size())),
- "The group count of inputs, weights and outputs should be the same.");
+static void check_group_count(
+ const std::vector& inputs,
+ const std::vector& weights,
+ const std::vector& outputs) {
+ TORCH_CHECK(
+ ((inputs.size() == weights.size()) && (inputs.size() == outputs.size())),
+ "The group count of inputs, weights and outputs should be the same.");
}
static void check_device_dtype(const torch::Dtype& dtype, const std::vector& tensors) {
@@ -68,21 +71,26 @@ static std::vector get_tensor_ptrs(const std::vector& tens
static torch::Tensor create_ptr_pointer(const std::vector& ptrs, cudaStream_t stream) {
auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA);
torch::Tensor gpu_ptrs = torch::empty({static_cast(ptrs.size())}, options);
- TORCH_CHECK(cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice,
- stream) == CUBLAS_STATUS_SUCCESS);
+ TORCH_CHECK(
+ cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) ==
+ CUBLAS_STATUS_SUCCESS);
return gpu_ptrs;
}
// We want compute input @ weight^T in row major
// This is equivalent to computing weight @ input^T in col major
// Cublas only accepts matrix in column major, so this arrangement is needed
-void cublas_grouped_gemm(const std::vector& inputs, // b: (m, k) row major = (k, m) col major
- const std::vector& weights, // a: (n, k) row major = (n, k)^T col major
- const std::vector& outputs, // c: (m, n) row major = (n, m) col major
- const torch::Dtype& out_dtype, int64_t cublas_handle, int64_t cuda_stream) {
- TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16,
- "cublas grouped_gemm can"
- "only be applied to float16 and bfloat16 dtype");
+void cublas_grouped_gemm(
+ const std::vector& inputs, // b: (m, k) row major = (k, m) col major
+ const std::vector& weights, // a: (n, k) row major = (n, k)^T col major
+ const std::vector& outputs, // c: (m, n) row major = (n, m) col major
+ const torch::Dtype& out_dtype,
+ int64_t cublas_handle,
+ int64_t cuda_stream) {
+ TORCH_CHECK(
+ out_dtype == torch::kHalf || out_dtype == torch::kBFloat16,
+ "cublas grouped_gemm can"
+ "only be applied to float16 and bfloat16 dtype");
int group_count = inputs.size();
check_group_count(inputs, weights, outputs);
@@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector& inputs, // b: (m, k
torch::Tensor d_c = create_ptr_pointer(c_array, stream);
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
- auto status = cublasGemmGroupedBatchedEx(handle, transa_array.data(), transb_array.data(), m_array.data(),
- n_array.data(), k_array.data(), alpha_array.data(), (void**)d_a.data_ptr(),
- cuda_data_type, lda_array.data(), (void**)d_b.data_ptr(), cuda_data_type,
- ldb_array.data(), beta_array.data(), (void**)d_c.data_ptr(), cuda_data_type,
- ldc_array.data(), group_count, group_size.data(), CUBLAS_COMPUTE_32F);
+ auto status = cublasGemmGroupedBatchedEx(
+ handle,
+ transa_array.data(),
+ transb_array.data(),
+ m_array.data(),
+ n_array.data(),
+ k_array.data(),
+ alpha_array.data(),
+ (void**)d_a.data_ptr(),
+ cuda_data_type,
+ lda_array.data(),
+ (void**)d_b.data_ptr(),
+ cuda_data_type,
+ ldb_array.data(),
+ beta_array.data(),
+ (void**)d_c.data_ptr(),
+ cuda_data_type,
+ ldc_array.data(),
+ group_count,
+ group_size.data(),
+ CUBLAS_COMPUTE_32F);
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status));
TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization");
return;
#endif
- TORCH_CHECK_NOT_IMPLEMENTED(false,
- "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion());
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion());
}
diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
similarity index 81%
rename from sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
rename to sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
index 337a5ad69ac..a62a5c0ce6d 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
+++ b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
@@ -35,8 +35,12 @@
using namespace cute;
template
-void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b) {
+void launch_sm90_fp8_blockwise_scaled_mm(
+ torch::Tensor& out,
+ const torch::Tensor& a,
+ const torch::Tensor& b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b) {
using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;
@@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
- ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC,
- LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp;
+ ArchTag,
+ OperatorClass,
+ TileShape,
+ ClusterShape,
+ EpilogueTileType,
+ ElementAccumulator,
+ ElementCompute,
+ ElementC,
+ LayoutC,
+ AlignmentC,
+ ElementD,
+ LayoutD,
+ AlignmentD,
+ EpilogueSchedule,
+ StoreEpilogueCompute>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
- ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator,
- TileShape, ClusterShape,
+ ArchTag,
+ OperatorClass,
+ ElementA,
+ LayoutA,
+ AlignmentA,
+ ElementB,
+ LayoutB,
+ AlignmentB,
+ ElementAccumulator,
+ TileShape,
+ ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
- using GemmKernel =
- cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape
- CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>;
+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
+ Shape, // Indicates ProblemShape
+ CollectiveMainloop,
+ CollectiveEpilogue,
+ cutlass::gemm::PersistentScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
Gemm gemm_op;
@@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
}
template
-void sm90_fp8_blockwise_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b) {
+void sm90_fp8_blockwise_dispatch_shape(
+ torch::Tensor& out,
+ const torch::Tensor& a,
+ const torch::Tensor& b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b) {
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>;
launch_sm90_fp8_blockwise_scaled_mm(out, a, b, scales_a, scales_b);
}
-torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b,
- const torch::Dtype& out_dtype) {
+torch::Tensor fp8_blockwise_scaled_mm(
+ const torch::Tensor& mat_a,
+ const torch::Tensor& mat_b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b,
+ const torch::Dtype& out_dtype) {
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
@@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
- TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0,
- "mat_a must be multiple of 16 bytes for memory alignment");
- TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0,
- "mat_b must be multiple of 16 bytes for memory alignment");
+ TORCH_CHECK(
+ (mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
+ TORCH_CHECK(
+ (mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
@@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
#endif
#endif
- TORCH_CHECK_NOT_IMPLEMENTED(false,
- "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
}
diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
similarity index 52%
rename from sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
rename to sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
index 36b9585f349..64731ebe4d2 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
+++ b/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
@@ -53,10 +53,17 @@ limitations under the License.
using namespace cute;
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
-template typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT,
- typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>>
+template <
+ typename ElementType,
+ typename OutElementType,
+ typename AccumElementType,
+ typename CtaShape,
+ typename WarpShape,
+ int Stages,
+ bool WithBias,
+ typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
+ template typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT,
+ typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>>
struct DeviceGemmFp8RowwiseSm89 {
static_assert(std::is_same_v, "ElementType must be FP8(e4m3)");
@@ -85,56 +92,86 @@ struct DeviceGemmFp8RowwiseSm89 {
// Number of epilogue stages in EVT
static constexpr int EVTEpilogueStages = 1;
- using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout;
+ using OutputTileThreadMap = cutlass::epilogue::threadblock::
+ OutputTileThreadLayout;
// Definition of EVT
using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch;
using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
- using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>;
+ cutlass::multiplies,
+ ElementComputeEpilogue,
+ ElementComputeEpilogue,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+ using bScaleSrc = cutlass::epilogue::threadblock::
+ VisitorRowBroadcast>;
using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT;
- using ComputeAScale =
- cutlass::epilogue::threadblock::VisitorCompute;
- using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast>;
+ using ComputeAScale = cutlass::epilogue::threadblock::
+ VisitorCompute;
+ using aScaleSrc = cutlass::epilogue::threadblock::
+ VisitorColBroadcast>;
using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT;
// With bias
using biasSrc =
cutlass::epilogue::threadblock::VisitorRowBroadcast>;
- using ComputeAScaleWithBias =
- cutlass::epilogue::threadblock::VisitorCompute;
+ using ComputeAScaleWithBias = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiply_add,
+ ElementC,
+ ElementComputeEpilogue,
+ cutlass::FloatRoundStyle::round_to_nearest>;
using EpilogueAScaleWithBias =
cutlass::epilogue::threadblock::Sm80EVT;
using dTar = cutlass::epilogue::threadblock::VisitorAuxStore<
- OutputTileThreadMap, ElementC, cutlass::FloatRoundStyle::round_to_nearest, Stride>;
- using EpilogueStore =
- typename cutlass::platform::conditional,
- cutlass::epilogue::threadblock::Sm80EVT>::type;
+ OutputTileThreadMap,
+ ElementC,
+ cutlass::FloatRoundStyle::round_to_nearest,
+ Stride>;
+ using EpilogueStore = typename cutlass::platform::conditional<
+ WithBias,
+ cutlass::epilogue::threadblock::Sm80EVT,
+ cutlass::epilogue::threadblock::Sm80EVT>::type;
using EpilogueOp = EpilogueStore;
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
- ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB,
- cutlass::ComplexTransform::kNone, AlignmentB, ElementC, LayoutC, AlignmentC, ElementAccumulator,
- ElementComputeEpilogue, OperatorClass, ArchTag, CtaShape, WarpShape, InstructionShape, EpilogueOp,
- ThreadblockSwizzle, Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel;
+ ElementA,
+ LayoutA,
+ cutlass::ComplexTransform::kNone,
+ AlignmentA,
+ ElementB,
+ LayoutB,
+ cutlass::ComplexTransform::kNone,
+ AlignmentB,
+ ElementC,
+ LayoutC,
+ AlignmentC,
+ ElementAccumulator,
+ ElementComputeEpilogue,
+ OperatorClass,
+ ArchTag,
+ CtaShape,
+ WarpShape,
+ InstructionShape,
+ EpilogueOp,
+ ThreadblockSwizzle,
+ Stages,
+ FP8MathOperator,
+ EVTEpilogueStages>::GemmKernel;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
};
template
-typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b,
- const c10::optional& bias) {
+typename Gemm::Arguments prepare_sm89_fp8_args(
+ torch::Tensor& out,
+ const torch::Tensor& a,
+ const torch::Tensor& b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b,
+ const c10::optional& bias) {
using ElementT = typename Gemm::ElementA;
using ElementOutput = typename Gemm::ElementD;
using ElementComputeEpilogue = float;
@@ -158,54 +195,61 @@ typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr());
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr());
- typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode
- {m, n, k}, // Problem size
- 1, // Split-k factor
- {}, // Epilogue args
- ptr_a, // a pointer
- ptr_b, // b pointer
- nullptr, // c pointer (unused)
- nullptr, // d pointer (unused)
- m * k, // batch stride a (unused)
- n * k, // batch stride b (unused)
- m * n, // batch stride c (unused)
- m * n, // batch stride d (unused)
- lda, // stride a
- ldb, // stride b
- ldc, // stride c (unused)
- ldc); // stride d (unused)
+ typename Gemm::Arguments args(
+ cutlass::gemm::GemmUniversalMode::kGemm, // Mode
+ {m, n, k}, // Problem size
+ 1, // Split-k factor
+ {}, // Epilogue args
+ ptr_a, // a pointer
+ ptr_b, // b pointer
+ nullptr, // c pointer (unused)
+ nullptr, // d pointer (unused)
+ m * k, // batch stride a (unused)
+ n * k, // batch stride b (unused)
+ m * n, // batch stride c (unused)
+ m * n, // batch stride d (unused)
+ lda, // stride a
+ ldb, // stride b
+ ldc, // stride c (unused)
+ ldc); // stride d (unused)
if constexpr (WithBias) {
- args.epilogue = {{
- {
- {}, // Accumulator
- {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
- {} // Multiplies
- },
- {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
- {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}},
- {} // Multiplies
- },
- {ptr_d, {n, _1{}, _0{}}}};
+ args.epilogue = {
+ {
+ {
+ {}, // Accumulator
+ {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
+ {} // Multiplies
+ },
+ {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
+ {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}},
+ {} // Multiplies
+ },
+ {ptr_d, {n, _1{}, _0{}}}};
} else {
- args.epilogue = {{
- {
- {}, // Accumulator
- {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
- {} // Multiplies
- },
- {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
- {} // Multiplies
- },
- {ptr_d, {n, _1{}, _0{}}}};
+ args.epilogue = {
+ {
+ {
+ {}, // Accumulator
+ {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
+ {} // Multiplies
+ },
+ {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
+ {} // Multiplies
+ },
+ {ptr_d, {n, _1{}, _0{}}}};
}
return args;
}
template
-void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b,
- const c10::optional& bias) {
+void launch_sm89_fp8_scaled_mm(
+ torch::Tensor& out,
+ const torch::Tensor& a,
+ const torch::Tensor& b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b,
+ const c10::optional& bias) {
auto args = prepare_sm89_fp8_args(out, a, b, scales_a, scales_b, bias);
Gemm gemm_op;
@@ -222,109 +266,187 @@ void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const
}
template
-void sm89_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b,
- const c10::optional& bias) {
+void sm89_fp8_dispatch_bias(
+ torch::Tensor& out,
+ const torch::Tensor& a,
+ const torch::Tensor& b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b,
+ const c10::optional& bias) {
using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutType;
using AccumElementType = float;
if (bias) {
- using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm;
+ using Gemm = typename DeviceGemmFp8RowwiseSm89<
+ ElementInput,
+ ElementOutput,
+ AccumElementType,
+ CtaShape,
+ WarpShape,
+ Stages,
+ true>::Gemm;
return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias);
} else {
- using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm;
+ using Gemm = typename DeviceGemmFp8RowwiseSm89<
+ ElementInput,
+ ElementOutput,
+ AccumElementType,
+ CtaShape,
+ WarpShape,
+ Stages,
+ false>::Gemm;
return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias);
}
}
template
-void sm89_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b,
- const c10::optional& bias) {
+void sm89_fp8_dispatch_shape(
+ torch::Tensor& out,
+ const torch::Tensor& a,
+ const torch::Tensor& b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b,
+ const c10::optional& bias) {
uint32_t const m = a.size(0);
uint32_t const n = out.size(1);
if (m == 1) {
if (n <= 8192) {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<16, 64, 128>,
+ cutlass::gemm::GemmShape<16, 64, 64>,
+ 7>(out, a, b, scales_a, scales_b, bias);
} else {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<32, 64, 128>,
+ cutlass::gemm::GemmShape<16, 64, 64>,
+ 5>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 16) {
// M in (1, 16]
if (n <= 8192) {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<16, 64, 128>,
+ cutlass::gemm::GemmShape<16, 64, 64>,
+ 4>(out, a, b, scales_a, scales_b, bias);
} else if (n <= 16384) {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<32, 64, 128>,
+ cutlass::gemm::GemmShape<16, 64, 64>,
+ 5>(out, a, b, scales_a, scales_b, bias);
} else {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<16, 64, 128>,
+ cutlass::gemm::GemmShape<16, 64, 64>,
+ 7>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 64) {
// M in (16, 64]
if (n <= 16384) {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<32, 64, 128>,
+ cutlass::gemm::GemmShape<16, 64, 64>,
+ 7>(out, a, b, scales_a, scales_b, bias);
} else {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<16, 64, 128>,
+ cutlass::gemm::GemmShape<16, 64, 64>,
+ 7>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 128) {
// M in (64, 128]
if (n <= 8192) {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<64, 64, 128>,
+ cutlass::gemm::GemmShape<32, 64, 64>,
+ 4>(out, a, b, scales_a, scales_b, bias);
} else if (n <= 16384) {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<64, 64, 128>,
+ cutlass::gemm::GemmShape<32, 64, 64>,
+ 5>(out, a, b, scales_a, scales_b, bias);
} else {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<32, 64, 128>,
+ cutlass::gemm::GemmShape<16, 64, 64>,
+ 5>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 256) {
// M in (128, 256]
if (n <= 8192) {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<128, 64, 64>,
+ cutlass::gemm::GemmShape<64, 32, 64>,
+ 5>(out, a, b, scales_a, scales_b, bias);
} else if (n <= 16384) {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<64, 128, 64>,
+ cutlass::gemm::GemmShape<64, 32, 64>,
+ 7>(out, a, b, scales_a, scales_b, bias);
} else {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<128, 64, 128>,
+ cutlass::gemm::GemmShape<64, 32, 128>,
+ 4>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 512) {
// M in (256, 512)
if (n <= 16384) {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<128, 128, 64>,
+ cutlass::gemm::GemmShape<64, 32, 64>,
+ 2>(out, a, b, scales_a, scales_b, bias);
} else {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<128, 128, 64>,
+ cutlass::gemm::GemmShape<64, 32, 64>,
+ 4>(out, a, b, scales_a, scales_b, bias);
}
} else {
// M in (512, inf)
if (n <= 8192) {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<128, 128, 64>,
+ cutlass::gemm::GemmShape<64, 32, 64>,
+ 3>(out, a, b, scales_a, scales_b, bias);
} else {
- return sm89_fp8_dispatch_bias,
- cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
+ return sm89_fp8_dispatch_bias<
+ OutType,
+ cutlass::gemm::GemmShape<128, 128, 64>,
+ cutlass::gemm::GemmShape<64, 32, 64>,
+ 2>(out, a, b, scales_a, scales_b, bias);
}
}
}
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
-template
+template <
+ typename ElementType,
+ typename OutElementType,
+ typename AccumElementType,
+ typename CTAShape,
+ typename ClusterShape,
+ typename MainloopScheduleType,
+ typename EpilogueScheduleType,
+ typename TileSchedulerType = void,
+ bool WithBias = false>
struct DeviceGemmFp8RowwiseSm90 {
static_assert(std::is_same_v, "ElementType must be FP8(e4m3)");
@@ -374,44 +496,70 @@ struct DeviceGemmFp8RowwiseSm90 {
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default
// setting in the Collective Builder
// Implement rowwise scaling epilogue.
- using XScale =
- cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue,
- cute::Stride, cute::Int<0>, cute::Int<0>>>;
-
- using WScale =
- cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue,
- cute::Stride, cute::Int<1>, cute::Int<0>>>;
-
- using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput,
- cute::Stride, cute::Int<1>, cute::Int<0>>>;
+ using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
+ 0,
+ TileShape,
+ ElementComputeEpilogue,
+ ElementComputeEpilogue,
+ cute::Stride, cute::Int<0>, cute::Int<0>>>;
+
+ using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
+ 0,
+ TileShape,
+ ElementComputeEpilogue,
+ ElementComputeEpilogue,
+ cute::Stride, cute::Int<1>, cute::Int<0>>>;
+
+ using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
+ 0,
+ TileShape,
+ ElementOutput,
+ ElementOutput,
+ cute::Stride, cute::Int<1>, cute::Int<0>>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
- using Compute0 = cutlass::epilogue::fusion::Sm90Compute;
+ using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::multiplies,
+ ElementComputeEpilogue, // First stage output type.
+ ElementComputeEpilogue, // First stage input types.
+ cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT;
- using Compute1 = cutlass::epilogue::fusion::Sm90Compute;
+ using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::multiplies,
+ ElementOutput,
+ ElementComputeEpilogue, // Second stage input types.
+ cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT;
// With bias
- using ComputeWithBias =
- cutlass::epilogue::fusion::Sm90Compute;
+ using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::multiply_add,
+ ElementOutput,
+ ElementComputeEpilogue,
+ cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT;
using EpilogueEVT = typename cutlass::platform::conditional::type;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
- cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementComputeEpilogue, ElementC, LayoutC,
- AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput, cutlass::epilogue::TmaWarpSpecialized,
+ cutlass::arch::Sm90,
+ cutlass::arch::OpClassTensorOp,
+ TileShape,
+ ClusterShape,
+ cutlass::epilogue::collective::EpilogueTileAuto,
+ ElementAccumulator,
+ ElementComputeEpilogue,
+ ElementC,
+ LayoutC,
+ AlignmentC,
+ ElementOutput,
+ LayoutOutput,
+ AlignmentOutput,
+ cutlass::epilogue::TmaWarpSpecialized,
EpilogueEVT>::CollectiveOp;
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
@@ -423,22 +571,38 @@ struct DeviceGemmFp8RowwiseSm90 {
using FastAccum = FastPongSchedule; // Default apply Pingpong
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
- ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator,
- TileShape, ClusterShape,
+ ArchTag,
+ OperatorClass,
+ ElementA,
+ LayoutA,
+ AlignmentA,
+ ElementB,
+ LayoutB,
+ AlignmentB,
+ ElementAccumulator,
+ TileShape,
+ ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType>::CollectiveOp;
- using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape
- CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>;
+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
+ Shape, // Indicates ProblemShape
+ CollectiveMainloop,
+ CollectiveEpilogue,
+ TileSchedulerType>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
};
template
-typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b,
- const c10::optional& bias) {
+typename Gemm::Arguments prepare_sm90_fp8_args(
+ torch::Tensor& out,
+ const torch::Tensor& a,
+ const torch::Tensor& b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b,
+ const c10::optional& bias) {
using ElementT = typename Gemm::ElementA;
using ElementOutput = typename Gemm::ElementD;
using ElementComputeEpilogue = float;
@@ -465,14 +629,15 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
StrideC stride_c;
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
- typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm,
- {m, n, k, 1},
- {ptr_a, stride_a, ptr_b, stride_b},
- {{}, // epilogue.thread
- nullptr,
- stride_c,
- ptr_d,
- stride_d}};
+ typename Gemm::Arguments args = {
+ cutlass::gemm::GemmUniversalMode::kGemm,
+ {m, n, k, 1},
+ {ptr_a, stride_a, ptr_b, stride_b},
+ {{}, // epilogue.thread
+ nullptr,
+ stride_c,
+ ptr_d,
+ stride_d}};
if constexpr (WithBias) {
args.epilogue.thread = {
{ptr_scales_a},
@@ -500,9 +665,13 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::
}
template
-void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b,
- const c10::optional& bias) {
+void launch_sm90_fp8_scaled_mm(
+ torch::Tensor& out,
+ const torch::Tensor& a,
+ const torch::Tensor& b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b,
+ const c10::optional& bias) {
auto args = prepare_sm90_fp8_args(out, a, b, scales_a, scales_b, bias);
Gemm gemm_op;
@@ -519,66 +688,117 @@ void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const
TORCH_CHECK(status == cutlass::Status::kSuccess)
}
-template
-void sm90_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b,
- const c10::optional& bias, bool fast_accum = true,
- bool use_persistent = false) {
+template <
+ typename OutType,
+ typename CTAShape,
+ typename ClusterShape,
+ typename MainloopScheduleType,
+ typename TileSchedulerType>
+void sm90_fp8_dispatch_bias(
+ torch::Tensor& out,
+ const torch::Tensor& a,
+ const torch::Tensor& b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b,
+ const c10::optional& bias,
+ bool fast_accum = true,
+ bool use_persistent = false) {
using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutType;
using AccumElementType = float;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
if (bias) {
- using Gemm =
- typename DeviceGemmFp8RowwiseSm90::Gemm;
+ using Gemm = typename DeviceGemmFp8RowwiseSm90<
+ ElementInput,
+ ElementOutput,
+ AccumElementType,
+ CTAShape,
+ ClusterShape,
+ MainloopScheduleType,
+ EpilogueScheduleType,
+ TileSchedulerType,
+ true>::Gemm;
return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias);
} else {
- using Gemm =
- typename DeviceGemmFp8RowwiseSm90::Gemm;
+ using Gemm = typename DeviceGemmFp8RowwiseSm90<
+ ElementInput,
+ ElementOutput,
+ AccumElementType,
+ CTAShape,
+ ClusterShape,
+ MainloopScheduleType,
+ EpilogueScheduleType,
+ TileSchedulerType,
+ false>::Gemm;
return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias);
}
}
template
-void sm90_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b,
- const c10::optional& bias) {
+void sm90_fp8_dispatch_shape(
+ torch::Tensor& out,
+ const torch::Tensor& a,
+ const torch::Tensor& b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b,
+ const c10::optional& bias) {
uint32_t const m = a.size(0);
using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using PersistentTileScheduler = cutlass::gemm::PersistentScheduler;
using BasicTileScheduler = void;
if (m <= 1) {
- return sm90_fp8_dispatch_bias, Shape<_1, _8, _1>, FastBasicScheduler,
- BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
+ return sm90_fp8_dispatch_bias<
+ OutType,
+ Shape<_64, _64, _128>,
+ Shape<_1, _8, _1>,
+ FastBasicScheduler,
+ BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
}
if (m <= 64) {
// m in [1, 64]
- return sm90_fp8_dispatch_bias, Shape<_1, _4, _1>, FastPingpongScheduler,
- PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
+ return sm90_fp8_dispatch_bias<
+ OutType,
+ Shape<_64, _64, _128>,
+ Shape<_1, _4, _1>,
+ FastPingpongScheduler,
+ PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else if (m <= 256) {
// m in (64, 256]
- return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler,
- PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
+ return sm90_fp8_dispatch_bias<
+ OutType,
+ Shape<_64, _64, _128>,
+ Shape<_1, _1, _1>,
+ FastPingpongScheduler,
+ PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else if (m <= 1024) {
// m in (256, 1024]
- return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler,
- PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
+ return sm90_fp8_dispatch_bias<
+ OutType,
+ Shape<_128, _128, _128>,
+ Shape<_1, _1, _1>,
+ FastPingpongScheduler,
+ PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else {
// m in (1024, inf)
- return sm90_fp8_dispatch_bias, Shape<_2, _1, _1>, FastPingpongScheduler,
- PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
+ return sm90_fp8_dispatch_bias<
+ OutType,
+ Shape<_128, _128, _128>,
+ Shape<_2, _1, _1>,
+ FastPingpongScheduler,
+ PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
}
}
#endif
-torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
- const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
- const c10::optional& bias) {
+torch::Tensor fp8_scaled_mm(
+ const torch::Tensor& mat_a,
+ const torch::Tensor& mat_b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b,
+ const torch::Dtype& out_dtype,
+ const c10::optional& bias) {
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
@@ -587,10 +807,10 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
- TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0,
- "mat_a must be multiple of 16 bytes for memory alignment");
- TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0,
- "mat_b must be multiple of 16 bytes for memory alignment");
+ TORCH_CHECK(
+ (mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
+ TORCH_CHECK(
+ (mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu b/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
similarity index 50%
rename from sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
rename to sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
index 4a8130d667e..86aa3b8f2f4 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
+++ b/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
@@ -35,11 +35,20 @@ limitations under the License.
using namespace cute;
-template
-void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b,
- const c10::optional& bias) {
+template <
+ typename ElementOutput,
+ typename ArchTag,
+ typename ThreadblockShape,
+ typename WarpShape,
+ typename InstructionShape,
+ int NumStages>
+void cutlass_int8_scaled_mm(
+ torch::Tensor& out,
+ const torch::Tensor& mat_a,
+ const torch::Tensor& mat_b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b,
+ const c10::optional& bias) {
using ElementAccumulator = int32_t;
using ElementCompute = float;
using ElementInputA = int8_t;
@@ -48,30 +57,51 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
using OperatorClass = cutlass::arch::OpClassTensorOp;
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
- using DefaultGemmConf = cutlass::gemm::device::DefaultGemmConfiguration;
+ using DefaultGemmConf = cutlass::gemm::device::
+ DefaultGemmConfiguration;
using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp;
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
- ElementInputA, cutlass::layout::RowMajor, DefaultGemmConf::kAlignmentA, ElementInputB,
- cutlass::layout::ColumnMajor, DefaultGemmConf::kAlignmentB, ElementOutput, cutlass::layout::RowMajor,
- ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
- ThreadblockSwizzle, NumStages, true, typename DefaultGemmConf::Operator>::GemmKernel;
+ ElementInputA,
+ cutlass::layout::RowMajor,
+ DefaultGemmConf::kAlignmentA,
+ ElementInputB,
+ cutlass::layout::ColumnMajor,
+ DefaultGemmConf::kAlignmentB,
+ ElementOutput,
+ cutlass::layout::RowMajor,
+ ElementAccumulator,
+ OperatorClass,
+ ArchTag,
+ ThreadblockShape,
+ WarpShape,
+ InstructionShape,
+ EpilogueOutputOp,
+ ThreadblockSwizzle,
+ NumStages,
+ true,
+ typename DefaultGemmConf::Operator>::GemmKernel;
using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape,
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count,
GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads,
- GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, cutlass::sizeof_bits::value>,
+ GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess,
+ cutlass::sizeof_bits::value>,
ElementCompute>;
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol<
- ThreadblockShape, GemmKernel_::kThreadCount, AlphaColTileIterator,
- typename GemmKernel_::Epilogue::OutputTileIterator, ElementAccumulator, ElementCompute, EpilogueOutputOp>;
+ ThreadblockShape,
+ GemmKernel_::kThreadCount,
+ AlphaColTileIterator,
+ typename GemmKernel_::Epilogue::OutputTileIterator,
+ ElementAccumulator,
+ ElementCompute,
+ EpilogueOutputOp>;
- using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
- EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue;
+ using Epilogue = typename cutlass::epilogue::threadblock::
+ EpilogueWithVisitorFromExistingEpilogue::Epilogue;
using GemmKernel =
cutlass::gemm::kernel::GemmWithEpilogueVisitor;
@@ -104,98 +134,164 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
typename EpilogueOutputOp::Params linearScalingParams;
typename EpilogueVisitor::Arguments visitor_args{linearScalingParams};
- typename Gemm::Arguments args{{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0},
- {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args};
+ typename Gemm::Arguments args{
+ {m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args};
- auto workspace = torch::empty(gemm_op.get_workspace_size(args),
- torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
+ auto workspace = torch::empty(
+ gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
auto can_implement = gemm_op.can_implement(args);
- TORCH_CHECK(can_implement == cutlass::Status::kSuccess,
- "gemm cannot implement, error: ", cutlassGetStatusString(can_implement));
+ TORCH_CHECK(
+ can_implement == cutlass::Status::kSuccess,
+ "gemm cannot implement, error: ",
+ cutlassGetStatusString(can_implement));
auto status = gemm_op(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
}
template
-void sm75_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
- const torch::Tensor& scales_a, const torch::Tensor& scales_b,
- const c10::optional& bias) {
+void sm75_dispatch_shape(
+ torch::Tensor& out,
+ const torch::Tensor& mat_a,
+ const torch::Tensor& mat_b,
+ const torch::Tensor& scales_a,
+ const torch::Tensor& scales_b,
+ const c10::optional& bias) {
int m = mat_a.size(0);
if (m <= 32) {
- cutlass_int8_scaled_mm,
- cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
- scales_b, bias);
+ cutlass_int8_scaled_mm<
+ ElementOutput,
+ ArchTag,
+ cutlass::gemm::GemmShape<32, 128, 64>,
+ cutlass::gemm::GemmShape<32, 64, 64>,
+ InstructionShape,
+ 2>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (m <= 64) {
- cutlass_int8_scaled_mm,
- cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
- scales_b, bias);
+ cutlass_int8_scaled_mm<
+ ElementOutput,
+ ArchTag,
+ cutlass::gemm::GemmShape<64, 128, 128>,
+ cutlass::gemm::GemmShape<64, 64, 64>,
+ InstructionShape,
+ 2>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (m <= 256) {
- cutlass_int8_scaled_mm,
- cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
- scales_b, bias);
+ cutlass_int8_scaled_mm<
+ ElementOutput,
+ ArchTag,
+ cutlass::gemm::GemmShape<128, 128, 128>,
+ cutlass::gemm::GemmShape<64, 64, 64>,
+ InstructionShape,
+ 2>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
- cutlass_int8_scaled_mm