diff --git a/run.py b/run.py index e2424c6a..7a4beff7 100644 --- a/run.py +++ b/run.py @@ -47,14 +47,14 @@ help='Log every n steps') parser.add_argument('--temperature', default=0.07, type=float, help='softmax temperature (default: 0.07)') -parser.add_argument('--n-views', default=2, type=int, metavar='N', +parser.add_argument('--n-views', default=4, type=int, metavar='N', help='Number of views for contrastive learning training.') parser.add_argument('--gpu-index', default=0, type=int, help='Gpu index.') def main(): args = parser.parse_args() - assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2." + # assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2." # check if gpu training is available if not args.disable_cuda and torch.cuda.is_available(): args.device = torch.device('cuda') diff --git a/simclr.py b/simclr.py index e022dca6..3b49bd29 100644 --- a/simclr.py +++ b/simclr.py @@ -3,6 +3,7 @@ import sys import torch +from torch import nn import torch.nn.functional as F from torch.cuda.amp import GradScaler, autocast from torch.utils.tensorboard import SummaryWriter @@ -12,48 +13,60 @@ torch.manual_seed(0) -class SimCLR(object): - - def __init__(self, *args, **kwargs): - self.args = kwargs['args'] - self.model = kwargs['model'].to(self.args.device) - self.optimizer = kwargs['optimizer'] - self.scheduler = kwargs['scheduler'] - self.writer = SummaryWriter() - logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG) - self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device) - - def info_nce_loss(self, features): +class InfoNCELoss(nn.Module): - labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0) + @staticmethod + def loss_forward(features: torch.Tensor, batch_size: int, n_views: int, temperature: float): + labels = torch.cat([torch.arange(batch_size) for _ in range(n_views)], dim=0).to(features.device) + # noinspection PyUnresolvedReferences labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() - labels = labels.to(self.args.device) - features = F.normalize(features, dim=1) similarity_matrix = torch.matmul(features, features.T) - # assert similarity_matrix.shape == ( - # self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size) - # assert similarity_matrix.shape == labels.shape # discard the main diagonal from both: labels and similarities matrix - mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device) + mask = torch.eye(labels.shape[0], dtype=torch.bool).to(features.device) labels = labels[~mask].view(labels.shape[0], -1) similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) - # assert similarity_matrix.shape == labels.shape - - # select and combine multiple positives - positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) + positives = similarity_matrix[labels.bool()].view(labels.shape[0] * (n_views - 1), -1) - # select only the negatives the negatives - negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) + # select only the negatives + # change: copy if n_views > 2 for other positive pairs of img + negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1).repeat(n_views - 1, 1) logits = torch.cat([positives, negatives], dim=1) - labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device) - - logits = logits / self.args.temperature + # the idx-0 corresponding to similarity between same img from different views (positive pairs) while the + # other columns correspond to similarity between negative pairs. + # the objective is to get the feature representation such that the positive pairs have higher similarity + # (0-th column in logits) while the negative pairs (the rest of columns) have lower similairty. + # therefore the label is set to 0 and crossentropy loss is applied afterward between label and logits. + labels = torch.zeros(logits.shape[0], dtype=torch.long).to(features.device) + + logits = logits / temperature return logits, labels + def __init__(self, batch_size, n_views, temperature): + super().__init__() + self.batch_size = batch_size + self.n_views = n_views + self.temperature = temperature + + def forward(self, features): + return InfoNCELoss.loss_forward(features, self.batch_size, self.n_views, self.temperature) + + +class SimCLR(object): + + def __init__(self, *args, **kwargs): + self.args = kwargs['args'] + self.model = kwargs['model'].to(self.args.device) + self.optimizer = kwargs['optimizer'] + self.scheduler = kwargs['scheduler'] + self.writer = SummaryWriter() + logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG) + self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device) + self.info_nce_loss = InfoNCELoss(self.args.batch_size, self.args.n_views, self.args.temperature) + def train(self, train_loader): scaler = GradScaler(enabled=self.args.fp16_precision)