diff --git a/tinyfl/model.py b/tinyfl/model.py index 80bdc54..5de15cb 100644 --- a/tinyfl/model.py +++ b/tinyfl/model.py @@ -23,7 +23,7 @@ def __init__(self) -> None: nn.Linear(128, 10), ) - def forward(self, x): + def forward(self, x: DataLoader)-> DataLoader: x = self.flatten(x) x = self.layers(x) return x @@ -63,7 +63,7 @@ def test_model(model: Model, testloader: DataLoader) -> Tuple[float, float]: return 100 * correct, test_loss -def fedavg_models(weights): +def fedavg_models(weights: list[dict])-> dict: avg = copy.deepcopy(weights[0]) for i in range(1, len(weights)): for key in avg: @@ -78,6 +78,16 @@ def fedavg_models(weights): def stratified_split_dataset(dataset: Dataset, num_parties: int) -> List[List[int]]: + """Splits dataset among parties for local models to be trained upon + + Args: + dataset: Dataset for the models to be trained on + num_parties: number of parties + + Returns: + + + """ def partition_list(l, n): indices = list(range(len(l))) shuffle(indices) diff --git a/tinyfl/scorer.py b/tinyfl/scorer.py index f474e03..06110f9 100644 --- a/tinyfl/scorer.py +++ b/tinyfl/scorer.py @@ -1,18 +1,52 @@ from tinyfl.model import create_model, test_model import numpy as np +from torch.utils.data import DataLoader +from typing import Any,List,Mapping -def _compute_accuracy(weight, testloader): +def _compute_accuracy(weight:Mapping[str, Any], testloader: DataLoader)-> float: + """ + Computes accuracy of model. + + Compares output of model with current set of weights to calculate percentage of correct answers. + + Args: + weight: Weights of the model stored in a dictionary + testloader: The loaded dataset + + Returns: + A float value of the accuracy of the model (% of correct answers) + """ model = create_model() model.load_state_dict(weight) return test_model(model, testloader)[0] -def accuracy_scorer(weights, testloader): +def accuracy_scorer(weights: List[Mapping[str, Any]], testloader: DataLoader)-> List(float): + """Computes accuracy of models. + + Args: + weights: A list of weights of each model which are stored in dictionaries + testloader: The loaded dataset + + Returns: + A list with float values of the accuracies of the models (% of correct answers) + """ return [_compute_accuracy(weight, testloader) for weight in weights] -def marginal_gain_scorer(weights, prev_scores, testloader): +def marginal_gain_scorer(weights: List[Mapping[str, Any]], prev_scores: List[float], testloader: DataLoader)-> List[float]: + """Calculates marginal gain in accuracy of model + + Calculates increase in accuracy of model after pulling wieghts + + Args: + weights: A list of weights of each model which are stored in dictionaries + prev_scores: List storing accuracies of models prior to most recent updation of weights + + Returns: + List of floats which represent the marginal increases in accuracies(if any) of each party + """ assert len(weights) == len(prev_scores) return [ max(a - b, 0) @@ -23,7 +57,7 @@ def marginal_gain_scorer(weights, prev_scores, testloader): ] -def multikrum_scorer(weights): +def multikrum_scorer(weights: List[Mapping[str, Any]]): R = len(weights) f = R // 3 - 1 closest_updates = R - f - 2