Skip to content

Commit

Permalink
Save on exit (#102)
Browse files Browse the repository at this point in the history
* Save on keyboard interrupt

* Bump up to v0.0.23
  • Loading branch information
erogol authored Mar 6, 2023
1 parent ace9f13 commit 542bd23
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 47 deletions.
2 changes: 1 addition & 1 deletion trainer/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v0.0.22
v0.0.23
1 change: 0 additions & 1 deletion trainer/io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datetime
import json
import sys
import os
import re
import sys
Expand Down
117 changes: 72 additions & 45 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
setup_torch_training_env,
)
from trainer.utils.cuda_memory import cuda_meminfo, should_reduce_batch_size
from trainer.utils.distributed import init_distributed
from trainer.utils.distributed import init_distributed, rank_zero_only

logger = logging.getLogger("trainer")

Expand Down Expand Up @@ -111,6 +111,9 @@ class TrainerConfig(Coqpit):
default="tensorboard", metadata={"help": "Logger to use for the tracking dashboard. Defaults to 'tensorboard'"}
)
# Fields for checkpointing
save_on_interrupt: bool = field(
default=True, metadata={"help": "Save checkpoint on interrupt (Ctrl+C). Defaults to True"}
)
log_model_step: int = field(
default=None,
metadata={
Expand Down Expand Up @@ -455,7 +458,7 @@ def __init__( # pylint: disable=dangerous-default-value
self.eval_samples = None
self.test_samples = None

#define custom train and eval loader
# define custom train and eval loader
self.train_loader = train_loader
self.eval_loader = eval_loader

Expand Down Expand Up @@ -1295,47 +1298,11 @@ def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_ti
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
if self.config.save_checkpoints:
# checkpoint the model
target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
save_checkpoint(
self.config,
self.model,
self.optimizer,
self.scaler if self.use_amp_scaler else None,
self.total_steps_done,
self.epochs_done,
self.output_path,
model_loss=target_avg_loss,
save_n_checkpoints=self.config.save_n_checkpoints,
save_func=self.dashboard_logger.save_model,
)
self.save_checkpoint()

if self.total_steps_done % self.config.log_model_step == 0:
# log checkpoint as artifact
aliases = [
f"epoch-{self.epochs_done}",
f"step-{self.total_steps_done}",
]
self.dashboard_logger.add_artifact(
file_or_dir=self.output_path, name="checkpoint", artifact_type="model", aliases=aliases
)

# training visualizations
if hasattr(self.model, "module") and isimplemented(self.model.module, "train_log"):
self.model.module.train_log(
batch,
outputs,
self.dashboard_logger,
self.training_assets,
self.total_steps_done,
)
elif isimplemented(self.model, "train_log"):
self.model.train_log(
batch,
outputs,
self.dashboard_logger,
self.training_assets,
self.total_steps_done,
)
if self.total_steps_done % self.config.log_model_step == 0:
# log checkpoint as artifact
self.update_training_dashboard_logger(batch=batch, outputs=outputs)

self.dashboard_logger.flush()

Expand Down Expand Up @@ -1683,6 +1650,14 @@ def fit(self) -> None:
if self.args.rank == 0:
self.dashboard_logger.finish()
except KeyboardInterrupt:
logger.info(" > Keyboard interrupt detected.")
if self.config.save_on_interrupt:
logger.info(" > Saving model before exiting...")
# save the model on keyboard interrupt
self.save_checkpoint()
# update the training dashboard logger
self.update_training_dashboard_logger()
# call the keyboard interrupt callback
self.callbacks.on_keyboard_interrupt(self)
# if the output folder is empty remove the run.
remove_experiment_folder(self.output_path)
Expand All @@ -1694,9 +1669,9 @@ def fit(self) -> None:
self.dashboard_logger.finish()
# stop without error signal
try:
sys.exit(0)
sys.exit(1)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
os._exit(1) # pylint: disable=protected-access
except BaseException: # pylint: disable=broad-except
remove_experiment_folder(self.output_path)
traceback.print_exc()
Expand Down Expand Up @@ -1746,6 +1721,7 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None):
self.torch_profiler.stop()
return self.torch_profiler

@rank_zero_only
def save_best_model(self) -> None:
"""Save the best model. It only saves if the current target loss is smaller then the previous."""

Expand All @@ -1768,6 +1744,52 @@ def save_best_model(self) -> None:
save_func=self.dashboard_logger.save_model,
)

@rank_zero_only
def save_checkpoint(self) -> None:
"""Save the current model checkpoint."""
target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
save_checkpoint(
self.config,
self.model,
self.optimizer,
self.scaler if self.use_amp_scaler else None,
self.total_steps_done,
self.epochs_done,
self.output_path,
model_loss=target_avg_loss,
save_n_checkpoints=self.config.save_n_checkpoints,
save_func=self.dashboard_logger.save_model,
)

@rank_zero_only
def update_training_dashboard_logger(self, batch=None, outputs=None):
aliases = [
f"epoch-{self.epochs_done}",
f"step-{self.total_steps_done}",
]
self.dashboard_logger.add_artifact(
file_or_dir=self.output_path, name="checkpoint", artifact_type="model", aliases=aliases
)

# training visualizations
if batch is not None and outputs is not None:
if hasattr(self.model, "module") and isimplemented(self.model.module, "train_log"):
self.model.module.train_log(
batch,
outputs,
self.dashboard_logger,
self.training_assets,
self.total_steps_done,
)
elif isimplemented(self.model, "train_log"):
self.model.train_log(
batch,
outputs,
self.dashboard_logger,
self.training_assets,
self.total_steps_done,
)

#####################
# GET FUNCTIONS
#####################
Expand Down Expand Up @@ -1921,7 +1943,12 @@ def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
if "target_loss" in self.config and self.config.target_loss:
if f"avg_{self.config.target_loss}" in keep_avg_target.avg_values.keys():
return keep_avg_target[f"avg_{self.config.target_loss}"]
return keep_avg_target["avg_loss_1"]
target_loss = keep_avg_target["avg_loss_1"]
if target_loss is None:
raise ValueError(
" [!] Target loss not found in the keep_avg_target. You might be exiting the training loop before it is computed or set the target_loss in the model config incorrectly."
)
return target_loss

# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
if isinstance(self.optimizer, list):
Expand Down

0 comments on commit 542bd23

Please sign in to comment.