From e7d1a21fc9504d0268a5b757fda42f5408bb431c Mon Sep 17 00:00:00 2001 From: Ben Eisner Date: Thu, 6 Jun 2024 00:23:31 -0400 Subject: [PATCH] resume training (#63) --- configs/train_ndf.yaml | 2 ++ scripts/train_residual_flow.py | 23 ++++++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/configs/train_ndf.yaml b/configs/train_ndf.yaml index e34e33b..1693320 100644 --- a/configs/train_ndf.yaml +++ b/configs/train_ndf.yaml @@ -53,9 +53,11 @@ training: check_val_every_n_epoch: 5 seed: 0 +resume_ckpt: null resources: num_workers: 8 wandb: group: Null + run_id_override: Null diff --git a/scripts/train_residual_flow.py b/scripts/train_residual_flow.py index e669730..390c1ab 100644 --- a/scripts/train_residual_flow.py +++ b/scripts/train_residual_flow.py @@ -14,6 +14,7 @@ from taxpose.training.flow_equivariance_training_module_nocentering import ( EquivarianceTrainingModule, ) +from taxpose.utils.load_model import get_weights_path def load_emb_weights(checkpoint_reference, wandb_cfg=None, run=None): @@ -53,6 +54,24 @@ def main(cfg): # torch.set_float32_matmul_precision("medium") TESTING = "PYTEST_CURRENT_TEST" in os.environ + if cfg.resume_ckpt: + print("Resuming from checkpoint") + print(cfg.resume_ckpt) + resume_ckpt = get_weights_path(cfg.resume_ckpt, cfg.wandb) + + # Resume the wandb run + if cfg.resume_ckpt.startswith(cfg.wandb.entity): + # Get the run_id from the checkpoint + resume_run_id = cfg.resume_ckpt.split("/")[2].split("-")[1].split(":")[0] + elif cfg.wandb.run_id_override is not None: + resume_run_id = cfg.wandb.run_id_override + else: + resume_run_id = None + + else: + resume_ckpt = None + resume_run_id = None + pl.seed_everything(cfg.seed) logger = WandbLogger( entity=cfg.wandb.entity, @@ -62,6 +81,7 @@ def main(cfg): job_type=cfg.job_type, save_code=True, log_model=True, + id=resume_run_id, config=omegaconf.OmegaConf.to_container(cfg, resolve=True), ) # logger.log_hyperparams(cfg) @@ -174,7 +194,8 @@ def main(cfg): cfg.model.pretraining.anchor.ckpt_path ) ) - trainer.fit(model, dm) + + trainer.fit(model, dm, ckpt_path=resume_ckpt) # Print he run id of the current run print("Run ID: {} ".format(logger.experiment.id))