Skip to content

Commit

Permalink
Adding SSL-EY to docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Dec 6, 2023
1 parent 410d971 commit fbd6438
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 0 deletions.
2 changes: 2 additions & 0 deletions benchmarks/imagenet/resnet50/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import knn_eval
import linear_eval
import simclr
import ssley
import swav
import torch
import vicreg
Expand Down Expand Up @@ -58,6 +59,7 @@
"dclw": {"model": dclw.DCLW, "transform": dclw.transform},
"dino": {"model": dino.DINO, "transform": dino.transform},
"simclr": {"model": simclr.SimCLR, "transform": simclr.transform},
"ssley": {"model": ssley.SSLEY, "transform": ssley.transform},
"swav": {"model": swav.SwAV, "transform": swav.transform},
"vicreg": {"model": vicreg.VICReg, "transform": vicreg.transform},
}
Expand Down
3 changes: 3 additions & 0 deletions docs/source/lightly.loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ lightly.loss
.. autoclass:: lightly.loss.regularizer.co2.CO2Regularizer
:members:

.. autoclass:: lightly.loss.ssley_loss.SSLEYLoss
:members:

.. autoclass:: lightly.loss.swav_loss.SwaVLoss
:members:

Expand Down
67 changes: 67 additions & 0 deletions tests/loss/test_SSLEYLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest

import pytest
import torch
import torch.nn.functional as F
from pytest_mock import MockerFixture
from torch import Tensor
from torch import distributed as dist

from lightly.loss import SSLEYLoss


class TestSSLEYLoss:
def test__gather_distributed(self, mocker: MockerFixture) -> None:
mock_is_available = mocker.patch.object(dist, "is_available", return_value=True)
SSLEYLoss(gather_distributed=True)
mock_is_available.assert_called_once()

def test__gather_distributed_dist_not_available(
self, mocker: MockerFixture
) -> None:
mock_is_available = mocker.patch.object(
dist, "is_available", return_value=False
)
with pytest.raises(ValueError):
SSLEYLoss(gather_distributed=True)
mock_is_available.assert_called_once()


class TestSSLEYLossUnitTest(unittest.TestCase):
# Old tests in unittest style, please add new tests to TestSSLEYLoss using pytest.
def test_forward_pass(self):
loss = SSLEYLoss()
for bsz in range(2, 4):
x0 = torch.randn((bsz, 32))
x1 = torch.randn((bsz, 32))

# symmetry
l1 = loss(x0, x1)
l2 = loss(x1, x0)
self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0)

@unittest.skipUnless(torch.cuda.is_available(), "Cuda not available")
def test_forward_pass_cuda(self):
loss = SSLEYLoss()
for bsz in range(2, 4):
x0 = torch.randn((bsz, 32)).cuda()
x1 = torch.randn((bsz, 32)).cuda()

# symmetry
l1 = loss(x0, x1)
l2 = loss(x1, x0)
self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0)

def test_forward_pass__error_batch_size_1(self):
loss = SSLEYLoss()
x0 = torch.randn((1, 32))
x1 = torch.randn((1, 32))
with self.assertRaises(AssertionError):
loss(x0, x1)

def test_forward_pass__error_different_shapes(self):
loss = SSLEYLoss()
x0 = torch.randn((2, 32))
x1 = torch.randn((2, 16))
with self.assertRaises(AssertionError):
loss(x0, x1)

0 comments on commit fbd6438

Please sign in to comment.