Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add seq parallelism for attention and MoE MLP #1328

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7b8f711
[exp] seq exp sharding
ZhiyuLi-goog Dec 29, 2024
802313a
update
ZhiyuLi-goog Dec 29, 2024
d8d4595
update
ZhiyuLi-goog Dec 29, 2024
b7f3225
merge for sp
suexu1025 Feb 27, 2025
7ed5fd1
fix merge parts
suexu1025 Feb 27, 2025
a1c6973
update merge confict base config
suexu1025 Feb 27, 2025
3f2d278
update to fix sharding mismatch
suexu1025 Feb 28, 2025
3e06ebb
update sub_seq for masks
suexu1025 Feb 27, 2025
d23d27b
update sharding axis
suexu1025 Feb 27, 2025
924ce77
update with reshape
suexu1025 Feb 28, 2025
b62812d
solve merge conflict
suexu1025 Mar 1, 2025
746f4a3
update for generate sharding
suexu1025 Feb 28, 2025
a6d345c
enable compute_axis configurable in mixtral model
suexu1025 Mar 4, 2025
e06c3d6
address output_logits sharding
suexu1025 Mar 5, 2025
65a64d4
clean up
suexu1025 Mar 5, 2025
23cd85f
Merge branch 'main' into qinwen/sharding_merge_main
suexu1025 Mar 5, 2025
10a9d82
update
suexu1025 Mar 5, 2025
0cca6df
update
suexu1025 Mar 6, 2025
cd005f3
Merge branch 'main' into qinwen/sharding_merge_main
suexu1025 Mar 6, 2025
ebae8e0
fix tests
suexu1025 Mar 6, 2025
2e0c459
added contition for non-sharded kernel for cp during inference only
suexu1025 Mar 6, 2025
37c843e
update
suexu1025 Mar 6, 2025
b63c63b
bug fix
suexu1025 Mar 7, 2025
9b32dc0
Merge branch 'main' into qinwen/sharding_merge_main
suexu1025 Mar 7, 2025
82d7fc3
Merge branch 'main' into qinwen/sharding_merge_main
suexu1025 Mar 7, 2025
4007e7c
fix tests
suexu1025 Mar 7, 2025
72f2a90
adddress comment
suexu1025 Mar 7, 2025
8da48f5
update
suexu1025 Mar 7, 2025
8a43dd5
address comments
suexu1025 Mar 7, 2025
56deeda
address comments
suexu1025 Mar 7, 2025
1c6be59
revert
suexu1025 Mar 7, 2025
bd0e199
address lint
suexu1025 Mar 7, 2025
44d646f
reformat for lint
suexu1025 Mar 7, 2025
5172068
update MOE test
suexu1025 Mar 7, 2025
d6787c3
add comment to explain grouping in generate_mask for moe model
suexu1025 Mar 7, 2025
f964acd
address the comments
suexu1025 Mar 8, 2025
930d77b
Merge branch 'main' into qinwen/sharding_merge_main
suexu1025 Mar 8, 2025
c5174de
update to fix tests
suexu1025 Mar 8, 2025
5c3fe75
Merge branch 'main' into qinwen/sharding_merge_main
suexu1025 Mar 8, 2025
b86e035
seperate yml for inference
suexu1025 Mar 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@

BATCH = "activation_batch"
LENGTH = "activation_length"
KV_LENGTH = "activation_kv_length"
EMBED = "activation_embed"
HEAD = "activation_heads"
PREFILL_KV_BATCH = "activation_prefill_kv_batch"
KV_BATCH = "activation_kv_batch"
KV_HEAD = "activation_kv_heads"
KV_HEAD_DIM = "activation_kv_head_dim"
D_KV = "activation_kv"
DECODE_BATCH = "decode_batch"
DECODE_LENGTH = "decode_length"
CACHE_BATCH_PREFILL = "cache_batch_prefill"
CACHE_BATCH = "cache_batch"
CACHE_SEQUENCE = "cache_sequence"
Expand Down
12 changes: 8 additions & 4 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ jax_cache_dir: "~/jax_cache"
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'

