You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I created the same issue in the orbax repository but I'm wondering if memory consumption when saving checkpoints could be improved on the MaxText's side.
Saving a 4GB array—as shown in the example below—can require up to 14GB of RAM.
This poses a significant problem since TPUs like the v5p typically have only slightly more CPU RAM (~440GB) than their TPU RAM per VM (4 chips × 95GB = 380GB). Consequently, while I am able to fit and train a large model, I encounter an out-of-memory (OOM) error when saving a checkpoint.
Below is a minimal example that reproduces this behavior: a 1-billion-parameter array (4GB) that consumes as much as 14GB of CPU memory when being saved to the local disk:
I created the same issue in the orbax repository but I'm wondering if memory consumption when saving checkpoints could be improved on the MaxText's side.
Saving a 4GB array—as shown in the example below—can require up to 14GB of RAM.
This poses a significant problem since TPUs like the v5p typically have only slightly more CPU RAM (~440GB) than their TPU RAM per VM (4 chips × 95GB = 380GB). Consequently, while I am able to fit and train a large model, I encounter an out-of-memory (OOM) error when saving a checkpoint.
Below is a minimal example that reproduces this behavior: a 1-billion-parameter array (4GB) that consumes as much as 14GB of CPU memory when being saved to the local disk:
Also, I tried different values for
save_concurrent_gb
,chunk_byte_size
andocdbt_target_data_file_size
but didn't suceeed.The text was updated successfully, but these errors were encountered: