Skip to content

Commit

Permalink
Fix docs, check ny, more flexible plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
bjodah committed Jul 12, 2016
1 parent f36f089 commit f472624
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 12 deletions.
16 changes: 8 additions & 8 deletions pyodesys/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import os

from .util import _ensure_4args
from .util import _ensure_4args, _default
from .plotting import plot_result, plot_phase_plane


Expand Down Expand Up @@ -220,6 +220,9 @@ def integrate(self, xout, y0, params=(), **kwargs):
"""
intern_xout, intern_y0, self.internal_params = self.pre_process(
xout, y0, params)
if hasattr(self, 'ny'):
if len(intern_y0) != self.ny:
raise ValueError("Incorrect length of intern_y0")
integrator = kwargs.pop('integrator', None)
if integrator is None:
integrator = os.environ.get('PYODESYS_INTEGRATOR', 'scipy')
Expand Down Expand Up @@ -458,13 +461,10 @@ def _plot(self, cb, internal_xout=None, internal_yout=None,

if 'names' not in kwargs:
kwargs['names'] = getattr(self, 'names', None)
if (internal_xout, internal_yout, internal_params) == (None,)*3:
internal_xout = self.internal_xout
internal_yout = self.internal_yout
internal_params = self.internal_params
elif None in (internal_xout, internal_yout, internal_params):
raise ValueError("Pass either all or none of internal_* kwargs")
return cb(internal_xout, internal_yout, internal_params, **kwargs)

return cb(_default(internal_xout, self.internal_xout),
_default(internal_yout, self.internal_yout),
_default(internal_params, self.internal_params), **kwargs)

def plot_result(self, **kwargs):
""" Plots the integrated dependent variables from last integration.
Expand Down
4 changes: 2 additions & 2 deletions pyodesys/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def plot_result(x, y, params=(), indices=None, plot=None, plot_kwargs_cb=None,
post_processors : iterable of callback (default: tuple())
"""
import matplotlib.pyplot as plt

if plot is None:
from matplotlib.pyplot import plot
if plot_kwargs_cb is None:
Expand Down Expand Up @@ -106,7 +108,6 @@ def post_process(x, y, params):
idx, lines=False, markers=markers, labels=names))

if xlabel is None:
import matplotlib.pyplot as plt
try:
plt.xlabel(x_post.dimensionality.latex)
except AttributeError:
Expand All @@ -115,7 +116,6 @@ def post_process(x, y, params):
plt.xlabel(xlabel)

if ylabel is None:
import matplotlib.pyplot as plt
try:
plt.ylabel(y_post.dimensionality.latex)
except AttributeError:
Expand Down
2 changes: 0 additions & 2 deletions pyodesys/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,6 @@ def from_callback(cls, cb, ny, nparams=0, backend=None, **kwargs):
length of p
backend : module (optional)
default: sympy
\*args :
arguments passed onto :class:`SymbolicSys`
\*\*kwargs :
keyword arguments passed onto :class:`SymbolicSys`
Expand Down
9 changes: 9 additions & 0 deletions pyodesys/tests/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ def identity(x):
logexp = (sp.log, sp.exp)


def test_SymbolicSys():
odesys = SymbolicSys.from_callback(lambda x, y, p, be: [y[1], -y[0]], 2)
with pytest.raises(ValueError):
odesys.integrate(1, [0])


def decay_rhs(t, y, k):
ny = len(y)
dydt = [0]*ny
Expand Down Expand Up @@ -97,6 +103,9 @@ def f(t, x, k):
ref = np.array(bateman_full(y0, k+[0], xout - xout[0], exp=np.exp)).T
assert np.allclose(yout, ref, rtol=3e-11, atol=3e-11)

with pytest.raises(TypeError):
odesys.integrate([1e-12, 1], [0]*len(k), k, integrator='scipy')


def test_ScaledSys_from_callback__exprs():
def f(t, x, k):
Expand Down
4 changes: 4 additions & 0 deletions pyodesys/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,7 @@ def _ensure_4args(func):
return lambda x, y, p=(), backend=math: func(x, y)
else:
raise ValueError("Incorrect numer of arguments")


def _default(arg, default):
return default if arg is None else arg

0 comments on commit f472624

Please sign in to comment.