Skip to content

Commit

Permalink
Added tests for pandas dm
Browse files Browse the repository at this point in the history
  • Loading branch information
BarzaH committed May 19, 2024
1 parent a920b48 commit 56f85b0
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 2 deletions.
48 changes: 48 additions & 0 deletions tests/fixtures/config/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,54 @@
}
)

regrhouse_prices_datamodule_cfg_w_target = DictConfig(
{
"task": ["table-regression"],
"name": "house prices",
"description": "",
"markup_info": "",
"date_time": "01.07.2022",
"_target_": "innofw.core.datamodules.pandas_datamodules.RegressionPandasDataModule",
"train": {
"source": str(
get_test_folder_path()
/ "data/tabular/regression/house_prices/train/train.csv"
)
},
"test": {
"source": str(
get_test_folder_path()
/ "data/tabular/regression/house_prices/test/test.csv"
)
},
"target_col": "price",
}
)

clusthouse_prices_datamodule_cfg_w_target = DictConfig(
{
"task": ["table-clustering"],
"name": "house prices",
"description": "",
"markup_info": "",
"date_time": "01.07.2022",
"_target_": "innofw.core.datamodules.pandas_datamodules.ClusteringPandasDataModule",
"train": {
"source": str(
get_test_folder_path()
/ "data/tabular/regression/house_prices/train/train.csv"
)
},
"test": {
"source": str(
get_test_folder_path()
/ "data/tabular/regression/house_prices/test/test.csv"
)
},
"target_col": "price",
}
)

wheat_datamodule_cfg_w_target = DictConfig(
{
"task": ["image-detection"],
Expand Down
56 changes: 54 additions & 2 deletions tests/unit/datamodules/pandas_datamodules/test_pandas_dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from innofw.constants import Stages
from innofw.core.datamodules import PandasDataModule
from innofw.utils.framework import get_datamodule
from tests.fixtures.config.datasets import house_prices_datamodule_cfg_w_target
from tests.fixtures.config.datasets import house_prices_datamodule_cfg_w_target, \
regrhouse_prices_datamodule_cfg_w_target,\
clusthouse_prices_datamodule_cfg_w_target


def test_save_preds(tmp_path):
def test_classsave_preds(tmp_path):
# create a house price dm
fw = Frameworks.sklearn
task = "table-regression"
Expand All @@ -17,6 +19,7 @@ def test_save_preds(tmp_path):
stage = Stages.train
# get target col values
df = dm.get_stage_dataloader(stage)
df = dm.get_stage_dataloader(Stages.test)
y = df["y"]
# pass "preds", stage, path to the function
dm.save_preds(y, stage, tmp_path)
Expand All @@ -28,3 +31,52 @@ def test_save_preds(tmp_path):
"regression.csv",
"clustering.csv",
]

def test_regrsave_preds(tmp_path):
# create a house price dm
fw = Frameworks.sklearn
task = "table-regression"
dm: PandasDataModule = get_datamodule(
regrhouse_prices_datamodule_cfg_w_target, fw, task=task
)
# for stage train
stage = Stages.train
# get target col values
df = dm.get_stage_dataloader(stage)
df = dm.get_stage_dataloader(Stages.test)
y = df["y"]
# pass "preds", stage, path to the function
dm.save_preds(y, stage, tmp_path)
# check if a file has been created
files = list(tmp_path.iterdir())
assert len(files) == 1
assert files[0].name in [
"prediction.csv",
"regression.csv",
"clustering.csv",
]


def test_clustsave_preds(tmp_path):
# create a house price dm
fw = Frameworks.sklearn
task = "table-clustering"
dm: PandasDataModule = get_datamodule(
clusthouse_prices_datamodule_cfg_w_target, fw, task=task
)
# for stage train
stage = Stages.train
# get target col values
df = dm.get_stage_dataloader(stage)
df = dm.get_stage_dataloader(Stages.test)
y = df["y"]
# pass "preds", stage, path to the function
dm.save_preds(y, stage, tmp_path)
# check if a file has been created
files = list(tmp_path.iterdir())
assert len(files) == 1
assert files[0].name in [
"prediction.csv",
"regression.csv",
"clustering.csv",
]

0 comments on commit 56f85b0

Please sign in to comment.