Skip to content

Commit

Permalink
all brentq in one place
Browse files Browse the repository at this point in the history
  • Loading branch information
s-m-e committed Jan 10, 2024
1 parent 292908b commit 2b7cf89
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 253 deletions.
262 changes: 140 additions & 122 deletions src/hapsira/core/math/ivp/_brentq.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,141 @@
from math import fabs

import operator

from numba import njit, jit
import numpy as np


from ._solver import brentq_sf
CONVERGED = 0
SIGNERR = -1
CONVERR = -2
# EVALUEERR = -3
INPROGRESS = 1

BRENTQ_ITER = 100
BRENTQ_XTOL = 2e-12
BRENTQ_RTOL = 4 * np.finfo(float).eps


@njit
def _min_hf(a, b):
return a if a < b else b


@njit
def _signbit_hf(a):
return a < 0


@jit
def _brentq_hf(
func, # callback_type
xa, # double
xb, # double
xtol, # double
rtol, # double
iter_, # int
):
xpre, xcur = xa, xb
xblk = 0.0
fpre, fcur, fblk = 0.0, 0.0, 0.0
spre, scur = 0.0, 0.0

iterations = 0

fpre = func(xpre)
fcur = func(xcur)
funcalls = 2
if fpre == 0:
return xpre, funcalls, iterations, CONVERGED
if fcur == 0:
return xcur, funcalls, iterations, CONVERGED
if _signbit_hf(fpre) == _signbit_hf(fcur):
return 0.0, funcalls, iterations, SIGNERR

iterations = 0
for _ in range(0, iter_):
iterations += 1
if fpre != 0 and fcur != 0 and _signbit_hf(fpre) != _signbit_hf(fcur):
xblk = xpre
fblk = fpre
scur = xcur - xpre
spre = scur
if fabs(fblk) < fabs(fcur):
xpre = xcur
xcur = xblk
xblk = xpre

fpre = fcur
fcur = fblk
fblk = fpre

delta = (xtol + rtol * fabs(xcur)) / 2
sbis = (xblk - xcur) / 2
if fcur == 0 or fabs(sbis) < delta:
return xcur, funcalls, iterations, CONVERGED

if fabs(spre) > delta and fabs(fcur) < fabs(fpre):
if xpre == xblk:
stry = -fcur * (xcur - xpre) / (fcur - fpre)
else:
dpre = (fpre - fcur) / (xpre - xcur)
dblk = (fblk - fcur) / (xblk - xcur)
stry = (
-fcur * (fblk * dblk - fpre * dpre) / (dblk * dpre * (fblk - fpre))
)
if 2 * fabs(stry) < _min_hf(fabs(spre), 3 * fabs(sbis) - delta):
spre = scur
scur = stry
else:
spre = sbis
scur = sbis
else:
spre = sbis
scur = sbis

xpre = xcur
fpre = fcur
if fabs(scur) > delta:
xcur += scur
else:
xcur += delta if sbis > 0 else -delta

fcur = func(xcur)
funcalls += 1

return xcur, funcalls, iterations, CONVERR


@jit
def brentq_sf(
func, # func
a, # double
b, # double
xtol, # double
rtol, # double
iter_, # int
):
if xtol < 0:
raise ValueError("xtol must be >= 0")
if iter_ < 0:
raise ValueError("maxiter should be > 0")

zero, funcalls, iterations, error_num = _brentq_hf(
func,
a,
b,
xtol,
rtol,
iter_,
)

if error_num == SIGNERR:
raise ValueError("f(a) and f(b) must have different signs")
if error_num == CONVERR:
raise RuntimeError("Failed to converge after %d iterations." % iterations)

return zero # double


def _wrap_nan_raise(f):
Expand All @@ -22,137 +154,23 @@ def f_raise(x):
return f_raise


_iter = 100
_xtol = 2e-12
_rtol = 4 * np.finfo(float).eps


