diff --git a/trainer/VERSION b/trainer/VERSION index c55013b..08692e0 100644 --- a/trainer/VERSION +++ b/trainer/VERSION @@ -1 +1 @@ -v0.0.25 +v0.0.26 diff --git a/trainer/trainer.py b/trainer/trainer.py index 750ec63..cd9029a 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -156,6 +156,12 @@ class TrainerConfig(Coqpit): ) # Fields for training specs mixed_precision: bool = field(default=False, metadata={"help": "Use mixed precision training. Defaults to False"}) + precision: str = field( + default="fp16", + metadata={ + "help": "Precision to use in mixed precision training. `fp16` for float16 and `bf16` for bfloat16. Defaults to 'f16'" + }, + ) epochs: int = field(default=1000, metadata={"help": "Number of epochs to train. Defaults to 1000"}) batch_size: int = field(default=32, metadata={"help": "Batch size to use. Defaults to 32"}) eval_batch_size: int = field(default=16, metadata={"help": "Batch size to use for eval. Defaults to 16"}) @@ -438,7 +444,11 @@ def __init__( # pylint: disable=dangerous-default-value self.keep_avg_eval = None self.use_apex = self._is_apex_available() - self.use_amp_scaler = self.use_cuda if self.config.mixed_precision else self.config.use_grad_scaler + self.use_amp_scaler = ( + self.use_cuda + if self.config.mixed_precision and self.config.precision == "fp16" + else self.config.use_grad_scaler + ) if train_samples is not None: # use the provided samples @@ -993,12 +1003,19 @@ def _model_train_step( return model.module.train_step(*input_args) return model.train_step(*input_args) - def _get_autocast_args(self, mixed_precision: bool): + def _get_autocast_args(self, mixed_precision: bool, precision: str): device = "cpu" dtype = torch.get_autocast_cpu_dtype() if self.use_cuda: device = "cuda" - dtype = torch.float16 if mixed_precision else torch.float32 + dtype = torch.float32 + if mixed_precision: + if precision == "fp16": + dtype = torch.float16 + elif precision == "bf16": + dtype = torch.bfloat16 + else: + raise ValueError(f" ❗ Unknown precision {precision}") elif mixed_precision: dtype = torch.bfloat16 return device, dtype @@ -1057,7 +1074,7 @@ def optimize( step_start_time = time.time() # forward pass and loss computation - device, dtype = self._get_autocast_args(config.mixed_precision) + device, dtype = self._get_autocast_args(config.mixed_precision, config.precision) with torch.autocast(device_type=device, dtype=dtype, enabled=config.mixed_precision): if optimizer_idx is not None: outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx) @@ -1170,7 +1187,7 @@ def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_ti if isimplemented(self.model, "optimize"): # pylint: disable=too-many-nested-blocks # custom optimize for the model step_time = time.time() - device, dtype = self._get_autocast_args(self.config.mixed_precision) + device, dtype = self._get_autocast_args(self.config.mixed_precision, self.config.precision) with torch.autocast(device_type=device, dtype=dtype, enabled=self.config.mixed_precision): outputs, loss_dict_new = self.model.optimize( batch,