diff --git a/src/careamics/losses/lvae/losses.py b/src/careamics/losses/lvae/losses.py index 310d0bb80..9514846c3 100644 --- a/src/careamics/losses/lvae/losses.py +++ b/src/careamics/losses/lvae/losses.py @@ -168,6 +168,9 @@ def get_kl_divergence_loss( dim=1, ) # shape: (B, n_layers) + # Apply free bits (& batch average) + kl = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,) + # In 3D case, rescale by Z dim # TODO If we have downsampling in Z dimension, then this needs to change. if len(img_shape) == 3: @@ -175,23 +178,20 @@ def get_kl_divergence_loss( # Rescaling if rescaling == "latent_dim": - for i in range(kl.shape[1]): + for i in range(len(kl)): latent_dim = topdown_data["z"][i].shape[1:] norm_factor = np.prod(latent_dim) - kl[:, i] = kl[:, i] / norm_factor + kl[i] = kl[i] / norm_factor elif rescaling == "image_dim": kl = kl / np.prod(img_shape[-2:]) - # Apply free bits - kl_loss = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,) - # Aggregation if aggregation == "mean": - kl_loss = kl_loss.mean() # shape: (1,) + kl = kl.mean() # shape: (1,) elif aggregation == "sum": - kl_loss = kl_loss.sum() # shape: (1,) + kl = kl.sum() # shape: (1,) - return kl_loss + return kl def _get_kl_divergence_loss_musplit( @@ -220,7 +220,7 @@ def _get_kl_divergence_loss_musplit( The KL divergence loss for the muSplit case. Shape is (1, ). """ return get_kl_divergence_loss( - kl_type=kl_type, + kl_type="kl", # TODO: hardcoded, deal in future PR topdown_data=topdown_data, rescaling="latent_dim", aggregation="mean", diff --git a/src/careamics/models/lvae/utils.py b/src/careamics/models/lvae/utils.py index 1089932a9..2698dbf5a 100644 --- a/src/careamics/models/lvae/utils.py +++ b/src/careamics/models/lvae/utils.py @@ -402,40 +402,3 @@ def kl_normal_mc(z, p_mulv, q_mulv): p_distrib = Normal(p_mu.get(), p_std) q_distrib = Normal(q_mu.get(), q_std) return q_distrib.log_prob(z) - p_distrib.log_prob(z) - - -def free_bits_kl( - kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6 -) -> torch.Tensor: - """ - Computes free-bits version of KL divergence. - Ensures that the KL doesn't go to zero for any latent dimension. - Hence, it contributes to use latent variables more efficiently, - leading to better representation learning. - - NOTE: - Takes in the KL with shape (batch size, layers), returns the KL with - free bits (for optimization) with shape (layers,), which is the average - free-bits KL per layer in the current batch. - If batch_average is False (default), the free bits are per layer and - per batch element. Otherwise, the free bits are still per layer, but - are assigned on average to the whole batch. In both cases, the batch - average is returned, so it's simply a matter of doing mean(clamp(KL)) - or clamp(mean(KL)). - - Args: - kl (torch.Tensor) - free_bits (float) - batch_average (bool, optional)) - eps (float, optional) - - Returns - ------- - The KL with free bits - """ - assert kl.dim() == 2 - if free_bits < eps: - return kl.mean(0) - if batch_average: - return kl.mean(0).clamp(min=free_bits) - return kl.clamp(min=free_bits).mean(0)