Skip to content

Commit

Permalink
compute ms ssim in test steps
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Jan 21, 2025
1 parent 0ed7c5d commit 1f439a4
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
26 changes: 26 additions & 0 deletions ml4h/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,3 +833,29 @@ def result(self):
def reset_state(self):
self.is_tracker.reset_state()


class MultiScaleSSIM(keras.metrics.Metric):
def __init__(self, max_val=6.0, name="multi_scale_ssim", **kwargs):
super(MultiScaleSSIM, self).__init__(name=name, **kwargs)
self.max_val = max_val
self.total_ssim = self.add_weight(name="total_ssim", initializer="zeros")
self.count = self.add_weight(name="count", initializer="zeros")

def update_state(self, y_true, y_pred, sample_weight=None):
# Calculate MS-SSIM for the batch
ssim = tf.image.ssim_multiscale(y_true, y_pred, max_val=self.max_val)
if sample_weight is not None:
ssim = tf.multiply(ssim, sample_weight)

# Update total MS-SSIM and count
self.total_ssim.assign_add(tf.reduce_sum(ssim))
self.count.assign_add(tf.cast(tf.size(ssim), tf.float32))

def result(self):
# Return the mean MS-SSIM over all batches
return tf.divide(self.total_ssim, self.count)

def reset_states(self):
# Reset the metric state variables
self.total_ssim.assign(0.0)
self.count.assign(0.0)
8 changes: 4 additions & 4 deletions ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from keras import layers

from ml4h.defines import IMAGE_EXT
from ml4h.metrics import KernelInceptionDistance, InceptionScore
from ml4h.metrics import KernelInceptionDistance, InceptionScore, MultiScaleSSIM
from ml4h.models.Block import Block
from ml4h.TensorMap import TensorMap

Expand Down Expand Up @@ -669,7 +669,7 @@ def compile(self, **kwargs):
self.supervised_loss_tracker = keras.metrics.Mean(name="supervised_loss")
if self.input_map.axes() == 3 and self.inspect_model:
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.input_map.shape, kernel_image_size=299)
self.inception_score = InceptionScore(name = "is", input_shape = self.input_map.shape, kernel_image_size=299)
self.ms_ssim = MultiScaleSSIM()

@property
def metrics(self):
Expand All @@ -678,7 +678,7 @@ def metrics(self):
m.append(self.supervised_loss_tracker)
if self.input_map.axes() == 3 and self.inspect_model:
m.append(self.kid)
m.append(self.inception_score)
m.append(self.ms_ssim)
return m

def denormalize(self, images):
Expand Down Expand Up @@ -886,7 +886,7 @@ def test_step(self, batch):
num_images=self.batch_size, diffusion_steps=20
)
self.kid.update_state(images, generated_images)
self.inception_score.update_state(images)
self.ms_ssim.update_state(images, generated_images)

return {m.name: m.result() for m in self.metrics}

Expand Down

0 comments on commit 1f439a4

Please sign in to comment.