Skip to content

Latest commit

 

History

History
37 lines (31 loc) · 2.63 KB

README.md

File metadata and controls

37 lines (31 loc) · 2.63 KB

Optimizers

We currently have the following optimizers:

Name Optimizer LR Scheduler
adamw_baseline AdamW Cosine Annealing with linear warmup
adamcpr AdamCPR Cosine Annealing with linear warmup
sgd_baseline Stochastic Gradient Descent Cosine Annealing
sgd_stepwise Stochastic Gradient Descent StepLR
adafactor Adafactor Constant

Creating your own optimizer

To add your own optimizer, you need to create a subfolder in the optimizers directory. The name of that folder will be the name used to invoke the optimizer. Within the folder you need to provide two files: optimizer.py and default.yaml. There is a template optimizer with useful comments, which can be used as a starting point.

optimizer.py

Here you need to implement a function configure_optimizers with the following signature:

configure_optimizers(model: GroupedModel, config: OptimizerConfig) -> OptimizerLRScheduler
  • The return type is the same as described here.
  • The GroupedModel is a wrapper around a torch.nn.Module. It additionally provides a method grouped_parameters, which returns the model parameters grouped by their weight_decay and learning_rate settings. This is useful for some tasks that want to use e.g. lower learning rates for different parts of the model or to avoid applying weight decay to your norm layers. The underlying torch.nn.Module can be accessed with model.model.
  • The OptimizerConfig has the lr_interval, max_steps, max_epochs attributes. It also gains all attributes provided in the optimizer section of the experiment.yaml.

default.yaml

Here you can provide default values for all the hyperparameters your optimizer needs. These values will be added to the OptimizerConfig passed to the configure_optimizers. So if you have the following default.yaml:

optimizer:
  name: my_awesome_optimizer
  output_dir_name: my_awesome_optimizer
  learning_rate: 1.e-3
  important:
    extra:
      parameter: 42

you could use config.important.extra.parameter in the configure_optimizers function.