Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DetConSLoss and DetConBLoss #1771

Merged
merged 68 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
22d8058
allow access to v2 transforms availability from anywhere
liopeer Nov 11, 2024
c963024
import unnecessary after exception
liopeer Nov 11, 2024
988a718
reformat
liopeer Nov 12, 2024
4b23bb7
add implementation of AddGridTransform
liopeer Nov 12, 2024
f4e74ab
make AddGridTransform importable
liopeer Nov 12, 2024
a3b7e15
add tests for AddGridTransform
liopeer Nov 12, 2024
ff7a14b
reformat
liopeer Nov 12, 2024
a6fb889
enhance docstring
liopeer Nov 12, 2024
59892a5
fix typing issues
liopeer Nov 12, 2024
edd1755
Merge branch 'master' into lionel-lig-5625-add-addgridtransform2
liopeer Nov 12, 2024
2e89959
fix import when transforms.v2 not available
liopeer Nov 12, 2024
1caaed7
add additional type ignore for 3.7 compatibility
liopeer Nov 12, 2024
7ea7675
reformat
liopeer Nov 12, 2024
357c4dc
add transform to docs
liopeer Nov 12, 2024
c91fdb8
change header
liopeer Nov 12, 2024
af3a02b
add explanation on data structures
liopeer Nov 12, 2024
038ec59
use kw args
liopeer Nov 12, 2024
d64264f
add assertion for mask dimension to be geq 2
liopeer Nov 12, 2024
2cd76cb
remove unnecessary fixtures
liopeer Nov 15, 2024
ccaa553
make argument order consistent
liopeer Nov 15, 2024
7ccf47f
add DetCon SingleView and MultiView transforms
liopeer Nov 18, 2024
8622953
add MultiViewTransform BaseClass for v2 transforms
liopeer Nov 18, 2024
0afb15e
add tests for DetCon transform
liopeer Nov 18, 2024
7b60c16
export all newly added transforms
liopeer Nov 18, 2024
3dee483
add DetCon single view and DetCon multi view transforms
liopeer Nov 18, 2024
55cf731
add torchvision transforms v2 compatible MultiViewTransforms
liopeer Nov 18, 2024
d341d59
make newly added transforms public
liopeer Nov 18, 2024
fa5d635
remove unnecessary fixtures
liopeer Nov 18, 2024
2aa8cdc
add tests for DetCon transform
liopeer Nov 18, 2024
f65e902
Merge branch 'master' into lionel-lig-5626-add-detcontransform
liopeer Nov 18, 2024
ad608c5
merge
liopeer Nov 18, 2024
a98dec5
remove wrongfully added files
liopeer Nov 18, 2024
5436335
add DetCon transform and MultiView transforms for v2 to docs
liopeer Nov 18, 2024
f42feef
fix docs references
liopeer Nov 18, 2024
00b9b60
fix import issues for minimal dependencies
liopeer Nov 18, 2024
d552340
fixing code format
liopeer Nov 18, 2024
3789630
add test for multiviewtransformv2
liopeer Nov 18, 2024
ac3639a
fix testing of multiview
liopeer Nov 18, 2024
56edcd4
use singular AddGridTransforms
liopeer Nov 19, 2024
6b477c3
consistent naming to DetConS
liopeer Nov 19, 2024
f51e834
adjust docstring reference numbering
liopeer Nov 19, 2024
8286ee3
name refactoring
liopeer Nov 19, 2024
cd37874
Merge branch 'master' into lionel-lig-5628-add-detconloss
liopeer Nov 21, 2024
02f5c75
Merge branch 'master' into lionel-lig-5628-add-detconloss
liopeer Nov 21, 2024
5928b57
start detconloss implementation
liopeer Nov 21, 2024
c4fc6ef
Merge branch 'master' into lionel-lig-5628-add-detconloss
liopeer Dec 23, 2024
3828ae3
implement detconloss
liopeer Jan 2, 2025
2a93b0c
test detconloss
liopeer Jan 2, 2025
285b490
make detconloss public
liopeer Jan 2, 2025
7f183a6
add detconloss to docs
liopeer Jan 2, 2025
bf080e5
Merge branch 'master' into lionel-lig-5628-add-detconloss
liopeer Jan 2, 2025
2f78124
Update lightly/loss/detcon_loss.py
liopeer Jan 3, 2025
18abcba
initial small fixes
liopeer Jan 3, 2025
a954bb5
add comments and avoid 0 division
liopeer Jan 3, 2025
aa07b86
remove labels_ext
liopeer Jan 3, 2025
32702a5
remove labels_ext
liopeer Jan 3, 2025
10579d4
revert normalization; some formatting
liopeer Jan 6, 2025
b8b7bf0
complete rewrite of tests
liopeer Jan 6, 2025
49819dc
fix typing issues
liopeer Jan 6, 2025
78e6c85
remove unused imports
liopeer Jan 6, 2025
9e37f70
Update lightly/loss/detcon_loss.py
liopeer Jan 6, 2025
f4d17c0
move to f-strings
liopeer Jan 6, 2025
9cfeeab
create test classes
liopeer Jan 6, 2025
e06f71d
Merge branch 'lionel-lig-5628-add-detconloss' of github.com:lightly-a…
liopeer Jan 6, 2025
d530c5b
squeeze instead of 0 indexing
liopeer Jan 6, 2025
d79ba68
formatting
liopeer Jan 6, 2025
b0bff0f
few more comments on tensor shapes
liopeer Jan 6, 2025
3211b88
additional comments on tensor shapes
liopeer Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/lightly.loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ lightly.loss
.. autoclass:: lightly.loss.dcl_loss.DCLWLoss
:members:

