Skip to content
This repository has been archived by the owner on Apr 27, 2024. It is now read-only.

Added docstrings, type hints to following functions: #19

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions tinyfl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self) -> None:
nn.Linear(128, 10),
)

def forward(self, x):
def forward(self, x: DataLoader)-> DataLoader:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take lite

Suggested change
def forward(self, x: DataLoader)-> DataLoader:
def forward(self, x):

x = self.flatten(x)
x = self.layers(x)
return x
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_model(model: Model, testloader: DataLoader) -> Tuple[float, float]:
return 100 * correct, test_loss


def fedavg_models(weights):
def fedavg_models(weights: list[dict])-> dict:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refer previous comments

avg = copy.deepcopy(weights[0])
for i in range(1, len(weights)):
for key in avg:
Expand All @@ -78,6 +78,16 @@ def fedavg_models(weights):


def stratified_split_dataset(dataset: Dataset, num_parties: int) -> List[List[int]]:
"""Splits dataset among parties for local models to be trained upon

Args:
dataset: Dataset for the models to be trained on
num_parties: number of parties

Returns:


"""
def partition_list(l, n):
indices = list(range(len(l)))
shuffle(indices)
Expand Down
42 changes: 38 additions & 4 deletions tinyfl/scorer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,52 @@
from tinyfl.model import create_model, test_model
import numpy as np
from torch.utils.data import DataLoader
from typing import Any,List,Mapping


def _compute_accuracy(weight, testloader):
def _compute_accuracy(weight:Mapping[str, Any], testloader: DataLoader)-> float:
"""
Computes accuracy of model.

Compares output of model with current set of weights to calculate percentage of correct answers.

Args:
weight: Weights of the model stored in a dictionary
testloader: The loaded dataset

Returns:
A float value of the accuracy of the model (% of correct answers)
"""
model = create_model()
model.load_state_dict(weight)
return test_model(model, testloader)[0]


def accuracy_scorer(weights, testloader):
def accuracy_scorer(weights: List[Mapping[str, Any]], testloader: DataLoader)-> List(float):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Suggested change
def accuracy_scorer(weights: List[Mapping[str, Any]], testloader: DataLoader)-> List(float):
def accuracy_scorer(weights: List[Mapping[str, Any]], testloader: DataLoader)-> List[float]:

"""Computes accuracy of models.

Args:
weights: A list of weights of each model which are stored in dictionaries
testloader: The loaded dataset

Returns:
A list with float values of the accuracies of the models (% of correct answers)
"""
return [_compute_accuracy(weight, testloader) for weight in weights]


def marginal_gain_scorer(weights, prev_scores, testloader):
def marginal_gain_scorer(weights: List[Mapping[str, Any]], prev_scores: List[float], testloader: DataLoader)-> List[float]:
"""Calculates marginal gain in accuracy of model

Calculates increase in accuracy of model after pulling wieghts

Args:
weights: A list of weights of each model which are stored in dictionaries
prev_scores: List storing accuracies of models prior to most recent updation of weights

Returns:
List of floats which represent the marginal increases in accuracies(if any) of each party
"""
assert len(weights) == len(prev_scores)
return [
max(a - b, 0)
Expand All @@ -23,7 +57,7 @@ def marginal_gain_scorer(weights, prev_scores, testloader):
]


def multikrum_scorer(weights):
def multikrum_scorer(weights: List[Mapping[str, Any]]):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return type is list of floats

Suggested change
def multikrum_scorer(weights: List[Mapping[str, Any]]):
def multikrum_scorer(weights: List[Mapping[str, Any]]) -> List[float]:

R = len(weights)
f = R // 3 - 1
closest_updates = R - f - 2
Expand Down