Skip to content

Commit

Permalink
Merge branch 'main' into xiezhq-check
Browse files Browse the repository at this point in the history
  • Loading branch information
xiezhq-hermann authored Jan 27, 2025
2 parents b914897 + 351a72d commit d90e164
Show file tree
Hide file tree
Showing 22 changed files with 1,742 additions and 21 deletions.
54 changes: 43 additions & 11 deletions benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
get_default_config,
get_moe_configs,
)
from sglang.srt.utils import is_hip

_is_hip_ = is_hip()


class BenchmarkConfig(TypedDict):
Expand Down Expand Up @@ -102,8 +105,8 @@ def benchmark_config(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)

w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)

input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)

Expand Down Expand Up @@ -165,17 +168,15 @@ def run():
return avg


def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]:
for block_n in [32, 64, 128, 256]:
waves_per_eu_range = 0
for num_stages in [2]:
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
Expand All @@ -184,11 +185,39 @@ def get_configs_compute_bound() -> List[Dict[str, int]]:
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs


def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs: List[BenchmarkConfig] = []
if _is_hip_:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs


@ray.remote(num_gpus=1)
class BenchmarkWorker:

Expand Down Expand Up @@ -297,6 +326,9 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
}


Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
):
super().__init__()

Expand All @@ -140,6 +141,7 @@ def __init__(
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.activation = activation

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
Expand All @@ -166,6 +168,7 @@ def __init__(

def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
assert self.activation == "silu"

if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
Expand Down
20 changes: 17 additions & 3 deletions python/sglang/srt/layers/moe/fused_moe_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from torch.nn import functional as F

from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import select_experts


Expand All @@ -23,6 +23,7 @@ def fused_moe_forward_native(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
Expand All @@ -41,7 +42,12 @@ def fused_moe_forward_native(
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
if activation == "silu":
x1 = F.silu(x1)
elif activation == "gelu":
x1 = F.gelu(x1)
else:
raise ValueError(f"Unsupported activation: {activation=}")
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
Expand All @@ -58,6 +64,7 @@ def moe_forward_native(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:

topk_weights, topk_ids = select_experts(
Expand All @@ -84,6 +91,13 @@ def moe_forward_native(
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()

if activation == "silu":
act = SiluAndMul()
elif activation == "gelu":
act = GeluAndMul()
else:
raise ValueError(f"Unsupported activation: {activation=}")

outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
Expand All @@ -96,7 +110,7 @@ def moe_forward_native(
layer_w2_weight = layer.w2_weight[i]

gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
gate_up = SiluAndMul()(gate_up)
gate_up = act(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight)
outputs.append(expert_out)
start_idx = end_idx
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
{
"1": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"64": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"2048": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"3072": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
}
}
Loading

0 comments on commit d90e164

Please sign in to comment.