diff --git a/optimistix/_solver/backtracking.py b/optimistix/_solver/backtracking.py index 7e2f904..040ab93 100644 --- a/optimistix/_solver/backtracking.py +++ b/optimistix/_solver/backtracking.py @@ -2,6 +2,7 @@ from typing_extensions import TypeAlias import equinox as eqx +import jax import jax.numpy as jnp from equinox.internal import ω from jaxtyping import Array, Bool, Scalar, ScalarLike @@ -85,7 +86,7 @@ def step( ) y_diff = (y_eval**ω - y**ω).ω - predicted_reduction = tree_dot(grad, y_diff) + predicted_reduction = tree_dot(jax.tree_map(jnp.conj, grad), y_diff).real # Terminate when the Armijo condition is satisfied. That is, `fn(y_eval)` # must do better than its linear approximation: # `fn(y_eval) < fn(y) + grad•y_diff` diff --git a/optimistix/_solver/bfgs.py b/optimistix/_solver/bfgs.py index 12a5ed0..7a4cb84 100644 --- a/optimistix/_solver/bfgs.py +++ b/optimistix/_solver/bfgs.py @@ -50,10 +50,14 @@ def _identity_pytree(pytree: PyTree[Array]) -> lx.PyTreeLinearOperator: for i2, l2 in enumerate(leaves): if i1 == i2: eye_leaves.append( - jnp.eye(jnp.size(l1)).reshape(jnp.shape(l1) + jnp.shape(l2)) + jnp.eye(jnp.size(l1), dtype=l1.dtype).reshape( + jnp.shape(l1) + jnp.shape(l2) + ) ) else: - eye_leaves.append(jnp.zeros(jnp.shape(l1) + jnp.shape(l2))) + eye_leaves.append( + jnp.zeros(jnp.shape(l1) + jnp.shape(l2), dtype=l1.dtype) + ) # This has a Lineax positive_semidefinite tag. This is okay because the BFGS update # preserves positive-definiteness. @@ -111,7 +115,7 @@ def no_update(hessian, hessian_inv): # this we jump straight to the line search. # Likewise we get inner <= eps on convergence, and so again we make no update # to avoid a division by zero. - inner_nonzero = inner > jnp.finfo(inner.dtype).eps + inner_nonzero = jnp.abs(inner) > jnp.finfo(inner.dtype).eps hessian, hessian_inv = filter_cond( inner_nonzero, bfgs_update, no_update, hessian, hessian_inv ) diff --git a/optimistix/_solver/dogleg.py b/optimistix/_solver/dogleg.py index 29ae75e..1d51620 100644 --- a/optimistix/_solver/dogleg.py +++ b/optimistix/_solver/dogleg.py @@ -2,6 +2,7 @@ from typing import Any, cast, Generic, Union import equinox as eqx +import jax import jax.lax as lax import jax.numpy as jnp import lineax as lx @@ -73,22 +74,25 @@ def query( state: _DoglegDescentState, ) -> _DoglegDescentState: del state + conj_grad = jax.tree_map(jnp.conj, f_info.grad) # Compute `denom = grad^T Hess grad.` if isinstance(f_info, FunctionInfo.EvalGradHessian): denom = tree_dot(f_info.grad, f_info.hessian.mv(f_info.grad)) elif isinstance(f_info, FunctionInfo.ResidualJac): # Use Gauss--Newton approximation `Hess ~ J^T J` - denom = sum_squares(f_info.jac.mv(f_info.grad)) + denom = sum_squares(f_info.jac.mv(conj_grad)) else: raise ValueError( "`DoglegDescent` can only be used with least-squares solvers, or " "quasi-Newton minimisers which make approximations to the Hessian " "(like `optx.BFGS(use_inverse=False)`)" ) - denom_nonzero = denom > jnp.finfo(denom.dtype).eps + denom_nonzero = jnp.abs(denom) > jnp.finfo(denom.dtype).eps safe_denom = jnp.where(denom_nonzero, denom, 1) # Compute `grad^T grad / (grad^T Hess grad)` - scaling = jnp.where(denom_nonzero, sum_squares(f_info.grad) / safe_denom, 0.0) + + with jax.numpy_dtype_promotion("standard"): + scaling = jnp.where(denom_nonzero, sum_squares(conj_grad) / safe_denom, 0.0) scaling = cast(Array, scaling) # Downhill towards the bottom of the quadratic basin. @@ -97,7 +101,8 @@ def query( newton_norm = self.trust_region_norm(newton_sol) # Downhill steepest descent. - cauchy = (-scaling * f_info.grad**ω).ω + with jax.numpy_dtype_promotion("standard"): + cauchy = (-scaling * conj_grad**ω).ω cauchy_norm = self.trust_region_norm(cauchy) return _DoglegDescentState( @@ -139,7 +144,8 @@ def interpolate_cauchy_and_newton(cauchy, newton): """ def interpolate(t): - return (cauchy**ω + (t - 1) * (newton**ω - cauchy**ω)).ω + with jax.numpy_dtype_promotion("standard"): + return (cauchy**ω + (t - 1) * (newton**ω - cauchy**ω)).ω # The vast majority of the time we expect users to use `two_norm`, # ie. the classic, elliptical trust region radius. In this case, we @@ -152,7 +158,7 @@ def interpolate(t): # find the value which hits the trust region radius. if self.trust_region_norm is two_norm: a = sum_squares((newton**ω - cauchy**ω).ω) - inner_prod = tree_dot(cauchy, (newton**ω - cauchy**ω).ω) + inner_prod = tree_dot(cauchy, (newton**ω - cauchy**ω).ω).real b = 2 * (inner_prod - a) c = state.cauchy_norm**2 - 2 * inner_prod + a - scaled_step_size**2 quadratic_1 = jnp.clip( diff --git a/optimistix/_solver/gauss_newton.py b/optimistix/_solver/gauss_newton.py index 17d8a06..80216b5 100644 --- a/optimistix/_solver/gauss_newton.py +++ b/optimistix/_solver/gauss_newton.py @@ -59,7 +59,8 @@ def newton_step( value.) """ if isinstance(f_info, FunctionInfo.EvalGradHessianInv): - newton = f_info.hessian_inv.mv(f_info.grad) + conj_grad = jax.tree_map(jnp.conj, f_info.grad) + newton = f_info.hessian_inv.mv(conj_grad) result = RESULTS.successful else: if isinstance(f_info, FunctionInfo.EvalGradHessian): @@ -73,7 +74,7 @@ def newton_step( "Cannot use a Newton descent with a solver that only evaluates the " "gradient, or only the function itself." ) - out = lx.linear_solve(operator, vector, linear_solver) + out = lx.linear_solve(operator, jax.tree_map(jnp.conj, vector), linear_solver) newton = out.value result = RESULTS.promote(out.result) return newton, result diff --git a/optimistix/_solver/gradient_methods.py b/optimistix/_solver/gradient_methods.py index 90f9398..ac84d0b 100644 --- a/optimistix/_solver/gradient_methods.py +++ b/optimistix/_solver/gradient_methods.py @@ -70,7 +70,7 @@ def query( ) if self.norm is not None: grad = (grad**ω / self.norm(grad)).ω - return _SteepestDescentState(grad) + return _SteepestDescentState(jax.tree_map(jnp.conj, grad)) def step( self, step_size: Scalar, state: _SteepestDescentState diff --git a/optimistix/_solver/levenberg_marquardt.py b/optimistix/_solver/levenberg_marquardt.py index 9e27c56..8abebc5 100644 --- a/optimistix/_solver/levenberg_marquardt.py +++ b/optimistix/_solver/levenberg_marquardt.py @@ -9,6 +9,7 @@ import lineax as lx from equinox.internal import ω from jaxtyping import Array, Float, PyTree, Scalar, ScalarLike +from lineax.internal import default_floating_dtype as default_floating_dtype from .._custom_types import Aux, Out, Y from .._misc import max_norm, tree_full_like, two_norm @@ -57,7 +58,11 @@ def damped_newton_step( lm_param = jnp.where(pred, 1 / safe_step_size, jnp.finfo(step_size).max) lm_param = cast(Array, lm_param) if isinstance(f_info, FunctionInfo.EvalGradHessian): - operator = f_info.hessian + lm_param * lx.IdentityLinearOperator( + leaves = jtu.tree_leaves(f_info.hessian.in_structure()) + dtype = ( + default_floating_dtype() if len(leaves) == 0 else jnp.result_type(*leaves) + ) + operator = f_info.hessian + lm_param.astype(dtype) * lx.IdentityLinearOperator( f_info.hessian.in_structure() ) vector = f_info.grad @@ -73,7 +78,7 @@ def damped_newton_step( "provide (approximate) Hessian information." ) linear_sol = lx.linear_solve(operator, vector, linear_solver, throw=False) - return linear_sol.value, RESULTS.promote(linear_sol.result) + return jax.tree_map(jnp.conj, linear_sol.value), RESULTS.promote(linear_sol.result) class _DampedNewtonDescentState(eqx.Module, strict=True): diff --git a/optimistix/_solver/nonlinear_cg.py b/optimistix/_solver/nonlinear_cg.py index cc43e55..48766a6 100644 --- a/optimistix/_solver/nonlinear_cg.py +++ b/optimistix/_solver/nonlinear_cg.py @@ -2,6 +2,7 @@ from typing import Any, cast, Generic, Union import equinox as eqx +import jax import jax.numpy as jnp from equinox.internal import ω from jaxtyping import Array, PyTree, Scalar @@ -31,7 +32,9 @@ def polak_ribiere(grad_vector: Y, grad_prev: Y, y_diff_prev: Y) -> Scalar: # have a gradient. In either case we set β=0 to revert to just gradient descent. pred = denominator > jnp.finfo(denominator.dtype).eps safe_denom = jnp.where(pred, denominator, 1) - out = jnp.where(pred, jnp.clip(numerator / safe_denom, min=0), 0) + + with jax.numpy_dtype_promotion("standard"): + out = jnp.where(pred, jnp.clip(numerator / safe_denom, min=0), 0) return cast(Scalar, out) @@ -67,7 +70,8 @@ def dai_yuan(grad: Y, grad_prev: Y, y_diff_prev: Y) -> Scalar: # Triggers at initialisation and convergence, as above. pred = jnp.abs(denominator) > jnp.finfo(denominator.dtype).eps safe_denom = jnp.where(pred, denominator, 1) - return jnp.where(pred, numerator / safe_denom, 0) + with jax.numpy_dtype_promotion("standard"): + return jnp.where(pred, numerator / safe_denom, 0) class _NonlinearCGDescentState(eqx.Module, Generic[Y], strict=True): @@ -141,11 +145,13 @@ def query( # `state.{grad, y_diff} = 0`, i.e. our previous step hit a local minima, then # on this next step we'll again just use gradient descent, and stop. beta = self.method(f_info.grad, state.grad, state.y_diff) - neg_grad = (-(f_info.grad**ω)).ω - nonlinear_cg_direction = (neg_grad**ω + beta * state.y_diff**ω).ω + conj_grad = jax.tree_map(jnp.conj, f_info.grad) + neg_grad = (-(conj_grad**ω)).ω + with jax.numpy_dtype_promotion("standard"): + nonlinear_cg_direction = (neg_grad**ω + beta * state.y_diff**ω).ω # Check if this is a descent direction. Use gradient descent if it isn't. y_diff = tree_where( - tree_dot(f_info.grad, nonlinear_cg_direction) < 0, + tree_dot(conj_grad, nonlinear_cg_direction).real < 0, nonlinear_cg_direction, neg_grad, ) diff --git a/optimistix/_solver/optax.py b/optimistix/_solver/optax.py index 1f6a26b..8d9d7c9 100644 --- a/optimistix/_solver/optax.py +++ b/optimistix/_solver/optax.py @@ -97,7 +97,9 @@ def step( ("loss" in self.verbose, "Loss", f), ("y" in self.verbose, "y", y), ) - updates, new_opt_state = self.optim.update(grads, state.opt_state) + updates, new_opt_state = self.optim.update( + jax.tree_map(jnp.conj, grads), state.opt_state + ) new_y = eqx.apply_updates(y, updates) terminate = cauchy_termination( self.rtol, diff --git a/optimistix/_solver/trust_region.py b/optimistix/_solver/trust_region.py index 4a6b850..4ee62b9 100644 --- a/optimistix/_solver/trust_region.py +++ b/optimistix/_solver/trust_region.py @@ -3,6 +3,7 @@ from typing_extensions import TypeAlias import equinox as eqx +import jax import jax.numpy as jnp from equinox import AbstractVar from equinox.internal import ω @@ -166,9 +167,9 @@ def predict_reduction( if isinstance(f_info, FunctionInfo.EvalGradHessian): # Minimisation algorithm. Directly compute the quadratic approximation. return tree_dot( - y_diff, + jax.tree_map(jnp.conj, f_info.grad), (f_info.grad**ω + 0.5 * f_info.hessian.mv(y_diff) ** ω).ω, - ) + ).real elif isinstance(f_info, FunctionInfo.ResidualJac): # Least-squares algorithm. So instead of considering fn (which returns the # residuals), instead consider `0.5*fn(y)^2`, and then apply the logic as @@ -190,7 +191,7 @@ def predict_reduction( jacobian_term = sum_squares( (f_info.jac.mv(y_diff) ** ω + f_info.residual**ω).ω ) - return 0.5 * (jacobian_term - rtr) + return 0.5 * (jacobian_term - rtr).real else: raise ValueError( "Cannot use `ClassicalTrustRegion` with this solver. This is because " @@ -273,7 +274,7 @@ def predict_reduction( FunctionInfo.ResidualJac, ), ): - return tree_dot(f_info.grad, y_diff) + return tree_dot(jax.tree_map(jnp.conj, f_info.grad), y_diff).real else: raise ValueError( "Cannot use `LinearTrustRegion` with this solver. This is because " diff --git a/tests/helpers.py b/tests/helpers.py index 29c5767..995523e 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -221,13 +221,13 @@ def bowl(tree: PyTree[Array], args: Array): # Trivial quadratic bowl smoke test for convergence. (y, _) = jfu.ravel_pytree(tree) matrix = args - return y.T @ matrix @ y + return y.T.conj() @ matrix @ y def diagonal_quadratic_bowl(tree: PyTree[Array], args: PyTree[Array]): # A diagonal quadratic bowl smoke test for convergence. weight_vector = args - return (ω(tree).call(jnp.square) * (0.1 + weight_vector**ω)).ω + return (ω(tree).call(jnp.square) * (0.1 + weight_vector**ω)).call(jnp.abs).ω def rosenbrock(tree: PyTree[Array], args: Scalar): @@ -317,7 +317,7 @@ def loss(model, x, y): def square_minus_one(x: Array, args: PyTree): """A simple ||x||^2 - 1 function.""" - return jnp.sum(jnp.square(x)) - 1.0 + return jnp.sum(jnp.square(jnp.abs(x))) - 1.0 # @@ -383,6 +383,16 @@ def get_weights(model): [jr.normal(key, leaf.shape, leaf.dtype) ** 2 for leaf in leaves] ) +diagonal_bowl_init_complex = ( + {"a": (0.05 + 0.01j) * jnp.ones((2, 3, 3), dtype=jnp.complex128)}, + ((0.01 + 0.05j) * jnp.ones(2, dtype=jnp.complex128)), +) +leaves_complex, treedef_complex = jtu.tree_flatten(diagonal_bowl_init_complex) +key = jr.PRNGKey(17) +diagonal_bowl_args_complex = treedef.unflatten( + [jr.normal(key, leaf.shape, leaf.dtype) ** 2 for leaf in leaves_complex] +) + # neural net args ffn_data = jnp.linspace(0, 1, 100)[..., None] ffn_args = (ffn_static, ffn_data) @@ -394,6 +404,12 @@ def get_weights(model): diagonal_bowl_init, diagonal_bowl_args, ), + ( + diagonal_quadratic_bowl, + jnp.array(0.0), + diagonal_bowl_init_complex, + diagonal_bowl_args_complex, + ), ( rosenbrock, jnp.array(0.0), @@ -463,6 +479,12 @@ def get_weights(model): ), # Problems with initial value of 0 (square_minus_one, jnp.array(-1.0), jnp.array(1.0), None), + ( + square_minus_one, + jnp.array(-1.0), + jnp.array(1.0 + 1.0j, dtype=jnp.complex128), + None, + ), ) # ROOT FIND/FIXED POINT PROBLEMS diff --git a/tests/test_least_squares.py b/tests/test_least_squares.py index 5a5ac0b..55c1dbf 100644 --- a/tests/test_least_squares.py +++ b/tests/test_least_squares.py @@ -67,8 +67,10 @@ def test_least_squares_jvp(getkey, solver, _fn, minimum, init, args): fn = _fn dynamic_args, static_args = eqx.partition(args, eqx.is_array) - t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), init) - t_dynamic_args = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), dynamic_args) + t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=x.dtype), init) + t_dynamic_args = jtu.tree_map( + lambda x: jr.normal(getkey(), x.shape, dtype=x.dtype), dynamic_args + ) def least_squares(x, dynamic_args, *, adjoint): args = eqx.combine(dynamic_args, static_args) diff --git a/tests/test_minimise.py b/tests/test_minimise.py index 2b6ab2c..f633905 100644 --- a/tests/test_minimise.py +++ b/tests/test_minimise.py @@ -72,8 +72,10 @@ def test_minimise_jvp(getkey, solver, _fn, minimum, init, args): fn = _fn dynamic_args, static_args = eqx.partition(args, eqx.is_array) - t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), init) - t_dynamic_args = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), dynamic_args) + t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=x.dtype), init) + t_dynamic_args = jtu.tree_map( + lambda x: jr.normal(getkey(), x.shape, dtype=x.dtype), dynamic_args + ) def minimise(x, dynamic_args, *, adjoint): args = eqx.combine(dynamic_args, static_args) @@ -138,18 +140,19 @@ def minimise(x, dynamic_args, *, adjoint): # assert tree_allclose(t_out2, t_expected_out, atol=atol, rtol=rtol) +@pytest.mark.parametrize("dtype", [jnp.float64, jnp.complex128]) @pytest.mark.parametrize( "method", [optx.polak_ribiere, optx.fletcher_reeves, optx.hestenes_stiefel, optx.dai_yuan], ) -def test_nonlinear_cg_methods(method): +def test_nonlinear_cg_methods(method, dtype): solver = optx.NonlinearCG(rtol=1e-10, atol=1e-10, method=method) def f(y, _): - A = jnp.array([[2.0, -1.0], [-1.0, 3.0]]) - b = jnp.array([-100.0, 5.0]) - c = jnp.array(100.0) - return jnp.einsum("ij,i,j", A, y, y) + jnp.dot(b, y) + c + A = jnp.array([[2.0, -1.0], [-1.0, 3.0]], dtype=dtype) + b = jnp.array([-100.0, 5.0], dtype=dtype) + c = jnp.array(100.0, dtype=dtype) + return (jnp.einsum("ij,i,j", A, y, y) + jnp.dot(b, y) + c).real # Analytic minimum: # 0 = df/dyk @@ -158,9 +161,11 @@ def f(y, _): # => y = -0.5 A^{-1} b # = [[-0.3, 0.1], [0.1, 0.2]] [-100, 5] # = [29.5, 9] - y0 = jnp.array([2.0, 3.0]) + y0 = jnp.array([2.0, 3.0], dtype=dtype) sol = optx.minimise(f, solver, y0, max_steps=500) - assert tree_allclose(sol.value, jnp.array([29.5, 9.0]), rtol=1e-5, atol=1e-5) + assert tree_allclose( + sol.value, jnp.array([29.5, 9.0], dtype=dtype), rtol=1e-5, atol=1e-5 + ) def test_optax_recompilation():