diff --git a/Dataset.py b/Dataset.py
index 2644891..1952074 100644
--- a/Dataset.py
+++ b/Dataset.py
@@ -1,73 +1,67 @@
-import torch
import os
+
import cv2 as cv
import numpy as np
-import torchvision
-import albumentations as A
-import random
-import einops
-from utilities import augmentations
+import torch
from tqdm import tqdm
-import matplotlib.pyplot as plt
-from annotation_utils import read_annotation, get_insar_path
+
+from utilities import augmentations
+
class Dataset(torch.utils.data.Dataset):
def __init__(self, config):
self.data_path = config["data_path"]
self.augmentations = augmentations.get_augmentations(config)
self.interferograms = []
- self.channels = config['num_channels']
+ self.channels = config["num_channels"]
frames = os.listdir(self.data_path)
for frame in tqdm(frames):
- frame_path = self.data_path + '/'+frame +'/interferograms/'
+ frame_path = self.data_path + "/" + frame + "/interferograms/"
caption = os.listdir(frame_path)
for cap in caption:
- caption_path = frame_path + cap +'/'+cap+ '.geo.diff.png'
+ caption_path = frame_path + cap + "/" + cap + ".geo.diff.png"
if os.path.exists(caption_path):
- image_dict = {'path': caption_path, 'frame': frame}
+ image_dict = {"path": caption_path, "frame": frame}
self.interferograms.append(image_dict)
self.num_examples = len(self.interferograms)
-
def __len__(self):
return self.num_examples
-
- def prepare_insar(self,insar):
- insar = torch.from_numpy(insar).float().permute(2,0,1)
+ def prepare_insar(self, insar):
+ insar = torch.from_numpy(insar).float().permute(2, 0, 1)
insar /= 255
return insar
- def read_insar(self,path):
- insar = cv.imread(path,0)
+ def read_insar(self, path):
+ insar = cv.imread(path, 0)
if insar is None:
- print('None')
+ print("None")
return insar
- insar = einops.repeat(insar, 'h w -> h w c', c=self.channels)
+ insar = np.expand_dims(insar, axis=2).repeat(self.channels, axis=2)
transform = self.augmentations(image=insar)
- insar_1 = transform['image']
+ insar_1 = transform["image"]
transform_2 = self.augmentations(image=insar)
- insar_2 = transform_2['image']
+ insar_2 = transform_2["image"]
insar_1 = self.prepare_insar(insar_1)
insar_2 = self.prepare_insar(insar_2)
- return (insar_1,insar_2)
+ return (insar_1, insar_2)
def __getitem__(self, index):
insar = None
while insar is None:
sample = self.interferograms[index]
- path = sample['path']
+ path = sample["path"]
insar = self.read_insar(path)
if insar is None:
- if index
+
+
This repository contains the data and code used in [Hephaestus: A large scale multitask dataset towards InSAR understanding](https://openaccess.thecvf.com/content/CVPR2022W/EarthVision/papers/Bountos_Hephaestus_A_Large_Scale_Multitask_Dataset_Towards_InSAR_Understanding_CVPRW_2022_paper.pdf) as published in CVPR 2022 workshop Earthvision.
@@ -17,9 +19,39 @@ If you use this work, please cite:
}
```
### Dependancies
-This repo has been tested with python3.9. To install the necessary dependancies run:
+This repo has been tested with python3.9. To install the necessary dependancies run:
`pip install -r requirements.txt`
+### Multi-GPU / Multi-Node training
+You can make use of torchrun or SLURM to launch distributed jobs.
+
+###### torchrun
+Single-Node Multi-GPU:
+```
+torchrun --standalone --nnodes=1 --nproc_per_node=2 main.py
+```
+
+Multi-Node Multi-GPU:
+```
+# On XXX.XXX.XXX.62 (the master node)
+torchrun \
+--nproc_per_node=2 --nnodes=2 --node_rank=0 \
+--master_addr=XXX.XXX.XXX.62 --master_port=1234 \
+main.py
+
+# On XXX.XXX.XXX.63 (the worker node)
+torchrun \
+--nproc_per_node=2 --nnodes=2 --node_rank=1 \
+--master_addr=XXX.XXX.XXX.62 --master_port=1234 \
+main.py
+```
+
+###### SLURM:
+After setting the relevant parameters inside hephaestus.slurm:
+```
+sbatch hephaestus.slurm
+```
+
### Dataset and pretrained models
The annotation files can be downloaded [here](https://www.dropbox.com/s/i08mz5514gczksz/annotations_hephaestus.zip?dl=0).
@@ -38,7 +70,7 @@ The dataset is organized in the following structure:
The cropped 224x224 patches, along with the respective masks and labels can be found [here](https://www.dropbox.com/s/2bkpj79jepk0vks/Hephaestus_Classification.zip?dl=0).
Some examples of these patches can be seen in the following figure.
-
+
The directory structure for the cropped patches is:
@@ -72,7 +104,7 @@ The script will automatically create folders for the checkpoints and store the c
### Annotation
The dataset contains both labeled and unlabeled data. The labeled part covers 38 frames summing up to 19,919 annotated InSAR.
-The list of the studied volcanoes, along with the temporal distribution of their samples can be seen below. 
+The list of the studied volcanoes, along with the temporal distribution of their samples can be seen below. 
Each labeled InSAR is accompanied by a json file containing the annotation details. Below we present an example of an annotation file. A detailed description can be seen in the original paper (section 2.2):
```python
diff --git a/annotation_utils.py b/annotation_utils.py
index 06a5c53..ea6f2d6 100644
--- a/annotation_utils.py
+++ b/annotation_utils.py
@@ -1,9 +1,10 @@
import json
+import os
import random
+
+import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
-import cv2 as cv
-import os
from tqdm import tqdm
np.random.seed(999)
@@ -15,38 +16,58 @@ def read_annotation(annotation_path):
annotation = json.load(reader)
return annotation
+
def extract_label(annotation_path):
annotation = read_annotation(annotation_path)
- if annotation['label']=='Non_Deformation':
+ if annotation["label"] == "Non_Deformation":
label = 0
- elif annotation['label']=='Deformation':
+ elif annotation["label"] == "Deformation":
label = 1
else:
label = 2
return label
-def get_insar_path(annotation_path,root_path='Hephaestus_Raw/'):
+
+def get_insar_path(annotation_path, root_path="Hephaestus_Raw/"):
annotation = read_annotation(annotation_path)
- frameID = annotation['frameID']
- primaryDate = annotation['primary_date']
- secondaryDate = annotation['secondary_date']
- primary_secondary = primaryDate + '_' + secondaryDate
- img_path = root_path + frameID + '/interferograms/' + primary_secondary + '/' + primary_secondary + '.geo.diff.png'
+ frameID = annotation["frameID"]
+ primaryDate = annotation["primary_date"]
+ secondaryDate = annotation["secondary_date"]
+ primary_secondary = primaryDate + "_" + secondaryDate
+ img_path = (
+ root_path
+ + frameID
+ + "/interferograms/"
+ + primary_secondary
+ + "/"
+ + primary_secondary
+ + ".geo.diff.png"
+ )
return img_path
-def get_segmentation(annotation_path='1.json', raw_insar_path='Hephaestus_Raw/',verbose=True):
- '''
+
+def get_segmentation(
+ annotation_path="1.json", raw_insar_path="Hephaestus_Raw/", verbose=True
+):
+ """
:param annotation_path:
:param raw_insar_path:
:param verbose:
:return:
- '''
- class_dir = {'Mogi':1, 'Dyke':2, 'Sill': 3, 'Spheroid':4, 'Earthquake':5, 'Unidentified':6}
+ """
+ class_dir = {
+ "Mogi": 1,
+ "Dyke": 2,
+ "Sill": 3,
+ "Spheroid": 4,
+ "Earthquake": 5,
+ "Unidentified": 6,
+ }
annotation = read_annotation(annotation_path)
- img_path = get_insar_path(annotation_path,root_path=raw_insar_path)
+ img_path = get_insar_path(annotation_path, root_path=raw_insar_path)
- segmentation = annotation['segmentation_mask']
- insar = cv.imread(img_path,0)
+ segmentation = annotation["segmentation_mask"]
+ insar = cv.imread(img_path, 0)
if verbose:
plt.imshow(insar)
plt.show()
@@ -55,52 +76,78 @@ def get_segmentation(annotation_path='1.json', raw_insar_path='Hephaestus_Raw/',
return []
if not any(isinstance(el, list) for el in segmentation):
segmentation = [segmentation]
- for idx,seg in enumerate(segmentation):
+ for idx, seg in enumerate(segmentation):
i = 0
points = []
mask = np.zeros(insar.shape)
- while i +1 < len(seg):
+ while i + 1 < len(seg):
x = int(seg[i])
- y = int(seg[i+1])
- points.append([x,y])
- i+=2
+ y = int(seg[i + 1])
+ points.append([x, y])
+ i += 2
- cv.fillPoly(mask, [np.asarray(points)], class_dir[annotation['activity_type'][idx]])#255)
+ cv.fillPoly(
+ mask, [np.asarray(points)], class_dir[annotation["activity_type"][idx]]
+ ) # 255)
if verbose:
- print('File : ', annotation_path)
+ print("File : ", annotation_path)
plt.imshow(mask)
plt.show()
masks.append(mask)
if verbose:
- print('Number of mask: ',len(masks))
+ print("Number of mask: ", len(masks))
return masks
-def mask_boundaries(annotation_path,raw_insar_path='Hephaestus_Raw/',verbose=True,index=0):
- '''
+
+def mask_boundaries(
+ annotation_path, raw_insar_path="Hephaestus_Raw/", verbose=True, index=0
+):
+ """
:param annotation_path: path of annotation file
:param raw_insar_path: Path of raw InSAR images
:return: segmentation mask boundaries (top_left_point,bottom_right_point)
- '''
- class_dir = {'Mogi':1, 'Dyke':2, 'Sill': 3, 'Spheroid':4, 'Earthquake':5, 'Unidentified':6}
+ """
+ class_dir = {
+ "Mogi": 1,
+ "Dyke": 2,
+ "Sill": 3,
+ "Spheroid": 4,
+ "Earthquake": 5,
+ "Unidentified": 6,
+ }
annotation = read_annotation(annotation_path)
- if type(annotation['segmentation_mask']) is not list:
+ if type(annotation["segmentation_mask"]) is not list:
print(annotation_path)
- mask = get_segmentation(annotation_path,raw_insar_path=raw_insar_path,verbose=False)[index]
+ mask = get_segmentation(
+ annotation_path, raw_insar_path=raw_insar_path, verbose=False
+ )[index]
- row,col = (mask==class_dir[annotation['activity_type'][index]]).nonzero()
+ row, col = (mask == class_dir[annotation["activity_type"][index]]).nonzero()
if verbose:
- rect = cv.rectangle(mask, pt1=(col.min(),row.min()), pt2=(col.max(),row.max()), color=255, thickness=2)
+ rect = cv.rectangle(
+ mask,
+ pt1=(col.min(), row.min()),
+ pt2=(col.max(), row.max()),
+ color=255,
+ thickness=2,
+ )
plt.imshow(rect)
plt.show()
- return (row.min(),col.min()), (row.max(),col.max())
+ return (row.min(), col.min()), (row.max(), col.max())
-def crop_around_object(annotation_path,verbose=True,output_size=64,raw_insar_path='Hephaestus_Raw/',index=0):
- '''
+def crop_around_object(
+ annotation_path,
+ verbose=True,
+ output_size=64,
+ raw_insar_path="Hephaestus_Raw/",
+ index=0,
+):
+ """
:param annotation_path: annotation file path
:param verbose: Option to plot and save the cropped image and mask.
@@ -108,182 +155,258 @@ def crop_around_object(annotation_path,verbose=True,output_size=64,raw_insar_pat
:param raw_insar_path: Path of raw InSAR images
:param index: Index of ground deformation in the InSAR. Useful for cases where the InSAR contains multiple ground deformation patterns.
:return: Randomly cropped image to output_size along with the respective mask. The cropped image is guaranteed to contain the ground deformation pattern.
- '''
- #Get all masks
- masks = get_segmentation(annotation_path,raw_insar_path=raw_insar_path,verbose=False)
-
-
- (row_min,col_min),(row_max,col_max) = mask_boundaries(annotation_path,raw_insar_path=raw_insar_path,verbose=False,index=index)
- mask = get_segmentation(annotation_path,raw_insar_path=raw_insar_path,verbose=False)[index]
+ """
+ # Get all masks
+ masks = get_segmentation(
+ annotation_path, raw_insar_path=raw_insar_path, verbose=False
+ )
+
+ (row_min, col_min), (row_max, col_max) = mask_boundaries(
+ annotation_path, raw_insar_path=raw_insar_path, verbose=False, index=index
+ )
+ mask = get_segmentation(
+ annotation_path, raw_insar_path=raw_insar_path, verbose=False
+ )[index]
object_width = col_max - col_min
object_height = row_max - row_min
- low = col_max+max(output_size-col_min,0)
- high = min(col_max+abs(output_size-object_width),mask.shape[1])
- if object_width>=output_size:
+ low = col_max + max(output_size - col_min, 0)
+ high = min(col_max + abs(output_size - object_width), mask.shape[1])
+ if object_width >= output_size:
low = col_min
- high = col_min+output_size
- print('='*20)
- print('WARNING: MASK IS >= THAN DESIRED OUTPUT_SIZE OF: ',output_size,'. THE MASK WILL BE CROPPED TO FIT EXPECTED OUTPUT SIZE.')
- print('='*20)
- print('Object width: ',object_width)
- print('Set low to: ',low)
- print('Set high to: ',high)
- print('Class: ',mask.max())
+ high = col_min + output_size
+ print("=" * 20)
+ print(
+ "WARNING: MASK IS >= THAN DESIRED OUTPUT_SIZE OF: ",
+ output_size,
+ ". THE MASK WILL BE CROPPED TO FIT EXPECTED OUTPUT SIZE.",
+ )
+ print("=" * 20)
+ print("Object width: ", object_width)
+ print("Set low to: ", low)
+ print("Set high to: ", high)
+ print("Class: ", mask.max())
if low >= high:
- print('Low', low)
- print('High',high)
- print('Object width: ',object_width)
- print('Mask width: ',mask.shape[1])
- random_right = np.random.randint(low,high)
+ print("Low", low)
+ print("High", high)
+ print("Object width: ", object_width)
+ print("Mask width: ", mask.shape[1])
+ random_right = np.random.randint(low, high)
left = random_right - output_size
- low_down = row_max+max(output_size-row_min,0)
- high_down = min(row_max+abs(output_size-object_height),mask.shape[0])
- if object_height>=output_size:
+ low_down = row_max + max(output_size - row_min, 0)
+ high_down = min(row_max + abs(output_size - object_height), mask.shape[0])
+ if object_height >= output_size:
low_down = row_min
- high_down = row_min+output_size
- print('='*20)
- print('WARNING: MASK IS >= THAN DESIRED OUTPUT_SIZE OF: ',output_size,'. THE MASK WILL BE CROPPED TO FIT EXPECTED OUTPUT SIZE.')
- print('='*20)
- print('Object height: ',object_height)
- print('Set low to: ',low_down)
- print('Set high to: ',high_down)
- print('Class: ',mask.max())
- random_down = np.random.randint(low_down,high_down)
+ high_down = row_min + output_size
+ print("=" * 20)
+ print(
+ "WARNING: MASK IS >= THAN DESIRED OUTPUT_SIZE OF: ",
+ output_size,
+ ". THE MASK WILL BE CROPPED TO FIT EXPECTED OUTPUT SIZE.",
+ )
+ print("=" * 20)
+ print("Object height: ", object_height)
+ print("Set low to: ", low_down)
+ print("Set high to: ", high_down)
+ print("Class: ", mask.max())
+ random_down = np.random.randint(low_down, high_down)
up = random_down - output_size
- #Unite Other Deformation Masks
- if len(masks)>0:
+ # Unite Other Deformation Masks
+ if len(masks) > 0:
for k in range(len(masks)):
- if k!=index:
+ if k != index:
mask = mask + masks[k]
- mask = mask[up:random_down,left:random_right]
- image_path = get_insar_path(annotation_path,root_path=raw_insar_path)
+ mask = mask[up:random_down, left:random_right]
+ image_path = get_insar_path(annotation_path, root_path=raw_insar_path)
image = cv.imread(image_path)
if verbose:
-
- insar_path = get_insar_path(annotation_path,root_path=raw_insar_path)
+ insar_path = get_insar_path(annotation_path, root_path=raw_insar_path)
insar = cv.imread(insar_path)
- insar = cv.cvtColor(insar,cv.COLOR_BGR2RGB)
- print(insar[up:random_down,left:random_right].shape)
+ insar = cv.cvtColor(insar, cv.COLOR_BGR2RGB)
+ print(insar[up:random_down, left:random_right].shape)
print(mask.shape)
- plt.imshow(insar[up:random_down,left:random_right,:])
- plt.axis('off')
+ plt.imshow(insar[up:random_down, left:random_right, :])
+ plt.axis("off")
plt.show()
- cmap = plt.get_cmap('tab10', 7)
- plt.imshow(mask,cmap=cmap,vmax=6.5,vmin=-0.5)
- plt.axis('off')
+ cmap = plt.get_cmap("tab10", 7)
+ plt.imshow(mask, cmap=cmap, vmax=6.5, vmin=-0.5)
+ plt.axis("off")
plt.show()
if image is None:
- print('=' * 40)
- print('Error. Image path not found\n')
+ print("=" * 40)
+ print("Error. Image path not found\n")
print(image_path)
- print('=' * 40)
- return image[up:random_down,left:random_right,:], mask
+ print("=" * 40)
+ return image[up:random_down, left:random_right, :], mask
-def save_crops(annotation_folder='annotations/',save_path = 'Hephaestus/labeled/',mask_path='Hephaestus/masks/',raw_insar_path='Hephaestus_Raw/',out_size=224,verbose=False):
- '''
+def save_crops(
+ annotation_folder="annotations/",
+ save_path="Hephaestus/labeled/",
+ mask_path="Hephaestus/masks/",
+ raw_insar_path="Hephaestus_Raw/",
+ out_size=224,
+ verbose=False,
+):
+ """
:param annotation_folder: folder of annotation jsons
:param save_path: folder path for generated images
:param mask_path: folder path for generated masks
:return: Label vector [ Deformation/Non Deformation ( 0 for Non Deformation, 1,2,3,4,5 for Mogi, Dyke, Sill, Spheroid, Earthquake), Phase (0 -> Rest, 1-> Unrest, 2-> Rebound), Intensity Level (0->Low, 1-> Medium, 2->High, 3-> Earthquake (Not volcanic activity related event intensity))]
- '''
+ """
- print('='*40)
- print('Cropping Hephaestus')
- print('='*40)
+ print("=" * 40)
+ print("Cropping Hephaestus")
+ print("=" * 40)
annotations = os.listdir(annotation_folder)
c = 0
- label_path = mask_path[:-6]+'cls_labels/'
+ label_path = mask_path[:-6] + "cls_labels/"
multiple_ctr = 0
- class_dir = {'Mogi':1, 'Dyke':2, 'Sill': 3, 'Spheroid':4, 'Earthquake':5, 'Low':0,'Medium':1,'High':2,'None':0,'Rest':0,'Unrest':1,'Rebound':2,'Unidentified':6}
+ class_dir = {
+ "Mogi": 1,
+ "Dyke": 2,
+ "Sill": 3,
+ "Spheroid": 4,
+ "Earthquake": 5,
+ "Low": 0,
+ "Medium": 1,
+ "High": 2,
+ "None": 0,
+ "Rest": 0,
+ "Unrest": 1,
+ "Rebound": 2,
+ "Unidentified": 6,
+ }
for file in tqdm(annotations):
label_json = {}
- annotation = read_annotation(annotation_folder+file)
-
- if 'Non_Deformation' in annotation['label']:
- image_path = get_insar_path(annotation_folder+file,root_path=raw_insar_path)
- image = cv.imread(image_path)
- tiles = image_tiling(image,tile_size=out_size)
- for idx,tile in enumerate(tiles):
- if image is None:
- print(image_path)
- print(file)
- continue
- #image = data_utilities.random_crop(image)
- cv.imwrite(save_path+'0/'+file[:-5]+'_'+str(idx)+'.png',tile)
- label_json['Deformation'] = [0]
- label_json['Intensity'] = 0
- label_json['Phase'] = class_dir[annotation['phase']]
- label_json['frameID'] = annotation['frameID']
- label_json['primary_date'] = annotation['primary_date']
- label_json['secondary_date'] = annotation['secondary_date']
- json_writer = open(label_path+file[:-5]+'_'+str(idx)+'.json','w')
- json.dump(label_json,json_writer)
- elif int(annotation['is_crowd'])==0:
- if 'Non_Deformation' not in annotation['label']:
- image, mask = crop_around_object(annotation_path=annotation_folder+file,verbose=False,raw_insar_path=raw_insar_path,output_size=out_size)
- folder = str(class_dir[annotation['activity_type'][0]])
- cv.imwrite(save_path+folder+'/'+file[:-5]+'.png',image)
- cv.imwrite(mask_path+folder+'/'+file[:-5]+'.png',mask)
- label_json['Deformation'] = [class_dir[annotation['activity_type'][0]]]
- if folder!=str(5):
- label_json['Intensity'] = class_dir[annotation['intensity_level'][0]]
+ annotation = read_annotation(annotation_folder + file)
+
+ if "Non_Deformation" in annotation["label"]:
+ image_path = get_insar_path(
+ annotation_folder + file, root_path=raw_insar_path
+ )
+ image = cv.imread(image_path)
+ tiles = image_tiling(image, tile_size=out_size)
+ for idx, tile in enumerate(tiles):
+ if image is None:
+ print(image_path)
+ print(file)
+ continue
+ # image = data_utilities.random_crop(image)
+ cv.imwrite(save_path + "0/" + file[:-5] + "_" + str(idx) + ".png", tile)
+ label_json["Deformation"] = [0]
+ label_json["Intensity"] = 0
+ label_json["Phase"] = class_dir[annotation["phase"]]
+ label_json["frameID"] = annotation["frameID"]
+ label_json["primary_date"] = annotation["primary_date"]
+ label_json["secondary_date"] = annotation["secondary_date"]
+ json_writer = open(
+ label_path + file[:-5] + "_" + str(idx) + ".json", "w"
+ )
+ json.dump(label_json, json_writer)
+ elif int(annotation["is_crowd"]) == 0:
+ if "Non_Deformation" not in annotation["label"]:
+ image, mask = crop_around_object(
+ annotation_path=annotation_folder + file,
+ verbose=False,
+ raw_insar_path=raw_insar_path,
+ output_size=out_size,
+ )
+ folder = str(class_dir[annotation["activity_type"][0]])
+ cv.imwrite(save_path + folder + "/" + file[:-5] + ".png", image)
+ cv.imwrite(mask_path + folder + "/" + file[:-5] + ".png", mask)
+ label_json["Deformation"] = [class_dir[annotation["activity_type"][0]]]
+ if folder != str(5):
+ label_json["Intensity"] = class_dir[
+ annotation["intensity_level"][0]
+ ]
else:
- label_json['Intensity'] = 3
- label_json['Phase'] = class_dir[annotation['phase']]
- label_json['frameID'] = annotation['frameID']
- label_json['primary_date'] = annotation['primary_date']
- label_json['secondary_date'] = annotation['secondary_date']
- json_writer = open(label_path + file, 'w')
+ label_json["Intensity"] = 3
+ label_json["Phase"] = class_dir[annotation["phase"]]
+ label_json["frameID"] = annotation["frameID"]
+ label_json["primary_date"] = annotation["primary_date"]
+ label_json["secondary_date"] = annotation["secondary_date"]
+ json_writer = open(label_path + file, "w")
json.dump(label_json, json_writer)
- elif int(annotation['is_crowd'])>0:
- for deformation in range(len(annotation['segmentation_mask'])):
- image, mask = crop_around_object(annotation_path=annotation_folder + file, verbose=False,raw_insar_path=raw_insar_path,index=deformation,output_size=out_size)
+ elif int(annotation["is_crowd"]) > 0:
+ for deformation in range(len(annotation["segmentation_mask"])):
+ image, mask = crop_around_object(
+ annotation_path=annotation_folder + file,
+ verbose=False,
+ raw_insar_path=raw_insar_path,
+ index=deformation,
+ output_size=out_size,
+ )
if verbose:
- print('Deformation index:',deformation)
- print('Seg length: ',len(annotation['segmentation_mask']))
- print('Annotation length: ',len(annotation['activity_type']))
- print('File: ',file)
-
- folder = str(class_dir[annotation['activity_type'][deformation]])
- cv.imwrite(save_path + folder + '/' + file[:-5] + '_'+ str(deformation) + '.png', image)
- cv.imwrite(mask_path + folder + '/' + file[:-5] + '_'+ str(deformation) + '.png', mask)
- label_json['Deformation'] = list(np.unique(mask))#class_dir[annotation['activity_type'][deformation]]
- label_json['Deformation'] = [int(x) for x in label_json['Deformation']]
- label_json['Deformation'].remove(0)
- if folder!=str(5):
- label_json['Intensity'] = class_dir[annotation['intensity_level'][deformation]]
+ print("Deformation index:", deformation)
+ print("Seg length: ", len(annotation["segmentation_mask"]))
+ print("Annotation length: ", len(annotation["activity_type"]))
+ print("File: ", file)
+
+ folder = str(class_dir[annotation["activity_type"][deformation]])
+ cv.imwrite(
+ save_path
+ + folder
+ + "/"
+ + file[:-5]
+ + "_"
+ + str(deformation)
+ + ".png",
+ image,
+ )
+ cv.imwrite(
+ mask_path
+ + folder
+ + "/"
+ + file[:-5]
+ + "_"
+ + str(deformation)
+ + ".png",
+ mask,
+ )
+ label_json["Deformation"] = list(
+ np.unique(mask)
+ ) # class_dir[annotation['activity_type'][deformation]]
+ label_json["Deformation"] = [int(x) for x in label_json["Deformation"]]
+ label_json["Deformation"].remove(0)
+ if folder != str(5):
+ label_json["Intensity"] = class_dir[
+ annotation["intensity_level"][deformation]
+ ]
else:
- label_json['Intensity'] = 3
- label_json['Phase'] = class_dir[annotation['phase']]
- label_json['frameID'] = annotation['frameID']
- label_json['primary_date'] = annotation['primary_date']
- label_json['secondary_date'] = annotation['secondary_date']
- json_writer = open(label_path + file[:-5] + '_'+ str(deformation) + '.json', 'w')
+ label_json["Intensity"] = 3
+ label_json["Phase"] = class_dir[annotation["phase"]]
+ label_json["frameID"] = annotation["frameID"]
+ label_json["primary_date"] = annotation["primary_date"]
+ label_json["secondary_date"] = annotation["secondary_date"]
+ json_writer = open(
+ label_path + file[:-5] + "_" + str(deformation) + ".json", "w"
+ )
json.dump(label_json, json_writer)
- multiple_ctr +=1
+ multiple_ctr += 1
- print('='*40)
- print('Cropping completed')
- print('='*40)
+ print("=" * 40)
+ print("Cropping completed")
+ print("=" * 40)
-def image_tiling(image,tile_size=64):
+def image_tiling(image, tile_size=64):
if image is None:
return []
- max_rows = image.shape[0]//tile_size
- max_cols = image.shape[1]//tile_size
+ max_rows = image.shape[0] // tile_size
+ max_cols = image.shape[1] // tile_size
tiles = []
for i in range(max_rows):
starting_row = i * tile_size
for j in range(max_cols):
starting_col = j * tile_size
- img = image[starting_row:starting_row+tile_size,starting_col:starting_col+tile_size]
+ img = image[
+ starting_row : starting_row + tile_size,
+ starting_col : starting_col + tile_size,
+ ]
tiles.append(img)
return tiles
-
diff --git a/configs/configs.json b/configs/configs.json
index 19be24b..8ad1f21 100644
--- a/configs/configs.json
+++ b/configs/configs.json
@@ -12,24 +12,15 @@
"batch_size": 64,
"epochs": 200,
"lr": 0.03,
- "num_workers": 16,
+ "num_workers": 2,
"mixed_precision": false,
"pretrained": true,
"resolution": 224,
"num_channels": 3,
- "gpu": -1,
"print_frequency": 10,
"resume_checkpoint": "YOUR_CHECKPOINT",
"start_epoch": 0,
-
- "distributed": true,
- "multiprocessing_distributed": true,
- "world_size": 1,
- "rank": 0,
- "dist_url": "tcp://127.0.0.1:15151",
- "dist_backend": "nccl",
"seed": ""
-
}
diff --git a/configs/method/mocov2/mocov2.json b/configs/method/mocov2/mocov2.json
index 2c2f47e..6110047 100644
--- a/configs/method/mocov2/mocov2.json
+++ b/configs/method/mocov2/mocov2.json
@@ -4,7 +4,7 @@
"momentum": 0.9,
"weight_decay": 1e-4,
"moco_dim":128,
- "moco_k":65646,
+ "moco_k":65536,
"moco_m":0.999,
"moco_t":0.07,
"mlp": true,
diff --git a/examples.png b/docs/examples.png
similarity index 100%
rename from examples.png
rename to docs/examples.png
diff --git a/hephaestus-logo.png b/docs/hephaestus-logo.png
similarity index 100%
rename from hephaestus-logo.png
rename to docs/hephaestus-logo.png
diff --git a/volcano_distribution.png b/docs/volcano_distribution.png
similarity index 100%
rename from volcano_distribution.png
rename to docs/volcano_distribution.png
diff --git a/hephaestus.slurm b/hephaestus.slurm
new file mode 100644
index 0000000..d5f963f
--- /dev/null
+++ b/hephaestus.slurm
@@ -0,0 +1,50 @@
+#!/bin/bash -l
+#SBATCH --job-name=Hephaestus # Job name
+#SBATCH --output=Hephaestus.out # Stdout (%j expands to jobId)
+#SBATCH --error=Hephaestus.err # Stderr (%j expands to jobId)
+#SBATCH --ntasks=4 # Number of tasks(processes)
+#SBATCH --nodes=2 # Number of nodes requested
+#SBATCH --gres=gpu:2 # GPUs per node -- must be equal to ntasks per node in case of data parallelism. Change for model parallelism
+#SBATCH --ntasks-per-node=2 # Tasks per node
+#SBATCH --cpus-per-task=10 # Threads per task
+#SBATCH --time=24:00:00 # walltime
+#SBATCH --mem=54G # memory per NODE
+#SBATCH --partition= # Partition
+#SBATCH --account= # Replace with your system project
+#SBATCH --wait-all-nodes=1 # Do not begin the execution until all nodes are ready for use
+
+module purge
+
+module load gnu/8
+module load java/12.0.2
+module load cuda/10.1.168
+module load intel/18
+module load intelmpi/2018
+module load tftorch/270-191
+
+echo "Start at `date`"
+echo "SLURM_NTASKS=$SLURM_NTASKS"
+
+if [ x$SLURM_CPUS_PER_TASK == x ]; then
+ export OMP_NUM_THREADS=1
+else
+ export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
+fi
+
+export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
+echo "MASTER_ADDR=$MASTER_ADDR"
+
+export MASTER_PORT=12802
+echo "MASTER_PORT=$MASTER_PORT"
+
+NODES=($( scontrol show hostname $SLURM_NODELIST | uniq ))
+export NUM_NODES=${#NODES[@]}
+WORKERS=$(printf '%s-ib:'${SLURM_NTASKS_PER_NODE}',' "${NODES[@]}" | sed 's/,$//')
+
+echo "SLULM_NODELIST=$SLULM_NODELIST"
+echo "WORKERS=$WORKERS"
+echo "NODES=$NODES"
+
+srun python3 main.py
+
+echo "End at `date`"
diff --git a/main.py b/main.py
index 63f9351..1ab566c 100644
--- a/main.py
+++ b/main.py
@@ -1,65 +1,359 @@
-import torchvision
-import sys
-import numpy as np
-import torch
-from torchvision import transforms
+import builtins
+import json
+import math
+import os
+import random
+import shutil
+import time
+import warnings
+
import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
import torch.nn as nn
-import torchvision
-import time
-import timm
+import torch.nn.parallel
+import torch.optim
+import torch.utils.data
+import torch.utils.data.distributed
import wandb
-import kornia
-import random
-import os
-import torch.distributed as dist
-import torch.multiprocessing as mp
-from torch.nn.parallel import DistributedDataParallel as DDP
-import json
-import pprint
-from tqdm import tqdm
-from utilities.utils import *
-from models.model_utils import *
-import self_supervised
-import self_supervised.mocov2
-from self_supervised.mocov2 import mocov2 as moco
-#os.environ['CUDA_VISIBLE_DEVICES'] ='1,2'
-def exec_ssl(config):
- announce_stuff('Initializing ' + config['method'],up=False)
- if config['method']=='mocov2':
- moco.main(config)
-
-if __name__ == '__main__':
-
- #Parse configurations
- config_path = 'configs/configs.json'
- config = prepare_configuration(config_path)
- json.dump(config,open(config['checkpoint_path']+'/config.json','w'))
- # Set seeds
- seed = config['seed']
- if seed is not None:
- torch.manual_seed(seed)
- np.random.seed(seed)
- random.seed(seed)
+import Dataset
+from self_supervised.mocov2 import builder
+from utilities.utils import prepare_configuration
+
+
+def is_distributed():
+ if "WORLD_SIZE" in os.environ:
+ return int(os.environ["WORLD_SIZE"]) > 1
+ if "SLURM_NTASKS" in os.environ:
+ return int(os.environ["SLURM_NTASKS"]) > 1
+ return False
+
+
+def world_info_from_env():
+ local_rank = 0
+ for v in ("LOCAL_RANK", "SLURM_LOCALID"):
+ if v in os.environ:
+ local_rank = int(os.environ[v])
+ break
+ global_rank = 0
+ for v in ("RANK", "SLURM_PROCID"):
+ if v in os.environ:
+ global_rank = int(os.environ[v])
+ break
+ world_size = 1
+ for v in ("WORLD_SIZE", "SLURM_NTASKS"):
+ if v in os.environ:
+ world_size = int(os.environ[v])
+ break
+ return local_rank, global_rank, world_size
+
+
+def is_global_master(args):
+ return args["rank"] == 0
+
+
+def train(train_loader, model, criterion, optimizer, epoch, args):
+ print("Training epoch: ", epoch)
+ batch_time = AverageMeter("Time", ":6.3f")
+ data_time = AverageMeter("Data", ":6.3f")
+ losses = AverageMeter("Loss", ":.4e")
+ top1 = AverageMeter("Acc@1", ":6.2f")
+ top5 = AverageMeter("Acc@5", ":6.2f")
+ progress = ProgressMeter(
+ len(train_loader),
+ [batch_time, data_time, losses, top1, top5],
+ prefix="Epoch: [{}]".format(epoch),
+ )
+
+ # switch to train mode
+ model.train()
+
+ end = time.time()
+ for i, (images, _) in enumerate(train_loader):
+ # measure data loading time
+ data_time.update(time.time() - end)
+
+ images[0] = images[0].cuda(non_blocking=True)
+ images[1] = images[1].cuda(non_blocking=True)
+
+ # compute output
+ output, target = model(im_q=images[0], im_k=images[1])
+ loss = criterion(output, target)
+
+ # acc1/acc5 are (K+1)-way contrast classifier accuracy
+ # measure accuracy and record loss
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
+ losses.update(loss.item(), images[0].size(0))
+ top1.update(acc1[0], images[0].size(0))
+ top5.update(acc5[0], images[0].size(0))
+
+ # compute gradient and do SGD step
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if i % args["print_frequency"] == 0:
+ progress.display(i)
+ if is_global_master(args):
+ wandb.log(
+ {
+ "top1": top1.avg,
+ "top5": top5.avg,
+ "loss": loss.item(),
+ "epoch": epoch,
+ "iteration": i,
+ }
+ )
+
+
+def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
+ torch.save(state, filename)
+ if is_best:
+ shutil.copyfile(filename, "model_best.pth.tar")
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
- #Initialize wandb
+ def __init__(self, name, fmt=":f"):
+ self.name = name
+ self.fmt = fmt
+ self.reset()
- if config['resume_wandb']:
- id_json = json.load(open(config['checkpoint_path']+'/id.json'))
- config['wandb_id']=id_json['wandb_id']
- wandb.init(project=config['wandb_project'],entity=config['wandb_entity'],id=config['wandb_id'],resume=config['resume_wandb'])
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def __str__(self):
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
+ return fmtstr.format(**self.__dict__)
+
+
+class ProgressMeter(object):
+ def __init__(self, num_batches, meters, prefix=""):
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+ self.meters = meters
+ self.prefix = prefix
+
+ def display(self, batch):
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
+ entries += [str(meter) for meter in self.meters]
+ print("\t".join(entries))
+
+ def _get_batch_fmtstr(self, num_batches):
+ num_digits = len(str(num_batches // 1))
+ fmt = "{:" + str(num_digits) + "d}"
+ return "[" + fmt + "/" + fmt.format(num_batches) + "]"
+
+
+def adjust_learning_rate(optimizer, epoch, args):
+ """Decay the learning rate based on schedule"""
+ lr = args["lr"]
+ if args["cos"]: # cosine lr schedule
+ lr *= 0.5 * (1.0 + math.cos(math.pi * epoch / args["epochs"]))
+ else: # stepwise lr schedule
+ for milestone in args["schedule"]:
+ lr *= 0.1 if epoch >= milestone else 1.0
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ with torch.no_grad():
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ # print(correct[:k].shape)
+ # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+ correct_k = (
+ correct[:k].reshape(k * correct.shape[1]).float().sum(0, keepdim=True)
+ )
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+def exec_model(model, args):
+ if args["seed"] is not None:
+ random.seed(args["seed"])
+ torch.manual_seed(args["seed"])
+ cudnn.deterministic = True
+ warnings.warn(
+ "You have chosen to seed training. "
+ "This will turn on the CUDNN deterministic setting, "
+ "which can slow down your training considerably! "
+ "You may see unexpected behavior when restarting "
+ "from checkpoints."
+ )
+
+ if is_distributed():
+ if "SLURM_PROCID" in os.environ:
+ args["local_rank"], args["rank"], args["world_size"] = world_info_from_env()
+ args["num_workers"] = int(os.environ["SLURM_CPUS_PER_TASK"])
+ os.environ["LOCAL_RANK"] = str(args["local_rank"])
+ os.environ["RANK"] = str(args["rank"])
+ os.environ["WORLD_SIZE"] = str(args["world_size"])
+ dist.init_process_group(
+ backend="nccl",
+ world_size=args["world_size"],
+ rank=args["rank"],
+ )
+ os.environ["WANDB_MODE"] = "offline"
+ else:
+ args["local_rank"], _, _ = world_info_from_env()
+ dist.init_process_group(backend="nccl")
+ args["world_size"] = dist.get_world_size()
+ args["rank"] = dist.get_rank()
+
+ torch.cuda.set_device(args["local_rank"])
+
+ # suppress printing if not master
+ if not is_global_master(args):
+
+ def print_pass(*args):
+ pass
+
+ builtins.print = print_pass
else:
- announce_stuff('Initializing Wandb')
- id = wandb.util.generate_id()
- config['wandb_id'] = id
- wandb.init(project=config['wandb_project'], entity=config['wandb_entity'],config=config,id=id,resume='allow')
- json.dump({'wandb_id':id},open(config['checkpoint_path']+'/id.json','w'))
+ raise NotImplementedError("Only DistributedDataParallel is supported.")
+
+ print("=> creating model '{}'".format(args["architecture"]))
+
+ if is_global_master(args):
+ # Initialize wandb
+ print("Initializing Wandb")
+ if args["resume_wandb"]:
+ id_json = json.load(open(args["checkpoint_path"] + "/id.json"))
+ args["wandb_id"] = id_json["wandb_id"]
+ wandb.init(
+ project=args["wandb_project"],
+ entity=args["wandb_entity"],
+ id=args["wandb_id"],
+ resume=args["resume_wandb"],
+ )
+ else:
+ id = wandb.sdk.lib.runid.generate_id()
+ args["wandb_id"] = id
+ wandb.init(
+ project=args["wandb_project"],
+ entity=args["wandb_entity"],
+ config=args,
+ id=id,
+ resume="allow",
+ )
+ json.dump({"wandb_id": id}, open(args["checkpoint_path"] + "/id.json", "w"))
+ wandb.watch(model)
+ print(model)
+
+ model.cuda()
+ args["batch_size"] = int(args["batch_size"] / int(args["world_size"]))
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[args["local_rank"]]
+ )
+
+ # define loss function (criterion) and optimizer
+ criterion = nn.CrossEntropyLoss().cuda()
+
+ optimizer = torch.optim.SGD(
+ model.parameters(),
+ args["lr"],
+ momentum=args["momentum"],
+ weight_decay=args["weight_decay"],
+ )
+
+ # optionally resume from a checkpoint
+ if args["resume_checkpoint"]:
+ if os.path.isfile(args["resume_checkpoint"]):
+ print("=> loading checkpoint '{}'".format(args["resume_checkpoint"]))
+ checkpoint = torch.load(
+ args["resume_checkpoint"], map_location=torch.cuda.get_device_name()
+ )
+ args["start_epoch"] = checkpoint["epoch"]
+ model.load_state_dict(checkpoint["state_dict"])
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ print(
+ "=> loaded checkpoint '{}' (epoch {})".format(
+ args["resume_checkpoint"], checkpoint["epoch"]
+ )
+ )
+ else:
+ print("=> no checkpoint found at '{}'".format(args["resume_checkpoint"]))
- announce_stuff('Starting project with the following settings:')
- pprint.pprint(config)
- announce_stuff('',up=False)
- exec_ssl(config)
+ cudnn.benchmark = True
+ print("Initializing Dataset")
+ train_dataset = Dataset.Dataset(args)
+ print("Dataset initialized. Size: ", len(train_dataset))
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args["batch_size"],
+ shuffle=(train_sampler is None),
+ num_workers=args["num_workers"],
+ pin_memory=True,
+ sampler=train_sampler,
+ drop_last=True,
+ )
+
+ for epoch in range(args["start_epoch"], args["epochs"]):
+ train_sampler.set_epoch(epoch)
+ adjust_learning_rate(optimizer, epoch, args)
+
+ # train for one epoch
+ train(train_loader, model, criterion, optimizer, epoch, args)
+
+ if is_global_master(args):
+ save_checkpoint(
+ {
+ "epoch": epoch + 1,
+ "arch": args["architecture"],
+ "state_dict": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ },
+ is_best=False,
+ filename=args["checkpoint_path"]
+ + "/checkpoint_{:04d}.pth.tar".format(epoch),
+ )
+ if is_global_master(args):
+ wandb.finish()
+
+
+if __name__ == "__main__":
+ # Parse configurations
+ config_path = "configs/configs.json"
+ config = prepare_configuration(config_path)
+ json.dump(config, open(config["checkpoint_path"] + "/config.json", "w"))
+ if config["method"] == "mocov2":
+ model = builder.MoCo(
+ config,
+ config["moco_dim"],
+ config["moco_k"],
+ config["moco_m"],
+ config["moco_t"],
+ config["mlp"],
+ )
+ else:
+ raise NotImplementedError(f'{config["method"]} is not supported.')
+ exec_model(model, config)
diff --git a/models/model_utils.py b/models/model_utils.py
index d65c2f5..56339eb 100644
--- a/models/model_utils.py
+++ b/models/model_utils.py
@@ -1,17 +1,20 @@
-import torch
-import torch.nn as nn
import timm
+
def create_model(configs):
- if 'vit' not in configs['architecture']:
- model = timm.models.create_model(configs['architecture'].lower(),pretrained=True,num_classes=0)
+ if "vit" not in configs["architecture"]:
+ model = timm.models.create_model(
+ configs["architecture"].lower(), pretrained=True, num_classes=0
+ )
else:
- model = timm.models.vision_transformer.VisionTransformer(img_size=int(configs['resolution']),
- patch_size=int(configs['patches']),
- in_chans=configs['num_channel'],
- embed_dim=configs['embed_dim'],
- depth=configs['depth'],
- num_heads=configs['num_heads'],
- num_classes=0)
+ model = timm.models.vision_transformer.VisionTransformer(
+ img_size=int(configs["resolution"]),
+ patch_size=int(configs["patches"]),
+ in_chans=configs["num_channel"],
+ embed_dim=configs["embed_dim"],
+ depth=configs["depth"],
+ num_heads=configs["num_heads"],
+ num_classes=0,
+ )
- return model
\ No newline at end of file
+ return model
diff --git a/self_supervised/mocov2/builder.py b/self_supervised/mocov2/builder.py
index 4146fc0..9a1fb3a 100644
--- a/self_supervised/mocov2/builder.py
+++ b/self_supervised/mocov2/builder.py
@@ -1,13 +1,15 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import timm
import torch
import torch.nn as nn
-import timm
+
class MoCo(nn.Module):
"""
Build a MoCo model with: a query encoder, a key encoder, and a queue
https://arxiv.org/abs/1911.05722
"""
+
def __init__(self, config, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
"""
dim: feature dimension (default: 128)
@@ -23,15 +25,29 @@ def __init__(self, config, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
# create the encoders
# num_classes is the output fc dimension
- self.encoder_q = timm.create_model(config['architecture'].lower(),num_classes=dim,pretrained=config['pretrained'])#base_encoder(num_classes=dim)
- self.encoder_k = timm.create_model(config['architecture'].lower(),num_classes=dim,pretrained=config['pretrained'])#base_encoder(num_classes=dim)
+ self.encoder_q = timm.create_model(
+ config["architecture"].lower(),
+ num_classes=dim,
+ pretrained=config["pretrained"],
+ ) # base_encoder(num_classes=dim)
+ self.encoder_k = timm.create_model(
+ config["architecture"].lower(),
+ num_classes=dim,
+ pretrained=config["pretrained"],
+ ) # base_encoder(num_classes=dim)
if mlp: # hack: brute-force replacement
dim_mlp = self.encoder_q.fc.weight.shape[1]
- self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
- self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)
-
- for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
+ self.encoder_q.fc = nn.Sequential(
+ nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc
+ )
+ self.encoder_k.fc = nn.Sequential(
+ nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc
+ )
+
+ for param_q, param_k in zip(
+ self.encoder_q.parameters(), self.encoder_k.parameters()
+ ):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
@@ -46,8 +62,10 @@ def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
- for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
- param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
+ for param_q, param_k in zip(
+ self.encoder_q.parameters(), self.encoder_k.parameters()
+ ):
+ param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
@@ -58,10 +76,12 @@ def _dequeue_and_enqueue(self, keys):
ptr = int(self.queue_ptr)
- assert self.K % batch_size == 0 # for simplicity
+ assert (
+ self.K % batch_size == 0
+ ), f"{self.K} mod {batch_size} should be equal to 0." # for simplicity
# replace the keys at ptr (dequeue and enqueue)
- self.queue[:, ptr:ptr + batch_size] = keys.T
+ self.queue[:, ptr : ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K # move pointer
self.queue_ptr[0] = ptr
@@ -142,9 +162,9 @@ def forward(self, im_q, im_k):
# compute logits
# Einstein sum is more intuitive
# positive logits: Nx1
- l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
+ l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
# negative logits: NxK
- l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
+ l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
@@ -168,9 +188,10 @@ def concat_all_gather(tensor):
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
- tensors_gather = [torch.ones_like(tensor)
- for _ in range(torch.distributed.get_world_size())]
+ tensors_gather = [
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
+ ]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
- return output
\ No newline at end of file
+ return output
diff --git a/self_supervised/mocov2/mocov2.py b/self_supervised/mocov2/mocov2.py
deleted file mode 100644
index 4b4063d..0000000
--- a/self_supervised/mocov2/mocov2.py
+++ /dev/null
@@ -1,313 +0,0 @@
-#!/usr/bin/env python
-# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
-import argparse
-import builtins
-import math
-import os
-import random
-import shutil
-import time
-import warnings
-
-import torch
-import torch.nn as nn
-import torch.nn.parallel
-import torch.backends.cudnn as cudnn
-import torch.distributed as dist
-import torch.optim
-import torch.multiprocessing as mp
-import torch.utils.data
-import torch.utils.data.distributed
-import torchvision.transforms as transforms
-import torchvision.datasets as datasets
-import torchvision.models as models
-import self_supervised.mocov2.builder as builder
-import timm
-import kornia
-import Dataset
-import wandb
-
-def main(args):
-
-
- if args["seed"] is not None:
- random.seed(args["seed"])
- torch.manual_seed(args["seed"])
- cudnn.deterministic = True
- warnings.warn('You have chosen to seed training. '
- 'This will turn on the CUDNN deterministic setting, '
- 'which can slow down your training considerably! '
- 'You may see unexpected behavior when restarting '
- 'from checkpoints.')
-
- if args["gpu"] is not None:
- warnings.warn('You have chosen a specific GPU. This will completely '
- 'disable data parallelism.')
-
- if args["dist_url"] == "env://" and args["world_size"] == -1:
- args["world_size"] = int(os.environ["WORLD_SIZE"])
-
- args["distributed"] = args["world_size"] > 1 or args["multiprocessing_distributed"]
-
- ngpus_per_node = torch.cuda.device_count()
- print('Number of gpus: ',ngpus_per_node)
- if args["multiprocessing_distributed"]:
- # Since we have ngpus_per_node processes per node, the total world_size
- # needs to be adjusted accordingly
- args["world_size"] = ngpus_per_node * args["world_size"]
- # Use torch.multiprocessing.spawn to launch distributed processes: the
- # main_worker process function
- mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
- else:
- # Simply call main_worker function
- main_worker(args["gpu"], ngpus_per_node, args)
-
-
-def main_worker(gpu, ngpus_per_node, args):
- args["gpu"] = gpu
-
- # suppress printing if not master
- if args["multiprocessing_distributed"] and args["gpu"] != 0:
- def print_pass(*args):
- pass
- builtins.print = print_pass
-
- if args["gpu"] is not None:
- print("Use GPU: {} for training".format(args["gpu"]))
-
- if args["distributed"]:
- if args["dist_url"] == "env://" and args["rank"] == -1:
- args["rank"] = int(os.environ["RANK"])
- if args["multiprocessing_distributed"]:
- # For multiprocessing distributed training, rank needs to be the
- # global rank among all the processes
- args["rank"] = args["rank"] * ngpus_per_node + gpu
- print('Initializing: ',args["dist_url"],' rank=',args['rank'])
- dist.init_process_group(backend=args["dist_backend"], init_method=args["dist_url"],
- world_size=args["world_size"], rank=args["rank"])
- # create model
- print("=> creating model '{}'".format(args["architecture"]))
- model = builder.MoCo(
- args,
- args["moco_dim"], args["moco_k"], args["moco_m"], args["moco_t"], args["mlp"])
- if args['gpu']==0:
- wandb.init(project=args['wandb_project'],entity=args['wandb_entity'],id=args['wandb_id'],resume=True)
- wandb.watch(model)
- print(model)
-
- if args["distributed"]:
- # For multiprocessing distributed, DistributedDataParallel constructor
- # should always set the single device scope, otherwise,
- # DistributedDataParallel will use all available devices.
- if args["gpu"] is not None:
- torch.cuda.set_device(args["gpu"])
- model.cuda(args["gpu"])
- # When using a single GPU per process and per
- # DistributedDataParallel, we need to divide the batch size
- # ourselves based on the total number of GPUs we have
- args["batch_size"] = int(args["batch_size"] / ngpus_per_node)
- args["workers"] = int((args["num_workers"] + ngpus_per_node - 1) / ngpus_per_node)
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args["gpu"]])
- else:
- model.cuda()
- # DistributedDataParallel will divide and allocate batch_size to all
- # available GPUs if device_ids are not set
- model = torch.nn.parallel.DistributedDataParallel(model)
- elif args["gpu"] is not None:
- torch.cuda.set_device(args["gpu"])
- model = model.cuda(args["gpu"])
- # comment out the following line for debugging
- raise NotImplementedError("Only DistributedDataParallel is supported.")
- else:
- # AllGather implementation (batch shuffle, queue update, etc.) in
- # this code only supports DistributedDataParallel.
- raise NotImplementedError("Only DistributedDataParallel is supported.")
-
- # define loss function (criterion) and optimizer
- criterion = nn.CrossEntropyLoss().cuda(args["gpu"])
-
- optimizer = torch.optim.SGD(model.parameters(), args["lr"],
- momentum=args["momentum"],
- weight_decay=args["weight_decay"])
-
- # optionally resume from a checkpoint
- if args["resume_checkpoint"]:
- if os.path.isfile(args["resume_checkpoint"]):
- print("=> loading checkpoint '{}'".format(args["resume_checkpoint"]))
- if args["gpu"] is None:
- checkpoint = torch.load(args["resume_checkpoint"])
- else:
- # Map model to be loaded to specified single gpu.
- loc = 'cuda:{}'.format(args["gpu"])
- checkpoint = torch.load(args["resume_checkpoint"], map_location=loc)
- args["start_epoch"] = checkpoint['epoch']
- model.load_state_dict(checkpoint['state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- print("=> loaded checkpoint '{}' (epoch {})"
- .format(args["resume_checkpoint"], checkpoint['epoch']))
- else:
- print("=> no checkpoint found at '{}'".format(args["resume_checkpoint"]))
-
- cudnn.benchmark = True
-
- # Data loading code
- #traindir = args["data_path"]
-
- print('Initializing Dataset')
- train_dataset = Dataset.Dataset(args)
- print('Dataset initialized. Size: ',len(train_dataset))
-
- if args["distributed"]:
- train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
- else:
- train_sampler = None
-
- train_loader = torch.utils.data.DataLoader(
- train_dataset, batch_size=args["batch_size"], shuffle=(train_sampler is None),
- num_workers=args["workers"], pin_memory=True, sampler=train_sampler, drop_last=True)
-
- for epoch in range(args["start_epoch"], args["epochs"]):
- if args["distributed"]:
- train_sampler.set_epoch(epoch)
- adjust_learning_rate(optimizer, epoch, args)
-
- # train for one epoch
- train(train_loader, model, criterion, optimizer, epoch, args)
-
- if not args["multiprocessing_distributed"] or (args["multiprocessing_distributed"]
- and args["rank"] % ngpus_per_node == 0):
- save_checkpoint({
- 'epoch': epoch + 1,
- 'arch': args["architecture"],
- 'state_dict': model.state_dict(),
- 'optimizer' : optimizer.state_dict(),
- }, is_best=False, filename=args['checkpoint_path']+'/checkpoint_{:04d}.pth.tar'.format(epoch))
- if args['gpu']==0:
- wandb.finish()
-
-def train(train_loader, model, criterion, optimizer, epoch, args):
- print('Training epoch: ',epoch)
- batch_time = AverageMeter('Time', ':6.3f')
- data_time = AverageMeter('Data', ':6.3f')
- losses = AverageMeter('Loss', ':.4e')
- top1 = AverageMeter('Acc@1', ':6.2f')
- top5 = AverageMeter('Acc@5', ':6.2f')
- progress = ProgressMeter(
- len(train_loader),
- [batch_time, data_time, losses, top1, top5],
- prefix="Epoch: [{}]".format(epoch))
-
- # switch to train mode
- model.train()
-
- end = time.time()
- for i, (images, _) in enumerate(train_loader):
- # measure data loading time
- data_time.update(time.time() - end)
-
- if args["gpu"] is not None:
- images[0] = images[0].cuda(args["gpu"], non_blocking=True)
- images[1] = images[1].cuda(args["gpu"], non_blocking=True)
-
- # compute output
- output, target = model(im_q=images[0], im_k=images[1])
- loss = criterion(output, target)
-
- # acc1/acc5 are (K+1)-way contrast classifier accuracy
- # measure accuracy and record loss
- acc1, acc5 = accuracy(output, target, topk=(1, 5))
- losses.update(loss.item(), images[0].size(0))
- top1.update(acc1[0], images[0].size(0))
- top5.update(acc5[0], images[0].size(0))
-
- # compute gradient and do SGD step
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
-
- if i % args["print_frequency"] == 0:
- progress.display(i)
- if args['gpu']==0:
- wandb.log({'top1':top1.avg,'top5':top5.avg,'loss':loss.item(),'epoch':epoch,'iteration':i})
-
-
-def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
- torch.save(state, filename)
- if is_best:
- shutil.copyfile(filename, 'model_best.pth.tar')
-
-
-class AverageMeter(object):
- """Computes and stores the average and current value"""
- def __init__(self, name, fmt=':f'):
- self.name = name
- self.fmt = fmt
- self.reset()
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
- def __str__(self):
- fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
- return fmtstr.format(**self.__dict__)
-
-
-class ProgressMeter(object):
- def __init__(self, num_batches, meters, prefix=""):
- self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
- self.meters = meters
- self.prefix = prefix
-
- def display(self, batch):
- entries = [self.prefix + self.batch_fmtstr.format(batch)]
- entries += [str(meter) for meter in self.meters]
- print('\t'.join(entries))
-
- def _get_batch_fmtstr(self, num_batches):
- num_digits = len(str(num_batches // 1))
- fmt = '{:' + str(num_digits) + 'd}'
- return '[' + fmt + '/' + fmt.format(num_batches) + ']'
-
-
-def adjust_learning_rate(optimizer, epoch, args):
- """Decay the learning rate based on schedule"""
- lr = args["lr"]
- if args["cos"]: # cosine lr schedule
- lr *= 0.5 * (1. + math.cos(math.pi * epoch / args["epochs"]))
- else: # stepwise lr schedule
- for milestone in args["schedule"]:
- lr *= 0.1 if epoch >= milestone else 1.
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
-
-
-def accuracy(output, target, topk=(1,)):
- """Computes the accuracy over the k top predictions for the specified values of k"""
- with torch.no_grad():
- maxk = max(topk)
- batch_size = target.size(0)
-
- _, pred = output.topk(maxk, 1, True, True)
- pred = pred.t()
- correct = pred.eq(target.view(1, -1).expand_as(pred))
-
- res = []
- for k in topk:
- #print(correct[:k].shape)
- #correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
- correct_k = correct[:k].reshape(k * correct.shape[1]).float().sum(0, keepdim=True)
- res.append(correct_k.mul_(100.0 / batch_size))
- return res
diff --git a/utilities/augmentations.py b/utilities/augmentations.py
index b90264b..b0f3c61 100644
--- a/utilities/augmentations.py
+++ b/utilities/augmentations.py
@@ -1,26 +1,35 @@
import albumentations as A
+
def get_augmentations(config):
- augmentations = config['augmentations']
+ augmentations = config["augmentations"]
independend_aug = []
- for k,v in augmentations.items():
- if k == 'RandomResizedCrop':
- aug = A.augmentations.RandomResizedCrop(height=v['value'],width=v['value'],p=v['p'])
- elif k=='ColorJitter':
- aug = A.augmentations.ColorJitter(brightness=v['value'][0],contrast=v['value'][1],saturation=v['value'][2],hue=v['value'][3],p=v['p'])
- elif k=='HorizontalFlip':
- aug = A.augmentations.HorizontalFlip(p=v['p'])
- elif k=='VerticalFlip':
- aug = A.augmentations.VerticalFlip(p=v['p'])
- elif k=='GaussianBlur':
- aug = A.augmentations.GaussianBlur(sigma_limit=v['value'],p=v['p'])
- elif k=='ElasticTransform':
- aug = A.augmentations.ElasticTransform(p=v['p'])
- elif k=='Cutout':
- aug = A.augmentations.CoarseDropout(p=v['p'])
- elif k=='GaussianNoise':
- aug = A.augmentations.GaussNoise(p=v['p'])
- elif k=='MultNoise':
- aug = A.augmentations.MultiplicativeNoise(p=v['p'])
+ for k, v in augmentations.items():
+ if k == "RandomResizedCrop":
+ aug = A.augmentations.RandomResizedCrop(
+ height=v["value"], width=v["value"], p=v["p"]
+ )
+ elif k == "ColorJitter":
+ aug = A.augmentations.ColorJitter(
+ brightness=v["value"][0],
+ contrast=v["value"][1],
+ saturation=v["value"][2],
+ hue=v["value"][3],
+ p=v["p"],
+ )
+ elif k == "HorizontalFlip":
+ aug = A.augmentations.HorizontalFlip(p=v["p"])
+ elif k == "VerticalFlip":
+ aug = A.augmentations.VerticalFlip(p=v["p"])
+ elif k == "GaussianBlur":
+ aug = A.augmentations.GaussianBlur(sigma_limit=v["value"], p=v["p"])
+ elif k == "ElasticTransform":
+ aug = A.augmentations.ElasticTransform(p=v["p"])
+ elif k == "Cutout":
+ aug = A.augmentations.CoarseDropout(p=v["p"])
+ elif k == "GaussianNoise":
+ aug = A.augmentations.GaussNoise(p=v["p"])
+ elif k == "MultNoise":
+ aug = A.augmentations.MultiplicativeNoise(p=v["p"])
independend_aug.append(aug)
- return A.Compose(independend_aug)
\ No newline at end of file
+ return A.Compose(independend_aug)
diff --git a/utilities/utils.py b/utilities/utils.py
index 3a955e3..696647f 100644
--- a/utilities/utils.py
+++ b/utilities/utils.py
@@ -1,36 +1,38 @@
-from pathlib import Path
import json
+from pathlib import Path
+
def create_checkpoint_directory(args):
- checkpoint_path = 'checkpoints/' + args['method'].lower() + '/' + args['architecture'].lower() + '/' + args['architecture'].lower() + '_'+ str(args['resolution'])
- Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
+ checkpoint_path = (
+ Path("checkpoints")
+ / args["method"].lower()
+ / args["architecture"].lower()
+ / f'{args["architecture"].lower()}_{str(args["resolution"])}'
+ )
+ checkpoint_path.mkdir(parents=True, exist_ok=True)
return checkpoint_path
-def announce_stuff(announcement,symbol='=',times=20,up=True,down=True):
- if up:
- print(symbol*times)
- print(announcement)
- if down:
- print(symbol*times)
def prepare_configuration(path):
- config = json.load(open(path, 'r'))
- # Create checkpoint path if it does not exist
- checkpoint_path = create_checkpoint_directory(config)
- config['checkpoint_path'] = checkpoint_path
+ # Load configuration files
+ base_cfg = json.load(open(path, "r"))
+ if not base_cfg["seed"]:
+ base_cfg["seed"] = None
- # Load augmentation settings
- augmentation_config = json.load(open(config['augmentation_config'],'r'))
- config.update(augmentation_config)
+ augmentation_cfg = json.load(open(base_cfg["augmentation_config"], "r"))
+ base_cfg.update(augmentation_cfg)
- #Load model settings
- model_config_path = 'configs/method/' + config['method'].lower() + '/' + config['method'].lower() + '.json'
- model_config = json.load(open(model_config_path,'r'))
- config.update(model_config)
+ model_cfg = (
+ Path("configs/method")
+ / base_cfg["method"].lower()
+ / f'{base_cfg["method"].lower()}.json'
+ )
+ with model_cfg.open("r", encoding="UTF-8") as target:
+ model_config = json.load(target)
+ base_cfg.update(model_config)
- if config['seed'] == '':
- config['seed'] = None
- if config['gpu']==-1:
- config['gpu'] = None
+ # Create checkpoint path if it does not exist
+ checkpoint_path = create_checkpoint_directory(base_cfg)
+ base_cfg["checkpoint_path"] = checkpoint_path.as_posix()
- return config
\ No newline at end of file
+ return base_cfg