-
Notifications
You must be signed in to change notification settings - Fork 3
Added docstrings, type hints to following functions: #19
base: main
Are you sure you want to change the base?
Conversation
_compute_accuracy, accuracy_scorer, marginal_gain_scorer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also format using black, you can setup vscode to use it by default
accuracy_scorer, marginal_gain_scorer, removing unecessary imports
@@ -23,7 +57,7 @@ def marginal_gain_scorer(weights, prev_scores, testloader): | |||
] | |||
|
|||
|
|||
def multikrum_scorer(weights): | |||
def multikrum_scorer(weights: List[Mapping[str, Any]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return type is list of floats
def multikrum_scorer(weights: List[Mapping[str, Any]]): | |
def multikrum_scorer(weights: List[Mapping[str, Any]]) -> List[float]: |
model = create_model() | ||
model.load_state_dict(weight) | ||
return test_model(model, testloader)[0] | ||
|
||
|
||
def accuracy_scorer(weights, testloader): | ||
def accuracy_scorer(weights: List[Mapping[str, Any]], testloader: DataLoader)-> List(float): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
def accuracy_scorer(weights: List[Mapping[str, Any]], testloader: DataLoader)-> List(float): | |
def accuracy_scorer(weights: List[Mapping[str, Any]], testloader: DataLoader)-> List[float]: |
@@ -63,7 +63,7 @@ def test_model(model: Model, testloader: DataLoader) -> Tuple[float, float]: | |||
return 100 * correct, test_loss | |||
|
|||
|
|||
def fedavg_models(weights): | |||
def fedavg_models(weights: list[dict])-> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refer previous comments
@@ -23,7 +23,7 @@ def __init__(self) -> None: | |||
nn.Linear(128, 10), | |||
) | |||
|
|||
def forward(self, x): | |||
def forward(self, x: DataLoader)-> DataLoader: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
take lite
def forward(self, x: DataLoader)-> DataLoader: | |
def forward(self, x): |
Sorry I think I pushed the scorer.py changes by accident, I was waiting until after Druva implemented the other ML Model to complete it properly. I'll Fix these |
_compute_accuracy, accuracy_scorer, marginal_gain_scorer