Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Straggler handling Follow-up #1097

Open
wants to merge 42 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2606ed6
Renamed Straggler Handling package
ishant162 Oct 18, 2024
87152a5
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Oct 18, 2024
c779a4a
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Oct 21, 2024
2f3873d
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Oct 23, 2024
896786f
Incorporated review comments
ishant162 Oct 23, 2024
f5cb6d6
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Oct 25, 2024
2f8ad23
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Oct 25, 2024
be8aeb7
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Oct 25, 2024
989828d
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Nov 2, 2024
92c8871
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Nov 8, 2024
664774a
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Nov 11, 2024
bd23ceb
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Nov 12, 2024
92f8497
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Nov 13, 2024
1c86a4c
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Nov 19, 2024
e89959b
Merge branch 'develop' into straggler_handling_update
ishant162 Nov 27, 2024
03f2bc1
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Dec 2, 2024
cab6d9f
resolving merge conflicts
ishant162 Dec 17, 2024
3969416
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Dec 19, 2024
eff79fc
Incorporated karan's review comments
ishant162 Dec 20, 2024
7f16cdb
Resolving merge conflicts
ishant162 Dec 20, 2024
f55e890
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Dec 20, 2024
afd4bbb
Merge branch 'develop' into straggler_handling_update
ishant162 Dec 30, 2024
4c2ebf7
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 3, 2025
83e2b9a
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 3, 2025
d3e92eb
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 5, 2025
1c9341b
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 7, 2025
9f27f5f
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 8, 2025
6ed5267
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 10, 2025
d117bd2
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 13, 2025
2aa0b22
Merge branch 'develop' into straggler_handling_update
ishant162 Jan 15, 2025
a66b5ac
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 16, 2025
8e0ac64
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 16, 2025
a1c7fae
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 16, 2025
fbaf7f0
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 16, 2025
8cc37c4
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 17, 2025
8f2fc6b
Fix code format
ishant162 Jan 17, 2025
76d3b82
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 20, 2025
784229b
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 21, 2025
b4ac4f4
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 22, 2025
3133cb6
Incorporated review comments
ishant162 Jan 22, 2025
833eb20
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 22, 2025
62c313b
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ The Open Federated Learning (OpenFL) framework supports straggler handling inter

The following are the straggler handling algorithms supported in OpenFL:

``CutoffTimeBasedStragglerHandling``
``CutoffPolicy``
ishant162 marked this conversation as resolved.
Show resolved Hide resolved
Identifies stragglers based on the cutoff time specified in the settings. Arguments to the function are:
- *Cutoff Time* (straggler_cutoff_time), specifies the cutoff time by which the aggregator should end the round early.
- *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model.

For example, in a federation of 5 collaborators, if :code:`straggler_cutoff_time` (in seconds) is set to 20 and :code:`minimum_reporting` is set to 2, atleast 2 collaborators (or more) would be included in the round, provided that the time limit of 20 seconds is not exceeded.
In an event where :code:`minimum_reporting` collaborators don't make it within the :code:`straggler_cutoff_time`, the straggler handling policy is disregarded.

``PercentageBasedStragglerHandling``
``PercentagePolicy``
Identifies stragglers based on the percetage specified. Arguments to the function are:
- *Percentage of collaborators* (percent_collaborators_needed), specifies a percentage of collaborators enough to end the round early.
- *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model.
Expand All @@ -29,12 +29,12 @@ The following are the straggler handling algorithms supported in OpenFL:
Demonstration of adding the straggler handling interface
=========================================================

The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentageBasedStragglerHandling``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffTimeBasedStragglerHandling** function instead:
The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentagePolicy``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffPolicy** function instead:

.. code-block:: yaml

straggler_handling_policy :
template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling
template : openfl.component.aggregator.straggler_handling.CutoffPolicy
settings :
straggler_cutoff_time : 20
minimum_reporting : 1
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml

straggler_handling_policy :
template : openfl.component.straggler_handling_functions.PercentageBasedStragglerHandling
template : openfl.component.aggregator.straggler_handling.PercentagePolicy
settings :
percent_collaborators_needed : 0.5
minimum_reporting : 1
14 changes: 5 additions & 9 deletions openfl/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@


