forked from SalesforceAIResearch/uni2ts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdefault.yaml
74 lines (74 loc) · 2.18 KB
/
default.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
hydra:
run:
dir: outputs/pretrain/${hydra:runtime.choices.model}/${hydra:runtime.choices.data}/${run_name}
defaults:
- model: ???
- data: ???
- val_data: null
- _self_
run_name: ???
seed: 0
tf32: true
compile: false # set to mode: default, reduce-overhead, max-autotune
ckpt_path: null # set to "last" to resume training
trainer:
_target_: lightning.Trainer
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 32
logger:
_target_: lightning.pytorch.loggers.TensorBoardLogger
save_dir: ${hydra:runtime.output_dir}
name: logs
callbacks:
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: epoch
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${hydra:runtime.output_dir}/checkpoints
filename: last
monitor: epoch
mode: max
save_top_k: 1
every_n_epochs: 10
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${hydra:runtime.output_dir}/checkpoints
monitor: epoch
save_weights_only: true
mode: max
save_top_k: -1
every_n_epochs: ${floordiv:${trainer.max_epochs},10}
- _target_: uni2ts.callbacks.HuggingFaceCheckpoint.HuggingFaceCheckpoint
dirpath: ${hydra:runtime.output_dir}/HF_checkpoints
filename: last
monitor: epoch
mode: max
save_top_k: 1
every_n_epochs: 1
# epoch-based training provides averaged metrics
# cannot use max_steps with epoch-based training - resume from checkpoint on wrong epoch
max_epochs: 1_000
enable_progress_bar: true
accumulate_grad_batches: 1
gradient_clip_val: 1.0
gradient_clip_algorithm: norm
train_dataloader:
_target_: uni2ts.data.loader.DataLoader
batch_size: 128
batch_size_factor: 2.0
cycle: true
num_batches_per_epoch: 100
shuffle: true
num_workers: 11
collate_fn:
_target_: uni2ts.data.loader.PackCollate
max_length: ${model.module_kwargs.max_seq_len}
seq_fields: ${cls_getattr:${model._target_},seq_fields}
pad_func_map: ${cls_getattr:${model._target_},pad_func_map}
pin_memory: true
drop_last: true
fill_last: false
worker_init_fn: null
prefetch_factor: 2
persistent_workers: true