def brentq(
f,
a,
b,
xtol=_xtol,
rtol=_rtol,
maxiter=_iter,
xtol=BRENTQ_XTOL,
rtol=BRENTQ_RTOL,
maxiter=BRENTQ_ITER,
):
"""
Find a root of a function in a bracketing interval using Brent's method.
Uses the classic Brent's method to find a root of the function `f` on
the sign changing interval [a , b]. Generally considered the best of the
rootfinding routines here. It is a safe version of the secant method that
uses inverse quadratic extrapolation. Brent's method combines root
bracketing, interval bisection, and inverse quadratic interpolation. It is
sometimes known as the van Wijngaarden-Dekker-Brent method. Brent (1973)
claims convergence is guaranteed for functions computable within [a,b].
[Brent1973]_ provides the classic description of the algorithm. Another
description can be found in a recent edition of Numerical Recipes, including
[PressEtal1992]_. A third description is at
http://mathworld.wolfram.com/BrentsMethod.html. It should be easy to
understand the algorithm just by reading our code. Our code diverges a bit
from standard presentations: we choose a different formula for the
extrapolation step.
Parameters
----------
f : function
Python function returning a number. The function :math:`f`
must be continuous, and :math:`f(a)` and :math:`f(b)` must
have opposite signs.
a : scalar
One end of the bracketing interval :math:`[a, b]`.
b : scalar
The other end of the bracketing interval :math:`[a, b]`.
xtol : number, optional
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
parameter must be positive. For nice functions, Brent's
method will often satisfy the above condition with ``xtol/2``
and ``rtol/2``. [Brent1973]_
rtol : number, optional
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
parameter cannot be smaller than its default value of
``4*np.finfo(float).eps``. For nice functions, Brent's
method will often satisfy the above condition with ``xtol/2``
and ``rtol/2``. [Brent1973]_
maxiter : int, optional
If convergence is not achieved in `maxiter` iterations, an error is
raised. Must be >= 0.
full_output : bool, optional
If `full_output` is False, the root is returned. If `full_output` is
True, the return value is ``(x, r)``, where `x` is the root, and `r` is
a `RootResults` object.
disp : bool, optional
If True, raise RuntimeError if the algorithm didn't converge.
Otherwise, the convergence status is recorded in any `RootResults`
return object.
Returns
-------
root : float
Root of `f` between `a` and `b`.
r : `RootResults` (present if ``full_output = True``)
Object containing information about the convergence. In particular,
``r.converged`` is True if the routine converged.
Notes
-----
`f` must be continuous. f(a) and f(b) must have opposite signs.
Related functions fall into several classes:
multivariate local optimizers
`fmin`, `fmin_powell`, `fmin_cg`, `fmin_bfgs`, `fmin_ncg`
nonlinear least squares minimizer
`leastsq`
constrained multivariate optimizers
`fmin_l_bfgs_b`, `fmin_tnc`, `fmin_cobyla`
global optimizers
`basinhopping`, `brute`, `differential_evolution`
local scalar minimizers
`fminbound`, `brent`, `golden`, `bracket`
N-D root-finding
`fsolve`
1-D root-finding
`brenth`, `ridder`, `bisect`, `newton`
scalar fixed-point finder
`fixed_point`
References
----------
.. [Brent1973]
Brent, R. P.,
*Algorithms for Minimization Without Derivatives*.
Englewood Cliffs, NJ: Prentice-Hall, 1973. Ch. 3-4.
.. [PressEtal1992]
Press, W. H.; Flannery, B. P.; Teukolsky, S. A.; and Vetterling, W. T.
*Numerical Recipes in FORTRAN: The Art of Scientific Computing*, 2nd ed.
Cambridge, England: Cambridge University Press, pp. 352-355, 1992.
Section 9.3: "Van Wijngaarden-Dekker-Brent Method."
Examples
--------
>>> def f(x):
... return (x**2 - 1)
>>> from scipy import optimize
>>> root = optimize.brentq(f, -2, 0)
>>> root
-1.0
>>> root = optimize.brentq(f, 0, 2)
>>> root
1.0
Loosely adapted from
https://github.com/scipy/scipy/blob/d23363809572e9a44074a3f06f66137083446b48/scipy/optimize/_zeros_py.py#L682
"""
maxiter = operator.index(maxiter)
if xtol <= 0:
raise ValueError("xtol too small (%g <= 0)" % xtol)
if rtol < _rtol:
raise ValueError(f"rtol too small ({rtol:g} < {_rtol:g})")
if rtol < BRENTQ_RTOL:
raise ValueError(f"rtol too small ({rtol:g} < {BRENTQ_RTOL:g})")
f = _wrap_nan_raise(f)
r = brentq_sf(f, a, b, xtol, rtol, maxiter)
return r
Loading

0 comments on commit 2b7cf89

Please sign in to comment.