Skip to content

Commit

Permalink
Merge pull request #10 from chrhansk/feature-deriv-checks
Browse files Browse the repository at this point in the history
Improve derivative checks
  • Loading branch information
chrhansk authored Oct 25, 2023
2 parents 5989bc6 + 47ee707 commit 8ec2b97
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 20 deletions.
56 changes: 53 additions & 3 deletions pygradflow/deriv_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,43 @@
from pygradflow.params import Params


class DerivError(ValueError):
def __init__(self, expected_value, actual_value, col_index, atol) -> None:
self.expected_value = expected_value
self.actual_value = actual_value
self.atol = atol

self.invalid_deriv = np.isclose(self.expected_value,
self.actual_value,
atol=self.atol)

self.invalid_deriv = np.logical_not(self.invalid_deriv)
self.invalid_indices = np.where(self.invalid_deriv)[0]

self.deriv_diffs = np.abs(self.expected_value - self.actual_value)
self.max_deriv_diff = self.deriv_diffs.max()

(num_rows, _) = self.deriv_diffs.shape

invalid_indices = np.zeros((num_rows, 2), dtype=int)
invalid_indices[:, 0] = self.invalid_indices
invalid_indices[:, 1] = col_index

self.invalid_indices = invalid_indices
self.col_index = col_index

def __str__(self):
num_invalid_indices = self.invalid_indices.size

message = (f"Expected derivative: {self.expected_value} "
f"and actual (findiff) derivative: {self.actual_value} "
f"differ at the {num_invalid_indices} "
f"indices: {self.invalid_indices} "
f"(max diff: {self.max_deriv_diff}, tolerance: {self.atol})")

return message


