diff --git a/examples/deepseek_v3/configs/base.py b/examples/deepseek_v3/configs/base.py new file mode 100644 index 000000000..e63b2aafc --- /dev/null +++ b/examples/deepseek_v3/configs/base.py @@ -0,0 +1,102 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base configuration for the model.""" + +import dataclasses +from typing import Literal + + +@dataclasses.dataclass +class ModelArgs: + """Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + n_routed_experts (int): Number of routed experts for MoE layers. + n_shared_experts (int): Number of shared experts for MoE layers. + n_activated_experts (int): Number of activated experts in MoE layers. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE + routing. + route_scale (float): Scaling factor for routing scores. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without + positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary + embeddings. + v_head_dim (int): Dimension for value projections. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + mscale (float): Scaling factor for extended attention. + world_size (int): World size. + rank (int): Rank. + block_size (int): Block size. + gemm_impl (Literal["bf16", "fp8"] | None): Implementation for GEMM + operations. + attn_impl (Literal["naive", "absorb"]): Implementation for attention + operations. + """ + + max_batch_size: int = 8 + max_seq_len: int = 4096 * 4 + dtype: Literal["bf16", "fp8"] = "bf16" + vocab_size: int = 102400 + dim: int = 2048 + inter_dim: int = 10944 + moe_inter_dim: int = 1408 + n_layers: int = 27 + n_dense_layers: int = 1 + n_heads: int = 16 + # moe + n_routed_experts: int = 64 + n_shared_experts: int = 2 + n_activated_experts: int = 6 + n_expert_groups: int = 1 + n_limited_groups: int = 1 + score_func: Literal["softmax", "sigmoid"] = "softmax" + route_scale: float = 1.0 + # mla + q_lora_rank: int = 0 + kv_lora_rank: int = 512 + qk_nope_head_dim: int = 128 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + # yarn + original_seq_len: int = 4096 + rope_theta: float = 10000.0 + rope_factor: float = 40 + beta_fast: int = 32 + beta_slow: int = 1 + mscale: float = 1.0 + # misc + world_size: int = 1 + rank: int = 0 + block_size: int = 128 + gemm_impl: Literal["bf16", "fp8"] | None = "bf16" + attn_impl: Literal["naive", "absorb"] = "absorb" diff --git a/examples/deepseek_v3/configs/config_16B.py b/examples/deepseek_v3/configs/config_16B.py new file mode 100644 index 000000000..dd46c6bf8 --- /dev/null +++ b/examples/deepseek_v3/configs/config_16B.py @@ -0,0 +1,40 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Configuration for the 16B model.""" + +from configs.base import ModelArgs + + +def get_config(): + """Returns the configuration for the model.""" + return ModelArgs( + vocab_size=102400, + dim=2048, + inter_dim=10944, + moe_inter_dim=1408, + n_layers=27, + n_dense_layers=1, + n_heads=16, + n_routed_experts=64, + n_shared_experts=2, + n_activated_experts=6, + route_scale=1.0, + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.707, + ) diff --git a/examples/deepseek_v3/configs/config_236B.py b/examples/deepseek_v3/configs/config_236B.py new file mode 100644 index 000000000..977602ca2 --- /dev/null +++ b/examples/deepseek_v3/configs/config_236B.py @@ -0,0 +1,41 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Configuration for the 236B model.""" + +from configs.base import ModelArgs + + +def get_config(): + """Returns the configuration for the model.""" + return ModelArgs( + vocab_size=102400, + dim=5120, + inter_dim=12288, + moe_inter_dim=1536, + n_layers=60, + n_dense_layers=1, + n_heads=128, + n_routed_experts=160, + n_shared_experts=2, + n_activated_experts=6, + n_expert_groups=8, + n_limited_groups=3, + route_scale=16.0, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) diff --git a/examples/deepseek_v3/configs/config_671B.py b/examples/deepseek_v3/configs/config_671B.py new file mode 100644 index 000000000..ef6d9290b --- /dev/null +++ b/examples/deepseek_v3/configs/config_671B.py @@ -0,0 +1,43 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Configuration for the 671B model.""" + +from configs.base import ModelArgs + + +def get_config(): + """Returns the configuration for the model.""" + return ModelArgs( + vocab_size=129280, + dim=7168, + inter_dim=18432, + moe_inter_dim=2048, + n_layers=61, + n_dense_layers=3, + n_heads=128, + n_routed_experts=256, + n_shared_experts=1, + n_activated_experts=8, + n_expert_groups=8, + n_limited_groups=4, + route_scale=2.5, + score_func="sigmoid", + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + dtype="fp8", + ) diff --git a/examples/deepseek_v3/kernel.py b/examples/deepseek_v3/kernel.py new file mode 100644 index 000000000..ff5346ff6 --- /dev/null +++ b/examples/deepseek_v3/kernel.py @@ -0,0 +1,258 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Triton kernels for PyTorch.""" + +from typing import Tuple + +import jax +import jax.numpy as jnp +import torch +import triton +import triton.language as tl + + +@triton.jit +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized + values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors + will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each + program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448.0 + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + +def act_quant( + x: torch.Tensor, block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and + its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for + quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous() + assert x.size(-1) % block_size == 0 + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']),) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + +def act_quant_jax(x: jax.Array): + s = jnp.max(jnp.abs(x)) / 448.0 + y = x / s + return y, s + + +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128 +) -> torch.Tensor: + """Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. + Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions + are not 2. + """ + assert x.is_contiguous() and s.is_contiguous() + assert x.dim() == 2 and s.dim() == 2 + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: ( + triton.cdiv(M, meta['BLOCK_SIZE']), + triton.cdiv(N, meta['BLOCK_SIZE']), + ) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + + +def weight_dequant_jax(x: jax.Array, s: jax.Array): + y = x * s + return y + + +fp8_gemm_configs = [ + triton.Config( + {'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, + num_stages=num_stages, + num_warps=8, + ) + for block_m in [16, 32, 64] + for block_n in [32, 64, 128] + for num_stages in [3, 4, 5, 6] +] + + +@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) +@triton.jit +def fp8_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """Performs a matrix multiplication operation on FP8 matrices with scaling factors. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A. + b_ptr (tl.tensor): Pointer to the second input matrix B. + c_ptr (tl.tensor): Pointer to the output matrix C. + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def fp8_gemm( + a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor +): + """Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be + contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must + be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) + return c + + +def fp8_gemm_jax(a: jax.Array, a_s: jax.Array, b: jax.Array, b_s: jax.Array): + a = a.astype(jnp.float32) + b = b.astype(jnp.float32) + a_s = a_s.astype(jnp.float32) + b_s = b_s.astype(jnp.float32) + a = a / a_s + b = b / b_s + c = jnp.dot(a, b) + return c diff --git a/examples/deepseek_v3/kernel_pytorch.py b/examples/deepseek_v3/kernel_pytorch.py new file mode 100644 index 000000000..253c8882d --- /dev/null +++ b/examples/deepseek_v3/kernel_pytorch.py @@ -0,0 +1,235 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Triton kernels for PyTorch.""" + +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized + values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors + will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each + program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448.0 + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + + +def act_quant( + x: torch.Tensor, block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and + its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for + quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous() + assert x.size(-1) % block_size == 0 + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']),) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128 +) -> torch.Tensor: + """Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. + Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions + are not 2. + """ + assert x.is_contiguous() and s.is_contiguous() + assert x.dim() == 2 and s.dim() == 2 + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: ( + triton.cdiv(M, meta['BLOCK_SIZE']), + triton.cdiv(N, meta['BLOCK_SIZE']), + ) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + + +fp8_gemm_configs = [ + triton.Config( + {'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, + num_stages=num_stages, + num_warps=8, + ) + for block_m in [16, 32, 64] + for block_n in [32, 64, 128] + for num_stages in [3, 4, 5, 6] +] + + +@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) +@triton.jit +def fp8_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """Performs a matrix multiplication operation on FP8 matrices with scaling factors. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A. + b_ptr (tl.tensor): Pointer to the second input matrix B. + c_ptr (tl.tensor): Pointer to the output matrix C. + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def fp8_gemm( + a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor +): + """Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be + contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must + be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) + return c diff --git a/examples/deepseek_v3/model.py b/examples/deepseek_v3/model.py new file mode 100644 index 000000000..715a58207 --- /dev/null +++ b/examples/deepseek_v3/model.py @@ -0,0 +1,910 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Model in PyTorch.""" +import math +from typing import Literal, Optional, Tuple +from absl import app +from absl import flags +from flax import nnx +from configs.base import ModelArgs +from kernel import act_quant_jax, fp8_gemm_jax, weight_dequant_jax +import jax +import jax.numpy as jnp +import torch +from torch import nn +import torch.distributed as dist +import torch.nn.functional as F + + +# world_size = 1 +# rank = 0 +# block_size = 128 +# gemm_impl: Literal["bf16", "fp8"] = "bf16" +# attn_impl: Literal["naive", "absorb"] = "absorb" + + +class ParallelEmbedding(nnx.Module): + """Embedding layer with parallelism support across distributed processes. + + Args: + vocab_size (int): Vocabulary size. + dim (int): Embedding dimension. + """ + + def __init__(self, config: ModelArgs, dim: int): + super().__init__() + self.vocab_size = config.vocab_size + self.dim = dim + self.world_size = config.world_size + self.rank = config.rank + assert config.vocab_size % config.world_size == 0 + self.part_vocab_size = config.vocab_size // config.world_size + self.vocab_start_idx = config.rank * self.part_vocab_size + self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size + self.weight = nnx.Param(jnp.empty((self.part_vocab_size, self.dim))) + + def __call__( + self, x: jax.Array + ) -> jax.Array: # [batch_size, seq_len] -> [batch_size, seq_len, dim] + """Forward pass for parallel embedding layer. + + Args: + x (jax.Array): Input tensor containing token indices. + + Returns: + jax.Array: Embedded representations. + + Raises: + ValueError: If `world_size` is not defined. + """ + # x: [batch_size, seq_len] + # mask: [batch_size, seq_len] + mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) + if self.world_size > 1: + x = x - self.vocab_start_idx # [batch_size, seq_len] + x = jnp.where(mask, 0, x) # [batch_size, seq_len] + y = jnp.take(self.weight.value, x, axis=0) # [batch_size, seq_len, dim] + # NOTE: for now let the XLA compiler do this + # if world_size > 1: + # y = jnp.where(mask[..., None], 0, y) # [batch_size, seq_len, dim] + # y = jax.lax.psum(y, "data") # [batch_size, seq_len, dim] + return y + + +def linear( + x: jax.Array, + weight: jax.Array, + bias: jax.Array | None = None, + *, + weight_scale: jax.Array | None = None, + config: ModelArgs, +) -> jax.Array: + """Applies a linear transformation to the incoming data: y = xA^T + b. + + This function supports specialized implementations based on quantization and + tensor formats. + + Args: + x (jax.Array): The input tensor. + weight (jax.Array): The weight tensor. It may be quantized and requires + dequantization for certain cases. + bias (Optional[jax.Array]): The bias tensor to be added. Default is None. + + Returns: + jax.Array: The result of the linear transformation, which may involve + quantization-aware computations depending on the input parameters. + + Notes: + - If `weight` is quantized (e.g., `element_size() > 1`), a dequantized + version + is used for computation. + - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are + applied. + - For other cases, the function applies quantization to `x` and uses + `fp8_gemm` for computation. + """ + if weight.dtype.itemsize > 1: + y = jnp.dot(x, weight) + if bias is not None: + y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) + return y + elif config.gemm_impl == "bf16": + assert weight_scale is not None + weight = weight_dequant_jax(weight, weight_scale) + y = jnp.dot(x, weight) + if bias is not None: + y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) + return y + else: + assert weight_scale is not None + x, scale = act_quant_jax(x) + y = fp8_gemm_jax(x, scale, weight, weight_scale) + if bias is not None: + y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) + return y + + +class Linear(nn.Module): + """Custom linear layer with support for quantized weights and optional bias. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + part_out_features: int + def __init__( + self, + in_features: int, + out_features: int, + *, + bias: bool = False, + dtype=torch.bfloat16, + config: ModelArgs, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nnx.Param(jnp.empty((out_features, in_features), dtype=dtype)) + block_size = config.block_size + self.config = config + if self.weight.value.dtype.itemsize == 1: + scale_out_features = (out_features + block_size - 1) // block_size + scale_in_features = (in_features + block_size - 1) // block_size + self.scale = nnx.Param( + jnp.empty( + (scale_out_features, scale_in_features), dtype=torch.float32 + ) + ) + else: + self.scale = None + if bias: + self.bias = nnx.Param(jnp.empty((self.part_out_features,))) + else: + self.bias = None + + def __call__(self, x: jax.Array) -> jax.Array: + """Forward pass for the custom linear layer. + + Args: + x (jax.Array): Input tensor. + + Returns: + jax.Array: Transformed tensor after linear computation. + """ + weight_scale = self.scale.value if self.scale is not None else None + return linear( + x, self.weight, self.bias, weight_scale=weight_scale, config=self.config + ) + + +class ColumnParallelLinear(Linear): + """Linear layer with column parallelism, splitting output features across distributed processes. + + Args: + in_features (int): Number of input features. + out_features (int): Total number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + + def __init__( + self, + in_features: int, + out_features: int, + *, + bias: bool = False, + dtype=None, + config: ModelArgs, + ): + assert out_features % config.world_size == 0 + self.part_out_features = out_features // config.world_size + super().__init__( + in_features, + self.part_out_features, + bias=bias, + dtype=dtype, + config=config, + ) + + # def __call__(self, x: jax.Array) -> jax.Array: + # """Forward pass for column parallel linear layer. + + # Args: + # x (jax.Array): Input tensor. + + # Returns: + # jax.Array: Transformed tensor with column-parallel computation. + # """ + # y = linear(x, self.weight, self.bias) + # return y + + +# class RowParallelLinear(Linear): +# """Linear layer with row parallelism, splitting input features across distributed processes. + +# Args: +# in_features (int): Total number of input features. +# out_features (int): Number of output features. +# bias (bool): Whether to include a bias term. Defaults to False. +# dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. +# """ + +# def __init__( +# self, in_features: int, out_features: int, bias: bool = False, dtype=None +# ): +# assert in_features % world_size == 0 +# self.part_in_features = in_features // world_size +# super().__init__(self.part_in_features, out_features, bias, dtype) + +# def forward(self, x: jax.Array) -> jax.Array: +# """Forward pass for row parallel linear layer. + +# Args: +# x (jax.Array): Input tensor. + +# Returns: +# jax.Array: Transformed tensor with row-parallel computation. +# """ +# y = linear(x, self.weight) +# if world_size > 1: +# dist.all_reduce(y) +# if self.bias is not None: +# y += self.bias +# return y + + +# class RMSNorm(nn.Module): +# """Root Mean Square Layer Normalization (RMSNorm). + +# Args: +# dim (int): Dimension of the input tensor. +# eps (float): Epsilon value for numerical stability. Defaults to 1e-6. +# """ + +# def __init__(self, dim: int, eps: float = 1e-6): +# super().__init__() +# self.dim = dim +# self.eps = eps +# self.weight = nn.Parameter(torch.ones(dim)) + +# def forward(self, x: jax.Array): +# """Forward pass for RMSNorm. + +# Args: +# x (jax.Array): Input tensor. + +# Returns: +# jax.Array: Normalized tensor with the same shape as input. +# """ +# return F.rms_norm(x, (self.dim,), self.weight, self.eps) + + +# def precompute_freqs_cis(args: ModelArgs) -> jax.Array: +# """Precomputes frequency-based complex exponential values for rotary positional embeddings. + +# Args: +# args (ModelArgs): Model arguments containing positional embedding +# parameters. + +# Returns: +# jax.Array: Precomputed complex exponential values for positional +# embeddings. +# """ +# dim = args.qk_rope_head_dim +# seqlen = args.max_seq_len +# beta_fast = args.beta_fast +# beta_slow = args.beta_slow +# base = args.rope_theta +# factor = args.rope_factor + +# def find_correction_dim(num_rotations, dim, base, max_seq_len): +# """Computes the correction dimension for a given number of rotations in the rotary positional embedding. + +# Args: +# num_rotations (float): Number of rotations to compute the correction +# for. +# dim (int): Dimensionality of the embedding space. +# base (float): Base value for the exponential computation. +# max_seq_len (int): Maximum sequence length. + +# Returns: +# float: The correction dimension based on the input parameters. +# """ +# return ( +# dim +# * math.log(max_seq_len / (num_rotations * 2 * math.pi)) +# / (2 * math.log(base)) +# ) + +# def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): +# """Computes the range of correction dimensions for rotary positional embeddings. + +# Args: +# low_rot (float): Lower bound for the number of rotations. +# high_rot (float): Upper bound for the number of rotations. +# dim (int): Dimensionality of the embedding space. +# base (float): Base value for the exponential computation. +# max_seq_len (int): Maximum sequence length. + +# Returns: +# Tuple[int, int]: The range of correction dimensions (low, high), clamped +# to valid indices. +# """ +# low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) +# high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) +# return max(low, 0), min(high, dim - 1) + +# def linear_ramp_factor(min, max, dim): +# """Computes a linear ramp function used to smooth values between a minimum and maximum range. + +# Args: +# min (float): Minimum value for the ramp function. +# max (float): Maximum value for the ramp function. +# dim (int): Dimensionality of the ramp tensor. + +# Returns: +# jax.Array: A tensor of shape (dim,) with values linearly interpolated +# between 0 and 1, +# clamped to the range [0, 1]. +# """ +# if min == max: +# max += 0.001 +# linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) +# ramp_func = torch.clamp(linear_func, 0, 1) +# return ramp_func + +# freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) +# if seqlen > args.original_seq_len: +# low, high = find_correction_range( +# beta_fast, beta_slow, dim, base, args.original_seq_len +# ) +# smooth = 1 - linear_ramp_factor(low, high, dim // 2) +# freqs = freqs / factor * (1 - smooth) + freqs * smooth + +# t = torch.arange(seqlen) +# freqs = torch.outer(t, freqs) +# freqs_cis = torch.polar(torch.ones_like(freqs), freqs) +# return freqs_cis + + +# def apply_rotary_emb(x: jax.Array, freqs_cis: jax.Array) -> jax.Array: +# """Applies rotary positional embeddings to the input tensor. + +# Args: +# x (jax.Array): Input tensor with positional embeddings to be applied. +# freqs_cis (jax.Array): Precomputed complex exponential values for +# positional embeddings. + +# Returns: +# jax.Array: Tensor with rotary embeddings applied. +# """ +# dtype = x.dtype +# x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) +# freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) +# y = torch.view_as_real(x * freqs_cis).flatten(3) +# return y.to(dtype) + + +# class MLA(nn.Module): +# """Multi-Headed Attention Layer (MLA). + +# Attributes: +# dim (int): Dimensionality of the input features. +# n_heads (int): Number of attention heads. +# n_local_heads (int): Number of local attention heads for distributed +# systems. +# q_lora_rank (int): Rank for low-rank query projection. +# kv_lora_rank (int): Rank for low-rank key/value projection. +# qk_nope_head_dim (int): Dimensionality of non-positional query/key +# projections. +# qk_rope_head_dim (int): Dimensionality of rotary-positional query/key +# projections. +# qk_head_dim (int): Total dimensionality of query/key projections. +# v_head_dim (int): Dimensionality of value projections. +# softmax_scale (float): Scaling factor for softmax in attention +# computation. +# """ + +# def __init__(self, args: ModelArgs): +# super().__init__() +# self.dim = args.dim +# self.n_heads = args.n_heads +# self.n_local_heads = args.n_heads // world_size +# self.q_lora_rank = args.q_lora_rank +# self.kv_lora_rank = args.kv_lora_rank +# self.qk_nope_head_dim = args.qk_nope_head_dim +# self.qk_rope_head_dim = args.qk_rope_head_dim +# self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim +# self.v_head_dim = args.v_head_dim + +# if self.q_lora_rank == 0: +# self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim) +# else: +# self.wq_a = Linear(self.dim, self.q_lora_rank) +# self.q_norm = RMSNorm(self.q_lora_rank) +# self.wq_b = ColumnParallelLinear( +# self.q_lora_rank, self.n_heads * self.qk_head_dim +# ) +# self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) +# self.kv_norm = RMSNorm(self.kv_lora_rank) +# self.wkv_b = ColumnParallelLinear( +# self.kv_lora_rank, +# self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), +# ) +# self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) +# self.softmax_scale = self.qk_head_dim**-0.5 +# if args.max_seq_len > args.original_seq_len: +# mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 +# self.softmax_scale = self.softmax_scale * mscale * mscale + +# if attn_impl == "naive": +# self.register_buffer( +# "k_cache", +# torch.zeros( +# args.max_batch_size, +# args.max_seq_len, +# self.n_local_heads, +# self.qk_head_dim, +# ), +# persistent=False, +# ) +# self.register_buffer( +# "v_cache", +# torch.zeros( +# args.max_batch_size, +# args.max_seq_len, +# self.n_local_heads, +# self.v_head_dim, +# ), +# persistent=False, +# ) +# else: +# self.register_buffer( +# "kv_cache", +# torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), +# persistent=False, +# ) +# self.register_buffer( +# "pe_cache", +# torch.zeros( +# args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim +# ), +# persistent=False, +# ) + +# def forward( +# self, +# x: jax.Array, +# start_pos: int, +# freqs_cis: jax.Array, +# mask: Optional[jax.Array], +# ): +# """Forward pass for the Multi-Headed Attention Layer (MLA). + +# Args: +# x (jax.Array): Input tensor of shape (batch_size, seq_len, dim). +# start_pos (int): Starting position in the sequence for caching. +# freqs_cis (jax.Array): Precomputed complex exponential values for +# rotary embeddings. +# mask (Optional[jax.Array]): Mask tensor to exclude certain positions +# from attention. + +# Returns: +# jax.Array: Output tensor with the same shape as the input. +# """ +# bsz, seqlen, _ = x.size() +# end_pos = start_pos + seqlen +# if self.q_lora_rank == 0: +# q = self.wq(x) +# else: +# q = self.wq_b(self.q_norm(self.wq_a(x))) +# q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) +# q_nope, q_pe = torch.split( +# q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 +# ) +# q_pe = apply_rotary_emb(q_pe, freqs_cis) +# kv = self.wkv_a(x) +# kv, k_pe = torch.split( +# kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 +# ) +# k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) +# if attn_impl == "naive": +# q = torch.cat([q_nope, q_pe], dim=-1) +# kv = self.wkv_b(self.kv_norm(kv)) +# kv = kv.view( +# bsz, +# seqlen, +# self.n_local_heads, +# self.qk_nope_head_dim + self.v_head_dim, +# ) +# k_nope, v = torch.split( +# kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 +# ) +# k = torch.cat( +# [k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1 +# ) +# self.k_cache[:bsz, start_pos:end_pos] = k +# self.v_cache[:bsz, start_pos:end_pos] = v +# scores = ( +# torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) +# * self.softmax_scale +# ) +# else: +# wkv_b = ( +# self.wkv_b.weight +# if self.wkv_b.scale is None +# else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) +# ) +# wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) +# q_nope = torch.einsum( +# "bshd,hdc->bshc", q_nope, wkv_b[:, : self.qk_nope_head_dim] +# ) +# self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) +# self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) +# scores = ( +# torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +# + torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos]) +# ) * self.softmax_scale +# if mask is not None: +# scores += mask.unsqueeze(1) +# scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) +# if attn_impl == "naive": +# x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) +# else: +# x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) +# x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim :]) +# x = self.wo(x.flatten(2)) +# return x + + +# class MLP(nn.Module): +# """Multi-Layer Perceptron (MLP) used as a feed-forward layer. + +# Attributes: +# w1 (nn.Module): Linear layer for input-to-hidden transformation. +# w2 (nn.Module): Linear layer for hidden-to-output transformation. +# w3 (nn.Module): Additional linear layer for feature transformation. +# """ + +# def __init__(self, dim: int, inter_dim: int): +# """Initializes the MLP layer. + +# Args: +# dim (int): Input and output dimensionality. +# inter_dim (int): Hidden layer dimensionality. +# """ +# super().__init__() +# self.w1 = ColumnParallelLinear(dim, inter_dim) +# self.w2 = RowParallelLinear(inter_dim, dim) +# self.w3 = ColumnParallelLinear(dim, inter_dim) + +# def forward(self, x: jax.Array) -> jax.Array: +# """Forward pass for the MLP layer. + +# Args: +# x (jax.Array): Input tensor. + +# Returns: +# jax.Array: Output tensor after MLP computation. +# """ +# return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +# class Gate(nn.Module): +# """Gating mechanism for routing inputs in a mixture-of-experts (MoE) model. + +# Attributes: +# dim (int): Dimensionality of input features. +# topk (int): Number of top experts activated for each input. +# n_groups (int): Number of groups for routing. +# topk_groups (int): Number of groups to route inputs to. +# score_func (str): Scoring function ('softmax' or 'sigmoid'). +# route_scale (float): Scaling factor for routing weights. +# weight (torch.nn.Parameter): Learnable weights for the gate. +# bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. +# """ + +# def __init__(self, args: ModelArgs): +# """Initializes the Gate module. + +# Args: +# args (ModelArgs): Model arguments containing gating parameters. +# """ +# super().__init__() +# self.dim = args.dim +# self.topk = args.n_activated_experts +# self.n_groups = args.n_expert_groups +# self.topk_groups = args.n_limited_groups +# self.score_func = args.score_func +# self.route_scale = args.route_scale +# self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) +# self.bias = ( +# nn.Parameter(torch.empty(args.n_routed_experts)) +# if self.dim == 7168 +# else None +# ) + +# def forward(self, x: jax.Array) -> Tuple[jax.Array, jax.Array]: +# """Forward pass for the gating mechanism. + +# Args: +# x (jax.Array): Input tensor. + +# Returns: +# Tuple[jax.Array, jax.Array]: Routing weights and selected expert +# indices. +# """ +# scores = linear(x, self.weight) +# if self.score_func == "softmax": +# scores = scores.softmax(dim=-1, dtype=torch.float32) +# else: +# scores = scores.sigmoid() +# original_scores = scores +# if self.bias is not None: +# scores = scores + self.bias +# if self.n_groups > 1: +# scores = scores.view(x.size(0), self.n_groups, -1) +# if self.bias is None: +# group_scores = scores.amax(dim=-1) +# else: +# group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) +# indices = group_scores.topk(self.topk_groups, dim=-1)[1] +# mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True) +# scores = (scores * mask.unsqueeze(-1)).flatten(1) +# indices = torch.topk(scores, self.topk, dim=-1)[1] +# weights = original_scores.gather(1, indices) +# if self.score_func == "sigmoid": +# weights /= weights.sum(dim=-1, keepdim=True) +# weights *= self.route_scale +# return weights.type_as(x), indices + + +# class Expert(nn.Module): +# """Expert layer for Mixture-of-Experts (MoE) models. + +# Attributes: +# w1 (nn.Module): Linear layer for input-to-hidden transformation. +# w2 (nn.Module): Linear layer for hidden-to-output transformation. +# w3 (nn.Module): Additional linear layer for feature transformation. +# """ + +# def __init__(self, dim: int, inter_dim: int): +# """Initializes the Expert layer. + +# Args: +# dim (int): Input and output dimensionality. +# inter_dim (int): Hidden layer dimensionality. +# """ +# super().__init__() +# self.w1 = Linear(dim, inter_dim) +# self.w2 = Linear(inter_dim, dim) +# self.w3 = Linear(dim, inter_dim) + +# def forward(self, x: jax.Array) -> jax.Array: +# """Forward pass for the Expert layer. + +# Args: +# x (jax.Array): Input tensor. + +# Returns: +# jax.Array: Output tensor after expert computation. +# """ +# return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +# class MoE(nn.Module): +# """Mixture-of-Experts (MoE) module. + +# Attributes: +# dim (int): Dimensionality of input features. +# n_routed_experts (int): Total number of experts in the model. +# n_local_experts (int): Number of experts handled locally in distributed +# systems. +# n_activated_experts (int): Number of experts activated for each input. +# gate (nn.Module): Gating mechanism to route inputs to experts. +# experts (nn.ModuleList): List of expert modules. +# shared_experts (nn.Module): Shared experts applied to all inputs. +# """ + +# def __init__(self, args: ModelArgs): +# """Initializes the MoE module. + +# Args: +# args (ModelArgs): Model arguments containing MoE parameters. +# """ +# super().__init__() +# self.dim = args.dim +# assert args.n_routed_experts % world_size == 0 +# self.n_routed_experts = args.n_routed_experts +# self.n_local_experts = args.n_routed_experts // world_size +# self.n_activated_experts = args.n_activated_experts +# self.experts_start_idx = rank * self.n_local_experts +# self.experts_end_idx = self.experts_start_idx + self.n_local_experts +# self.gate = Gate(args) +# self.experts = nn.ModuleList([ +# Expert(args.dim, args.moe_inter_dim) +# if self.experts_start_idx <= i < self.experts_end_idx +# else None +# for i in range(self.n_routed_experts) +# ]) +# self.shared_experts = MLP( +# args.dim, args.n_shared_experts * args.moe_inter_dim +# ) + +# def forward(self, x: jax.Array) -> jax.Array: +# """Forward pass for the MoE module. + +# Args: +# x (jax.Array): Input tensor. + +# Returns: +# jax.Array: Output tensor after expert routing and computation. +# """ +# shape = x.size() +# x = x.view(-1, self.dim) +# weights, indices = self.gate(x) +# y = torch.zeros_like(x) +# counts = torch.bincount( +# indices.flatten(), minlength=self.n_routed_experts +# ).tolist() +# for i in range(self.experts_start_idx, self.experts_end_idx): +# if counts[i] == 0: +# continue +# expert = self.experts[i] +# idx, top = torch.where(indices == i) +# y[idx] += expert(x[idx]) * weights[idx, top, None] +# z = self.shared_experts(x) +# if world_size > 1: +# dist.all_reduce(y) +# return (y + z).view(shape) + + +# class Block(nn.Module): +# """Transformer block combining attention and feed-forward layers. + +# Attributes: +# attn (nn.Module): Attention layer (MLA). +# ffn (nn.Module): Feed-forward network (MLP or MoE). +# attn_norm (nn.Module): Layer normalization for attention. +# ffn_norm (nn.Module): Layer normalization for feed-forward network. +# """ + +# def __init__(self, layer_id: int, args: ModelArgs): +# """Initializes the Transformer block. + +# Args: +# layer_id (int): Layer index in the transformer. +# args (ModelArgs): Model arguments containing block parameters. +# """ +# super().__init__() +# self.attn = MLA(args) +# self.ffn = ( +# MLP(args.dim, args.inter_dim) +# if layer_id < args.n_dense_layers +# else MoE(args) +# ) +# self.attn_norm = RMSNorm(args.dim) +# self.ffn_norm = RMSNorm(args.dim) + +# def forward( +# self, +# x: jax.Array, +# start_pos: int, +# freqs_cis: jax.Array, +# mask: Optional[jax.Array], +# ) -> jax.Array: +# """Forward pass for the Transformer block. + +# Args: +# x (jax.Array): Input tensor. +# start_pos (int): Starting position in the sequence. +# freqs_cis (jax.Array): Precomputed complex exponential values for +# rotary embeddings. +# mask (Optional[jax.Array]): Mask tensor to exclude certain positions +# from attention. + +# Returns: +# jax.Array: Output tensor after block computation. +# """ +# x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask) +# x = x + self.ffn(self.ffn_norm(x)) +# return x + + +# class Transformer(nn.Module): +# """Transformer model with positional embeddings, multiple layers, and output projection. + +# Attributes: +# max_seq_len (int): Maximum sequence length for the transformer. +# embed (nn.Module): Embedding layer for input tokens. +# layers (torch.nn.ModuleList): List of transformer blocks. +# norm (nn.Module): Layer normalization applied after all blocks. +# head (nn.Module): Output projection layer mapping to vocabulary size. +# freqs_cis (jax.Array): Precomputed complex exponential values for +# rotary embeddings. +# """ + +# def __init__(self, args: ModelArgs): +# """Initializes the Transformer model. + +# Args: +# args (ModelArgs): Model arguments containing transformer parameters. +# """ +# global world_size, rank +# world_size = dist.get_world_size() if dist.is_initialized() else 1 +# rank = dist.get_rank() if dist.is_initialized() else 0 +# Linear.dtype = ( +# torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 +# ) +# super().__init__() +# self.max_seq_len = args.max_seq_len +# self.embed = ParallelEmbedding(args.vocab_size, args.dim) +# self.layers = torch.nn.ModuleList() +# for layer_id in range(args.n_layers): +# self.layers.append(Block(layer_id, args)) +# self.norm = RMSNorm(args.dim) +# self.head = ColumnParallelLinear( +# args.dim, args.vocab_size, dtype=torch.get_default_dtype() +# ) +# self.register_buffer( +# "freqs_cis", precompute_freqs_cis(args), persistent=False +# ) + +# @torch.inference_mode() +# def forward(self, tokens: jax.Array, start_pos: int = 0): +# """Forward pass for the Transformer model. + +# Args: +# tokens (jax.Array): Input tensor of token IDs with shape (batch_size, +# seq_len). +# start_pos (int, optional): Starting position in the sequence for rotary +# embeddings. Defaults to 0. + +# Returns: +# jax.Array: Logits tensor of shape (batch_size, vocab_size). +# """ +# seqlen = tokens.size(1) +# h = self.embed(tokens) +# freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] +# mask = None +# if seqlen > 1: +# mask = torch.full( +# (seqlen, seqlen), float("-inf"), device=tokens.device +# ).triu_(1) +# for layer in self.layers: +# h = layer(h, start_pos, freqs_cis, mask) +# h = self.norm(h)[:, -1] +# logits = self.head(h) +# if world_size > 1: +# all_logits = [torch.empty_like(logits) for _ in range(world_size)] +# dist.all_gather(all_logits, logits) +# logits = torch.cat(all_logits, dim=-1) +# return logits + + +# def main(argv) -> None: +# torch.set_default_dtype(torch.bfloat16) +# torch.set_default_device("cuda") +# torch.manual_seed(0) +# args = ModelArgs() +# x = torch.randint(0, args.vocab_size, (2, 128)) +# model = Transformer(args) +# print(model(x).size()) + + +# if __name__ == "__main__": +# app.run(main) diff --git a/examples/deepseek_v3/model_pytorch.py b/examples/deepseek_v3/model_pytorch.py new file mode 100644 index 000000000..6794bbaee --- /dev/null +++ b/examples/deepseek_v3/model_pytorch.py @@ -0,0 +1,869 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Model in PyTorch.""" +import math +from typing import Literal, Optional, Tuple +from absl import app +from absl import flags +from configs.base import ModelArgs +from kernel_pytorch import act_quant, fp8_gemm, weight_dequant +import torch +from torch import nn +import torch.distributed as dist +import torch.nn.functional as F + + +world_size = 1 +rank = 0 +block_size = 128 +gemm_impl: Literal["bf16", "fp8"] = "bf16" +attn_impl: Literal["naive", "absorb"] = "absorb" + + +class ParallelEmbedding(nn.Module): + """Embedding layer with parallelism support across distributed processes. + + Args: + vocab_size (int): Vocabulary size. + dim (int): Embedding dimension. + """ + + def __init__(self, vocab_size: int, dim: int): + super().__init__() + self.vocab_size = vocab_size + self.dim = dim + assert vocab_size % world_size == 0 + self.part_vocab_size = vocab_size // world_size + self.vocab_start_idx = rank * self.part_vocab_size + self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size + self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for parallel embedding layer. + + Args: + x (torch.Tensor): Input tensor containing token indices. + + Returns: + torch.Tensor: Embedded representations. + + Raises: + ValueError: If `world_size` is not defined. + """ + if world_size > 1: + mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) + x = x - self.vocab_start_idx + x[mask] = 0 + y = F.embedding(x, self.weight) + if world_size > 1: + y[mask] = 0 + dist.all_reduce(y) + return y + + +def linear( + x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: + """Applies a linear transformation to the incoming data: y = xA^T + b. + + This function supports specialized implementations based on quantization and + tensor formats. + + Args: + x (torch.Tensor): The input tensor. + weight (torch.Tensor): The weight tensor. It may be quantized and requires + dequantization for certain cases. + bias (Optional[torch.Tensor]): The bias tensor to be added. Default is + None. + + Returns: + torch.Tensor: The result of the linear transformation, which may involve + quantization-aware computations depending on the input parameters. + + Notes: + - If `weight` is quantized (e.g., `element_size() > 1`), a dequantized + version + is used for computation. + - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are + applied. + - For other cases, the function applies quantization to `x` and uses + `fp8_gemm` for computation. + """ + if weight.element_size() > 1: + return F.linear(x, weight, bias) + elif gemm_impl == "bf16": + weight = weight_dequant(weight, weight.scale) + return F.linear(x, weight, bias) + else: + x, scale = act_quant(x, block_size) + y = fp8_gemm(x, scale, weight, weight.scale) + if bias is not None: + y += bias + return y + + +class Linear(nn.Module): + """Custom linear layer with support for quantized weights and optional bias. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + + dtype = torch.bfloat16 + + def __init__( + self, in_features: int, out_features: int, bias: bool = False, dtype=None + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.empty(out_features, in_features, dtype=dtype or Linear.dtype) + ) + if self.weight.element_size() == 1: + scale_out_features = (out_features + block_size - 1) // block_size + scale_in_features = (in_features + block_size - 1) // block_size + self.weight.scale = self.scale = nn.Parameter( + torch.empty( + scale_out_features, scale_in_features, dtype=torch.float32 + ) + ) + else: + self.register_parameter("scale", None) + if bias: + self.bias = nn.Parameter(torch.empty(self.part_out_features)) + else: + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for the custom linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor after linear computation. + """ + return linear(x, self.weight, self.bias) + + +class ColumnParallelLinear(Linear): + """Linear layer with column parallelism, splitting output features across distributed processes. + + Args: + in_features (int): Number of input features. + out_features (int): Total number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + + def __init__( + self, in_features: int, out_features: int, bias: bool = False, dtype=None + ): + assert out_features % world_size == 0 + self.part_out_features = out_features // world_size + super().__init__(in_features, self.part_out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for column parallel linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with column-parallel computation. + """ + y = linear(x, self.weight, self.bias) + return y + + +class RowParallelLinear(Linear): + """Linear layer with row parallelism, splitting input features across distributed processes. + + Args: + in_features (int): Total number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + + def __init__( + self, in_features: int, out_features: int, bias: bool = False, dtype=None + ): + assert in_features % world_size == 0 + self.part_in_features = in_features // world_size + super().__init__(self.part_in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for row parallel linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with row-parallel computation. + """ + y = linear(x, self.weight) + if world_size > 1: + dist.all_reduce(y) + if self.bias is not None: + y += self.bias + return y + + +class RMSNorm(nn.Module): + """Root Mean Square Layer Normalization (RMSNorm). + + Args: + dim (int): Dimension of the input tensor. + eps (float): Epsilon value for numerical stability. Defaults to 1e-6. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor): + """Forward pass for RMSNorm. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized tensor with the same shape as input. + """ + return F.rms_norm(x, (self.dim,), self.weight, self.eps) + + +def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: + """Precomputes frequency-based complex exponential values for rotary positional embeddings. + + Args: + args (ModelArgs): Model arguments containing positional embedding + parameters. + + Returns: + torch.Tensor: Precomputed complex exponential values for positional + embeddings. + """ + dim = args.qk_rope_head_dim + seqlen = args.max_seq_len + beta_fast = args.beta_fast + beta_slow = args.beta_slow + base = args.rope_theta + factor = args.rope_factor + + def find_correction_dim(num_rotations, dim, base, max_seq_len): + """Computes the correction dimension for a given number of rotations in the rotary positional embedding. + + Args: + num_rotations (float): Number of rotations to compute the correction + for. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + float: The correction dimension based on the input parameters. + """ + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): + """Computes the range of correction dimensions for rotary positional embeddings. + + Args: + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + Tuple[int, int]: The range of correction dimensions (low, high), clamped + to valid indices. + """ + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + """Computes a linear ramp function used to smooth values between a minimum and maximum range. + + Args: + min (float): Minimum value for the ramp function. + max (float): Maximum value for the ramp function. + dim (int): Dimensionality of the ramp tensor. + + Returns: + torch.Tensor: A tensor of shape (dim,) with values linearly interpolated + between 0 and 1, + clamped to the range [0, 1]. + """ + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if seqlen > args.original_seq_len: + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, args.original_seq_len + ) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + t = torch.arange(seqlen) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """Applies rotary positional embeddings to the input tensor. + + Args: + x (torch.Tensor): Input tensor with positional embeddings to be applied. + freqs_cis (torch.Tensor): Precomputed complex exponential values for + positional embeddings. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ + dtype = x.dtype + x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) + y = torch.view_as_real(x * freqs_cis).flatten(3) + return y.to(dtype) + + +class MLA(nn.Module): + """Multi-Headed Attention Layer (MLA). + + Attributes: + dim (int): Dimensionality of the input features. + n_heads (int): Number of attention heads. + n_local_heads (int): Number of local attention heads for distributed + systems. + q_lora_rank (int): Rank for low-rank query projection. + kv_lora_rank (int): Rank for low-rank key/value projection. + qk_nope_head_dim (int): Dimensionality of non-positional query/key + projections. + qk_rope_head_dim (int): Dimensionality of rotary-positional query/key + projections. + qk_head_dim (int): Total dimensionality of query/key projections. + v_head_dim (int): Dimensionality of value projections. + softmax_scale (float): Scaling factor for softmax in attention + computation. + """ + + def __init__(self, args: ModelArgs): + super().__init__() + self.dim = args.dim + self.n_heads = args.n_heads + self.n_local_heads = args.n_heads // world_size + self.q_lora_rank = args.q_lora_rank + self.kv_lora_rank = args.kv_lora_rank + self.qk_nope_head_dim = args.qk_nope_head_dim + self.qk_rope_head_dim = args.qk_rope_head_dim + self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim + self.v_head_dim = args.v_head_dim + + if self.q_lora_rank == 0: + self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim) + else: + self.wq_a = Linear(self.dim, self.q_lora_rank) + self.q_norm = RMSNorm(self.q_lora_rank) + self.wq_b = ColumnParallelLinear( + self.q_lora_rank, self.n_heads * self.qk_head_dim + ) + self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) + self.kv_norm = RMSNorm(self.kv_lora_rank) + self.wkv_b = ColumnParallelLinear( + self.kv_lora_rank, + self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), + ) + self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) + self.softmax_scale = self.qk_head_dim**-0.5 + if args.max_seq_len > args.original_seq_len: + mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 + self.softmax_scale = self.softmax_scale * mscale * mscale + + if attn_impl == "naive": + self.register_buffer( + "k_cache", + torch.zeros( + args.max_batch_size, + args.max_seq_len, + self.n_local_heads, + self.qk_head_dim, + ), + persistent=False, + ) + self.register_buffer( + "v_cache", + torch.zeros( + args.max_batch_size, + args.max_seq_len, + self.n_local_heads, + self.v_head_dim, + ), + persistent=False, + ) + else: + self.register_buffer( + "kv_cache", + torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), + persistent=False, + ) + self.register_buffer( + "pe_cache", + torch.zeros( + args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim + ), + persistent=False, + ) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + """Forward pass for the Multi-Headed Attention Layer (MLA). + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + start_pos (int): Starting position in the sequence for caching. + freqs_cis (torch.Tensor): Precomputed complex exponential values for + rotary embeddings. + mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions + from attention. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + end_pos = start_pos + seqlen + if self.q_lora_rank == 0: + q = self.wq(x) + else: + q = self.wq_b(self.q_norm(self.wq_a(x))) + q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + q_pe = apply_rotary_emb(q_pe, freqs_cis) + kv = self.wkv_a(x) + kv, k_pe = torch.split( + kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) + if attn_impl == "naive": + q = torch.cat([q_nope, q_pe], dim=-1) + kv = self.wkv_b(self.kv_norm(kv)) + kv = kv.view( + bsz, + seqlen, + self.n_local_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + k_nope, v = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + k = torch.cat( + [k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1 + ) + self.k_cache[:bsz, start_pos:end_pos] = k + self.v_cache[:bsz, start_pos:end_pos] = v + scores = ( + torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) + * self.softmax_scale + ) + else: + wkv_b = ( + self.wkv_b.weight + if self.wkv_b.scale is None + else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) + ) + wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) + q_nope = torch.einsum( + "bshd,hdc->bshc", q_nope, wkv_b[:, : self.qk_nope_head_dim] + ) + self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) + self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) + scores = ( + torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + + torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos]) + ) * self.softmax_scale + if mask is not None: + scores += mask.unsqueeze(1) + scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) + if attn_impl == "naive": + x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) + else: + x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) + x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim :]) + x = self.wo(x.flatten(2)) + return x + + +class MLP(nn.Module): + """Multi-Layer Perceptron (MLP) used as a feed-forward layer. + + Attributes: + w1 (nn.Module): Linear layer for input-to-hidden transformation. + w2 (nn.Module): Linear layer for hidden-to-output transformation. + w3 (nn.Module): Additional linear layer for feature transformation. + """ + + def __init__(self, dim: int, inter_dim: int): + """Initializes the MLP layer. + + Args: + dim (int): Input and output dimensionality. + inter_dim (int): Hidden layer dimensionality. + """ + super().__init__() + self.w1 = ColumnParallelLinear(dim, inter_dim) + self.w2 = RowParallelLinear(inter_dim, dim) + self.w3 = ColumnParallelLinear(dim, inter_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for the MLP layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after MLP computation. + """ + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class Gate(nn.Module): + """Gating mechanism for routing inputs in a mixture-of-experts (MoE) model. + + Attributes: + dim (int): Dimensionality of input features. + topk (int): Number of top experts activated for each input. + n_groups (int): Number of groups for routing. + topk_groups (int): Number of groups to route inputs to. + score_func (str): Scoring function ('softmax' or 'sigmoid'). + route_scale (float): Scaling factor for routing weights. + weight (torch.nn.Parameter): Learnable weights for the gate. + bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. + """ + + def __init__(self, args: ModelArgs): + """Initializes the Gate module. + + Args: + args (ModelArgs): Model arguments containing gating parameters. + """ + super().__init__() + self.dim = args.dim + self.topk = args.n_activated_experts + self.n_groups = args.n_expert_groups + self.topk_groups = args.n_limited_groups + self.score_func = args.score_func + self.route_scale = args.route_scale + self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) + self.bias = ( + nn.Parameter(torch.empty(args.n_routed_experts)) + if self.dim == 7168 + else None + ) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for the gating mechanism. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert + indices. + """ + scores = linear(x, self.weight) + if self.score_func == "softmax": + scores = scores.softmax(dim=-1, dtype=torch.float32) + else: + scores = scores.sigmoid() + original_scores = scores + if self.bias is not None: + scores = scores + self.bias + if self.n_groups > 1: + scores = scores.view(x.size(0), self.n_groups, -1) + if self.bias is None: + group_scores = scores.amax(dim=-1) + else: + group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) + indices = group_scores.topk(self.topk_groups, dim=-1)[1] + mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True) + scores = (scores * mask.unsqueeze(-1)).flatten(1) + indices = torch.topk(scores, self.topk, dim=-1)[1] + weights = original_scores.gather(1, indices) + if self.score_func == "sigmoid": + weights /= weights.sum(dim=-1, keepdim=True) + weights *= self.route_scale + return weights.type_as(x), indices + + +class Expert(nn.Module): + """Expert layer for Mixture-of-Experts (MoE) models. + + Attributes: + w1 (nn.Module): Linear layer for input-to-hidden transformation. + w2 (nn.Module): Linear layer for hidden-to-output transformation. + w3 (nn.Module): Additional linear layer for feature transformation. + """ + + def __init__(self, dim: int, inter_dim: int): + """Initializes the Expert layer. + + Args: + dim (int): Input and output dimensionality. + inter_dim (int): Hidden layer dimensionality. + """ + super().__init__() + self.w1 = Linear(dim, inter_dim) + self.w2 = Linear(inter_dim, dim) + self.w3 = Linear(dim, inter_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for the Expert layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after expert computation. + """ + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class MoE(nn.Module): + """Mixture-of-Experts (MoE) module. + + Attributes: + dim (int): Dimensionality of input features. + n_routed_experts (int): Total number of experts in the model. + n_local_experts (int): Number of experts handled locally in distributed + systems. + n_activated_experts (int): Number of experts activated for each input. + gate (nn.Module): Gating mechanism to route inputs to experts. + experts (nn.ModuleList): List of expert modules. + shared_experts (nn.Module): Shared experts applied to all inputs. + """ + + def __init__(self, args: ModelArgs): + """Initializes the MoE module. + + Args: + args (ModelArgs): Model arguments containing MoE parameters. + """ + super().__init__() + self.dim = args.dim + assert args.n_routed_experts % world_size == 0 + self.n_routed_experts = args.n_routed_experts + self.n_local_experts = args.n_routed_experts // world_size + self.n_activated_experts = args.n_activated_experts + self.experts_start_idx = rank * self.n_local_experts + self.experts_end_idx = self.experts_start_idx + self.n_local_experts + self.gate = Gate(args) + self.experts = nn.ModuleList([ + Expert(args.dim, args.moe_inter_dim) + if self.experts_start_idx <= i < self.experts_end_idx + else None + for i in range(self.n_routed_experts) + ]) + self.shared_experts = MLP( + args.dim, args.n_shared_experts * args.moe_inter_dim + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for the MoE module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after expert routing and computation. + """ + shape = x.size() + x = x.view(-1, self.dim) + weights, indices = self.gate(x) + y = torch.zeros_like(x) + counts = torch.bincount( + indices.flatten(), minlength=self.n_routed_experts + ).tolist() + for i in range(self.experts_start_idx, self.experts_end_idx): + if counts[i] == 0: + continue + expert = self.experts[i] + idx, top = torch.where(indices == i) + y[idx] += expert(x[idx]) * weights[idx, top, None] + z = self.shared_experts(x) + if world_size > 1: + dist.all_reduce(y) + return (y + z).view(shape) + + +class Block(nn.Module): + """Transformer block combining attention and feed-forward layers. + + Attributes: + attn (nn.Module): Attention layer (MLA). + ffn (nn.Module): Feed-forward network (MLP or MoE). + attn_norm (nn.Module): Layer normalization for attention. + ffn_norm (nn.Module): Layer normalization for feed-forward network. + """ + + def __init__(self, layer_id: int, args: ModelArgs): + """Initializes the Transformer block. + + Args: + layer_id (int): Layer index in the transformer. + args (ModelArgs): Model arguments containing block parameters. + """ + super().__init__() + self.attn = MLA(args) + self.ffn = ( + MLP(args.dim, args.inter_dim) + if layer_id < args.n_dense_layers + else MoE(args) + ) + self.attn_norm = RMSNorm(args.dim) + self.ffn_norm = RMSNorm(args.dim) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ) -> torch.Tensor: + """Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position in the sequence. + freqs_cis (torch.Tensor): Precomputed complex exponential values for + rotary embeddings. + mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions + from attention. + + Returns: + torch.Tensor: Output tensor after block computation. + """ + x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask) + x = x + self.ffn(self.ffn_norm(x)) + return x + + +class Transformer(nn.Module): + """Transformer model with positional embeddings, multiple layers, and output projection. + + Attributes: + max_seq_len (int): Maximum sequence length for the transformer. + embed (nn.Module): Embedding layer for input tokens. + layers (torch.nn.ModuleList): List of transformer blocks. + norm (nn.Module): Layer normalization applied after all blocks. + head (nn.Module): Output projection layer mapping to vocabulary size. + freqs_cis (torch.Tensor): Precomputed complex exponential values for + rotary embeddings. + """ + + def __init__(self, args: ModelArgs): + """Initializes the Transformer model. + + Args: + args (ModelArgs): Model arguments containing transformer parameters. + """ + global world_size, rank + world_size = dist.get_world_size() if dist.is_initialized() else 1 + rank = dist.get_rank() if dist.is_initialized() else 0 + Linear.dtype = ( + torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 + ) + super().__init__() + self.max_seq_len = args.max_seq_len + self.embed = ParallelEmbedding(args.vocab_size, args.dim) + self.layers = torch.nn.ModuleList() + for layer_id in range(args.n_layers): + self.layers.append(Block(layer_id, args)) + self.norm = RMSNorm(args.dim) + self.head = ColumnParallelLinear( + args.dim, args.vocab_size, dtype=torch.get_default_dtype() + ) + self.register_buffer( + "freqs_cis", precompute_freqs_cis(args), persistent=False + ) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int = 0): + """Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, + seq_len). + start_pos (int, optional): Starting position in the sequence for rotary + embeddings. Defaults to 0. + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ + seqlen = tokens.size(1) + h = self.embed(tokens) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + mask = None + if seqlen > 1: + mask = torch.full( + (seqlen, seqlen), float("-inf"), device=tokens.device + ).triu_(1) + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h)[:, -1] + logits = self.head(h) + if world_size > 1: + all_logits = [torch.empty_like(logits) for _ in range(world_size)] + dist.all_gather(all_logits, logits) + logits = torch.cat(all_logits, dim=-1) + return logits + + +def main(argv) -> None: + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device("cuda") + torch.manual_seed(0) + args = ModelArgs() + x = torch.randint(0, args.vocab_size, (2, 128)) + model = Transformer(args) + print(model(x).size()) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/deepseek_v3/model_test.py b/examples/deepseek_v3/model_test.py new file mode 100644 index 000000000..7d69cd7a2 --- /dev/null +++ b/examples/deepseek_v3/model_test.py @@ -0,0 +1,124 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================å +"""Tests for the DeepSeek model.""" + +from typing import Literal + +from absl.testing import absltest +from absl.testing import parameterized +from flax import nnx +import model as model_lib +import model_pytorch as model_lib_pytorch +import jax +import jax.numpy as jnp +import numpy as np +import torch + + +def bf16_pt_to_jax(pt_tensor: torch.Tensor) -> jax.Array: + return jnp.array( + pt_tensor.detach().to(torch.float32).numpy().astype(jnp.bfloat16) + ) + + +class ModelTest(parameterized.TestCase): + + def test_parallel_embedding_parity(self): + vocab_size = 100 + dim = 64 + batch_size = 2 + seq_len = 10 + config = model_lib.ModelArgs(vocab_size=vocab_size, dim=dim) + + # Create PyTorch embedding + embedding_pytorch = model_lib_pytorch.ParallelEmbedding(vocab_size, dim) + input_pytorch = torch.randint(0, vocab_size, (batch_size, seq_len)) + output_pytorch = embedding_pytorch(input_pytorch) + + # Create Flax embedding + embedding_flax = model_lib.ParallelEmbedding(config, dim) + input_flax = jnp.array(input_pytorch.detach().numpy()) + + # Copy weights from PyTorch to Flax + embedding_flax.weight.value = jnp.array( + embedding_pytorch.weight.detach().numpy() + ) + + # Run Flax embedding + output_flax = embedding_flax(input_flax) + + # Compare outputs + self.assertTrue( + jnp.allclose(output_flax, jnp.array(output_pytorch.detach().numpy())) + ) + + def test_linear_parity_bf16(self): + gemm_impl = "bf16" + in_features = 6 + out_features = 7 + config = model_lib.ModelArgs(gemm_impl=gemm_impl) + + x_pytorch = torch.rand(in_features).to(torch.bfloat16) + linear_pytorch = model_lib_pytorch.ColumnParallelLinear( + in_features=in_features, out_features=out_features, bias=True + ) + linear_pytorch.weight.data = torch.rand(out_features, in_features).to( + torch.bfloat16 + ) + linear_pytorch.bias.data = torch.rand(out_features).to(torch.bfloat16) + y_pytorch = linear_pytorch(x_pytorch) + + # jax + x_jax = bf16_pt_to_jax(x_pytorch) + linear_jax = model_lib.ColumnParallelLinear( + in_features=in_features, + out_features=out_features, + bias=True, + config=config, + ) + linear_jax.weight.value = bf16_pt_to_jax(linear_pytorch.weight).T + linear_jax.bias.value = bf16_pt_to_jax(linear_pytorch.bias) + y_jax = linear_jax(x_jax) + + np.testing.assert_allclose( + y_jax, + bf16_pt_to_jax(y_pytorch), + rtol=0.0072, + ) + + def test_linear_parity_bf16_fp8(self): + gemm_impl = "fp8" + in_features = 6 + out_features = 7 + config = model_lib.ModelArgs(gemm_impl=gemm_impl) + + # pytorch + dtype_pytorch = torch.float8_e4m3fn + x_pytorch = torch.rand(in_features).to(torch.bfloat16) + linear_pytorch = model_lib_pytorch.ColumnParallelLinear( + in_features=in_features, + out_features=out_features, + dtype=dtype_pytorch, + bias=True, + ) + linear_pytorch.weight.data = torch.rand(out_features, in_features).to( + dtype_pytorch + ) + linear_pytorch.bias.data = torch.rand(out_features).to(dtype_pytorch) + y_pytorch = linear_pytorch(x_pytorch) + + +if __name__ == "__main__": + absltest.main()