You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If we replace model.embed_tokens.weight: [fsdp, tensor] with model.embed_tokens.weight: [tensor, fsdp], one would assume the model would train just as fine, because this change won't affect any subsequent decoder layers. In practice we observe that:
The gradient of some decoder layers becomes NaN by the 6-th iteration.
The collectives in the backward pass are drastically different (e.g. all-reduce becomes all-gather.
The model uses less HBM.
This is a tracking bug to find the root cause of this problem. Some hypothesis:
The model.embed_tokens.weight sharding got propagated to other tensors during the backward pass, changing the collectives significantly, and introducing numerical instability. We'll need to inspect how GSPMD propagated the shardings to tensors in the backward pass to dig deeper.
The scaling configuration for Llama 3.1 405B on 1 Trillium pod is
If we replace
model.embed_tokens.weight: [fsdp, tensor]
withmodel.embed_tokens.weight: [tensor, fsdp]
, one would assume the model would train just as fine, because this change won't affect any subsequent decoder layers. In practice we observe that:This is a tracking bug to find the root cause of this problem. Some hypothesis:
model.embed_tokens.weight
sharding got propagated to other tensors during the backward pass, changing the collectives significantly, and introducing numerical instability. We'll need to inspect how GSPMD propagated the shardings to tensors in the backward pass to dig deeper.To repro
tp run torchprime/torch_xla_models/train.py model=llama-3.1-405b global_batch_size=64 mesh.fsdp=64 mesh.tensor=4 dataset_config_name=wikitext-103-raw-v1 profile_step=15 logging_steps=1 model.scaling.sharding='{model.embed_tokens.weight:[tensor,fsdp]}'
The text was updated successfully, but these errors were encountered: