Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixed a bug in KL loss aggregation (LVAE) #277

Merged
merged 7 commits into from
Nov 23, 2024
18 changes: 9 additions & 9 deletions src/careamics/losses/lvae/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,30 +168,30 @@ def get_kl_divergence_loss(
dim=1,
) # shape: (B, n_layers)

# Apply free bits (& batch average)
kl = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,)
jdeschamps marked this conversation as resolved.
Show resolved Hide resolved

# In 3D case, rescale by Z dim
# TODO If we have downsampling in Z dimension, then this needs to change.
if len(img_shape) == 3:
kl = kl / img_shape[0]

# Rescaling
if rescaling == "latent_dim":
for i in range(kl.shape[1]):
for i in range(len(kl)):
latent_dim = topdown_data["z"][i].shape[1:]
norm_factor = np.prod(latent_dim)
kl[:, i] = kl[:, i] / norm_factor
kl[i] = kl[i] / norm_factor
elif rescaling == "image_dim":
kl = kl / np.prod(img_shape[-2:])

# Apply free bits
kl_loss = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,)

# Aggregation
if aggregation == "mean":
kl_loss = kl_loss.mean() # shape: (1,)
kl = kl.mean() # shape: (1,)
elif aggregation == "sum":
kl_loss = kl_loss.sum() # shape: (1,)
kl = kl.sum() # shape: (1,)

return kl_loss
return kl


def _get_kl_divergence_loss_musplit(
Expand Down Expand Up @@ -220,7 +220,7 @@ def _get_kl_divergence_loss_musplit(
The KL divergence loss for the muSplit case. Shape is (1, ).
"""
return get_kl_divergence_loss(
kl_type=kl_type,
kl_type="kl", # TODO: hardcoded, deal in future PR
topdown_data=topdown_data,
rescaling="latent_dim",
aggregation="mean",
Expand Down
37 changes: 0 additions & 37 deletions src/careamics/models/lvae/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,40 +402,3 @@ def kl_normal_mc(z, p_mulv, q_mulv):
p_distrib = Normal(p_mu.get(), p_std)
q_distrib = Normal(q_mu.get(), q_std)
return q_distrib.log_prob(z) - p_distrib.log_prob(z)


def free_bits_kl(
kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
) -> torch.Tensor:
"""
Computes free-bits version of KL divergence.
Ensures that the KL doesn't go to zero for any latent dimension.
Hence, it contributes to use latent variables more efficiently,
leading to better representation learning.

NOTE:
Takes in the KL with shape (batch size, layers), returns the KL with
free bits (for optimization) with shape (layers,), which is the average
free-bits KL per layer in the current batch.
If batch_average is False (default), the free bits are per layer and
per batch element. Otherwise, the free bits are still per layer, but
are assigned on average to the whole batch. In both cases, the batch
average is returned, so it's simply a matter of doing mean(clamp(KL))
or clamp(mean(KL)).

Args:
kl (torch.Tensor)
free_bits (float)
batch_average (bool, optional))
eps (float, optional)

Returns
-------
The KL with free bits
"""
assert kl.dim() == 2
if free_bits < eps:
return kl.mean(0)
if batch_average:
return kl.mean(0).clamp(min=free_bits)
return kl.clamp(min=free_bits).mean(0)
Loading