diff --git a/tests/fixtures/config/datasets.py b/tests/fixtures/config/datasets.py index c6220d10..7f24eb84 100644 --- a/tests/fixtures/config/datasets.py +++ b/tests/fixtures/config/datasets.py @@ -87,6 +87,54 @@ } ) +regrhouse_prices_datamodule_cfg_w_target = DictConfig( + { + "task": ["table-regression"], + "name": "house prices", + "description": "", + "markup_info": "", + "date_time": "01.07.2022", + "_target_": "innofw.core.datamodules.pandas_datamodules.RegressionPandasDataModule", + "train": { + "source": str( + get_test_folder_path() + / "data/tabular/regression/house_prices/train/train.csv" + ) + }, + "test": { + "source": str( + get_test_folder_path() + / "data/tabular/regression/house_prices/test/test.csv" + ) + }, + "target_col": "price", + } +) + +clusthouse_prices_datamodule_cfg_w_target = DictConfig( + { + "task": ["table-clustering"], + "name": "house prices", + "description": "", + "markup_info": "", + "date_time": "01.07.2022", + "_target_": "innofw.core.datamodules.pandas_datamodules.ClusteringPandasDataModule", + "train": { + "source": str( + get_test_folder_path() + / "data/tabular/regression/house_prices/train/train.csv" + ) + }, + "test": { + "source": str( + get_test_folder_path() + / "data/tabular/regression/house_prices/test/test.csv" + ) + }, + "target_col": "price", + } +) + wheat_datamodule_cfg_w_target = DictConfig( { "task": ["image-detection"], diff --git a/tests/unit/datamodules/pandas_datamodules/test_pandas_dm.py b/tests/unit/datamodules/pandas_datamodules/test_pandas_dm.py index 1b95ea1e..275841a1 100644 --- a/tests/unit/datamodules/pandas_datamodules/test_pandas_dm.py +++ b/tests/unit/datamodules/pandas_datamodules/test_pandas_dm.py @@ -3,10 +3,12 @@ from innofw.constants import Stages from innofw.core.datamodules import PandasDataModule from innofw.utils.framework import get_datamodule -from tests.fixtures.config.datasets import house_prices_datamodule_cfg_w_target +from tests.fixtures.config.datasets import house_prices_datamodule_cfg_w_target, \ +regrhouse_prices_datamodule_cfg_w_target,\ +clusthouse_prices_datamodule_cfg_w_target -def test_save_preds(tmp_path): +def test_classsave_preds(tmp_path): # create a house price dm fw = Frameworks.sklearn task = "table-regression" @@ -17,6 +19,7 @@ def test_save_preds(tmp_path): stage = Stages.train # get target col values df = dm.get_stage_dataloader(stage) + df = dm.get_stage_dataloader(Stages.test) y = df["y"] # pass "preds", stage, path to the function dm.save_preds(y, stage, tmp_path) @@ -28,3 +31,52 @@ def test_save_preds(tmp_path): "regression.csv", "clustering.csv", ] + +def test_regrsave_preds(tmp_path): + # create a house price dm + fw = Frameworks.sklearn + task = "table-regression" + dm: PandasDataModule = get_datamodule( + regrhouse_prices_datamodule_cfg_w_target, fw, task=task + ) + # for stage train + stage = Stages.train + # get target col values + df = dm.get_stage_dataloader(stage) + df = dm.get_stage_dataloader(Stages.test) + y = df["y"] + # pass "preds", stage, path to the function + dm.save_preds(y, stage, tmp_path) + # check if a file has been created + files = list(tmp_path.iterdir()) + assert len(files) == 1 + assert files[0].name in [ + "prediction.csv", + "regression.csv", + "clustering.csv", + ] + + +def test_clustsave_preds(tmp_path): + # create a house price dm + fw = Frameworks.sklearn + task = "table-clustering" + dm: PandasDataModule = get_datamodule( + clusthouse_prices_datamodule_cfg_w_target, fw, task=task + ) + # for stage train + stage = Stages.train + # get target col values + df = dm.get_stage_dataloader(stage) + df = dm.get_stage_dataloader(Stages.test) + y = df["y"] + # pass "preds", stage, path to the function + dm.save_preds(y, stage, tmp_path) + # check if a file has been created + files = list(tmp_path.iterdir()) + assert len(files) == 1 + assert files[0].name in [ + "prediction.csv", + "regression.csv", + "clustering.csv", + ] \ No newline at end of file