From f472624705877c540868d8fcc7053be0207357b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Dahlgren?= Date: Tue, 12 Jul 2016 16:16:24 +0200 Subject: [PATCH] Fix docs, check ny, more flexible plotting --- pyodesys/core.py | 16 ++++++++-------- pyodesys/plotting.py | 4 ++-- pyodesys/symbolic.py | 2 -- pyodesys/tests/test_symbolic.py | 9 +++++++++ pyodesys/util.py | 4 ++++ 5 files changed, 23 insertions(+), 12 deletions(-) diff --git a/pyodesys/core.py b/pyodesys/core.py index bca719ee..75acac32 100644 --- a/pyodesys/core.py +++ b/pyodesys/core.py @@ -13,7 +13,7 @@ import os -from .util import _ensure_4args +from .util import _ensure_4args, _default from .plotting import plot_result, plot_phase_plane @@ -220,6 +220,9 @@ def integrate(self, xout, y0, params=(), **kwargs): """ intern_xout, intern_y0, self.internal_params = self.pre_process( xout, y0, params) + if hasattr(self, 'ny'): + if len(intern_y0) != self.ny: + raise ValueError("Incorrect length of intern_y0") integrator = kwargs.pop('integrator', None) if integrator is None: integrator = os.environ.get('PYODESYS_INTEGRATOR', 'scipy') @@ -458,13 +461,10 @@ def _plot(self, cb, internal_xout=None, internal_yout=None, if 'names' not in kwargs: kwargs['names'] = getattr(self, 'names', None) - if (internal_xout, internal_yout, internal_params) == (None,)*3: - internal_xout = self.internal_xout - internal_yout = self.internal_yout - internal_params = self.internal_params - elif None in (internal_xout, internal_yout, internal_params): - raise ValueError("Pass either all or none of internal_* kwargs") - return cb(internal_xout, internal_yout, internal_params, **kwargs) + + return cb(_default(internal_xout, self.internal_xout), + _default(internal_yout, self.internal_yout), + _default(internal_params, self.internal_params), **kwargs) def plot_result(self, **kwargs): """ Plots the integrated dependent variables from last integration. diff --git a/pyodesys/plotting.py b/pyodesys/plotting.py index eac66757..92937402 100644 --- a/pyodesys/plotting.py +++ b/pyodesys/plotting.py @@ -55,6 +55,8 @@ def plot_result(x, y, params=(), indices=None, plot=None, plot_kwargs_cb=None, post_processors : iterable of callback (default: tuple()) """ + import matplotlib.pyplot as plt + if plot is None: from matplotlib.pyplot import plot if plot_kwargs_cb is None: @@ -106,7 +108,6 @@ def post_process(x, y, params): idx, lines=False, markers=markers, labels=names)) if xlabel is None: - import matplotlib.pyplot as plt try: plt.xlabel(x_post.dimensionality.latex) except AttributeError: @@ -115,7 +116,6 @@ def post_process(x, y, params): plt.xlabel(xlabel) if ylabel is None: - import matplotlib.pyplot as plt try: plt.ylabel(y_post.dimensionality.latex) except AttributeError: diff --git a/pyodesys/symbolic.py b/pyodesys/symbolic.py index 9ccc5771..f1277264 100644 --- a/pyodesys/symbolic.py +++ b/pyodesys/symbolic.py @@ -216,8 +216,6 @@ def from_callback(cls, cb, ny, nparams=0, backend=None, **kwargs): length of p backend : module (optional) default: sympy - \*args : - arguments passed onto :class:`SymbolicSys` \*\*kwargs : keyword arguments passed onto :class:`SymbolicSys` diff --git a/pyodesys/tests/test_symbolic.py b/pyodesys/tests/test_symbolic.py index c4d791d5..8fc9979c 100644 --- a/pyodesys/tests/test_symbolic.py +++ b/pyodesys/tests/test_symbolic.py @@ -22,6 +22,12 @@ def identity(x): logexp = (sp.log, sp.exp) +def test_SymbolicSys(): + odesys = SymbolicSys.from_callback(lambda x, y, p, be: [y[1], -y[0]], 2) + with pytest.raises(ValueError): + odesys.integrate(1, [0]) + + def decay_rhs(t, y, k): ny = len(y) dydt = [0]*ny @@ -97,6 +103,9 @@ def f(t, x, k): ref = np.array(bateman_full(y0, k+[0], xout - xout[0], exp=np.exp)).T assert np.allclose(yout, ref, rtol=3e-11, atol=3e-11) + with pytest.raises(TypeError): + odesys.integrate([1e-12, 1], [0]*len(k), k, integrator='scipy') + def test_ScaledSys_from_callback__exprs(): def f(t, x, k): diff --git a/pyodesys/util.py b/pyodesys/util.py index 026a3318..8e7592ec 100644 --- a/pyodesys/util.py +++ b/pyodesys/util.py @@ -157,3 +157,7 @@ def _ensure_4args(func): return lambda x, y, p=(), backend=math: func(x, y) else: raise ValueError("Incorrect numer of arguments") + + +def _default(arg, default): + return default if arg is None else arg