Skip to content

Commit

Permalink
- add --selected_task_group option to aggregator cli, default to "lea…
Browse files Browse the repository at this point in the history
…rning" (#1258)

- enhance Aggregator to take selected_task_group attribute to enable fedeval or learning switching at aggregator level
- rebase 16.Jan.1
- fix aggregator cli test cases as per new "selected_task_group" field in start
- changed default assigner task_group name to "learning" and "evaluation"
- updated worspaces to use new task_group names - learning / evaluation
- updated as per review comments
- update the FedEval documentation with e2e usage steps
- Rebased 15-Jan.1
- Fixed docs indentation issue,reduced the verbosity in doc
Signed-off-by: Shailesh Pant <[email protected]>
  • Loading branch information
ishaileshpant authored Jan 16, 2025
1 parent b8e2c70 commit 69a2ceb
Show file tree
Hide file tree
Showing 7 changed files with 560 additions and 31 deletions.
482 changes: 467 additions & 15 deletions docs/about/features_index/fed_eval.rst

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion openfl-workspace/torch_cnn_mnist/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ aggregator:
assigner:
settings:
task_groups:
- name: train_and_validate
- name: learning
percentage: 1.0
tasks:
- aggregated_model_validation
Expand Down
2 changes: 1 addition & 1 deletion openfl-workspace/workspace/plan/defaults/assigner.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
template : openfl.component.RandomGroupedAssigner
settings :
task_groups :
- name : train_and_validate
- name : learning
percentage : 1.0
tasks :
- aggregated_model_validation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
template : openfl.component.RandomGroupedAssigner
settings :
task_groups :
- name : validate
- name : evaluation
percentage : 1.0
tasks :
- aggregated_model_validation
10 changes: 8 additions & 2 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


"""Aggregator module."""

import logging
Expand Down Expand Up @@ -85,6 +84,7 @@ def __init__(
callbacks: Optional[List] = None,
persist_checkpoint=True,
persistent_db_path=None,
task_group: str = "learning",
):
"""Initializes the Aggregator.
Expand All @@ -111,7 +111,9 @@ def __init__(
Defaults to 1.
initial_tensor_dict (dict, optional): Initial tensor dictionary.
callbacks: List of callbacks to be used during the experiment.
task_group (str, optional): Selected task_group for assignment.
"""
self.task_group = task_group
self.round_number = 0
self.next_model_round_number = 0

Expand Down Expand Up @@ -298,9 +300,13 @@ def _load_initial_tensors(self):
self.model, compression_pipeline=self.compression_pipeline
)

if round_number > self.round_number:
# Check selected task_group before updating round number
if self.task_group == "evaluation":
logger.info(f"Skipping round_number check for {self.task_group} task_group")
elif round_number > self.round_number:
logger.info(f"Starting training from round {round_number} of previously saved model")
self.round_number = round_number

tensor_key_dict = {
TensorKey(k, self.uuid, self.round_number, False, ("model",)): v
for k, v in tensor_dict.items()
Expand Down
49 changes: 41 additions & 8 deletions openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


"""Aggregator module."""

import sys
from logging import getLogger
from pathlib import Path

from click import Path as ClickPath
from click import confirm, echo, group, option, pass_context, style
from click import (
Path as ClickPath,
)
from click import (
confirm,
echo,
group,
option,
pass_context,
style,
)

from openfl.cryptography.ca import sign_certificate
from openfl.cryptography.io import get_csr_hash, read_crt, read_csr, read_key, write_crt, write_key
from openfl.cryptography.io import (
get_csr_hash,
read_crt,
read_csr,
read_key,
write_crt,
write_key,
)
from openfl.cryptography.participant import generate_csr
from openfl.federated import Plan
from openfl.interface.cli_helper import CERT_DIR
Expand Down Expand Up @@ -52,24 +67,42 @@ def aggregator(context):
default="plan/cols.yaml",
type=ClickPath(exists=True),
)
def start_(plan, authorized_cols):
"""Start the aggregator service."""
@option(
"--task_group",
required=False,
default="learning",
help="Selected task-group for assignment - defaults to learning",
)
def start_(plan, authorized_cols, task_group):
"""Start the aggregator service.
Args:
plan (str): Path to plan config file
authorized_cols (str): Path to authorized collaborators file
task_group (str): Selected task-group for assignement - defaults to 'learning'
"""
if is_directory_traversal(plan):
echo("Federated learning plan path is out of the openfl workspace scope.")
sys.exit(1)
if is_directory_traversal(authorized_cols):
echo("Authorized collaborator list file path is out of the openfl workspace scope.")
sys.exit(1)

