diff --git a/dataset/Dataset.py b/dataset/Dataset.py index aa32f6c..5089489 100644 --- a/dataset/Dataset.py +++ b/dataset/Dataset.py @@ -13,6 +13,7 @@ import einops import torchvision import albumentations as A + np.random.seed(999) torch.manual_seed(999) random.seed(999) @@ -257,13 +258,15 @@ def __init__(self, config, mode="train"): label = 0 else: label = 1 - + sample_insar_path = annotation_utils.get_insar_path( + annotation_path=annotation_path + annotation_file, + root_path=self.config["data_path"]) + sample_cc_path = sample_insar_path[:-8] + 'cc.png' + if not os.path.isfile(sample_cc_path) or not os.path.isfile(sample_insar_path): + continue sample_dict = { "frameID": annotation["frameID"], - "insar_path": annotation_utils.get_insar_path( - annotation_path=annotation_path + annotation_file, - root_path=self.config["data_path"], - ), + "insar_path":sample_insar_path, "label": annotation, } self.interferograms.append(sample_dict) diff --git a/utilities/utils.py b/utilities/utils.py index 8b45154..e78b7b5 100644 --- a/utilities/utils.py +++ b/utilities/utils.py @@ -194,7 +194,6 @@ def load_checkpoint(model, optimizer, args): def extract_state_dict_from_ddp_checkpoint(checkpoint_path): print("=> loading checkpoint '{}'".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path, map_location="cpu") - print(checkpoint.keys()) encoder_state_dict = {} for key in list(checkpoint["state_dict"].keys()): checkpoint["state_dict"][key.replace("module.", "")] = checkpoint[