-
Notifications
You must be signed in to change notification settings - Fork 289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TRN2 Meshes and Configurations #916
base: main
Are you sure you want to change the base?
TRN2 Meshes and Configurations #916
Conversation
6b404f6
to
3f7c840
Compare
Added a ModelConfigModifier that overrides the class for a module. Allowing different model configurations based on Model size and platform. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for making such change, overall looks good. A few nit comments.
continue | ||
# Here we assume x.y.z format. | ||
# One example would be model.decoder.transformer.layer. | ||
target_modules = module_name.split(".") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try to extract a common util function named something like
def replace_module_recursive(target_modules:str, config_key: str, target_config)
and make it applied to both here and RematSpecModifier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I extracted a helper function, let me know if this looks good
708fc5e
to
d481132
Compare
Added |
5be50d7
to
9b10041
Compare
|
||
found_module, parent_module, key_in_parent = find_target_module(module_name, cfg) | ||
|
||
# Copy configurations from the config being replaced on a best effort basis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, this behavior is not explained in the class comments. So we are not replacing but merging the configs? Maybe we should support a merge function instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the goal is to change the config to a similar module. This means most of the configuration can be reused from before. Essentially replacing the module but merging the config. Let me extract out a merge function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Abstracted out a merge function let me know if more changes are needed for this.
9b10041
to
0f0a530
Compare
@ruomingp Thank you for the review, I have addressed all your comments, please let me know if more changes are needed. |
for module_name, model_cfg in self._model_cfg_modifications.items(): | ||
found_module = _find_target_module(module_name, cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In utils.py we have get_recursively
and set_recursively
for Nested[...]
. I wonder if it will be useful to add corresponding methods to ConfigBase. Then we can do something like:
for module_name, model_cfg in self._model_cfg_modifications.items(): | |
found_module = _find_target_module(module_name, cfg) | |
for cfg_path, cfg_modification in self._model_cfg_modifications.items(): | |
child_cfg = cfg.get_recursively(cfg_path) | |
child_cfg = cfg_modification(child_cfg, path=cfg_path) | |
cfg.set_recursively(cfg_path, value=child_cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added get_recursively and set_recursively functions to ConfigBase. Let me know if it looks good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if an alternative (which aims to simplify the ConfigBase
api) is to do something similar to Python's sorted
; we allow utils.get_resursively
to take a value fn:
# Default behavior is to use key lookup:
utils.get_recursively(..., value_fn=lambda k,v: v[k])
# Custom behavior can be attribute lookup:
utils.get_recursively(..., value_fn=lambda k,v: getattr(v,k))
A benefit is that other non-config instances can also leverage get_recursively
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @markblee , maybe we can do this in a follow-up PR?
c23e3b2
to
94bfff6
Compare
Added a more flexible |
45c7df1
to
8807856
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly lgtm, some minor comments.
8807856
to
25510d6
Compare
eec33eb
to
86bafa8
Compare
780d424
to
b6ae638
Compare
axlearn/experiments/text/gpt/fuji.py
Outdated
@@ -151,6 +155,72 @@ def get_trainer_kwargs( | |||
|
|||
rope_theta = ROPE_THETA[version] | |||
|
|||
# TRN2 specific model config modifications |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move all the modifications in a helper function?
saying
def _generate_trainium2_custom_configs():
...
return trn2_model_modifications, trn2_partition_spec_modifications
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed it, lmk if it looks good
b6ae638
to
f10ebd0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approve overall, please address @hanzhi713 's comment.
+ Fix modifier tests
f10ebd0
to
37986ce
Compare
Updated the PR to address failing tests, can we re-trigger the CI please? Thank you |
37986ce
to
4cda0dd
Compare
@kelvin-zou @hanzhi713 Thank you both for the review, I addressed all the comments. Let's merge this if it looks good |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
Hello @ruomingp, can I please get an approval if this PR looks good? Looks like that is needed for the PR to merge. |
maybe @markblee can have a second eye for the final round? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approve to unblock.
(Looks like we still need @ruomingp 's approval to unblock the 'requested changes'.) |
4cda0dd
to
53472f2
Compare
Addressed the final comment from @markblee, thanks everyone! Can we run the CI again and merge this? |
This PR adds meshes for TRN2/1 for Fuji models and transformer layer configuration favorable to Neuron.
Neuron supports stacked transformer and GroupedQKVLinear instead of FusedGroupedQKVLinear for Grouped Query Attention (GQA)
This is a newer version of the PR #885. This PR resolved all comments and requested changes mentioned in the linked PR.