Skip to content

This repository hosts code for converting the original MLP Mixer models (JAX) to TensorFlow.

License

Notifications You must be signed in to change notification settings

sayakpaul/MLPMixer-jax2tf

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MLPMixer-jax2tf


Example usage.

This repository hosts code for converting the original MLP-Mixer models [1] (JAX) to TensorFlow. The converted models are hosted on TensorFlow Hub and can be found here: https://tfhub.dev/sayakpaul/collections/mlp-mixer/1.

Note that it's a requirement to use TensorFlow 2.6 or greater to use the converted models.

Several model variants are available:

SAM [2] pre-trained (these models were pre-trained on ImageNet-1k):

ImageNet-1k fine-tuned:

ImageNet-21k pre-trained:

For more details on the training protocols, please follow [1, 3].

The original model classes and weights [4] were converted using the jax2tf tool [5]. For details on the conversion process, please refer to the conversion.ipynb notebook.

I independently validated two models on the ImageNet-1k validation set. The table below reports the top-1 accuracies along with their respective logs from tensorboard.dev.

Model Top-1 Accuracy tb.dev link
B-16 fine-tuned on
ImageNet-1k
75.31% Link
B-16 pre-trained on
ImageNet-1k using SAM
75.58% Link

Here is a tensorboard.dev run that logs fine-tuning results (using this model) for the Flowers dataset.

Other notebooks

  • classification.ipynb: Shows how to load a Vision Transformer model from TensorFlow Hub and run image classification.
  • fine-tune.ipynb: Shows how to fine-tune a Vision Transformer model from TensorFlow Hub on the tf_flowers dataset.

References

[1] MLP-Mixer: An all-MLP Architecture for Vision by Tolstikhin et al.

[2] Sharpness-Aware Minimization for Efficiently Improving Generalization by Foret et al.

[3] When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations by Chen et al.

[4] Vision Transformer GitHub

[5] jax2tf tool

Acknowledgements

Thanks to the ML-GDE program for providing GCP Credit support that helped me execute the experiments for this project.

About

This repository hosts code for converting the original MLP Mixer models (JAX) to TensorFlow.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published