Skip to content

Commit

Permalink
Merge branch 'main' into debug_radixcache_stack_overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
luzengxiangcn authored Feb 5, 2025
2 parents 1ecc9ef + 7ab8494 commit 635c8a3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
13 changes: 10 additions & 3 deletions python/sglang/srt/layers/quantization/fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import triton
import triton.language as tl

from sglang.srt.utils import get_device_name, is_hip
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip

is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
Expand Down Expand Up @@ -450,9 +450,16 @@ def grid(META):
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)

# Use manually unrolledx4 kernel on AMD GPU.
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
kernel = (
_w8a8_block_fp8_matmul_unrolledx4 if is_hip_ == True else _w8a8_block_fp8_matmul
_w8a8_block_fp8_matmul_unrolledx4
if (is_hip_ == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)

kernel[grid](
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,13 @@ def get_device_name(device_id: int = 0) -> str:
return torch.hpu.get_device_name(device_id)


def get_device_core_count(device_id: int = 0) -> int:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return torch.cuda.get_device_properties(device_id).multi_processor_count

return 0


def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
major, minor = None, None
if hasattr(torch, "cuda") and torch.cuda.is_available():
Expand Down

0 comments on commit 635c8a3

Please sign in to comment.