Skip to content

Commit

Permalink
bugfix in storing/loading KerasSequenceCNN (deepcopy on keras object …
Browse files Browse the repository at this point in the history
…crashes in newer versions)
  • Loading branch information
LonnekeScheffer committed Apr 24, 2024
1 parent 64ef227 commit 83fd21f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
14 changes: 10 additions & 4 deletions immuneML/ml_methods/KerasSequenceCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def __init__(self, units_per_layer: list = None, activation: str = None, trainin
self.training_percentage = training_percentage

self.background_probabilities = None
self.CNN = None
self.label = None
self.class_mapping = None
self.result_path = None
Expand Down Expand Up @@ -202,7 +201,7 @@ def store(self, path: Path, feature_names=None, details_path: Path = None):

self.model.save(path / "model.keras")

custom_vars = copy.deepcopy(vars(self))
custom_vars = self.get_params()
del custom_vars["model"]
del custom_vars["result_path"]

Expand Down Expand Up @@ -234,8 +233,15 @@ def check_if_exists(self, path):
return self.model is not None

def get_params(self):
params = copy.deepcopy(vars(self))
params["model"] = copy.deepcopy(self.model).state_dict()
params = dict()

# using 'deepcopy' on the model directly results in an error, therefore loop over all other items
for key, value in vars(self).items():
if key != "model":
params[key] = copy.deepcopy(value)

params["model"] = copy.deepcopy(self.model.get_config())

return params

def get_label_name(self):
Expand Down
16 changes: 6 additions & 10 deletions test/ml_methods/test_kerasSequenceCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,17 @@ def _test_fit(self):
cnn2 = KerasSequenceCNN()
cnn2.load(path / "model_storage")

cnn2_vars = vars(cnn2)
del cnn2_vars["CNN"]
cnn_vars = vars(cnn)
del cnn_vars["CNN"]
cnn2_params = cnn2.get_params()
cnn_params = cnn.get_params()

for item, value in cnn_vars.items():
for item, value in cnn_params.items():
if isinstance(value, Label):
self.assertDictEqual(vars(value), (vars(cnn2_vars[item])))
elif not isinstance(value, keras.Sequential):
self.assertEqual(value, cnn2_vars[item])
self.assertDictEqual(vars(value), (vars(cnn2_params[item])))
else:
self.assertEqual(value, cnn2_params[item])

predictions_proba2 = cnn2.predict_proba(enc_dataset.encoded_data, label)

print(predictions_proba2)

self.assertTrue(all(predictions_proba["CMV"]["yes"] == predictions_proba2["CMV"]["yes"]))
self.assertTrue(all(predictions_proba["CMV"]["no"] == predictions_proba2["CMV"]["no"]))

Expand Down

0 comments on commit 83fd21f

Please sign in to comment.