From fd30ac05bac582402b0f8d9b4ed83913cdd12f4d Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Wed, 7 Aug 2024 16:37:09 +0300 Subject: [PATCH 01/12] Add (failing) test --- tests/helpers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index 29c5767..a131f09 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -317,7 +317,7 @@ def loss(model, x, y): def square_minus_one(x: Array, args: PyTree): """A simple ||x||^2 - 1 function.""" - return jnp.sum(jnp.square(x)) - 1.0 + return jnp.sum(jnp.square(jnp.abs(x))) - 1.0 # @@ -463,6 +463,12 @@ def get_weights(model): ), # Problems with initial value of 0 (square_minus_one, jnp.array(-1.0), jnp.array(1.0), None), + ( + square_minus_one, + jnp.array(-1.0, dtype=jnp.complex128), + jnp.array(1.0, dtype=jnp.complex128), + None, + ), ) # ROOT FIND/FIXED POINT PROBLEMS From 5db155bdd14e8a90b0a5c4abcd35d01f66c0cc76 Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Wed, 7 Aug 2024 16:50:15 +0300 Subject: [PATCH 02/12] Rough fixes --- optimistix/_solver/backtracking.py | 2 +- optimistix/_solver/bfgs.py | 6 ++++-- optimistix/_solver/dogleg.py | 11 ++++++++--- optimistix/_solver/nonlinear_cg.py | 5 ++++- optimistix/_solver/trust_region.py | 6 +++--- tests/test_minimise.py | 2 +- 6 files changed, 21 insertions(+), 11 deletions(-) diff --git a/optimistix/_solver/backtracking.py b/optimistix/_solver/backtracking.py index 7e2f904..5bbb7ac 100644 --- a/optimistix/_solver/backtracking.py +++ b/optimistix/_solver/backtracking.py @@ -85,7 +85,7 @@ def step( ) y_diff = (y_eval**ω - y**ω).ω - predicted_reduction = tree_dot(grad, y_diff) + predicted_reduction = tree_dot(grad, y_diff).real # Terminate when the Armijo condition is satisfied. That is, `fn(y_eval)` # must do better than its linear approximation: # `fn(y_eval) < fn(y) + grad•y_diff` diff --git a/optimistix/_solver/bfgs.py b/optimistix/_solver/bfgs.py index 12a5ed0..1182303 100644 --- a/optimistix/_solver/bfgs.py +++ b/optimistix/_solver/bfgs.py @@ -50,7 +50,9 @@ def _identity_pytree(pytree: PyTree[Array]) -> lx.PyTreeLinearOperator: for i2, l2 in enumerate(leaves): if i1 == i2: eye_leaves.append( - jnp.eye(jnp.size(l1)).reshape(jnp.shape(l1) + jnp.shape(l2)) + jnp.eye(jnp.size(l1), dtype=l1.dtype).reshape( + jnp.shape(l1) + jnp.shape(l2) + ) ) else: eye_leaves.append(jnp.zeros(jnp.shape(l1) + jnp.shape(l2))) @@ -111,7 +113,7 @@ def no_update(hessian, hessian_inv): # this we jump straight to the line search. # Likewise we get inner <= eps on convergence, and so again we make no update # to avoid a division by zero. - inner_nonzero = inner > jnp.finfo(inner.dtype).eps + inner_nonzero = jnp.abs(inner) > jnp.finfo(inner.dtype).eps hessian, hessian_inv = filter_cond( inner_nonzero, bfgs_update, no_update, hessian, hessian_inv ) diff --git a/optimistix/_solver/dogleg.py b/optimistix/_solver/dogleg.py index 29ae75e..0cab5e4 100644 --- a/optimistix/_solver/dogleg.py +++ b/optimistix/_solver/dogleg.py @@ -2,6 +2,7 @@ from typing import Any, cast, Generic, Union import equinox as eqx +import jax import jax.lax as lax import jax.numpy as jnp import lineax as lx @@ -85,10 +86,14 @@ def query( "quasi-Newton minimisers which make approximations to the Hessian " "(like `optx.BFGS(use_inverse=False)`)" ) - denom_nonzero = denom > jnp.finfo(denom.dtype).eps + denom_nonzero = jnp.abs(denom) > jnp.finfo(denom.dtype).eps safe_denom = jnp.where(denom_nonzero, denom, 1) # Compute `grad^T grad / (grad^T Hess grad)` - scaling = jnp.where(denom_nonzero, sum_squares(f_info.grad) / safe_denom, 0.0) + + with jax.numpy_dtype_promotion("standard"): + scaling = jnp.where( + denom_nonzero, sum_squares(f_info.grad) / safe_denom, 0.0 + ) scaling = cast(Array, scaling) # Downhill towards the bottom of the quadratic basin. @@ -152,7 +157,7 @@ def interpolate(t): # find the value which hits the trust region radius. if self.trust_region_norm is two_norm: a = sum_squares((newton**ω - cauchy**ω).ω) - inner_prod = tree_dot(cauchy, (newton**ω - cauchy**ω).ω) + inner_prod = tree_dot(cauchy, (newton**ω - cauchy**ω).ω).real b = 2 * (inner_prod - a) c = state.cauchy_norm**2 - 2 * inner_prod + a - scaled_step_size**2 quadratic_1 = jnp.clip( diff --git a/optimistix/_solver/nonlinear_cg.py b/optimistix/_solver/nonlinear_cg.py index cc43e55..af71b17 100644 --- a/optimistix/_solver/nonlinear_cg.py +++ b/optimistix/_solver/nonlinear_cg.py @@ -2,6 +2,7 @@ from typing import Any, cast, Generic, Union import equinox as eqx +import jax import jax.numpy as jnp from equinox.internal import ω from jaxtyping import Array, PyTree, Scalar @@ -31,7 +32,9 @@ def polak_ribiere(grad_vector: Y, grad_prev: Y, y_diff_prev: Y) -> Scalar: # have a gradient. In either case we set β=0 to revert to just gradient descent. pred = denominator > jnp.finfo(denominator.dtype).eps safe_denom = jnp.where(pred, denominator, 1) - out = jnp.where(pred, jnp.clip(numerator / safe_denom, min=0), 0) + + with jax.numpy_dtype_promotion("standard"): + out = jnp.where(pred, jnp.clip(numerator / safe_denom, min=0), 0) return cast(Scalar, out) diff --git a/optimistix/_solver/trust_region.py b/optimistix/_solver/trust_region.py index 4a6b850..83f5dfb 100644 --- a/optimistix/_solver/trust_region.py +++ b/optimistix/_solver/trust_region.py @@ -168,7 +168,7 @@ def predict_reduction( return tree_dot( y_diff, (f_info.grad**ω + 0.5 * f_info.hessian.mv(y_diff) ** ω).ω, - ) + ).real elif isinstance(f_info, FunctionInfo.ResidualJac): # Least-squares algorithm. So instead of considering fn (which returns the # residuals), instead consider `0.5*fn(y)^2`, and then apply the logic as @@ -190,7 +190,7 @@ def predict_reduction( jacobian_term = sum_squares( (f_info.jac.mv(y_diff) ** ω + f_info.residual**ω).ω ) - return 0.5 * (jacobian_term - rtr) + return 0.5 * (jacobian_term - rtr).real else: raise ValueError( "Cannot use `ClassicalTrustRegion` with this solver. This is because " @@ -273,7 +273,7 @@ def predict_reduction( FunctionInfo.ResidualJac, ), ): - return tree_dot(f_info.grad, y_diff) + return tree_dot(f_info.grad, y_diff).real else: raise ValueError( "Cannot use `LinearTrustRegion` with this solver. This is because " diff --git a/tests/test_minimise.py b/tests/test_minimise.py index 2b6ab2c..f51676d 100644 --- a/tests/test_minimise.py +++ b/tests/test_minimise.py @@ -53,7 +53,7 @@ def test_minimise(solver, _fn, minimum, init, args): throw=False, ).value optx_min = _fn(optx_argmin, args) - assert tree_allclose(optx_min, minimum, atol=atol, rtol=rtol) + assert tree_allclose(optx_min, minimum.real, atol=atol, rtol=rtol) @pytest.mark.parametrize("solver", minimisers) From 0878d871cf0284edad93e31c81350f5ae6a224ec Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Wed, 7 Aug 2024 17:57:46 +0300 Subject: [PATCH 03/12] And least squares test too --- optimistix/_solver/bfgs.py | 4 +++- tests/helpers.py | 20 ++++++++++++++++++-- tests/test_minimise.py | 2 +- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/optimistix/_solver/bfgs.py b/optimistix/_solver/bfgs.py index 1182303..7a4cb84 100644 --- a/optimistix/_solver/bfgs.py +++ b/optimistix/_solver/bfgs.py @@ -55,7 +55,9 @@ def _identity_pytree(pytree: PyTree[Array]) -> lx.PyTreeLinearOperator: ) ) else: - eye_leaves.append(jnp.zeros(jnp.shape(l1) + jnp.shape(l2))) + eye_leaves.append( + jnp.zeros(jnp.shape(l1) + jnp.shape(l2), dtype=l1.dtype) + ) # This has a Lineax positive_semidefinite tag. This is okay because the BFGS update # preserves positive-definiteness. diff --git a/tests/helpers.py b/tests/helpers.py index a131f09..cb0f609 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -221,13 +221,13 @@ def bowl(tree: PyTree[Array], args: Array): # Trivial quadratic bowl smoke test for convergence. (y, _) = jfu.ravel_pytree(tree) matrix = args - return y.T @ matrix @ y + return y.T.conj() @ matrix @ y def diagonal_quadratic_bowl(tree: PyTree[Array], args: PyTree[Array]): # A diagonal quadratic bowl smoke test for convergence. weight_vector = args - return (ω(tree).call(jnp.square) * (0.1 + weight_vector**ω)).ω + return (ω(tree).call(jnp.abs).call(jnp.square) * (0.1 + weight_vector**ω)).ω def rosenbrock(tree: PyTree[Array], args: Scalar): @@ -383,6 +383,16 @@ def get_weights(model): [jr.normal(key, leaf.shape, leaf.dtype) ** 2 for leaf in leaves] ) +diagonal_bowl_init_complex = ( + {"a": 0.05 * jnp.ones((2, 3, 3), dtype=jnp.complex128)}, + (0.05j * jnp.ones(2, dtype=jnp.complex128)), +) +leaves_complex, treedef_complex = jtu.tree_flatten(diagonal_bowl_init_complex) +key = jr.PRNGKey(17) +diagonal_bowl_args_complex = treedef.unflatten( + [jr.normal(key, leaf.shape, leaf.dtype) ** 2 for leaf in leaves] +) + # neural net args ffn_data = jnp.linspace(0, 1, 100)[..., None] ffn_args = (ffn_static, ffn_data) @@ -394,6 +404,12 @@ def get_weights(model): diagonal_bowl_init, diagonal_bowl_args, ), + ( + diagonal_quadratic_bowl, + jnp.array(0.0), + diagonal_bowl_init_complex, + diagonal_bowl_args_complex, + ), ( rosenbrock, jnp.array(0.0), diff --git a/tests/test_minimise.py b/tests/test_minimise.py index f51676d..2b6ab2c 100644 --- a/tests/test_minimise.py +++ b/tests/test_minimise.py @@ -53,7 +53,7 @@ def test_minimise(solver, _fn, minimum, init, args): throw=False, ).value optx_min = _fn(optx_argmin, args) - assert tree_allclose(optx_min, minimum.real, atol=atol, rtol=rtol) + assert tree_allclose(optx_min, minimum, atol=atol, rtol=rtol) @pytest.mark.parametrize("solver", minimisers) From b33bbef812ef5d0cabec23701efc51798ae7741c Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Thu, 8 Aug 2024 15:57:36 +0300 Subject: [PATCH 04/12] More fixes --- optimistix/_solver/dogleg.py | 6 ++++-- optimistix/_solver/levenberg_marquardt.py | 7 ++++++- optimistix/_solver/nonlinear_cg.py | 3 ++- tests/helpers.py | 6 +++--- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/optimistix/_solver/dogleg.py b/optimistix/_solver/dogleg.py index 0cab5e4..ff073b8 100644 --- a/optimistix/_solver/dogleg.py +++ b/optimistix/_solver/dogleg.py @@ -102,7 +102,8 @@ def query( newton_norm = self.trust_region_norm(newton_sol) # Downhill steepest descent. - cauchy = (-scaling * f_info.grad**ω).ω + with jax.numpy_dtype_promotion("standard"): + cauchy = (-scaling * f_info.grad**ω).ω cauchy_norm = self.trust_region_norm(cauchy) return _DoglegDescentState( @@ -144,7 +145,8 @@ def interpolate_cauchy_and_newton(cauchy, newton): """ def interpolate(t): - return (cauchy**ω + (t - 1) * (newton**ω - cauchy**ω)).ω + with jax.numpy_dtype_promotion("standard"): + return (cauchy**ω + (t - 1) * (newton**ω - cauchy**ω)).ω # The vast majority of the time we expect users to use `two_norm`, # ie. the classic, elliptical trust region radius. In this case, we diff --git a/optimistix/_solver/levenberg_marquardt.py b/optimistix/_solver/levenberg_marquardt.py index 9e27c56..82fabda 100644 --- a/optimistix/_solver/levenberg_marquardt.py +++ b/optimistix/_solver/levenberg_marquardt.py @@ -9,6 +9,7 @@ import lineax as lx from equinox.internal import ω from jaxtyping import Array, Float, PyTree, Scalar, ScalarLike +from lineax.internal import default_floating_dtype as default_floating_dtype from .._custom_types import Aux, Out, Y from .._misc import max_norm, tree_full_like, two_norm @@ -57,7 +58,11 @@ def damped_newton_step( lm_param = jnp.where(pred, 1 / safe_step_size, jnp.finfo(step_size).max) lm_param = cast(Array, lm_param) if isinstance(f_info, FunctionInfo.EvalGradHessian): - operator = f_info.hessian + lm_param * lx.IdentityLinearOperator( + leaves = jtu.tree_leaves(f_info.hessian.in_structure()) + dtype = ( + default_floating_dtype() if len(leaves) == 0 else jnp.result_type(*leaves) + ) + operator = f_info.hessian + lm_param.astype(dtype) * lx.IdentityLinearOperator( f_info.hessian.in_structure() ) vector = f_info.grad diff --git a/optimistix/_solver/nonlinear_cg.py b/optimistix/_solver/nonlinear_cg.py index af71b17..bbaaff7 100644 --- a/optimistix/_solver/nonlinear_cg.py +++ b/optimistix/_solver/nonlinear_cg.py @@ -70,7 +70,8 @@ def dai_yuan(grad: Y, grad_prev: Y, y_diff_prev: Y) -> Scalar: # Triggers at initialisation and convergence, as above. pred = jnp.abs(denominator) > jnp.finfo(denominator.dtype).eps safe_denom = jnp.where(pred, denominator, 1) - return jnp.where(pred, numerator / safe_denom, 0) + with jax.numpy_dtype_promotion("standard"): + return jnp.where(pred, numerator / safe_denom, 0) class _NonlinearCGDescentState(eqx.Module, Generic[Y], strict=True): diff --git a/tests/helpers.py b/tests/helpers.py index cb0f609..367a7ce 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -227,7 +227,7 @@ def bowl(tree: PyTree[Array], args: Array): def diagonal_quadratic_bowl(tree: PyTree[Array], args: PyTree[Array]): # A diagonal quadratic bowl smoke test for convergence. weight_vector = args - return (ω(tree).call(jnp.abs).call(jnp.square) * (0.1 + weight_vector**ω)).ω + return (ω(tree).call(jnp.square) * (0.1 + weight_vector**ω)).call(jnp.abs).ω def rosenbrock(tree: PyTree[Array], args: Scalar): @@ -390,7 +390,7 @@ def get_weights(model): leaves_complex, treedef_complex = jtu.tree_flatten(diagonal_bowl_init_complex) key = jr.PRNGKey(17) diagonal_bowl_args_complex = treedef.unflatten( - [jr.normal(key, leaf.shape, leaf.dtype) ** 2 for leaf in leaves] + [jr.normal(key, leaf.shape, leaf.dtype) ** 2 for leaf in leaves_complex] ) # neural net args @@ -481,7 +481,7 @@ def get_weights(model): (square_minus_one, jnp.array(-1.0), jnp.array(1.0), None), ( square_minus_one, - jnp.array(-1.0, dtype=jnp.complex128), + jnp.array(-1.0), jnp.array(1.0, dtype=jnp.complex128), None, ), From bfaa5d765b4641695cfb2c465dd6dd8dcf2ec036 Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Thu, 8 Aug 2024 16:04:45 +0300 Subject: [PATCH 05/12] Test updates --- optimistix/_solver/nonlinear_cg.py | 3 ++- tests/test_least_squares.py | 6 ++++-- tests/test_minimise.py | 23 ++++++++++++++--------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/optimistix/_solver/nonlinear_cg.py b/optimistix/_solver/nonlinear_cg.py index bbaaff7..e7fb5f9 100644 --- a/optimistix/_solver/nonlinear_cg.py +++ b/optimistix/_solver/nonlinear_cg.py @@ -146,7 +146,8 @@ def query( # on this next step we'll again just use gradient descent, and stop. beta = self.method(f_info.grad, state.grad, state.y_diff) neg_grad = (-(f_info.grad**ω)).ω - nonlinear_cg_direction = (neg_grad**ω + beta * state.y_diff**ω).ω + with jax.numpy_dtype_promotion("standard"): + nonlinear_cg_direction = (neg_grad**ω + beta * state.y_diff**ω).ω # Check if this is a descent direction. Use gradient descent if it isn't. y_diff = tree_where( tree_dot(f_info.grad, nonlinear_cg_direction) < 0, diff --git a/tests/test_least_squares.py b/tests/test_least_squares.py index 5a5ac0b..55c1dbf 100644 --- a/tests/test_least_squares.py +++ b/tests/test_least_squares.py @@ -67,8 +67,10 @@ def test_least_squares_jvp(getkey, solver, _fn, minimum, init, args): fn = _fn dynamic_args, static_args = eqx.partition(args, eqx.is_array) - t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), init) - t_dynamic_args = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), dynamic_args) + t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=x.dtype), init) + t_dynamic_args = jtu.tree_map( + lambda x: jr.normal(getkey(), x.shape, dtype=x.dtype), dynamic_args + ) def least_squares(x, dynamic_args, *, adjoint): args = eqx.combine(dynamic_args, static_args) diff --git a/tests/test_minimise.py b/tests/test_minimise.py index 2b6ab2c..f633905 100644 --- a/tests/test_minimise.py +++ b/tests/test_minimise.py @@ -72,8 +72,10 @@ def test_minimise_jvp(getkey, solver, _fn, minimum, init, args): fn = _fn dynamic_args, static_args = eqx.partition(args, eqx.is_array) - t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), init) - t_dynamic_args = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape), dynamic_args) + t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=x.dtype), init) + t_dynamic_args = jtu.tree_map( + lambda x: jr.normal(getkey(), x.shape, dtype=x.dtype), dynamic_args + ) def minimise(x, dynamic_args, *, adjoint): args = eqx.combine(dynamic_args, static_args) @@ -138,18 +140,19 @@ def minimise(x, dynamic_args, *, adjoint): # assert tree_allclose(t_out2, t_expected_out, atol=atol, rtol=rtol) +@pytest.mark.parametrize("dtype", [jnp.float64, jnp.complex128]) @pytest.mark.parametrize( "method", [optx.polak_ribiere, optx.fletcher_reeves, optx.hestenes_stiefel, optx.dai_yuan], ) -def test_nonlinear_cg_methods(method): +def test_nonlinear_cg_methods(method, dtype): solver = optx.NonlinearCG(rtol=1e-10, atol=1e-10, method=method) def f(y, _): - A = jnp.array([[2.0, -1.0], [-1.0, 3.0]]) - b = jnp.array([-100.0, 5.0]) - c = jnp.array(100.0) - return jnp.einsum("ij,i,j", A, y, y) + jnp.dot(b, y) + c + A = jnp.array([[2.0, -1.0], [-1.0, 3.0]], dtype=dtype) + b = jnp.array([-100.0, 5.0], dtype=dtype) + c = jnp.array(100.0, dtype=dtype) + return (jnp.einsum("ij,i,j", A, y, y) + jnp.dot(b, y) + c).real # Analytic minimum: # 0 = df/dyk @@ -158,9 +161,11 @@ def f(y, _): # => y = -0.5 A^{-1} b # = [[-0.3, 0.1], [0.1, 0.2]] [-100, 5] # = [29.5, 9] - y0 = jnp.array([2.0, 3.0]) + y0 = jnp.array([2.0, 3.0], dtype=dtype) sol = optx.minimise(f, solver, y0, max_steps=500) - assert tree_allclose(sol.value, jnp.array([29.5, 9.0]), rtol=1e-5, atol=1e-5) + assert tree_allclose( + sol.value, jnp.array([29.5, 9.0], dtype=dtype), rtol=1e-5, atol=1e-5 + ) def test_optax_recompilation(): From 9ca911855f90207bb375fd66aa37d1a580e7bd4c Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Mon, 12 Aug 2024 12:56:25 +0300 Subject: [PATCH 06/12] Update init --- tests/helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index 367a7ce..995523e 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -384,8 +384,8 @@ def get_weights(model): ) diagonal_bowl_init_complex = ( - {"a": 0.05 * jnp.ones((2, 3, 3), dtype=jnp.complex128)}, - (0.05j * jnp.ones(2, dtype=jnp.complex128)), + {"a": (0.05 + 0.01j) * jnp.ones((2, 3, 3), dtype=jnp.complex128)}, + ((0.01 + 0.05j) * jnp.ones(2, dtype=jnp.complex128)), ) leaves_complex, treedef_complex = jtu.tree_flatten(diagonal_bowl_init_complex) key = jr.PRNGKey(17) @@ -482,7 +482,7 @@ def get_weights(model): ( square_minus_one, jnp.array(-1.0), - jnp.array(1.0, dtype=jnp.complex128), + jnp.array(1.0 + 1.0j, dtype=jnp.complex128), None, ), ) From e596abf3ac71f1ea46cb2271fcfc995d606fcc98 Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Mon, 12 Aug 2024 14:56:53 +0300 Subject: [PATCH 07/12] Fix gauss-newton --- optimistix/_solver/gauss_newton.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimistix/_solver/gauss_newton.py b/optimistix/_solver/gauss_newton.py index 17d8a06..939a3f3 100644 --- a/optimistix/_solver/gauss_newton.py +++ b/optimistix/_solver/gauss_newton.py @@ -76,6 +76,7 @@ def newton_step( out = lx.linear_solve(operator, vector, linear_solver) newton = out.value result = RESULTS.promote(out.result) + newton = jax.tree_map(jnp.conj, newton) return newton, result From 7e7503878b0eb9f2797227c6b51e78cab0bf9fbe Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Mon, 12 Aug 2024 15:13:03 +0300 Subject: [PATCH 08/12] Fix some more steps with conj --- optimistix/_solver/dogleg.py | 5 +++-- optimistix/_solver/levenberg_marquardt.py | 2 +- optimistix/_solver/nonlinear_cg.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/optimistix/_solver/dogleg.py b/optimistix/_solver/dogleg.py index ff073b8..67a70b5 100644 --- a/optimistix/_solver/dogleg.py +++ b/optimistix/_solver/dogleg.py @@ -74,12 +74,13 @@ def query( state: _DoglegDescentState, ) -> _DoglegDescentState: del state + conj_grad = jax.tree_map(jnp.conj, f_info.grad) # Compute `denom = grad^T Hess grad.` if isinstance(f_info, FunctionInfo.EvalGradHessian): - denom = tree_dot(f_info.grad, f_info.hessian.mv(f_info.grad)) + denom = tree_dot(conj_grad, f_info.hessian.mv(f_info.grad)) elif isinstance(f_info, FunctionInfo.ResidualJac): # Use Gauss--Newton approximation `Hess ~ J^T J` - denom = sum_squares(f_info.jac.mv(f_info.grad)) + denom = sum_squares(f_info.jac.mv(conj_grad)) else: raise ValueError( "`DoglegDescent` can only be used with least-squares solvers, or " diff --git a/optimistix/_solver/levenberg_marquardt.py b/optimistix/_solver/levenberg_marquardt.py index 82fabda..8abebc5 100644 --- a/optimistix/_solver/levenberg_marquardt.py +++ b/optimistix/_solver/levenberg_marquardt.py @@ -78,7 +78,7 @@ def damped_newton_step( "provide (approximate) Hessian information." ) linear_sol = lx.linear_solve(operator, vector, linear_solver, throw=False) - return linear_sol.value, RESULTS.promote(linear_sol.result) + return jax.tree_map(jnp.conj, linear_sol.value), RESULTS.promote(linear_sol.result) class _DampedNewtonDescentState(eqx.Module, strict=True): diff --git a/optimistix/_solver/nonlinear_cg.py b/optimistix/_solver/nonlinear_cg.py index e7fb5f9..5d7a658 100644 --- a/optimistix/_solver/nonlinear_cg.py +++ b/optimistix/_solver/nonlinear_cg.py @@ -145,7 +145,7 @@ def query( # `state.{grad, y_diff} = 0`, i.e. our previous step hit a local minima, then # on this next step we'll again just use gradient descent, and stop. beta = self.method(f_info.grad, state.grad, state.y_diff) - neg_grad = (-(f_info.grad**ω)).ω + neg_grad = (-(jax.tree_map(jnp.conj, f_info.grad) ** ω)).ω with jax.numpy_dtype_promotion("standard"): nonlinear_cg_direction = (neg_grad**ω + beta * state.y_diff**ω).ω # Check if this is a descent direction. Use gradient descent if it isn't. From ced8903f5fdc025ad4e0095dd4455833fd84e369 Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Mon, 12 Aug 2024 15:22:59 +0300 Subject: [PATCH 09/12] Fix some more steps with conj --- optimistix/_solver/trust_region.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optimistix/_solver/trust_region.py b/optimistix/_solver/trust_region.py index 83f5dfb..1238825 100644 --- a/optimistix/_solver/trust_region.py +++ b/optimistix/_solver/trust_region.py @@ -3,6 +3,7 @@ from typing_extensions import TypeAlias import equinox as eqx +import jax import jax.numpy as jnp from equinox import AbstractVar from equinox.internal import ω @@ -166,7 +167,7 @@ def predict_reduction( if isinstance(f_info, FunctionInfo.EvalGradHessian): # Minimisation algorithm. Directly compute the quadratic approximation. return tree_dot( - y_diff, + jax.tree_map(jnp.conj, y_diff), (f_info.grad**ω + 0.5 * f_info.hessian.mv(y_diff) ** ω).ω, ).real elif isinstance(f_info, FunctionInfo.ResidualJac): @@ -273,7 +274,7 @@ def predict_reduction( FunctionInfo.ResidualJac, ), ): - return tree_dot(f_info.grad, y_diff).real + return tree_dot(jax.tree_map(jnp.conj, f_info.grad), y_diff).real else: raise ValueError( "Cannot use `LinearTrustRegion` with this solver. This is because " From 347bd2055ddabbb881da3be80da467d2ceea7af3 Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Mon, 12 Aug 2024 16:08:29 +0300 Subject: [PATCH 10/12] Tree dot already conjugates --- optimistix/_solver/dogleg.py | 8 +++----- optimistix/_solver/gauss_newton.py | 6 +++--- optimistix/_solver/trust_region.py | 5 ++--- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/optimistix/_solver/dogleg.py b/optimistix/_solver/dogleg.py index 67a70b5..1d51620 100644 --- a/optimistix/_solver/dogleg.py +++ b/optimistix/_solver/dogleg.py @@ -77,7 +77,7 @@ def query( conj_grad = jax.tree_map(jnp.conj, f_info.grad) # Compute `denom = grad^T Hess grad.` if isinstance(f_info, FunctionInfo.EvalGradHessian): - denom = tree_dot(conj_grad, f_info.hessian.mv(f_info.grad)) + denom = tree_dot(f_info.grad, f_info.hessian.mv(f_info.grad)) elif isinstance(f_info, FunctionInfo.ResidualJac): # Use Gauss--Newton approximation `Hess ~ J^T J` denom = sum_squares(f_info.jac.mv(conj_grad)) @@ -92,9 +92,7 @@ def query( # Compute `grad^T grad / (grad^T Hess grad)` with jax.numpy_dtype_promotion("standard"): - scaling = jnp.where( - denom_nonzero, sum_squares(f_info.grad) / safe_denom, 0.0 - ) + scaling = jnp.where(denom_nonzero, sum_squares(conj_grad) / safe_denom, 0.0) scaling = cast(Array, scaling) # Downhill towards the bottom of the quadratic basin. @@ -104,7 +102,7 @@ def query( # Downhill steepest descent. with jax.numpy_dtype_promotion("standard"): - cauchy = (-scaling * f_info.grad**ω).ω + cauchy = (-scaling * conj_grad**ω).ω cauchy_norm = self.trust_region_norm(cauchy) return _DoglegDescentState( diff --git a/optimistix/_solver/gauss_newton.py b/optimistix/_solver/gauss_newton.py index 939a3f3..80216b5 100644 --- a/optimistix/_solver/gauss_newton.py +++ b/optimistix/_solver/gauss_newton.py @@ -59,7 +59,8 @@ def newton_step( value.) """ if isinstance(f_info, FunctionInfo.EvalGradHessianInv): - newton = f_info.hessian_inv.mv(f_info.grad) + conj_grad = jax.tree_map(jnp.conj, f_info.grad) + newton = f_info.hessian_inv.mv(conj_grad) result = RESULTS.successful else: if isinstance(f_info, FunctionInfo.EvalGradHessian): @@ -73,10 +74,9 @@ def newton_step( "Cannot use a Newton descent with a solver that only evaluates the " "gradient, or only the function itself." ) - out = lx.linear_solve(operator, vector, linear_solver) + out = lx.linear_solve(operator, jax.tree_map(jnp.conj, vector), linear_solver) newton = out.value result = RESULTS.promote(out.result) - newton = jax.tree_map(jnp.conj, newton) return newton, result diff --git a/optimistix/_solver/trust_region.py b/optimistix/_solver/trust_region.py index 1238825..a3d0b0f 100644 --- a/optimistix/_solver/trust_region.py +++ b/optimistix/_solver/trust_region.py @@ -3,7 +3,6 @@ from typing_extensions import TypeAlias import equinox as eqx -import jax import jax.numpy as jnp from equinox import AbstractVar from equinox.internal import ω @@ -167,7 +166,7 @@ def predict_reduction( if isinstance(f_info, FunctionInfo.EvalGradHessian): # Minimisation algorithm. Directly compute the quadratic approximation. return tree_dot( - jax.tree_map(jnp.conj, y_diff), + f_info.grad, (f_info.grad**ω + 0.5 * f_info.hessian.mv(y_diff) ** ω).ω, ).real elif isinstance(f_info, FunctionInfo.ResidualJac): @@ -274,7 +273,7 @@ def predict_reduction( FunctionInfo.ResidualJac, ), ): - return tree_dot(jax.tree_map(jnp.conj, f_info.grad), y_diff).real + return tree_dot(f_info.grad, y_diff).real else: raise ValueError( "Cannot use `LinearTrustRegion` with this solver. This is because " From 92c532284defa887b89f59e45c4d250547787c0a Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Mon, 12 Aug 2024 16:19:46 +0300 Subject: [PATCH 11/12] Gradient descent --- optimistix/_solver/gradient_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimistix/_solver/gradient_methods.py b/optimistix/_solver/gradient_methods.py index 90f9398..ac84d0b 100644 --- a/optimistix/_solver/gradient_methods.py +++ b/optimistix/_solver/gradient_methods.py @@ -70,7 +70,7 @@ def query( ) if self.norm is not None: grad = (grad**ω / self.norm(grad)).ω - return _SteepestDescentState(grad) + return _SteepestDescentState(jax.tree_map(jnp.conj, grad)) def step( self, step_size: Scalar, state: _SteepestDescentState From de1529a5145eca747ff0135f3008877fb43b2d84 Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Mon, 12 Aug 2024 17:36:25 +0300 Subject: [PATCH 12/12] Fix more conjugations --- optimistix/_solver/backtracking.py | 3 ++- optimistix/_solver/nonlinear_cg.py | 5 +++-- optimistix/_solver/optax.py | 4 +++- optimistix/_solver/trust_region.py | 5 +++-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/optimistix/_solver/backtracking.py b/optimistix/_solver/backtracking.py index 5bbb7ac..040ab93 100644 --- a/optimistix/_solver/backtracking.py +++ b/optimistix/_solver/backtracking.py @@ -2,6 +2,7 @@ from typing_extensions import TypeAlias import equinox as eqx +import jax import jax.numpy as jnp from equinox.internal import ω from jaxtyping import Array, Bool, Scalar, ScalarLike @@ -85,7 +86,7 @@ def step( ) y_diff = (y_eval**ω - y**ω).ω - predicted_reduction = tree_dot(grad, y_diff).real + predicted_reduction = tree_dot(jax.tree_map(jnp.conj, grad), y_diff).real # Terminate when the Armijo condition is satisfied. That is, `fn(y_eval)` # must do better than its linear approximation: # `fn(y_eval) < fn(y) + grad•y_diff` diff --git a/optimistix/_solver/nonlinear_cg.py b/optimistix/_solver/nonlinear_cg.py index 5d7a658..48766a6 100644 --- a/optimistix/_solver/nonlinear_cg.py +++ b/optimistix/_solver/nonlinear_cg.py @@ -145,12 +145,13 @@ def query( # `state.{grad, y_diff} = 0`, i.e. our previous step hit a local minima, then # on this next step we'll again just use gradient descent, and stop. beta = self.method(f_info.grad, state.grad, state.y_diff) - neg_grad = (-(jax.tree_map(jnp.conj, f_info.grad) ** ω)).ω + conj_grad = jax.tree_map(jnp.conj, f_info.grad) + neg_grad = (-(conj_grad**ω)).ω with jax.numpy_dtype_promotion("standard"): nonlinear_cg_direction = (neg_grad**ω + beta * state.y_diff**ω).ω # Check if this is a descent direction. Use gradient descent if it isn't. y_diff = tree_where( - tree_dot(f_info.grad, nonlinear_cg_direction) < 0, + tree_dot(conj_grad, nonlinear_cg_direction).real < 0, nonlinear_cg_direction, neg_grad, ) diff --git a/optimistix/_solver/optax.py b/optimistix/_solver/optax.py index 1f6a26b..8d9d7c9 100644 --- a/optimistix/_solver/optax.py +++ b/optimistix/_solver/optax.py @@ -97,7 +97,9 @@ def step( ("loss" in self.verbose, "Loss", f), ("y" in self.verbose, "y", y), ) - updates, new_opt_state = self.optim.update(grads, state.opt_state) + updates, new_opt_state = self.optim.update( + jax.tree_map(jnp.conj, grads), state.opt_state + ) new_y = eqx.apply_updates(y, updates) terminate = cauchy_termination( self.rtol, diff --git a/optimistix/_solver/trust_region.py b/optimistix/_solver/trust_region.py index a3d0b0f..4ee62b9 100644 --- a/optimistix/_solver/trust_region.py +++ b/optimistix/_solver/trust_region.py @@ -3,6 +3,7 @@ from typing_extensions import TypeAlias import equinox as eqx +import jax import jax.numpy as jnp from equinox import AbstractVar from equinox.internal import ω @@ -166,7 +167,7 @@ def predict_reduction( if isinstance(f_info, FunctionInfo.EvalGradHessian): # Minimisation algorithm. Directly compute the quadratic approximation. return tree_dot( - f_info.grad, + jax.tree_map(jnp.conj, f_info.grad), (f_info.grad**ω + 0.5 * f_info.hessian.mv(y_diff) ** ω).ω, ).real elif isinstance(f_info, FunctionInfo.ResidualJac): @@ -273,7 +274,7 @@ def predict_reduction( FunctionInfo.ResidualJac, ), ): - return tree_dot(f_info.grad, y_diff).real + return tree_dot(jax.tree_map(jnp.conj, f_info.grad), y_diff).real else: raise ValueError( "Cannot use `LinearTrustRegion` with this solver. This is because "