Skip to content

Commit

Permalink
Update checkpointing to use fsspec
Browse files Browse the repository at this point in the history
Summary:

- Make the data/checkpoint code fsspec compatible
- Still will not work with s3 saves, due to `torch.distributed.checkpoint.save` not being out of the box workable with `fsspec`. Will implement in followup PR


Test Plan:

Run unit tests and the commands below

```
python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100
```

```
torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100
```

These currently won't work due to the torch distributed save, but theses hould be tested at a later date

```
python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/
```

```
torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/
```
  • Loading branch information
EntilZha committed Feb 6, 2025
1 parent c79b1fd commit 3412646
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 72 deletions.
14 changes: 9 additions & 5 deletions bytelatent/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,19 @@ class TrainArgs(BaseModel):
def dump_to_yaml_file(
self, path: str, log_config: bool = True, sort_keys: bool = True
):
yaml_str = self.dump_to_yaml_str(sort_keys=sort_keys)
with open(path, "w") as f:
if log_config:
logger.info("Using the following config for this run:")
logger.info(yaml_str)
f.write(yaml_str)

def dump_to_yaml_str(self, sort_keys: bool = True):
model_dict = self.model_dump(mode="json")
yaml_str = yaml.dump(
model_dict,
allow_unicode=True,
sort_keys=sort_keys,
default_flow_style=False,
)
with open(path, "w") as f:
if log_config:
logger.info("Using the following config for this run:")
logger.info(yaml_str)
f.write(yaml_str)
return yaml_str
115 changes: 62 additions & 53 deletions bytelatent/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import logging
import os
import re
from pathlib import Path
from typing import List, Optional, Tuple

