diff --git a/openfl-tutorials/experimental/Vision_Transformer/Workflow_Interface_102_Vision_Transformer.ipynb b/openfl-tutorials/experimental/Vision_Transformer/Workflow_Interface_102_Vision_Transformer.ipynb new file mode 100644 index 00000000000..571d2b2ea34 --- /dev/null +++ b/openfl-tutorials/experimental/Vision_Transformer/Workflow_Interface_102_Vision_Transformer.ipynb @@ -0,0 +1,690 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "79a3d106", + "metadata": {}, + "source": [ + "# Workflow Interface 102: \n", + "# Vision Transformer for Image Classification using MedMNIST\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/openfl/blob/develop/openfl-tutorials/experimental/Vision_Transformer/Workflow_Interface_102_Vision_Transformer.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "dc0ccb0a", + "metadata": {}, + "source": [ + "Introduced in the seminal paper \"Attention is All you Need\" transformers have revolutionized natural language processing by using self-attention mechanisms to capture global dependencies in textual data. Leveraging this, Dosovitskiy et al. introduced the one of the first successful and empirically validated pure transformer model for image classification in [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929v2). \n", + "\n", + "\n", + "| | \n", + "|:--:| \n", + "| *[source](https://arxiv.org/abs/2010.11929v2)* |\n", + "\n", + "In contrast to tradition convolutional neural networks which focus on capturing local image features within a spatial window using a sliding filter, the self-attention mechanism enables vision transformers to capture global relationships between image patches. \n", + "\n", + "In this tutorial, you will learn how to set up a horizontal federated learning workflow using the OpenFL Experimental Workflow Interface to train a vision transformer to classify images from the MedMNIST dataset. This notebook expands on the use case from the [first](https://github.com/intel/openfl/blob/develop/openfl-tutorials/experimental/Workflow_Interface_101_MNIST.ipynb) quick start notebook. Its objective is to demonstrate how a user can modify the workflow interface for different use cases" + ] + }, + { + "cell_type": "markdown", + "id": "ff6e97a0", + "metadata": {}, + "source": [ + "# Getting Started" + ] + }, + { + "cell_type": "markdown", + "id": "7085394d", + "metadata": {}, + "source": [ + "First we start by installing the necessary dependencies for the workflow interface and the vision transformer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2504d13c", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install git+https://github.com/intel/openfl.git\n", + "# !pip install -r requirements_workflow_interface.txt\n", + "# !pip install -r requirements_vision_transformer.txt\n", + "\n", + "# Uncomment this if running in Google Colab\n", + "#!pip install -r https://raw.githubusercontent.com/intel/openfl/develop/openfl-tutorials/experimental/Vision_Transformer/requirements_workflow_interface.txt\n", + "#!pip install -r https://raw.githubusercontent.com/intel/openfl/develop/openfl-tutorials/experimental/Vision_Transformer/requirements_vision_transformer.txt\n", + "\n", + "#import os\n", + "#os.environ[\"USERNAME\"] = \"colab\"" + ] + }, + { + "cell_type": "markdown", + "id": "c6ff2f9b", + "metadata": {}, + "source": [ + "# Setting up the experiment" + ] + }, + { + "cell_type": "markdown", + "id": "605dd5ca", + "metadata": {}, + "source": [ + "For those of you who are familiar with a standard deep learning training pipeline, you may recognize that this section demonstrates many familiar steps such as setting up your data and defining your dataloader, model, parameters, helper functions, etc." + ] + }, + { + "cell_type": "markdown", + "id": "796fe058", + "metadata": {}, + "source": [ + "We start by importing the [MedMNIST](https://github.com/MedMNIST/MedMNIST/) package and defining our dataset. This cell will provide information about the package and list the available datasets. We will use the PathMNIST dataset. This is a colon pathology comprised of 107,180 unique 2D images. We will train our vision transformer to classify an individual image as one of 9 classes.\n", + "\n", + "| | \n", + "|:--:| \n", + "| *Sample of images [(source)](https://medmnist.com/)* |\n", + "\n", + "Set `data_flag` to choose a different dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0c87319", + "metadata": {}, + "outputs": [], + "source": [ + "# https://github.com/MedMNIST/MedMNIST/blob/main/examples/getting_started.ipynb\n", + "import medmnist\n", + "from medmnist import INFO, Evaluator\n", + "\n", + "print(f\"MedMNIST v{medmnist.__version__} @ {medmnist.HOMEPAGE}\")\n", + "\n", + "print('\\n---- List of Available datasets ----\\n')\n", + "for key in INFO:\n", + " print(key)\n", + " \n", + "print('\\n------------------------------------\\n')\n", + "\n", + "data_flag = 'pathmnist'\n", + "print(f'Chosen dataset: {data_flag}')\n", + "\n", + "info = INFO[data_flag]\n", + "task = info['task']\n", + "n_channels = info['n_channels']\n", + "n_classes = len(info['label'])\n", + "\n", + "DataClass = getattr(medmnist, info['python_class'])" + ] + }, + { + "cell_type": "markdown", + "id": "2ed5bba7", + "metadata": {}, + "source": [ + "Next, we will load our dataset and prepare it to be consumed by our model. We will be using the HuggingFace transformer library's implementation of the [vision transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit) pretrained on ImageNet-21k as the backbone of our network. To that end, we will use `ViTImageProcessor` which will provide the proper parameters needed to process and transform our dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4cb74c7", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torchvision.transforms as transforms\n", + "from torch.utils.data import DataLoader, Subset\n", + "from transformers import ViTImageProcessor\n", + "\n", + "import time\n", + "import numpy as np\n", + "\n", + "# preprocessing\n", + "processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')\n", + "\n", + "image_mean, image_std = processor.image_mean, processor.image_std\n", + "h = processor.size[\"height\"]\n", + "w = processor.size[\"width\"]\n", + "\n", + "train_transforms = transforms.Compose([\n", + " transforms.Resize([h, w]),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=image_mean, std=image_std)\n", + " ])\n", + "\n", + "test_transforms = transforms.Compose([\n", + " transforms.Resize([h, w]),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=image_mean, std=image_std)\n", + " ])\n", + "\n", + "\n", + "# load the data\n", + "medmnist_train = DataClass(split='train', transform=train_transforms, download=True)\n", + "medmnist_test = DataClass(split='test', transform=test_transforms, download=True)\n", + "\n", + "# For demonstration purposes, we take a subset to reduce overall size and training time\n", + "##################\n", + "subset_indices = range(320)\n", + "medmnist_train = Subset(medmnist_train, subset_indices)\n", + "medmnist_test = Subset(medmnist_test, subset_indices)\n", + "##################" + ] + }, + { + "cell_type": "markdown", + "id": "0e1bebe0", + "metadata": {}, + "source": [ + "We now define our network and inference function. As previously noted, our network will use a pretrained vision transformer background `ViTModel`. We add a custom classification head, which will enable us to fine-tune our model on the chosen PathMNIST dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff3b4fc2", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from transformers import ViTModel\n", + "\n", + "\n", + "class CustomVisionTransformer(nn.Module):\n", + " def __init__(self, num_classes):\n", + " super(CustomVisionTransformer, self).__init__()\n", + " self.backbone = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')\n", + " self.classifier = nn.Linear(self.backbone.config.hidden_size, num_classes)\n", + "\n", + " def forward(self, x):\n", + " # Extract features from the transformer\n", + " features = self.backbone(x)\n", + " # Take the hidden state from the [CLS] token\n", + " cls_token = features.last_hidden_state[:, 0, :]\n", + " # Pass it through the classification head\n", + " logits = self.classifier(cls_token)\n", + " return logits\n", + " \n", + " \n", + "def inference(model, test_loader, criterion):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " test_loss = 0\n", + "\n", + " with torch.no_grad():\n", + " for data, labels in test_loader:\n", + " outputs = model(data)\n", + " test_loss += criterion(outputs, labels.flatten())\n", + " \n", + " _, predicted = torch.max(outputs, 1)\n", + " \n", + " correct += (predicted == labels.flatten()).sum().item()\n", + " \n", + " test_loss /= len(test_loader.dataset)\n", + "\n", + " accuracy = float(correct / len(test_loader.dataset))\n", + " return accuracy" + ] + }, + { + "cell_type": "markdown", + "id": "68741136", + "metadata": {}, + "source": [ + "# Setting up the OpenFL Workflow Interface" + ] + }, + { + "cell_type": "markdown", + "id": "2194cccb", + "metadata": {}, + "source": [ + "We will now set up the experimental OpenFL workflow interface in order to fine-tune our model in a horizontal federated learning framework. We import the `FLSpec`, `LocalRuntime`, and placement decorators.\n", + "\n", + "- `FLSpec` – Defines the flow specification. User defined flows are subclasses of this.\n", + "- `Runtime` – Defines where the flow runs, infrastructure for task transitions (how information gets sent). The `LocalRuntime` runs the flow on a single node.\n", + "- `aggregator/collaborator` - placement decorators that define where the task will be assigned" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "abf44fa5", + "metadata": {}, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "\n", + "from openfl.experimental.interface import FLSpec, Aggregator, Collaborator\n", + "from openfl.experimental.runtime import LocalRuntime\n", + "from openfl.experimental.placement import aggregator, collaborator\n", + "\n", + "\n", + "def FedAvg(models, weights=None):\n", + " new_model = models[0]\n", + " state_dicts = [model.state_dict() for model in models]\n", + " state_dict = new_model.state_dict()\n", + " for key in models[1].state_dict():\n", + " state_dict[key] = torch.from_numpy(np.average([state[key].numpy() for state in state_dicts],\n", + " axis=0, \n", + " weights=weights))\n", + " new_model.load_state_dict(state_dict)\n", + " return new_model" + ] + }, + { + "cell_type": "markdown", + "id": "5247407f", + "metadata": {}, + "source": [ + "Now we come to the flow definition. The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with `start` and concludes with the `end` task. The aggregator begins with a base model and optimizer. The aggregator begins the flow with the `start` task, where the list of collaborators is extracted from the runtime (`self.collaborators = self.runtime.collaborators`) and is then used as the list of participants to run the task listed in `self.next`, `aggregated_model_validation`. The model, optimizer, and anything that is not explicitly excluded from the next function will be passed from the `start` function on the aggregator to the `aggregated_model_validation` task on the collaborator. Where the tasks run is determined by the placement decorator that precedes each task definition (`@aggregator` or `@collaborator`). Once each of the collaborators (defined in the runtime) complete the `aggregated_model_validation` task, they pass their current state onto the `train` task, from `train` to `local_model_validation`, and then finally to `join` at the aggregator. It is in `join` that an average is taken of the model weights, and the next round can begin. Throughout the process, we will save out the collaborator models as well as the final aggregated model.\n", + "\n", + "| | \n", + "|:--:| \n", + "| *General OpenFL Workflow Interface architecture* |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdb4895c", + "metadata": {}, + "outputs": [], + "source": [ + "class FederatedFlow(FLSpec):\n", + " def __init__(self, model, optimizer, criterion, rounds=2, epochs=3, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.model = model\n", + " self.optimizer = optimizer\n", + " self.criterion = criterion\n", + " self.rounds = rounds\n", + " self.epochs = epochs\n", + "\n", + " @aggregator\n", + " def start(self):\n", + " print(f'Performing initialization for model')\n", + " self.collaborators = self.runtime.collaborators\n", + " self.private = 10\n", + " self.current_round = 0\n", + " self.next(self.aggregated_model_validation,foreach='collaborators',exclude=['private'])\n", + "\n", + " @collaborator\n", + " def aggregated_model_validation(self):\n", + " print(f'Round: {self.current_round+1}\\n-------------------------------')\n", + " print(f'Performing aggregated model validation for collaborator {self.input}')\n", + " self.agg_validation_score = inference(self.model, self.test_loader, self.criterion)\n", + " print(f'{self.input} value of {self.agg_validation_score}')\n", + " self.next(self.train)\n", + "\n", + " @collaborator\n", + " def train(self):\n", + " if not os.path.exists(os.path.join('weights',f'{self.input}')):\n", + " os.makedirs(os.path.join('weights',f'{self.input}'))\n", + " \n", + " best_acc = 0.0\n", + " \n", + " print(f\"{self.input}\")\n", + " for t in range(self.epochs):\n", + " for phase in ['train', 'val']:\n", + " \n", + " if phase == 'train':\n", + " self.model.train()\n", + " self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.5)\n", + " train_loss = 0.0\n", + "\n", + " for batch_idx, (images, labels) in enumerate(self.train_loader):\n", + " self.optimizer.zero_grad()\n", + " outputs = self.model(images)\n", + "\n", + " loss = self.criterion(outputs, labels.flatten())\n", + " loss.backward()\n", + " self.optimizer.step()\n", + "\n", + " train_loss += loss.item() * images.size(0)\n", + " data_size = len(self.train_loader)*images.size(0)\n", + " \n", + " else:\n", + " self.local_validation_score = inference(self.model, self.test_loader, self.criterion)\n", + " \n", + " self.loss = train_loss/data_size\n", + " print(f'Epoch {t+1} | Train Loss: {self.loss:.4f} | Local Acc: {self.local_validation_score:.4f}')\n", + "\n", + " if phase == 'val' and self.local_validation_score > best_acc:\n", + " best_acc = self.local_validation_score\n", + " torch.save(self.model.state_dict(), os.path.join('weights', f'{self.input}','model.pth'))\n", + " torch.save(self.optimizer.state_dict(), os.path.join('weights', f'{self.input}','optimizer.pth'))\n", + " \n", + " self.training_completed = True\n", + " self.next(self.local_model_validation)\n", + " \n", + " @collaborator\n", + " def local_model_validation(self):\n", + " self.local_validation_score = inference(self.model,self.test_loader, self.criterion)\n", + " print(f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')\n", + " self.next(self.join, exclude=['training_completed'])\n", + "\n", + " @aggregator\n", + " def join(self,inputs):\n", + " self.average_loss = sum(input.loss for input in inputs)/len(inputs)\n", + " self.aggregated_model_accuracy = sum(input.agg_validation_score for input in inputs)/len(inputs)\n", + " self.local_model_accuracy = sum(input.local_validation_score for input in inputs)/len(inputs)\n", + " print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')\n", + " print(f'Average training loss = {self.average_loss}')\n", + " print(f'Average local model validation values = {self.local_model_accuracy}')\n", + " self.model = FedAvg([input.model for input in inputs])\n", + " self.optimizer = [input.optimizer for input in inputs][0]\n", + " \n", + " torch.save(self.model.state_dict(), os.path.join('weights', 'aggregated_model.pth'))\n", + " torch.save(self.optimizer.state_dict(), os.path.join('weights', 'aggregated_optimizer.pth'))\n", + " \n", + " self.current_round += 1\n", + " if self.current_round < self.rounds:\n", + " self.next(self.aggregated_model_validation, foreach='collaborators', exclude=['private'])\n", + " else:\n", + " self.next(self.end)\n", + " \n", + " @aggregator\n", + " def end(self):\n", + " print(f'This is the end of the flow') " + ] + }, + { + "cell_type": "markdown", + "id": "34923542", + "metadata": {}, + "source": [ + "You'll notice in the `FederatedFlow` definition above that there were certain attributes that the flow was not initialized with, namely the `train_loader` and `test_loader` for each of the collaborators. These are **private_attributes** that are exposed only through the runtime. Each participant has its own set of private attributes: a dictionary where the key is the attribute name, and the value is the object that will be made accessible through that participant's task. \n", + "\n", + "Below, we segment shards of the PathMNIST dataset for **four collaborators**: Portland, Seattle, Chandler, and Portland. Each has their own slice of the dataset that's accessible via the `train_loader` or `test_loader` attribute. Note that the private attributes are flexible, and you can choose to pass in a completely different type of object to any of the collaborators or aggregator (with an arbitrary name). These private attributes will always be filtered out of the current state when transfering from collaborator to aggregator, or vice versa. \n", + "\n", + "You'll see that, for the sake of this demonstration, we simply sample an event amount of data from our main dataset and assign them to each collaborator. It is also here that we define `BATCH_SIZE`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b887419", + "metadata": {}, + "outputs": [], + "source": [ + "BATCH_SIZE = 8\n", + "\n", + "# Setup participants\n", + "aggregator = Aggregator()\n", + "aggregator.private_attributes = {}\n", + "\n", + "# Setup collaborators with private attributes\n", + "collaborator_names = ['Portland', 'Seattle', 'Chandler','Bangalore']\n", + "collaborators = [Collaborator(name=name) for name in collaborator_names]\n", + "\n", + "for idx, collaborator in enumerate(collaborators):\n", + " train_subset_indices = np.array(range(idx,len(medmnist_train),len(collaborators)))\n", + " local_train = Subset(medmnist_train, train_subset_indices)\n", + " \n", + " test_subset_indices = np.array(range(idx,len(medmnist_test),len(collaborators)))\n", + " local_test = Subset(medmnist_test, test_subset_indices)\n", + " collaborator.private_attributes = {\n", + " 'train_loader': DataLoader(dataset=local_train, batch_size=BATCH_SIZE, shuffle=True),\n", + " 'test_loader': DataLoader(dataset=local_test, batch_size=BATCH_SIZE, shuffle=True)\n", + " }\n", + "\n", + "local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators, backend='single_process')\n", + "print(f'Local runtime collaborators = {local_runtime.collaborators}')" + ] + }, + { + "cell_type": "markdown", + "id": "a96af773", + "metadata": {}, + "source": [ + "Now that we have our flow and runtime defined, let's run the experiment!\n", + "\n", + "We will begin by defining a base model, optimizer, and loss function that'll be used by each collaborator. You may also define the number of rounds and epochs here if you do not wish to use the default values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50da077d", + "metadata": {}, + "outputs": [], + "source": [ + "model = CustomVisionTransformer(n_classes)\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)\n", + "criterion = nn.CrossEntropyLoss()\n", + "\n", + "flflow = FederatedFlow(model=model, optimizer=optimizer, criterion=criterion)\n", + "flflow.runtime = local_runtime\n", + "flflow.run()" + ] + }, + { + "cell_type": "markdown", + "id": "3b5f9315", + "metadata": {}, + "source": [ + "Now that the flow has completed, let's get the final model and accuracy. Note that the aggregated model accuracy was defined prior to the final training round. However, the saved out model is the final aggregated model during the `join` task" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef943794", + "metadata": {}, + "outputs": [], + "source": [ + "print(f'Sample of the final model weights: {flflow.model.state_dict()[\"classifier.weight\"][0]}')\n", + "\n", + "print(f'\\nFinal aggregated model accuracy for {flflow.rounds} rounds of training: {flflow.aggregated_model_accuracy}')" + ] + }, + { + "cell_type": "markdown", + "id": "20f65b76", + "metadata": {}, + "source": [ + "We can get the final model, and all other aggregator attributes after the flow completes. But what if there's an intermediate model task and its specific output that we want to look at in detail? This is where **checkpointing** and reuse of Metaflow tooling come in handy.\n", + "\n", + "Let's make a tweak to the flow object, and run the experiment one more time (we can even use our previous model / optimizer as a base for the experiment)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5bf07f3e", + "metadata": {}, + "outputs": [], + "source": [ + "flflow2 = FederatedFlow(model=flflow.model, optimizer=flflow.optimizer, criterion=flflow.criterion, \n", + " checkpoint=True)\n", + "\n", + "flflow2.runtime = local_runtime\n", + "flflow2.run()" + ] + }, + { + "cell_type": "markdown", + "id": "8eee9a9b", + "metadata": {}, + "source": [ + "Now that the flow is complete, let's dig into some of the information captured along the way. \n", + "\n", + "**Note:** this required `checkpoint=True` to be set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aeae05d4", + "metadata": {}, + "outputs": [], + "source": [ + "run_id = flflow2._run_id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fa68057", + "metadata": {}, + "outputs": [], + "source": [ + "import metaflow\n", + "from metaflow import Metaflow, Flow, Task, Step" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e33f22b", + "metadata": {}, + "outputs": [], + "source": [ + "m = Metaflow()\n", + "list(m)" + ] + }, + { + "cell_type": "markdown", + "id": "a2091889", + "metadata": {}, + "source": [ + "Let's look at the latest run that generated some results:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90ee1fea", + "metadata": {}, + "outputs": [], + "source": [ + "f = Flow('FederatedFlow').latest_run\n", + "list(f)" + ] + }, + { + "cell_type": "markdown", + "id": "2288f8b8", + "metadata": {}, + "source": [ + "And its list of steps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de42dca6", + "metadata": {}, + "outputs": [], + "source": [ + "s = Step(f'FederatedFlow/{run_id}/train')\n", + "list(s)" + ] + }, + { + "cell_type": "markdown", + "id": "a902ced0", + "metadata": {}, + "source": [ + "Now we see **4x** steps: **4** collaborators each performed **x** rounds of model training " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "854bd1fb", + "metadata": {}, + "outputs": [], + "source": [ + "t = Task(f'FederatedFlow/{run_id}/train/3')\n", + "t" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "769eb896", + "metadata": {}, + "outputs": [], + "source": [ + "t.data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7596d3a", + "metadata": {}, + "outputs": [], + "source": [ + "t.data.model" + ] + }, + { + "cell_type": "markdown", + "id": "7cad8ede", + "metadata": {}, + "source": [ + "Now let's look at its log output (stdout)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9fd662f", + "metadata": {}, + "outputs": [], + "source": [ + "print(t.stdout)" + ] + }, + { + "cell_type": "markdown", + "id": "ed71204c", + "metadata": {}, + "source": [ + "And any error logs? (stderr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8433dbd8", + "metadata": {}, + "outputs": [], + "source": [ + "print(t.stderr)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "openfl_ViT", + "language": "python", + "name": "openfl_vit" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/openfl-tutorials/experimental/Vision_Transformer/images/pathmnist.png b/openfl-tutorials/experimental/Vision_Transformer/images/pathmnist.png new file mode 100644 index 00000000000..f03feeda9d0 Binary files /dev/null and b/openfl-tutorials/experimental/Vision_Transformer/images/pathmnist.png differ diff --git a/openfl-tutorials/experimental/Vision_Transformer/images/vision_transformer.png b/openfl-tutorials/experimental/Vision_Transformer/images/vision_transformer.png new file mode 100644 index 00000000000..3d13b917284 Binary files /dev/null and b/openfl-tutorials/experimental/Vision_Transformer/images/vision_transformer.png differ diff --git a/openfl-tutorials/experimental/Vision_Transformer/images/workflow.png b/openfl-tutorials/experimental/Vision_Transformer/images/workflow.png new file mode 100644 index 00000000000..73dbb0c333a Binary files /dev/null and b/openfl-tutorials/experimental/Vision_Transformer/images/workflow.png differ diff --git a/openfl-tutorials/experimental/Vision_Transformer/requirements_vision_transformer.txt b/openfl-tutorials/experimental/Vision_Transformer/requirements_vision_transformer.txt new file mode 100644 index 00000000000..6fc96a35844 --- /dev/null +++ b/openfl-tutorials/experimental/Vision_Transformer/requirements_vision_transformer.txt @@ -0,0 +1,4 @@ +torch==2.0.1 +torchvision==0.15.2 +medmnist==2.2.2 +transformers==4.30.1 \ No newline at end of file diff --git a/openfl-tutorials/experimental/Vision_Transformer/requirements_workflow_interface.txt b/openfl-tutorials/experimental/Vision_Transformer/requirements_workflow_interface.txt new file mode 100644 index 00000000000..a721bb7e289 --- /dev/null +++ b/openfl-tutorials/experimental/Vision_Transformer/requirements_workflow_interface.txt @@ -0,0 +1,3 @@ +dill==0.3.6 +metaflow==2.7.15 +ray==2.2.0 \ No newline at end of file