Skip to content

Commit

Permalink
Remove ignores
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth committed Nov 20, 2023
1 parent d5fb239 commit c916d1e
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 28 deletions.
13 changes: 8 additions & 5 deletions lightly/loss/memory_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

import functools
from typing import Optional, Tuple, Union

import torch
from torch import Tensor


class MemoryBankModule(torch.nn.Module):
Expand Down Expand Up @@ -68,9 +68,9 @@ def _init_memory_bank(self, dim: int) -> None:
# we could use register buffers like in the moco repo
# https://github.com/facebookresearch/moco but we don't
# want to pollute our checkpoints
self.bank = torch.randn(dim, self.size).type_as(self.bank) # type: ignore
self.bank = torch.nn.functional.normalize(self.bank, dim=0)
self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr) # type: ignore
bank: Tensor = torch.randn(dim, self.size).type_as(self.bank)
self.bank: Tensor = torch.nn.functional.normalize(bank, dim=0)
self.bank_ptr: Tensor = torch.zeros(1).type_as(self.bank_ptr)

@torch.no_grad()
def _dequeue_and_enqueue(self, batch: torch.Tensor) -> None:
Expand All @@ -92,7 +92,10 @@ def _dequeue_and_enqueue(self, batch: torch.Tensor) -> None:
self.bank_ptr[0] = ptr + batch_size

def forward(
self, output: torch.Tensor, labels: Optional[torch.Tensor] = None, update: bool = False
self,
output: torch.Tensor,
labels: Optional[torch.Tensor] = None,
update: bool = False,
) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
"""Query memory bank for additional negative samples
Expand Down
1 change: 1 addition & 0 deletions lightly/models/_momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import copy
from typing import Iterable, Tuple

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
Expand Down
9 changes: 6 additions & 3 deletions lightly/models/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
# All Rights Reserved

from __future__ import annotations

from typing import Any

import torch
import torch.nn as nn
from typing import Any, Optional, Tuple


class SplitBatchNorm(nn.BatchNorm2d):
Expand Down Expand Up @@ -63,10 +65,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.eps,
).view(N, C, H, W)
else:
assert self.running_mean is not None and self.running_var is not None

Check warning on line 68 in lightly/models/batchnorm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/batchnorm.py#L68

Added line #L68 was not covered by tests
result = nn.functional.batch_norm(
input,
self.running_mean[: self.num_features], # type: ignore
self.running_var[: self.num_features], # type: ignore
self.running_mean[: self.num_features],
self.running_var[: self.num_features],
self.weight,
self.bias,
False,
Expand Down
25 changes: 16 additions & 9 deletions lightly/models/modules/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ class ProjectionHead(nn.Module):
"""

def __init__(self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]]) -> None:

def __init__(
self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]]
) -> None:
super(ProjectionHead, self).__init__()

layers: List[nn.Module] = []
layers: List[nn.Module] = []
for input_dim, output_dim, batch_norm, non_linearity in blocks:
use_bias = not bool(batch_norm)
layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
Expand All @@ -53,7 +54,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Input of shape bsz x num_ftrs.
"""
return self.layers(x) # type: ignore
projection: torch.Tensor = self.layers(x)
return projection


class BarlowTwinsProjectionHead(ProjectionHead):
Expand Down Expand Up @@ -369,7 +371,9 @@ def get_updated_group_features(self, x: torch.Tensor) -> torch.Tensor:
mask = assignments == assigned_class
group_features[assigned_class] = self.beta * self.group_features[
assigned_class
] + (1 - self.beta) * x[mask].mean(axis=0) # type: ignore
] + (1 - self.beta) * x[mask].mean(
axis=0
) # type: ignore

return group_features

Expand Down Expand Up @@ -523,7 +527,9 @@ def __init__(
)
self.n_steps_frozen_prototypes = n_steps_frozen_prototypes

def forward(self, x: torch.Tensor, step: Optional[int]=None) -> Union[torch.Tensor, List[torch.Tensor]]:
def forward(
self, x: torch.Tensor, step: Optional[int] = None
) -> Union[torch.Tensor, List[torch.Tensor]]:
self._freeze_prototypes_if_required(step)
out = []
for layer in self.heads:
Expand All @@ -535,7 +541,7 @@ def normalize(self) -> None:
for layer in self.heads:
utils.normalize_weight(layer.weight)

def _freeze_prototypes_if_required(self, step: Optional[int]=None) -> None:
def _freeze_prototypes_if_required(self, step: Optional[int] = None) -> None:
if self.n_steps_frozen_prototypes > 0:
if step is None:
raise ValueError(
Expand Down Expand Up @@ -602,8 +608,9 @@ def __init__(
self.freeze_last_layer = freeze_last_layer
self.last_layer = nn.Linear(bottleneck_dim, output_dim, bias=False)
self.last_layer = nn.utils.weight_norm(self.last_layer)
self.last_layer.weight_g.data.fill_(1) # type: ignore

# Tell mypy this is ok because fill_ is overloaded.
self.last_layer.weight_g.data.fill_(1) # type: ignore

# Option to normalize last layer.
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
Expand Down
12 changes: 9 additions & 3 deletions lightly/models/modules/nn_memory_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# All Rights Reserved

from typing import Optional

import torch

from lightly.loss.memory_bank import MemoryBankModule


class NNMemoryBankModule(MemoryBankModule):
class NNMemoryBankModule(MemoryBankModule): # type: ignore # Cannot subclass type Any.
"""Nearest Neighbour Memory Bank implementation
This class implements a nearest neighbour memory bank as described in the
Expand Down Expand Up @@ -41,7 +42,12 @@ class NNMemoryBankModule(MemoryBankModule):
def __init__(self, size: int = 2**16):
super(NNMemoryBankModule, self).__init__(size)

