Skip to content

Commit

Permalink
Merge branch 'master' into jeremy-lig-5769-remove-docker-archive-from…
Browse files Browse the repository at this point in the history
…-lightlyssl-docs
  • Loading branch information
japrescott authored Jan 6, 2025
2 parents ef7cbe3 + 356ae56 commit 42bd402
Show file tree
Hide file tree
Showing 49 changed files with 982 additions and 1,044 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check_example_nbs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ on:
jobs:
convert-to-nbs:
name: "Check Example Notebooks"
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- name: Checkout Code
uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/discord_release_notification.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:

jobs:
notify-discord:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- name: Send Notification to Discord
env:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:
jobs:
build:
name: Build and release
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
jobs:
detect-code-changes:
name: Detect Code Changes
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
outputs:
run-tests: ${{ steps.filter.outputs.run-tests }}
steps:
Expand All @@ -29,7 +29,7 @@ jobs:
name: Test
needs: detect-code-changes
if: needs.detect-code-changes.outputs.run-tests == 'true'
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
strategy:
matrix:
python: ["3.7", "3.12"]
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_api_deps_only.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ on:
jobs:
detect-code-changes:
name: Detect Code Changes
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
outputs:
run-tests: ${{ steps.filter.outputs.run-tests }}
steps:
Expand All @@ -30,7 +30,7 @@ jobs:
name: Test
needs: detect-code-changes
if: needs.detect-code-changes.outputs.run-tests == 'true'
runs-on: ubuntu-latest
runs-on: ubuntu-22.04

steps:
- name: Checkout Code
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_code_format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
jobs:
test:
name: Check
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
strategy:
matrix:
python: ["3.7", "3.12"]
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_minimal_deps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
jobs:
detect-code-changes:
name: Detect Code Changes
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
outputs:
run-tests: ${{ steps.filter.outputs.run-tests }}
steps:
Expand All @@ -29,7 +29,7 @@ jobs:
name: Test
needs: detect-code-changes
if: needs.detect-code-changes.outputs.run-tests == 'true'
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
strategy:
matrix:
python: ["3.7"]
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_setup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
jobs:
detect-code-changes:
name: Detect Code Changes
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
outputs:
run-tests: ${{ steps.filter.outputs.run-tests }}
steps:
Expand All @@ -27,7 +27,7 @@ jobs:
name: Test setup.py
needs: detect-code-changes
if: needs.detect-code-changes.outputs.run-tests == 'true'
runs-on: ubuntu-latest
runs-on: ubuntu-22.04

