From 83fd21f0a236a9e8ebc5530110eb79b5fe2682a7 Mon Sep 17 00:00:00 2001 From: Lonneke Scheffer Date: Wed, 24 Apr 2024 18:03:03 +0200 Subject: [PATCH] bugfix in storing/loading KerasSequenceCNN (deepcopy on keras object crashes in newer versions) --- immuneML/ml_methods/KerasSequenceCNN.py | 14 ++++++++++---- test/ml_methods/test_kerasSequenceCNN.py | 16 ++++++---------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/immuneML/ml_methods/KerasSequenceCNN.py b/immuneML/ml_methods/KerasSequenceCNN.py index d888c518a..d63d9bb0e 100644 --- a/immuneML/ml_methods/KerasSequenceCNN.py +++ b/immuneML/ml_methods/KerasSequenceCNN.py @@ -69,7 +69,6 @@ def __init__(self, units_per_layer: list = None, activation: str = None, trainin self.training_percentage = training_percentage self.background_probabilities = None - self.CNN = None self.label = None self.class_mapping = None self.result_path = None @@ -202,7 +201,7 @@ def store(self, path: Path, feature_names=None, details_path: Path = None): self.model.save(path / "model.keras") - custom_vars = copy.deepcopy(vars(self)) + custom_vars = self.get_params() del custom_vars["model"] del custom_vars["result_path"] @@ -234,8 +233,15 @@ def check_if_exists(self, path): return self.model is not None def get_params(self): - params = copy.deepcopy(vars(self)) - params["model"] = copy.deepcopy(self.model).state_dict() + params = dict() + + # using 'deepcopy' on the model directly results in an error, therefore loop over all other items + for key, value in vars(self).items(): + if key != "model": + params[key] = copy.deepcopy(value) + + params["model"] = copy.deepcopy(self.model.get_config()) + return params def get_label_name(self): diff --git a/test/ml_methods/test_kerasSequenceCNN.py b/test/ml_methods/test_kerasSequenceCNN.py index 038dda884..a63165ca8 100644 --- a/test/ml_methods/test_kerasSequenceCNN.py +++ b/test/ml_methods/test_kerasSequenceCNN.py @@ -75,21 +75,17 @@ def _test_fit(self): cnn2 = KerasSequenceCNN() cnn2.load(path / "model_storage") - cnn2_vars = vars(cnn2) - del cnn2_vars["CNN"] - cnn_vars = vars(cnn) - del cnn_vars["CNN"] + cnn2_params = cnn2.get_params() + cnn_params = cnn.get_params() - for item, value in cnn_vars.items(): + for item, value in cnn_params.items(): if isinstance(value, Label): - self.assertDictEqual(vars(value), (vars(cnn2_vars[item]))) - elif not isinstance(value, keras.Sequential): - self.assertEqual(value, cnn2_vars[item]) + self.assertDictEqual(vars(value), (vars(cnn2_params[item]))) + else: + self.assertEqual(value, cnn2_params[item]) predictions_proba2 = cnn2.predict_proba(enc_dataset.encoded_data, label) - print(predictions_proba2) - self.assertTrue(all(predictions_proba["CMV"]["yes"] == predictions_proba2["CMV"]["yes"])) self.assertTrue(all(predictions_proba["CMV"]["no"] == predictions_proba2["CMV"]["no"]))