Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cosine restart learning rate #2953

Open
wants to merge 3 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def gelu_wrapper(x):
"softplus": tf.nn.softplus,
"sigmoid": tf.sigmoid,
"tanh": tf.nn.tanh,
"swish": tf.nn.swish,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that it has been renamed to silu: tensorflow/tensorflow#41066

"gelu": gelu,
"gelu_tf": gelu_tf,
"None": None,
Expand Down
56 changes: 34 additions & 22 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
)
from deepmd.utils.learning_rate import (
LearningRateExp,
LearningRateCos,
LearningRateCosRestarts,
)
from deepmd.utils.sess import (
run_sess,
Expand Down Expand Up @@ -113,13 +115,21 @@ def get_lr_and_coef(lr_param):
scale_lr_coef = np.sqrt(self.run_opt.world_size).real
else:
scale_lr_coef = 1.0
lr_type = lr_param.get("type", "exp")
if lr_type == "exp":
self.lr_type = lr_param.get("type", "exp")
if self.lr_type == "exp":
lr = LearningRateExp(
lr_param["start_lr"], lr_param["stop_lr"], lr_param["decay_steps"]
)
elif self.lr_type == "cos":
lr = LearningRateCos(
lr_param["start_lr"], lr_param["stop_lr"], lr_param["decay_steps"]
)
elif self.lr_type == "cosrestart":
lr = LearningRateCosRestarts(
lr_param["start_lr"], lr_param["stop_lr"], lr_param["decay_steps"]
)
else:
raise RuntimeError("unknown learning_rate type " + lr_type)
raise RuntimeError("unknown learning_rate type " + self.lr_type)
return lr, scale_lr_coef

# learning rate
Expand Down Expand Up @@ -553,29 +563,31 @@ def train(self, train_data=None, valid_data=None):
is_first_step = True
self.cur_batch = cur_batch
if not self.multi_task_mode:
log.info(
"start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e"
% (
run_sess(self.sess, self.learning_rate),
self.lr.value(cur_batch),
self.lr.decay_steps_,
self.lr.decay_rate_,
self.lr.value(stop_batch),
)
)
else:
for fitting_key in self.fitting:
if self.lr_type == "exp":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a good behavior to switch the learning rate in the Trainer. Instead, implement the method LearningRate.log_start (LearningRate should be an abstract base class and inherited by all learning rate classes) and call self.lr.log_start(self.sess) here.

log.info(
"%s: start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e"
"start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e"
% (
fitting_key,
run_sess(self.sess, self.learning_rate_dict[fitting_key]),
self.lr_dict[fitting_key].value(cur_batch),
self.lr_dict[fitting_key].decay_steps_,
self.lr_dict[fitting_key].decay_rate_,
self.lr_dict[fitting_key].value(stop_batch),
run_sess(self.sess, self.learning_rate),
self.lr.value(cur_batch),
self.lr.decay_steps_,
self.lr.decay_rate_,
self.lr.value(stop_batch),
)
)
else:
for fitting_key in self.fitting:
if self.lr_type == "exp":
log.info(
"%s: start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e"
% (
fitting_key,
run_sess(self.sess, self.learning_rate_dict[fitting_key]),
self.lr_dict[fitting_key].value(cur_batch),
self.lr_dict[fitting_key].decay_steps_,
self.lr_dict[fitting_key].decay_rate_,
self.lr_dict[fitting_key].value(stop_batch),
)
)

prf_options = None
prf_run_metadata = None
Expand Down
31 changes: 30 additions & 1 deletion deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,13 +1010,42 @@ def learning_rate_exp():
]
return args

def learning_rate_cos():
doc_start_lr = "The learning rate the start of the training."
doc_stop_lr = "The desired learning rate at the end of the training."
doc_decay_steps = (
"Number of steps to decay over."
)

args = [
Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr),
Argument("stop_lr", float, optional=True, default=1e-8, doc=doc_stop_lr),
Argument("decay_steps", int, optional=True, default=100000, doc=doc_decay_steps),
]
return args

def learning_rate_cosrestarts():
doc_start_lr = "The learning rate the start of the training."
doc_stop_lr = "The desired learning rate at the end of the training."
doc_decay_steps = (
"Number of steps to decay over of the first decay."
)

args = [
Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr),
Argument("stop_lr", float, optional=True, default=1e-8, doc=doc_stop_lr),
Argument("decay_steps", int, optional=True, default=10000, doc=doc_decay_steps),
]
return args

def learning_rate_variant_type_args():
doc_lr = "The type of the learning rate."