plan = Plan.parse(
# Parse plan and override mode if specified
parsed_plan = Plan.parse(
plan_config_path=Path(plan).absolute(),
cols_config_path=Path(authorized_cols).absolute(),
)

# Set task_group in aggregator settings
if "settings" not in parsed_plan.config["aggregator"]:
parsed_plan.config["aggregator"]["settings"] = {}
parsed_plan.config["aggregator"]["settings"]["task_group"] = task_group
logger.info(f"Setting aggregator to assign: {task_group} task_group")

logger.info("🧿 Starting the Aggregator Service.")

plan.get_server().serve()
parsed_plan.get_server().serve()


@aggregator.command(name="generate-cert-request")
Expand Down
44 changes: 41 additions & 3 deletions tests/openfl/interface/test_aggregator_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,19 @@ def test_aggregator_start(mock_parse):
plan_config = plan_path.joinpath('plan.yaml')
cols_config = plan_path.joinpath('cols.yaml')

mock_parse.return_value = mock.Mock()
# Create a mock plan with the required fields
mock_plan = mock.MagicMock()
mock_plan.__getitem__.side_effect = {'task_group': 'learning'}.get
mock_plan.get = {'task_group': 'learning'}.get
# Add the config attribute with proper nesting
mock_plan.config = {
'aggregator': {
'settings': {
'task_group': 'learning'
}
}
}
mock_parse.return_value = mock_plan

ret = start_(['-p', plan_config,
'-c', cols_config], standalone_mode=False)
Expand All @@ -32,7 +44,20 @@ def test_aggregator_start_illegal_plan(mock_parse, mock_is_directory_traversal):
plan_config = plan_path.joinpath('plan.yaml')
cols_config = plan_path.joinpath('cols.yaml')

mock_parse.return_value = mock.Mock()
# Create a mock plan with the required fields
mock_plan = mock.MagicMock()
mock_plan.__getitem__.side_effect = {'task_group': 'learning'}.get
mock_plan.get = {'task_group': 'learning'}.get
# Add the config attribute with proper nesting
mock_plan.config = {
'aggregator': {
'settings': {
'task_group': 'learning'
}
}
}
mock_parse.return_value = mock_plan

mock_is_directory_traversal.side_effect = [True, False]

with TestCase.assertRaises(test_aggregator_start_illegal_plan, SystemExit):
Expand All @@ -48,7 +73,20 @@ def test_aggregator_start_illegal_cols(mock_parse, mock_is_directory_traversal):
plan_config = plan_path.joinpath('plan.yaml')
cols_config = plan_path.joinpath('cols.yaml')

mock_parse.return_value = mock.Mock()
# Create a mock plan with the required fields
mock_plan = mock.MagicMock()
mock_plan.__getitem__.side_effect = {'task_group': 'learning'}.get
mock_plan.get = {'task_group': 'learning'}.get
# Add the config attribute with proper nesting
mock_plan.config = {
'aggregator': {
'settings': {
'task_group': 'learning'
}
}
}
mock_parse.return_value = mock_plan

mock_is_directory_traversal.side_effect = [False, True]

with TestCase.assertRaises(test_aggregator_start_illegal_cols, SystemExit):
Expand Down

0 comments on commit 69a2ceb

Please sign in to comment.