Skip to content

Commit

Permalink
Minor changes to Checkpointer
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Feb 26, 2025
1 parent daec8c5 commit 11a31d3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 10 deletions.
2 changes: 1 addition & 1 deletion axlearn/common/array_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ async def _run_deserializer():
return fut.result()


class BoundedDataShardedAsyncCheckpointManager(serialization.GlobalAsyncCheckpointManager):
class BoundedDataShardedAsyncCheckpointManager(GlobalAsyncCheckpointManager):
"""Similar to GlobalAsyncCheckpointManager but with few improvements:
1. Writing to tensorstore requires no host-to-host copy most of the time. This reduces host
Expand Down
19 changes: 13 additions & 6 deletions axlearn/common/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,9 @@ def restore_from_dir(
)
return self._restore_tensorstore_state(state, ckpt_dir=ckpt_dir, spec=spec)

def _restore_tensorstore_state(self, state, *, ckpt_dir: str, spec: CheckpointSpec):
def _restore_tensorstore_state(
self, state, *, ckpt_dir: str, spec: CheckpointSpec, sync: bool = True
):
restored_gda_values = self._manager.deserialize(
shardings=spec.shardings,
tensorstore_specs=spec.tensorstore_specs,
Expand All @@ -584,7 +586,8 @@ def _restore_tensorstore_state(self, state, *, ckpt_dir: str, spec: CheckpointSp
restored_state = jax.tree_util.tree_unflatten(
jax.tree_util.tree_structure(state), state_leaves
)
multihost_utils.sync_global_devices(ckpt_dir)
if sync:
multihost_utils.sync_global_devices(ckpt_dir)
return restored_state

def stop(self):
Expand Down Expand Up @@ -906,7 +909,11 @@ class Config(BaseCheckpointer.Config):
def _all_checkpoint_paths(cls, base_dir: str) -> list[str]:
"""Like `checkpoint_paths`, but also include non-committed checkpoints."""
try:
return [path for path in fs.listdir(base_dir) if path.startswith(STEP_PREFIX)]
return [
os.path.join(base_dir, path.rstrip("/"))
for path in fs.listdir(base_dir)
if path.startswith(STEP_PREFIX)
]
except fs.NotFoundError:
return []

Expand All @@ -918,7 +925,7 @@ def checkpoint_paths(cls, base_dir: str) -> list[str]:
# gcs when there are many checkpoint files, even if using a "native" solution like
# `google-cloud-python` SDK.
paths = cls._all_checkpoint_paths(base_dir)
paths = [os.path.join(base_dir, path, "index") for path in paths]
paths = [os.path.join(path, "index") for path in paths]
with futures.ThreadPoolExecutor() as pool:
index_exists = pool.map(fs.exists, paths)
return [os.path.dirname(path) for path, committed in zip(paths, index_exists) if committed]
Expand Down Expand Up @@ -1042,12 +1049,12 @@ def _run_garbage_collection(self):
remaining_dirs, gc_dirs = [], []

try:
step_dirs = [step.rstrip("/") for step in self._all_checkpoint_paths(cfg.dir)]
step_dirs = self._all_checkpoint_paths(cfg.dir)
except fs.NotFoundError:
step_dirs = []

# Gather all candidate checkpoint dirs, as well as all committed checkpoint dirs.
dirs = sorted([os.path.join(cfg.dir, step) for step in step_dirs], reverse=True)
dirs = sorted(step_dirs, reverse=True)
committed_dirs = set(self.checkpoint_paths(cfg.dir))

# Collect the recent non-committed checkpoints, since any of them could be in-progress.
Expand Down
15 changes: 12 additions & 3 deletions axlearn/common/checkpointer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,8 +1137,13 @@ def make_state(float_dtype):
),
)

@parameterized.parameters(jnp.float32, jnp.bfloat16, jnp.int32, jnp.int16)
def test_save_and_restore_from_dir_async(self, restore_floats_as: jnp.dtype):
@parameterized.product(
restore_floats_as=[jnp.float32, jnp.bfloat16, jnp.int32, jnp.int16],
max_concurrent_gb=[None, 1],
)
def test_save_and_restore_from_dir_async(
self, restore_floats_as: jnp.dtype, max_concurrent_gb: Optional[int]
):
mesh_shape = (1, 1)
if not test_utils.is_supported_mesh_shape(mesh_shape):
return
Expand All @@ -1148,7 +1153,11 @@ def make_state(float_dtype):

with _mesh(mesh_shape):
state = make_state(float_dtype=jnp.float32)
storage = TensorStoreStateStorage.default_config().instantiate()
storage = (
TensorStoreStateStorage.default_config()
.set(max_concurrent_gb=max_concurrent_gb)
.instantiate()
)
with tempfile.TemporaryDirectory() as root_dir:
step = 1000
# Save ckpt.
Expand Down

0 comments on commit 11a31d3

Please sign in to comment.