import fsspec
import s3fs
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
Expand Down Expand Up @@ -70,26 +69,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
Returns the path to the consolidated checkpoint
"""
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
if not (consolidate_path / CONSOLIDATE_NAME).exists():
consolidate_path.mkdir(exist_ok=True)
logger.info(f"Consolidating to: {str(consolidate_path)}")
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
(consolidate_path / CONFIG_NAME).write_text(
(Path(ckpt_dir) / CONFIG_NAME).read_text()
consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER)
consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME)
if not fs.exists(consolidate_name):
fs.mkdirs(consolidate_path, exist_ok=True)
logger.info(f"Consolidating to: {consolidate_path}")
dcp_to_torch_save(ckpt_dir, consolidate_name)
fs.write_text(
os.path.join(consolidate_path, CONFIG_NAME),
fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)),
)
logger.info("Consolidated !")
return consolidate_path


def load_from_checkpoint(
fs: fsspec.AbstractFileSystem,
ckpt_dir: str,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
optimizer: torch.optim.Optimizer | None = None,
model_key: str = "model",
optim_key: str = "optim",
):
if not (Path(ckpt_dir) / ".metadata").exists():
if not fs.exists(os.path.join(ckpt_dir, ".metadata")):
raise ValueError(
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
)
Expand All @@ -115,19 +117,24 @@ def __init__(self, args: CheckpointArgs):
self.init_ckpt_path = args.init_ckpt_path
self.continue_training_from_init = args.continue_training_from_init

assert self.fs.exists(
self.path
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
if not isinstance(self.fs, s3fs.S3FileSystem):
# S3 does not have a concept of directories
assert self.fs.exists(
self.path
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"

self.existing_saves = self.get_existing_saves()

def get_existing_saves(self) -> List[Path]:
folders = [
p
for p in Path(self.path).iterdir()
if p.is_dir() and re.match(RE_FOLDER, p.name)
]
folders.sort(key=lambda p: _get_key_step(p.name))
def get_existing_saves(self) -> list[str]:
if self.fs.exists(self.path) and self.fs.isdir(self.path):
folders = [
p
for p in self.fs.ls(self.path)
if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p))
]
else:
folders = []
folders.sort(key=lambda p: _get_key_step(os.path.basename(p)))
return folders

def clean_up(self):
Expand All @@ -136,8 +143,9 @@ def clean_up(self):
eval_folders = []
other_folders = []
for p in self.existing_saves:
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
assert isinstance(p, str), f"Base path type: {p}"
is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0
is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0
if is_dump:
dump_folders.append(p)
if is_eval:
Expand All @@ -161,40 +169,39 @@ def clean_up(self):

if dist.get_rank() == 0:
for folder in folder_to_remove:
for file in folder.iterdir():
if file.is_file():
file.unlink()
elif file.is_dir():
assert file.name in [CONSOLIDATE_FOLDER]
for f in file.iterdir():
f.unlink()
file.rmdir()
folder.rmdir()
for file in self.fs.ls(folder):
if self.fs.isfile(file):
self.fs.rm_file(file)
elif self.fs.isdir(file):
assert os.path.name(file) in [CONSOLIDATE_FOLDER]
for f in self.fs.ls(file):
self.fs.rm(f)
self.fs.rmdir(file)
self.fs.rmdir(folder)

dist.barrier()

self.existing_saves = list(folder_to_keep)
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p)))

def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
def get_last_step_path(self, dp_rank: int = 0) -> str | None:
path = None
for p in reversed(self.existing_saves):
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():

if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))):
path = p
break
return path

def _create_folder(self, base_path: Path, folder_name: str) -> Path:
folder = base_path / folder_name
def _create_folder(self, base_path: str, folder_name: str) -> str:
folder = os.path.join(base_path, folder_name)
if get_is_master():
folder.mkdir(parents=False, exist_ok=True)
self.fs.mkdirs(folder, exist_ok=True)
if dist.is_initialized():
dist.barrier()
return folder

def _get_dp_tp_mesh(
self, device_mesh: Optional[DeviceMesh] = None
) -> Tuple[int, int]:
def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]:
dp_rank = 0
tp_rank = 0
if device_mesh is not None:
Expand Down Expand Up @@ -222,14 +229,14 @@ def save(
model,
optimizer,
train_state,
config,
device_mesh: Optional[DeviceMesh] = None,
config: BaseModel,
device_mesh: DeviceMesh | None = None,
) -> bool:

# When creating directory check if only rank0 or is there other solution
path = Path(self.path)
path = self.path
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
logger.info(f"Saving to: {str(curr_save_dir)}")
logger.info(f"Saving to: {curr_save_dir}")

if dist.is_initialized():
dist.barrier()
Expand All @@ -242,17 +249,19 @@ def save(
if dist.is_initialized():
dist.barrier()

print("config type", type(config))
if get_is_master():
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
self.fs.write_text(
os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json()
)

# Add json dump here
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
if tp_rank == 0:
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info(
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
)
with open(curr_save_dir / train_state_name, "w") as f:
train_state_full_path = os.path.join(curr_save_dir, train_state_name)
logger.info(f"Saving train state to: {train_state_full_path}")
with self.fs.open(train_state_full_path, "w") as f:
json.dump(train_state.state_dict(), f)
logger.info("Train state saved !")

Expand All @@ -271,7 +280,7 @@ def load(
optimizer,
train_state,
device_mesh: DeviceMesh,
path: Optional[Path] = None,
path: str | None = None,
):
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
Expand All @@ -284,12 +293,12 @@ def load(
# Only load train state if it's provided, the files exist and we're not loading from init path
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info("Reloading train state")
with open(path / train_state_name, "r") as f:
with self.fs.open(os.path.join(path, train_state_name), "r") as f:
train_state_dict = json.load(f)
train_state.load_state_dict(train_state_dict)
logger.info("Train state reloaded")

logger.info(f"Loading from: {str(path)}")
logger.info(f"Loading from: {path}")
state_dict = self.get_state_dict(
model=model,
optimizer=optimizer,
Expand Down
9 changes: 8 additions & 1 deletion bytelatent/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import time
from datetime import timedelta

import fsspec

from bytelatent.distributed import get_global_rank, get_is_slurm_job


Expand Down Expand Up @@ -92,6 +94,7 @@ def init_logger(
*,
name: str | None = None,
level: str = "INFO",
fs: fsspec.AbstractFileSystem | None = None,
):
"""
Setup logging.
Expand Down Expand Up @@ -121,7 +124,11 @@ def init_logger(

if log_file is not None and get_global_rank() == 0:
# build file handler
file_handler = logging.FileHandler(log_file, "a")
if fs is None:
file_handler = logging.FileHandler(log_file, "a")
else:
file_stream = fs.open(log_file, mode="a")
file_handler = logging.StreamHandler(file_stream)
file_handler.setLevel(logging.NOTSET)
file_handler.setFormatter(LogFormatter())
# update logger
Expand Down
15 changes: 13 additions & 2 deletions bytelatent/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path
from typing import Any, Union

import fsspec
import torch
import torch.nn as nn
import wandb
Expand Down Expand Up @@ -54,14 +55,24 @@ class LoggingArgs(BaseModel):


class MetricLogger:
def __init__(self, outdir: Path, args: Any | None = None):
def __init__(
self,
outdir: Path,
# args: TrainArgs
args: Any | None = None,
fs: fsspec.AbstractFileSystem | None = None,
):
self.outdir = outdir
self.jsonl_writer = None
self.fs = fs
self.args = args

def open(self):
if self.jsonl_writer is None:
self.jsonl_writer = open(self.outdir, "a")
if self.fs is None:
self.jsonl_writer = open(self.outdir, "a")
else:
self.jsonl_writer = self.fs.open(self.outdir, "a")
if (
self.args is not None
and self.args.logging.wandb is not None
Expand Down
Loading

0 comments on commit 3412646

Please sign in to comment.