Skip to content

Commit

Permalink
add linear evaluation functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
ngbountos committed Jun 18, 2023
1 parent 55a2656 commit fc7cd8f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
14 changes: 6 additions & 8 deletions configs/supervised_configs.json
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
{
"wandb_project":"YOUR_WANDB_PROJECT",
"wandb_project":"YOUR_PROJECT",
"wandb_entity":"YOUR_ENTITY",
"task":"classification",
"num_classes":11,
"device":"cuda:1",
"wandb":true,
"mixed_precision":true,
"ssl_encoder":null,
"ssl_run_id_path":null,
"annotation_path":"YOUR_PATH/annotations/",
"data_path":"YOUR_PATH/Hephaestus_Raw/",
"batch_size":64,
"batch_size":128,
"num_workers":4,
"device":"cuda:1",
"lr":0.0001,
"weight_decay":1e-4,
"epochs":2,
"architecture":"ResNet18",
"epochs":10,
"architecture":"ResNet50",
"oversampling":true,
"multilabel":true,
"linear_evaluation":true,
"augment":false,
"num_channel":2,
"class_weights":[1.0,1.0],
"seed":999,
"image_size":224
}
25 changes: 24 additions & 1 deletion training/train_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@


def train_epoch(train_loader,model,optimizer,criterion,epoch,configs):
model.train()
if not configs['linear_evaluation']:
model.train()

for idx, batch in enumerate(tqdm.tqdm(train_loader)):

Expand Down Expand Up @@ -63,8 +64,30 @@ def train(configs):
criterion = nn.BCEWithLogitsLoss()
if configs['ssl_encoder'] is None:
base_model = model_utils.create_model(configs)
if 'vit' in configs['architecture']:
in_features = base_model.head.in_features
else:
in_features = base_model.fc.in_features
else:
print('Loading SSL checkpoint: ',configs['ssl_encoder'])
base_model = torch.load(configs['ssl_encoder'],map_location='cpu')
#Create dummy model to get fully connected layer's input dim
dummy_model = model_utils.create_model(configs)
if 'vit' in configs['architecture']:
in_features = dummy_model.head.in_features
else:
in_features = dummy_model.fc.in_features
del dummy_model

if configs['linear_evaluation']:
for param in base_model.parameters():
param.requires_grad = False
if 'vit' not in configs['architecture']:
base_model.fc = nn.Linear(in_features,configs['num_classes'])
else:
base_model.head = nn.Linear(in_features,configs['num_classes'])

base_model.eval()

optimizer = torch.optim.AdamW(base_model.parameters(),lr=configs['lr'],weight_decay=configs['weight_decay'])

Expand Down

0 comments on commit fc7cd8f

Please sign in to comment.