Skip to content

Commit

Permalink
add metrics for multilabel problems
Browse files Browse the repository at this point in the history
  • Loading branch information
ngbountos committed Jun 17, 2023
1 parent fece70f commit 294323f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
9 changes: 5 additions & 4 deletions configs/supervised_configs.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
{
"wandb_project":"YOUR_PROJECT",
"wandb_project":"YOUR_WANDB_PROJECT",
"wandb_entity":"YOUR_ENTITY",
"task":"classification",
"num_classes":11,
"device":"cuda:1",
"wandb":false,
"wandb":true,
"mixed_precision":true,
"ssl_encoder":null,
"ssl_run_id_path":null,
Expand All @@ -14,9 +14,10 @@
"num_workers":4,
"lr":0.0001,
"weight_decay":1e-4,
"epochs":10,
"architecture":"ResNet50",
"epochs":2,
"architecture":"ResNet18",
"oversampling":true,
"multilabel":true,
"augment":false,
"num_channel":2,
"class_weights":[1.0,1.0],
Expand Down
14 changes: 10 additions & 4 deletions utilities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,16 @@ def prepare_supervised_learning_loaders(configs):
return train_loader, val_loader, test_loader

def initialize_metrics(configs):
accuracy = Accuracy(task='multiclass', average='micro',multidim_average='global',num_classes=configs['num_classes']).to(configs['device'])
fscore = F1Score(task='multiclass', num_classes=configs['num_classes'],average='micro',multidim_average='global').to(configs['device'])
precision = Precision(task='multiclass', average='micro', num_classes=configs['num_classes'],multidim_average='global').to(configs['device'])
recall = Recall(task='multiclass', average='micro', num_classes=configs['num_classes'],multidim_average='global').to(configs['device'])
if configs['multilabel']:
accuracy = Accuracy(task='multilabel', average='micro',multidim_average='global',num_labels=configs['num_classes']).to(configs['device'])
fscore = F1Score(task='multilabel', num_labels=configs['num_classes'],average='micro',multidim_average='global').to(configs['device'])
precision = Precision(task='multilabel', average='micro', num_labels=configs['num_classes'],multidim_average='global').to(configs['device'])
recall = Recall(task='multilabel', average='micro', num_labels=configs['num_classes'],multidim_average='global').to(configs['device'])
else:
accuracy = Accuracy(task='multiclass', average='micro',multidim_average='global',num_classes=configs['num_classes']).to(configs['device'])
fscore = F1Score(task='multiclass', num_classes=configs['num_classes'],average='micro',multidim_average='global').to(configs['device'])
precision = Precision(task='multiclass', average='micro', num_classes=configs['num_classes'],multidim_average='global').to(configs['device'])
recall = Recall(task='multiclass', average='micro', num_classes=configs['num_classes'],multidim_average='global').to(configs['device'])
return [accuracy, fscore, precision, recall]


Expand Down

0 comments on commit 294323f

Please sign in to comment.