Skip to content

Commit

Permalink
further filtering fix for free
Browse files Browse the repository at this point in the history
  • Loading branch information
MDCHAMP committed Nov 26, 2021
1 parent 02fe27c commit ca3666b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
7 changes: 3 additions & 4 deletions src/toybox/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ class excitation():
Bass class for excitations implemented in toybox, contains generic input filtering and conditioning methods
'''

def scipy_filter(self, a, b, filt_fun):
def apply_filter(self, filt_fun):
'''wrapper for SciPy filters'''
def _filt(x):
return self.filt_fun(a, b, x)
self._filt = _filt
self._filt = filt_fun
return self # important for call signature

def generate(self, ts):
if hasattr(self, '_filt'):

return self._filt(self._generate(ts))
else:
return self._generate(ts)
Expand Down
31 changes: 19 additions & 12 deletions tests/toybox/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy

from scipy.integrate import solve_ivp
from scipy.signal import butter, lfilter
from scipy.signal import butter, lfilter, welch

import toybox as tb
from toybox.premade import *
Expand Down Expand Up @@ -154,20 +154,27 @@ def test_normalisation(S):
assert (np.mean(sig) - offset) < 1e-3


@pytest.mark.parametrize('S', systems)
@pytest.mark.parametrize('w', [10,20,50,80])
def test_scipy_filter(S, w):

def test_apply_filter():
n = 1000
fs = 500
normal_cutoff = w / (0.5 * fs)
b, a = butter(4, normal_cutoff, btype='low', analog=False)
wn = 50
ts = np.linspace(0, n/fs, num=n)
b,a = butter(2, wn/(fs*0.5))

def my_filter(x):
return lfilter(b,a,x,axis=0)

S = deepcopy(S)
S.excitation = [None]*S.dofs
x = white_gaussian(0, 1)
x1 = x.generate(ts)

x.apply_filter(my_filter)
x2 = x.generate(ts)

x = white_gaussian(0, 1).scipy_filter(a, b, lfilter)
assert hasattr(x, '_filt')
S.excitation[0] = x


f1, p1 = welch(x1, fs, axis=0)
f2, p2 = welch(x2, fs, axis=0)

assert np.all( p2[f2 > (wn*1.5)] <= p1[f1 > (wn*1.5)]) # filtered fs are lower


0 comments on commit ca3666b

Please sign in to comment.