Skip to content

Commit

Permalink
add supervised training
Browse files Browse the repository at this point in the history
  • Loading branch information
ngbountos committed Mar 15, 2023
1 parent 9b9e55b commit 695d811
Show file tree
Hide file tree
Showing 3 changed files with 337 additions and 1 deletion.
119 changes: 118 additions & 1 deletion Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
import numpy as np
import torch
from tqdm import tqdm

import json
from utilities import augmentations
import random
import einops

np.random.seed(999)
torch.manual_seed(999)
random.seed(999)


class Dataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -65,3 +71,114 @@ def __getitem__(self, index):
index = 0

return insar, 0





class SupervisedDataset(torch.utils.data.Dataset):
def __init__(self, config, mode):
self.config = config
self.mode = mode
self.train_test_split = json.load(open(config['split_path'],'r'))
self.valid_files = self.train_test_split[mode]['0']
self.valid_files.extend(self.train_test_split[mode]['1'])
self.records = []
self.oversampling = config['oversampling']
self.positives = []
self.negatives = []
self.frameIDs = []
self.augmentations = augmentations.get_augmentations(config)

for file in self.valid_files:
annotation = json.load(open(file,'r'))
sample = file.split('/')[-1]
insar_path = self.config['supervised_data_path'] + '/labeled/'+str(annotation['Deformation'][0])+'/' + sample[:-5] + '.png'
mask_path = self.config['supervised_data_path'] + '/masks/'+str(annotation['Deformation'][0])+'/' + sample[:-5] + '.png'

if len(annotation['Deformation'])>1:
for activity in annotation['Deformation']:
insar_path = self.config['supervised_data_path'] + '/labeled/'+str(activity)+'/' + sample[:-5] + '.png'
mask_path = self.config['supervised_data_path'] + '/masks/'+str(activity)+'/' + sample[:-5] + '.png'

if os.path.isfile(insar_path):
break
record = {'frameID':annotation['frameID'],'label':annotation['Deformation'],'intensity':annotation['Intensity'],'phase':annotation['Phase'],'insar_path':insar_path, 'mask_path':mask_path}
if 0 not in record['label']:
self.positives.append(record)
else:
self.negatives.append(record)
self.records.append(record)
self.frameIDs.append(record['frameID'])
self.num_examples = len(self.records)
self.frameIDs = np.unique(self.frameIDs)
self.frame_dict = {}
for idx, frame in enumerate(self.frameIDs):
self.frame_dict[frame] = idx

print(mode + ' Number of ground deformation samples: ',len(self.positives))
print(mode + ' Number of non deformation samples: ',len(self.negatives))
def __len__(self):
return self.num_examples

def read_insar(self, path):
insar = cv.imread(path, 0)
if insar is None:
print("None")
return insar
if self.config['augment'] and self.mode == 'train':
transform = self.augmentations(image=insar)
insar = transform['image']
insar = einops.repeat(insar, 'h w -> c h w', c=3)
insar = torch.from_numpy(insar)/255
return insar

def __getitem__(self, index):
label = None
if self.oversampling and self.mode=='train':
choice = random.randint(0,1)
if choice == 0:
choice_neg = random.randint(0,len(self.negatives)-1)
sample = self.negatives[choice_neg]
else:

choice_pos = random.randint(0,len(self.positives)-1)
sample = self.positives[choice_pos]
else:
sample = self.records[index]

insar = self.read_insar(sample['insar_path'])


if label is None:
if not os.path.isfile(sample['mask_path']) and 0 in sample['label']:
mask = np.zeros((224,224))
else:
mask = cv.imread(sample['mask_path'],0)

if self.config['num_classes']<=2:
if np.sum(mask)>0:
label = 1.0
else:
label = 0.0
if self.config['num_classes']==2:

label = torch.tensor(label).long()

return (insar.float(), torch.from_numpy(mask).long(), label)
else:
label = torch.tensor(label).float()
return (insar.float(), torch.from_numpy(mask).float(), label)
else:
one_hot = torch.zeros((13))
for act in sample['label']:
one_hot[act] = 1