def deriv_check(
f: Callable,
xval: np.ndarray,
Expand All @@ -15,19 +52,24 @@ def deriv_check(
(n,) = xval.shape

fval = f(xval)

fval = np.atleast_1d(fval)

(m,) = fval.shape

sparse_dval = sp.sparse.issparse(dval)

eps = params.deriv_pert

dsparse = False

if sp.sparse.issparse(dval):
if sparse_dval:
dval = dval.tocsc()
dsparse = True
else:
dval = np.atleast_2d(dval)

assert dval.shape == (m, n)

xtest = np.copy(xval)

for i in range(n):
Expand All @@ -37,11 +79,19 @@ def deriv_check(

apx_dval = (testval - fval) / eps

if apx_dval.ndim == 1:
apx_dval = apx_dval[:, np.newaxis]

darray = dval[:, i]

if dsparse:
darray = darray.toarray()

assert np.allclose(darray, apx_dval[:, None], atol=params.deriv_tol)
darray = np.atleast_2d(darray)

assert darray.shape == apx_dval.shape

if not np.allclose(darray, apx_dval, atol=params.deriv_tol):
raise DerivError(darray, apx_dval, i, params.deriv_tol)

xtest[i] -= eps
13 changes: 11 additions & 2 deletions pygradflow/params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from enum import Enum, auto
from enum import Enum, Flag, auto
from dataclasses import dataclass

import numpy as np
Expand Down Expand Up @@ -33,6 +33,13 @@ class Precision(Enum):
Double = auto()


class DerivCheck(Flag):
NoCheck = 0
CheckFirst = (1 << 0)
CheckSecond = (1 << 1)
CheckAll = CheckFirst | CheckSecond


@dataclass
class Params:
rho: float = 1e2
Expand All @@ -43,7 +50,9 @@ class Params:
theta_ref: float = 0.5

lamb_init: float = 1.0
# Up to 1e-6 for single precision?
lamb_min: float = 1e-12
lamb_max: float = 1e12
lamb_inc: float = 2.0
lamb_red: float = 0.5

Expand All @@ -62,7 +71,7 @@ class Params:
linear_solver_type: LinearSolverType = LinearSolverType.LU
penalty_update: PenaltyUpdate = PenaltyUpdate.DualNorm

deriv_check: bool = False
deriv_check: DerivCheck = DerivCheck.NoCheck
deriv_pert: float = 1e-8
deriv_tol: float = 1e-4

Expand Down
34 changes: 20 additions & 14 deletions pygradflow/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,30 @@ def next_iterates():

def _deriv_check(self, x: np.ndarray, y: np.ndarray) -> None:
from pygradflow.deriv_check import deriv_check
from pygradflow.params import DerivCheck

eval = self.evaluator
params = self.params
deriv_check_type = params.deriv_check

logger.info("Checking objective derivative")
if deriv_check_type == DerivCheck.NoCheck:
return

deriv_check(lambda x: eval.obj(x), x, eval.obj_grad(x), params)
if deriv_check_type & DerivCheck.CheckFirst:
logger.info("Checking objective derivative")
deriv_check(lambda x: eval.obj(x), x, eval.obj_grad(x), params)

logger.info("Checking constraint derivative")
logger.info("Checking constraint derivative")
deriv_check(lambda x: eval.cons(x), x, eval.cons_jac(x), params)

deriv_check(lambda x: eval.cons(x), x, eval.cons_jac(x), params)
if deriv_check_type & DerivCheck.CheckSecond:
logger.info("Checking Hessian")

logger.info("Checking Hessian")

deriv_check(
lambda x: eval.obj_grad(x) + eval.cons_jac(x).T.dot(y),
x,
eval.lag_hess(x, y),
params,
)
deriv_check(
lambda x: eval.obj_grad(x) + eval.cons_jac(x).T.dot(y),
x,
eval.lag_hess(x, y),
params)

def print_result(self, iterate: Iterate) -> None:
rho = self.rho
Expand Down Expand Up @@ -144,8 +148,7 @@ def solve(self, x_0: np.ndarray, y_0: np.ndarray) -> Result:

controller = DistanceRatioController(problem, params)

if params.deriv_check:
self._deriv_check(x, y)
self._deriv_check(x, y)

iterate = Iterate(problem, params, x, y, self.evaluator)
self.rho = self.penalty.initial(iterate)
Expand All @@ -169,6 +172,9 @@ def solve(self, x_0: np.ndarray, y_0: np.ndarray) -> Result:
accept = step_result.accepted
lamb = step_result.lamb

if lamb >= params.lamb_max:
raise Exception(f"Inverse step size {lamb} exceeded maximum {params.lamb_max} (incorrect derivatives?)")

accept_str = (
colored("Accept", "green") if accept else colored("Reject", "red")
)
Expand Down
58 changes: 57 additions & 1 deletion tests/pygradflow/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pygradflow.params import (
NewtonType,
Params,
DerivCheck,
PenaltyUpdate,
Precision,
StepSolverType,
Expand Down Expand Up @@ -347,7 +348,9 @@ def test_solve_tame():
problem = Tame()
x_0 = np.array([0.0, 0.0])
y_0 = np.array([0.0])
params = Params(newton_type=NewtonType.Full, deriv_check=True)
params = Params(newton_type=NewtonType.Full,
deriv_check=DerivCheck.CheckAll)

solver = Solver(problem, params)

result = solver.solve(x_0, y_0)
Expand All @@ -371,3 +374,56 @@ def test_solve_with_newton_types(hs71_instance, newton_type):
result = solver.solve(x_0, y_0)

assert result.success


def test_grad_errors():
problem = Tame()

def obj_grad(x):
g = Tame().obj_grad(x)
g[0] += 1.
return g

problem.obj_grad = obj_grad

x_0 = np.array([0.0, 0.0])
y_0 = np.array([0.0])
params = Params(deriv_check=DerivCheck.CheckAll)

solver = Solver(problem, params)

with pytest.raises(ValueError) as e:
solver.solve(x_0, y_0)

e = e.value
assert (e.invalid_indices == [[0, 0]]).all()


def test_cons_errors():
problem = Tame()

invalid_index = 1

def cons_jac(x):
g = Tame().cons_jac(x)
g.data[invalid_index] += 1.
return g

problem.cons_jac = cons_jac

x_0 = np.array([0.0, 0.0])
y_0 = np.array([0.0])
params = Params(deriv_check=DerivCheck.CheckAll)

solver = Solver(problem, params)

with pytest.raises(ValueError) as e:
solver.solve(x_0, y_0)

e = e.value
jac = Tame().cons_jac(x_0)

invalid_row = jac.row[invalid_index]
invalid_col = jac.col[invalid_index]

assert (e.invalid_indices == [[invalid_row, invalid_col]]).all()

0 comments on commit 8ec2b97

Please sign in to comment.