-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathwarmup.py
73 lines (62 loc) · 2.49 KB
/
warmup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
Copyright (C) eqtgroup.com Ltd 2021
https://github.com/EQTPartners/pause
License: MIT, https://github.com/EQTPartners/pause/LICENSE.md
"""
import tensorflow as tf
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applies a warmup schedule on a given learning rate decay schedule."""
def __init__(
self,
initial_learning_rate: float,
decay_schedule_fn: tf.keras.optimizers.schedules.LearningRateSchedule,
warmup_steps: int,
power: float = 1.0,
name: str = None,
) -> None:
"""Initialize the WarmUp Class.
Args:
initial_learning_rate (float): initial learning rate.
decay_schedule_fn (tf.keras.optimizers.schedules.LearningRateSchedule): A learning rate schedule function.
warmup_steps (int): The number of warm up steps.
power (float, optional): The power parameter. Defaults to 1.0.
name (str, optional): The name of the op. Defaults to None.
"""
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.power = power
self.decay_schedule_fn = decay_schedule_fn
self.name = name
def __call__(self, step: int) -> tf.Tensor:
"""Obtain the warm-up learning rate.
Args:
step (int): The current training step.
Returns:
tf.Tensor: The learning rate.
"""
with tf.name_scope(self.name or "WarmUp") as name:
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = self.initial_learning_rate * tf.math.pow(
warmup_percent_done, self.power
)
return tf.cond(
global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step),
name=name,
)
def get_config(self) -> dict:
"""Obtain the config of this warm-up object.
Returns:
dict: The values of configurations.
"""
return {
"initial_learning_rate": self.initial_learning_rate,
"decay_schedule_fn": self.decay_schedule_fn,
"warmup_steps": self.warmup_steps,
"power": self.power,
"name": self.name,
}