diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 84fc623b..34049670 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -17,7 +17,7 @@ from sklearn.model_selection import GridSearchCV from sklearn.tree import DecisionTreeClassifier -from molpipeline import ErrorFilter, Pipeline +from molpipeline import ErrorFilter, FilterReinserter, Pipeline, PostPredictionWrapper from molpipeline.any2mol import AutoToMol, SmilesToMol from molpipeline.mol2any import MolToMorganFP, MolToRDKitPhysChem, MolToSmiles from molpipeline.mol2mol import ( @@ -384,16 +384,24 @@ def test_calibrated_classifier(self) -> None: smi2mol = SmilesToMol() mol2morgan = MolToMorganFP(radius=FP_RADIUS, n_bits=FP_SIZE) d_tree = DecisionTreeClassifier() + error_filter = ErrorFilter(filter_everything=True) s_pipeline = Pipeline( [ ("smi2mol", smi2mol), ("morgan", mol2morgan), + ("error_filter", error_filter), ("decision_tree", d_tree), + ( + "error_replacer", + PostPredictionWrapper( + FilterReinserter.from_error_filter(error_filter, None) + ), + ), ] ) - calibrated_pipeline = CalibratedClassifierCV(s_pipeline) + calibrated_pipeline = CalibratedClassifierCV(s_pipeline, cv=2) calibrated_pipeline.fit(TEST_SMILES, CONTAINS_OX) - predicted_value_array = s_pipeline.predict(TEST_SMILES) + predicted_value_array = calibrated_pipeline.predict(TEST_SMILES) for pred_val, true_val in zip(predicted_value_array, CONTAINS_OX): self.assertEqual(pred_val, true_val)