-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add seq parallelism for attention and MoE MLP #1328
base: main
Are you sure you want to change the base?
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, looks good, left some comments.
Lets test, dense models, training and moe 8x22B on v6e-8 before pushing
MaxText/layers/attentions.py
Outdated
@@ -1424,7 +1457,7 @@ def out_projection(self, output_dim: int, out: Array) -> Array: | |||
features=output_dim, | |||
axis=(-2, -1), | |||
kernel_init=self.kernel_init, | |||
kernel_axes=("heads", "kv", "embed"), | |||
kernel_axes=(None, None, None), # trade speed with memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this mean we might OOM in some higher batch sizes?
@@ -103,6 +103,9 @@ def __call__( | |||
float32_logits=cfg.float32_logits, | |||
quant=self.quant, | |||
kv_quant=quantizations.configure_kv_quant(cfg), | |||
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]), | |||
ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for fixing this :)
@@ -212,6 +212,7 @@ def validate_model_name(s: str) -> bool: | |||
"llama3.1-70b", | |||
"llama3.1-405b", | |||
"llama3.3-70b", | |||
"subsup", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
cp = self.config.ici_context_parallelism | ||
batch_size = inputs.shape[0] | ||
seq_len = inputs.shape[1] | ||
if seq_len % cp != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe abstract this part to get_cp and get_sub_seq_length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, let's extract those into a helper function, and reuse. Or get_context_partition_and_sub_seq
to return both.
['activation_length', ['sequence']], | ||
['activation_length', ['sequence', 'context']], | ||
['activation_length', ['context']], | ||
['activation_length_q', ['context']], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as talked offline, shall we pull out a config, moe_config_inference.yml specifically for moe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this config? i.e. is this context parallelism related?
LENGTH = common_types.LENGTH | ||
KV_LENGTH = common_types.KV_LENGTH |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will have to make changes at places where kv_quant=True or else its broken, one work around is to, push this code with an assertion that when context_parallel !=1, kv_quant should be False.
and quickly follow up with quantization changes. but if its not too hard, I would prefer them going together
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change and improvement! Overall, LGTM! One thing I'd like to see how is the performance impact for training. When it's ready, could you help run a benchmark on 8X7b (or other model size) with FSDP + EP sharding in dropping (with and without this change)? Capturing profiles will be great! Thank you!
['activation_length', ['sequence']], | ||
['activation_length', ['sequence', 'context']], | ||
['activation_length', ['context']], | ||
['activation_length_q', ['context']], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this config? i.e. is this context parallelism related?
@@ -568,16 +582,32 @@ def apply_attention_dot( | |||
key = key.astype(jnp.float32) | |||
|
|||
q_seq_len = query.shape[1] | |||
# special sharding for decode | |||
if self.config.ici_context_parallelism > 0 and q_seq_len == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: could we wrap this self.config.ici_context_parallelism > 0 and q_seq_len == 1
to a helper function and reuse, and name it something like is_context_parallelism_in_decoding()
?
cp = 1 | ||
sub_seq = seq_len // cp | ||
|
||
top_k_indices = jnp.reshape(top_k_indices, (batch_size, cp, sub_seq, top_k_indices.shape[2])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we rename cp
as context_partitions
or similar? And add a comment about top_k_indices shape, i.e. [batch_size, context_partition, sub_sequence, num_experts_per_tok]?
cp = self.config.ici_context_parallelism | ||
batch_size = inputs.shape[0] | ||
seq_len = inputs.shape[1] | ||
if seq_len % cp != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, let's extract those into a helper function, and reuse. Or get_context_partition_and_sub_seq
to return both.
# intermediate_layer = nn.with_logical_constraint( | ||
# intermediate_layer, | ||
# ("activation_exp", "activation_batch_no_exp", None, "activation_embed"), | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this block is removed
if self.config.activations_in_float32:
intermediate_layer = intermediate_layer.astype(jnp.float32)
Description
with sp+ep, moe customer 2k seq inference improved by 20%
FIXES: b/374773995
Tests
tested on v6e/v5p:
SEQ=2048
python MaxText/inference_microbenchmark.py MaxText/configs/base.yml max_prefill_predict_length=$SEQ max_target_length=6144 model_name=mixtral-8x7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_expert_parallelism=1 ici_context_parallelism=4 ici_tensor_parallelism=1 scan_layers=false per_device_batch_size=1 attention=dot_product megablox=False quantization=int8 checkpoint_is_quantized=True quantize_kvcache=True capacity_factor=1 tokenizer_path=assets/tokenizer.mistral-v3 compute_axis_order=0,2,1,3 ar_cache_axis_order=0,2,1,3 enable_jax_profiler=True inference_microbenchmark_prefill_lengths="$SEQ" base_output_directory=$OUT_DIR run_name=$RUN_NAME profiler=xplane model_call_mode=inference inference_microbenchmark_stages=prefill
Checklist
Before submitting this PR, please make sure (put X in square brackets):