You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
means adam t will increase after every self._lr_decay_steps.
And it means mhat and vhat will not work as moving average because (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) will be very small always. (bellow is adam update code)
def update(i, g, state):
x, m, v = state
m = (1 - b1) * g + b1 * m # First moment estimate.
v = (1 - b2) * jnp.square(g) + b2 * v # Second moment estimate.
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
return x, m, v
effectively deactivating decay_coef and not a single config in PGPE configs changes this.
So there is no easy way to check if the proposed change is a regression. But when I created
the C++ PGPE implementation wrapped in fpgpec.py
I implemented the C++ code exactly as you propose now, see: pgpe.cpp.
Do you have a benchmark problem actually using decay_coef?
Bug
When use adam with PGPE this code
means adam t will increase after every self._lr_decay_steps.
And it means mhat and vhat will not work as moving average because
(1 - jnp.asarray(b1, m.dtype) ** (i + 1))
will be very small always. (bellow is adam update code)Suggestion
I think it is better to change this code to
and to remove self._lr_decay_steps at
The text was updated successfully, but these errors were encountered: