Skip to content

Commit

Permalink
compute ms ssim in test steps
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Jan 21, 2025
1 parent a65f88d commit a6f26f2
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ml4h/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def __init__(self, name="multi_scale_ssim", **kwargs):

def update_state(self, y_true, y_pred, max_val, sample_weight=None):
# Calculate MS-SSIM for the batch
ssim = tf.image.ssim_multiscale(y_true, y_pred, max_val=max_val)
ssim = tf.image.ssim_multiscale(y_true, y_pred, max_val=max_val, power_factors=[0.01, 0.2, 0.5, 0.29])
if sample_weight is not None:
ssim = tf.multiply(ssim, sample_weight)

Expand Down
3 changes: 1 addition & 2 deletions ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,8 +888,7 @@ def test_step(self, batch):
self.kid.update_state(images, generated_images)
max_pixel_value = tf.reduce_max(tf.abs(images))
max_val = 2 * max_pixel_value # Double the max absolute value
self.ms_ssim.update_state(tf.tile(images, [1, 1, 1, 3]),
tf.tile(generated_images, [1, 1, 1, 3]), max_val)
self.ms_ssim.update_state(images, generated_images, max_val)

return {m.name: m.result() for m in self.metrics}

Expand Down

0 comments on commit a6f26f2

Please sign in to comment.