Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRN2 Meshes and Configurations #916

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

apoorvtintin
Copy link
Contributor

This PR adds meshes for TRN2/1 for Fuji models and transformer layer configuration favorable to Neuron.

Neuron supports stacked transformer and GroupedQKVLinear instead of FusedGroupedQKVLinear for Grouped Query Attention (GQA)

This is a newer version of the PR #885. This PR resolved all comments and requested changes mentioned in the linked PR.

@apoorvtintin apoorvtintin requested review from ruomingp, markblee and a team as code owners January 10, 2025 00:48
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 6b404f6 to 3f7c840 Compare January 10, 2025 00:53
@apoorvtintin
Copy link
Contributor Author

Added a ModelConfigModifier that overrides the class for a module. Allowing different model configurations based on Model size and platform.

Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making such change, overall looks good. A few nit comments.

continue
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to extract a common util function named something like
def replace_module_recursive(target_modules:str, config_key: str, target_config) and make it applied to both here and RematSpecModifier

Copy link
Contributor Author

@apoorvtintin apoorvtintin Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted a helper function, let me know if this looks good

axlearn/common/trainer_config_modifier_test.py Outdated Show resolved Hide resolved
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 708fc5e to d481132 Compare January 10, 2025 07:38
@apoorvtintin
Copy link
Contributor Author

apoorvtintin commented Jan 10, 2025

Added ParameterPartitionSpecModifier for parameters to shard Embeddings in a vocab parallel manner as described in Megatron LM.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 5be50d7 to 9b10041 Compare January 10, 2025 08:10
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved

found_module, parent_module, key_in_parent = find_target_module(module_name, cfg)

# Copy configurations from the config being replaced on a best effort basis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, this behavior is not explained in the class comments. So we are not replacing but merging the configs? Maybe we should support a merge function instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the goal is to change the config to a similar module. This means most of the configuration can be reused from before. Essentially replacing the module but merging the config. Let me extract out a merge function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Abstracted out a merge function let me know if more changes are needed for this.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 9b10041 to 0f0a530 Compare January 12, 2025 07:06
@apoorvtintin
Copy link
Contributor Author

@ruomingp Thank you for the review, I have addressed all your comments, please let me know if more changes are needed.

@apoorvtintin apoorvtintin requested a review from ruomingp January 12, 2025 07:08
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
Comment on lines 239 to 244
for module_name, model_cfg in self._model_cfg_modifications.items():
found_module = _find_target_module(module_name, cfg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In utils.py we have get_recursively and set_recursively for Nested[...]. I wonder if it will be useful to add corresponding methods to ConfigBase. Then we can do something like:

Suggested change
for module_name, model_cfg in self._model_cfg_modifications.items():
found_module = _find_target_module(module_name, cfg)
for cfg_path, cfg_modification in self._model_cfg_modifications.items():
child_cfg = cfg.get_recursively(cfg_path)
child_cfg = cfg_modification(child_cfg, path=cfg_path)
cfg.set_recursively(cfg_path, value=child_cfg)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added get_recursively and set_recursively functions to ConfigBase. Let me know if it looks good

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if an alternative (which aims to simplify the ConfigBase api) is to do something similar to Python's sorted; we allow utils.get_resursively to take a value fn:

# Default behavior is to use key lookup:
utils.get_recursively(..., value_fn=lambda k,v: v[k])

# Custom behavior can be attribute lookup:
utils.get_recursively(..., value_fn=lambda k,v: getattr(v,k))

A benefit is that other non-config instances can also leverage get_recursively.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @markblee , maybe we can do this in a follow-up PR?

@apoorvtintin
Copy link
Contributor Author

apoorvtintin commented Jan 15, 2025

Added a more flexible PartitionSpecModifier that can modify multiple partition_spec attributes in a single module config.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 45c7df1 to 8807856 Compare January 17, 2025 01:17
Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly lgtm, some minor comments.

axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 8807856 to 25510d6 Compare January 22, 2025 01:39
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from eec33eb to 86bafa8 Compare January 23, 2025 05:40
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 780d424 to b6ae638 Compare February 5, 2025 19:01
@apoorvtintin apoorvtintin requested a review from ruomingp February 5, 2025 19:03
@kelvin-zou kelvin-zou requested a review from hanzhi713 February 5, 2025 19:52
@@ -151,6 +155,72 @@ def get_trainer_kwargs(

rope_theta = ROPE_THETA[version]

# TRN2 specific model config modifications
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move all the modifications in a helper function?
saying

def _generate_trainium2_custom_configs():
...
return trn2_model_modifications, trn2_partition_spec_modifications

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed it, lmk if it looks good

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from b6ae638 to f10ebd0 Compare February 5, 2025 22:44
Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve overall, please address @hanzhi713 's comment.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from f10ebd0 to 37986ce Compare February 6, 2025 04:53
@apoorvtintin
Copy link
Contributor Author

Updated the PR to address failing tests, can we re-trigger the CI please? Thank you

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 37986ce to 4cda0dd Compare February 6, 2025 22:13
@apoorvtintin
Copy link
Contributor Author

@kelvin-zou @hanzhi713 Thank you both for the review, I addressed all the comments. Let's merge this if it looks good

Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@apoorvtintin
Copy link
Contributor Author

Hello @ruomingp, can I please get an approval if this PR looks good? Looks like that is needed for the PR to merge.

@kelvin-zou
Copy link
Contributor

maybe @markblee can have a second eye for the final round?

Copy link
Contributor

@markblee markblee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve to unblock.

axlearn/common/config.py Outdated Show resolved Hide resolved
@markblee
Copy link
Contributor

markblee commented Feb 7, 2025

(Looks like we still need @ruomingp 's approval to unblock the 'requested changes'.)

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 4cda0dd to 53472f2 Compare February 10, 2025 01:23
@apoorvtintin
Copy link
Contributor Author

Addressed the final comment from @markblee, thanks everyone! Can we run the CI again and merge this?

@ruomingp ruomingp added this pull request to the merge queue Feb 10, 2025
@ruomingp ruomingp removed this pull request from the merge queue due to a manual request Feb 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants