Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Implement Adagrad optimizer #19

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions ngraph/frontends/neon/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,73 @@ def variable_update(self, variable, grad, scale_factor):
variable - (scale_factor * self.ell * m) / (ng.sqrt(v) + self.epsilon))
])
return updates


class Adagrad(LearningRateOptimizer):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a docstring? You can use this one as a good reference: https://github.com/NervanaSystems/neon/blob/master/neon/optimizers/optimizer.py#L637

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll do that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Updated. Thanks.

Adagrad optimization algorithm.

Adagrad is an algorithm that adapts the learning rate individually for each parameter
by dividing by the :math:`L_2`-norm of all previous gradients. Given the parameters
:math:`\\theta`, gradient :math:`\\nabla J`, accumulating norm :math:`G`, and smoothing
factor :math:`\\epsilon`, we use the update equations:

.. math::

G' = G + (\\nabla J)^2

.. math::

\\theta' = \\theta - \\frac{\\alpha}{\sqrt{G' + \\epsilon}} \\nabla J

where the smoothing factor :math:`\\epsilon` prevents from dividing by zero.
By adjusting the learning rate individually for each parameter, Adagrad adapts
to the geometry of the error surface. Differently scaled weights have appropriately scaled
update steps.

Example usage:

.. code-block:: python

import ngraph as ng
from ngraph.frontends.neon.optimizers import Adagrad

# use Adagrad with a learning rate of 1e-3
optimizer = Adagrad(learning_rate=1e-3, epsilon=1e-8)
"""
metadata = {'layer_type': 'adagrad_optimizer'}

def __init__(
self,
learning_rate=1e-3,
epsilon=1e-8,
gradient_clip_norm=None,
gradient_clip_value=None,
**kwargs
):
"""
Class constructor.
Arguments:
learning_rate (float): the multiplication coefficient of updates
epsilon (float): numerical stability factor
gradient_clip_norm (float, optional): Target gradient norm.
Defaults to None.
gradient_clip_value (float, optional): Value to element-wise clip
gradients.
Defaults to None.
"""
super(Adagrad, self).__init__(learning_rate, **kwargs)
self.epsilon = epsilon
self.gradient_clip_norm = gradient_clip_norm
self.gradient_clip_value = gradient_clip_value

def variable_update(self, variable, grad, scale_factor):
grad = clip_gradient_value(grad, self.gradient_clip_value)
state = ng.persistent_tensor(axes=grad.axes, initial_value=0.)
updates = ng.sequential([
ng.assign(state, state + ng.square(grad)),
ng.assign(variable,
variable - (scale_factor * self.lrate * grad)
/ (ng.sqrt(state + self.epsilon)))
])
return updates
47 changes: 46 additions & 1 deletion ngraph/frontends/neon/tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import pytest
import numpy as np
import ngraph as ng
from ngraph.frontends.neon import GradientDescentMomentum, RMSProp, Adam, LearningRateOptimizer
from ngraph.frontends.neon import GradientDescentMomentum, RMSProp, Adam
from ngraph.frontends.neon import LearningRateOptimizer, Adagrad
from ngraph.testing.execution import ExecutorFactory

pytestmark = pytest.mark.transformer_dependent
Expand Down Expand Up @@ -124,6 +125,33 @@ def __call__(self, input_data, weights):
return weights


class AdagradReference(object):
'''
Simple numpy reference for Adagrad
'''
def __init__(self, learning_rate, epsilon):
self.learning_rate = learning_rate
self.epsilon = epsilon
self.state = None

def __call__(self, input_data, weights):
'''
input_data in this case is a numpy array with batch_size on axis 1
and weights is a matrix with 1 column
'''
if self.state is None:
self.state = np.zeros_like(weights)

gradient = - input_data.mean(axis=1)

self.state[:] = self.state + np.square(gradient)

weights[:] = weights \
- gradient * self.learning_rate / (np.sqrt(self.state + self.epsilon))

return weights


def compare_optimizer(opt_ng, opt_ref):

# Set up data placeholders
Expand Down Expand Up @@ -279,6 +307,22 @@ def test_adam(random_learning_rate, random_beta_1, random_beta_2, epsilon, selec
compare_optimizer(adam, adam_reference)


@pytest.mark.parametrize("epsilon", [1e-6])
@pytest.mark.parametrize("select_variables", [False, True])
def test_adagrad(random_learning_rate, epsilon, select_variables):
adagrad_args = {'learning_rate': random_learning_rate,
'epsilon': epsilon}

adagrad_ref = AdagradReference(**adagrad_args)
adagrad = Adagrad(**adagrad_args)

# test baseline against reference
if select_variables:
compare_optimizer_variable_select(adagrad, adagrad_ref)
else:
compare_optimizer(adagrad, adagrad_ref)


@pytest.config.argon_disabled # TODO triage
@pytest.config.flex_disabled(reason="Unknown problem yet")
def test_learning_policy_step():
Expand Down Expand Up @@ -366,3 +410,4 @@ def test_learning_policy_schedule(drop_factor):
test_rmsprop(0.1, 0.95, 1e-6)
test_gdm(0.1, 0.1, 0.1, False)
test_adam(0.1, 0.5, 0.9, 1e-6, None)
test_adagrad(0.1, 1e-6)