Skip to content

Commit

Permalink
Add types optim module (#1942)
Browse files Browse the repository at this point in the history
* init

* more types

* refine hints

* attempt keep TypeVar

* TypeVar -> Any
  • Loading branch information
juanitorduz authored Jan 3, 2025
1 parent e5fc673 commit 6ae76ea
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 30 deletions.
78 changes: 48 additions & 30 deletions numpyro/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from collections import namedtuple
from collections.abc import Callable
from typing import Any, TypeVar
from typing import Any

import jax
from jax import jacfwd, lax, value_and_grad
Expand All @@ -18,6 +18,7 @@
import jax.numpy as jnp
from jax.scipy.optimize import minimize
from jax.tree_util import register_pytree_node
from jax.typing import ArrayLike

__all__ = [
"Adam",
Expand All @@ -31,12 +32,12 @@
"SM3",
]

_Params = TypeVar("_Params")
_OptState = TypeVar("_OptState")
_IterOptState = tuple[int, _OptState]
_Params = Any
_OptState = Any
_IterOptState = tuple[ArrayLike, _OptState]


def _value_and_grad(f, x, forward_mode_differentiation=False):
def _value_and_grad(f, x, forward_mode_differentiation=False) -> tuple:
if forward_mode_differentiation:

def _wrapper(x):
Expand All @@ -51,6 +52,9 @@ def _wrapper(x):

class _NumPyroOptim(object):
def __init__(self, optim_fn: Callable, *args, **kwargs) -> None:
self.init_fn: Callable[[_Params], _IterOptState]
self.update_fn: Callable[[ArrayLike, _Params, _OptState], _OptState]
self.get_params_fn: Callable[[_OptState], _Params]
self.init_fn, self.update_fn, self.get_params_fn = optim_fn(*args, **kwargs)

def init(self, params: _Params) -> _IterOptState:
Expand Down Expand Up @@ -80,7 +84,7 @@ def eval_and_update(
fn: Callable[[Any], tuple],
state: _IterOptState,
forward_mode_differentiation: bool = False,
):
) -> tuple[tuple[Any, Any], _IterOptState]:
"""
Performs an optimization step for the objective function `fn`.
For most optimizers, the update is performed based on the gradient
Expand All @@ -96,7 +100,7 @@ def eval_and_update(
:param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
:return: a pair of the output of objective function and the new optimizer state.
"""
params = self.get_params(state)
params: _Params = self.get_params(state)
(out, aux), grads = _value_and_grad(
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
)
Expand All @@ -107,7 +111,7 @@ def eval_and_stable_update(
fn: Callable[[Any], tuple],
state: _IterOptState,
forward_mode_differentiation: bool = False,
):
) -> tuple[tuple[Any, Any], _IterOptState]:
"""
Like :meth:`eval_and_update` but when the value of the objective function
or the gradients are not finite, we will not update the input `state`
Expand All @@ -118,7 +122,7 @@ def eval_and_stable_update(
:param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
:return: a pair of the output of objective function and the new optimizer state.
"""
params = self.get_params(state)
params: _Params = self.get_params(state)
(out, aux), grads = _value_and_grad(
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
)
Expand All @@ -141,7 +145,7 @@ def get_params(self, state: _IterOptState) -> _Params:
return self.get_params_fn(opt_state)