.. autoclass:: lightly.loss.detcon_loss.DetConBLoss
:members:

.. autoclass:: lightly.loss.detcon_loss.DetConSLoss
:members:

.. autoclass:: lightly.loss.dino_loss.DINOLoss
:members:

Expand Down
1 change: 1 addition & 0 deletions lightly/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# All Rights Reserved
from lightly.loss.barlow_twins_loss import BarlowTwinsLoss
from lightly.loss.dcl_loss import DCLLoss, DCLWLoss
from lightly.loss.detcon_loss import DetConBLoss, DetConSLoss
from lightly.loss.dino_loss import DINOLoss
from lightly.loss.emp_ssl_loss import EMPSSLLoss
from lightly.loss.ibot_loss import IBOTPatchLoss
Expand Down
284 changes: 284 additions & 0 deletions lightly/loss/detcon_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import distributed as dist
from torch import distributed as torch_dist
from torch.nn import Module


class DetConSLoss:
guarin marked this conversation as resolved.
Show resolved Hide resolved
"""Implementation of the DetConS loss. [2]_

The inputs are two-fold:

- Two latent representations of the same batch under different views, as generated\
by SimCLR [3]_ and additional pooling over the regions of the segmentation.
- Two integer masks that indicate the regions of the segmentation that were used\
for pooling.

guarin marked this conversation as resolved.
Show resolved Hide resolved
For calculating the contrastive loss, regions under the same mask in the same image
(under a different view) are considered as positives and everything else as
negatives. With :math:`v_m` and :math:`v_{m'}'` being the pooled feature maps under
mask :math:`m` and :math:`m'` respectively, and additionally scaled to a norm of
:math:`\\frac{1}{\\sqrt{\\tau}}`, the formula for the contrastive loss is

.. math::
\\mathcal{L} = \sum_{m}\sum_{m'} \mathbb{1}_{m, m'} \\left[ - \\log\
liopeer marked this conversation as resolved.
Show resolved Hide resolved
\\frac{\\exp(v_m \\cdot v_{m'}')}{\\exp(v_m \\cdot v_{m'}') +\
\\sum_{n}\\exp (v_m \\cdot v_{m'}')} \\right]

where :math:`\\mathbb{1}_{m, m'}` is 1 if the masks are the same and 0 otherwise.

References:
.. [2] DetCon https://arxiv.org/abs/2103.10957
.. [3] SimCLR https://arxiv.org/abs/2002.05709

Attributes:
temperature:
The temperature :math:`\\tau` in the contrastive loss.
gather_distributed:
If True, the similarity matrix is gathered across all GPUs before the loss
is calculated. Else, the loss is calculated on each GPU separately.
"""

