diff --git a/globalemu/preprocess.py b/globalemu/preprocess.py index 4448d6d..eb5b046 100644 --- a/globalemu/preprocess.py +++ b/globalemu/preprocess.py @@ -267,9 +267,8 @@ def load_data(file): norm_train_labels = norm_train_labels.flatten() np.save(self.base_dir + 'labels_stds.npy', labels_stds) - test_labels_stds = test_labels.std() norm_test_labels = [ - test_labels[i, :]/test_labels_stds + test_labels[i, :]/labels_stds for i in range(test_labels.shape[0])] norm_test_labels = np.array(norm_test_labels)