Skip to content

Commit

Permalink
fix detcon loss distributed issue
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Jan 15, 2025
1 parent 5db25cf commit 4750c7c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
22 changes: 12 additions & 10 deletions lightly/loss/detcon_loss.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import distributed as dist
from torch import distributed as torch_dist
from torch.nn import Module

import lightly.utils.dist as lightly_dist


class DetConSLoss(Module):
"""Implementation of the DetConS loss. [2]_
Expand Down Expand Up @@ -129,7 +131,7 @@ def __init__(
self.gather_distributed = gather_distributed
if abs(self.temperature) < self.eps:
raise ValueError(f"Illegal temperature: abs({self.temperature}) < 1e-8")
if self.gather_distributed and not dist.is_available():
if self.gather_distributed and not torch_dist.is_available():
raise ValueError(
"gather_distributed is True but torch.distributed is not available. "
"Please set gather_distributed=False or install a torch version with "
Expand Down Expand Up @@ -175,18 +177,18 @@ def forward(
infinity_proxy = 1e9

# gather distributed
if not self.gather_distributed or dist.get_world_size() < 2:
if self.gather_distributed and lightly_dist.world_size() > 1:
target_view0_large = torch.cat(lightly_dist.gather(target_view0), dim=0)
target_view1_large = torch.cat(lightly_dist.gather(target_view1), dim=0)
replica_id = lightly_dist.rank()
labels_idx = torch.arange(b, device=pred_view0.device) + replica_id * b
enlarged_b = b * lightly_dist.world_size()
labels_local = F.one_hot(labels_idx, num_classes=enlarged_b)
else:
target_view0_large = target_view0
target_view1_large = target_view1
labels_local = torch.eye(b, device=pred_view0.device)
enlarged_b = b
else:
target_view0_large = torch.cat(dist.gather(target_view0), dim=0)
target_view1_large = torch.cat(dist.gather(target_view1), dim=0)
replica_id = dist.get_rank()
labels_idx = torch.arange(b, device=pred_view0.device) + replica_id * b
enlarged_b = b * dist.get_world_size()
labels_local = F.one_hot(labels_idx, num_classes=enlarged_b)

# normalize
pred_view0 = F.normalize(pred_view0, p=2, dim=2)
Expand Down
14 changes: 9 additions & 5 deletions tests/loss/test_detcon_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import torch
from pytest_mock import MockerFixture
from torch import Tensor
from torch import distributed as dist
from torch import distributed as torch_dist

from lightly.loss import DetConBLoss, DetConSLoss
from lightly.utils import dist as lightly_dist


class TestDetConSLoss:
Expand Down Expand Up @@ -108,18 +109,20 @@ def test_DetConBLoss_distributed_against_original(
temperature=temperature,
)

mock_is_available = mocker.patch.object(dist, "is_available", return_value=True)
mock_is_available = mocker.patch.object(
torch_dist, "is_available", return_value=True
)
mock_get_world_size = mocker.patch.object(
dist, "get_world_size", return_value=world_size
lightly_dist, "world_size", return_value=world_size
)

loss_fn = DetConBLoss(temperature=temperature, gather_distributed=True)

total_loss: Tensor = torch.tensor(0.0)
for rank in range(world_size):
mock_get_rank = mocker.patch.object(dist, "get_rank", return_value=rank)
mock_get_rank = mocker.patch.object(lightly_dist, "rank", return_value=rank)
mock_gather = mocker.patch.object(
dist,
lightly_dist,
"gather",
side_effect=[
[t["target1"] for t in tensors],
Expand All @@ -137,6 +140,7 @@ def test_DetConBLoss_distributed_against_original(
total_loss += loss_val
total_loss /= world_size

print(world_size, total_loss, loss_nondist)
assert torch.allclose(
total_loss, torch.tensor(loss_nondist, dtype=torch.float32), atol=1e-4
)
Expand Down

0 comments on commit 4750c7c

Please sign in to comment.