steps:
- name: Checkout Code
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests_unmocked.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on: [workflow_dispatch]
jobs:
test:
name: Run unmocked tests
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- name: Checkout Code
uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/weekly_dependency_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ on:
jobs:
test_fresh_install:
name: Test fresh install
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
[![Downloads](https://static.pepy.tech/badge/lightly)](https://pepy.tech/project/lightly)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Discord](https://img.shields.io/discord/752876370337726585?logo=discord&logoColor=white&label=discord&color=7289da)](https://discord.gg/xvNJW94)
![codecov.io](https://codecov.io/github/lightly-ai/lightly/coverage.svg?branch=master)


Lightly**SSL** is a computer vision framework for self-supervised learning.

Expand Down
5 changes: 3 additions & 2 deletions lightly/loss/barlow_twins_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor


class BarlowTwinsLoss(torch.nn.Module):
Expand Down Expand Up @@ -81,7 +82,7 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:

invariance_loss = torch.diagonal(c).add_(-1).pow_(2).sum()
redundancy_reduction_loss = _off_diagonal(c).pow_(2).sum()
loss = invariance_loss + self.lambda_param * redundancy_reduction_loss
loss: Tensor = invariance_loss + self.lambda_param * redundancy_reduction_loss

return loss

Expand All @@ -106,7 +107,7 @@ def _normalize(
return normalized[0], normalized[1]


def _off_diagonal(x):
def _off_diagonal(x: Tensor) -> Tensor:
"""Returns a flattened view of the off-diagonal elements of a square matrix."""

# Ensure the input is a square matrix
Expand Down
12 changes: 7 additions & 5 deletions lightly/loss/dcl_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def negative_mises_fisher_weights(
out0: Tensor, out1: Tensor, sigma: float = 0.5
) -> torch.Tensor:
) -> Tensor:
"""Negative Mises-Fisher weighting function as presented in Decoupled Contrastive Learning [0].
The implementation was inspired by [1].
Expand All @@ -35,7 +35,7 @@ def negative_mises_fisher_weights(
similarity = torch.einsum("nm,nm->n", out0.detach(), out1.detach()) / sigma

# Return negative Mises-Fisher weights
return 2 - out0.shape[0] * nn.functional.softmax(similarity, dim=0)
return torch.tensor(2 - out0.shape[0] * nn.functional.softmax(similarity, dim=0))


class DCLLoss(nn.Module):
Expand Down Expand Up @@ -148,13 +148,15 @@ def forward(
out1_all = out1

# Calculate symmetric loss
loss0 = self._loss(out0, out1, out0_all, out1_all)
loss1 = self._loss(out1, out0, out1_all, out0_all)
loss0: Tensor = self._loss(out0, out1, out0_all, out1_all)
loss1: Tensor = self._loss(out1, out0, out1_all, out0_all)

# Return the mean loss over the mini-batch
return 0.5 * (loss0 + loss1)

def _loss(self, out0, out1, out0_all, out1_all):
def _loss(
self, out0: Tensor, out1: Tensor, out0_all: Tensor, out1_all: Tensor
) -> Tensor:
"""Calculates DCL loss for out0 with respect to its positives in out1
and the negatives in out1, out0_all, and out1_all.
Expand Down
21 changes: 12 additions & 9 deletions lightly/loss/dino_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from torch.nn import Module, Parameter

from lightly.models.modules import center
from lightly.models.modules.center import CENTER_MODE_TO_FUNCTION
Expand Down Expand Up @@ -83,6 +83,7 @@ def __init__(

# TODO(Guarin, 08/24): Refactor this to use the Center module directly once
# we do a breaking change.
self.center: Parameter
self.register_buffer("center", torch.zeros(1, 1, output_dim))

# we apply a warm up for the teacher temperature because
Expand Down Expand Up @@ -123,13 +124,15 @@ def forward(
if epoch < self.warmup_teacher_temp_epochs:
teacher_temp = self.teacher_temp_schedule[epoch]
else:
teacher_temp = self.teacher_temp
teacher_temp = torch.tensor(self.teacher_temp)

teacher_out = torch.stack(teacher_out)
t_out = F.softmax((teacher_out - self.center) / teacher_temp, dim=-1)
teacher_out_stacked = torch.stack(teacher_out)
t_out: Tensor = F.softmax(
(teacher_out_stacked - self.center) / teacher_temp, dim=-1
)

student_out = torch.stack(student_out)
s_out = F.log_softmax(student_out / self.student_temp, dim=-1)
student_out_stacked = torch.stack(student_out)
s_out = F.log_softmax(student_out_stacked / self.student_temp, dim=-1)

# Calculate feature similarities, ignoring the diagonal
# b = batch_size, t = n_views_teacher, s = n_views_student, d = output_dim
Expand All @@ -138,12 +141,12 @@ def forward(

# Number of loss terms, ignoring the diagonal
n_terms = loss.numel() - loss.diagonal().numel()
batch_size = teacher_out.shape[1]
batch_size = teacher_out_stacked.shape[1]

loss = loss.sum() / (n_terms * batch_size)

# Update the center used for the teacher output
self.update_center(teacher_out)
self.update_center(teacher_out_stacked)

return loss

Expand All @@ -161,6 +164,6 @@ def update_center(self, teacher_out: Tensor) -> None:
batch_center = self._center_fn(x=teacher_out, dim=(0, 1))

# Update the center with a moving average
self.center = center.center_momentum(
self.center.data = center.center_momentum(
center=self.center, batch_center=batch_center, momentum=self.center_momentum
)
17 changes: 10 additions & 7 deletions lightly/loss/hypersphere_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module


class HypersphereLoss(torch.nn.Module):
class HypersphereLoss(Module):
"""Implementation of the loss described in 'Understanding Contrastive Representation Learning through
Alignment and Uniformity on the Hypersphere.' [0]
Expand Down Expand Up @@ -44,7 +46,7 @@ class HypersphereLoss(torch.nn.Module):
>>> loss = loss_fn(out0, out1)
"""

def __init__(self, t=1.0, lam=1.0, alpha=2.0):
def __init__(self, t: float = 1.0, lam: float = 1.0, alpha: float = 2.0):
"""Initializes the HypersphereLoss module with the specified parameters.
Parameters as described in [0]
Expand All @@ -63,7 +65,7 @@ def __init__(self, t=1.0, lam=1.0, alpha=2.0):
self.lam = lam
self.alpha = alpha

def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
def forward(self, z_a: Tensor, z_b: Tensor) -> Tensor:
"""Computes the Hypersphere loss, which combines alignment and uniformity loss terms.
Args:
Expand All @@ -80,13 +82,14 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
y = F.normalize(z_b)

# Calculate alignment loss
def lalign(x, y):
return (x - y).norm(dim=1).pow(self.alpha).mean()
def lalign(x: Tensor, y: Tensor) -> Tensor:
lalign_: Tensor = (x - y).norm(dim=1).pow(self.alpha).mean()
return lalign_

# Calculate uniformity loss
def lunif(x):
def lunif(x: Tensor) -> Tensor:
sq_pdist = torch.pdist(x, p=2).pow(2)
return sq_pdist.mul(-self.t).exp().mean().log()

# Combine alignment and uniformity loss terms
return lalign(x, y) + self.lam * (lunif(x) + lunif(y)) / 2
return lalign(x, y) + self.lam * (lunif(x) + lunif(y)) / 2.0
16 changes: 9 additions & 7 deletions lightly/loss/ntx_ent_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from typing import Sequence, Union

import torch
from torch import Tensor
from torch import distributed as torch_dist
from torch import nn

from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.utils import dist


class NTXentLoss(MemoryBankModule):
class NTXentLoss(nn.Module):
"""Implementation of the Contrastive Cross Entropy Loss.
This implementation follows the SimCLR[0] paper. If you enable the memory
Expand Down Expand Up @@ -80,7 +81,10 @@ def __init__(
ValueError: If temperature is less than 1e-8 to prevent divide by zero.
ValueError: If gather_distributed is True but torch.distributed is not available.
"""
super().__init__(size=memory_bank_size, gather_distributed=gather_distributed)
super().__init__()
self.memory_bank = MemoryBankModule(
size=memory_bank_size, gather_distributed=gather_distributed
)
self.temperature = temperature
self.gather_distributed = gather_distributed
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
Expand All @@ -97,7 +101,7 @@ def __init__(
"distributed support."
)

def forward(self, out0: torch.Tensor, out1: torch.Tensor):
def forward(self, out0: Tensor, out1: Tensor) -> Tensor:
"""Forward pass through Contrastive Cross-Entropy Loss.
If used with a memory bank, the samples from the memory bank are used
Expand Down Expand Up @@ -129,9 +133,7 @@ def forward(self, out0: torch.Tensor, out1: torch.Tensor):
# for evaluating the loss on the test set)
# out1: shape: (batch_size, embedding_size)
# negatives: shape: (embedding_size, memory_bank_size)
out1, negatives = super(NTXentLoss, self).forward(
out1, update=out0.requires_grad
)
out1, negatives = self.memory_bank.forward(out1, update=out0.requires_grad)

# Use cosine similarity (dot product) as all vectors are normalized to unit length
# Notation in einsum: n = batch_size, c = embedding_size and k = memory_bank_size.
Expand Down Expand Up @@ -192,6 +194,6 @@ def forward(self, out0: torch.Tensor, out1: torch.Tensor):
labels = labels.repeat(2)

# Calculate the cross-entropy loss
loss = self.cross_entropy(logits, labels)
loss: Tensor = self.cross_entropy(logits, labels)

return loss
2 changes: 1 addition & 1 deletion lightly/loss/pmsn_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,6 @@ def _power_law_distribution(size: int, exponent: float, device: torch.device) ->
A power law distribution tensor summing up to 1.
"""
k = torch.arange(1, size + 1, device=device)
power_dist = k ** (-exponent)
power_dist = torch.tensor(k ** (-exponent))
power_dist = power_dist / power_dist.sum()
return power_dist
Loading

0 comments on commit 42bd402

Please sign in to comment.