Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add seq parallelism for attention and MoE MLP #1328

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

suexu1025
Copy link
Collaborator

@suexu1025 suexu1025 commented Mar 1, 2025

Description

  1. Add seq_parallelism + exp_parallelism for attention + followed MLP module
    with sp+ep, moe customer 2k seq inference improved by 20%
  2. Fix prefill_KV_cache sharding mismatch during seq_parallelism
  3. decode improved by 10%
  4. Enable inference auto layout in mistral model

FIXES: b/374773995

Tests

tested on v6e/v5p:
SEQ=2048

python MaxText/inference_microbenchmark.py MaxText/configs/base.yml max_prefill_predict_length=$SEQ max_target_length=6144 model_name=mixtral-8x7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_expert_parallelism=1 ici_context_parallelism=4 ici_tensor_parallelism=1 scan_layers=false per_device_batch_size=1 attention=dot_product megablox=False quantization=int8 checkpoint_is_quantized=True quantize_kvcache=True capacity_factor=1 tokenizer_path=assets/tokenizer.mistral-v3 compute_axis_order=0,2,1,3 ar_cache_axis_order=0,2,1,3 enable_jax_profiler=True inference_microbenchmark_prefill_lengths="$SEQ" base_output_directory=$OUT_DIR run_name=$RUN_NAME profiler=xplane model_call_mode=inference inference_microbenchmark_stages=prefill

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • [ x ] I have performed a self-review of my code.
  • [ x ] I have necessary comments in my code, particularly in hard-to-understand areas.
  • [ x ] I have run end-to-end tests tests and provided workload links above if applicable.
  • [ x ] I have made or will make corresponding changes to the doc if needed.

Copy link

google-cla bot commented Mar 1, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@suexu1025 suexu1025 changed the title [Draft] Add seq parallelism for attention and MLP Add seq parallelism for attention and MoE MLP Mar 6, 2025
Copy link
Collaborator

@mailvijayasingh mailvijayasingh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good, left some comments.

Lets test, dense models, training and moe 8x22B on v6e-8 before pushing

@@ -1424,7 +1457,7 @@ def out_projection(self, output_dim: int, out: Array) -> Array:
features=output_dim,
axis=(-2, -1),
kernel_init=self.kernel_init,
kernel_axes=("heads", "kv", "embed"),
kernel_axes=(None, None, None), # trade speed with memory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean we might OOM in some higher batch sizes?

@@ -103,6 +103,9 @@ def __call__(
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for fixing this :)

@@ -212,6 +212,7 @@ def validate_model_name(s: str) -> bool:
"llama3.1-70b",
"llama3.1-405b",
"llama3.3-70b",
"subsup",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

cp = self.config.ici_context_parallelism
batch_size = inputs.shape[0]
seq_len = inputs.shape[1]
if seq_len % cp != 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe abstract this part to get_cp and get_sub_seq_length?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, let's extract those into a helper function, and reuse. Or get_context_partition_and_sub_seq to return both.

['activation_length', ['sequence']],
['activation_length', ['sequence', 'context']],
['activation_length', ['context']],
['activation_length_q', ['context']],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as talked offline, shall we pull out a config, moe_config_inference.yml specifically for moe?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this config? i.e. is this context parallelism related?

LENGTH = common_types.LENGTH
KV_LENGTH = common_types.KV_LENGTH
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have to make changes at places where kv_quant=True or else its broken, one work around is to, push this code with an assertion that when context_parallel !=1, kv_quant should be False.
and quickly follow up with quantization changes. but if its not too hard, I would prefer them going together

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change and improvement! Overall, LGTM! One thing I'd like to see how is the performance impact for training. When it's ready, could you help run a benchmark on 8X7b (or other model size) with FSDP + EP sharding in dropping (with and without this change)? Capturing profiles will be great! Thank you!

['activation_length', ['sequence']],
['activation_length', ['sequence', 'context']],
['activation_length', ['context']],
['activation_length_q', ['context']],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this config? i.e. is this context parallelism related?

@@ -568,16 +582,32 @@ def apply_attention_dot(
key = key.astype(jnp.float32)

q_seq_len = query.shape[1]
# special sharding for decode
if self.config.ici_context_parallelism > 0 and q_seq_len == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: could we wrap this self.config.ici_context_parallelism > 0 and q_seq_len == 1 to a helper function and reuse, and name it something like is_context_parallelism_in_decoding()?

cp = 1
sub_seq = seq_len // cp

top_k_indices = jnp.reshape(top_k_indices, (batch_size, cp, sub_seq, top_k_indices.shape[2]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename cp as context_partitions or similar? And add a comment about top_k_indices shape, i.e. [batch_size, context_partition, sub_sequence, num_experts_per_tok]?

cp = self.config.ici_context_parallelism
batch_size = inputs.shape[0]
seq_len = inputs.shape[1]
if seq_len % cp != 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, let's extract those into a helper function, and reuse. Or get_context_partition_and_sub_seq to return both.

# intermediate_layer = nn.with_logical_constraint(
# intermediate_layer,
# ("activation_exp", "activation_batch_no_exp", None, "activation_embed"),
# )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this block is removed

if self.config.activations_in_float32:
          intermediate_layer = intermediate_layer.astype(jnp.float32)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants