Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
s-m-e committed Jan 25, 2024
1 parent f6e7adc commit d4b1ab3
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 58 deletions.
20 changes: 18 additions & 2 deletions src/hapsira/core/math/ivp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
from ._solve import solve_ivp
from ._brentq import brentq
from ._brentq import (
BRENTQ_CONVERGED,
BRENTQ_SIGNERR,
BRENTQ_CONVERR,
BRENTQ_ERROR,
BRENTQ_XTOL,
BRENTQ_RTOL,
BRENTQ_MAXITER,
brentq_hf,
)

__all__ = [
"solve_ivp",
"brentq",
"BRENTQ_CONVERGED",
"BRENTQ_SIGNERR",
"BRENTQ_CONVERR",
"BRENTQ_ERROR",
"BRENTQ_XTOL",
"BRENTQ_RTOL",
"BRENTQ_MAXITER",
"brentq_hf",
]
94 changes: 46 additions & 48 deletions src/hapsira/core/math/ivp/_brentq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,26 @@
from ...jit import hjit


CONVERGED = 0
SIGNERR = -1
CONVERR = -2
__all__ = [
"BRENTQ_CONVERGED",
"BRENTQ_SIGNERR",
"BRENTQ_CONVERR",
"BRENTQ_ERROR",
"BRENTQ_XTOL",
"BRENTQ_RTOL",
"BRENTQ_MAXITER",
"brentq_hf",
]


BRENTQ_CONVERGED = 0
BRENTQ_SIGNERR = -1
BRENTQ_CONVERR = -2
BRENTQ_ERROR = -3

BRENTQ_ITER = 100
BRENTQ_XTOL = 2e-12
BRENTQ_RTOL = 4 * EPS
BRENTQ_MAXITER = 100


@hjit("f(f,f)")
Expand All @@ -23,32 +36,48 @@ def _signbit_s_hf(a):
return a < 0


@hjit("f(F(f(f)),f,f,f,f,f)", forceobj=True, nopython=False, cache=False)
def _brentq_hf(
@hjit("Tuple([f,i8])(F(f(f)),f,f,f,f,f)", forceobj=True, nopython=False, cache=False)
def brentq_hf(
func, # callback_type
xa, # double
xb, # double
xtol, # double
rtol, # double
iter_, # int
maxiter, # int
):
"""
Loosely adapted from
https://github.com/scipy/scipy/blob/d23363809572e9a44074a3f06f66137083446b48/scipy/optimize/_zeros_py.py#L682
"""

# if not xtol + 0. > 0:
# return 0., BRENTQ_ERROR
# if not rtol + 0. >= BRENTQ_RTOL:
# return 0., BRENTQ_ERROR
# if not maxiter + 0 >= 0:
# return 0., BRENTQ_ERROR

xpre, xcur = xa, xb
xblk = 0.0
fpre, fcur, fblk = 0.0, 0.0, 0.0
spre, scur = 0.0, 0.0

fpre = func(xpre)
assert not isnan(fpre)
if isnan(fpre):
return 0.0, BRENTQ_ERROR

fcur = func(xcur)
assert not isnan(fcur)
if isnan(fcur):
return 0.0, BRENTQ_ERROR

if fpre == 0:
return xpre, CONVERGED
return xpre, BRENTQ_CONVERGED
if fcur == 0:
return xcur, CONVERGED
return xcur, BRENTQ_CONVERGED
if _signbit_s_hf(fpre) == _signbit_s_hf(fcur):
return 0.0, SIGNERR
return 0.0, BRENTQ_SIGNERR

for _ in range(0, iter_):
for _ in range(0, maxiter):
if fpre != 0 and fcur != 0 and _signbit_s_hf(fpre) != _signbit_s_hf(fcur):
xblk = xpre
fblk = fpre
Expand All @@ -66,7 +95,7 @@ def _brentq_hf(
delta = (xtol + rtol * fabs(xcur)) / 2
sbis = (xblk - xcur) / 2
if fcur == 0 or fabs(sbis) < delta:
return xcur, CONVERGED
return xcur, BRENTQ_CONVERGED

if fabs(spre) > delta and fabs(fcur) < fabs(fpre):
if xpre == xblk:
Expand Down Expand Up @@ -95,38 +124,7 @@ def _brentq_hf(
xcur += delta if sbis > 0 else -delta

fcur = func(xcur)
assert not isnan(fcur)

return xcur, CONVERR


@hjit("f(F(f(f)),f,f,f,f,f)", forceobj=True, nopython=False, cache=False)
def brentq(
f,
a,
b,
xtol=BRENTQ_XTOL,
rtol=BRENTQ_RTOL,
maxiter=BRENTQ_ITER,
):
"""
Loosely adapted from
https://github.com/scipy/scipy/blob/d23363809572e9a44074a3f06f66137083446b48/scipy/optimize/_zeros_py.py#L682
"""

assert xtol > 0
assert rtol >= BRENTQ_RTOL
assert maxiter >= 0

zero, error_num = _brentq_hf(
f,
a,
b,
xtol,
rtol,
maxiter,
)

assert error_num == CONVERGED
if isnan(fcur):
return 0.0, BRENTQ_ERROR

return zero
return xcur, BRENTQ_CONVERR
11 changes: 7 additions & 4 deletions src/hapsira/core/math/ivp/_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from ._brentq import brentq
from ._brentq import brentq_hf, BRENTQ_CONVERGED, BRENTQ_MAXITER
from ._solution import OdeSolution
from ._rk import DOP853
from ...math.linalg import EPS
Expand Down Expand Up @@ -42,13 +42,16 @@ def _solve_event_equation(
def wrapper(t):
return event(t, sol(t), argk)

return brentq(
value, status = brentq_hf(
wrapper,
t_old,
t,
xtol=4 * EPS,
rtol=4 * EPS,
4 * EPS,
4 * EPS,
BRENTQ_MAXITER,
)
assert BRENTQ_CONVERGED == status
return value


def _handle_events(
Expand Down
23 changes: 19 additions & 4 deletions src/hapsira/threebody/restricted.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import numpy as np

from hapsira.core.jit import hjit
from hapsira.core.math.ivp import brentq
from hapsira.core.math.ivp import (
brentq_hf,
BRENTQ_XTOL,
BRENTQ_RTOL,
BRENTQ_MAXITER,
BRENTQ_CONVERGED,
)
from hapsira.util import norm


Expand Down Expand Up @@ -51,15 +57,24 @@ def eq_L123(xi):
tol = 1e-11 # `brentq` uses a xtol of 2e-12, so it should be covered
a = -pi2 + tol
b = 1 - pi2 - tol
xi = brentq(eq_L123, a, b)
xi, status = brentq_hf(
eq_L123, a, b, BRENTQ_XTOL, BRENTQ_RTOL, BRENTQ_MAXITER
) # TODO call into hf
assert status == BRENTQ_CONVERGED
lp[0] = xi + pi2

# L2
xi = brentq(eq_L123, 1, 1.5)
xi, status = brentq_hf(
eq_L123, 1, 1.5, BRENTQ_XTOL, BRENTQ_RTOL, BRENTQ_MAXITER
) # TODO call into hf
assert status == BRENTQ_CONVERGED
lp[1] = xi + pi2

# L3
xi = brentq(eq_L123, -1.5, -1)
xi, status = brentq_hf(
eq_L123, -1.5, -1, BRENTQ_XTOL, BRENTQ_RTOL, BRENTQ_MAXITER
) # TODO call into hf
assert status == BRENTQ_CONVERGED
lp[2] = xi + pi2

# L4, L5
Expand Down

0 comments on commit d4b1ab3

Please sign in to comment.