Skip to content

Commit

Permalink
Typecheck TiCo loss (#1756)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth authored Dec 23, 2024
1 parent 4b7c719 commit 53d1af4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 18 deletions.
7 changes: 5 additions & 2 deletions lightly/loss/tico_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Union

import torch
import torch.distributed as dist
from torch import Tensor

from lightly.utils.dist import gather

Expand Down Expand Up @@ -70,7 +73,7 @@ def __init__(

self.beta = beta
self.rho = rho
self.C = None
self.C: Union[Tensor, None] = None
self.gather_distributed = gather_distributed

def forward(
Expand Down Expand Up @@ -131,7 +134,7 @@ def forward(
transformative_invariance_loss = 1.0 - (z_a * z_b).sum(dim=1).mean()
covariance_contrast_loss = self.rho * (torch.mm(z_a, C) * z_a).sum(dim=1).mean()

loss = transformative_invariance_loss + covariance_contrast_loss
loss: Tensor = transformative_invariance_loss + covariance_contrast_loss

# Update covariance matrix
if update_covariance_matrix:
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ exclude = '''(?x)(
lightly/cli/train_cli.py |
lightly/cli/_cli_simclr.py |
lightly/cli/_helpers.py |
lightly/loss/tico_loss.py |
lightly/loss/pmsn_loss.py |
lightly/loss/swav_loss.py |
lightly/loss/negative_cosine_similarity.py |
Expand Down Expand Up @@ -257,7 +256,6 @@ exclude = '''(?x)(
tests/loss/test_barlow_twins_loss.py |
tests/loss/test_SymNegCosineSimilarityLoss.py |
tests/loss/test_MemoryBank.py |
tests/loss/test_TicoLoss.py |
tests/loss/test_PMSNLoss.py |
tests/loss/test_HyperSphere.py |
tests/loss/test_SwaVLoss.py |
Expand Down
23 changes: 9 additions & 14 deletions tests/loss/test_TicoLoss.py → tests/loss/test_tico_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import unittest

import pytest
import torch
from pytest_mock import MockerFixture
Expand All @@ -24,10 +22,7 @@ def test__gather_distributed_dist_not_available(
TiCoLoss(gather_distributed=True)
mock_is_available.assert_called_once()


class TestTiCoLossUnitTest(unittest.TestCase):
# Old tests in unittest style, please add new tests to TestTiCoLoss using pytest.
def test_forward_pass(self):
def test_forward_pass(self) -> None:
torch.manual_seed(0)
loss = TiCoLoss()
for bsz in range(2, 4):
Expand All @@ -37,10 +32,10 @@ def test_forward_pass(self):
# symmetry
l1 = loss(x0, x1, update_covariance_matrix=False)
l2 = loss(x1, x0, update_covariance_matrix=False)
self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0, 2)
assert l1 == pytest.approx(l2, abs=1e-2)

@unittest.skipUnless(torch.cuda.is_available(), "Cuda not available")
def test_forward_pass_cuda(self):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
def test_forward_pass_cuda(self) -> None:
torch.manual_seed(0)
loss = TiCoLoss()
for bsz in range(2, 4):
Expand All @@ -50,20 +45,20 @@ def test_forward_pass_cuda(self):
# symmetry
l1 = loss(x0, x1, update_covariance_matrix=False)
l2 = loss(x1, x0, update_covariance_matrix=False)
self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0, 2)
assert l1 == pytest.approx(l2, abs=1e-2)

def test_forward_pass__error_batch_size_1(self):
def test_forward_pass__error_batch_size_1(self) -> None:
torch.manual_seed(0)
loss = TiCoLoss()
x0 = torch.randn((1, 256))
x1 = torch.randn((1, 256))
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
loss(x0, x1, update_covariance_matrix=False)

def test_forward_pass__error_different_shapes(self):
def test_forward_pass__error_different_shapes(self) -> None:
torch.manual_seed(0)
loss = TiCoLoss()
x0 = torch.randn((2, 32))
x1 = torch.randn((2, 16))
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
loss(x0, x1, update_covariance_matrix=False)

0 comments on commit 53d1af4

Please sign in to comment.