Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove progress bar from validation #123

Closed
wants to merge 7 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions pyha_analyzer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import datetime
import logging
import os
import time
from typing import Any, Tuple

import numpy as np
Expand All @@ -18,13 +19,13 @@
from torch.utils.data import DataLoader
from torchmetrics.classification import MultilabelAveragePrecision
from tqdm import tqdm
import wandb

import wandb
from pyha_analyzer import config
from pyha_analyzer.dataset import get_datasets, make_dataloaders
from pyha_analyzer.utils import set_seed
from pyha_analyzer.models.early_stopper import EarlyStopper
from pyha_analyzer.models.timm_model import TimmModel
from pyha_analyzer.utils import set_seed

tqdm.pandas()
time_now = datetime.datetime.now().strftime('%Y%m%d-%H%M')
Expand Down Expand Up @@ -176,23 +177,30 @@ def valid(model: Any,
dataset_ratio = 1.0

num_valid_samples = int(len(data_loader)*dataset_ratio)

# tqdm is a progress bar
dl_iter = tqdm(data_loader, position=5, total=num_valid_samples)
start_time = time.time()

with torch.no_grad():
for index, (mels, labels) in enumerate(dl_iter):
for index, (mels, labels) in enumerate(data_loader):
if index > num_valid_samples:
# Stop early if not doing full validation
break

# Janky progress bar
# Using instead tqdm b/c of https://github.com/wandb/wandb/issues/1265
for proportion in [0.25, 0.5, 0.75]:
if index == int(proportion * num_valid_samples):
logger.info("Validation is %d%% complete",
int(100 * proportion))

loss, outputs = run_batch(model, mels, labels)

running_loss += loss.item()

log_pred.append(torch.clone(outputs.cpu()).detach())
log_label.append(torch.clone(labels.cpu()).detach())

# Print duration of validation
end_time = time.time()
logger.info("Validation took %d seconds", int(end_time - start_time))

# softmax predictions
log_pred = F.softmax(torch.cat(log_pred)).to(cfg.device)
Expand Down Expand Up @@ -250,12 +258,6 @@ def logging_setup() -> None:
def main(in_sweep=True) -> None:
""" Main function
"""
# pylint: disable-next=global-statement
global EPOCH
# pylint: disable-next=global-statement
global BEST_VALID_MAP
EPOCH = 0
BEST_VALID_MAP = 0
logger.info("Device is: %s, Preprocessing Device is %s", cfg.device, cfg.prepros_device)
set_seed(cfg.seed)

Expand Down