def __init__(
self, temperature: float = 0.1, gather_distributed: bool = True
) -> None:
self.detconbloss = DetConBLoss(

Check warning on line 47 in lightly/loss/detcon_loss.py

View check run for this annotation

Codecov / codecov/patch

lightly/loss/detcon_loss.py#L47

Added line #L47 was not covered by tests
temperature=temperature, gather_distributed=gather_distributed
)

def forward(
self, view0: Tensor, view1: Tensor, mask_view0: Tensor, mask_view1: Tensor
) -> Tensor:
"""Calculate the contrastive loss under the same mask in the same image.

The tensor shapes and value ranges are given by variables :math:`B, M, D, N`,
where :math:`B` is the batch size, :math:`M` is the sampled number of image
masks / regions, :math:`D` is the embedding size and :math:`N` is the total
number of masks.

Args:
view0: Mask-pooled output for the first view, a float tensor of shape
:math:`(B, M, D)`.
pred_view1: Mask-pooled output for the second view, a float tensor of shape
:math:`(B, M, D)`.
mask_view0: Indices corresponding to the sampled masks for the first view,
an integer tensor of shape :math:`(B, M)` with (possibly repeated)
indices in the range :math:`[0, N)`.
mask_view1: Indices corresponding to the sampled masks for the second view,
an integer tensor of shape (B, M) with (possibly repeated) indices in
the range :math:`[0, N)`.

Returns:
A scalar float tensor containing the contrastive loss.
"""
loss: Tensor = self.detconbloss(

Check warning on line 76 in lightly/loss/detcon_loss.py

View check run for this annotation

Codecov / codecov/patch

lightly/loss/detcon_loss.py#L76

Added line #L76 was not covered by tests
view0, view1, view0, view1, mask_view0, mask_view1
)
return loss

Check warning on line 79 in lightly/loss/detcon_loss.py

View check run for this annotation

Codecov / codecov/patch

lightly/loss/detcon_loss.py#L79

Added line #L79 was not covered by tests


class DetConBLoss(Module):
"""Implementation of the DetConB loss. [0]_

The inputs are three-fold:

- Two latent representations of the same batch under different views, as generated\
by BYOL's [1]_ prediction branch and additional pooling over the regions of\
the segmentation.
- Two latent representations of the same batch under different views, as generated\
by BYOL's target branch and additional pooling over the regions of the\
segmentation.
- Two integer masks that indicate the regions of the segmentation that were used\
for pooling.

For calculating the contrastive loss, regions under the same mask in the same image
(under a different view) are considered as positives and everything else as
negatives. With :math:`v_m` and :math:`v_{m'}'` being the pooled feature maps under
mask :math:`m` and :math:`m'` respectively, and additionally scaled to a norm of
:math:`\\frac{1}{\\sqrt{\\tau}}`, the formula for the contrastive loss is

.. math::
\\mathcal{L} = \sum_{m}\sum_{m'} \mathbb{1}_{m, m'} \\left[ - \\log \\frac{\\exp(v_m \\cdot v_{m'}')}{\\exp(v_m \\cdot v_{m'}') + \\sum_{n}\\exp (v_m \\cdot v_{m'}')} \\right]
guarin marked this conversation as resolved.
Show resolved Hide resolved

where :math:`\\mathbb{1}_{m, m'}` is 1 if the masks are the same and 0 otherwise.
Since :math:`v_m` and :math:`v_{m'}'` stem from different branches, the loss is
symmetrized by also calculating the loss with the roles of the views reversed. [1]_

References:
.. [0] DetCon https://arxiv.org/abs/2103.10957
.. [1] BYOL https://arxiv.org/abs/2006.07733

Attributes:
temperature:
The temperature :math:`\\tau` in the contrastive loss.
gather_distributed:
If True, the similarity matrix is gathered across all GPUs before the loss
is calculated. Else, the loss is calculated on each GPU separately.
"""

