From adf6d8e4683ec91cb581c17481d6d1fb606505e6 Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Fri, 19 Apr 2024 16:47:45 -0700 Subject: [PATCH] Make sure cluster_label_adapted values cannot be invalid --- mlreco/models/full_chain.py | 2 +- mlreco/utils/unwrap.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlreco/models/full_chain.py b/mlreco/models/full_chain.py index bfa4e2bc..5d3bf5f5 100644 --- a/mlreco/models/full_chain.py +++ b/mlreco/models/full_chain.py @@ -432,7 +432,7 @@ def full_chain_cnn(self, input): if self.enable_cnn_clust or self.enable_dbscan: cnn_result.update({'segment_label_tmp': [semantic_labels] }) if label_clustering is not None: - if 'input_rescaled' in cnn_result: + if len(label_clustering[0]) == len(input[0]): label_clustering[0][:, VALUE_COL] = input[0][:, VALUE_COL] cnn_result.update({'cluster_label_adapted': label_clustering }) diff --git a/mlreco/utils/unwrap.py b/mlreco/utils/unwrap.py index 0d3f1d45..a79f08b7 100644 --- a/mlreco/utils/unwrap.py +++ b/mlreco/utils/unwrap.py @@ -398,7 +398,7 @@ def input_unwrap_rules(schemas): for name, schema in schemas.items(): parser = schema['parser'] assert parser in INPUT_RULES, f'Unable to unwrap data from {parser}' - rules[name] = deepcopy(RULES[parser]) + rules[name] = deepcopy(INPUT_RULES[parser]) if rules[name][0] == 'tensor': rules[name][1] = name