From 7b8f71132970f69848cf66b9dd97d3a955fc6f95 Mon Sep 17 00:00:00 2001 From: ZhiyuLi-goog Date: Sun, 29 Dec 2024 02:16:30 +0000 Subject: [PATCH 01/34] [exp] seq exp sharding --- MaxText/configs/base.yml | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 6fbed741a..58fb3be53 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -273,21 +273,16 @@ logical_axis_rules: [ ['activation_vocab', 'tensor_sequence'], ['activation_vocab', 'sequence'], ['activation_stage', 'stage'], - ['activation_exp', 'expert'], - ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], - ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']], - ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], - ['embed', ['fsdp', 'sequence', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence']], - ['embed_no_exp', ['fsdp', 'sequence']], - ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_exp', 'context'], + ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']], + ['vocab', ['tensor', 'autoregressive']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], + ['embed', ['fsdp', 'sequence', 'context', 'expert']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], + ['embed_no_exp', ['fsdp', 'sequence', 'context']], + ['norm', 'tensor'], + ['q_heads', ['tensor', 'autoregressive']], + ['heads', ['tensor', 'autoregressive']], ['layers', 'stage'], ['kv', []], ['kv_head_dim', []], @@ -297,7 +292,7 @@ logical_axis_rules: [ ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['cache_kv', []], ['cache_sequence', []], - ['exp', 'expert'], + ['exp', 'context'], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] From 802313aa9c597ec1e23587bbe7f7ec84a5627782 Mon Sep 17 00:00:00 2001 From: ZhiyuLi-goog Date: Sun, 29 Dec 2024 04:46:28 +0000 Subject: [PATCH 02/34] update --- MaxText/layers/attentions.py | 8 +++++--- MaxText/layers/linears.py | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index b89c65998..1fb9a7066 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -1170,7 +1170,7 @@ def query_init(*args): features=(self.num_query_heads, self.head_dim), axis=-1, kernel_init=query_init, - kernel_axes=("embed", "q_heads", "kv"), + kernel_axes=(None, None, None), # attn_q_weight_ndh=(None, zero_axes, None), dtype=self.dtype, weight_dtype=self.weight_dtype, name="query", @@ -1196,7 +1196,9 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: if self.num_query_heads % self.num_kv_heads != 0: raise ValueError("Invalid num_kv_heads for GQA.") - kernel_axes = ("embed", "kv_heads", "kv_head_dim") + # attn_k_weight_kdh=(None, zero_axes, None), + # kernel_axes = ("embed", "kv_heads", "kv_head_dim") + kernel_axes = (None, None, None) kv_proj = DenseGeneral( features=(self.num_kv_heads, self.head_dim), @@ -1234,7 +1236,7 @@ def out_projection(self, output_dim: int, out: Array) -> Array: features=output_dim, axis=(-2, -1), kernel_init=self.kernel_init, - kernel_axes=("heads", "kv", "embed"), + kernel_axes=(None, None, None), # trade speed with memory dtype=self.dtype, weight_dtype=self.weight_dtype, name="out", diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 27b9fb032..c3fc11751 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -623,6 +623,7 @@ def maybe_all_gather_kernel_weight_in_expert_parallelism(self, kernel, kernel_ax def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): # gate_logits: batch, length, expert + # follow router_logits = shd.shard(router_logits, (None, None, None)) gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_length", None)) # shape of top_k_weights & top_k_indices: (batch, sequence, num_experts_per_tok) top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok) From d8d45958a6f2cabf336b1d8de0a69114f7a971e7 Mon Sep 17 00:00:00 2001 From: ZhiyuLi-goog Date: Sun, 29 Dec 2024 05:06:47 +0000 Subject: [PATCH 03/34] update --- MaxText/layers/linears.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index c3fc11751..512ebab61 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -640,8 +640,8 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): # token dropping if needed dispatch_mask, combine_mask = self.generate_masks(top_k_indices, weights) mask_axes = ("activation_batch", "activation_length", None, None) - dispatch_mask = nn.with_logical_constraint(dispatch_mask, mask_axes) - combine_mask = nn.with_logical_constraint(combine_mask, mask_axes) + # dispatch_mask = nn.with_logical_constraint(dispatch_mask, mask_axes) + # combine_mask = nn.with_logical_constraint(combine_mask, mask_axes) if self.config.model_call_mode != "inference": loss = self.load_balance_loss(top_k_indices, weights) else: From b7f3225cbda05a6e3caed5324f6b7d8ddbf1e390 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Thu, 27 Feb 2025 02:34:13 +0000 Subject: [PATCH 04/34] merge for sp --- MaxText/configs/base.yml | 5 +- MaxText/layers/attentions.py | 15 ++++- MaxText/layers/linears.py | 123 +++++++++++++++++++++++++++-------- 3 files changed, 112 insertions(+), 31 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 58fb3be53..f7ace1462 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -263,8 +263,11 @@ logical_axis_rules: [ ['activation_length', ['sequence']], ['activation_norm_length', ['tensor_sequence', 'sequence']], ['activation_embed', ['tensor', 'tensor_transpose']], - ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_mlp', ['context''tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_length', ['context']], + ['activation_length_q', ['context']], + ['activation_kv_length', ['context']] ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 1fb9a7066..edca78fdc 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -20,6 +20,7 @@ from typing import Any, Optional, Tuple from flax import linen as nn +from flax.linen import partitioning import jax from jax import lax from jax.ad_checkpoint import checkpoint_name @@ -462,6 +463,8 @@ def compute_local_attention( local_out = self.wv_product(local_exps, value, model_mode) + local_out = partitioning.with_sharding_constraint(local_out, (BATCH, KV_LENGTH, HEAD, D_KV)) + if self.reshape_q and q_seq_len == 1: local_max = local_max[:, 0:1, :, :] local_sum = local_sum[:, 0:1, :, :] @@ -487,16 +490,22 @@ def apply_attention_dot( key = key.astype(jnp.float32) q_seq_len = query.shape[1] + + query = partitioning.with_sharding_constraint(query, (BATCH, LENGTH, HEAD, D_KV)) + key = partitioning.with_sharding_constraint(key, (BATCH, KV_LENGTH, HEAD, D_KV)) + value = partitioning.with_sharding_constraint(value, (BATCH, KV_LENGTH, HEAD, D_KV)) + attn_weights = self.qk_product(query, key, q_seq_len, model_mode) + attn_weights = partitioning.with_sharding_constraint(attn_weights, (BATCH, HEAD, None, LENGTH, KV_LENGTH)) + if self.attn_logits_soft_cap: attn_weights = jnp.tanh(attn_weights / self.attn_logits_soft_cap) attn_weights = attn_weights * self.attn_logits_soft_cap - - # Casting softmaxt computation for float32 for model stability. - if self.float32_logits: + if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_logits: attn_weights = attn_weights.astype(jnp.float32) attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + attn_mask = partitioning.with_sharding_constraint(attn_mask, (BATCH, HEAD, None, LENGTH, KV_LENGTH)) if attn_mask is not None: attn_weights = apply_mask_to_logits(attn_weights, attn_mask) return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 512ebab61..09d8b6127 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -19,6 +19,7 @@ from typing import Any, Callable, Iterable, Sequence, Tuple, Union, Optional import flax +from flax.linen import partitioning import flax.linen as nn import jax from jax import lax @@ -529,6 +530,9 @@ def generate_masks(self, top_k_indices, softmax_probs): self.config.capacity_factor, ) ) + cp = self.config.ici_context_parallelism + if seq_len % cp == 0: + expert_capacity_per_batch = max(math.ceil(expert_capacity_per_batch / cp), self.config.capacity_factor) max_logging.log(f"Applying potential token dropping with a batch expert_capacity of {expert_capacity_per_batch}") # calculate expert mask and drop tokens if needed @@ -576,6 +580,12 @@ def generate_masks(self, top_k_indices, softmax_probs): combine_mask = softmax_probs[..., None] * expert_token_position_in_capacity combine_mask = combine_mask[..., 1:] dispatch_mask = combine_mask.astype(bool) + + #ici_context_parallelism + if seq_len % cp == 0: + dispatch_mask = jnp.reshape(dispatch_mask, (batch_size, cp, seq_len//cp, self.num_experts, expert_capacity_per_batch)) + combine_mask = jnp.reshape(combine_mask, (batch_size, cp, seq_len//cp, self.num_experts, expert_capacity_per_batch)) + return dispatch_mask, combine_mask # See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. @@ -636,49 +646,91 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): weights = self.reshape_and_update_weights(top_k_weights, top_k_indices) matmul_precision = lax.Precision(self.config.matmul_precision) + cp = self.config.ici_context_parallelism + batch_size = inputs.shape[0] + seq_len = inputs.shape[1] + + do_cp = seq_len % cp == 0 if self.config.capacity_factor > 0: # token dropping if needed dispatch_mask, combine_mask = self.generate_masks(top_k_indices, weights) - mask_axes = ("activation_batch", "activation_length", None, None) - # dispatch_mask = nn.with_logical_constraint(dispatch_mask, mask_axes) - # combine_mask = nn.with_logical_constraint(combine_mask, mask_axes) + mask_axes = ("activation_batch", "activation_length", None, None, None) + if do_cp: + mask_axes = ("activation_batch", "activation_length", None, None, None) + input_axis = ("activation_batch", "activation_length", None, "activation_embed") + dispatch_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_embed") + mlp_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_mlp") + else: + mask_axes = ("activation_batch", "activation_length", None, None) + input_axis = ("activation_batch", "activation_length", "activation_embed") + dispatch_axis = ("activation_exp", "activation_batch_no_exp", None, "activation_embed") + mlp_axis = ("activation_exp", "activation_batch_no_exp", None, "activation_mlp") + + dispatch_mask = nn.with_logical_constraint(dispatch_mask, mask_axes) + combine_mask = nn.with_logical_constraint(combine_mask, mask_axes) if self.config.model_call_mode != "inference": loss = self.load_balance_loss(top_k_indices, weights) else: loss = None - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) + + if do_cp: + inputs = jnp.reshape(inputs,(batch_size, cp, seq_len//cp, inputs.shape[2])) + + inputs = nn.with_logical_constraint(inputs, input_axis) + with jax.named_scope("dispatch"): - dispatch = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=DISPATCH)( - "BSM,BSEC -> EBCM", inputs, dispatch_mask, precision=matmul_precision - ) - dispatch = nn.with_logical_constraint( + if do_cp: + dispatch = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=DISPATCH)( + "BNSM,BNSEC -> EBNCM", inputs, dispatch_mask, precision=matmul_precision + ) + dispatch = nn.with_logical_constraint( dispatch, - ("activation_exp", "activation_batch_no_exp", None, "activation_embed"), - ) + (None, "activation_batch_no_exp", "activation_length", None, "activation_embed"), + ) + else: + dispatch = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=DISPATCH)( + "BSM,BSEC -> EBCM", inputs, dispatch_mask, precision=matmul_precision + ) + dispatch = nn.with_logical_constraint( + dispatch, + dispatch_axis, + ) + #print("dispatch", dispatch.shape) with jax.named_scope("wi_0"): w0_kernel_axes = ("exp", None, "mlp") w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes) - layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( - "EBCM,EMH -> EBCH", dispatch, w0_kernel, precision=matmul_precision - ) + if do_cp: + layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( + "EBNCM,EMH -> EBNCH", dispatch, w0_kernel, precision=matmul_precision + ) + else: + layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( + "EBCM,EMH -> EBCH", dispatch, w0_kernel, precision=matmul_precision + ) + if self.config.activations_in_float32: layer_w0 = layer_w0.astype(jnp.float32) layer_w0 = nn.with_logical_constraint( layer_w0, - ("activation_exp", "activation_batch_no_exp", None, "activation_mlp"), + mlp_axis, ) layer_w0 = checkpoint_name(layer_w0, "mlpwi_0") with jax.named_scope("wi_1"): w1_kernel_axes = ("exp", None, "mlp") w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes) - layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( - "EBCM,EMH -> EBCH", dispatch, w1_kernel, precision=matmul_precision - ) + if do_cp: + layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( + "EBNCM,EMH -> EBNCH", dispatch, w1_kernel, precision=matmul_precision + ) + else: + layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( + "EBCM,EMH -> EBCH", dispatch, w1_kernel, precision=matmul_precision + ) if self.config.activations_in_float32: layer_w1 = layer_w1.astype(jnp.float32) layer_w1 = nn.with_logical_constraint( layer_w1, - ("activation_exp", "activation_batch_no_exp", None, "activation_mlp"), + mlp_axis, ) layer_w1 = checkpoint_name(layer_w1, "mlpwi_1") layer_w0_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) @@ -686,9 +738,14 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): with jax.named_scope("wo"): wo_kernel_axes = ("exp", "mlp", None) wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes) - intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( - "EBCH,EHM -> EBCM", layer_multiply, wo_kernel, precision=matmul_precision - ) + if do_cp: + intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( + "EBNCH,EHM -> EBNCM", layer_multiply, wo_kernel, precision=matmul_precision + ) + else: + intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( + "EBCH,EHM -> EBCM", layer_multiply, wo_kernel, precision=matmul_precision + ) intermediate_layer = nn.with_logical_constraint( intermediate_layer, ("activation_exp", "activation_batch_no_exp", None, "activation_embed"), @@ -698,14 +755,26 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): intermediate_layer = checkpoint_name(intermediate_layer, "mlpwo") with jax.named_scope("combine"): # Matmul & element wise operation - output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)( - "EBCM,BSEC -> BSM", - intermediate_layer, - combine_mask, - precision=matmul_precision, - ).astype(self.dtype) + if do_cp: + output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)( + "EBNCM,BNSEC -> BNSM", + intermediate_layer, + combine_mask, + precision=matmul_precision, + ) + output = jnp.reshape(output, (output.shape[0], output.shape[1]*output.shape[2], output.shape[3])) + else: + output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)( + "EBCM,BSEC -> BSM", + intermediate_layer, + combine_mask, + precision=matmul_precision, + ) .astype(self.dtype) return output, loss else: + top_k_weights /= top_k_weights.sum(-1, keepdims=True) + weights = self.reshape_and_update_weights(top_k_weights, top_k_indices) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) with jax.named_scope("wi_0"): layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( From 7ed5fd1da0fef0eeb045c07fef4d3e344c5ea3cd Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Thu, 27 Feb 2025 06:15:43 +0000 Subject: [PATCH 05/34] fix merge parts --- MaxText/common_types.py | 1 + MaxText/layers/attentions.py | 1 + MaxText/max_utils.py | 1 - MaxText/maxtext_utils.py | 3 ++- MaxText/pyconfig.py | 3 +++ 5 files changed, 7 insertions(+), 2 deletions(-) diff --git a/MaxText/common_types.py b/MaxText/common_types.py index c96bcaeef..bdca4f380 100644 --- a/MaxText/common_types.py +++ b/MaxText/common_types.py @@ -36,6 +36,7 @@ BATCH = "activation_batch" LENGTH = "activation_length" +KV_LENGTH = "activation_kv_length" EMBED = "activation_embed" HEAD = "activation_heads" PREFILL_KV_BATCH = "activation_prefill_kv_batch" diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index edca78fdc..6ccb753a6 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -68,6 +68,7 @@ class AttentionType(enum.Enum): PREFILL_KV_BATCH = common_types.PREFILL_KV_BATCH KV_BATCH = common_types.KV_BATCH LENGTH = common_types.LENGTH +KV_LENGTH = common_types.KV_LENGTH HEAD = common_types.HEAD EMBED = common_types.EMBED KV_HEAD = common_types.KV_HEAD diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index ae0e49563..a8e29b589 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -640,7 +640,6 @@ def create_device_mesh(config, devices=None): num_devices_per_slice = num_devices // num_slices multi_slice_env = num_slices > 1 - # Find possible unspecified parallelisms ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index f547c5b12..7c355bc22 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -239,8 +239,9 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance): """ total_num_params = max_utils.calculate_num_params_from_pytree(params) product_num_devices_for_weight_sharding = 1 - for axis in ["fsdp", "fsdp_transpose", "sequence", "tensor", "tensor_transpose", "tensor_sequence", "stage", "expert"]: + for axis in ["fsdp", "fsdp_transpose", "sequence", "context", "tensor", "tensor_transpose", "tensor_sequence", "stage", "expert"]: product_num_devices_for_weight_sharding *= mesh.shape[axis] + print(product_num_devices_for_weight_sharding) total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params) perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, ( diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 5c0d5c007..e2b50f8cd 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -210,6 +210,7 @@ def validate_model_name(s: str) -> bool: "llama3.1-70b", "llama3.1-405b", "llama3.3-70b", + "subsup", "mistral-7b", "mixtral-8x7b", "mixtral-8x22b", @@ -532,6 +533,7 @@ def create_parallelisms_list(raw_keys): raw_keys["ici_fsdp_parallelism"], raw_keys["ici_fsdp_transpose_parallelism"], raw_keys["ici_sequence_parallelism"], + raw_keys["ici_context_parallelism"], raw_keys["ici_tensor_parallelism"], raw_keys["ici_tensor_transpose_parallelism"], raw_keys["ici_tensor_sequence_parallelism"], @@ -544,6 +546,7 @@ def create_parallelisms_list(raw_keys): raw_keys["dcn_fsdp_parallelism"], raw_keys["dcn_fsdp_transpose_parallelism"], raw_keys["dcn_sequence_parallelism"], + raw_keys["dcn_context_parallelism"], raw_keys["dcn_tensor_parallelism"], raw_keys["dcn_tensor_transpose_parallelism"], raw_keys["dcn_tensor_sequence_parallelism"], From a1c697350dd1eb662d053db659620e68d0890885 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Thu, 27 Feb 2025 17:56:01 +0000 Subject: [PATCH 06/34] update merge confict base config --- MaxText/configs/base.yml | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index f7ace1462..4ab7f2399 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -253,21 +253,21 @@ jax_cache_dir: "~/jax_cache" hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' # Parallelism -mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] +mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], - ['activation_length', ['sequence']], + ['activation_length', ['sequence', 'context']], + ['activation_length', ['context']], + ['activation_length_q', ['context']], + ['activation_kv_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'sequence']], ['activation_embed', ['tensor', 'tensor_transpose']], - ['activation_mlp', ['context''tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_length', ['context']], - ['activation_length_q', ['context']], - ['activation_kv_length', ['context']] ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], @@ -276,16 +276,21 @@ logical_axis_rules: [ ['activation_vocab', 'tensor_sequence'], ['activation_vocab', 'sequence'], ['activation_stage', 'stage'], - ['activation_exp', 'context'], - ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']], - ['vocab', ['tensor', 'autoregressive']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['embed', ['fsdp', 'sequence', 'context', 'expert']], + ['activation_exp', ['expert', 'context']], + ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']], + ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], + ['embed', ['fsdp', 'sequence', 'expert']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose']], + ['embed_no_exp', ['fsdp', 'sequence', 'context', 'tensor_transpose']], ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], ['embed_no_exp', ['fsdp', 'sequence', 'context']], - ['norm', 'tensor'], - ['q_heads', ['tensor', 'autoregressive']], - ['heads', ['tensor', 'autoregressive']], + ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['layers', 'stage'], ['kv', []], ['kv_head_dim', []], @@ -295,10 +300,10 @@ logical_axis_rules: [ ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['cache_kv', []], ['cache_sequence', []], - ['exp', 'context'], + ['exp', ['expert', 'context']], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'conext', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] # sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters. sharding_tolerance: 0.02 @@ -311,6 +316,7 @@ dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 dcn_fsdp_transpose_parallelism: 1 dcn_sequence_parallelism: 1 # never recommended +dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 # never recommended dcn_tensor_transpose_parallelism: 1 dcn_tensor_sequence_parallelism: 1 # never recommended @@ -321,6 +327,7 @@ ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_fsdp_transpose_parallelism: 1 ici_sequence_parallelism: 1 +ici_context_parallelism: 1 ici_tensor_parallelism: 1 ici_tensor_transpose_parallelism: 1 ici_tensor_sequence_parallelism: 1 From 3f2d2788e40ea55b9a0148f4743a021676b9b07f Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 28 Feb 2025 22:16:05 +0000 Subject: [PATCH 07/34] update to fix sharding mismatch --- MaxText/configs/base.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 4ab7f2399..a678fc449 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -299,7 +299,7 @@ logical_axis_rules: [ ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['cache_kv', []], - ['cache_sequence', []], + ['cache_sequence', ['context']], ['exp', ['expert', 'context']], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details From 3e06ebb64c64eaaf3bf12dc74ca915d1e5949f70 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Thu, 27 Feb 2025 22:36:30 +0000 Subject: [PATCH 08/34] update sub_seq for masks --- MaxText/layers/linears.py | 131 +++++++++++++++----------------------- 1 file changed, 50 insertions(+), 81 deletions(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 09d8b6127..d84f194c1 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -522,7 +522,14 @@ def reshape_and_update_weights(self, weights, indices): def generate_masks(self, top_k_indices, softmax_probs): # calculate expert_capacity = (tokens_per_batch / num_experts) * capacity_factor batch_size, seq_len, _ = top_k_indices.shape - tokens_per_batch = seq_len * self.num_experts_per_tok + cp = self.config.ici_context_parallelism + if seq_len % cp != 0: + cp = 1 + sub_seq = seq_len // cp + + top_k_indices = jnp.reshape(top_k_indices, (batch_size, cp, sub_seq, top_k_indices.shape[2])) + + tokens_per_batch = sub_seq * self.num_experts_per_tok # this is to avoid expert_capacity_per_batch = 0 expert_capacity_per_batch = int( max( @@ -530,9 +537,6 @@ def generate_masks(self, top_k_indices, softmax_probs): self.config.capacity_factor, ) ) - cp = self.config.ici_context_parallelism - if seq_len % cp == 0: - expert_capacity_per_batch = max(math.ceil(expert_capacity_per_batch / cp), self.config.capacity_factor) max_logging.log(f"Applying potential token dropping with a batch expert_capacity of {expert_capacity_per_batch}") # calculate expert mask and drop tokens if needed @@ -546,29 +550,33 @@ def generate_masks(self, top_k_indices, softmax_probs): # trunc_expert_mask becomes [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 0, 0, 0],[0, 0, 0, 1]]]], # so the 2nd token for expert #1 ([0, 1] & [1, 3]) is dropped, output of updated_expert_mask is [[[1, 1],[0, 1]]]. expert_mask = jax.nn.one_hot(top_k_indices, num_classes=self.num_experts, dtype=jnp.int32) - expert_mask_fused = jnp.reshape(expert_mask, (batch_size, seq_len * self.num_experts_per_tok, self.num_experts)) - expert_mask_fused = nn.with_logical_constraint(expert_mask_fused, ("activation_batch", None, None)) - expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=1) + expert_mask_fused = expert_mask + expert_mask_fused = nn.with_logical_constraint(expert_mask_fused, ("activation_batch", None, None, None)) + expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=2) expert_token_count = jnp.reshape( expert_token_count_fused, - ((batch_size, seq_len, self.num_experts_per_tok, self.num_experts)), + ((batch_size, cp, sub_seq, self.num_experts_per_tok, self.num_experts)), ) expert_token_count = nn.with_logical_constraint( - expert_token_count, ("activation_batch", "activation_length", None, None) + expert_token_count, ("activation_batch", None, "activation_length", None, None) ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) - combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2) + combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3) # reshape & update weights + softmax_probs = jnp.reshape( + softmax_probs, + ((batch_size, cp, sub_seq, self.num_experts)), + ) softmax_probs *= combined_expert_mask # calculate token position in expert capacity dimension expert_token_position_fused = expert_mask_fused * expert_token_count_fused expert_token_position = jnp.reshape( expert_token_position_fused, - (batch_size, seq_len, self.num_experts_per_tok, self.num_experts), + (batch_size, cp, sub_seq, self.num_experts_per_tok, self.num_experts), ) - combined_expert_token_position = jnp.sum(expert_token_position, axis=2) * combined_expert_mask + combined_expert_token_position = jnp.sum(expert_token_position, axis=3) * combined_expert_mask expert_token_position_in_capacity = jax.nn.one_hot( combined_expert_token_position, num_classes=expert_capacity_per_batch + 1, @@ -582,9 +590,8 @@ def generate_masks(self, top_k_indices, softmax_probs): dispatch_mask = combine_mask.astype(bool) #ici_context_parallelism - if seq_len % cp == 0: - dispatch_mask = jnp.reshape(dispatch_mask, (batch_size, cp, seq_len//cp, self.num_experts, expert_capacity_per_batch)) - combine_mask = jnp.reshape(combine_mask, (batch_size, cp, seq_len//cp, self.num_experts, expert_capacity_per_batch)) + dispatch_mask = jnp.reshape(dispatch_mask, (batch_size, cp, sub_seq, self.num_experts, expert_capacity_per_batch)) + combine_mask = jnp.reshape(combine_mask, (batch_size, cp, sub_seq, self.num_experts, expert_capacity_per_batch)) return dispatch_mask, combine_mask @@ -649,22 +656,16 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): cp = self.config.ici_context_parallelism batch_size = inputs.shape[0] seq_len = inputs.shape[1] - - do_cp = seq_len % cp == 0 + if seq_len % cp != 0: + cp = 1 + sub_seq = seq_len // cp if self.config.capacity_factor > 0: # token dropping if needed - dispatch_mask, combine_mask = self.generate_masks(top_k_indices, weights) + dispatch_mask, combine_mask = self.generate_masks(top_k_indices, softmax_probs) mask_axes = ("activation_batch", "activation_length", None, None, None) - if do_cp: - mask_axes = ("activation_batch", "activation_length", None, None, None) - input_axis = ("activation_batch", "activation_length", None, "activation_embed") - dispatch_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_embed") - mlp_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_mlp") - else: - mask_axes = ("activation_batch", "activation_length", None, None) - input_axis = ("activation_batch", "activation_length", "activation_embed") - dispatch_axis = ("activation_exp", "activation_batch_no_exp", None, "activation_embed") - mlp_axis = ("activation_exp", "activation_batch_no_exp", None, "activation_mlp") + input_axis = ("activation_batch", "activation_length", None, "activation_embed") + dispatch_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_embed") + mlp_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_mlp") dispatch_mask = nn.with_logical_constraint(dispatch_mask, mask_axes) combine_mask = nn.with_logical_constraint(combine_mask, mask_axes) @@ -672,41 +673,28 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): loss = self.load_balance_loss(top_k_indices, weights) else: loss = None - - if do_cp: - inputs = jnp.reshape(inputs,(batch_size, cp, seq_len//cp, inputs.shape[2])) + inputs = jnp.reshape(inputs,(batch_size, cp, sub_seq, inputs.shape[2])) inputs = nn.with_logical_constraint(inputs, input_axis) with jax.named_scope("dispatch"): - if do_cp: - dispatch = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=DISPATCH)( + dispatch = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=DISPATCH)( "BNSM,BNSEC -> EBNCM", inputs, dispatch_mask, precision=matmul_precision ) - dispatch = nn.with_logical_constraint( - dispatch, - (None, "activation_batch_no_exp", "activation_length", None, "activation_embed"), - ) - else: - dispatch = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=DISPATCH)( - "BSM,BSEC -> EBCM", inputs, dispatch_mask, precision=matmul_precision - ) + dispatch = nn.with_logical_constraint( + dispatch, + (None, "activation_batch_no_exp", "activation_length", None, "activation_embed"), + ) dispatch = nn.with_logical_constraint( dispatch, dispatch_axis, ) - #print("dispatch", dispatch.shape) with jax.named_scope("wi_0"): w0_kernel_axes = ("exp", None, "mlp") w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes) - if do_cp: - layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( + layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( "EBNCM,EMH -> EBNCH", dispatch, w0_kernel, precision=matmul_precision ) - else: - layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( - "EBCM,EMH -> EBCH", dispatch, w0_kernel, precision=matmul_precision - ) if self.config.activations_in_float32: layer_w0 = layer_w0.astype(jnp.float32) @@ -718,14 +706,9 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): with jax.named_scope("wi_1"): w1_kernel_axes = ("exp", None, "mlp") w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes) - if do_cp: - layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( + layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( "EBNCM,EMH -> EBNCH", dispatch, w1_kernel, precision=matmul_precision ) - else: - layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( - "EBCM,EMH -> EBCH", dispatch, w1_kernel, precision=matmul_precision - ) if self.config.activations_in_float32: layer_w1 = layer_w1.astype(jnp.float32) layer_w1 = nn.with_logical_constraint( @@ -738,38 +721,24 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): with jax.named_scope("wo"): wo_kernel_axes = ("exp", "mlp", None) wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes) - if do_cp: - intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( + intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( "EBNCH,EHM -> EBNCM", layer_multiply, wo_kernel, precision=matmul_precision ) - else: - intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( - "EBCH,EHM -> EBCM", layer_multiply, wo_kernel, precision=matmul_precision - ) - intermediate_layer = nn.with_logical_constraint( - intermediate_layer, - ("activation_exp", "activation_batch_no_exp", None, "activation_embed"), - ) - if self.config.activations_in_float32: - intermediate_layer = intermediate_layer.astype(jnp.float32) + + # intermediate_layer = nn.with_logical_constraint( + # intermediate_layer, + # ("activation_exp", "activation_batch_no_exp", None, "activation_embed"), + # ) intermediate_layer = checkpoint_name(intermediate_layer, "mlpwo") with jax.named_scope("combine"): # Matmul & element wise operation - if do_cp: - output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)( - "EBNCM,BNSEC -> BNSM", - intermediate_layer, - combine_mask, - precision=matmul_precision, - ) - output = jnp.reshape(output, (output.shape[0], output.shape[1]*output.shape[2], output.shape[3])) - else: - output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)( - "EBCM,BSEC -> BSM", - intermediate_layer, - combine_mask, - precision=matmul_precision, - ) .astype(self.dtype) + output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)( + "EBNCM,BNSEC -> BNSM", + intermediate_layer, + combine_mask, + precision=matmul_precision, + ) + output = jnp.reshape(output, (output.shape[0], output.shape[1]*output.shape[2], output.shape[3])) return output, loss else: top_k_weights /= top_k_weights.sum(-1, keepdims=True) From d23d27b7d46bf5d7c070b3bc560f3b2f36f52634 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Thu, 27 Feb 2025 23:05:06 +0000 Subject: [PATCH 09/34] update sharding axis --- MaxText/layers/linears.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index d84f194c1..56fe26f3d 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -558,7 +558,7 @@ def generate_masks(self, top_k_indices, softmax_probs): ((batch_size, cp, sub_seq, self.num_experts_per_tok, self.num_experts)), ) expert_token_count = nn.with_logical_constraint( - expert_token_count, ("activation_batch", None, "activation_length", None, None) + expert_token_count, ("activation_batch", "activation_length", None, None, None) ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3) From 924ce77fc8def8661076c86d79580cc58f092d2b Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 28 Feb 2025 00:42:43 +0000 Subject: [PATCH 10/34] update with reshape --- MaxText/layers/linears.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 56fe26f3d..062d0f047 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -550,7 +550,7 @@ def generate_masks(self, top_k_indices, softmax_probs): # trunc_expert_mask becomes [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 0, 0, 0],[0, 0, 0, 1]]]], # so the 2nd token for expert #1 ([0, 1] & [1, 3]) is dropped, output of updated_expert_mask is [[[1, 1],[0, 1]]]. expert_mask = jax.nn.one_hot(top_k_indices, num_classes=self.num_experts, dtype=jnp.int32) - expert_mask_fused = expert_mask + expert_mask_fused = jnp.reshape(expert_mask, (batch_size, cp, sub_seq * self.num_experts_per_tok, self.num_experts)) expert_mask_fused = nn.with_logical_constraint(expert_mask_fused, ("activation_batch", None, None, None)) expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=2) expert_token_count = jnp.reshape( From b62812dc860436234c635aa682458ab79273acd5 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Sat, 1 Mar 2025 00:11:11 +0000 Subject: [PATCH 11/34] solve merge conflict --- MaxText/layers/linears.py | 1 + 1 file changed, 1 insertion(+) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 062d0f047..fb43384ab 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -642,6 +642,7 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): # gate_logits: batch, length, expert # follow router_logits = shd.shard(router_logits, (None, None, None)) gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_length", None)) + softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype) # shape of top_k_weights & top_k_indices: (batch, sequence, num_experts_per_tok) top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok) From 746f4a3a36d903e2d92384586a6074c1772c5c7c Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 28 Feb 2025 01:44:38 +0000 Subject: [PATCH 12/34] update for generate sharding --- MaxText/layers/attentions.py | 24 +++++++++++++++++------- MaxText/layers/linears.py | 15 ++++++++++----- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 6ccb753a6..6cbd4c4b3 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -491,14 +491,21 @@ def apply_attention_dot( key = key.astype(jnp.float32) q_seq_len = query.shape[1] - - query = partitioning.with_sharding_constraint(query, (BATCH, LENGTH, HEAD, D_KV)) - key = partitioning.with_sharding_constraint(key, (BATCH, KV_LENGTH, HEAD, D_KV)) - value = partitioning.with_sharding_constraint(value, (BATCH, KV_LENGTH, HEAD, D_KV)) + # special sharding for decode + if self.config.ici_context_parallelism > 0 and q_seq_len == 1: + query = partitioning.with_sharding_constraint(query, (KV_LENGTH, None, HEAD, D_KV)) + key = partitioning.with_sharding_constraint(key, (KV_LENGTH, None, HEAD, D_KV)) + value = partitioning.with_sharding_constraint(value, (KV_LENGTH, None, HEAD, D_KV)) + else: + query = partitioning.with_sharding_constraint(query, (BATCH, LENGTH, HEAD, D_KV)) + key = partitioning.with_sharding_constraint(key, (BATCH, KV_LENGTH, HEAD, D_KV)) + value = partitioning.with_sharding_constraint(value, (BATCH, KV_LENGTH, HEAD, D_KV)) attn_weights = self.qk_product(query, key, q_seq_len, model_mode) - - attn_weights = partitioning.with_sharding_constraint(attn_weights, (BATCH, HEAD, None, LENGTH, KV_LENGTH)) + if self.config.ici_context_parallelism > 0 and q_seq_len == 1: + attn_weights = partitioning.with_sharding_constraint(attn_weights, (KV_LENGTH, HEAD, None, None, None)) + else: + attn_weights = partitioning.with_sharding_constraint(attn_weights, (BATCH, HEAD, None, LENGTH, KV_LENGTH)) if self.attn_logits_soft_cap: attn_weights = jnp.tanh(attn_weights / self.attn_logits_soft_cap) @@ -506,7 +513,10 @@ def apply_attention_dot( if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_logits: attn_weights = attn_weights.astype(jnp.float32) attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) - attn_mask = partitioning.with_sharding_constraint(attn_mask, (BATCH, HEAD, None, LENGTH, KV_LENGTH)) + if self.config.ici_context_parallelism > 0 and q_seq_len == 1: + attn_mask = partitioning.with_sharding_constraint(attn_mask, (KV_LENGTH, HEAD, None, None, None)) + else: + attn_mask = partitioning.with_sharding_constraint(attn_mask, (BATCH, HEAD, None, LENGTH, KV_LENGTH)) if attn_mask is not None: attn_weights = apply_mask_to_logits(attn_weights, attn_mask) return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index fb43384ab..0a9ab16c7 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -663,11 +663,16 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): if self.config.capacity_factor > 0: # token dropping if needed dispatch_mask, combine_mask = self.generate_masks(top_k_indices, softmax_probs) - mask_axes = ("activation_batch", "activation_length", None, None, None) - input_axis = ("activation_batch", "activation_length", None, "activation_embed") - dispatch_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_embed") - mlp_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_mlp") - + if self.config.ici_context_parallelism > 0 and cp == 1: + mask_axes = ( "activation_length", "activation_batch", None, None, None) + input_axis = ("activation_length", "activation_batch", None, "activation_embed") + dispatch_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_embed") + mlp_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_mlp") + else: + mask_axes = ("activation_batch", "activation_length", None, None, None) + input_axis = ("activation_batch", "activation_length", None, "activation_embed") + dispatch_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_embed") + mlp_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_mlp") dispatch_mask = nn.with_logical_constraint(dispatch_mask, mask_axes) combine_mask = nn.with_logical_constraint(combine_mask, mask_axes) if self.config.model_call_mode != "inference": From a6d345c72dcc3f4bb4cc0a54825b822b4bcc4706 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Tue, 4 Mar 2025 03:08:49 +0000 Subject: [PATCH 13/34] enable compute_axis configurable in mixtral model --- MaxText/layers/mistral.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index c3fe2b5bf..aba706777 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -101,6 +101,9 @@ def __call__( float32_logits=cfg.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(cfg), + prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]), + ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]), + compute_axis_order=tuple([int(i) for i in cfg.compute_axis_order.split(",")]), ) attention_lnx = attention_layer( From e06c3d67cfbca403d943b609b47b77fef546a9fb Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Wed, 5 Mar 2025 20:02:54 +0000 Subject: [PATCH 14/34] address output_logits sharding --- MaxText/common_types.py | 2 ++ MaxText/configs/base.yml | 10 +++++--- MaxText/layers/attentions.py | 50 ++++++++++++++++++++++++++++++------ MaxText/layers/linears.py | 10 +++++--- MaxText/layers/models.py | 13 +++++++--- 5 files changed, 66 insertions(+), 19 deletions(-) diff --git a/MaxText/common_types.py b/MaxText/common_types.py index bdca4f380..dfbacb9e6 100644 --- a/MaxText/common_types.py +++ b/MaxText/common_types.py @@ -44,6 +44,8 @@ KV_HEAD = "activation_kv_heads" KV_HEAD_DIM = "activation_kv_head_dim" D_KV = "activation_kv" +DECODE_BATCH = "decode_batch" +DECODE_LENGTH = "decode_length" CACHE_BATCH_PREFILL = "cache_batch_prefill" CACHE_BATCH = "cache_batch" CACHE_SEQUENCE = "cache_sequence" diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index a678fc449..f857971d2 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -269,16 +269,18 @@ logical_axis_rules: [ ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], - ['activation_vocab', 'sequence'], + ['activation_vocab', ['sequence', 'context']], ['activation_stage', 'stage'], ['activation_exp', ['expert', 'context']], + ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], + ['decode_length', []], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], - ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], @@ -295,7 +297,7 @@ logical_axis_rules: [ ['kv', []], ['kv_head_dim', []], ['cache_batch_prefill', []], - ['cache_batch', []], + ['cache_batch', ['context']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['cache_kv', []], diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 6cbd4c4b3..94687f15f 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -67,6 +67,8 @@ class AttentionType(enum.Enum): BATCH = common_types.BATCH PREFILL_KV_BATCH = common_types.PREFILL_KV_BATCH KV_BATCH = common_types.KV_BATCH +DECODE_BATCH = common_types.DECODE_BATCH +DECODE_LENGTH = common_types.DECODE_LENGTH LENGTH = common_types.LENGTH KV_LENGTH = common_types.KV_LENGTH HEAD = common_types.HEAD @@ -463,14 +465,29 @@ def compute_local_attention( local_sum = jnp.reshape(local_sum, (local_sum.shape[0], local_sum.shape[1], local_sum.shape[2] * local_sum.shape[3], 1)) local_out = self.wv_product(local_exps, value, model_mode) + if model_mode == common_types.MODEL_MODE_PREFILL: + print("prefill") + print(local_out.shape) + local_out = partitioning.with_sharding_constraint(local_out, (BATCH, KV_LENGTH, HEAD, D_KV)) + #local_out = partitioning.with_sharding_constraint(local_out, DECODE_BATCH, HEAD, D_KV) + else: + print("decode") + print(local_out.shape) + local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) - local_out = partitioning.with_sharding_constraint(local_out, (BATCH, KV_LENGTH, HEAD, D_KV)) if self.reshape_q and q_seq_len == 1: local_max = local_max[:, 0:1, :, :] local_sum = local_sum[:, 0:1, :, :] local_out = local_out[:, 0:1, :, :] + if model_mode != common_types.MODEL_MODE_PREFILL: + local_max = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) + local_sum = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) + local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) + + print(local_sum.shape) + print(local_max.shape) return local_out, local_max, local_sum def apply_attention_dot( @@ -493,9 +510,9 @@ def apply_attention_dot( q_seq_len = query.shape[1] # special sharding for decode if self.config.ici_context_parallelism > 0 and q_seq_len == 1: - query = partitioning.with_sharding_constraint(query, (KV_LENGTH, None, HEAD, D_KV)) - key = partitioning.with_sharding_constraint(key, (KV_LENGTH, None, HEAD, D_KV)) - value = partitioning.with_sharding_constraint(value, (KV_LENGTH, None, HEAD, D_KV)) + query = partitioning.with_sharding_constraint(query, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) + key = partitioning.with_sharding_constraint(key, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) + value = partitioning.with_sharding_constraint(value, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) else: query = partitioning.with_sharding_constraint(query, (BATCH, LENGTH, HEAD, D_KV)) key = partitioning.with_sharding_constraint(key, (BATCH, KV_LENGTH, HEAD, D_KV)) @@ -546,6 +563,9 @@ def qk_product(self, query: Array, key: Array | KVTensor, q_seq_len: int, model_ b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads + print(query.shape) + print(key.shape) + print(self.compute_axis_order) if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0, 1, 2, 3): query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) if self.reshape_q and q_seq_len == 1: @@ -846,6 +866,7 @@ def update_ar_key_value( one_token_value_shaped_for_cache = jnp.transpose(one_token_value, self.ar_cache_axis_order) ar_cache_axis_names = self.transpose_tuple(self.cache_logical_axis_names, self.ar_cache_axis_order) + print(ar_cache_axis_names) if self.kv_quant: one_token_key_shaped_for_cache, one_token_key_scale_shaped_for_cache = self.kv_quant.quantize( one_token_key_shaped_for_cache, ar_cache_axis_names @@ -890,7 +911,7 @@ def value_body(i, val): cached_value_var.value = jax.lax.dynamic_update_index_in_dim( cached_value_var.value, one_token_value_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis ) - + print("cv", cached_value_var.value.shape) cached_key_var.value = nn.with_logical_constraint(cached_key_var.value, ar_cache_axis_names) cached_value_var.value = nn.with_logical_constraint(cached_value_var.value, ar_cache_axis_names) @@ -1139,9 +1160,11 @@ class Attention(nn.Module): prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM) query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM) input_axis_names: AxisNames = (BATCH, LENGTH, EMBED) + decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED) key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM) value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM) out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) + decode_out_axis_names = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV) prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3) ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3) @@ -1347,8 +1370,12 @@ def __call__( Returns: output of shape `[batch, length, q_features]`. """ - inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names) - inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names) + if model_mode == common_types.MODEL_MODE_PREFILL: + inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names) + inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names) + else: + inputs_q = nn.with_logical_constraint(inputs_q, self.decode_input_axis_names) + inputs_kv = nn.with_logical_constraint(inputs_kv, self.decode_input_axis_names) # apply projection. if self.config.fused_qkv: @@ -1366,6 +1393,10 @@ def __call__( query = nn.with_logical_constraint(query, self.prefill_query_axis_names) key = nn.with_logical_constraint(key, self.prefill_key_axis_names) value = nn.with_logical_constraint(value, self.prefill_value_axis_names) + elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + query = nn.with_logical_constraint(query, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) + key = nn.with_logical_constraint(key,(DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) + value = nn.with_logical_constraint(value, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) else: query = nn.with_logical_constraint(query, self.query_axis_names) key = nn.with_logical_constraint(key, self.key_axis_names) @@ -1378,7 +1409,10 @@ def __call__( out = self.attention_op(query, key, value, decoder_segment_ids, model_mode) - out = nn.with_logical_constraint(out, self.out_axis_names) + if model_mode == common_types.MODEL_MODE_PREFILL: + out = nn.with_logical_constraint(out, self.out_axis_names) + else: + out = nn.with_logical_constraint(out, self.decode_out_axis_names) # apply output projection, output dim is set to the input dim. out = self.out_projection(inputs_q.shape[-1], out) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 0a9ab16c7..96b876d2a 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -684,13 +684,15 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): inputs = nn.with_logical_constraint(inputs, input_axis) with jax.named_scope("dispatch"): + # only cp during prefill dispatch = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=DISPATCH)( "BNSM,BNSEC -> EBNCM", inputs, dispatch_mask, precision=matmul_precision ) - dispatch = nn.with_logical_constraint( - dispatch, - (None, "activation_batch_no_exp", "activation_length", None, "activation_embed"), - ) + if cp > 1: + dispatch = nn.with_logical_constraint( + dispatch, + (None, "activation_batch_no_exp", "activation_length", None, "activation_embed"), + ) dispatch = nn.with_logical_constraint( dispatch, dispatch_axis, diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 8f423e712..ea998098f 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -504,9 +504,16 @@ def __call__( )( y ) # We do not quantize the logits matmul. - logits = nn.with_logical_constraint( - logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab") - ) + + if model_mode in [common_types.MODEL_MODE_PREFILL, common_types.MODEL_MODE_AUTOREGRESSIVE]: + logits = nn.with_logical_constraint( + logits, (None, None, "activation_vocab") + ) + else: + logits = nn.with_logical_constraint( + logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab") + ) + if self.config.cast_logits_to_fp32: logits = logits.astype(jnp.float32) return logits From 65a64d443659ca1606228273b2b102195c7d81c5 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Wed, 5 Mar 2025 21:18:10 +0000 Subject: [PATCH 15/34] clean up --- MaxText/layers/attentions.py | 18 +++--------------- MaxText/max_utils.py | 1 + MaxText/maxtext_utils.py | 1 - 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 94687f15f..8b19e8d06 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -465,14 +465,9 @@ def compute_local_attention( local_sum = jnp.reshape(local_sum, (local_sum.shape[0], local_sum.shape[1], local_sum.shape[2] * local_sum.shape[3], 1)) local_out = self.wv_product(local_exps, value, model_mode) - if model_mode == common_types.MODEL_MODE_PREFILL: - print("prefill") - print(local_out.shape) - local_out = partitioning.with_sharding_constraint(local_out, (BATCH, KV_LENGTH, HEAD, D_KV)) - #local_out = partitioning.with_sharding_constraint(local_out, DECODE_BATCH, HEAD, D_KV) + if model_mode != common_types.MODEL_MODE_AUTOREGRESSIVE: + local_out = partitioning.with_sharding_constraint(local_out, (BATCH, KV_LENGTH, HEAD, D_KV)) else: - print("decode") - print(local_out.shape) local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) @@ -481,13 +476,11 @@ def compute_local_attention( local_sum = local_sum[:, 0:1, :, :] local_out = local_out[:, 0:1, :, :] - if model_mode != common_types.MODEL_MODE_PREFILL: + if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: local_max = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) local_sum = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) - print(local_sum.shape) - print(local_max.shape) return local_out, local_max, local_sum def apply_attention_dot( @@ -563,9 +556,6 @@ def qk_product(self, query: Array, key: Array | KVTensor, q_seq_len: int, model_ b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads - print(query.shape) - print(key.shape) - print(self.compute_axis_order) if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0, 1, 2, 3): query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) if self.reshape_q and q_seq_len == 1: @@ -866,7 +856,6 @@ def update_ar_key_value( one_token_value_shaped_for_cache = jnp.transpose(one_token_value, self.ar_cache_axis_order) ar_cache_axis_names = self.transpose_tuple(self.cache_logical_axis_names, self.ar_cache_axis_order) - print(ar_cache_axis_names) if self.kv_quant: one_token_key_shaped_for_cache, one_token_key_scale_shaped_for_cache = self.kv_quant.quantize( one_token_key_shaped_for_cache, ar_cache_axis_names @@ -911,7 +900,6 @@ def value_body(i, val): cached_value_var.value = jax.lax.dynamic_update_index_in_dim( cached_value_var.value, one_token_value_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis ) - print("cv", cached_value_var.value.shape) cached_key_var.value = nn.with_logical_constraint(cached_key_var.value, ar_cache_axis_names) cached_value_var.value = nn.with_logical_constraint(cached_value_var.value, ar_cache_axis_names) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index a8e29b589..ae0e49563 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -640,6 +640,7 @@ def create_device_mesh(config, devices=None): num_devices_per_slice = num_devices // num_slices multi_slice_env = num_slices > 1 + # Find possible unspecified parallelisms ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 7c355bc22..607970d22 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -241,7 +241,6 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance): product_num_devices_for_weight_sharding = 1 for axis in ["fsdp", "fsdp_transpose", "sequence", "context", "tensor", "tensor_transpose", "tensor_sequence", "stage", "expert"]: product_num_devices_for_weight_sharding *= mesh.shape[axis] - print(product_num_devices_for_weight_sharding) total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params) perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, ( From 10a9d82957283767c8eb78e0243cafed0dd22ab9 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Wed, 5 Mar 2025 22:42:07 +0000 Subject: [PATCH 16/34] update --- MaxText/configs/base.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index c408565f2..c85009f2a 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -303,7 +303,7 @@ logical_axis_rules: [ ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['cache_kv', []], ['cache_sequence', ['context']], - ['exp', ['expert', 'context'], + ['exp', ['expert', 'context']], ['paged_kv_heads', []], ['num_pages', ['tensor']], ['tokens_per_page', []], From 0cca6df0c78665e117aee076d66ef5a49b92a726 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Thu, 6 Mar 2025 00:38:11 +0000 Subject: [PATCH 17/34] update --- MaxText/configs/base.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index c85009f2a..9736ae0d3 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -310,7 +310,7 @@ logical_axis_rules: [ ['paged_kv_head_dim_size', []], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'conext', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] # sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters. sharding_tolerance: 0.02 From ebae8e065f5488f52c603349a7fe95e0adfae118 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Thu, 6 Mar 2025 01:03:07 +0000 Subject: [PATCH 18/34] fix tests --- MaxText/pyconfig.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index f758df89f..e11f62125 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -626,6 +626,7 @@ def pipeline_first_axis(raw_keys): raw_keys["ici_fsdp_parallelism"], raw_keys["ici_fsdp_transpose_parallelism"], raw_keys["ici_sequence_parallelism"], + raw_keys["ici_context_parallelism"], raw_keys["ici_tensor_parallelism"], raw_keys["ici_tensor_transpose_parallelism"], raw_keys["ici_tensor_sequence_parallelism"], @@ -638,6 +639,7 @@ def pipeline_first_axis(raw_keys): raw_keys["dcn_fsdp_parallelism"], raw_keys["dcn_fsdp_transpose_parallelism"], raw_keys["dcn_sequence_parallelism"], + raw_keys["dnc_context_parallelism"], raw_keys["dcn_tensor_parallelism"], raw_keys["dcn_tensor_transpose_parallelism"], raw_keys["dcn_tensor_sequence_parallelism"], From 2e0c45963dbd5109174ec961b578b48c28c18058 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Thu, 6 Mar 2025 22:25:08 +0000 Subject: [PATCH 19/34] added contition for non-sharded kernel for cp during inference only --- MaxText/layers/attentions.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index fa6b769ed..d97262011 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -1386,12 +1386,13 @@ def query_projection(self, inputs_q: Array) -> Array: def query_init(*args): # pylint: disable=no-value-for-parameter return self.kernel_init(*args) / depth_scaling - + + kernel_axes = (None, None, None) if self.config.ici_context_parallelism > 1 else ("embed", "q_heads", "kv") query_proj = DenseGeneral( features=(self.num_query_heads, self.head_dim), axis=-1, kernel_init=query_init, - kernel_axes=(None, None, None), # attn_q_weight_ndh=(None, zero_axes, None), + kernel_axes=kernel_axes, dtype=self.dtype, weight_dtype=self.weight_dtype, name="query", @@ -1418,8 +1419,7 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: raise ValueError("Invalid num_kv_heads for GQA.") # attn_k_weight_kdh=(None, zero_axes, None), - # kernel_axes = ("embed", "kv_heads", "kv_head_dim") - kernel_axes = (None, None, None) + kernel_axes = (None, None, None) if self.config.ici_context_parallelism > 1 else ("embed", "kv_heads", "kv_head_dim") kv_proj = DenseGeneral( features=(self.num_kv_heads, self.head_dim), @@ -1453,11 +1453,12 @@ def qkv_projection(self, inputs: Array, proj_name: str): return query, key, value def out_projection(self, output_dim: int, out: Array) -> Array: + out_kernel_axis = (None, None, None) if self.config.ici_context_parallelism > 1 else ("heads", "kv", "embed") out_proj = DenseGeneral( features=output_dim, axis=(-2, -1), kernel_init=self.kernel_init, - kernel_axes=(None, None, None), # trade speed with memory + kernel_axes=out_kernel_axis, # trade speed with memory dtype=self.dtype, weight_dtype=self.weight_dtype, name="out", From 37c843e7410b57d7b42c636343f1b9a7a056ce67 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Thu, 6 Mar 2025 23:03:26 +0000 Subject: [PATCH 20/34] update --- MaxText/pyconfig.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index e11f62125..e59b00ed0 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -639,7 +639,7 @@ def pipeline_first_axis(raw_keys): raw_keys["dcn_fsdp_parallelism"], raw_keys["dcn_fsdp_transpose_parallelism"], raw_keys["dcn_sequence_parallelism"], - raw_keys["dnc_context_parallelism"], + raw_keys["dcn_context_parallelism"], raw_keys["dcn_tensor_parallelism"], raw_keys["dcn_tensor_transpose_parallelism"], raw_keys["dcn_tensor_sequence_parallelism"], From b63c63b7e3bdad122477399baa68738de3353281 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 7 Mar 2025 02:38:27 +0000 Subject: [PATCH 21/34] bug fix --- MaxText/layers/attentions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index d97262011..14af088ce 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -557,8 +557,8 @@ def compute_local_attention( local_out = local_out[:, 0:1, :, :] if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: - local_max = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) - local_sum = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) + local_max = partitioning.with_sharding_constraint(local_max, (DECODE_BATCH, None, HEAD, D_KV)) + local_sum = partitioning.with_sharding_constraint(local_sum, (DECODE_BATCH, None, HEAD, D_KV)) local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) return local_out, local_max, local_sum From 4007e7cc5df0f88522a7586ea9dd02adb330c467 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 7 Mar 2025 19:03:04 +0000 Subject: [PATCH 22/34] fix tests --- MaxText/pyconfig.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index e59b00ed0..9fc895aca 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -593,6 +593,7 @@ def validate_multiple_slices(raw_keys): raw_keys["dcn_tensor_parallelism"], raw_keys["dcn_tensor_sequence_parallelism"], raw_keys["dcn_expert_parallelism"], + raw_keys["dcn_context_parallelism"], raw_keys["dcn_autoregressive_parallelism"], ] ) @@ -652,6 +653,7 @@ def pipeline_first_axis(raw_keys): "fsdp", "fsdp_transpose", "sequence", + "context", "tensor", "tensor_transpose", "tensor_sequence", @@ -665,6 +667,7 @@ def pipeline_first_axis(raw_keys): "fsdp", "fsdp_transpose", "sequence", + "context", "tensor", "tensor_transpose", "tensor_sequence", From 72f2a90cb8bec9e209f7790ae3914616b8d36be5 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 7 Mar 2025 19:31:17 +0000 Subject: [PATCH 23/34] adddress comment --- MaxText/layers/linears.py | 19 +++++++++++-------- MaxText/pyconfig.py | 1 - 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 96b876d2a..4039b8815 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -519,13 +519,17 @@ def reshape_and_update_weights(self, weights, indices): update_weights = update_weights.at[index_update].set(weights) return update_weights + def get_context_partition_and_sub_seq(self, seq_len): + cp = self.config.ici_context_parallelism + if seq_len % cp != 0: + cp = 1 + sub_seq = seq_len // cp + return cp, sub_seq + def generate_masks(self, top_k_indices, softmax_probs): # calculate expert_capacity = (tokens_per_batch / num_experts) * capacity_factor batch_size, seq_len, _ = top_k_indices.shape - cp = self.config.ici_context_parallelism - if seq_len % cp != 0: - cp = 1 - sub_seq = seq_len // cp + cp, sub_seq = self.get_context_partition_and_sub_seq(seq_len) top_k_indices = jnp.reshape(top_k_indices, (batch_size, cp, sub_seq, top_k_indices.shape[2])) @@ -654,12 +658,11 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): weights = self.reshape_and_update_weights(top_k_weights, top_k_indices) matmul_precision = lax.Precision(self.config.matmul_precision) - cp = self.config.ici_context_parallelism batch_size = inputs.shape[0] seq_len = inputs.shape[1] - if seq_len % cp != 0: - cp = 1 - sub_seq = seq_len // cp + + cp, sub_seq = self.get_context_partition_and_sub_seq(seq_len) + if self.config.capacity_factor > 0: # token dropping if needed dispatch_mask, combine_mask = self.generate_masks(top_k_indices, softmax_probs) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 9fc895aca..d5af6af5e 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -212,7 +212,6 @@ def validate_model_name(s: str) -> bool: "llama3.1-70b", "llama3.1-405b", "llama3.3-70b", - "subsup", "mistral-7b", "mixtral-8x7b", "mixtral-8x22b", From 8da48f5a93d8a1765ea539c2d4d29ceabbac817c Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 7 Mar 2025 19:34:38 +0000 Subject: [PATCH 24/34] update --- MaxText/layers/linears.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 4039b8815..331e05d05 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -735,7 +735,8 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( "EBNCH,EHM -> EBNCM", layer_multiply, wo_kernel, precision=matmul_precision ) - + if self.config.activations_in_float32: + intermediate_layer = intermediate_layer.astype(jnp.float32) # intermediate_layer = nn.with_logical_constraint( # intermediate_layer, # ("activation_exp", "activation_batch_no_exp", None, "activation_embed"), From 8a43dd567c41d3dcada3b843d03fe5d700cca12b Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 7 Mar 2025 19:43:42 +0000 Subject: [PATCH 25/34] address comments --- MaxText/layers/attentions.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 14af088ce..05d864822 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -180,7 +180,7 @@ def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Arr assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." assert key.shape[-3] == value.shape[-3], "k, v lengths must match." assert query.shape[-1] == key.shape[-1], "q, k depths must match." - + # Attention mask is generated in the same way # as mentioned in SARATHI - https://arxiv.org/abs/2308.16369 def generate_attention_mask_for_chunk(self, query, key, previous_chunk: Any = None) -> Array | None: @@ -550,19 +550,16 @@ def compute_local_attention( else: local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) - if self.reshape_q and q_seq_len == 1: local_max = local_max[:, 0:1, :, :] local_sum = local_sum[:, 0:1, :, :] local_out = local_out[:, 0:1, :, :] - if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: - local_max = partitioning.with_sharding_constraint(local_max, (DECODE_BATCH, None, HEAD, D_KV)) - local_sum = partitioning.with_sharding_constraint(local_sum, (DECODE_BATCH, None, HEAD, D_KV)) - local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) - return local_out, local_max, local_sum + def is_partition_in_decode(self, seq_len): + return self.config.ici_context_parallelism > 0 and seq_len == 1 + def apply_attention_dot( self, query: Array, @@ -582,8 +579,9 @@ def apply_attention_dot( key = key.astype(jnp.float32) q_seq_len = query.shape[1] + # special sharding for decode - if self.config.ici_context_parallelism > 0 and q_seq_len == 1: + if self.is_partition_in_decode(q_seq_len): query = partitioning.with_sharding_constraint(query, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) key = partitioning.with_sharding_constraint(key, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) value = partitioning.with_sharding_constraint(value, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) @@ -593,7 +591,7 @@ def apply_attention_dot( value = partitioning.with_sharding_constraint(value, (BATCH, KV_LENGTH, HEAD, D_KV)) attn_weights = self.qk_product(query, key, q_seq_len, model_mode) - if self.config.ici_context_parallelism > 0 and q_seq_len == 1: + if self.is_partition_in_decode(q_seq_len): attn_weights = partitioning.with_sharding_constraint(attn_weights, (KV_LENGTH, HEAD, None, None, None)) else: attn_weights = partitioning.with_sharding_constraint(attn_weights, (BATCH, HEAD, None, LENGTH, KV_LENGTH)) @@ -601,10 +599,12 @@ def apply_attention_dot( if self.attn_logits_soft_cap: attn_weights = jnp.tanh(attn_weights / self.attn_logits_soft_cap) attn_weights = attn_weights * self.attn_logits_soft_cap - if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_logits: + + # Casting softmaxt computation for float32 for model stability. + if self.float32_logits: attn_weights = attn_weights.astype(jnp.float32) attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode, previous_chunk) - if self.config.ici_context_parallelism > 0 and q_seq_len == 1: + if self.is_partition_in_decode(q_seq_len): attn_mask = partitioning.with_sharding_constraint(attn_mask, (KV_LENGTH, HEAD, None, None, None)) else: attn_mask = partitioning.with_sharding_constraint(attn_mask, (BATCH, HEAD, None, LENGTH, KV_LENGTH)) @@ -1386,7 +1386,6 @@ def query_projection(self, inputs_q: Array) -> Array: def query_init(*args): # pylint: disable=no-value-for-parameter return self.kernel_init(*args) / depth_scaling - kernel_axes = (None, None, None) if self.config.ici_context_parallelism > 1 else ("embed", "q_heads", "kv") query_proj = DenseGeneral( features=(self.num_query_heads, self.head_dim), @@ -1418,7 +1417,6 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: if self.num_query_heads % self.num_kv_heads != 0: raise ValueError("Invalid num_kv_heads for GQA.") - # attn_k_weight_kdh=(None, zero_axes, None), kernel_axes = (None, None, None) if self.config.ici_context_parallelism > 1 else ("embed", "kv_heads", "kv_head_dim") kv_proj = DenseGeneral( From 56deeda329f2f735d934e140b88943c562536541 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 7 Mar 2025 19:52:37 +0000 Subject: [PATCH 26/34] address comments --- MaxText/pyconfig.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index d5af6af5e..6ef8953f1 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -154,6 +154,9 @@ def validate_keys(keys): "Not using emergency checkpoint, ignoring local_checkpoint_directory, local_checkpoint_period," " use_replicator_service and replicator_backup_interval_minutes" ) + assert ( + keys["ici_context_parallelism"] == 1 or keys["quantize_kvcache"] == False + ), "currently context parallelism doesn't support quantized kv cache" validate_multiple_slices(keys) if keys["num_experts"] > 1: From 1c6be59459132f260c8875d860cbf50ae12345f2 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 7 Mar 2025 19:59:09 +0000 Subject: [PATCH 27/34] revert --- MaxText/layers/attentions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 05d864822..0efdb90b2 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -180,7 +180,6 @@ def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Arr assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." assert key.shape[-3] == value.shape[-3], "k, v lengths must match." assert query.shape[-1] == key.shape[-1], "q, k depths must match." - # Attention mask is generated in the same way # as mentioned in SARATHI - https://arxiv.org/abs/2308.16369 def generate_attention_mask_for_chunk(self, query, key, previous_chunk: Any = None) -> Array | None: @@ -555,6 +554,11 @@ def compute_local_attention( local_sum = local_sum[:, 0:1, :, :] local_out = local_out[:, 0:1, :, :] + if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + local_max = partitioning.with_sharding_constraint(local_max, (DECODE_BATCH, None, HEAD, D_KV)) + local_sum = partitioning.with_sharding_constraint(local_sum, (DECODE_BATCH, None, HEAD, D_KV)) + local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) + return local_out, local_max, local_sum def is_partition_in_decode(self, seq_len): From bd0e1998a19b2ba9551241b79ed280481be05bcb Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 7 Mar 2025 21:14:52 +0000 Subject: [PATCH 28/34] address lint --- MaxText/pyconfig.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 6ef8953f1..8944f707c 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -155,7 +155,7 @@ def validate_keys(keys): " use_replicator_service and replicator_backup_interval_minutes" ) assert ( - keys["ici_context_parallelism"] == 1 or keys["quantize_kvcache"] == False + keys["ici_context_parallelism"] == 1 or keys["quantize_kvcache"] is False ), "currently context parallelism doesn't support quantized kv cache" validate_multiple_slices(keys) From 44d646f12d86fe0db66e7f6381132adbfce25a03 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 7 Mar 2025 21:18:45 +0000 Subject: [PATCH 29/34] reformat for lint --- MaxText/layers/attentions.py | 14 +++++----- MaxText/layers/linears.py | 52 ++++++++++++++++++------------------ MaxText/layers/models.py | 12 ++++----- MaxText/maxtext_utils.py | 12 ++++++++- 4 files changed, 50 insertions(+), 40 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 0efdb90b2..d056ae790 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -180,6 +180,7 @@ def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Arr assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." assert key.shape[-3] == value.shape[-3], "k, v lengths must match." assert query.shape[-1] == key.shape[-1], "q, k depths must match." + # Attention mask is generated in the same way # as mentioned in SARATHI - https://arxiv.org/abs/2308.16369 def generate_attention_mask_for_chunk(self, query, key, previous_chunk: Any = None) -> Array | None: @@ -545,7 +546,7 @@ def compute_local_attention( local_out = self.wv_product(local_exps, value, model_mode) if model_mode != common_types.MODEL_MODE_AUTOREGRESSIVE: - local_out = partitioning.with_sharding_constraint(local_out, (BATCH, KV_LENGTH, HEAD, D_KV)) + local_out = partitioning.with_sharding_constraint(local_out, (BATCH, KV_LENGTH, HEAD, D_KV)) else: local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, None, HEAD, D_KV)) @@ -1390,6 +1391,7 @@ def query_projection(self, inputs_q: Array) -> Array: def query_init(*args): # pylint: disable=no-value-for-parameter return self.kernel_init(*args) / depth_scaling + kernel_axes = (None, None, None) if self.config.ici_context_parallelism > 1 else ("embed", "q_heads", "kv") query_proj = DenseGeneral( features=(self.num_query_heads, self.head_dim), @@ -1455,12 +1457,12 @@ def qkv_projection(self, inputs: Array, proj_name: str): return query, key, value def out_projection(self, output_dim: int, out: Array) -> Array: - out_kernel_axis = (None, None, None) if self.config.ici_context_parallelism > 1 else ("heads", "kv", "embed") + out_kernel_axis = (None, None, None) if self.config.ici_context_parallelism > 1 else ("heads", "kv", "embed") out_proj = DenseGeneral( features=output_dim, axis=(-2, -1), kernel_init=self.kernel_init, - kernel_axes=out_kernel_axis, # trade speed with memory + kernel_axes=out_kernel_axis, # trade speed with memory dtype=self.dtype, weight_dtype=self.weight_dtype, name="out", @@ -1555,7 +1557,7 @@ def __call__( """ if model_mode == common_types.MODEL_MODE_PREFILL: inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names) - inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names) + inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names) else: inputs_q = nn.with_logical_constraint(inputs_q, self.decode_input_axis_names) inputs_kv = nn.with_logical_constraint(inputs_kv, self.decode_input_axis_names) @@ -1578,8 +1580,8 @@ def __call__( value = nn.with_logical_constraint(value, self.prefill_value_axis_names) elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: query = nn.with_logical_constraint(query, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) - key = nn.with_logical_constraint(key,(DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) - value = nn.with_logical_constraint(value, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) + key = nn.with_logical_constraint(key, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) + value = nn.with_logical_constraint(value, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) else: query = nn.with_logical_constraint(query, self.query_axis_names) key = nn.with_logical_constraint(key, self.key_axis_names) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 331e05d05..558e61c4d 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -520,12 +520,12 @@ def reshape_and_update_weights(self, weights, indices): return update_weights def get_context_partition_and_sub_seq(self, seq_len): - cp = self.config.ici_context_parallelism - if seq_len % cp != 0: - cp = 1 - sub_seq = seq_len // cp - return cp, sub_seq - + cp = self.config.ici_context_parallelism + if seq_len % cp != 0: + cp = 1 + sub_seq = seq_len // cp + return cp, sub_seq + def generate_masks(self, top_k_indices, softmax_probs): # calculate expert_capacity = (tokens_per_batch / num_experts) * capacity_factor batch_size, seq_len, _ = top_k_indices.shape @@ -593,7 +593,7 @@ def generate_masks(self, top_k_indices, softmax_probs): combine_mask = combine_mask[..., 1:] dispatch_mask = combine_mask.astype(bool) - #ici_context_parallelism + # ici_context_parallelism dispatch_mask = jnp.reshape(dispatch_mask, (batch_size, cp, sub_seq, self.num_experts, expert_capacity_per_batch)) combine_mask = jnp.reshape(combine_mask, (batch_size, cp, sub_seq, self.num_experts, expert_capacity_per_batch)) @@ -667,8 +667,8 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): # token dropping if needed dispatch_mask, combine_mask = self.generate_masks(top_k_indices, softmax_probs) if self.config.ici_context_parallelism > 0 and cp == 1: - mask_axes = ( "activation_length", "activation_batch", None, None, None) - input_axis = ("activation_length", "activation_batch", None, "activation_embed") + mask_axes = ("activation_length", "activation_batch", None, None, None) + input_axis = ("activation_length", "activation_batch", None, "activation_embed") dispatch_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_embed") mlp_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_mlp") else: @@ -682,44 +682,44 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): loss = self.load_balance_loss(top_k_indices, weights) else: loss = None - inputs = jnp.reshape(inputs,(batch_size, cp, sub_seq, inputs.shape[2])) + inputs = jnp.reshape(inputs, (batch_size, cp, sub_seq, inputs.shape[2])) - inputs = nn.with_logical_constraint(inputs, input_axis) + inputs = nn.with_logical_constraint(inputs, input_axis) with jax.named_scope("dispatch"): # only cp during prefill dispatch = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=DISPATCH)( - "BNSM,BNSEC -> EBNCM", inputs, dispatch_mask, precision=matmul_precision - ) + "BNSM,BNSEC -> EBNCM", inputs, dispatch_mask, precision=matmul_precision + ) if cp > 1: dispatch = nn.with_logical_constraint( - dispatch, - (None, "activation_batch_no_exp", "activation_length", None, "activation_embed"), - ) - dispatch = nn.with_logical_constraint( dispatch, - dispatch_axis, + (None, "activation_batch_no_exp", "activation_length", None, "activation_embed"), ) + dispatch = nn.with_logical_constraint( + dispatch, + dispatch_axis, + ) with jax.named_scope("wi_0"): w0_kernel_axes = ("exp", None, "mlp") w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes) layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( - "EBNCM,EMH -> EBNCH", dispatch, w0_kernel, precision=matmul_precision - ) + "EBNCM,EMH -> EBNCH", dispatch, w0_kernel, precision=matmul_precision + ) if self.config.activations_in_float32: layer_w0 = layer_w0.astype(jnp.float32) layer_w0 = nn.with_logical_constraint( layer_w0, - mlp_axis, + mlp_axis, ) layer_w0 = checkpoint_name(layer_w0, "mlpwi_0") with jax.named_scope("wi_1"): w1_kernel_axes = ("exp", None, "mlp") w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes) layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( - "EBNCM,EMH -> EBNCH", dispatch, w1_kernel, precision=matmul_precision - ) + "EBNCM,EMH -> EBNCH", dispatch, w1_kernel, precision=matmul_precision + ) if self.config.activations_in_float32: layer_w1 = layer_w1.astype(jnp.float32) layer_w1 = nn.with_logical_constraint( @@ -733,8 +733,8 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): wo_kernel_axes = ("exp", "mlp", None) wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes) intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( - "EBNCH,EHM -> EBNCM", layer_multiply, wo_kernel, precision=matmul_precision - ) + "EBNCH,EHM -> EBNCM", layer_multiply, wo_kernel, precision=matmul_precision + ) if self.config.activations_in_float32: intermediate_layer = intermediate_layer.astype(jnp.float32) # intermediate_layer = nn.with_logical_constraint( @@ -750,7 +750,7 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): combine_mask, precision=matmul_precision, ) - output = jnp.reshape(output, (output.shape[0], output.shape[1]*output.shape[2], output.shape[3])) + output = jnp.reshape(output, (output.shape[0], output.shape[1] * output.shape[2], output.shape[3])) return output, loss else: top_k_weights /= top_k_weights.sum(-1, keepdims=True) diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 24c38ae94..b4c30b3df 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -512,16 +512,14 @@ def __call__( )( y ) # We do not quantize the logits matmul. - - if model_mode in [common_types.MODEL_MODE_PREFILL, common_types.MODEL_MODE_AUTOREGRESSIVE]: - logits = nn.with_logical_constraint( - logits, (None, None, "activation_vocab") - ) - else: + + if model_mode in [common_types.MODEL_MODE_PREFILL, common_types.MODEL_MODE_AUTOREGRESSIVE]: + logits = nn.with_logical_constraint(logits, (None, None, "activation_vocab")) + else: logits = nn.with_logical_constraint( logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab") ) - + if self.config.cast_logits_to_fp32: logits = logits.astype(jnp.float32) return logits diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index a73cc0daa..e474c770c 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -309,7 +309,17 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance): """ total_num_params = max_utils.calculate_num_params_from_pytree(params) product_num_devices_for_weight_sharding = 1 - for axis in ["fsdp", "fsdp_transpose", "sequence", "context", "tensor", "tensor_transpose", "tensor_sequence", "stage", "expert"]: + for axis in [ + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "tensor", + "tensor_transpose", + "tensor_sequence", + "stage", + "expert", + ]: product_num_devices_for_weight_sharding *= mesh.shape[axis] total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params) perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding From 51720685bd4979282b0fe412349015d77c053fea Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 7 Mar 2025 22:53:21 +0000 Subject: [PATCH 30/34] update MOE test --- MaxText/tests/moe_test.py | 57 ++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/MaxText/tests/moe_test.py b/MaxText/tests/moe_test.py index e146671eb..57ccc87d6 100644 --- a/MaxText/tests/moe_test.py +++ b/MaxText/tests/moe_test.py @@ -124,35 +124,44 @@ def test_generate_masks(self): expected_combine_mask = jnp.array( [ [ - [[0.2, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.8, 0], [0, 0], [0, 0]], - [[0, 0.68], [0, 0], [0, 0], [0, 0], [0.32, 0], [0, 0], [0, 0], [0, 0]], - [[0, 0], [0.78, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], - [[0, 0], [0, 0], [0, 0], [0.32, 0], [0, 0], [0, 0.68], [0, 0], [0, 0]], - ], - [ - [[0, 0], [0.26, 0], [0.74, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], - [[0, 0], [0, 0.79], [0, 0], [0, 0], [0.21, 0], [0, 0], [0, 0], [0, 0]], - [[0.89, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.11, 0], [0, 0], [0, 0]], - [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.89, 0]], - ], - [ - [[0, 0], [0, 0], [0.26, 0], [0, 0], [0, 0], [0, 0], [0.74, 0], [0, 0]], - [[0, 0], [0, 0], [0, 0.88], [0.12, 0], [0, 0], [0, 0], [0, 0], [0, 0]], - [[0, 0], [0, 0], [0, 0], [0, 0], [0.83, 0], [0, 0], [0, 0], [0, 0]], - [[0, 0], [0.35, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], - ], - [ - [[0, 0], [0.47, 0], [0, 0], [0, 0], [0.53, 0], [0, 0], [0, 0], [0, 0]], - [[0.36, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.64, 0]], - [[0, 0.15], [0, 0], [0, 0], [0, 0], [0, 0], [0.85, 0], [0, 0], [0, 0]], - [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0.18], [0, 0], [0, 0], [0, 0.82]], - ], + [ + [ + [[0.2, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.8, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.68], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.32, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.78, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.32, 0.0], [0.0, 0.0], [0.0, 0.68], [0.0, 0.0], [0.0, 0.0]], + ] + ], + [ + [ + [[0.0, 0.0], [0.26, 0.0], [0.74, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.0, 0.79], [0.0, 0.0], [0.0, 0.0], [0.21, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.89, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.11, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.89, 0.0]], + ] + ], + [ + [ + [[0.0, 0.0], [0.0, 0.0], [0.26, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.74, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.88], [0.12, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.83, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.35, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + ] + ], + [ + [ + [[0.0, 0.0], [0.47, 0.0], [0.0, 0.0], [0.0, 0.0], [0.53, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.36, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.64, 0.0]], + [[0.0, 0.15], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.85, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.18], [0.0, 0.0], [0.0, 0.0], [0.0, 0.82]], + ] + ], + ] ], dtype=jnp.float32, ) expected_dispatch_mask = expected_combine_mask.astype(bool) actual_dispatch_mask, actual_combine_mask = self.model.generate_masks(top_k_indices, softmax_probs) - self.assertTrue((expected_dispatch_mask == actual_dispatch_mask).all()) self.assertTrue(jax.numpy.allclose(expected_combine_mask, actual_combine_mask, rtol=1e-02, atol=1e-02)) From d6787c394c561f7456cf3adec5f32f48038b6ce2 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 7 Mar 2025 22:58:44 +0000 Subject: [PATCH 31/34] add comment to explain grouping in generate_mask for moe model --- MaxText/layers/linears.py | 1 + 1 file changed, 1 insertion(+) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 558e61c4d..2038f7fdf 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -531,6 +531,7 @@ def generate_masks(self, top_k_indices, softmax_probs): batch_size, seq_len, _ = top_k_indices.shape cp, sub_seq = self.get_context_partition_and_sub_seq(seq_len) + # breaking the sequence into sub sequences. It is effectively grouping the tokens in a sequence into groups, and route only within each group. top_k_indices = jnp.reshape(top_k_indices, (batch_size, cp, sub_seq, top_k_indices.shape[2])) tokens_per_batch = sub_seq * self.num_experts_per_tok From f964acd12a1955837b3d8f4deb2849a13a4aea6e Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Sat, 8 Mar 2025 01:09:27 +0000 Subject: [PATCH 32/34] address the comments --- MaxText/layers/linears.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 2038f7fdf..340552f1d 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -645,7 +645,7 @@ def maybe_all_gather_kernel_weight_in_expert_parallelism(self, kernel, kernel_ax def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): # gate_logits: batch, length, expert - # follow router_logits = shd.shard(router_logits, (None, None, None)) + # follow router_logits non-sharded kernel gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_length", None)) softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype) # shape of top_k_weights & top_k_indices: (batch, sequence, num_experts_per_tok) @@ -738,10 +738,11 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): ) if self.config.activations_in_float32: intermediate_layer = intermediate_layer.astype(jnp.float32) - # intermediate_layer = nn.with_logical_constraint( - # intermediate_layer, - # ("activation_exp", "activation_batch_no_exp", None, "activation_embed"), - # ) + if self.config.model_call_mode != "inference": + intermediate_layer = nn.with_logical_constraint( + intermediate_layer, + ("activation_exp", "activation_batch_no_exp", None, "activation_embed"), + ) intermediate_layer = checkpoint_name(intermediate_layer, "mlpwo") with jax.named_scope("combine"): # Matmul & element wise operation From c5174de79b3c03211dda9305ceb476e93ce1c7f8 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Sat, 8 Mar 2025 01:29:54 +0000 Subject: [PATCH 33/34] update to fix tests --- MaxText/configs/base.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 538feeff7..d0f6c9ce0 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -265,7 +265,6 @@ logical_axis_rules: [ ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], ['activation_length', ['sequence', 'context']], - ['activation_length', ['context']], ['activation_length_q', ['context']], ['activation_kv_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'sequence']], From b86e0356a8c0ae0dcecf5a7b2c916842ea8f20a5 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Sat, 8 Mar 2025 04:41:17 +0000 Subject: [PATCH 34/34] seperate yml for inference --- MaxText/configs/base.yml | 28 +- MaxText/configs/inference.yml | 626 ++++++++++++++++++++++++++++++++++ 2 files changed, 639 insertions(+), 15 deletions(-) create mode 100644 MaxText/configs/inference.yml diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index d0f6c9ce0..3af88cc08 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -264,26 +264,24 @@ logical_axis_rules: [ ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], - ['activation_length', ['sequence', 'context']], - ['activation_length_q', ['context']], - ['activation_kv_length', ['context']], + ['activation_length', ['sequence']], ['activation_norm_length', ['tensor_sequence', 'sequence']], ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], + ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], - ['activation_vocab', ['sequence', 'context']], + ['activation_vocab', ['sequence']], ['activation_stage', 'stage'], - ['activation_exp', ['expert', 'context']], - ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], + ['activation_exp', ['expert']], + ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', []], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], - ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context']], + ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], @@ -291,21 +289,21 @@ logical_axis_rules: [ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], ['embed', ['fsdp', 'sequence', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'sequence', 'context', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], - ['embed_no_exp', ['fsdp', 'sequence', 'context']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose']], + ['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence']], + ['embed_no_exp', ['fsdp', 'sequence']], ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['layers', 'stage'], ['kv', []], ['kv_head_dim', []], ['cache_batch_prefill', []], - ['cache_batch', ['context']], + ['cache_batch', []], ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['cache_kv', []], - ['cache_sequence', ['context']], - ['exp', ['expert', 'context']], + ['cache_sequence', []], + ['exp', 'expert'], ['paged_kv_heads', []], ['num_pages', ['tensor']], ['tokens_per_page', []], diff --git a/MaxText/configs/inference.yml b/MaxText/configs/inference.yml new file mode 100644 index 000000000..7fa591455 --- /dev/null +++ b/MaxText/configs/inference.yml @@ -0,0 +1,626 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +# If there is already a checkpoint under this run, that checkpoint will auto-resume. +run_name: "" + +model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this! +normalization_layer_epsilon: 1.e-05 + +################################## CHECKPOINTING ################################## +# Checkpointing makes the following choices in the following order, starting with (1): +# (1) If there is already a checkpoint for this run_name, we load the latest entire checkpoint. +# This ensures if we're resuming a run after preemption or hardware failure we lose minimum state. +# (2) Same priority and mutually exclusive -- you can't set both! +# * If load_parameters_path is set, we load a parameter only checkpoint from that path. +# * If load_full_state_path is set, we load a full state checkpoint from that path. +# (3) We don't load a checkpoint and initialize state instead! + +# Loads a just parameters from a specific directory +# e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/items/NUMBER or NUMBER/items +load_parameters_path: "" +# Loads a full checkpoint including optimizer state and step count from a specific directory +# e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/items/NUMBER or NUMBER/items +load_full_state_path: "" + +# If enable_checkpointing is true, an asynchronous checkpointer will be used if +# async_checkpointing is true, else a synchronous one is used. If you have +# problems with the checkpointer we recommend trying the synchronous one. +enable_checkpointing: True +async_checkpointing: True +checkpoint_period: 10_000 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +force_unroll: False # during generate_param_only_checkpoint should we unroll the loop? + +# checkpointing using orbax has two important parameters: array driver +# and its underlying storage - the kvstore (preferably ocdbt) +# orbax supports setting a target file size, chunking a single +# large arrays into small physical files (<2GB) can speed up distributed and over +# the network loading enormously +checkpoint_storage_target_data_file_size_bytes: 2147483648 +checkpoint_storage_use_ocdbt: True +checkpoint_storage_use_zarr3: True +# larger models requires higher concurrent GB for I/O +# default concurrent gb for PytreeCheckpointHandler is 96GB +checkpoint_storage_concurrent_gb: 96 +############################### END CHECKPOINTING ################################## + + +reuse_example_batch: 0 # for testing TPU performance, this options repeated uses the same batch. + + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +gcs_metrics: False + +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False + +# Activation dtypes. +dtype: "bfloat16" +# Used to configure quantization in the transformer layers, defaults to null implying bf16. +# Possible alternative settings are as follows: +# 'int8' for dynamic range quantization using 8-bits +# 'intmp' for mixed precision quantization for inference as described here: MaxText/configs/quantization/README.md +# 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs. +# 'nanoo_fp8' for 8-bit floating-point GeMMs on AMD MI300/MI325 GPUs. +quantization: "" +# Choose one of default, high, and highest. +# https://kolonist26-jax-kr.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +matmul_precision: "default" +activations_in_float32: False # Sets activations to float32 before nonlinearity it true, else dtype +# Used to replicate the quantization scale to avoid the inefficient XLA fusion for 2d sharding. +replicate_quant_scale: False +# Path to file with quantization config for intmp. +quant_cfg_path: "" +quantize_kvcache: False # Set to True to quantize KV Cache values, defaults to False +# Valid kv_quant_axis values: +# - "" is valid only when quantize_kvcache is False +# - "dkv" indicates quantize kv cache over the cache_kv, i.e. kv dimension axis +# - "heads_and_dkv" indicates quantize kv cache over cache_heads and cache_kv axes +# Default to "heads_and_dkv" for faster compution, kv_quant_axis is not used when quantize_kvcache is False +# - "dkv" is expected with better accuracy but degraded computation +kv_quant_axis: "heads_and_dkv" +kv_quant_dtype: "int8" +checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint +# Saves params quantized on fly at following path +save_quantized_params_path: "" +#Used to configure the mode in which model is called +# when left as is, corresponds to training +# accepted values are "inference" +model_call_mode: "" + +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 + +decoder_block: "llama2" # which style of DecoderBlock to use. +# Global parameter scale needs to be a power of 2. If you want finer grained control of the model sizes +# then you should explicitly set base_embed_dim, base_num_query_heads, base_num_kv_heads, +# base_mlp_dim, base_num_decoder_layers and/or head_dim. +weight_dtype: float32 +global_parameter_scale: 1 +base_emb_dim: 2048 +base_num_query_heads: 16 +base_num_kv_heads: 16 +base_mlp_dim: 7168 +base_num_decoder_layers: 16 +head_dim: 128 +mlp_activations: ["silu", "linear"] +dropout_rate: 0.0 +logits_via_embedding: False +normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true +logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embedding dot product for stability +cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly. +float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product +float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax + +# mixture of experts (moe) +num_experts: 1 +num_experts_per_tok: 1 +megablox: True +sparse_matmul: True +capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default +load_balance_loss_weight: 0.01 # weight for the load balance loss + +# deepseek moe +base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer (use base_mlp_dim if not DeepSeek style) +first_num_dense_layers: 0 # number of initial dense layers in the model +shared_experts: 1 +routed_scaling_factor: 1.0 # scaling factor for routing scores +routed_score_func: "" # scoring function for routing +routed_bias: False # a flag if a bias term is added for routing + +# pipeline parallelism +# The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats. +# There is a tradeoff between the num_layers_per_pipeline_stage and num_pipeline_repeats: The more layers per stage the easier +# it is to hide the pipeline communication behind the compute since there is more compute per stage, however there will be a larger bubble +# since there are fewer repeats. Similarly there is tradeoff for num_pipeline_microbatches - more microbatches leads to a smaller bubble, +# but a smaller size per microbatch which may hurt per-stage performance. Additionally note when microbatches > num_stages we have the opportunity to +# perform the circular transfer (last stage to first) asynchronously. +# The bubble fraction is (num_stages - 1) / (num_pipeline_repeats * num_pipeline_microbatches + num_stages - 1) +num_layers_per_pipeline_stage: 1 +# The number of repeats will be set to num_decoder_layers / (num_pipeline_stages * num_layers_per_pipeline_stage) +num_pipeline_repeats: -1 +# num_pipeline_microbatches must be a multiple of the number of pipeline stages. By default it is set to the number of stages. +# Note the microbatch_size is given by global_batch_size / num_pipeline_microbatches, where global_batch_size = per_device_batch_size * num_devices +num_pipeline_microbatches: -1 +pipeline_delay_activation_forwarding: False # This delays the activation forwarding one loop iteration simplifying XLA's task of overlapping since +# the communication and compute in each iteration are now independent. However this comes at the cost of doubling the pipeline bubble, +# and you must set the number of microbatches to at least 2 * num_stages (the minimum 2 * num_stages is set by default with this delay). + +pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights over FSDP before the first pipeline iteration. +# This is a memory/time tradeoff - we now have to store the FSDP gathered weights and gradients (typically in bf16), as opposed +# to only one stage's worth, however we only execute one all-gather and reduce across per repeat, as opposed +# to every microbatch. This is similar to zero-1 sharding, since we also don't need to all gather the FSDP weights in the backward pass. +# An alternative to setting this to true may be to replace any FSDP with DP and use optimizer offloading if necessary. +# A more optimal behavior is to all-gather at the start of each repeat, which would ideally get the best of both worlds - +# a small amount of memory and time, however this has proven hard to implement in SPMD, see b/364386697 for more. + +# There are two loops for PP: +# 1) Outer loop over microbatches (pipeline iterations) +# 2) Inner loop over layers (layers per stage) +# We have observed extra remat when a remat policy and scanning is performed on both, and recommend the default +# settings below of scanning and setting a remat policy only over the pipeline iterations. +# It may be useful to do the reverse when the layers_per_stage is very large. +# The below settings only have effect when using pipeline parallelism. +scan_pipeline_iterations: True +# The layers per stage scanning option is set by scan_layers, we recommend setting scan_layers=False +set_remat_policy_on_pipeline_iterations: True +set_remat_policy_on_layers_per_stage: False + + +# Choose 'remat_policy' between 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp', +# 'save_qkv_proj', 'qkv_proj_offloaded', 'custom', 'minimal_offloaded', 'save_out_proj' and 'full'. +# These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest) +remat_policy: 'full' +# If "custom" remat_policy is chosen, you can select tensors from the following list to offload on host memory, rematerialize or save on device memory. +# Pick one of these options for following tensors: ['remat','device','offload'] +decoder_layer_input: 'device' # this tensor cannot be rematerialized - it serves as periodic checkpoints that act as the remat start points +context: 'remat' # From https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/attention.py#L581-L583 +mlpwi: 'remat' +mlpwi_0: 'remat' +mlpwi_1: 'remat' +mlpwo: 'remat' +query_proj: 'remat' +key_proj: 'remat' +value_proj: 'remat' +qkv_proj: 'remat' +out_proj: 'remat' + +optimizer_memory_host_offload: False +scan_layers: True # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations. +param_scan_axis: 1 + +# The attention parameter dictates the specific algorithm/methodology used to compute the attention scores +# The attention_type parameter determines the variants of attention, e.g. global or local_sliding +attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te +attention_type: 'global' # Supported attention_type: global, local_sliding, mla +sliding_window_size: 0 +attn_logits_soft_cap: 0.0 +final_logits_soft_cap: 0.0 +use_post_attn_norm: False +use_post_ffw_norm: False + +# MLA parameters +q_lora_rank: 0 +kv_lora_rank: 512 +qk_nope_head_dim: 128 +qk_rope_head_dim: 64 +v_head_dim: 128 + +# Combine matmuls for QKV and MLP +fused_qkv: False +fused_mlp: False + +record_internal_nn_metrics: 0 + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Whether or not to enable emergency checkpoint. If True, `local_checkpoint_directory` and a non-zero `local_checkpoint_period` must also be specified. +# Emergency checkpoint is an experimental Orbax feature that: periodically saves to persistent storage and, with a larger invertal, saves to a local directory. +# During restore, if a local copy is available in any slice, it will be broadcast to other slices without having to fetch from persistent storage. +# See more details on https://github.com/google/orbax/tree/main/checkpoint/orbax/checkpoint/experimental/emergency. +enable_emergency_checkpoint: False + +# It should be specified when and only when `enable_emergency_checkpoint` is True. +local_checkpoint_directory: "" + +# It should be a positive number when and only when `enable_emergency_checkpoint` is True. +local_checkpoint_period: 0 + +# Whether or not to use emergency checkpoint with the replicator service. +use_replicator_service: False + +# The interval to backup local checkpoints to the persistent storage. +replicator_backup_interval_minutes: 0 + +# Jax cache directory +jax_cache_dir: "~/jax_cache" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' + +# Parallelism +mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] +logical_axis_rules: [ + ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], + ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], + ['activation_length', ['context', 'sequence']], + ['activation_length', ['context']], + ['activation_length_q', ['context']], + ['activation_kv_length', ['context']], + ['activation_norm_length', ['tensor_sequence', 'sequence']], + ['activation_embed', ['tensor', 'tensor_transpose']], + ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], + ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']], + ['activation_vocab', ['tensor', 'tensor_transpose']], + ['activation_vocab', 'tensor_sequence'], + ['activation_vocab', ['sequence', 'context']], + ['activation_stage', 'stage'], + ['activation_exp', ['expert', 'context']], + ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], + ['decode_length', []], + ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context']], + ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']], + ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], + ['embed', ['fsdp', 'sequence', 'expert']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose']], + ['embed_no_exp', ['fsdp', 'sequence', 'context', 'tensor_transpose']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], + ['embed_no_exp', ['fsdp', 'sequence', 'context']], + ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['layers', 'stage'], + ['kv', []], + ['kv_head_dim', []], + ['cache_batch_prefill', []], + ['cache_batch', ['context']], + ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], + ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], + ['cache_kv', []], + ['cache_sequence', ['context']], + ['exp', ['expert', 'context']], + ['paged_kv_heads', []], + ['num_pages', ['tensor']], + ['tokens_per_page', []], + ['paged_kv_head_dim_size', []], + ] +# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] + +# sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters. +sharding_tolerance: 0.02 + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: 1 +dcn_fsdp_transpose_parallelism: 1 +dcn_sequence_parallelism: 1 # never recommended +dcn_context_parallelism: 1 +dcn_tensor_parallelism: 1 # never recommended +dcn_tensor_transpose_parallelism: 1 +dcn_tensor_sequence_parallelism: 1 # never recommended +dcn_pipeline_parallelism: 1 +dcn_expert_parallelism: 1 +dcn_autoregressive_parallelism: 1 # never recommended +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_fsdp_transpose_parallelism: 1 +ici_sequence_parallelism: 1 +ici_context_parallelism: 1 +ici_tensor_parallelism: 1 +ici_tensor_transpose_parallelism: 1 +ici_tensor_sequence_parallelism: 1 +ici_autoregressive_parallelism: 1 +ici_pipeline_parallelism: 1 +ici_expert_parallelism: 1 + +# The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation, +# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1. +num_slices: -1 + +# Tokenizer +vocab_size: 32_000 # powers of 2 for sharding +tokenizer_path: "assets/tokenizer.llama2" +# tfds pipeline supports tokenizer_type: sentencepiece, huggingface, tiktoken +# grain pipeline supports tokenizer_type: sentencepiece, huggingface +# hf pipeline only supports huggingface type, and will ignore tokenizer_type flag +tokenizer_type: "sentencepiece" +tokenize_train_data: True # False if the dataset is pre-tokenized +tokenize_eval_data: True # False if the dataset is pre-tokenized +add_bos: True +add_eos: True + +# Dataset +per_device_batch_size: 12.0 +expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else total_hosts//expansion_factor_real_data will pull data from GCS. +eval_per_device_batch_size: 0.0 +max_corpus_chars: 10_000_000 +train_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected" +eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected" + +# direct preference optimization (DPO) +use_dpo: False +dpo_label_smoothing: 0.0 +dpo_beta: 0.1 + +# dataset_type must be synthetic, hf, grain, tfds +# details in: https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md +dataset_type: tfds +# for TFDS input pipeline (dataset_type=tfds) +dataset_path: "" # your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/" +dataset_name: 'c4/en:3.0.1' +eval_dataset_name: 'c4/en:3.0.1' +eval_split: 'validation' +# for HuggingFace input pipeline (dataset_type=hf) +hf_path: '' +hf_data_dir: '' +hf_train_files: '' +hf_eval_split: '' +hf_eval_files: '' +hf_access_token: '' +# for Grain input pipeline (dataset_type=grain) +grain_train_files: '' +grain_eval_files: '' +grain_worker_count: 1 + +# Training loop +steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps +log_period: 100 # Flushes Tensorboard + +jax_distributed_initialization_timeout: 300 # This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py +# Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers +# only to the jax coordination service. +jax_debug_log_modules: "" # Set this to "jax" to enable jax verbose logging such as for the jax coordination service initialization. +skip_jax_distributed_system: False # If True we will not initialize the jax distributed system. +# Currently the jax distributed is needed on cloud TPUs for async checkpointing. +# However when run on google internal TPUs the coordination service is started automatically +# and we should set this to True so we won't try to initialize a second time manually. + +# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 +# Learning rate schedule has either two or three parts: +# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction] +# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps +# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps. +# The zero learning rate section can be used to more accurately measure the fully trained model's performance. +learning_rate: 3.e-5 +cosine_learning_rate_final_fraction: 0.1 +warmup_steps_fraction: 0.1 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +max_target_length: 2048 # Maximum sequence length +max_prefill_predict_length: 64 # Maximum length for the prefill when doing autoregression +prompt: "I love to" # Prompt for language model sampling. +load_from_prefill_dir: False # If true, decode.py doesn't "prefill" but just reads from directory +prefill_cache_dir: "" # If set and load_from_prefill_dir, decode.py reads from directory. If set, decode.py writes to directory +autoregressive_decode_assert: "" + +# For nsys profiler, pass the training command to nsys command +# e.g. nsys profile -s none --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop {training command} +profiler: "" # Supported profiler: '', xplane, nsys +# If set to true, upload all profiler results from all hosts. Otherwise, only upload the profiler result from the first host. +upload_all_profiler_results: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 1 +# Profile for a small number of steps to avoid a large profile file size. +profiler_steps: 5 +profile_cleanly: True # If set to true, adds a block_until_ready on train state which aligns the profile for each step. +profile_periodically_period: -1 # If set to a positive integer, profile every profile_periodically_period steps. +# This is useful to debug scenarios where performance is changing. + + +# Dump HLO options +dump_hlo: False +dump_hlo_local_dir: "/tmp/xla_dump/" +dump_hlo_delete_local_after: True # Cleans local directory after its uploaded +dump_hlo_gcs_dir: "" # Defaults to {base_output_directory}/{run_name}/xla_dump +dump_hlo_module_name: "jit_train_step" # Filter uploading modules by this string. Set to empty string to remove any filter. +dump_hlo_xla_flags: "" # Defaults to "--xla_dump_to={dump_hlo_local_dir} --xla_dump_hlo_module_re={dump_hlo_module_name} --xla_dump_large_constants" +dump_hlo_upload_all: False # If true all hosts dump HLO, false only jax.process_index()==0 +# All hosts should have identical HLO for SPMD programs, however we have encountered some bugs +# where this is not the case and it is helpful to compare HLO across hosts. + +# When dropout is false the model is a deterministic function of the +# data_shuffle_seed and init_weights_seed (i.e. reproducible losses) +enable_dropout: True +enable_data_shuffling: True +data_shuffle_seed: 0 +init_weights_seed: 0 + +# You may disable clipping by setting gradient_clipping_threshold to zero. +gradient_clipping_threshold: 1.0 + +# Instead of updating the weights every step, you may effectively use a larger +# batch by accumulating the gradient over a set of steps. +gradient_accumulation_steps: 1 + +# AdamW optimizer parameters +# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 +opt_type: "adamw" # one of "adam_pax" or "adamw" +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_eps_root: 0. # A small constant applied to denominator inside the square root. +adam_weight_decay: 0.1 # AdamW Weight decay + +# Stack trace parameters +collect_stack_trace: False +stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False. +stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds. + +# Use iota operator in Embed +use_iota_embed: False +# use positional embedding +use_untrainable_positional_embedding: False +trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size +# RoPE parameters +rope_type: "default" # one of "default", "llama3.1" or "yarn" +rope_min_timescale: 1 +rope_max_timescale: 10_000 + +# yarn RoPE parameters +original_seq_len: 4096 +rope_theta: 10000.0 +rope_factor: 40 +beta_fast: 32 +beta_slow: 1 +mscale: 1.0 + +# Ahead of time Compilation (aka AOT) +# Only set these arguments if you are running train_compile or loading a compiled train step. +compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g. compiled_train_v5e-256.pickle +compile_topology: '' # Target hardware version, e.g. 'v5e-256' +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. + +decode_sampling_strategy: "greedy" # decode_sampling_strategy should be one of greedy, weighted, nucleus, or topk +decode_sampling_nucleus_p: -1 # set if you're doing nucleus / top-p +decode_sampling_top_k: 0 # set if you're doing top-k +decode_sampling_temperature: 1. + +eval_interval: -1 # the specific number of train step between eval_step +eval_steps: -1 # run this number of steps for eval, recommend setting this to prevent error due to running out of evel data +target_eval_loss: 0. # early stop once reaching target eval_loss + +# Goodput parameters +enable_goodput_recording: True +monitor_goodput: True +goodput_upload_interval_seconds: 30 +enable_pathways_goodput: False +monitor_step_time_deviation: True +step_deviation_interval_seconds: 30 + +# GCP workload monitoring +report_heartbeat_metric_for_gcp_monitoring: False +heartbeat_reporting_interval_in_seconds: 5 +report_performance_metric_for_gcp_monitoring: False + +enable_tensorboard: True + +# Vertex AI Tensorboard Configurations - https://github.com/google/maxtext/tree/main/getting_started/Use_Vertex_AI_Tensorboard.md +# Set to True for GCE, False if running via XPK +use_vertex_tensorboard: False +# Project to create Vertex AI Tensorboard in for GCE, blank if project is set using 'gcloud config set project' +# Set this to blank if running via XPK +vertex_tensorboard_project: "" +# Region to create Vertex AI Tensorboard in for GCE, blank if running via XPK +# Vertex AI supported regions: https://cloud.google.com/vertex-ai/docs/general/locations#available-regions +vertex_tensorboard_region: "" + +# If set to True, MaxText will perform extra checks using jax.checkify. Note that this will effect performance. +max_checkify: False + +# Inference +inference_microbenchmark_prefix_cache_entries_num: 100 +inference_microbenchmark_prefix_cache_common_prefix_proportion: 0.5 +inference_microbenchmark_prefill_lengths: "64,128,256,512,1024" +inference_microbenchmark_stages: "prefill,generate" +inference_microbenchmark_loop_iters: 10 +inference_microbenchmark_log_file_path: "" +inference_metadata_file: "" # path to a json file +inference_server: "MaxtextInterleavedServer" # inference server to start +inference_benchmark_test: False +enable_model_warmup: False +multi_sampling: False + +# Stack prefill cache across the layer to reduce the +# Python layer latency. +stack_prefill_result_cache: False + +# KV Cache layout control +# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV +# Default layout: 1,2,0,3 ; CACHE_SEQUENCE, CACHE_HEADS, CACHE_BATCH, CACHE_KV +prefill_cache_axis_order: "1,2,0,3" +ar_cache_axis_order: "1,2,0,3" + +# Compute layout control +# Default layout: 0,1,2,3 ; BATCH, LENGTH, HEAD, D_KV +# Currently only support compute layout: 0,1,2,3 and 0,2,1,3 +compute_axis_order: "0,1,2,3" + +reshape_q: False + +# Maxengine Metrics +prometheus_port: 0 + +# Maxengine server +enable_jax_profiler: False +jax_profiler_port: 9999 + +log_config: True # Prints the config (after defaults have been set by pyconfig logic) + +# Checkpoint Structured logging +enable_checkpoint_cloud_logger: False + +# Single-controller +enable_single_controller: False + +custom_mesh: "" # Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8'] +# Split physical axes for https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.mesh_utils.create_device_mesh.html +allow_split_physical_axes: False +# Apply transformations to the mesh to optimize for TPU v6e +optimize_mesh_for_tpu_v6e: False + +use_ragged_attention: False +ragged_block_size: 256 + +### Splash attention block sizes +# These can be tuned for specific hardware generations, and can be set up to +# the model's sequence length. +sa_block_q: 512 +sa_block_kv: 512 +sa_block_kv_compute: 512 +sa_block_q_dkv: 512 +sa_block_kv_dkv: 512 +sa_block_kv_dkv_compute: 512 +sa_block_q_dq: 512 +sa_block_kv_dq: 512 +sa_use_fused_bwd_kernel: False +sa_q_layout: "HEAD_DIM_MINOR" +sa_k_layout: "HEAD_DIM_MINOR" +sa_v_layout: "HEAD_DIM_MINOR" + +####################### +### Paged Attention ### +####################### +# These settings take effect only when `attention=paged`. +# They should be adjusted based on the available HBM and model config. +pagedattn_num_pages: 64 +pagedattn_tokens_per_page: 32 +pagedattn_pages_per_compute_block: 8 + +# Chunked Prefill Parameters +prefill_chunk_size: 256 +use_chunked_prefill: False