Skip to content

Commit

Permalink
Add helper to build hessian vector product
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Seyboldt <[email protected]>
  • Loading branch information
ricardoV94 and aseyboldt committed Jun 27, 2024
1 parent 94a055d commit c86763f
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 0 deletions.
10 changes: 10 additions & 0 deletions doc/tutorial/gradients.rst
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,16 @@ or, making use of the R-operator:
>>> f([4, 4], [2, 2])
array([ 4., 4.])

There is a builtin helper that uses the first method

>>> x = pt.dvector('x')
>>> v = pt.dvector('v')
>>> y = pt.sum(x ** 2)
>>> Hv = pytensor.gradient.hessian_vector_product(y, x, v)
>>> f = pytensor.function([x, v], Hv)
>>> f([4, 4], [2, 2])
array([ 4., 4.])


Final Pointers
==============
Expand Down
77 changes: 77 additions & 0 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,83 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
return as_list_or_tuple(using_list, using_tuple, hessians)


def hessian_vector_product(cost, wrt, p, **grad_kwargs):
"""Return the expression of the Hessian times a vector p.
Notes
-----
This function uses backward autodiff twice to obtain the desired expression.
You may want to manually build the equivalent expression by combining backward
followed by forward (if all Ops support it) autodiff.
See {ref}`docs/_tutcomputinggrads#Hessian-times-a-Vector` for how to do this.
Parameters
----------
cost: Scalar (0-dimensional) variable.
wrt: Vector (1-dimensional tensor) 'Variable' or list of Vectors
p: Vector (1-dimensional tensor) 'Variable' or list of Vectors
Each vector will be used for the hessp wirt to exach input variable
**grad_kwargs:
Keyword arguments passed to `grad` function.
Returns
-------
:class:` Vector or list of Vectors
The Hessian times p of the `cost` with respect to (elements of) `wrt`.
Examples
--------
>>> import numpy as np
>>> from scipy.optimize import minimize
>>> from pytensor import function
>>> from pytensor.tensor import vector
>>> from pytensor.gradient import grad, hessian_vector_product
>>>
>>> x = vector('x')
>>> p = vector('p')
>>>
>>> rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
>>> rosen_jac = grad(rosen, x)
>>> rosen_hessp = hessian_vector_product(rosen, x, p)
>>>
>>> rosen_fn = function([x], rosen)
>>> rosen_jac_fn = function([x], rosen_jac)
>>> rosen_hessp_fn = function([x, p], rosen_hessp)
>>> x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])
>>> res = minimize(
... rosen_fn,
... x0,
... method="Newton-CG",
... jac=rosen_jac_fn,
... hessp=rosen_hessp_fn,
... options={"xtol": 1e-8, "disp": True},
... )
Optimization terminated successfully.
Current function value: 0.000000
Iterations: 24
Function evaluations: 33
Gradient evaluations: 33
Hessian evaluations: 66
>>> res.x
array([1. , 1. , 1. , 0.99999999, 0.99999999])
"""
wrt_list = wrt if isinstance(wrt, Sequence) else [wrt]
p_list = p if isinstance(p, Sequence) else [p]
grad_wrt_list = grad(cost, wrt=wrt_list, **grad_kwargs)
hessian_cost = pytensor.tensor.add(
*[
(grad_wrt * p).sum()
for grad_wrt, p in zip(grad_wrt_list, p_list, strict=True)
]
)
Hp_list = grad(hessian_cost, wrt=wrt_list, **grad_kwargs)

if isinstance(wrt, Variable):
return Hp_list[0]
return Hp_list


def _is_zero(x):
"""
Returns 'yes', 'no', or 'maybe' indicating whether x
Expand Down
39 changes: 39 additions & 0 deletions tests/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
from scipy.optimize import rosen_hess_prod

import pytensor
import pytensor.tensor.basic as ptb
Expand All @@ -22,6 +23,7 @@
grad_scale,
grad_undefined,
hessian,
hessian_vector_product,
jacobian,
subgraph_grad,
zero_grad,
Expand Down Expand Up @@ -1081,3 +1083,40 @@ def test_jacobian_disconnected_inputs():
func_s = pytensor.function([s2], jacobian_s)
val = np.array(1.0).astype(pytensor.config.floatX)
assert np.allclose(func_s(val), np.zeros(1))


class TestHessianVectorProdudoct:
def test_rosen(self):
x = vector("x", dtype="float64")
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()

p = vector("p", dtype="float64")
rosen_hess_prod_pt = hessian_vector_product(rosen, wrt=x, p=p)

x_test = 0.1 * np.arange(9)
p_test = 0.5 * np.arange(9)
np.testing.assert_allclose(
rosen_hess_prod_pt.eval({x: x_test, p: p_test}),
rosen_hess_prod(x_test, p_test),
)

def test_multiple_wrt(self):
x = vector("x", dtype="float64")
y = vector("y", dtype="float64")
p_x = vector("p_x", dtype="float64")
p_y = vector("p_y", dtype="float64")

cost = (x**2 - y**2).sum()
hessp_x, hessp_y = hessian_vector_product(cost, wrt=[x, y], p=[p_x, p_y])

hessp_fn = pytensor.function([x, y, p_x, p_y], [hessp_x, hessp_y])
test = {
# x, y don't matter
"x": np.full((3,), np.nan),
"y": np.full((3,), np.nan),
"p_x": [1, 2, 3],
"p_y": [3, 2, 1],
}
hessp_x_eval, hessp_y_eval = hessp_fn(**test)
np.testing.assert_allclose(hessp_x_eval, [2, 4, 6])
np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2])

0 comments on commit c86763f

Please sign in to comment.