def __init__(
self, temperature: float = 0.1, gather_distributed: bool = True
) -> None:
super().__init__()
self.eps = 1e-8
self.temperature = temperature
self.gather_distributed = gather_distributed
self.eps = 1e-11
guarin marked this conversation as resolved.
Show resolved Hide resolved

if abs(self.temperature) < self.eps:
raise ValueError(

Check warning on line 131 in lightly/loss/detcon_loss.py

View check run for this annotation

Codecov / codecov/patch

lightly/loss/detcon_loss.py#L131

Added line #L131 was not covered by tests
"Illegal temperature: abs({}) < 1e-8".format(self.temperature)
guarin marked this conversation as resolved.
Show resolved Hide resolved
)
if self.gather_distributed and not torch_dist.is_available():
raise ValueError(

Check warning on line 135 in lightly/loss/detcon_loss.py

View check run for this annotation

Codecov / codecov/patch

lightly/loss/detcon_loss.py#L135

Added line #L135 was not covered by tests
"gather_distributed is True but torch.distributed is not available. "
"Please set gather_distributed=False or install a torch version with "
"distributed support."
)

def forward(
self,
pred_view0: Tensor,
pred_view1: Tensor,
target_view0: Tensor,
target_view1: Tensor,
mask_view0: Tensor,
mask_view1: Tensor,
) -> Tensor:
"""Calculate the contrastive loss under the same mask in the same image.

The tensor shapes and value ranges are given by variables :math:`B, M, D, N`,
where :math:`B` is the batch size, :math:`M` is the sampled number of image
masks / regions, :math:`D` is the embedding size and :math:`N` is the total
number of masks.

Args:
pred_view0: Mask-pooled output of the prediction branch for the first view,
a float tensor of shape :math:`(B, M, D)`.
pred_view1: Mask-pooled output of the prediction branch for the second view,
a float tensor of shape :math:`(B, M, D)`.
target_view0: Mask-pooled output of the target branch for the first view,
a float tensor of shape :math:`(B, M, D)`.
target_view1: Mask-pooled output of the target branch for the second view,
a float tensor of shape :math:`(B, M, D)`.
mask_view0: Indices corresponding to the sampled masks for the first view,
an integer tensor of shape :math:`(B, M)` with (possibly repeated)
indices in the range :math:`[0, N)`.
mask_view1: Indices corresponding to the sampled masks for the second view,
an integer tensor of shape (B, M) with (possibly repeated) indices in
the range :math:`[0, N)`.

Returns:
A scalar float tensor containing the contrastive loss.
"""
b, m, d = pred_view0.size()
infinity_proxy = 1e9

# gather distributed
if not self.gather_distributed or dist.get_world_size() < 2:
target_view0_large = target_view0
target_view1_large = target_view1
labels_local = torch.eye(b, device=pred_view0.device)
labels_ext = torch.cat(
[
torch.eye(b, device=pred_view0.device),
torch.zeros_like(labels_local),
],
dim=1,
)
else:
target_view0_large = torch.cat(dist.gather(target_view0), dim=0)
target_view1_large = torch.cat(dist.gather(target_view1), dim=0)
replica_id = dist.get_rank()
labels_idx = torch.arange(b, device=pred_view0.device) + replica_id * b
enlarged_b = b * dist.get_world_size()
labels_local = F.one_hot(labels_idx, num_classes=enlarged_b)
labels_ext = F.one_hot(labels_idx, num_classes=2 * enlarged_b)

# normalize
pred_view0 = F.normalize(pred_view0, p=2, dim=2)
pred_view1 = F.normalize(pred_view1, p=2, dim=2)
target_view0_large = F.normalize(target_view0_large, p=2, dim=2)
target_view1_large = F.normalize(target_view1_large, p=2, dim=2)
guarin marked this conversation as resolved.
Show resolved Hide resolved

