Skip to content

Commit

Permalink
[Fix] VFI: update param_scheduler, hooks, and epoch_base_runner
Browse files Browse the repository at this point in the history
  • Loading branch information
Yshuo-Li committed Jul 18, 2022
1 parent 5dc5589 commit 8eb5b5d
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 42 deletions.
25 changes: 11 additions & 14 deletions configs/video_interpolators/cain/cain_b5_g1b32_vimeo90k_triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,7 @@
]
test_evaluator = val_evaluator

# 1604 iters == 1 epoch
epoch_length = 1604

train_cfg = dict(
type='IterBasedTrainLoop', max_iters=300_000, val_interval=epoch_length)
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

Expand All @@ -138,23 +134,24 @@
# learning policy
param_scheduler = dict(
type='ReduceLR',
by_epoch=False,
by_epoch=True,
mode='min',
factor=0.5,
patience=5,
cooldown=0,
verbose=True)
cooldown=0)

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=epoch_length * 4,
save_optimizer=True,
by_epoch=False),
type='CheckpointHook', interval=1, save_optimizer=True, by_epoch=True),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=1),
logger=dict(type='LoggerHook', interval=100),
sampler_seed=dict(type='DistSamplerSeedHook'),
# visualization=dict(type='EditVisualizationHook'),
param_scheduler=dict(
type='ReduceLRSchedulerHook', by_epoch=False, val_metric='MAE'),
type='ReduceLRSchedulerHook',
by_epoch=True,
interval=1,
val_metric='MAE'),
)

log_processor = dict(type='LogProcessor', by_epoch=True)
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,7 @@
]
test_evaluator = val_evaluator

epoch_length = 2020

train_cfg = dict(
type='IterBasedTrainLoop', max_iters=1_000_000, val_interval=epoch_length)
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

Expand All @@ -146,30 +143,26 @@
))

# learning policy
# 1604 iters == 1 epoch
total_iters = 1000000
lr_config = dict(
param_scheduler = dict(
type='ReduceLR',
by_epoch=False,
by_epoch=True,
mode='min',
factor=0.5,
patience=10,
cooldown=20,
verbose=True)
cooldown=20)

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=epoch_length * 2,
save_optimizer=True,
by_epoch=False),
type='CheckpointHook', interval=1, save_optimizer=True, by_epoch=True),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
sampler_seed=dict(type='DistSamplerSeedHook'),
# visualization=dict(type='EditVisualizationHook'),
param_scheduler=dict(
type='ReduceLRSchedulerHook',
by_epoch=False,
interval=epoch_length,
by_epoch=True,
interval=1,
val_metric='MAE'),
)

log_processor = dict(type='LogProcessor', by_epoch=True)
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
))

# learning policy
lr_config = dict(
param_scheduler = dict(
type='StepLR',
by_epoch=False,
gamma=0.5,
Expand Down
10 changes: 1 addition & 9 deletions mmedit/optimizer/scheduler/reduce_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class ReduceLR(_ParamScheduler):
eps (float, optional): Minimal decay applied to lr. If the difference
between new and old lr is smaller than eps, the update is
ignored. Default: 1e-8.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
begin (int): Step at which to start updating the learning rate.
Defaults to 0.
end (int): Step at which to stop updating the learning rate.
Expand All @@ -68,7 +66,6 @@ def __init__(self,
cooldown: int = 0,
min_lr: float = 0.,
eps: float = 1e-8,
verbose: bool = False,
**kwargs):

super().__init__(optimizer=optimizer, param_name='lr', **kwargs)
Expand Down Expand Up @@ -99,7 +96,6 @@ def __init__(self,
self.mode_worse = None # the worse value for the chosen mode
self.min_lr = min_lr
self.eps = eps
self.verbose = verbose
self.last_epoch = 0
self._init_is_better(self.mode)
self._reset()
Expand Down Expand Up @@ -130,11 +126,7 @@ def _get_value(self):
for group in self.optimizer.param_groups:
regular_lr = group[self.param_name]
if regular_lr - regular_lr * self.factor > self.eps:
new_lr = max(regular_lr * self.factor, self.min_lr)
if self.verbose:
print(f'Reducing learning rate of {group} from '
f'{regular_lr:.4e} to {new_lr:.4e}.')
regular_lr = new_lr
regular_lr = max(regular_lr * self.factor, self.min_lr)
results.append(regular_lr)
return results

Expand Down
4 changes: 3 additions & 1 deletion mmedit/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,15 @@ def register_all_modules(init_default_scope: bool = True) -> None:
Defaults to True.
""" # noqa
import mmedit.datasets # noqa: F401,F403
import mmedit.hooks # noqa: F401,F403
import mmedit.metrics # noqa: F401,F403
import mmedit.models # noqa: F401,F403
import mmedit.optimizer # noqa: F401,F403
import mmedit.transforms # noqa: F401,F403

if init_default_scope:
never_created = DefaultScope.get_current_instance() is None \
or not DefaultScope.check_instance_created('mmedit')
or not DefaultScope.check_instance_created('mmedit')
if never_created:
DefaultScope.get_instance('mmedit', scope_name='mmedit')
return
Expand Down
1 change: 0 additions & 1 deletion tools/dist_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ python -m torch.distributed.launch \
--master_port=$PORT \
$(dirname "$0")/train.py \
$CONFIG \
--seed 0 \
--launcher pytorch ${@:3}

0 comments on commit 8eb5b5d

Please sign in to comment.