def forward(self, output: torch.Tensor, labels: Optional[torch.Tensor] = None, update: bool = False) -> torch.Tensor:
def forward(
self,
output: torch.Tensor,
labels: Optional[torch.Tensor] = None,
update: bool = False,
) -> torch.Tensor:
"""Returns nearest neighbour of output tensor from memory bank
Args:
Expand All @@ -51,7 +57,7 @@ def forward(self, output: torch.Tensor, labels: Optional[torch.Tensor] = None, u
"""

output, bank = super(NNMemoryBankModule, self).forward(output, update=update)
bank = bank.to(output.device).t() # type: ignore
bank = bank.to(output.device).t()

output_normed = torch.nn.functional.normalize(output, dim=1)
bank_normed = torch.nn.functional.normalize(bank, dim=1)
Expand Down
21 changes: 14 additions & 7 deletions lightly/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Tensor of shape bsz x channels x W x H
"""

out = self.conv1(x)
out: torch.Tensor = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)

Expand All @@ -84,7 +84,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
out += self.shortcut(x)
out = F.relu(out)

return out # type: ignore
return out


class Bottleneck(nn.Module):
Expand Down Expand Up @@ -144,7 +144,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Tensor of shape bsz x channels x W x H
"""

out = self.conv1(x)
out: torch.Tensor = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)

Expand All @@ -158,7 +158,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
out += self.shortcut(x)
out = F.relu(out)

return out # type: ignore
return out


class ResNet(nn.Module):
Expand Down Expand Up @@ -209,7 +209,14 @@ def __init__(
)
self.linear = nn.Linear(self.base * 8 * block.expansion, num_classes)

def _make_layer(self, block: type[BasicBlock], planes: int, num_layers: int, stride: int, num_splits: int) -> nn.Sequential:
def _make_layer(
self,
block: type[BasicBlock],
planes: int,
num_layers: int,
stride: int,
num_splits: int,
) -> nn.Sequential:
strides = [stride] + [1] * (num_layers - 1)
layers = []
for stride in strides:
Expand Down Expand Up @@ -287,8 +294,8 @@ def ResNetGenerator(
)

return ResNet(
**model_params[name], # type: ignore
**model_params[name], # type: ignore # Cannot unpack dict to type "ResNet".
width=width,
num_classes=num_classes,
num_splits=num_splits
num_splits=num_splits,
)
1 change: 0 additions & 1 deletion lightly/models/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import List


ZOO = {
"resnet-9/simclr/d16/w0.0625": "https://storage.googleapis.com/models_boris/whattolabel-resnet9-simclr-d16-w0.0625-i-ce0d6bd9.pth",
"resnet-9/simclr/d16/w0.125": "https://storage.googleapis.com/models_boris/whattolabel-resnet9-simclr-d16-w0.125-i-7269c38d.pth",
Expand Down

0 comments on commit c916d1e

Please sign in to comment.