labels_local = labels_local[:, None, :, None]
labels_ext = labels_ext[:, None, :, None]
guarin marked this conversation as resolved.
Show resolved Hide resolved

# calculate similarity matrices
logits_aa = (
torch.einsum("abk,uvk->abuv", pred_view0, target_view0_large)
/ self.temperature
)
logits_bb = (
torch.einsum("abk,uvk->abuv", pred_view1, target_view1_large)
/ self.temperature
)
logits_ab = (
torch.einsum("abk,uvk->abuv", pred_view0, target_view1_large)
/ self.temperature
)
logits_ba = (
torch.einsum("abk,uvk->abuv", pred_view1, target_view0_large)
/ self.temperature
)

# determine where the masks are the same
same_mask_aa = _same_mask(mask_view0, mask_view0)
same_mask_bb = _same_mask(mask_view1, mask_view1)
same_mask_ab = _same_mask(mask_view0, mask_view1)
same_mask_ba = _same_mask(mask_view1, mask_view0)

# remove similarities between the same masks
labels_aa = labels_local * same_mask_aa
labels_bb = labels_local * same_mask_bb
labels_ab = labels_local * same_mask_ab
labels_ba = labels_local * same_mask_ba

logits_aa = logits_aa - infinity_proxy * labels_aa
logits_bb = logits_bb - infinity_proxy * labels_bb
labels_aa = 0.0 * labels_aa
labels_bb = 0.0 * labels_bb

labels_abaa = torch.cat([labels_ab, labels_aa], dim=2)
labels_babb = torch.cat([labels_ba, labels_bb], dim=2)

labels_0 = labels_abaa.view(b, m, -1)
labels_1 = labels_babb.view(b, m, -1)

num_positives_0 = torch.sum(labels_0, dim=-1, keepdim=True)
num_positives_1 = torch.sum(labels_1, dim=-1, keepdim=True)

labels_0 = labels_0 / torch.maximum(num_positives_0, torch.tensor(1))
labels_1 = labels_1 / torch.maximum(num_positives_1, torch.tensor(1))
guarin marked this conversation as resolved.
Show resolved Hide resolved

obj_area_0 = torch.sum(_same_mask(mask_view0, mask_view0), dim=(2, 3))
obj_area_1 = torch.sum(_same_mask(mask_view1, mask_view1), dim=(2, 3))
guarin marked this conversation as resolved.
Show resolved Hide resolved

weights_0 = torch.gt(num_positives_0[..., 0], 1e-3).float()
weights_0 = weights_0 / obj_area_0
weights_1 = torch.gt(num_positives_1[..., 0], 1e-3).float()
weights_1 = weights_1 / obj_area_1
guarin marked this conversation as resolved.
Show resolved Hide resolved

logits_abaa = torch.cat([logits_ab, logits_aa], dim=2)
logits_babb = torch.cat([logits_ba, logits_bb], dim=2)

logits_abaa = logits_abaa.view(b, m, -1)
logits_babb = logits_babb.view(b, m, -1)

loss_a = torch_manual_cross_entropy(labels_0, logits_abaa, weights_0)
loss_b = torch_manual_cross_entropy(labels_1, logits_babb, weights_1)
loss = loss_a + loss_b
return loss


def _same_mask(mask0: Tensor, mask1: Tensor) -> Tensor:
return (mask0[:, :, None] == mask1[:, None, :]).float()[:, :, None, :]
guarin marked this conversation as resolved.
Show resolved Hide resolved


def torch_manual_cross_entropy(
guarin marked this conversation as resolved.
Show resolved Hide resolved
labels: Tensor, logits: Tensor, weight: Tensor
) -> Tensor:
ce = -weight * torch.sum(labels * F.log_softmax(logits, dim=-1), dim=-1)
return torch.mean(ce)
Loading