return Variant(
"type",
[Argument("exp", dict, learning_rate_exp())],
[Argument("exp", dict, learning_rate_exp()),
Argument("cos", dict, learning_rate_cos()),
Argument("cosrestart", dict, learning_rate_cosrestarts())],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may need to add some documentation to variants (doc="xxx"). Otherwise, no one knows what they are.

optional=True,
default_tag="exp",
doc=doc_lr,
Expand Down
171 changes: 171 additions & 0 deletions deepmd/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,174 @@ def start_lr(self) -> float:
def value(self, step: int) -> float:
"""Get the lr at a certain step."""
return self.start_lr_ * np.power(self.decay_rate_, (step // self.decay_steps_))

class LearningRateCos:
r"""The cosine decaying learning rate.

The function returns the decayed learning rate. It is computed as:
```python
global_step = min(global_step, decay_steps)
cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps))
decayed = (1 - alpha) * cosine_decay + alpha
decayed_learning_rate = learning_rate * decayed
```
Comment on lines +113 to +118
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Parameters
----------
start_lr
Starting learning rate
stop_lr
Minimum learning rate value as a fraction of learning_rate.
decay_steps
Number of steps to decay over.
"""

def __init__(
self,
start_lr: float,
stop_lr: float = 5e-8,
decay_steps: int = 100000,
) -> None:
"""Constructor."""
self.cd = {}
self.cd["start_lr"] = start_lr
self.cd["stop_lr"] = stop_lr
self.cd["decay_steps"] = decay_steps
self.start_lr_ = self.cd["start_lr"]
self.alpha_ = self.cd["stop_lr"]/self.cd["start_lr"]

def build(
self, global_step: tf.Tensor, stop_step: Optional[int] = None
) -> tf.Tensor:
"""Build the learning rate.

Parameters
----------
global_step
The tf Tensor prividing the global training step
stop_step
The stop step.

Returns
-------
learning_rate
The learning rate
"""
if stop_step is None:
self.decay_steps_ = (
self.cd["decay_steps"] if self.cd["decay_steps"] is not None else 100000
)
else:
self.stop_lr_ = (
self.cd["stop_lr"] if self.cd["stop_lr"] is not None else 5e-8
)
self.decay_steps_ = (
self.cd["decay_steps"]
if self.cd["decay_steps"] is not None
else stop_step
)

return tf.train.cosine_decay(
self.start_lr_,
global_step,
self.decay_steps_,
self.alpha_,
name="cosine",
)

def start_lr(self) -> float:
"""Get the start lr."""
return self.start_lr_

def value(self, step: int) -> float:
"""Get the lr at a certain step."""
step = min(step, self.decay_steps_)
cosine_decay = 0.5 * (1 + np.cos(np.pi * step / self.decay_steps_))
decayed = (1 - self.alpha_) * cosine_decay + self.alpha_
decayed_learning_rate = self.start_lr_ * decayed
return decayed_learning_rate


class LearningRateCosRestarts:
r"""The cosine decaying restart learning rate.

The function returns the cosine decayed learning rate while taking into account
possible warm restarts.
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line should be removed.


Parameters
----------
start_lr
Starting learning rate
stop_lr
Minimum learning rate value as a fraction of learning_rate.
decay_steps
Number of steps to decay over.
"""

def __init__(
self,
start_lr: float,
stop_lr: float = 5e-8,
decay_steps: int = 10000,
) -> None:
"""Constructor."""
self.cd = {}
self.cd["start_lr"] = start_lr
self.cd["stop_lr"] = stop_lr
self.cd["decay_steps"] = decay_steps
self.start_lr_ = self.cd["start_lr"]
self.alpha_ = self.cd["stop_lr"]/self.cd["start_lr"]

def build(
self, global_step: tf.Tensor, stop_step: Optional[int] = None
) -> tf.Tensor:
"""Build the learning rate.

Parameters
----------
global_step
The tf Tensor prividing the global training step
stop_step
The stop step.

Returns
-------
learning_rate
The learning rate
"""
if stop_step is None:
self.decay_steps_ = (
self.cd["decay_steps"] if self.cd["decay_steps"] is not None else 10000
)
else:
self.stop_lr_ = (
self.cd["stop_lr"] if self.cd["stop_lr"] is not None else 5e-8
)
self.decay_steps_ = (
self.cd["decay_steps"]
if self.cd["decay_steps"] is not None
else stop_step
)



return tf.train.cosine_decay_restarts(
learning_rate=self.start_lr_,
global_step=global_step,
first_decay_steps=self.decay_steps_,
alpha=self.alpha_,
name="cosinerestart",
)

def start_lr(self) -> float:
"""Get the start lr."""
return self.start_lr_

def value(self, step: int) -> float:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may not need to implement the value method if you do not print the information regarding the learning rate at the beginning of the training:
https://github.com/hellozhaoming/deepmd-kit/blob/05052c195308f61b63ce2bab130ce0e8cba60604/deepmd/train/trainer.py#L566

"""Get the lr at a certain step. Need to revise later"""
step = min(step, self.decay_steps_)
cosine_decay = 0.5 * (1 + np.cos(np.pi * step / self.decay_steps_))
decayed = (1 - self.alpha_) * cosine_decay + self.alpha_
decayed_learning_rate = self.start_lr_ * decayed
return decayed_learning_rate