From ff98a07cc5250f1445c48b1d2d266f2f578f2f0f Mon Sep 17 00:00:00 2001 From: federico-carrara Date: Mon, 18 Nov 2024 16:37:13 +0100 Subject: [PATCH 1/5] fix: changed the order of free-bits and rescaling, which was previously wrong --- src/careamics/losses/lvae/losses.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/careamics/losses/lvae/losses.py b/src/careamics/losses/lvae/losses.py index 310d0bb80..e0cd59e8f 100644 --- a/src/careamics/losses/lvae/losses.py +++ b/src/careamics/losses/lvae/losses.py @@ -168,6 +168,9 @@ def get_kl_divergence_loss( dim=1, ) # shape: (B, n_layers) + # Apply free bits (& batch average) + kl = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,) + # In 3D case, rescale by Z dim # TODO If we have downsampling in Z dimension, then this needs to change. if len(img_shape) == 3: @@ -175,23 +178,20 @@ def get_kl_divergence_loss( # Rescaling if rescaling == "latent_dim": - for i in range(kl.shape[1]): + for i in range(len(kl)): latent_dim = topdown_data["z"][i].shape[1:] norm_factor = np.prod(latent_dim) - kl[:, i] = kl[:, i] / norm_factor + kl[i] = kl[i] / norm_factor elif rescaling == "image_dim": kl = kl / np.prod(img_shape[-2:]) - # Apply free bits - kl_loss = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,) - # Aggregation if aggregation == "mean": - kl_loss = kl_loss.mean() # shape: (1,) + kl = kl.mean() # shape: (1,) elif aggregation == "sum": - kl_loss = kl_loss.sum() # shape: (1,) + kl = kl.sum() # shape: (1,) - return kl_loss + return kl def _get_kl_divergence_loss_musplit( From 4207f3843d3a7c540ec531bcc7395b7c2ed42fb3 Mon Sep 17 00:00:00 2001 From: federico-carrara Date: Mon, 18 Nov 2024 18:17:18 +0100 Subject: [PATCH 2/5] tmp: hardcoding kl_type in musplit KL loss to "kl", since it is always like this --- src/careamics/losses/lvae/losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/careamics/losses/lvae/losses.py b/src/careamics/losses/lvae/losses.py index e0cd59e8f..4258b2f48 100644 --- a/src/careamics/losses/lvae/losses.py +++ b/src/careamics/losses/lvae/losses.py @@ -220,7 +220,7 @@ def _get_kl_divergence_loss_musplit( The KL divergence loss for the muSplit case. Shape is (1, ). """ return get_kl_divergence_loss( - kl_type=kl_type, + kl_type="kl", # TODO: hardcoded, deal in future PR topdown_data=topdown_data, rescaling="latent_dim", aggregation="mean", From 71d61904a66b870ad819fe45d3e010f602e82c07 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 17:19:22 +0000 Subject: [PATCH 3/5] style(pre-commit.ci): auto fixes [...] --- src/careamics/losses/lvae/losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/careamics/losses/lvae/losses.py b/src/careamics/losses/lvae/losses.py index 4258b2f48..9514846c3 100644 --- a/src/careamics/losses/lvae/losses.py +++ b/src/careamics/losses/lvae/losses.py @@ -220,7 +220,7 @@ def _get_kl_divergence_loss_musplit( The KL divergence loss for the muSplit case. Shape is (1, ). """ return get_kl_divergence_loss( - kl_type="kl", # TODO: hardcoded, deal in future PR + kl_type="kl", # TODO: hardcoded, deal in future PR topdown_data=topdown_data, rescaling="latent_dim", aggregation="mean", From c3aece1b23e624b448f5334a266dba725c370999 Mon Sep 17 00:00:00 2001 From: federico-carrara Date: Tue, 19 Nov 2024 11:04:06 +0100 Subject: [PATCH 4/5] rm: removed duplicated function --- src/careamics/models/lvae/utils.py | 39 +----------------------------- 1 file changed, 1 insertion(+), 38 deletions(-) diff --git a/src/careamics/models/lvae/utils.py b/src/careamics/models/lvae/utils.py index 1089932a9..cf520aa62 100644 --- a/src/careamics/models/lvae/utils.py +++ b/src/careamics/models/lvae/utils.py @@ -401,41 +401,4 @@ def kl_normal_mc(z, p_mulv, q_mulv): p_distrib = Normal(p_mu.get(), p_std) q_distrib = Normal(q_mu.get(), q_std) - return q_distrib.log_prob(z) - p_distrib.log_prob(z) - - -def free_bits_kl( - kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6 -) -> torch.Tensor: - """ - Computes free-bits version of KL divergence. - Ensures that the KL doesn't go to zero for any latent dimension. - Hence, it contributes to use latent variables more efficiently, - leading to better representation learning. - - NOTE: - Takes in the KL with shape (batch size, layers), returns the KL with - free bits (for optimization) with shape (layers,), which is the average - free-bits KL per layer in the current batch. - If batch_average is False (default), the free bits are per layer and - per batch element. Otherwise, the free bits are still per layer, but - are assigned on average to the whole batch. In both cases, the batch - average is returned, so it's simply a matter of doing mean(clamp(KL)) - or clamp(mean(KL)). - - Args: - kl (torch.Tensor) - free_bits (float) - batch_average (bool, optional)) - eps (float, optional) - - Returns - ------- - The KL with free bits - """ - assert kl.dim() == 2 - if free_bits < eps: - return kl.mean(0) - if batch_average: - return kl.mean(0).clamp(min=free_bits) - return kl.clamp(min=free_bits).mean(0) + return q_distrib.log_prob(z) - p_distrib.log_prob(z) \ No newline at end of file From 52ba7ad65660522995228c1b2f07d53906830ec2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 10:05:50 +0000 Subject: [PATCH 5/5] style(pre-commit.ci): auto fixes [...] --- src/careamics/models/lvae/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/careamics/models/lvae/utils.py b/src/careamics/models/lvae/utils.py index cf520aa62..2698dbf5a 100644 --- a/src/careamics/models/lvae/utils.py +++ b/src/careamics/models/lvae/utils.py @@ -401,4 +401,4 @@ def kl_normal_mc(z, p_mulv, q_mulv): p_distrib = Normal(p_mu.get(), p_std) q_distrib = Normal(q_mu.get(), q_std) - return q_distrib.log_prob(z) - p_distrib.log_prob(z) \ No newline at end of file + return q_distrib.log_prob(z) - p_distrib.log_prob(z)