Skip to content

Commit

Permalink
update: vae.py to support both tensorflow and torch backends
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Dec 3, 2023
1 parent dfc5695 commit 3f33056
Showing 1 changed file with 42 additions and 12 deletions.
54 changes: 42 additions & 12 deletions examples/generative/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@

import os

os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["KERAS_BACKEND"] = "tensorflow" # @param ["tensorflow", "torch"]

import numpy as np
import tensorflow as tf
import keras
from keras import layers
from keras import layers, ops

"""
## Create a sampling layer
Expand All @@ -30,10 +29,10 @@ class Sampling(layers.Layer):

def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.random.normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
batch = ops.shape(z_mean)[0]
dim = ops.shape(z_mean)[1]
epsilon = keras.random.normal(shape=(batch, dim))
return z_mean + ops.exp(0.5 * z_log_var) * epsilon


"""
Expand Down Expand Up @@ -90,21 +89,52 @@ def metrics(self):
self.kl_loss_tracker,
]

def train_step(self, data):
def _tf_train_step(self, data):
import tensorflow as tf

with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
reconstruction_loss = ops.mean(
ops.sum(
keras.losses.binary_crossentropy(data, reconstruction),
axis=(1, 2),
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
return total_loss, reconstruction_loss, kl_loss

def _torch_train_step(self, data):
import torch

self.zero_grad()
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = ops.mean(
ops.sum(
keras.losses.binary_crossentropy(data, reconstruction),
axis=(1, 2),
)
)
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
total_loss.backward()
trainable_weights = [v for v in self.trainable_weights]
gradients = [v.value.grad for v in trainable_weights]
with torch.no_grad():
self.optimizer.apply(gradients, trainable_weights)
return total_loss, reconstruction_loss, kl_loss

def train_step(self, data):
if keras.backend.backend() == "tensorflow":
total_loss, reconstruction_loss, kl_loss = self._tf_train_step(data)
elif keras.backend.backend() == "torch":
total_loss, reconstruction_loss, kl_loss = self._torch_train_step(data)
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
Expand Down

0 comments on commit 3f33056

Please sign in to comment.