one_hot[sample['intensity']+7] = 1
one_hot[sample['phase'] + 10] = 1

return insar,mask, one_hot
else:
label = torch.tensor(label).long()
mask = np.zeros((224,224))
return (insar.float(), torch.from_numpy(mask).long(), label)
29 changes: 29 additions & 0 deletions configs/supervised_configs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"wandb_project":"YOUR_PROJECT",
"wandb_entity":"YOUR_ENTITY",
"task":"classification",
"num_classes":2,
"device":"cuda:1",
"wandb":false,
"ssl_encoder":null,
"contrastive":false,
"split_path":"./train_test_split.json",
"supervised_data_path":"/mnt/nvme1/nbountos/datasets/HephaestusSSL",
"data_path":"/mnt/nvme1/nbountos/datasets/Hephaestus_Raw",
"batch_size":64,
"num_workers":12,
"lr":0.0001,
"weight_decay":1e-4,
"epochs":10,
"architecture":"ResNet18",
"linear_evaluation":false,
"oversampling":true,
"threshold":null,
"multilabel":false,
"augment":false,
"num_channels":3,
"class_weights":[1.0,1.0],
"synthetic_probability":0,
"synth_path":"/mnt/nvme1/nbountos/datasets/synth/",
"seed":999
}
190 changes: 190 additions & 0 deletions supervised_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import torch
import torch.nn as nn
import numpy as np
from utilities import utils as utils
import json
from tqdm import tqdm
from pathlib import Path
import random
import timm
from torchmetrics.classification import BinaryStatScores
import wandb
import pprint

np.random.seed(999)
torch.manual_seed(999)
random.seed(999)

def define_checkpoint(configs):
checkpoint_path = (
Path("finetuning_checkpoints") /
configs["task"].lower() /
configs["architecture"].lower()
)
checkpoint_path.mkdir(parents=True, exist_ok=True)
return checkpoint_path

def eval_cls(configs,loader,criterion,mode='Val',model=None, epoch=-1):
print('\n\nInitializing Evaluation: ' + mode+'\n\n\n')
if model is None:
checkpoint_path = 'best_model_task_'+configs['task']+'.pt'
model = torch.load(configs['checkpoint_path'] / checkpoint_path)['model']
model.eval()
accuracy, fscore, precision, recall, _, average_precision = utils.initialize_metrics(configs)
metric = BinaryStatScores().to(configs['device'])
loss = 0.0
prediction_list = []
ground_truth = []

with torch.no_grad():
for batch in tqdm(loader):
insar, _ , label = batch
insar = insar.to(configs['device'])

label = label.to(configs['device'])
out = model(insar)
if configs['num_classes']==1:
label = label.unsqueeze(1)
if configs['multilabel'] or configs['num_classes']==1:
predictions = torch.sigmoid(out)
tmp_pred = predictions.clone()

predictions[tmp_pred>=configs['threshold']] = 1
predictions[tmp_pred<configs['threshold']] = 0
else:
predictions = out.argmax(1)

prediction_list.extend(predictions.cpu().detach().numpy())
ground_truth.extend(label.cpu().detach().numpy())

accuracy(predictions,label)
fscore(predictions,label)
precision(predictions,label)
recall(predictions,label)
average_precision(out[:,1].float(), label)

metric(predictions, label)
loss += criterion(out,label)
print('TP, FP, TN, FN, Support')
print(metric.compute())

if configs['wandb']:
wandb.log({mode+ ' F-Score':fscore.compute(),mode+' Acc':accuracy.compute(),mode + ' Precision':precision.compute(),mode + ' Recall': recall.compute(), mode + ' Avg. Precision':average_precision.compute(), mode + 'Loss: ':loss/len(loader.dataset)})
print({mode+ ' F-Score':fscore.compute(),mode+' Acc':accuracy.compute(),mode + ' Precision':precision.compute(),mode + ' Recall': recall.compute(), mode + ' Avg. Precision':average_precision.compute()}, mode + 'Loss: ',loss/len(loader.dataset),'Epoch: ',epoch)