from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.aggregator.straggler_handling import (
CutoffPolicy,
PercentagePolicy,
StragglerPolicy,
)
from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner
from openfl.component.assigner.static_grouped_assigner import StaticGroupedAssigner
from openfl.component.collaborator.collaborator import Collaborator
from openfl.component.straggler_handling_functions.cutoff_time_based_straggler_handling import (
CutoffTimeBasedStragglerHandling,
)
from openfl.component.straggler_handling_functions.percentage_based_straggler_handling import (
PercentageBasedStragglerHandling,
)
from openfl.component.straggler_handling_functions.straggler_handling_function import (
StragglerHandlingPolicy,
)
5 changes: 5 additions & 0 deletions openfl/component/aggregator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@


from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.aggregator.straggler_handling import (
CutoffPolicy,
PercentagePolicy,
StragglerPolicy,
)
10 changes: 4 additions & 6 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import List, Optional

import openfl.callbacks as callbacks_module
from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling
from openfl.component.aggregator.straggler_handling import CutoffPolicy, StragglerPolicy
from openfl.databases import PersistentTensorDB, TensorDB
from openfl.interface.aggregation_functions import WeightedAverage
from openfl.pipelines import NoCompressionPipeline, TensorCodec
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
last_state_path,
assigner,
use_delta_updates=True,
straggler_handling_policy=None,
straggler_handling_policy: StragglerPolicy = CutoffPolicy,
rounds_to_train=256,
single_col_cert_common_name=None,
compression_pipeline=None,
Expand All @@ -100,7 +100,6 @@ def __init__(
weight.
assigner: Assigner object.
straggler_handling_policy (optional): Straggler handling policy.
Defaults to CutoffTimeBasedStragglerHandling.
rounds_to_train (int, optional): Number of rounds to train.
Defaults to 256.
single_col_cert_common_name (str, optional): Common name for single
Expand All @@ -127,9 +126,8 @@ def __init__(
# FIXME: "" instead of None is for protobuf compatibility.
self.single_col_cert_common_name = single_col_cert_common_name or ""

self.straggler_handling_policy = (
straggler_handling_policy or CutoffTimeBasedStragglerHandling()
)
self.straggler_handling_policy = straggler_handling_policy()

self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,68 @@
# SPDX-License-Identifier: Apache-2.0


"""Cutoff time based Straggler Handling function."""
"""Straggler handling module."""

import threading
import time
from abc import ABC, abstractmethod
from logging import getLogger
from typing import Callable

import numpy as np

from openfl.component.straggler_handling_functions.straggler_handling_function import (
StragglerHandlingPolicy,
)
logger = getLogger(__name__)


class CutoffTimeBasedStragglerHandling(StragglerHandlingPolicy):
class StragglerPolicy(ABC):
"""Federated Learning straggler handling interface."""

@abstractmethod
def start_policy(self, **kwargs) -> None:
"""
Start straggler handling policy for collaborator for a particular round.
NOTE: Refer CutoffPolicy for reference.
ishant162 marked this conversation as resolved.
Show resolved Hide resolved

Args:
**kwargs
"""
raise NotImplementedError

@abstractmethod
def reset_policy_for_round(self) -> None:
"""Reset policy for the next round."""
raise NotImplementedError

@abstractmethod
def straggler_cutoff_check(
self, num_collaborators_done: int, num_all_collaborators: int, **kwargs
) -> bool:
"""
Determines whether it is time to end the round early.

Args:
num_collaborators_done: int
Number of collaborators finished.
num_all_collaborators: int
Total number of collaborators.

Returns:
bool: True if it is time to end the round early, False otherwise.

Raises:
NotImplementedError: This method must be implemented by a subclass.
"""
raise NotImplementedError


class CutoffPolicy(StragglerPolicy):
"""Cutoff time based Straggler Handling function."""

def __init__(
self, round_start_time=None, straggler_cutoff_time=np.inf, minimum_reporting=1, **kwargs
):
"""
Initialize a CutoffTimeBasedStragglerHandling object.
Initialize a CutoffPolicy object.

Args:
round_start_time (optional): The start time of the round. Defaults
Expand All @@ -40,21 +80,16 @@ def __init__(
self.round_start_time = round_start_time
self.straggler_cutoff_time = straggler_cutoff_time
self.minimum_reporting = minimum_reporting
self.logger = getLogger(__name__)
self.is_timer_started = False

if self.straggler_cutoff_time == np.inf:
self.logger.warning(
"CutoffTimeBasedStragglerHandling is disabled as straggler_cutoff_time "
"is set to np.inf."
)
logger.warning("CutoffPolicy is disabled as straggler_cutoff_time is set to np.inf.")

def reset_policy_for_round(self) -> None:
"""
Reset timer for the next round.
"""
"""Reset timer for the next round."""
if hasattr(self, "timer"):
self.timer.cancel()
delattr(self, "timer")
self.is_timer_started = False

def start_policy(self, callback: Callable) -> None:
"""
Expand All @@ -64,22 +99,21 @@ def start_policy(self, callback: Callable) -> None:
Args:
callback: Callable
Callback function for when straggler_cutoff_time elapses

Returns:
None
"""
# If straggler_cutoff_time is set to infinity
# or if the timer is already running,
# do not start the policy.
if self.straggler_cutoff_time == np.inf or hasattr(self, "timer"):
if self.straggler_cutoff_time == np.inf or self.is_timer_started:
return

self.round_start_time = time.time()
self.timer = threading.Timer(
self.straggler_cutoff_time,
callback,
)
self.timer.daemon = True
self.timer.start()
self.is_timer_started = True

def straggler_cutoff_check(
self,
Expand Down Expand Up @@ -108,13 +142,13 @@ def straggler_cutoff_check(
# Time has expired
# Check if minimum_reporting collaborators have reported results
elif self.__minimum_collaborators_reported(num_collaborators_done):
self.logger.info(
logger.info(
f"{num_collaborators_done} collaborators have reported results. "
"Applying cutoff policy and proceeding with end of round."
)
return True
else:
self.logger.info(
logger.info(
f"Waiting for minimum {self.minimum_reporting} collaborator(s) to report results."
)
return False
Expand All @@ -141,3 +175,66 @@ def __minimum_collaborators_reported(self, num_collaborators_done) -> bool:
False otherwise.
"""
return num_collaborators_done >= self.minimum_reporting


class PercentagePolicy(StragglerPolicy):
"""Percentage based Straggler Handling function."""

def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs):
"""Initialize a PercentagePolicy object.

Args:
percent_collaborators_needed (float, optional): The percentage of
collaborators needed. Defaults to 1.0.
minimum_reporting (int, optional): The minimum number of
collaborators that should report. Defaults to 1.
ishant162 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Variable length argument list.
"""
ishant162 marked this conversation as resolved.
Show resolved Hide resolved
if minimum_reporting <= 0:
raise ValueError("minimum_reporting must be >0")
ishant162 marked this conversation as resolved.
Show resolved Hide resolved

self.percent_collaborators_needed = percent_collaborators_needed
self.minimum_reporting = minimum_reporting

def reset_policy_for_round(self) -> None:
"""Not required in PercentagePolicy."""
pass

def start_policy(self, **kwargs) -> None:
"""Not required in PercentagePolicy."""
pass

def straggler_cutoff_check(
self,
num_collaborators_done: int,
num_all_collaborators: int,
) -> bool:
"""
If percent_collaborators_needed and minimum_reporting collaborators have
reported results, then it is time to end round early.

Args:
num_collaborators_done (int): The number of collaborators that
have reported.
all_collaborators (list): All the collaborators.

Returns:
bool: True if the straggler cutoff conditions are met, False
otherwise.
"""
return (
num_collaborators_done >= self.percent_collaborators_needed * num_all_collaborators
) and self.__minimum_collaborators_reported(num_collaborators_done)

def __minimum_collaborators_reported(self, num_collaborators_done) -> bool:
"""Check if the minimum number of collaborators have reported.

Args:
num_collaborators_done (int): The number of collaborators that
have reported.

Returns:
bool: True if the minimum number of collaborators have reported,
False otherwise.
"""
return num_collaborators_done >= self.minimum_reporting
13 changes: 0 additions & 13 deletions openfl/component/straggler_handling_functions/__init__.py

This file was deleted.

Loading
Loading