Skip to content

Commit

Permalink
Migrate vivit tutorial to Keras3(all backends)
Browse files Browse the repository at this point in the history
  • Loading branch information
SuryanarayanaY committed Jan 15, 2024
1 parent 9bb7074 commit 699621c
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions examples/vision/vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Video Vision Transformer
Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ayush Thakur](https://twitter.com/ayushthakur0) (equal contribution)
Date created: 2022/01/12
Last modified: 2022/01/12
Last modified: 2024/01/15
Description: A Transformer-based architecture for video classification.
Accelerator: GPU
"""
Expand Down Expand Up @@ -30,8 +30,8 @@
the embedding scheme and one of the variants of the Transformer
architecture, for simplicity.
This example requires TensorFlow 2.6 or higher, and the `medmnist`
package, which can be installed by running the code cell below.
This example requires `medmnist` package, which can be installed
by running the code cell below.
"""

"""shell
Expand All @@ -48,9 +48,9 @@
import medmnist
import ipywidgets
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf # for data preprocessing only
import keras
from keras import layers, ops

# Setting seed for reproducibility
SEED = 42
Expand Down Expand Up @@ -137,7 +137,6 @@ def download_and_prepare_dataset(data_info: dict):
"""


@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
"""Preprocess the frames tensors and parse the labels."""
# Preprocess images
Expand Down Expand Up @@ -235,7 +234,7 @@ def build(self, input_shape):
self.position_embedding = layers.Embedding(
input_dim=num_tokens, output_dim=self.embed_dim
)
self.positions = tf.range(start=0, limit=num_tokens, delta=1)
self.positions = ops.arange(0, num_tokens, 1)

def call(self, encoded_tokens):
# Encode the positions and add it to the encoded tokens
Expand Down Expand Up @@ -295,8 +294,8 @@ def create_vivit_classifier(
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
x3 = keras.Sequential(
[
layers.Dense(units=embed_dim * 4, activation=tf.nn.gelu),
layers.Dense(units=embed_dim, activation=tf.nn.gelu),
layers.Dense(units=embed_dim * 4, activation=ops.gelu),
layers.Dense(units=embed_dim, activation=ops.gelu),
]
)(x3)

Expand Down Expand Up @@ -368,11 +367,13 @@ def run_experiment():
for i, (testsample, label) in enumerate(zip(testsamples, labels)):
# Generate gif
with io.BytesIO() as gif:
imageio.mimsave(gif, (testsample.numpy() * 255).astype("uint8"), "GIF", fps=5)
imageio.mimsave(
gif, (np.squeeze(testsample) * 255).astype("uint8"), "GIF", fps=5
)
videos.append(gif.getvalue())

# Get model prediction
output = model.predict(tf.expand_dims(testsample, axis=0))[0]
output = model.predict(np.expand_dims(testsample, axis=0))[0]
pred = np.argmax(output, axis=0)

ground_truths.append(label.numpy().astype("int"))
Expand Down

0 comments on commit 699621c

Please sign in to comment.