Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Aggregation functions library #901

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions openfl-tutorials/experimental/Workflow_Interface_101_MNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"id": "14821d97",
"metadata": {},
"source": [
"# Workflow Interface 101: Quickstart\n",
"# Workflow Interface 101: MNIST with FedAvg aggregation algorithm\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/openfl/blob/develop/openfl-tutorials/experimental/Workflow_Interface_101_MNIST.ipynb)"
]
},
Expand Down Expand Up @@ -94,6 +94,9 @@
"metadata": {},
"outputs": [],
"source": [
"# !pip install torch\n",
"# !pip install torchvision\n",
"\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
Expand Down Expand Up @@ -198,17 +201,9 @@
"from openfl.experimental.runtime import LocalRuntime\n",
"from openfl.experimental.placement import aggregator, collaborator\n",
"\n",
"\n",
"def FedAvg(models, weights=None):\n",
" new_model = models[0]\n",
" state_dicts = [model.state_dict() for model in models]\n",
" state_dict = new_model.state_dict()\n",
" for key in models[1].state_dict():\n",
" state_dict[key] = torch.from_numpy(np.average([state[key].numpy() for state in state_dicts],\n",
" axis=0, \n",
" weights=weights))\n",
" new_model.load_state_dict(state_dict)\n",
" return new_model"
"#Import plugin adapter and FedAvg aggregation algorithm\n",
"from openfl.plugins.frameworks_adapters.pytorch_adapter import FrameworkAdapterPlugin as fa\n",
"from openfl.experimental.interface.aggregation_functions.fedavg import FedAvg\n"
]
},
{
Expand Down Expand Up @@ -247,6 +242,7 @@
" self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n",
" momentum=momentum)\n",
" self.rounds = rounds\n",
" self.agg_func = FedAvg()\n",
"\n",
" @aggregator\n",
" def start(self):\n",
Expand All @@ -264,7 +260,9 @@
" self.next(self.train)\n",
"\n",
" @collaborator\n",
" def train(self):\n",
" def train(self): \n",
" self.train_dataset_length = len(self.train_loader.dataset)\n",
" \n",
" self.model.train()\n",
" self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n",
" momentum=momentum)\n",
Expand Down Expand Up @@ -302,8 +300,25 @@
" print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')\n",
" print(f'Average training loss = {self.average_loss}')\n",
" print(f'Average local model validation values = {self.local_model_accuracy}')\n",
" self.model = FedAvg([input.model for input in inputs])\n",
" self.optimizer = [input.optimizer for input in inputs][0]\n",
"\n",
" model_weights, collaborator_weights=[],[]\n",
" for input in inputs:\n",
" collaborator_weights.append(input.train_dataset_length/len(mnist_train.data))\n",
" keys_list, inner_tensors_list = [], []\n",
" for k,v in (fa.get_tensor_dict(input.model, input.optimizer)).items():\n",
" if k == '__opt_state_needed':\n",
" continue\n",
" else:\n",
" inner_tensors_list.append(v)\n",
" keys_list.append(k)\n",
" model_weights.append(inner_tensors_list)\n",
" avg_tensors = self.agg_func.aggregate_models(model_weights, collaborator_weights)\n",
" \n",
" state_dict = dict(zip(keys_list, avg_tensors))\n",
" # Add back __opt_state_needed key\n",
" state_dict['__opt_state_needed'] = 'true'\n",
" fa.set_tensor_dict(self.model, state_dict, self.optimizer)\n",
" \n",
" self.current_round += 1\n",
" if self.current_round < self.rounds:\n",
" self.next(self.aggregated_model_validation,\n",
Expand Down Expand Up @@ -343,8 +358,11 @@
"aggregator.private_attributes = {}\n",
"\n",
"# Setup collaborators with private attributes\n",
"collaborator_names = ['Portland', 'Seattle', 'Chandler','Bangalore']\n",
"collaborator_names = ['Portland', 'Seattle', 'Chandler', 'Bangalore']\n",
"collaborators = [Collaborator(name=name) for name in collaborator_names]\n",
"# Keep a list of collaborator weights. The weights are decided by the number of samples for each collaborator\n",
"collaborators_weights_dict = {}\n",
"\n",
"for idx, collaborator in enumerate(collaborators):\n",
" local_train = deepcopy(mnist_train)\n",
" local_test = deepcopy(mnist_test)\n",
Expand All @@ -356,6 +374,15 @@
" 'train_loader': torch.utils.data.DataLoader(local_train,batch_size=batch_size_train, shuffle=True),\n",
" 'test_loader': torch.utils.data.DataLoader(local_test,batch_size=batch_size_train, shuffle=True)\n",
" }\n",
" collaborators_weights_dict[collaborator] = len(local_train.data)\n",
"\n",
"for col in collaborators_weights_dict:\n",
" collaborators_weights_dict[col] /= len(mnist_train.data)\n",
"\n",
"if len(collaborators_weights_dict) != 0:\n",
" assert np.abs(1.0 - sum(collaborators_weights_dict.values())) < 0.01, (\n",
" f'Collaborator weights do not sum to 1.0: {collaborators_weights_dict}'\n",
" )\n",
"\n",
"local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators, backend='single_process')\n",
"print(f'Local runtime collaborators = {local_runtime.collaborators}')"
Expand Down
Loading