# Parallelism
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
Expand All @@ -275,9 +275,11 @@ logical_axis_rules: [
['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
['activation_vocab', ['tensor', 'tensor_transpose']],
['activation_vocab', 'tensor_sequence'],
['activation_vocab', 'sequence'],
['activation_vocab', ['sequence']],
['activation_stage', 'stage'],
['activation_exp', 'expert'],
['activation_exp', ['expert']],
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['decode_length', []],
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
Expand Down Expand Up @@ -308,7 +310,7 @@ logical_axis_rules: [
['paged_kv_head_dim_size', []],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]

# sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters.
sharding_tolerance: 0.02
Expand All @@ -321,6 +323,7 @@ dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1
dcn_sequence_parallelism: 1 # never recommended
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1 # never recommended
dcn_tensor_transpose_parallelism: 1
dcn_tensor_sequence_parallelism: 1 # never recommended
Expand All @@ -331,6 +334,7 @@ ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_transpose_parallelism: 1
ici_sequence_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 1
ici_tensor_transpose_parallelism: 1
ici_tensor_sequence_parallelism: 1
Expand Down
626 changes: 626 additions & 0 deletions MaxText/configs/inference.yml

Large diffs are not rendered by default.

63 changes: 56 additions & 7 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Any, Optional, Tuple

from flax import linen as nn
from flax.linen import partitioning
import jax
from jax import lax
from jax.ad_checkpoint import checkpoint_name
Expand Down Expand Up @@ -68,7 +69,10 @@ class AttentionType(enum.Enum):
BATCH = common_types.BATCH
PREFILL_KV_BATCH = common_types.PREFILL_KV_BATCH
KV_BATCH = common_types.KV_BATCH
DECODE_BATCH = common_types.DECODE_BATCH
DECODE_LENGTH = common_types.DECODE_LENGTH
LENGTH = common_types.LENGTH
KV_LENGTH = common_types.KV_LENGTH
HEAD = common_types.HEAD
EMBED = common_types.EMBED
KV_HEAD = common_types.KV_HEAD
Expand Down Expand Up @@ -541,14 +545,26 @@ def compute_local_attention(
local_sum = jnp.reshape(local_sum, (local_sum.shape[0], local_sum.shape[1], local_sum.shape[2] * local_sum.shape[3], 1))

local_out = self.wv_product(local_exps, value, model_mode)
if model_mode != common_types.MODEL_MODE_AUTOREGRESSIVE:
local_out = partitioning.with_sharding_constraint(local_out, (BATCH, KV_LENGTH, HEAD, D_KV))
else:
local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV))

if self.reshape_q and q_seq_len == 1:
local_max = local_max[:, 0:1, :, :]
local_sum = local_sum[:, 0:1, :, :]
local_out = local_out[:, 0:1, :, :]

if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
local_max = partitioning.with_sharding_constraint(local_max, (DECODE_BATCH, None, HEAD, D_KV))
local_sum = partitioning.with_sharding_constraint(local_sum, (DECODE_BATCH, None, HEAD, D_KV))
local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV))

return local_out, local_max, local_sum

def is_partition_in_decode(self, seq_len):
return self.config.ici_context_parallelism > 0 and seq_len == 1

def apply_attention_dot(
self,
query: Array,
Expand All @@ -568,7 +584,22 @@ def apply_attention_dot(
key = key.astype(jnp.float32)

q_seq_len = query.shape[1]

# special sharding for decode
if self.is_partition_in_decode(q_seq_len):
query = partitioning.with_sharding_constraint(query, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV))
key = partitioning.with_sharding_constraint(key, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV))
value = partitioning.with_sharding_constraint(value, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV))
else:
query = partitioning.with_sharding_constraint(query, (BATCH, LENGTH, HEAD, D_KV))
key = partitioning.with_sharding_constraint(key, (BATCH, KV_LENGTH, HEAD, D_KV))
value = partitioning.with_sharding_constraint(value, (BATCH, KV_LENGTH, HEAD, D_KV))

attn_weights = self.qk_product(query, key, q_seq_len, model_mode)
if self.is_partition_in_decode(q_seq_len):
attn_weights = partitioning.with_sharding_constraint(attn_weights, (KV_LENGTH, HEAD, None, None, None))
else:
attn_weights = partitioning.with_sharding_constraint(attn_weights, (BATCH, HEAD, None, LENGTH, KV_LENGTH))