def _add_doc(fn):
def _add_doc(fn) -> Callable[[Any], Any]:
def _wrapped(cls):
cls.__doc__ = "Wrapper class for the JAX optimizer: :func:`~jax.example_libraries.optimizers.{}`".format(
fn.__name__
Expand All @@ -153,7 +157,7 @@ def _wrapped(cls):

@_add_doc(optimizers.adam)
class Adam(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(Adam, self).__init__(optimizers.adam, *args, **kwargs)


Expand All @@ -170,11 +174,11 @@ class ClippedAdam(_NumPyroOptim):
https://arxiv.org/abs/1412.6980
"""

def __init__(self, *args, clip_norm=10.0, **kwargs):
def __init__(self, *args, clip_norm: float = 10.0, **kwargs) -> None:
self.clip_norm = clip_norm
super(ClippedAdam, self).__init__(optimizers.adam, *args, **kwargs)

def update(self, g, state):
def update(self, g: _Params, state: _IterOptState) -> _IterOptState:
i, opt_state = state
# clip norm
g = jax.tree.map(lambda g_: jnp.clip(g_, -self.clip_norm, self.clip_norm), g)
Expand All @@ -184,39 +188,39 @@ def update(self, g, state):

@_add_doc(optimizers.adagrad)
class Adagrad(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(Adagrad, self).__init__(optimizers.adagrad, *args, **kwargs)


@_add_doc(optimizers.momentum)
class Momentum(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(Momentum, self).__init__(optimizers.momentum, *args, **kwargs)


@_add_doc(optimizers.rmsprop)
class RMSProp(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(RMSProp, self).__init__(optimizers.rmsprop, *args, **kwargs)


@_add_doc(optimizers.rmsprop_momentum)
class RMSPropMomentum(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(RMSPropMomentum, self).__init__(
optimizers.rmsprop_momentum, *args, **kwargs
)


@_add_doc(optimizers.sgd)
class SGD(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(SGD, self).__init__(optimizers.sgd, *args, **kwargs)


@_add_doc(optimizers.sm3)
class SM3(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(SM3, self).__init__(optimizers.sm3, *args, **kwargs)


Expand All @@ -225,24 +229,36 @@ def __init__(self, *args, **kwargs):
# and pass `unravel_fn` around.
# When arbitrary pytree is supported in JAX, we can just simply use
# identity functions for `init_fn` and `get_params`.
_MinimizeState = namedtuple("MinimizeState", ["flat_params", "unravel_fn"])
class _MinimizeState(namedtuple("_MinimizeState", ["flat_params", "unravel_fn"])):
flat_params: ArrayLike
unravel_fn: Callable[[ArrayLike], _Params]


register_pytree_node(
_MinimizeState,
lambda state: ((state.flat_params,), (state.unravel_fn,)),
lambda data, xs: _MinimizeState(xs[0], data[0]),
)


def _minimize_wrapper():
def init_fn(params):
def _minimize_wrapper() -> (
tuple[
Callable[[_Params], _MinimizeState],
Callable[[Any, Any, _MinimizeState], _MinimizeState],
Callable[[_MinimizeState], _Params],
]
):
def init_fn(params: _Params) -> _MinimizeState:
flat_params, unravel_fn = ravel_pytree(params)
return _MinimizeState(flat_params, unravel_fn)

def update_fn(i, grad_tree, opt_state):
def update_fn(
i: ArrayLike, grad_tree: ArrayLike, opt_state: _MinimizeState
) -> _MinimizeState:
# we don't use update_fn in Minimize, so let it do nothing
return opt_state

def get_params(opt_state):
def get_params(opt_state: _MinimizeState) -> _Params:
flat_params, unravel_fn = opt_state
return unravel_fn(flat_params)

Expand Down Expand Up @@ -289,7 +305,7 @@ class Minimize(_NumPyroOptim):
>>> assert_allclose(quantiles["b"], 3., atol=1e-3)
"""

def __init__(self, method="BFGS", **kwargs):
def __init__(self, method="BFGS", **kwargs) -> None:
super().__init__(_minimize_wrapper)
self._method = method
self._kwargs = kwargs
Expand All @@ -298,8 +314,8 @@ def eval_and_update(
self,
fn: Callable[[Any], tuple],
state: _IterOptState,
forward_mode_differentiation=False,
):
forward_mode_differentiation: bool = False,
) -> tuple[tuple[Any, None], _IterOptState]:
i, (flat_params, unravel_fn) = state

def loss_fn(x):
Expand Down Expand Up @@ -333,17 +349,19 @@ def optax_to_numpyro(transformation) -> _NumPyroOptim:
"""
import optax

def init_fn(params):
def init_fn(params: _Params) -> tuple[_Params, Any]:
opt_state = transformation.init(params)
return params, opt_state

def update_fn(step, grads, state):
def update_fn(
step: ArrayLike, grads: ArrayLike, state: tuple[_Params, Any]
) -> tuple[_Params, Any]:
params, opt_state = state
updates, opt_state = transformation.update(grads, opt_state, params)
updated_params = optax.apply_updates(params, updates)
return updated_params, opt_state

def get_params_fn(state):
def get_params_fn(state: tuple[_Params, Any]) -> _Params:
params, _ = state
return params

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ module = [
"numpyro.contrib.stochastic_support.*",
"numpyro.diagnostics.*",
"numpyro.handlers.*",
"numpyro.optim.*",
"numpyro.primitives.*",
"numpyro.patch.*",
"numpyro.util.*",
Expand Down

0 comments on commit 6ae76ea

Please sign in to comment.