From 5ee8c70dc00b63742956bb70d6a9271b410356e5 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 12 Feb 2025 17:46:29 +0100 Subject: [PATCH] add test for ChEMPROP Model, which fails --- .../test_chemprop/test_chemprop_pipeline.py | 61 ++++++++++++++----- 1 file changed, 47 insertions(+), 14 deletions(-) diff --git a/test_extras/test_chemprop/test_chemprop_pipeline.py b/test_extras/test_chemprop/test_chemprop_pipeline.py index c600a3a3..305cfaa6 100644 --- a/test_extras/test_chemprop/test_chemprop_pipeline.py +++ b/test_extras/test_chemprop/test_chemprop_pipeline.py @@ -9,6 +9,7 @@ import pandas as pd from lightning import pytorch as pl from sklearn.base import clone +from sklearn.calibration import CalibratedClassifierCV from molpipeline.any2mol import SmilesToMol from molpipeline.error_handling import ErrorFilter, FilterReinserter @@ -317,28 +318,34 @@ def test_prediction(self) -> None: class TestClassificationPipeline(unittest.TestCase): """Test the Chemprop model pipeline for classification.""" - def test_prediction(self) -> None: - """Test the prediction of the classification model.""" - - molecule_net_bbbp_df = pd.read_csv( + def setUp(self) -> None: + """Set up repeated variables.""" + self.molecule_net_bbbp_df = pd.read_csv( TEST_DATA_DIR / "molecule_net_bbbp.tsv.gz", sep="\t", nrows=100 ) + + def test_prediction(self) -> None: + """Test the prediction of the classification model.""" classification_model = get_classification_pipeline() classification_model.fit( - molecule_net_bbbp_df["smiles"].tolist(), - molecule_net_bbbp_df["p_np"].to_numpy(), + self.molecule_net_bbbp_df["smiles"].tolist(), + self.molecule_net_bbbp_df["p_np"].to_numpy(), + ) + pred = classification_model.predict( + self.molecule_net_bbbp_df["smiles"].tolist() ) - pred = classification_model.predict(molecule_net_bbbp_df["smiles"].tolist()) proba = classification_model.predict_proba( - molecule_net_bbbp_df["smiles"].tolist() + self.molecule_net_bbbp_df["smiles"].tolist() ) - self.assertEqual(len(pred), len(molecule_net_bbbp_df)) + self.assertEqual(len(pred), len(self.molecule_net_bbbp_df)) self.assertEqual(proba.shape[1], 2) - self.assertEqual(proba.shape[0], len(molecule_net_bbbp_df)) + self.assertEqual(proba.shape[0], len(self.molecule_net_bbbp_df)) model_copy = joblib_dump_load(classification_model) - pred_copy = model_copy.predict(molecule_net_bbbp_df["smiles"].tolist()) - proba_copy = model_copy.predict_proba(molecule_net_bbbp_df["smiles"].tolist()) + pred_copy = model_copy.predict(self.molecule_net_bbbp_df["smiles"].tolist()) + proba_copy = model_copy.predict_proba( + self.molecule_net_bbbp_df["smiles"].tolist() + ) nan_indices = np.isnan(pred) self.assertListEqual(nan_indices.tolist(), np.isnan(pred_copy).tolist()) @@ -349,14 +356,40 @@ def test_prediction(self) -> None: # Test single prediction, this was causing an error before single_mol_pred = classification_model.predict( - [molecule_net_bbbp_df["smiles"].iloc[0]] + [self.molecule_net_bbbp_df["smiles"].iloc[0]] ) self.assertEqual(single_mol_pred.shape, (1,)) single_mol_proba = classification_model.predict_proba( - [molecule_net_bbbp_df["smiles"].iloc[0]] + [self.molecule_net_bbbp_df["smiles"].iloc[0]] ) self.assertEqual(single_mol_proba.shape, (1, 2)) + def test_calibrated_classifier(self) -> None: + """Test if the pipeline can be used with a CalibratedClassifierCV.""" + calibrated_pipeline = CalibratedClassifierCV( + get_classification_pipeline(), cv=2, ensemble=True, method="isotonic" + ) + calibrated_pipeline.fit( + self.molecule_net_bbbp_df["smiles"].tolist(), + self.molecule_net_bbbp_df["p_np"].to_numpy(), + ) + predicted_value_array = calibrated_pipeline.predict( + self.molecule_net_bbbp_df["smiles"].tolist() + ) + predicted_proba_array = calibrated_pipeline.predict_proba( + self.molecule_net_bbbp_df["smiles"].tolist() + ) + self.assertIsInstance(predicted_value_array, np.ndarray) + self.assertIsInstance(predicted_proba_array, np.ndarray) + self.assertEqual( + predicted_value_array.shape, + (len(self.molecule_net_bbbp_df["smiles"].tolist()),), + ) + self.assertEqual( + predicted_proba_array.shape, + (len(self.molecule_net_bbbp_df["smiles"].tolist()), 2), + ) + class TestMulticlassClassificationPipeline(unittest.TestCase): """Test the Chemprop model pipeline for multiclass classification."""