Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update example "Train a Vision Transformer on small datasets" from keras 2 to 3 version #1922

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 9 additions & 20 deletions examples/vision/shiftvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,15 @@
In this example, we minimally implement the paper with close alignement to the author's
[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).

This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can
be installed using the following command:
"""
"""shell
pip install -qq -U tensorflow-addons
"""

"""
## Setup and imports
"""

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import keras
from keras import layers

import pathlib
import glob
Expand Down Expand Up @@ -280,7 +271,7 @@ def __init__(self, drop_path_prob, **kwargs):
def call(self, x, training=False):
if training:
keep_prob = 1 - self.drop_path_prob
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
shape = (tf.shape(x)[0],) + (1,) * (len(x.shape) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
Expand Down Expand Up @@ -871,7 +862,7 @@ def get_config(self):
)

# Get the optimizer.
optimizer = tfa.optimizers.AdamW(
optimizer = keras.optimizers.AdamW(
learning_rate=scheduled_lrs, weight_decay=config.weight_decay
)

Expand Down Expand Up @@ -913,7 +904,7 @@ def get_config(self):

It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well.
"""
model.save("ShiftViT")
model.export("ShiftViT")

"""
## Model inference
Expand All @@ -932,12 +923,10 @@ def get_config(self):
"""
**Load saved model**
"""
# Custom objects are not included when the model is saved.
# At loading time, these objects need to be passed for reconstruction of the model
saved_model = tf.keras.models.load_model(
"ShiftViT",
custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW},
)
saved_layer = keras.layers.TFSMLayer("ShiftViT")
inputs = tf.keras.Input(shape=(config.input_shape)) # specify your input shape
outputs = saved_layer(inputs)
saved_model = tf.keras.Model(inputs, outputs)

"""
**Utility functions for inference**
Expand Down
17 changes: 5 additions & 12 deletions examples/vision/vit_small_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,7 @@
example is inspired from
[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).

_Note_: This example requires TensorFlow 2.6 or higher, as well as
[TensorFlow Addons](https://www.tensorflow.org/addons), which can be
installed using the following command:

```python
pip install -qq -U tensorflow-addons
```
_Note_: This example requires TensorFlow 3 or higher
"""
"""
## Setup
Expand All @@ -50,10 +44,9 @@
import math
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_addons as tfa
import keras
import matplotlib.pyplot as plt
from tensorflow.keras import layers
from keras import layers

# Setting seed for reproducibiltiy
SEED = 42
Expand Down Expand Up @@ -355,7 +348,7 @@ def call(self, encoded_patches):
"""


class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
class MultiHeadAttentionLSA(keras.layers.MultiHeadAttention):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# The trainable temperature term. The initial value is
Expand Down Expand Up @@ -499,7 +492,7 @@ def run_experiment(model):
warmup_steps=warmup_steps,
)

optimizer = tfa.optimizers.AdamW(
optimizer = keras.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

Expand Down