This repository hosts code for converting the original MLP-Mixer models [1] (JAX) to TensorFlow. The converted models are hosted on TensorFlow Hub and can be found here: https://tfhub.dev/sayakpaul/collections/mlp-mixer/1.
Note that it's a requirement to use TensorFlow 2.6 or greater to use the converted models.
Several model variants are available:
- B/16 (classification, feature-extractor)
- B/32 (classification, feature-extractor)
- B/16 (classification, feature-extractor)
- L/16 (classification, feature-extractor)
- B/16 (classification, feature-extractor)
- L/16 (classification, feature-extractor)
For more details on the training protocols, please follow [1, 3].
The original model classes and weights [4] were converted using the jax2tf
tool [5]. For details on the conversion process,
please refer to the conversion.ipynb
notebook.
I independently validated two models on the ImageNet-1k validation set. The table below reports the top-1 accuracies along with their respective logs from tensorboard.dev.
Model | Top-1 Accuracy | tb.dev link |
---|---|---|
B-16 fine-tuned on ImageNet-1k |
75.31% | Link |
B-16 pre-trained on ImageNet-1k using SAM |
75.58% | Link |
Here is a tensorboard.dev run that logs fine-tuning results (using this model) for the Flowers dataset.
classification.ipynb
: Shows how to load a Vision Transformer model from TensorFlow Hub and run image classification.fine-tune.ipynb
: Shows how to fine-tune a Vision Transformer model from TensorFlow Hub on thetf_flowers
dataset.
[1] MLP-Mixer: An all-MLP Architecture for Vision by Tolstikhin et al.
[2] Sharpness-Aware Minimization for Efficiently Improving Generalization by Foret et al.
[5] jax2tf tool
Thanks to the ML-GDE program for providing GCP Credit support that helped me execute the experiments for this project.