Skip to content

Commit

Permalink
Add nutella training loop
Browse files Browse the repository at this point in the history
  • Loading branch information
mbazzani committed Aug 10, 2023
1 parent e13b11f commit 2c02f45
Showing 1 changed file with 82 additions and 10 deletions.
92 changes: 82 additions & 10 deletions pyha_analyzer/nutella.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,42 @@
import scipy
import torch

from pyha_analyzer import pseudolabel, dataset, config
from pyha_analyzer.models.timm_model import TimmModel

cfg = config.cfg

def compute_nearest_neighbors(
batch_feature: np.ndarray,
dataset_feature: np.ndarray,
knn: int,
memory_efficient_computation: bool = False
) -> np.ndarray:
"""
Compute batch_feature's nearest-neighbors among dataset_feature.
Args:
batch_feature: The features for the provided batch of data,
shape [batch_size, feature_dim]
dataset_feature: The features for the whole dataset,
shape [dataset_size, feature_dim]
knn: The number of nearest-neighbors to use.
memory_efficient_computation: Whether to make computation memory
efficient. This option trades speed for memory footprint by looping over
samples in the batch instead of fully vectorizing nearest-neighbor
computation. For large datasets, memory usage can be a bottleneck, which
is why we set this option to True by default.
Returns:
The batch's nearest-neighbors affinity matrix of shape [batch_size, dataset_size],
where position (i, j) indicates whether dataset_feature[j]
belongs to batch_feature[i]'s nearest-neighbors.
Raises:
ValueError: If batch_feature and dataset_feature don't have the same
number of dimensions, or if their feature dimension don't match.
"""

def compute_nearest_neighbors(batch_feature: np.ndarray,
dataset_feature: np.ndarray,
knn: int,
memory_efficient_computation: bool = False) -> np.ndarray:
""" Algorithm to compute the nearest neighbors between two sets of features. """
batch_shape = batch_feature.shape
dataset_shape = dataset_feature.shape

Expand Down Expand Up @@ -65,7 +93,7 @@ def teacher_step(
alpha: float = 1.0,
normalize_pseudo_labels: bool = True,
eps: float = 1e-8
) -> np.ndarray:
) -> np.ndarray:
"""Computes the pseudo-labels (teacher-step) following Eq.(3) in the paper.
Args:
Expand Down Expand Up @@ -118,8 +146,52 @@ def teacher_step(
pseudo_label /= pseudo_label.sum(axis=-1, keepdims=True) + eps
return pseudo_label

def get_dataset_info(model, dl):
indices = data = predictions = features = []
print(f"{len(dl)}")
for _, (mels, _, index) in enumerate(dl):
dataset.append(mels)
predictions.append(model(mels))
predictions.append(model.get_features(mels))
indices.append(index)
print(f"{len(indices)}")
print(f"{len(data)}")
print(f"{len(predictions)}")
print(f"{len(features)}")
return [torch.cat(l) for l in (indices, data, predictions, features)]

def regularized_pseudolabels(model, dataset):
dataset_features = model.get_features(dataset)
# Might need to rewrite iteratively for less memory
dataset_predictions = model(dataset)
nn_matrix = compute_nearest_neighbors(
batch_features, dataset_features, knn=cfg.notela_knn
) # [dataset_size, dataset_size]
pseudo_labels = teacher_step(
dataset_predictions, dataset_predictions,
nn_matrix,
lambda_ = cfg.notela_lambda
) # [dataset_size, num_classes]
return pseudo_labels #TODO: Convert to one hot

def finetune(model):
"""
Fine tune on pseudo labels
"""
pseudo_df, train_dl, valid_dl, infer_dl = pseudolabel.pseudo_label_data(model)

logger.info("Finetuning on pseudo labels...")
train_process = TrainProcess(model, train_dl, valid_dl, infer_dl)
train_process.valid()
train_process.inference_valid()
for _ in range(cfg.epochs):
dataset, indices, predictions, features = get_dataset_info(train_dl)
pseudo_labels = regularized_pseudolabels(TrainProcess.model, dataset)
#train_process.update_dataset_predictions(indices, pseudolabel) #TBD
train_process.run_epoch()
train_process.valid()
train_process.inference_valid()

if __name__=="__main__":
model = TimmModel(10)
image = torch.rand((1, 3, 100, 100))
features = model.get_features(image)
print(f"{features=}")
model = TimmModel(3).to(cfg.device)
finetune(model)

0 comments on commit 2c02f45

Please sign in to comment.