-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathstep_3_split_train_test.py
90 lines (71 loc) · 3.5 KB
/
step_3_split_train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import pathlib
from typing import Tuple
import pandas as pd
import piml.config
from piml.config.dataset import DatasetConfig
from piml.utils.pandas import df_f64_f32, to_gz_csv
def split_test_train(df: pd.DataFrame, test_interval: Tuple[str, str]) -> Tuple[pd.DataFrame, pd.DataFrame]:
""" Split into test and training. """
df_test_mask = df["TIME"].between(*test_interval, inclusive="left")
df_test = df.loc[df_test_mask]
df_train = df.loc[~df_test_mask]
return df_train, df_test
def write_dataset(df_train: pd.DataFrame, df_test: pd.DataFrame,
ds_config: DatasetConfig, train_test_dir: pathlib.Path) -> None:
""" Write train dataset, test dataset, and list of features to csv files. """
to_gz_csv(
df_f64_f32(df_train),
train_test_dir / ds_config.get_train_name(with_suffix=True),
index=False
)
to_gz_csv(
df_f64_f32(df_test),
train_test_dir / ds_config.get_test_name(with_suffix=True),
index=False
)
def validate_dataset(df: pd.DataFrame, ws: piml.Workspace):
# Validate that all required input variables are present
dim_vars = ws.config.dim_vars
required_vars = dim_vars.input_strs + ["TIME", "DAY_YEAR"]
missing_vars = set(required_vars) - set(df.columns)
if len(missing_vars) > 0:
raise ValueError(f"Validation failed! The following variables are missing: {missing_vars}.")
# Expected output/target variable name changes depending on whether pre-pi transform is requested..
dim_target = dim_vars.output.symbol.name
pre_pi_tf = ws.config.dataset.target_transformers.get("pre_pi")
if pre_pi_tf is None:
# If no pre-pi transform is enabled, just require variable to be present.
if dim_target not in df.columns:
raise ValueError(f"Validation failed! The dimensional target variable {dim_target} is missing.")
else:
# If pre-pi transform is enabled, we explicitly require dim target variable to end in `_tf`.
dim_target_tf = dim_target
dim_target = dim_target_tf[:-3]
if dim_target not in df.columns:
raise ValueError(
f"Validation failed! You requested a pre-pi transform: {pre_pi_tf}. "
f"For clarity, dataset must contain the non-transformed dimensional target variable in col XXX "
f"and the matching transformed variable must be named XXX_tf as dim_output. "
f"Could not find `{dim_target}` based on `{dim_target_tf}`."
)
if __name__ == '__main__':
ws = piml.Workspace.auto()
for f in ws.data_processed.glob("*.csv.gz"):
# Read
print(f"Reading {f.name}... ", end=" ")
df = pd.read_csv(f)
df = df.rename(columns=ws.config.dataset.col_to_var)
# Validate
validate_dataset(df, ws)
print("Valid! ", end=" ")
# Split into test and training
test_interval = ws.config.dataset.test_interval
test_interval = (str(test_interval[0]), str(test_interval[1]))
print(f"Performing train/test split with test interval {test_interval[0]} -- {test_interval[1]}... ", end=" ")
df_train, df_test = split_test_train(df, test_interval)
# Print diagnostics
print(f"Test ratio: {len(df_test) / len(df):.2f}.", end=" ")
print(f"Number of days in test: {len(df_test['DAY_YEAR'].unique())}.", end=" ")
# Write to disk
write_dataset(df_train, df_test, ds_config=ws.config.dataset, train_test_dir=ws.data_train_test)
print("Done!")