Skip to content

Commit

Permalink
Add get_train_times as defined in the book page 106 on CV.
Browse files Browse the repository at this point in the history
  • Loading branch information
imcu committed Jul 1, 2019
1 parent 1899cdb commit e38a523
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 1 deletion.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
.idea/*
.idea/*
__pycache__
test_reports
.coverage
9 changes: 9 additions & 0 deletions mlfinlab/cross_validation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Functions derived from Chapter 7: Cross Validation
"""

from mlfinlab.cross_validation.cross_validation import get_train_times

__all__ = [
'get_train_times',
]
21 changes: 21 additions & 0 deletions mlfinlab/cross_validation/cross_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
Implements the book chapter 7 on Cross Validation for financial data
"""

import pandas as pd


def get_train_times(observations: pd.Series, test_times: pd.Series) -> pd.Series: # pragma: no cover
"""
Given test_times, find the times of the training observations.
—observations.index: Time when the observation started.
—observations.value: Time when the observation ended.
—test_times: Times of testing observations.
"""
train = observations.copy(deep=True)
for start_ix, end_ix in test_times.iteritems():
df0 = train[(start_ix <= train.index) & (train.index <= end_ix)].index # train starts within test
df1 = train[(start_ix <= train) & (train <= end_ix)].index # train ends within test
df2 = train[(train.index <= start_ix) & (end_ix <= train)].index # train envelops test
train = train.drop(df0.union(df1).union(df2))
return train
115 changes: 115 additions & 0 deletions mlfinlab/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Tests the cross validation technique described in Ch.7 of the book
"""
import unittest
import os
import pandas as pd

from mlfinlab.cross_validation.cross_validation import get_train_times


class TestCrossValidation(unittest.TestCase):
"""
Test the functionality of the time series cross validation technique
"""

def __init__(self, *args):
super().__init__(*args)
self.verbose = False

def log(self, msg):
"""
Simple method to suppress debugging strings
"""
if self.verbose: # pragma: no cover
print(msg)

def setUp(self):

"""
This is how the observations dataset looks like
2019-01-01 00:00:00 2019-01-01 00:02:00
2019-01-01 00:01:00 2019-01-01 00:03:00
2019-01-01 00:02:00 2019-01-01 00:04:00
2019-01-01 00:03:00 2019-01-01 00:05:00
2019-01-01 00:04:00 2019-01-01 00:06:00
2019-01-01 00:05:00 2019-01-01 00:07:00
2019-01-01 00:06:00 2019-01-01 00:08:00
2019-01-01 00:07:00 2019-01-01 00:09:00
2019-01-01 00:08:00 2019-01-01 00:10:00
2019-01-01 00:09:00 2019-01-01 00:11:00
"""

pwd_path = os.path.dirname(__file__)
self.log(f"pwd_path= {pwd_path}")

self.observations = pd.Series(
index=pd.date_range(start='2019-01-01 00:00:00', periods=10, freq='T'),
data=pd.date_range(start='2019-01-01 00:02:00', periods=10, freq='T'),
)
self.log(self.observations)

def test_get_train_times_1(self):
"""
Tests the get_train_times method for the case where the train STARTS within test
"""
test_times = pd.Series(
index=pd.date_range(start='2019-01-01 00:01:00', periods=1, freq='T'),
data=pd.date_range(start='2019-01-01 00:02:00', periods=1, freq='T'),
)
self.log(f"test_times=\n{test_times}")
train_times_ret = get_train_times(self.observations, test_times)
self.log(f"train_times_ret=\n{train_times_ret}")

train_times_ok = pd.Series(
index=pd.date_range(start='2019-01-01 00:03:00', end='2019-01-01 00:09:00', freq='T'),
data=pd.date_range(start='2019-01-01 00:05:00', end='2019-01-01 00:11:00', freq='T'),
)
self.log(f"train_times=\n{train_times_ok}")

self.assertTrue(train_times_ret.equals(train_times_ok), "train dataset doesn't match")

def test_get_train_times_2(self):
"""
Tests the get_train_times method for the case where the train ENDS within test
"""
test_times = pd.Series(
index=pd.date_range(start='2019-01-01 00:08:00', periods=1, freq='T'),
data=pd.date_range(start='2019-01-01 00:11:00', periods=1, freq='T'),
)
self.log(f"test_times=\n{test_times}")
train_times_ret = get_train_times(self.observations, test_times)
self.log(f"train_times_ret=\n{train_times_ret}")

train_times_ok = pd.Series(
index=pd.date_range(start='2019-01-01 00:00:00', end='2019-01-01 00:05:00', freq='T'),
data=pd.date_range(start='2019-01-01 00:02:00', end='2019-01-01 00:07:00', freq='T'),
)
self.log(f"train_times=\n{train_times_ok}")

self.assertTrue(train_times_ret.equals(train_times_ok), "train dataset doesn't match")

def test_get_train_times_3(self):
"""
Tests the get_train_times method for the case where the train ENVELOPES test
"""
test_times = pd.Series(
index=pd.date_range(start='2019-01-01 00:06:00', periods=1, freq='T'),
data=pd.date_range(start='2019-01-01 00:08:00', periods=1, freq='T'),
)
self.log(f"test_times=\n{test_times}")
train_times_ret = get_train_times(self.observations, test_times)
self.log(f"train_times_ret=\n{train_times_ret}")

train_times_ok1 = pd.Series(
index=pd.date_range(start='2019-01-01 00:00:00', end='2019-01-01 00:03:00', freq='T'),
data=pd.date_range(start='2019-01-01 00:02:00', end='2019-01-01 00:05:00', freq='T'),
)
train_times_ok2 = pd.Series(
index=pd.date_range(start='2019-01-01 00:09:00', end='2019-01-01 00:09:00', freq='T'),
data=pd.date_range(start='2019-01-01 00:11:00', end='2019-01-01 00:11:00', freq='T'),
)
train_times_ok = pd.concat([train_times_ok1, train_times_ok2])
self.log(f"train_times=\n{train_times_ok}")

self.assertTrue(train_times_ret.equals(train_times_ok), "train dataset doesn't match")
4 changes: 4 additions & 0 deletions scripts/run_tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Helper module used to tun tests
"""

import unittest
import xmlrunner

Expand Down

0 comments on commit e38a523

Please sign in to comment.