return fscore.compute()


def train_cls(configs):
train_loader, val_loader, test_loader = utils.prepare_supervised_learning_loaders(configs)

accuracy, fscore, precision, recall, _, avg = utils.initialize_metrics(configs)
class_weights = torch.tensor(configs['class_weights']).to(configs['device'])
if configs['num_classes']>=2 or configs['num_classes']==1:
criterion = nn.CrossEntropyLoss(weight=class_weights)
else:
criterion = nn.BCEWithLogitsLoss()

if configs['ssl_encoder'] is None:
print('='*20)
print('Creating model pretrained on Imagenet')
print('='*20)
model = timm.create_model(configs['architecture'].lower(),num_classes=configs['num_classes'],pretrained=True)
if configs['linear_evaluation']:
for param in model.parameters():
param.requires_grad = False
model.eval()
if 'vit' in configs['architecture'] or 'swin' in configs['architecture']:
model.head = nn.Linear(model.head.in_features,configs['num_classes'])
else:
model.fc = nn.Linear(model.fc.in_features,configs['num_classes'])
else:
print('='*20)
print('Creating SSL model pretrained on Hephaestus')
print('='*20)

model = torch.load('pretrained_encoder_resnet50.pt',map_location=configs['device'])
out_dim = 2048

if configs['linear_evaluation']:
for param in model.parameters():
param.requires_grad = False
model.eval()
else:
for param in model.parameters():
param.requires_grad = False

model.fc = nn.Linear(out_dim,configs['num_classes'])
model.to(configs['device'])
#print(model)
optimizer = torch.optim.AdamW(model.parameters(),lr=configs['lr'],weight_decay=configs['weight_decay'])

best_val = 0.0
best_epoch = 0.0
best_stats = {}
for epoch in range(configs['epochs']):
for idx, batch in tqdm(enumerate(train_loader)):
if idx>200:
break
optimizer.zero_grad()
if not configs['linear_evaluation']:
model.train()
insar, _, label = batch
insar = insar.to(configs['device'])
label = label.to(configs['device'])
if configs['num_classes']==1:
label = label.unsqueeze(1)
out = model(insar)
if configs['multilabel'] or configs['num_classes']==1:
predictions = torch.sigmoid(out)
tmp_pred = predictions.clone()
predictions[tmp_pred>=configs['threshold']] = 1
predictions[tmp_pred<configs['threshold']] = 0
else:
predictions = out.argmax(1)
accuracy(predictions,label)

loss = criterion(out,label)
loss.backward()
optimizer.step()

if idx%100 == 0:
log_dict = {'Loss: ':loss.mean().item(),' Train accuracy: ':accuracy.compute(),'Epoch':epoch}
print(log_dict)
if configs['wandb']:
wandb.log(log_dict)

val_loss = eval_cls(configs,val_loader,criterion,model=model, epoch=epoch)
if val_loss >= best_val:
best_val = val_loss
best_epoch = epoch
best_stats = {'val_loss':best_val,'epoch':best_epoch}
print('New Best model: ',best_stats)

model_path = 'best_model_task_'+configs['task']+'.pt'
print('Saving model to: ',configs['checkpoint_path'] / model_path )
torch.save({'model':model,'stats':best_stats},configs['checkpoint_path'] / model_path )

print('='*20)
print('Start Testing')
print('='*20)
eval_cls(configs,test_loader,criterion,mode='Test')

def train_segmentation(configs):
pass

if __name__== '__main__':

config_path = 'configs/supervised_configs.json'
configs =json.load(open(config_path,'r'))
augmentation_cfg = json.load(open("configs/augmentations/augmentation.json", "r"))
configs.update(augmentation_cfg)
configs['checkpoint_path'] = define_checkpoint(configs)
if configs['wandb']:
wandb.init(
project=configs['wandb_project'],
entity=configs['wandb_entity'],
config=configs
)
pprint.pprint(configs)
train_cls(configs)

0 comments on commit 695d811

Please sign in to comment.