if self.attn_logits_soft_cap:
attn_weights = jnp.tanh(attn_weights / self.attn_logits_soft_cap)
Expand All @@ -578,6 +609,10 @@ def apply_attention_dot(
if self.float32_logits:
attn_weights = attn_weights.astype(jnp.float32)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode, previous_chunk)
if self.is_partition_in_decode(q_seq_len):
attn_mask = partitioning.with_sharding_constraint(attn_mask, (KV_LENGTH, HEAD, None, None, None))
else:
attn_mask = partitioning.with_sharding_constraint(attn_mask, (BATCH, HEAD, None, LENGTH, KV_LENGTH))
if attn_mask is not None:
attn_weights = apply_mask_to_logits(attn_weights, attn_mask)
return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode)
Expand Down Expand Up @@ -1027,7 +1062,6 @@ def value_body(i, val):
cached_value_var.value = jax.lax.dynamic_update_index_in_dim(
cached_value_var.value, one_token_value_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis
)

cached_key_var.value = nn.with_logical_constraint(cached_key_var.value, ar_cache_axis_names)
cached_value_var.value = nn.with_logical_constraint(cached_value_var.value, ar_cache_axis_names)

Expand Down Expand Up @@ -1294,9 +1328,11 @@ class Attention(nn.Module):
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED)
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED)
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
decode_out_axis_names = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)

prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3)
ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3)
Expand Down Expand Up @@ -1356,11 +1392,12 @@ def query_init(*args):
# pylint: disable=no-value-for-parameter
return self.kernel_init(*args) / depth_scaling

kernel_axes = (None, None, None) if self.config.ici_context_parallelism > 1 else ("embed", "q_heads", "kv")
query_proj = DenseGeneral(
features=(self.num_query_heads, self.head_dim),
axis=-1,
kernel_init=query_init,
kernel_axes=("embed", "q_heads", "kv"),
kernel_axes=kernel_axes,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="query",
Expand All @@ -1386,7 +1423,7 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array:
if self.num_query_heads % self.num_kv_heads != 0:
raise ValueError("Invalid num_kv_heads for GQA.")

kernel_axes = ("embed", "kv_heads", "kv_head_dim")
kernel_axes = (None, None, None) if self.config.ici_context_parallelism > 1 else ("embed", "kv_heads", "kv_head_dim")

kv_proj = DenseGeneral(
features=(self.num_kv_heads, self.head_dim),
Expand Down Expand Up @@ -1420,11 +1457,12 @@ def qkv_projection(self, inputs: Array, proj_name: str):
return query, key, value

def out_projection(self, output_dim: int, out: Array) -> Array:
out_kernel_axis = (None, None, None) if self.config.ici_context_parallelism > 1 else ("heads", "kv", "embed")
out_proj = DenseGeneral(
features=output_dim,
axis=(-2, -1),
kernel_init=self.kernel_init,
kernel_axes=("heads", "kv", "embed"),
kernel_axes=out_kernel_axis, # trade speed with memory
dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="out",
Expand Down Expand Up @@ -1517,8 +1555,12 @@ def __call__(
Returns:
output of shape `[batch, length, q_features]`.
"""
inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names)
inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names)
if model_mode == common_types.MODEL_MODE_PREFILL:
inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names)
inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names)
else:
inputs_q = nn.with_logical_constraint(inputs_q, self.decode_input_axis_names)
inputs_kv = nn.with_logical_constraint(inputs_kv, self.decode_input_axis_names)

# apply projection.
if self.config.fused_qkv:
Expand All @@ -1536,6 +1578,10 @@ def __call__(
query = nn.with_logical_constraint(query, self.prefill_query_axis_names)
key = nn.with_logical_constraint(key, self.prefill_key_axis_names)
value = nn.with_logical_constraint(value, self.prefill_value_axis_names)
elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
query = nn.with_logical_constraint(query, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV))
key = nn.with_logical_constraint(key, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV))
value = nn.with_logical_constraint(value, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV))
else:
query = nn.with_logical_constraint(query, self.query_axis_names)
key = nn.with_logical_constraint(key, self.key_axis_names)
Expand All @@ -1554,7 +1600,10 @@ def __call__(
else:
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, previous_chunk)

out = nn.with_logical_constraint(out, self.out_axis_names)
if model_mode == common_types.MODEL_MODE_PREFILL:
out = nn.with_logical_constraint(out, self.out_axis_names)
else:
out = nn.with_logical_constraint(out, self.decode_out_axis_names)
out = self.out_projection(inputs_q.shape[-1], out)
out = checkpoint_name(out, "out_proj")
return out
Expand Down
Loading
Loading