diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index 4b85d5a3d4..6339f0392c 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -13,6 +13,8 @@ # limitations under the License. from __future__ import annotations +import typing as tp + import jax import jax.numpy as jnp import optax @@ -23,6 +25,8 @@ from flax.nnx.object import Object from flax.nnx.variablelib import Variable, VariableState +M = tp.TypeVar('M', bound=nnx.Module) + # TODO: add tests and docstrings @@ -101,7 +105,7 @@ def optimizer_update_variables(x, update): return jax.tree.map(optimizer_update_variables, opt_state, updates) -class Optimizer(Object): +class Optimizer(Object, tp.Generic[M]): """Simple train state for the common case with a single Optax optimizer. Example usage:: @@ -168,7 +172,7 @@ class Optimizer(Object): def __init__( self, - model: nnx.Module, + model: M, tx: optax.GradientTransformation, wrt: filterlib.Filter = nnx.Param, ):