diff --git a/lightly/loss/detcon_loss.py b/lightly/loss/detcon_loss.py index b96019976..5c6d0d9ed 100644 --- a/lightly/loss/detcon_loss.py +++ b/lightly/loss/detcon_loss.py @@ -1,9 +1,11 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch import distributed as dist +from torch import distributed as torch_dist from torch.nn import Module +import lightly.utils.dist as lightly_dist + class DetConSLoss(Module): """Implementation of the DetConS loss. [2]_ @@ -129,7 +131,7 @@ def __init__( self.gather_distributed = gather_distributed if abs(self.temperature) < self.eps: raise ValueError(f"Illegal temperature: abs({self.temperature}) < 1e-8") - if self.gather_distributed and not dist.is_available(): + if self.gather_distributed and not torch_dist.is_available(): raise ValueError( "gather_distributed is True but torch.distributed is not available. " "Please set gather_distributed=False or install a torch version with " @@ -175,18 +177,18 @@ def forward( infinity_proxy = 1e9 # gather distributed - if not self.gather_distributed or dist.get_world_size() < 2: + if self.gather_distributed and lightly_dist.world_size() > 1: + target_view0_large = torch.cat(lightly_dist.gather(target_view0), dim=0) + target_view1_large = torch.cat(lightly_dist.gather(target_view1), dim=0) + replica_id = lightly_dist.rank() + labels_idx = torch.arange(b, device=pred_view0.device) + replica_id * b + enlarged_b = b * lightly_dist.world_size() + labels_local = F.one_hot(labels_idx, num_classes=enlarged_b) + else: target_view0_large = target_view0 target_view1_large = target_view1 labels_local = torch.eye(b, device=pred_view0.device) enlarged_b = b - else: - target_view0_large = torch.cat(dist.gather(target_view0), dim=0) - target_view1_large = torch.cat(dist.gather(target_view1), dim=0) - replica_id = dist.get_rank() - labels_idx = torch.arange(b, device=pred_view0.device) + replica_id * b - enlarged_b = b * dist.get_world_size() - labels_local = F.one_hot(labels_idx, num_classes=enlarged_b) # normalize pred_view0 = F.normalize(pred_view0, p=2, dim=2) diff --git a/tests/loss/test_detcon_loss.py b/tests/loss/test_detcon_loss.py index a3cf25cdd..584601759 100644 --- a/tests/loss/test_detcon_loss.py +++ b/tests/loss/test_detcon_loss.py @@ -5,9 +5,10 @@ import torch from pytest_mock import MockerFixture from torch import Tensor -from torch import distributed as dist +from torch import distributed as torch_dist from lightly.loss import DetConBLoss, DetConSLoss +from lightly.utils import dist as lightly_dist class TestDetConSLoss: @@ -108,18 +109,20 @@ def test_DetConBLoss_distributed_against_original( temperature=temperature, ) - mock_is_available = mocker.patch.object(dist, "is_available", return_value=True) + mock_is_available = mocker.patch.object( + torch_dist, "is_available", return_value=True + ) mock_get_world_size = mocker.patch.object( - dist, "get_world_size", return_value=world_size + lightly_dist, "world_size", return_value=world_size ) loss_fn = DetConBLoss(temperature=temperature, gather_distributed=True) total_loss: Tensor = torch.tensor(0.0) for rank in range(world_size): - mock_get_rank = mocker.patch.object(dist, "get_rank", return_value=rank) + mock_get_rank = mocker.patch.object(lightly_dist, "rank", return_value=rank) mock_gather = mocker.patch.object( - dist, + lightly_dist, "gather", side_effect=[ [t["target1"] for t in tensors], @@ -137,6 +140,7 @@ def test_DetConBLoss_distributed_against_original( total_loss += loss_val total_loss /= world_size + print(world_size, total_loss, loss_nondist) assert torch.allclose( total_loss, torch.tensor(loss_nondist, dtype=torch.float32), atol=1e-4 )