Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving a checkpoint requires at least 3x more RAM than the size of the array itself #1322

Open
rodrigo-f-nogueira opened this issue Feb 28, 2025 · 0 comments

Comments

@rodrigo-f-nogueira
Copy link

I created the same issue in the orbax repository but I'm wondering if memory consumption when saving checkpoints could be improved on the MaxText's side.

Saving a 4GB array—as shown in the example below—can require up to 14GB of RAM.

This poses a significant problem since TPUs like the v5p typically have only slightly more CPU RAM (~440GB) than their TPU RAM per VM (4 chips × 95GB = 380GB). Consequently, while I am able to fit and train a large model, I encounter an out-of-memory (OOM) error when saving a checkpoint.

Below is a minimal example that reproduces this behavior: a 1-billion-parameter array (4GB) that consumes as much as 14GB of CPU memory when being saved to the local disk:

import jax
import jax.numpy as jnp
import orbax.checkpoint as ocp
from flax.training.train_state import TrainState
from etils import epath

params = 1e9  # 1B params => 4GB in fp32 
checkpoint_dir = "/root/temp_dir"

dim1 = int(params ** 0.5)
weight = jax.random.uniform(jax.random.PRNGKey(123), shape=(dim1, dim1), dtype=jnp.float32)

print(f"{dim1=}")
print(f"{weight.size=}; Total GB: {4 * weight.size / 1e9}")

state = TrainState(step=0, apply_fn=None, params=weight, tx=None, opt_state=None)

p = epath.Path(checkpoint_dir)
p.mkdir(exist_ok=True, parents=True)

item_handlers = {"items": ocp.PyTreeCheckpointHandler(save_concurrent_gb=8, use_ocdbt=True, use_zarr3=True)}
checkpoint_manager = ocp.CheckpointManager(
    p,
    item_names=("items",),
    item_handlers=item_handlers,
    options=ocp.CheckpointManagerOptions(
        create=True,
        save_interval_steps=10,
        enable_async_checkpointing=True,
    ),
    logger=None,
)

chunk_byte_size = 2 * 1024**3  # 2 GB
save_args = jax.tree.map(lambda _: ocp.SaveArgs(chunk_byte_size=chunk_byte_size), state)

step = 10
checkpoint_manager.save(
        step,
        args=ocp.args.Composite(
            items=ocp.args.PyTreeSave(
                item=state, save_args=save_args, ocdbt_target_data_file_size=chunk_byte_size
            )
        ),
    )

checkpoint_manager.wait_until_finished()

Also, I tried different values for save_concurrent_gb, chunk_byte_size and ocdbt_target_data_file_size but didn't suceeed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant