diff --git a/Dataset.py b/Dataset.py index 2644891..1952074 100644 --- a/Dataset.py +++ b/Dataset.py @@ -1,73 +1,67 @@ -import torch import os + import cv2 as cv import numpy as np -import torchvision -import albumentations as A -import random -import einops -from utilities import augmentations +import torch from tqdm import tqdm -import matplotlib.pyplot as plt -from annotation_utils import read_annotation, get_insar_path + +from utilities import augmentations + class Dataset(torch.utils.data.Dataset): def __init__(self, config): self.data_path = config["data_path"] self.augmentations = augmentations.get_augmentations(config) self.interferograms = [] - self.channels = config['num_channels'] + self.channels = config["num_channels"] frames = os.listdir(self.data_path) for frame in tqdm(frames): - frame_path = self.data_path + '/'+frame +'/interferograms/' + frame_path = self.data_path + "/" + frame + "/interferograms/" caption = os.listdir(frame_path) for cap in caption: - caption_path = frame_path + cap +'/'+cap+ '.geo.diff.png' + caption_path = frame_path + cap + "/" + cap + ".geo.diff.png" if os.path.exists(caption_path): - image_dict = {'path': caption_path, 'frame': frame} + image_dict = {"path": caption_path, "frame": frame} self.interferograms.append(image_dict) self.num_examples = len(self.interferograms) - def __len__(self): return self.num_examples - - def prepare_insar(self,insar): - insar = torch.from_numpy(insar).float().permute(2,0,1) + def prepare_insar(self, insar): + insar = torch.from_numpy(insar).float().permute(2, 0, 1) insar /= 255 return insar - def read_insar(self,path): - insar = cv.imread(path,0) + def read_insar(self, path): + insar = cv.imread(path, 0) if insar is None: - print('None') + print("None") return insar - insar = einops.repeat(insar, 'h w -> h w c', c=self.channels) + insar = np.expand_dims(insar, axis=2).repeat(self.channels, axis=2) transform = self.augmentations(image=insar) - insar_1 = transform['image'] + insar_1 = transform["image"] transform_2 = self.augmentations(image=insar) - insar_2 = transform_2['image'] + insar_2 = transform_2["image"] insar_1 = self.prepare_insar(insar_1) insar_2 = self.prepare_insar(insar_2) - return (insar_1,insar_2) + return (insar_1, insar_2) def __getitem__(self, index): insar = None while insar is None: sample = self.interferograms[index] - path = sample['path'] + path = sample["path"] insar = self.read_insar(path) if insar is None: - if index + +

This repository contains the data and code used in [Hephaestus: A large scale multitask dataset towards InSAR understanding](https://openaccess.thecvf.com/content/CVPR2022W/EarthVision/papers/Bountos_Hephaestus_A_Large_Scale_Multitask_Dataset_Towards_InSAR_Understanding_CVPRW_2022_paper.pdf) as published in CVPR 2022 workshop Earthvision. @@ -17,9 +19,39 @@ If you use this work, please cite: } ``` ### Dependancies -This repo has been tested with python3.9. To install the necessary dependancies run: +This repo has been tested with python3.9. To install the necessary dependancies run: `pip install -r requirements.txt` +### Multi-GPU / Multi-Node training +You can make use of torchrun or SLURM to launch distributed jobs. + +###### torchrun +Single-Node Multi-GPU: +``` +torchrun --standalone --nnodes=1 --nproc_per_node=2 main.py +``` + +Multi-Node Multi-GPU: +``` +# On XXX.XXX.XXX.62 (the master node) +torchrun \ +--nproc_per_node=2 --nnodes=2 --node_rank=0 \ +--master_addr=XXX.XXX.XXX.62 --master_port=1234 \ +main.py + +# On XXX.XXX.XXX.63 (the worker node) +torchrun \ +--nproc_per_node=2 --nnodes=2 --node_rank=1 \ +--master_addr=XXX.XXX.XXX.62 --master_port=1234 \ +main.py +``` + +###### SLURM: +After setting the relevant parameters inside hephaestus.slurm: +``` +sbatch hephaestus.slurm +``` + ### Dataset and pretrained models The annotation files can be downloaded [here](https://www.dropbox.com/s/i08mz5514gczksz/annotations_hephaestus.zip?dl=0). @@ -38,7 +70,7 @@ The dataset is organized in the following structure: The cropped 224x224 patches, along with the respective masks and labels can be found [here](https://www.dropbox.com/s/2bkpj79jepk0vks/Hephaestus_Classification.zip?dl=0). Some examples of these patches can be seen in the following figure. -![figure](examples.png) +![figure](docs/examples.png) The directory structure for the cropped patches is: @@ -72,7 +104,7 @@ The script will automatically create folders for the checkpoints and store the c ### Annotation The dataset contains both labeled and unlabeled data. The labeled part covers 38 frames summing up to 19,919 annotated InSAR. -The list of the studied volcanoes, along with the temporal distribution of their samples can be seen below. ![below](volcano_distribution.png) +The list of the studied volcanoes, along with the temporal distribution of their samples can be seen below. ![below](docs/volcano_distribution.png) Each labeled InSAR is accompanied by a json file containing the annotation details. Below we present an example of an annotation file. A detailed description can be seen in the original paper (section 2.2): ```python diff --git a/annotation_utils.py b/annotation_utils.py index 06a5c53..ea6f2d6 100644 --- a/annotation_utils.py +++ b/annotation_utils.py @@ -1,9 +1,10 @@ import json +import os import random + +import cv2 as cv import matplotlib.pyplot as plt import numpy as np -import cv2 as cv -import os from tqdm import tqdm np.random.seed(999) @@ -15,38 +16,58 @@ def read_annotation(annotation_path): annotation = json.load(reader) return annotation + def extract_label(annotation_path): annotation = read_annotation(annotation_path) - if annotation['label']=='Non_Deformation': + if annotation["label"] == "Non_Deformation": label = 0 - elif annotation['label']=='Deformation': + elif annotation["label"] == "Deformation": label = 1 else: label = 2 return label -def get_insar_path(annotation_path,root_path='Hephaestus_Raw/'): + +def get_insar_path(annotation_path, root_path="Hephaestus_Raw/"): annotation = read_annotation(annotation_path) - frameID = annotation['frameID'] - primaryDate = annotation['primary_date'] - secondaryDate = annotation['secondary_date'] - primary_secondary = primaryDate + '_' + secondaryDate - img_path = root_path + frameID + '/interferograms/' + primary_secondary + '/' + primary_secondary + '.geo.diff.png' + frameID = annotation["frameID"] + primaryDate = annotation["primary_date"] + secondaryDate = annotation["secondary_date"] + primary_secondary = primaryDate + "_" + secondaryDate + img_path = ( + root_path + + frameID + + "/interferograms/" + + primary_secondary + + "/" + + primary_secondary + + ".geo.diff.png" + ) return img_path -def get_segmentation(annotation_path='1.json', raw_insar_path='Hephaestus_Raw/',verbose=True): - ''' + +def get_segmentation( + annotation_path="1.json", raw_insar_path="Hephaestus_Raw/", verbose=True +): + """ :param annotation_path: :param raw_insar_path: :param verbose: :return: - ''' - class_dir = {'Mogi':1, 'Dyke':2, 'Sill': 3, 'Spheroid':4, 'Earthquake':5, 'Unidentified':6} + """ + class_dir = { + "Mogi": 1, + "Dyke": 2, + "Sill": 3, + "Spheroid": 4, + "Earthquake": 5, + "Unidentified": 6, + } annotation = read_annotation(annotation_path) - img_path = get_insar_path(annotation_path,root_path=raw_insar_path) + img_path = get_insar_path(annotation_path, root_path=raw_insar_path) - segmentation = annotation['segmentation_mask'] - insar = cv.imread(img_path,0) + segmentation = annotation["segmentation_mask"] + insar = cv.imread(img_path, 0) if verbose: plt.imshow(insar) plt.show() @@ -55,52 +76,78 @@ def get_segmentation(annotation_path='1.json', raw_insar_path='Hephaestus_Raw/', return [] if not any(isinstance(el, list) for el in segmentation): segmentation = [segmentation] - for idx,seg in enumerate(segmentation): + for idx, seg in enumerate(segmentation): i = 0 points = [] mask = np.zeros(insar.shape) - while i +1 < len(seg): + while i + 1 < len(seg): x = int(seg[i]) - y = int(seg[i+1]) - points.append([x,y]) - i+=2 + y = int(seg[i + 1]) + points.append([x, y]) + i += 2 - cv.fillPoly(mask, [np.asarray(points)], class_dir[annotation['activity_type'][idx]])#255) + cv.fillPoly( + mask, [np.asarray(points)], class_dir[annotation["activity_type"][idx]] + ) # 255) if verbose: - print('File : ', annotation_path) + print("File : ", annotation_path) plt.imshow(mask) plt.show() masks.append(mask) if verbose: - print('Number of mask: ',len(masks)) + print("Number of mask: ", len(masks)) return masks -def mask_boundaries(annotation_path,raw_insar_path='Hephaestus_Raw/',verbose=True,index=0): - ''' + +def mask_boundaries( + annotation_path, raw_insar_path="Hephaestus_Raw/", verbose=True, index=0 +): + """ :param annotation_path: path of annotation file :param raw_insar_path: Path of raw InSAR images :return: segmentation mask boundaries (top_left_point,bottom_right_point) - ''' - class_dir = {'Mogi':1, 'Dyke':2, 'Sill': 3, 'Spheroid':4, 'Earthquake':5, 'Unidentified':6} + """ + class_dir = { + "Mogi": 1, + "Dyke": 2, + "Sill": 3, + "Spheroid": 4, + "Earthquake": 5, + "Unidentified": 6, + } annotation = read_annotation(annotation_path) - if type(annotation['segmentation_mask']) is not list: + if type(annotation["segmentation_mask"]) is not list: print(annotation_path) - mask = get_segmentation(annotation_path,raw_insar_path=raw_insar_path,verbose=False)[index] + mask = get_segmentation( + annotation_path, raw_insar_path=raw_insar_path, verbose=False + )[index] - row,col = (mask==class_dir[annotation['activity_type'][index]]).nonzero() + row, col = (mask == class_dir[annotation["activity_type"][index]]).nonzero() if verbose: - rect = cv.rectangle(mask, pt1=(col.min(),row.min()), pt2=(col.max(),row.max()), color=255, thickness=2) + rect = cv.rectangle( + mask, + pt1=(col.min(), row.min()), + pt2=(col.max(), row.max()), + color=255, + thickness=2, + ) plt.imshow(rect) plt.show() - return (row.min(),col.min()), (row.max(),col.max()) + return (row.min(), col.min()), (row.max(), col.max()) -def crop_around_object(annotation_path,verbose=True,output_size=64,raw_insar_path='Hephaestus_Raw/',index=0): - ''' +def crop_around_object( + annotation_path, + verbose=True, + output_size=64, + raw_insar_path="Hephaestus_Raw/", + index=0, +): + """ :param annotation_path: annotation file path :param verbose: Option to plot and save the cropped image and mask. @@ -108,182 +155,258 @@ def crop_around_object(annotation_path,verbose=True,output_size=64,raw_insar_pat :param raw_insar_path: Path of raw InSAR images :param index: Index of ground deformation in the InSAR. Useful for cases where the InSAR contains multiple ground deformation patterns. :return: Randomly cropped image to output_size along with the respective mask. The cropped image is guaranteed to contain the ground deformation pattern. - ''' - #Get all masks - masks = get_segmentation(annotation_path,raw_insar_path=raw_insar_path,verbose=False) - - - (row_min,col_min),(row_max,col_max) = mask_boundaries(annotation_path,raw_insar_path=raw_insar_path,verbose=False,index=index) - mask = get_segmentation(annotation_path,raw_insar_path=raw_insar_path,verbose=False)[index] + """ + # Get all masks + masks = get_segmentation( + annotation_path, raw_insar_path=raw_insar_path, verbose=False + ) + + (row_min, col_min), (row_max, col_max) = mask_boundaries( + annotation_path, raw_insar_path=raw_insar_path, verbose=False, index=index + ) + mask = get_segmentation( + annotation_path, raw_insar_path=raw_insar_path, verbose=False + )[index] object_width = col_max - col_min object_height = row_max - row_min - low = col_max+max(output_size-col_min,0) - high = min(col_max+abs(output_size-object_width),mask.shape[1]) - if object_width>=output_size: + low = col_max + max(output_size - col_min, 0) + high = min(col_max + abs(output_size - object_width), mask.shape[1]) + if object_width >= output_size: low = col_min - high = col_min+output_size - print('='*20) - print('WARNING: MASK IS >= THAN DESIRED OUTPUT_SIZE OF: ',output_size,'. THE MASK WILL BE CROPPED TO FIT EXPECTED OUTPUT SIZE.') - print('='*20) - print('Object width: ',object_width) - print('Set low to: ',low) - print('Set high to: ',high) - print('Class: ',mask.max()) + high = col_min + output_size + print("=" * 20) + print( + "WARNING: MASK IS >= THAN DESIRED OUTPUT_SIZE OF: ", + output_size, + ". THE MASK WILL BE CROPPED TO FIT EXPECTED OUTPUT SIZE.", + ) + print("=" * 20) + print("Object width: ", object_width) + print("Set low to: ", low) + print("Set high to: ", high) + print("Class: ", mask.max()) if low >= high: - print('Low', low) - print('High',high) - print('Object width: ',object_width) - print('Mask width: ',mask.shape[1]) - random_right = np.random.randint(low,high) + print("Low", low) + print("High", high) + print("Object width: ", object_width) + print("Mask width: ", mask.shape[1]) + random_right = np.random.randint(low, high) left = random_right - output_size - low_down = row_max+max(output_size-row_min,0) - high_down = min(row_max+abs(output_size-object_height),mask.shape[0]) - if object_height>=output_size: + low_down = row_max + max(output_size - row_min, 0) + high_down = min(row_max + abs(output_size - object_height), mask.shape[0]) + if object_height >= output_size: low_down = row_min - high_down = row_min+output_size - print('='*20) - print('WARNING: MASK IS >= THAN DESIRED OUTPUT_SIZE OF: ',output_size,'. THE MASK WILL BE CROPPED TO FIT EXPECTED OUTPUT SIZE.') - print('='*20) - print('Object height: ',object_height) - print('Set low to: ',low_down) - print('Set high to: ',high_down) - print('Class: ',mask.max()) - random_down = np.random.randint(low_down,high_down) + high_down = row_min + output_size + print("=" * 20) + print( + "WARNING: MASK IS >= THAN DESIRED OUTPUT_SIZE OF: ", + output_size, + ". THE MASK WILL BE CROPPED TO FIT EXPECTED OUTPUT SIZE.", + ) + print("=" * 20) + print("Object height: ", object_height) + print("Set low to: ", low_down) + print("Set high to: ", high_down) + print("Class: ", mask.max()) + random_down = np.random.randint(low_down, high_down) up = random_down - output_size - #Unite Other Deformation Masks - if len(masks)>0: + # Unite Other Deformation Masks + if len(masks) > 0: for k in range(len(masks)): - if k!=index: + if k != index: mask = mask + masks[k] - mask = mask[up:random_down,left:random_right] - image_path = get_insar_path(annotation_path,root_path=raw_insar_path) + mask = mask[up:random_down, left:random_right] + image_path = get_insar_path(annotation_path, root_path=raw_insar_path) image = cv.imread(image_path) if verbose: - - insar_path = get_insar_path(annotation_path,root_path=raw_insar_path) + insar_path = get_insar_path(annotation_path, root_path=raw_insar_path) insar = cv.imread(insar_path) - insar = cv.cvtColor(insar,cv.COLOR_BGR2RGB) - print(insar[up:random_down,left:random_right].shape) + insar = cv.cvtColor(insar, cv.COLOR_BGR2RGB) + print(insar[up:random_down, left:random_right].shape) print(mask.shape) - plt.imshow(insar[up:random_down,left:random_right,:]) - plt.axis('off') + plt.imshow(insar[up:random_down, left:random_right, :]) + plt.axis("off") plt.show() - cmap = plt.get_cmap('tab10', 7) - plt.imshow(mask,cmap=cmap,vmax=6.5,vmin=-0.5) - plt.axis('off') + cmap = plt.get_cmap("tab10", 7) + plt.imshow(mask, cmap=cmap, vmax=6.5, vmin=-0.5) + plt.axis("off") plt.show() if image is None: - print('=' * 40) - print('Error. Image path not found\n') + print("=" * 40) + print("Error. Image path not found\n") print(image_path) - print('=' * 40) - return image[up:random_down,left:random_right,:], mask + print("=" * 40) + return image[up:random_down, left:random_right, :], mask -def save_crops(annotation_folder='annotations/',save_path = 'Hephaestus/labeled/',mask_path='Hephaestus/masks/',raw_insar_path='Hephaestus_Raw/',out_size=224,verbose=False): - ''' +def save_crops( + annotation_folder="annotations/", + save_path="Hephaestus/labeled/", + mask_path="Hephaestus/masks/", + raw_insar_path="Hephaestus_Raw/", + out_size=224, + verbose=False, +): + """ :param annotation_folder: folder of annotation jsons :param save_path: folder path for generated images :param mask_path: folder path for generated masks :return: Label vector [ Deformation/Non Deformation ( 0 for Non Deformation, 1,2,3,4,5 for Mogi, Dyke, Sill, Spheroid, Earthquake), Phase (0 -> Rest, 1-> Unrest, 2-> Rebound), Intensity Level (0->Low, 1-> Medium, 2->High, 3-> Earthquake (Not volcanic activity related event intensity))] - ''' + """ - print('='*40) - print('Cropping Hephaestus') - print('='*40) + print("=" * 40) + print("Cropping Hephaestus") + print("=" * 40) annotations = os.listdir(annotation_folder) c = 0 - label_path = mask_path[:-6]+'cls_labels/' + label_path = mask_path[:-6] + "cls_labels/" multiple_ctr = 0 - class_dir = {'Mogi':1, 'Dyke':2, 'Sill': 3, 'Spheroid':4, 'Earthquake':5, 'Low':0,'Medium':1,'High':2,'None':0,'Rest':0,'Unrest':1,'Rebound':2,'Unidentified':6} + class_dir = { + "Mogi": 1, + "Dyke": 2, + "Sill": 3, + "Spheroid": 4, + "Earthquake": 5, + "Low": 0, + "Medium": 1, + "High": 2, + "None": 0, + "Rest": 0, + "Unrest": 1, + "Rebound": 2, + "Unidentified": 6, + } for file in tqdm(annotations): label_json = {} - annotation = read_annotation(annotation_folder+file) - - if 'Non_Deformation' in annotation['label']: - image_path = get_insar_path(annotation_folder+file,root_path=raw_insar_path) - image = cv.imread(image_path) - tiles = image_tiling(image,tile_size=out_size) - for idx,tile in enumerate(tiles): - if image is None: - print(image_path) - print(file) - continue - #image = data_utilities.random_crop(image) - cv.imwrite(save_path+'0/'+file[:-5]+'_'+str(idx)+'.png',tile) - label_json['Deformation'] = [0] - label_json['Intensity'] = 0 - label_json['Phase'] = class_dir[annotation['phase']] - label_json['frameID'] = annotation['frameID'] - label_json['primary_date'] = annotation['primary_date'] - label_json['secondary_date'] = annotation['secondary_date'] - json_writer = open(label_path+file[:-5]+'_'+str(idx)+'.json','w') - json.dump(label_json,json_writer) - elif int(annotation['is_crowd'])==0: - if 'Non_Deformation' not in annotation['label']: - image, mask = crop_around_object(annotation_path=annotation_folder+file,verbose=False,raw_insar_path=raw_insar_path,output_size=out_size) - folder = str(class_dir[annotation['activity_type'][0]]) - cv.imwrite(save_path+folder+'/'+file[:-5]+'.png',image) - cv.imwrite(mask_path+folder+'/'+file[:-5]+'.png',mask) - label_json['Deformation'] = [class_dir[annotation['activity_type'][0]]] - if folder!=str(5): - label_json['Intensity'] = class_dir[annotation['intensity_level'][0]] + annotation = read_annotation(annotation_folder + file) + + if "Non_Deformation" in annotation["label"]: + image_path = get_insar_path( + annotation_folder + file, root_path=raw_insar_path + ) + image = cv.imread(image_path) + tiles = image_tiling(image, tile_size=out_size) + for idx, tile in enumerate(tiles): + if image is None: + print(image_path) + print(file) + continue + # image = data_utilities.random_crop(image) + cv.imwrite(save_path + "0/" + file[:-5] + "_" + str(idx) + ".png", tile) + label_json["Deformation"] = [0] + label_json["Intensity"] = 0 + label_json["Phase"] = class_dir[annotation["phase"]] + label_json["frameID"] = annotation["frameID"] + label_json["primary_date"] = annotation["primary_date"] + label_json["secondary_date"] = annotation["secondary_date"] + json_writer = open( + label_path + file[:-5] + "_" + str(idx) + ".json", "w" + ) + json.dump(label_json, json_writer) + elif int(annotation["is_crowd"]) == 0: + if "Non_Deformation" not in annotation["label"]: + image, mask = crop_around_object( + annotation_path=annotation_folder + file, + verbose=False, + raw_insar_path=raw_insar_path, + output_size=out_size, + ) + folder = str(class_dir[annotation["activity_type"][0]]) + cv.imwrite(save_path + folder + "/" + file[:-5] + ".png", image) + cv.imwrite(mask_path + folder + "/" + file[:-5] + ".png", mask) + label_json["Deformation"] = [class_dir[annotation["activity_type"][0]]] + if folder != str(5): + label_json["Intensity"] = class_dir[ + annotation["intensity_level"][0] + ] else: - label_json['Intensity'] = 3 - label_json['Phase'] = class_dir[annotation['phase']] - label_json['frameID'] = annotation['frameID'] - label_json['primary_date'] = annotation['primary_date'] - label_json['secondary_date'] = annotation['secondary_date'] - json_writer = open(label_path + file, 'w') + label_json["Intensity"] = 3 + label_json["Phase"] = class_dir[annotation["phase"]] + label_json["frameID"] = annotation["frameID"] + label_json["primary_date"] = annotation["primary_date"] + label_json["secondary_date"] = annotation["secondary_date"] + json_writer = open(label_path + file, "w") json.dump(label_json, json_writer) - elif int(annotation['is_crowd'])>0: - for deformation in range(len(annotation['segmentation_mask'])): - image, mask = crop_around_object(annotation_path=annotation_folder + file, verbose=False,raw_insar_path=raw_insar_path,index=deformation,output_size=out_size) + elif int(annotation["is_crowd"]) > 0: + for deformation in range(len(annotation["segmentation_mask"])): + image, mask = crop_around_object( + annotation_path=annotation_folder + file, + verbose=False, + raw_insar_path=raw_insar_path, + index=deformation, + output_size=out_size, + ) if verbose: - print('Deformation index:',deformation) - print('Seg length: ',len(annotation['segmentation_mask'])) - print('Annotation length: ',len(annotation['activity_type'])) - print('File: ',file) - - folder = str(class_dir[annotation['activity_type'][deformation]]) - cv.imwrite(save_path + folder + '/' + file[:-5] + '_'+ str(deformation) + '.png', image) - cv.imwrite(mask_path + folder + '/' + file[:-5] + '_'+ str(deformation) + '.png', mask) - label_json['Deformation'] = list(np.unique(mask))#class_dir[annotation['activity_type'][deformation]] - label_json['Deformation'] = [int(x) for x in label_json['Deformation']] - label_json['Deformation'].remove(0) - if folder!=str(5): - label_json['Intensity'] = class_dir[annotation['intensity_level'][deformation]] + print("Deformation index:", deformation) + print("Seg length: ", len(annotation["segmentation_mask"])) + print("Annotation length: ", len(annotation["activity_type"])) + print("File: ", file) + + folder = str(class_dir[annotation["activity_type"][deformation]]) + cv.imwrite( + save_path + + folder + + "/" + + file[:-5] + + "_" + + str(deformation) + + ".png", + image, + ) + cv.imwrite( + mask_path + + folder + + "/" + + file[:-5] + + "_" + + str(deformation) + + ".png", + mask, + ) + label_json["Deformation"] = list( + np.unique(mask) + ) # class_dir[annotation['activity_type'][deformation]] + label_json["Deformation"] = [int(x) for x in label_json["Deformation"]] + label_json["Deformation"].remove(0) + if folder != str(5): + label_json["Intensity"] = class_dir[ + annotation["intensity_level"][deformation] + ] else: - label_json['Intensity'] = 3 - label_json['Phase'] = class_dir[annotation['phase']] - label_json['frameID'] = annotation['frameID'] - label_json['primary_date'] = annotation['primary_date'] - label_json['secondary_date'] = annotation['secondary_date'] - json_writer = open(label_path + file[:-5] + '_'+ str(deformation) + '.json', 'w') + label_json["Intensity"] = 3 + label_json["Phase"] = class_dir[annotation["phase"]] + label_json["frameID"] = annotation["frameID"] + label_json["primary_date"] = annotation["primary_date"] + label_json["secondary_date"] = annotation["secondary_date"] + json_writer = open( + label_path + file[:-5] + "_" + str(deformation) + ".json", "w" + ) json.dump(label_json, json_writer) - multiple_ctr +=1 + multiple_ctr += 1 - print('='*40) - print('Cropping completed') - print('='*40) + print("=" * 40) + print("Cropping completed") + print("=" * 40) -def image_tiling(image,tile_size=64): +def image_tiling(image, tile_size=64): if image is None: return [] - max_rows = image.shape[0]//tile_size - max_cols = image.shape[1]//tile_size + max_rows = image.shape[0] // tile_size + max_cols = image.shape[1] // tile_size tiles = [] for i in range(max_rows): starting_row = i * tile_size for j in range(max_cols): starting_col = j * tile_size - img = image[starting_row:starting_row+tile_size,starting_col:starting_col+tile_size] + img = image[ + starting_row : starting_row + tile_size, + starting_col : starting_col + tile_size, + ] tiles.append(img) return tiles - diff --git a/configs/configs.json b/configs/configs.json index 19be24b..8ad1f21 100644 --- a/configs/configs.json +++ b/configs/configs.json @@ -12,24 +12,15 @@ "batch_size": 64, "epochs": 200, "lr": 0.03, - "num_workers": 16, + "num_workers": 2, "mixed_precision": false, "pretrained": true, "resolution": 224, "num_channels": 3, - "gpu": -1, "print_frequency": 10, "resume_checkpoint": "YOUR_CHECKPOINT", "start_epoch": 0, - - "distributed": true, - "multiprocessing_distributed": true, - "world_size": 1, - "rank": 0, - "dist_url": "tcp://127.0.0.1:15151", - "dist_backend": "nccl", "seed": "" - } diff --git a/configs/method/mocov2/mocov2.json b/configs/method/mocov2/mocov2.json index 2c2f47e..6110047 100644 --- a/configs/method/mocov2/mocov2.json +++ b/configs/method/mocov2/mocov2.json @@ -4,7 +4,7 @@ "momentum": 0.9, "weight_decay": 1e-4, "moco_dim":128, - "moco_k":65646, + "moco_k":65536, "moco_m":0.999, "moco_t":0.07, "mlp": true, diff --git a/examples.png b/docs/examples.png similarity index 100% rename from examples.png rename to docs/examples.png diff --git a/hephaestus-logo.png b/docs/hephaestus-logo.png similarity index 100% rename from hephaestus-logo.png rename to docs/hephaestus-logo.png diff --git a/volcano_distribution.png b/docs/volcano_distribution.png similarity index 100% rename from volcano_distribution.png rename to docs/volcano_distribution.png diff --git a/hephaestus.slurm b/hephaestus.slurm new file mode 100644 index 0000000..d5f963f --- /dev/null +++ b/hephaestus.slurm @@ -0,0 +1,50 @@ +#!/bin/bash -l +#SBATCH --job-name=Hephaestus # Job name +#SBATCH --output=Hephaestus.out # Stdout (%j expands to jobId) +#SBATCH --error=Hephaestus.err # Stderr (%j expands to jobId) +#SBATCH --ntasks=4 # Number of tasks(processes) +#SBATCH --nodes=2 # Number of nodes requested +#SBATCH --gres=gpu:2 # GPUs per node -- must be equal to ntasks per node in case of data parallelism. Change for model parallelism +#SBATCH --ntasks-per-node=2 # Tasks per node +#SBATCH --cpus-per-task=10 # Threads per task +#SBATCH --time=24:00:00 # walltime +#SBATCH --mem=54G # memory per NODE +#SBATCH --partition= # Partition +#SBATCH --account= # Replace with your system project +#SBATCH --wait-all-nodes=1 # Do not begin the execution until all nodes are ready for use + +module purge + +module load gnu/8 +module load java/12.0.2 +module load cuda/10.1.168 +module load intel/18 +module load intelmpi/2018 +module load tftorch/270-191 + +echo "Start at `date`" +echo "SLURM_NTASKS=$SLURM_NTASKS" + +if [ x$SLURM_CPUS_PER_TASK == x ]; then + export OMP_NUM_THREADS=1 +else + export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +fi + +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +echo "MASTER_ADDR=$MASTER_ADDR" + +export MASTER_PORT=12802 +echo "MASTER_PORT=$MASTER_PORT" + +NODES=($( scontrol show hostname $SLURM_NODELIST | uniq )) +export NUM_NODES=${#NODES[@]} +WORKERS=$(printf '%s-ib:'${SLURM_NTASKS_PER_NODE}',' "${NODES[@]}" | sed 's/,$//') + +echo "SLULM_NODELIST=$SLULM_NODELIST" +echo "WORKERS=$WORKERS" +echo "NODES=$NODES" + +srun python3 main.py + +echo "End at `date`" diff --git a/main.py b/main.py index 63f9351..1ab566c 100644 --- a/main.py +++ b/main.py @@ -1,65 +1,359 @@ -import torchvision -import sys -import numpy as np -import torch -from torchvision import transforms +import builtins +import json +import math +import os +import random +import shutil +import time +import warnings + import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist import torch.nn as nn -import torchvision -import time -import timm +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed import wandb -import kornia -import random -import os -import torch.distributed as dist -import torch.multiprocessing as mp -from torch.nn.parallel import DistributedDataParallel as DDP -import json -import pprint -from tqdm import tqdm -from utilities.utils import * -from models.model_utils import * -import self_supervised -import self_supervised.mocov2 -from self_supervised.mocov2 import mocov2 as moco -#os.environ['CUDA_VISIBLE_DEVICES'] ='1,2' -def exec_ssl(config): - announce_stuff('Initializing ' + config['method'],up=False) - if config['method']=='mocov2': - moco.main(config) - -if __name__ == '__main__': - - #Parse configurations - config_path = 'configs/configs.json' - config = prepare_configuration(config_path) - json.dump(config,open(config['checkpoint_path']+'/config.json','w')) - # Set seeds - seed = config['seed'] - if seed is not None: - torch.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) +import Dataset +from self_supervised.mocov2 import builder +from utilities.utils import prepare_configuration + + +def is_distributed(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) > 1 + if "SLURM_NTASKS" in os.environ: + return int(os.environ["SLURM_NTASKS"]) > 1 + return False + + +def world_info_from_env(): + local_rank = 0 + for v in ("LOCAL_RANK", "SLURM_LOCALID"): + if v in os.environ: + local_rank = int(os.environ[v]) + break + global_rank = 0 + for v in ("RANK", "SLURM_PROCID"): + if v in os.environ: + global_rank = int(os.environ[v]) + break + world_size = 1 + for v in ("WORLD_SIZE", "SLURM_NTASKS"): + if v in os.environ: + world_size = int(os.environ[v]) + break + return local_rank, global_rank, world_size + + +def is_global_master(args): + return args["rank"] == 0 + + +def train(train_loader, model, criterion, optimizer, epoch, args): + print("Training epoch: ", epoch) + batch_time = AverageMeter("Time", ":6.3f") + data_time = AverageMeter("Data", ":6.3f") + losses = AverageMeter("Loss", ":.4e") + top1 = AverageMeter("Acc@1", ":6.2f") + top5 = AverageMeter("Acc@5", ":6.2f") + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch), + ) + + # switch to train mode + model.train() + + end = time.time() + for i, (images, _) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + images[0] = images[0].cuda(non_blocking=True) + images[1] = images[1].cuda(non_blocking=True) + + # compute output + output, target = model(im_q=images[0], im_k=images[1]) + loss = criterion(output, target) + + # acc1/acc5 are (K+1)-way contrast classifier accuracy + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images[0].size(0)) + top1.update(acc1[0], images[0].size(0)) + top5.update(acc5[0], images[0].size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args["print_frequency"] == 0: + progress.display(i) + if is_global_master(args): + wandb.log( + { + "top1": top1.avg, + "top5": top5.avg, + "loss": loss.item(), + "epoch": epoch, + "iteration": i, + } + ) + + +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + + +class AverageMeter(object): + """Computes and stores the average and current value""" - #Initialize wandb + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() - if config['resume_wandb']: - id_json = json.load(open(config['checkpoint_path']+'/id.json')) - config['wandb_id']=id_json['wandb_id'] - wandb.init(project=config['wandb_project'],entity=config['wandb_entity'],id=config['wandb_id'],resume=config['resume_wandb']) + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print("\t".join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate based on schedule""" + lr = args["lr"] + if args["cos"]: # cosine lr schedule + lr *= 0.5 * (1.0 + math.cos(math.pi * epoch / args["epochs"])) + else: # stepwise lr schedule + for milestone in args["schedule"]: + lr *= 0.1 if epoch >= milestone else 1.0 + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + # print(correct[:k].shape) + # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + correct_k = ( + correct[:k].reshape(k * correct.shape[1]).float().sum(0, keepdim=True) + ) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def exec_model(model, args): + if args["seed"] is not None: + random.seed(args["seed"]) + torch.manual_seed(args["seed"]) + cudnn.deterministic = True + warnings.warn( + "You have chosen to seed training. " + "This will turn on the CUDNN deterministic setting, " + "which can slow down your training considerably! " + "You may see unexpected behavior when restarting " + "from checkpoints." + ) + + if is_distributed(): + if "SLURM_PROCID" in os.environ: + args["local_rank"], args["rank"], args["world_size"] = world_info_from_env() + args["num_workers"] = int(os.environ["SLURM_CPUS_PER_TASK"]) + os.environ["LOCAL_RANK"] = str(args["local_rank"]) + os.environ["RANK"] = str(args["rank"]) + os.environ["WORLD_SIZE"] = str(args["world_size"]) + dist.init_process_group( + backend="nccl", + world_size=args["world_size"], + rank=args["rank"], + ) + os.environ["WANDB_MODE"] = "offline" + else: + args["local_rank"], _, _ = world_info_from_env() + dist.init_process_group(backend="nccl") + args["world_size"] = dist.get_world_size() + args["rank"] = dist.get_rank() + + torch.cuda.set_device(args["local_rank"]) + + # suppress printing if not master + if not is_global_master(args): + + def print_pass(*args): + pass + + builtins.print = print_pass else: - announce_stuff('Initializing Wandb') - id = wandb.util.generate_id() - config['wandb_id'] = id - wandb.init(project=config['wandb_project'], entity=config['wandb_entity'],config=config,id=id,resume='allow') - json.dump({'wandb_id':id},open(config['checkpoint_path']+'/id.json','w')) + raise NotImplementedError("Only DistributedDataParallel is supported.") + + print("=> creating model '{}'".format(args["architecture"])) + + if is_global_master(args): + # Initialize wandb + print("Initializing Wandb") + if args["resume_wandb"]: + id_json = json.load(open(args["checkpoint_path"] + "/id.json")) + args["wandb_id"] = id_json["wandb_id"] + wandb.init( + project=args["wandb_project"], + entity=args["wandb_entity"], + id=args["wandb_id"], + resume=args["resume_wandb"], + ) + else: + id = wandb.sdk.lib.runid.generate_id() + args["wandb_id"] = id + wandb.init( + project=args["wandb_project"], + entity=args["wandb_entity"], + config=args, + id=id, + resume="allow", + ) + json.dump({"wandb_id": id}, open(args["checkpoint_path"] + "/id.json", "w")) + wandb.watch(model) + print(model) + + model.cuda() + args["batch_size"] = int(args["batch_size"] / int(args["world_size"])) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args["local_rank"]] + ) + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda() + + optimizer = torch.optim.SGD( + model.parameters(), + args["lr"], + momentum=args["momentum"], + weight_decay=args["weight_decay"], + ) + + # optionally resume from a checkpoint + if args["resume_checkpoint"]: + if os.path.isfile(args["resume_checkpoint"]): + print("=> loading checkpoint '{}'".format(args["resume_checkpoint"])) + checkpoint = torch.load( + args["resume_checkpoint"], map_location=torch.cuda.get_device_name() + ) + args["start_epoch"] = checkpoint["epoch"] + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + print( + "=> loaded checkpoint '{}' (epoch {})".format( + args["resume_checkpoint"], checkpoint["epoch"] + ) + ) + else: + print("=> no checkpoint found at '{}'".format(args["resume_checkpoint"])) - announce_stuff('Starting project with the following settings:') - pprint.pprint(config) - announce_stuff('',up=False) - exec_ssl(config) + cudnn.benchmark = True + print("Initializing Dataset") + train_dataset = Dataset.Dataset(args) + print("Dataset initialized. Size: ", len(train_dataset)) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args["batch_size"], + shuffle=(train_sampler is None), + num_workers=args["num_workers"], + pin_memory=True, + sampler=train_sampler, + drop_last=True, + ) + + for epoch in range(args["start_epoch"], args["epochs"]): + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + if is_global_master(args): + save_checkpoint( + { + "epoch": epoch + 1, + "arch": args["architecture"], + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + }, + is_best=False, + filename=args["checkpoint_path"] + + "/checkpoint_{:04d}.pth.tar".format(epoch), + ) + if is_global_master(args): + wandb.finish() + + +if __name__ == "__main__": + # Parse configurations + config_path = "configs/configs.json" + config = prepare_configuration(config_path) + json.dump(config, open(config["checkpoint_path"] + "/config.json", "w")) + if config["method"] == "mocov2": + model = builder.MoCo( + config, + config["moco_dim"], + config["moco_k"], + config["moco_m"], + config["moco_t"], + config["mlp"], + ) + else: + raise NotImplementedError(f'{config["method"]} is not supported.') + exec_model(model, config) diff --git a/models/model_utils.py b/models/model_utils.py index d65c2f5..56339eb 100644 --- a/models/model_utils.py +++ b/models/model_utils.py @@ -1,17 +1,20 @@ -import torch -import torch.nn as nn import timm + def create_model(configs): - if 'vit' not in configs['architecture']: - model = timm.models.create_model(configs['architecture'].lower(),pretrained=True,num_classes=0) + if "vit" not in configs["architecture"]: + model = timm.models.create_model( + configs["architecture"].lower(), pretrained=True, num_classes=0 + ) else: - model = timm.models.vision_transformer.VisionTransformer(img_size=int(configs['resolution']), - patch_size=int(configs['patches']), - in_chans=configs['num_channel'], - embed_dim=configs['embed_dim'], - depth=configs['depth'], - num_heads=configs['num_heads'], - num_classes=0) + model = timm.models.vision_transformer.VisionTransformer( + img_size=int(configs["resolution"]), + patch_size=int(configs["patches"]), + in_chans=configs["num_channel"], + embed_dim=configs["embed_dim"], + depth=configs["depth"], + num_heads=configs["num_heads"], + num_classes=0, + ) - return model \ No newline at end of file + return model diff --git a/self_supervised/mocov2/builder.py b/self_supervised/mocov2/builder.py index 4146fc0..9a1fb3a 100644 --- a/self_supervised/mocov2/builder.py +++ b/self_supervised/mocov2/builder.py @@ -1,13 +1,15 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import timm import torch import torch.nn as nn -import timm + class MoCo(nn.Module): """ Build a MoCo model with: a query encoder, a key encoder, and a queue https://arxiv.org/abs/1911.05722 """ + def __init__(self, config, dim=128, K=65536, m=0.999, T=0.07, mlp=False): """ dim: feature dimension (default: 128) @@ -23,15 +25,29 @@ def __init__(self, config, dim=128, K=65536, m=0.999, T=0.07, mlp=False): # create the encoders # num_classes is the output fc dimension - self.encoder_q = timm.create_model(config['architecture'].lower(),num_classes=dim,pretrained=config['pretrained'])#base_encoder(num_classes=dim) - self.encoder_k = timm.create_model(config['architecture'].lower(),num_classes=dim,pretrained=config['pretrained'])#base_encoder(num_classes=dim) + self.encoder_q = timm.create_model( + config["architecture"].lower(), + num_classes=dim, + pretrained=config["pretrained"], + ) # base_encoder(num_classes=dim) + self.encoder_k = timm.create_model( + config["architecture"].lower(), + num_classes=dim, + pretrained=config["pretrained"], + ) # base_encoder(num_classes=dim) if mlp: # hack: brute-force replacement dim_mlp = self.encoder_q.fc.weight.shape[1] - self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) - self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) - - for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): + self.encoder_q.fc = nn.Sequential( + nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc + ) + self.encoder_k.fc = nn.Sequential( + nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc + ) + + for param_q, param_k in zip( + self.encoder_q.parameters(), self.encoder_k.parameters() + ): param_k.data.copy_(param_q.data) # initialize param_k.requires_grad = False # not update by gradient @@ -46,8 +62,10 @@ def _momentum_update_key_encoder(self): """ Momentum update of the key encoder """ - for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): - param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) + for param_q, param_k in zip( + self.encoder_q.parameters(), self.encoder_k.parameters() + ): + param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) @torch.no_grad() def _dequeue_and_enqueue(self, keys): @@ -58,10 +76,12 @@ def _dequeue_and_enqueue(self, keys): ptr = int(self.queue_ptr) - assert self.K % batch_size == 0 # for simplicity + assert ( + self.K % batch_size == 0 + ), f"{self.K} mod {batch_size} should be equal to 0." # for simplicity # replace the keys at ptr (dequeue and enqueue) - self.queue[:, ptr:ptr + batch_size] = keys.T + self.queue[:, ptr : ptr + batch_size] = keys.T ptr = (ptr + batch_size) % self.K # move pointer self.queue_ptr[0] = ptr @@ -142,9 +162,9 @@ def forward(self, im_q, im_k): # compute logits # Einstein sum is more intuitive # positive logits: Nx1 - l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) + l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) # negative logits: NxK - l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) # logits: Nx(1+K) logits = torch.cat([l_pos, l_neg], dim=1) @@ -168,9 +188,10 @@ def concat_all_gather(tensor): Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ - tensors_gather = [torch.ones_like(tensor) - for _ in range(torch.distributed.get_world_size())] + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) - return output \ No newline at end of file + return output diff --git a/self_supervised/mocov2/mocov2.py b/self_supervised/mocov2/mocov2.py deleted file mode 100644 index 4b4063d..0000000 --- a/self_supervised/mocov2/mocov2.py +++ /dev/null @@ -1,313 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import argparse -import builtins -import math -import os -import random -import shutil -import time -import warnings - -import torch -import torch.nn as nn -import torch.nn.parallel -import torch.backends.cudnn as cudnn -import torch.distributed as dist -import torch.optim -import torch.multiprocessing as mp -import torch.utils.data -import torch.utils.data.distributed -import torchvision.transforms as transforms -import torchvision.datasets as datasets -import torchvision.models as models -import self_supervised.mocov2.builder as builder -import timm -import kornia -import Dataset -import wandb - -def main(args): - - - if args["seed"] is not None: - random.seed(args["seed"]) - torch.manual_seed(args["seed"]) - cudnn.deterministic = True - warnings.warn('You have chosen to seed training. ' - 'This will turn on the CUDNN deterministic setting, ' - 'which can slow down your training considerably! ' - 'You may see unexpected behavior when restarting ' - 'from checkpoints.') - - if args["gpu"] is not None: - warnings.warn('You have chosen a specific GPU. This will completely ' - 'disable data parallelism.') - - if args["dist_url"] == "env://" and args["world_size"] == -1: - args["world_size"] = int(os.environ["WORLD_SIZE"]) - - args["distributed"] = args["world_size"] > 1 or args["multiprocessing_distributed"] - - ngpus_per_node = torch.cuda.device_count() - print('Number of gpus: ',ngpus_per_node) - if args["multiprocessing_distributed"]: - # Since we have ngpus_per_node processes per node, the total world_size - # needs to be adjusted accordingly - args["world_size"] = ngpus_per_node * args["world_size"] - # Use torch.multiprocessing.spawn to launch distributed processes: the - # main_worker process function - mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) - else: - # Simply call main_worker function - main_worker(args["gpu"], ngpus_per_node, args) - - -def main_worker(gpu, ngpus_per_node, args): - args["gpu"] = gpu - - # suppress printing if not master - if args["multiprocessing_distributed"] and args["gpu"] != 0: - def print_pass(*args): - pass - builtins.print = print_pass - - if args["gpu"] is not None: - print("Use GPU: {} for training".format(args["gpu"])) - - if args["distributed"]: - if args["dist_url"] == "env://" and args["rank"] == -1: - args["rank"] = int(os.environ["RANK"]) - if args["multiprocessing_distributed"]: - # For multiprocessing distributed training, rank needs to be the - # global rank among all the processes - args["rank"] = args["rank"] * ngpus_per_node + gpu - print('Initializing: ',args["dist_url"],' rank=',args['rank']) - dist.init_process_group(backend=args["dist_backend"], init_method=args["dist_url"], - world_size=args["world_size"], rank=args["rank"]) - # create model - print("=> creating model '{}'".format(args["architecture"])) - model = builder.MoCo( - args, - args["moco_dim"], args["moco_k"], args["moco_m"], args["moco_t"], args["mlp"]) - if args['gpu']==0: - wandb.init(project=args['wandb_project'],entity=args['wandb_entity'],id=args['wandb_id'],resume=True) - wandb.watch(model) - print(model) - - if args["distributed"]: - # For multiprocessing distributed, DistributedDataParallel constructor - # should always set the single device scope, otherwise, - # DistributedDataParallel will use all available devices. - if args["gpu"] is not None: - torch.cuda.set_device(args["gpu"]) - model.cuda(args["gpu"]) - # When using a single GPU per process and per - # DistributedDataParallel, we need to divide the batch size - # ourselves based on the total number of GPUs we have - args["batch_size"] = int(args["batch_size"] / ngpus_per_node) - args["workers"] = int((args["num_workers"] + ngpus_per_node - 1) / ngpus_per_node) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args["gpu"]]) - else: - model.cuda() - # DistributedDataParallel will divide and allocate batch_size to all - # available GPUs if device_ids are not set - model = torch.nn.parallel.DistributedDataParallel(model) - elif args["gpu"] is not None: - torch.cuda.set_device(args["gpu"]) - model = model.cuda(args["gpu"]) - # comment out the following line for debugging - raise NotImplementedError("Only DistributedDataParallel is supported.") - else: - # AllGather implementation (batch shuffle, queue update, etc.) in - # this code only supports DistributedDataParallel. - raise NotImplementedError("Only DistributedDataParallel is supported.") - - # define loss function (criterion) and optimizer - criterion = nn.CrossEntropyLoss().cuda(args["gpu"]) - - optimizer = torch.optim.SGD(model.parameters(), args["lr"], - momentum=args["momentum"], - weight_decay=args["weight_decay"]) - - # optionally resume from a checkpoint - if args["resume_checkpoint"]: - if os.path.isfile(args["resume_checkpoint"]): - print("=> loading checkpoint '{}'".format(args["resume_checkpoint"])) - if args["gpu"] is None: - checkpoint = torch.load(args["resume_checkpoint"]) - else: - # Map model to be loaded to specified single gpu. - loc = 'cuda:{}'.format(args["gpu"]) - checkpoint = torch.load(args["resume_checkpoint"], map_location=loc) - args["start_epoch"] = checkpoint['epoch'] - model.load_state_dict(checkpoint['state_dict']) - optimizer.load_state_dict(checkpoint['optimizer']) - print("=> loaded checkpoint '{}' (epoch {})" - .format(args["resume_checkpoint"], checkpoint['epoch'])) - else: - print("=> no checkpoint found at '{}'".format(args["resume_checkpoint"])) - - cudnn.benchmark = True - - # Data loading code - #traindir = args["data_path"] - - print('Initializing Dataset') - train_dataset = Dataset.Dataset(args) - print('Dataset initialized. Size: ',len(train_dataset)) - - if args["distributed"]: - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - else: - train_sampler = None - - train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=args["batch_size"], shuffle=(train_sampler is None), - num_workers=args["workers"], pin_memory=True, sampler=train_sampler, drop_last=True) - - for epoch in range(args["start_epoch"], args["epochs"]): - if args["distributed"]: - train_sampler.set_epoch(epoch) - adjust_learning_rate(optimizer, epoch, args) - - # train for one epoch - train(train_loader, model, criterion, optimizer, epoch, args) - - if not args["multiprocessing_distributed"] or (args["multiprocessing_distributed"] - and args["rank"] % ngpus_per_node == 0): - save_checkpoint({ - 'epoch': epoch + 1, - 'arch': args["architecture"], - 'state_dict': model.state_dict(), - 'optimizer' : optimizer.state_dict(), - }, is_best=False, filename=args['checkpoint_path']+'/checkpoint_{:04d}.pth.tar'.format(epoch)) - if args['gpu']==0: - wandb.finish() - -def train(train_loader, model, criterion, optimizer, epoch, args): - print('Training epoch: ',epoch) - batch_time = AverageMeter('Time', ':6.3f') - data_time = AverageMeter('Data', ':6.3f') - losses = AverageMeter('Loss', ':.4e') - top1 = AverageMeter('Acc@1', ':6.2f') - top5 = AverageMeter('Acc@5', ':6.2f') - progress = ProgressMeter( - len(train_loader), - [batch_time, data_time, losses, top1, top5], - prefix="Epoch: [{}]".format(epoch)) - - # switch to train mode - model.train() - - end = time.time() - for i, (images, _) in enumerate(train_loader): - # measure data loading time - data_time.update(time.time() - end) - - if args["gpu"] is not None: - images[0] = images[0].cuda(args["gpu"], non_blocking=True) - images[1] = images[1].cuda(args["gpu"], non_blocking=True) - - # compute output - output, target = model(im_q=images[0], im_k=images[1]) - loss = criterion(output, target) - - # acc1/acc5 are (K+1)-way contrast classifier accuracy - # measure accuracy and record loss - acc1, acc5 = accuracy(output, target, topk=(1, 5)) - losses.update(loss.item(), images[0].size(0)) - top1.update(acc1[0], images[0].size(0)) - top5.update(acc5[0], images[0].size(0)) - - # compute gradient and do SGD step - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - if i % args["print_frequency"] == 0: - progress.display(i) - if args['gpu']==0: - wandb.log({'top1':top1.avg,'top5':top5.avg,'loss':loss.item(),'epoch':epoch,'iteration':i}) - - -def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): - torch.save(state, filename) - if is_best: - shutil.copyfile(filename, 'model_best.pth.tar') - - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self, name, fmt=':f'): - self.name = name - self.fmt = fmt - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - def __str__(self): - fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' - return fmtstr.format(**self.__dict__) - - -class ProgressMeter(object): - def __init__(self, num_batches, meters, prefix=""): - self.batch_fmtstr = self._get_batch_fmtstr(num_batches) - self.meters = meters - self.prefix = prefix - - def display(self, batch): - entries = [self.prefix + self.batch_fmtstr.format(batch)] - entries += [str(meter) for meter in self.meters] - print('\t'.join(entries)) - - def _get_batch_fmtstr(self, num_batches): - num_digits = len(str(num_batches // 1)) - fmt = '{:' + str(num_digits) + 'd}' - return '[' + fmt + '/' + fmt.format(num_batches) + ']' - - -def adjust_learning_rate(optimizer, epoch, args): - """Decay the learning rate based on schedule""" - lr = args["lr"] - if args["cos"]: # cosine lr schedule - lr *= 0.5 * (1. + math.cos(math.pi * epoch / args["epochs"])) - else: # stepwise lr schedule - for milestone in args["schedule"]: - lr *= 0.1 if epoch >= milestone else 1. - for param_group in optimizer.param_groups: - param_group['lr'] = lr - - -def accuracy(output, target, topk=(1,)): - """Computes the accuracy over the k top predictions for the specified values of k""" - with torch.no_grad(): - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - #print(correct[:k].shape) - #correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) - correct_k = correct[:k].reshape(k * correct.shape[1]).float().sum(0, keepdim=True) - res.append(correct_k.mul_(100.0 / batch_size)) - return res diff --git a/utilities/augmentations.py b/utilities/augmentations.py index b90264b..b0f3c61 100644 --- a/utilities/augmentations.py +++ b/utilities/augmentations.py @@ -1,26 +1,35 @@ import albumentations as A + def get_augmentations(config): - augmentations = config['augmentations'] + augmentations = config["augmentations"] independend_aug = [] - for k,v in augmentations.items(): - if k == 'RandomResizedCrop': - aug = A.augmentations.RandomResizedCrop(height=v['value'],width=v['value'],p=v['p']) - elif k=='ColorJitter': - aug = A.augmentations.ColorJitter(brightness=v['value'][0],contrast=v['value'][1],saturation=v['value'][2],hue=v['value'][3],p=v['p']) - elif k=='HorizontalFlip': - aug = A.augmentations.HorizontalFlip(p=v['p']) - elif k=='VerticalFlip': - aug = A.augmentations.VerticalFlip(p=v['p']) - elif k=='GaussianBlur': - aug = A.augmentations.GaussianBlur(sigma_limit=v['value'],p=v['p']) - elif k=='ElasticTransform': - aug = A.augmentations.ElasticTransform(p=v['p']) - elif k=='Cutout': - aug = A.augmentations.CoarseDropout(p=v['p']) - elif k=='GaussianNoise': - aug = A.augmentations.GaussNoise(p=v['p']) - elif k=='MultNoise': - aug = A.augmentations.MultiplicativeNoise(p=v['p']) + for k, v in augmentations.items(): + if k == "RandomResizedCrop": + aug = A.augmentations.RandomResizedCrop( + height=v["value"], width=v["value"], p=v["p"] + ) + elif k == "ColorJitter": + aug = A.augmentations.ColorJitter( + brightness=v["value"][0], + contrast=v["value"][1], + saturation=v["value"][2], + hue=v["value"][3], + p=v["p"], + ) + elif k == "HorizontalFlip": + aug = A.augmentations.HorizontalFlip(p=v["p"]) + elif k == "VerticalFlip": + aug = A.augmentations.VerticalFlip(p=v["p"]) + elif k == "GaussianBlur": + aug = A.augmentations.GaussianBlur(sigma_limit=v["value"], p=v["p"]) + elif k == "ElasticTransform": + aug = A.augmentations.ElasticTransform(p=v["p"]) + elif k == "Cutout": + aug = A.augmentations.CoarseDropout(p=v["p"]) + elif k == "GaussianNoise": + aug = A.augmentations.GaussNoise(p=v["p"]) + elif k == "MultNoise": + aug = A.augmentations.MultiplicativeNoise(p=v["p"]) independend_aug.append(aug) - return A.Compose(independend_aug) \ No newline at end of file + return A.Compose(independend_aug) diff --git a/utilities/utils.py b/utilities/utils.py index 3a955e3..696647f 100644 --- a/utilities/utils.py +++ b/utilities/utils.py @@ -1,36 +1,38 @@ -from pathlib import Path import json +from pathlib import Path + def create_checkpoint_directory(args): - checkpoint_path = 'checkpoints/' + args['method'].lower() + '/' + args['architecture'].lower() + '/' + args['architecture'].lower() + '_'+ str(args['resolution']) - Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + checkpoint_path = ( + Path("checkpoints") + / args["method"].lower() + / args["architecture"].lower() + / f'{args["architecture"].lower()}_{str(args["resolution"])}' + ) + checkpoint_path.mkdir(parents=True, exist_ok=True) return checkpoint_path -def announce_stuff(announcement,symbol='=',times=20,up=True,down=True): - if up: - print(symbol*times) - print(announcement) - if down: - print(symbol*times) def prepare_configuration(path): - config = json.load(open(path, 'r')) - # Create checkpoint path if it does not exist - checkpoint_path = create_checkpoint_directory(config) - config['checkpoint_path'] = checkpoint_path + # Load configuration files + base_cfg = json.load(open(path, "r")) + if not base_cfg["seed"]: + base_cfg["seed"] = None - # Load augmentation settings - augmentation_config = json.load(open(config['augmentation_config'],'r')) - config.update(augmentation_config) + augmentation_cfg = json.load(open(base_cfg["augmentation_config"], "r")) + base_cfg.update(augmentation_cfg) - #Load model settings - model_config_path = 'configs/method/' + config['method'].lower() + '/' + config['method'].lower() + '.json' - model_config = json.load(open(model_config_path,'r')) - config.update(model_config) + model_cfg = ( + Path("configs/method") + / base_cfg["method"].lower() + / f'{base_cfg["method"].lower()}.json' + ) + with model_cfg.open("r", encoding="UTF-8") as target: + model_config = json.load(target) + base_cfg.update(model_config) - if config['seed'] == '': - config['seed'] = None - if config['gpu']==-1: - config['gpu'] = None + # Create checkpoint path if it does not exist + checkpoint_path = create_checkpoint_directory(base_cfg) + base_cfg["checkpoint_path"] = checkpoint_path.as_posix() - return config \ No newline at end of file + return base_cfg