From 2606ed6ffb1396376825dcce2d30bd89ce0a659d Mon Sep 17 00:00:00 2001 From: Ishant Thakare Date: Fri, 18 Oct 2024 12:15:28 +0530 Subject: [PATCH 1/6] Renamed Straggler Handling package Signed-off-by: Ishant Thakare --- .../straggler_handling_algorithms.rst | 2 +- docs/source/api/openfl_component.rst | 2 +- .../plan/plan.yaml | 2 +- openfl/component/__init__.py | 6 +- openfl/component/aggregator/aggregator.py | 2 +- .../straggler_handling_functions/__init__.py | 13 -- .../percentage_based_straggler_handling.py | 78 ----------- .../straggler_handling_function.py | 59 -------- .../straggler_handling_policy/__init__.py | 9 ++ .../straggler_handling_policy.py} | 126 +++++++++++++++++- openfl/federated/plan/plan.py | 2 +- setup.py | 2 +- 12 files changed, 138 insertions(+), 165 deletions(-) delete mode 100644 openfl/component/straggler_handling_functions/__init__.py delete mode 100644 openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py delete mode 100644 openfl/component/straggler_handling_functions/straggler_handling_function.py create mode 100644 openfl/component/straggler_handling_policy/__init__.py rename openfl/component/{straggler_handling_functions/cutoff_time_based_straggler_handling.py => straggler_handling_policy/straggler_handling_policy.py} (54%) diff --git a/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst b/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst index a71b1385e5..d5f00500e0 100644 --- a/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst +++ b/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst @@ -34,7 +34,7 @@ The example template, **torch_cnn_mnist_straggler_check**, uses the ``Percentage .. code-block:: yaml straggler_handling_policy : - template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling + template : openfl.component.straggler_handling_policy.CutoffTimeBasedStragglerHandling settings : straggler_cutoff_time : 20 minimum_reporting : 1 diff --git a/docs/source/api/openfl_component.rst b/docs/source/api/openfl_component.rst index 8deb6528ac..67639599f5 100644 --- a/docs/source/api/openfl_component.rst +++ b/docs/source/api/openfl_component.rst @@ -17,4 +17,4 @@ Component modules reference: openfl.component.collaborator openfl.component.director openfl.component.envoy - openfl.component.straggler_handling_functions + openfl.component.straggler_handling_policy diff --git a/openfl-workspace/torch_cnn_mnist_straggler_check/plan/plan.yaml b/openfl-workspace/torch_cnn_mnist_straggler_check/plan/plan.yaml index a42b064e56..cca9414fb6 100644 --- a/openfl-workspace/torch_cnn_mnist_straggler_check/plan/plan.yaml +++ b/openfl-workspace/torch_cnn_mnist_straggler_check/plan/plan.yaml @@ -45,7 +45,7 @@ compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml straggler_handling_policy : - template : openfl.component.straggler_handling_functions.PercentageBasedStragglerHandling + template : openfl.component.straggler_handling_policy.CutoffTimeBasedStragglerHandling settings : percent_collaborators_needed : 0.5 minimum_reporting : 1 \ No newline at end of file diff --git a/openfl/component/__init__.py b/openfl/component/__init__.py index 3b787f87d0..97788f4a6f 100644 --- a/openfl/component/__init__.py +++ b/openfl/component/__init__.py @@ -7,12 +7,8 @@ from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner from openfl.component.assigner.static_grouped_assigner import StaticGroupedAssigner from openfl.component.collaborator.collaborator import Collaborator -from openfl.component.straggler_handling_functions.cutoff_time_based_straggler_handling import ( +from openfl.component.straggler_handling_policy import ( CutoffTimeBasedStragglerHandling, -) -from openfl.component.straggler_handling_functions.percentage_based_straggler_handling import ( PercentageBasedStragglerHandling, -) -from openfl.component.straggler_handling_functions.straggler_handling_function import ( StragglerHandlingPolicy, ) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 81d3e7411a..3540c0791e 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -8,7 +8,7 @@ from logging import getLogger from threading import Lock -from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling +from openfl.component.straggler_handling_policy import CutoffTimeBasedStragglerHandling from openfl.databases import TensorDB from openfl.interface.aggregation_functions import WeightedAverage from openfl.pipelines import NoCompressionPipeline, TensorCodec diff --git a/openfl/component/straggler_handling_functions/__init__.py b/openfl/component/straggler_handling_functions/__init__.py deleted file mode 100644 index 5ab0af1794..0000000000 --- a/openfl/component/straggler_handling_functions/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2020-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - - -from openfl.component.straggler_handling_functions.cutoff_time_based_straggler_handling import ( - CutoffTimeBasedStragglerHandling, -) -from openfl.component.straggler_handling_functions.percentage_based_straggler_handling import ( - PercentageBasedStragglerHandling, -) -from openfl.component.straggler_handling_functions.straggler_handling_function import ( - StragglerHandlingPolicy, -) diff --git a/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py b/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py deleted file mode 100644 index e556ea291b..0000000000 --- a/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2020-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - - -"""Percentage based Straggler Handling function.""" -from logging import getLogger - -from openfl.component.straggler_handling_functions.straggler_handling_function import ( - StragglerHandlingPolicy, -) - - -class PercentageBasedStragglerHandling(StragglerHandlingPolicy): - """Percentage based Straggler Handling function.""" - - def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs): - """Initialize a PercentageBasedStragglerHandling object. - - Args: - percent_collaborators_needed (float, optional): The percentage of - collaborators needed. Defaults to 1.0. - minimum_reporting (int, optional): The minimum number of - collaborators that should report. Defaults to 1. - **kwargs: Variable length argument list. - """ - if minimum_reporting <= 0: - raise ValueError("minimum_reporting must be >0") - - self.percent_collaborators_needed = percent_collaborators_needed - self.minimum_reporting = minimum_reporting - self.logger = getLogger(__name__) - - def reset_policy_for_round(self) -> None: - """ - Not required in PercentageBasedStragglerHandling. - """ - pass - - def start_policy(self, **kwargs) -> None: - """ - Not required in PercentageBasedStragglerHandling. - """ - pass - - def straggler_cutoff_check( - self, - num_collaborators_done: int, - num_all_collaborators: int, - ) -> bool: - """ - If percent_collaborators_needed and minimum_reporting collaborators have - reported results, then it is time to end round early. - - Args: - num_collaborators_done (int): The number of collaborators that - have reported. - all_collaborators (list): All the collaborators. - - Returns: - bool: True if the straggler cutoff conditions are met, False - otherwise. - """ - return ( - num_collaborators_done >= self.percent_collaborators_needed * num_all_collaborators - ) and self.__minimum_collaborators_reported(num_collaborators_done) - - def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: - """Check if the minimum number of collaborators have reported. - - Args: - num_collaborators_done (int): The number of collaborators that - have reported. - - Returns: - bool: True if the minimum number of collaborators have reported, - False otherwise. - """ - return num_collaborators_done >= self.minimum_reporting diff --git a/openfl/component/straggler_handling_functions/straggler_handling_function.py b/openfl/component/straggler_handling_functions/straggler_handling_function.py deleted file mode 100644 index 8bd47bc045..0000000000 --- a/openfl/component/straggler_handling_functions/straggler_handling_function.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2020-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - - -"""Straggler handling module.""" - -from abc import ABC, abstractmethod - - -class StragglerHandlingPolicy(ABC): - """Federated Learning straggler handling interface.""" - - @abstractmethod - def start_policy(self, **kwargs) -> None: - """ - Start straggler handling policy for collaborator for a particular round. - NOTE: Refer CutoffTimeBasedStragglerHandling for reference. - - Args: - **kwargs - - Returns: - None - """ - raise NotImplementedError - - @abstractmethod - def reset_policy_for_round(self) -> None: - """ - Reset policy variable for the next round. - - Args: - None - - Returns: - None - """ - raise NotImplementedError - - @abstractmethod - def straggler_cutoff_check( - self, num_collaborators_done: int, num_all_collaborators: int, **kwargs - ) -> bool: - """ - Determines whether it is time to end the round early. - - Args: - num_collaborators_done: int - Number of collaborators finished. - num_all_collaborators: int - Total number of collaborators. - - Returns: - bool: True if it is time to end the round early, False otherwise. - - Raises: - NotImplementedError: This method must be implemented by a subclass. - """ - raise NotImplementedError diff --git a/openfl/component/straggler_handling_policy/__init__.py b/openfl/component/straggler_handling_policy/__init__.py new file mode 100644 index 0000000000..4de28d0e48 --- /dev/null +++ b/openfl/component/straggler_handling_policy/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from openfl.component.straggler_handling_policy.straggler_handling_policy import ( + CutoffTimeBasedStragglerHandling, + PercentageBasedStragglerHandling, + StragglerHandlingPolicy, +) diff --git a/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py b/openfl/component/straggler_handling_policy/straggler_handling_policy.py similarity index 54% rename from openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py rename to openfl/component/straggler_handling_policy/straggler_handling_policy.py index ca8e218f7c..3056b5d7b8 100644 --- a/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py +++ b/openfl/component/straggler_handling_policy/straggler_handling_policy.py @@ -2,17 +2,67 @@ # SPDX-License-Identifier: Apache-2.0 -"""Cutoff time based Straggler Handling function.""" +"""Straggler handling module.""" + import threading import time +from abc import ABC, abstractmethod from logging import getLogger from typing import Callable import numpy as np -from openfl.component.straggler_handling_functions.straggler_handling_function import ( - StragglerHandlingPolicy, -) + +class StragglerHandlingPolicy(ABC): + """Federated Learning straggler handling interface.""" + + @abstractmethod + def start_policy(self, **kwargs) -> None: + """ + Start straggler handling policy for collaborator for a particular round. + NOTE: Refer CutoffTimeBasedStragglerHandling for reference. + + Args: + **kwargs + + Returns: + None + """ + raise NotImplementedError + + @abstractmethod + def reset_policy_for_round(self) -> None: + """ + Reset policy variable for the next round. + + Args: + None + + Returns: + None + """ + raise NotImplementedError + + @abstractmethod + def straggler_cutoff_check( + self, num_collaborators_done: int, num_all_collaborators: int, **kwargs + ) -> bool: + """ + Determines whether it is time to end the round early. + + Args: + num_collaborators_done: int + Number of collaborators finished. + num_all_collaborators: int + Total number of collaborators. + + Returns: + bool: True if it is time to end the round early, False otherwise. + + Raises: + NotImplementedError: This method must be implemented by a subclass. + """ + raise NotImplementedError class CutoffTimeBasedStragglerHandling(StragglerHandlingPolicy): @@ -140,3 +190,71 @@ def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: False otherwise. """ return num_collaborators_done >= self.minimum_reporting + + +class PercentageBasedStragglerHandling(StragglerHandlingPolicy): + """Percentage based Straggler Handling function.""" + + def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs): + """Initialize a PercentageBasedStragglerHandling object. + + Args: + percent_collaborators_needed (float, optional): The percentage of + collaborators needed. Defaults to 1.0. + minimum_reporting (int, optional): The minimum number of + collaborators that should report. Defaults to 1. + **kwargs: Variable length argument list. + """ + if minimum_reporting <= 0: + raise ValueError("minimum_reporting must be >0") + + self.percent_collaborators_needed = percent_collaborators_needed + self.minimum_reporting = minimum_reporting + self.logger = getLogger(__name__) + + def reset_policy_for_round(self) -> None: + """ + Not required in PercentageBasedStragglerHandling. + """ + pass + + def start_policy(self, **kwargs) -> None: + """ + Not required in PercentageBasedStragglerHandling. + """ + pass + + def straggler_cutoff_check( + self, + num_collaborators_done: int, + num_all_collaborators: int, + ) -> bool: + """ + If percent_collaborators_needed and minimum_reporting collaborators have + reported results, then it is time to end round early. + + Args: + num_collaborators_done (int): The number of collaborators that + have reported. + all_collaborators (list): All the collaborators. + + Returns: + bool: True if the straggler cutoff conditions are met, False + otherwise. + """ + return ( + num_collaborators_done >= self.percent_collaborators_needed * num_all_collaborators + ) and self.__minimum_collaborators_reported(num_collaborators_done) + + def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: + """Check if the minimum number of collaborators have reported. + + Args: + num_collaborators_done (int): The number of collaborators that + have reported. + + Returns: + bool: True if the minimum number of collaborators have reported, + False otherwise. + """ + return num_collaborators_done >= self.minimum_reporting diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 5f0575837d..526e0d49c2 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -421,7 +421,7 @@ def get_tensor_pipe(self): def get_straggler_handling_policy(self): """Get straggler handling policy.""" - template = "openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling" + template = "openfl.component.straggler_handling_policy.CutoffTimeBasedStragglerHandling" defaults = self.config.get("straggler_handling_policy", {TEMPLATE: template, SETTINGS: {}}) if self.straggler_policy_ is None: diff --git a/setup.py b/setup.py index b1dea5bcfd..27fc79ac98 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ def run(self): 'openfl.component.collaborator', 'openfl.component.director', 'openfl.component.envoy', - 'openfl.component.straggler_handling_functions', + 'openfl.component.straggler_handling_policy', 'openfl.cryptography', 'openfl.databases', 'openfl.databases.utilities', From 896786f97a437ffa9e3f407b4acee85fbac00cda Mon Sep 17 00:00:00 2001 From: Ishant Thakare Date: Wed, 23 Oct 2024 17:55:43 +0530 Subject: [PATCH 2/6] Incorporated review comments Signed-off-by: Ishant Thakare --- .../straggler_handling_algorithms.rst | 8 +++---- docs/source/api/openfl_component.rst | 1 - .../plan/plan.yaml | 2 +- openfl/component/__init__.py | 10 ++++---- openfl/component/aggregator/__init__.py | 5 ++++ openfl/component/aggregator/aggregator.py | 14 ++++++----- .../straggler_handling.py} | 24 ++++++++++--------- .../straggler_handling_policy/__init__.py | 9 ------- openfl/federated/plan/plan.py | 2 +- setup.py | 1 - 10 files changed, 37 insertions(+), 39 deletions(-) rename openfl/component/{straggler_handling_policy/straggler_handling_policy.py => aggregator/straggler_handling.py} (92%) delete mode 100644 openfl/component/straggler_handling_policy/__init__.py diff --git a/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst b/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst index d5f00500e0..26527de921 100644 --- a/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst +++ b/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst @@ -11,7 +11,7 @@ The Open Federated Learning (|productName|) framework supports straggler handlin The following are the straggler handling algorithms supported in |productName|: -``CutoffTimeBasedStragglerHandling`` +``CutoffPolicy`` Identifies stragglers based on the cutoff time specified in the settings. Arguments to the function are: - *Cutoff Time* (straggler_cutoff_time), specifies the cutoff time by which the aggregator should end the round early. - *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model. @@ -19,7 +19,7 @@ The following are the straggler handling algorithms supported in |productName|: For example, in a federation of 5 collaborators, if :code:`straggler_cutoff_time` (in seconds) is set to 20 and :code:`minimum_reporting` is set to 2, atleast 2 collaborators (or more) would be included in the round, provided that the time limit of 20 seconds is not exceeded. In an event where :code:`minimum_reporting` collaborators don't make it within the :code:`straggler_cutoff_time`, the straggler handling policy is disregarded. -``PercentageBasedStragglerHandling`` +``PercentagePolicy`` Identifies stragglers based on the percetage specified. Arguments to the function are: - *Percentage of collaborators* (percent_collaborators_needed), specifies a percentage of collaborators enough to end the round early. - *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model. @@ -29,12 +29,12 @@ The following are the straggler handling algorithms supported in |productName|: Demonstration of adding the straggler handling interface ========================================================= -The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentageBasedStragglerHandling``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffTimeBasedStragglerHandling** function instead: +The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentagePolicy``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffPolicy** function instead: .. code-block:: yaml straggler_handling_policy : - template : openfl.component.straggler_handling_policy.CutoffTimeBasedStragglerHandling + template : openfl.component.aggregator.straggler_handling.CutoffPolicy settings : straggler_cutoff_time : 20 minimum_reporting : 1 diff --git a/docs/source/api/openfl_component.rst b/docs/source/api/openfl_component.rst index 67639599f5..0af3099b49 100644 --- a/docs/source/api/openfl_component.rst +++ b/docs/source/api/openfl_component.rst @@ -17,4 +17,3 @@ Component modules reference: openfl.component.collaborator openfl.component.director openfl.component.envoy - openfl.component.straggler_handling_policy diff --git a/openfl-workspace/torch_cnn_mnist_straggler_check/plan/plan.yaml b/openfl-workspace/torch_cnn_mnist_straggler_check/plan/plan.yaml index cca9414fb6..e2378424b9 100644 --- a/openfl-workspace/torch_cnn_mnist_straggler_check/plan/plan.yaml +++ b/openfl-workspace/torch_cnn_mnist_straggler_check/plan/plan.yaml @@ -45,7 +45,7 @@ compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml straggler_handling_policy : - template : openfl.component.straggler_handling_policy.CutoffTimeBasedStragglerHandling + template : openfl.component.aggregator.straggler_handling.PercentagePolicy settings : percent_collaborators_needed : 0.5 minimum_reporting : 1 \ No newline at end of file diff --git a/openfl/component/__init__.py b/openfl/component/__init__.py index 97788f4a6f..60a078371e 100644 --- a/openfl/component/__init__.py +++ b/openfl/component/__init__.py @@ -3,12 +3,12 @@ from openfl.component.aggregator.aggregator import Aggregator +from openfl.component.aggregator.straggler_handling import ( + CutoffPolicy, + PercentagePolicy, + StragglerHandlingPolicy, +) from openfl.component.assigner.assigner import Assigner from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner from openfl.component.assigner.static_grouped_assigner import StaticGroupedAssigner from openfl.component.collaborator.collaborator import Collaborator -from openfl.component.straggler_handling_policy import ( - CutoffTimeBasedStragglerHandling, - PercentageBasedStragglerHandling, - StragglerHandlingPolicy, -) diff --git a/openfl/component/aggregator/__init__.py b/openfl/component/aggregator/__init__.py index ed7661486e..728e4c68d8 100644 --- a/openfl/component/aggregator/__init__.py +++ b/openfl/component/aggregator/__init__.py @@ -3,3 +3,8 @@ from openfl.component.aggregator.aggregator import Aggregator +from openfl.component.aggregator.straggler_handling import ( + CutoffPolicy, + PercentagePolicy, + StragglerHandlingPolicy, +) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 4de7ab40af..eb23bc50ed 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -8,7 +8,7 @@ from logging import getLogger from threading import Lock -from openfl.component.straggler_handling_policy import CutoffTimeBasedStragglerHandling +from openfl.component.aggregator.straggler_handling import CutoffPolicy from openfl.databases import TensorDB from openfl.interface.aggregation_functions import WeightedAverage from openfl.pipelines import NoCompressionPipeline, TensorCodec @@ -69,7 +69,7 @@ def __init__( best_state_path, last_state_path, assigner, - straggler_handling_policy=None, + straggler_handling_policy=CutoffPolicy, rounds_to_train=256, single_col_cert_common_name=None, compression_pipeline=None, @@ -92,7 +92,7 @@ def __init__( weight. assigner: Assigner object. straggler_handling_policy (optional): Straggler handling policy. - Defaults to CutoffTimeBasedStragglerHandling. + Defaults to CutoffPolicy. rounds_to_train (int, optional): Number of rounds to train. Defaults to 256. single_col_cert_common_name (str, optional): Common name for single @@ -117,9 +117,11 @@ def __init__( # Cleaner solution? self.single_col_cert_common_name = "" - self.straggler_handling_policy = ( - straggler_handling_policy or CutoffTimeBasedStragglerHandling() - ) + if straggler_handling_policy == CutoffPolicy: + self.straggler_handling_policy = straggler_handling_policy() + else: + self.straggler_handling_policy = straggler_handling_policy + self._end_of_round_check_done = [False] * rounds_to_train self.stragglers = [] diff --git a/openfl/component/straggler_handling_policy/straggler_handling_policy.py b/openfl/component/aggregator/straggler_handling.py similarity index 92% rename from openfl/component/straggler_handling_policy/straggler_handling_policy.py rename to openfl/component/aggregator/straggler_handling.py index 3056b5d7b8..df26080e90 100644 --- a/openfl/component/straggler_handling_policy/straggler_handling_policy.py +++ b/openfl/component/aggregator/straggler_handling.py @@ -20,7 +20,7 @@ class StragglerHandlingPolicy(ABC): def start_policy(self, **kwargs) -> None: """ Start straggler handling policy for collaborator for a particular round. - NOTE: Refer CutoffTimeBasedStragglerHandling for reference. + NOTE: Refer CutoffPolicy for reference. Args: **kwargs @@ -65,14 +65,14 @@ def straggler_cutoff_check( raise NotImplementedError -class CutoffTimeBasedStragglerHandling(StragglerHandlingPolicy): +class CutoffPolicy(StragglerHandlingPolicy): """Cutoff time based Straggler Handling function.""" def __init__( self, round_start_time=None, straggler_cutoff_time=np.inf, minimum_reporting=1, **kwargs ): """ - Initialize a CutoffTimeBasedStragglerHandling object. + Initialize a CutoffPolicy object. Args: round_start_time (optional): The start time of the round. Defaults @@ -89,12 +89,12 @@ def __init__( self.round_start_time = round_start_time self.straggler_cutoff_time = straggler_cutoff_time self.minimum_reporting = minimum_reporting + self.is_timer_started = False self.logger = getLogger(__name__) if self.straggler_cutoff_time == np.inf: self.logger.warning( - "CutoffTimeBasedStragglerHandling is disabled as straggler_cutoff_time " - "is set to np.inf." + "CutoffPolicy is disabled as straggler_cutoff_time " "is set to np.inf." ) def reset_policy_for_round(self) -> None: @@ -103,7 +103,7 @@ def reset_policy_for_round(self) -> None: """ if hasattr(self, "timer"): self.timer.cancel() - delattr(self, "timer") + self.is_timer_started = False def start_policy(self, callback: Callable) -> None: """ @@ -120,8 +120,9 @@ def start_policy(self, callback: Callable) -> None: # If straggler_cutoff_time is set to infinity # or if the timer is already running, # do not start the policy. - if self.straggler_cutoff_time == np.inf or hasattr(self, "timer"): + if self.straggler_cutoff_time == np.inf or self.is_timer_started: return + self.round_start_time = time.time() self.timer = threading.Timer( self.straggler_cutoff_time, @@ -129,6 +130,7 @@ def start_policy(self, callback: Callable) -> None: ) self.timer.daemon = True self.timer.start() + self.is_timer_started = True def straggler_cutoff_check( self, @@ -192,11 +194,11 @@ def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: return num_collaborators_done >= self.minimum_reporting -class PercentageBasedStragglerHandling(StragglerHandlingPolicy): +class PercentagePolicy(StragglerHandlingPolicy): """Percentage based Straggler Handling function.""" def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs): - """Initialize a PercentageBasedStragglerHandling object. + """Initialize a PercentagePolicy object. Args: percent_collaborators_needed (float, optional): The percentage of @@ -214,13 +216,13 @@ def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwar def reset_policy_for_round(self) -> None: """ - Not required in PercentageBasedStragglerHandling. + Not required in PercentagePolicy. """ pass def start_policy(self, **kwargs) -> None: """ - Not required in PercentageBasedStragglerHandling. + Not required in PercentagePolicy. """ pass diff --git a/openfl/component/straggler_handling_policy/__init__.py b/openfl/component/straggler_handling_policy/__init__.py deleted file mode 100644 index 4de28d0e48..0000000000 --- a/openfl/component/straggler_handling_policy/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2020-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - - -from openfl.component.straggler_handling_policy.straggler_handling_policy import ( - CutoffTimeBasedStragglerHandling, - PercentageBasedStragglerHandling, - StragglerHandlingPolicy, -) diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 526e0d49c2..10ef43589f 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -421,7 +421,7 @@ def get_tensor_pipe(self): def get_straggler_handling_policy(self): """Get straggler handling policy.""" - template = "openfl.component.straggler_handling_policy.CutoffTimeBasedStragglerHandling" + template = "openfl.component.aggregator.straggler_handling.CutoffPolicy" defaults = self.config.get("straggler_handling_policy", {TEMPLATE: template, SETTINGS: {}}) if self.straggler_policy_ is None: diff --git a/setup.py b/setup.py index bb64069e07..c6b03dcfc1 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,6 @@ def run(self): 'openfl.component.collaborator', 'openfl.component.director', 'openfl.component.envoy', - 'openfl.component.straggler_handling_policy', 'openfl.cryptography', 'openfl.databases', 'openfl.databases.utilities', From eff79fc52a5c5b59de958c510837127459380f91 Mon Sep 17 00:00:00 2001 From: Ishant Thakare Date: Fri, 20 Dec 2024 14:38:11 +0530 Subject: [PATCH 3/6] Incorporated karan's review comments Signed-off-by: Ishant Thakare --- openfl/component/__init__.py | 2 +- openfl/component/aggregator/__init__.py | 2 +- openfl/component/aggregator/aggregator.py | 10 ++-- .../aggregator/straggler_handling.py | 46 +++++-------------- openfl/federated/plan/plan.py | 8 +++- 5 files changed, 23 insertions(+), 45 deletions(-) diff --git a/openfl/component/__init__.py b/openfl/component/__init__.py index 60a078371e..792cfd855d 100644 --- a/openfl/component/__init__.py +++ b/openfl/component/__init__.py @@ -6,7 +6,7 @@ from openfl.component.aggregator.straggler_handling import ( CutoffPolicy, PercentagePolicy, - StragglerHandlingPolicy, + StragglerPolicy, ) from openfl.component.assigner.assigner import Assigner from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner diff --git a/openfl/component/aggregator/__init__.py b/openfl/component/aggregator/__init__.py index 728e4c68d8..1bbddaf4f2 100644 --- a/openfl/component/aggregator/__init__.py +++ b/openfl/component/aggregator/__init__.py @@ -6,5 +6,5 @@ from openfl.component.aggregator.straggler_handling import ( CutoffPolicy, PercentagePolicy, - StragglerHandlingPolicy, + StragglerPolicy, ) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 354d512d15..db718dfc85 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -9,7 +9,7 @@ from logging import getLogger from threading import Lock -from openfl.component.aggregator.straggler_handling import CutoffPolicy +from openfl.component.aggregator.straggler_handling import CutoffPolicy, StragglerPolicy from openfl.databases import TensorDB from openfl.interface.aggregation_functions import WeightedAverage from openfl.pipelines import NoCompressionPipeline, TensorCodec @@ -71,7 +71,7 @@ def __init__( last_state_path, assigner, use_delta_updates=True, - straggler_handling_policy=CutoffPolicy, + straggler_handling_policy: StragglerPolicy = CutoffPolicy, rounds_to_train=256, single_col_cert_common_name=None, compression_pipeline=None, @@ -95,7 +95,6 @@ def __init__( weight. assigner: Assigner object. straggler_handling_policy (optional): Straggler handling policy. - Defaults to CutoffPolicy. rounds_to_train (int, optional): Number of rounds to train. Defaults to 256. single_col_cert_common_name (str, optional): Common name for single @@ -123,10 +122,7 @@ def __init__( # FIXME: "" instead of None is for protobuf compatibility. self.single_col_cert_common_name = single_col_cert_common_name or "" - if straggler_handling_policy == CutoffPolicy: - self.straggler_handling_policy = straggler_handling_policy() - else: - self.straggler_handling_policy = straggler_handling_policy + self.straggler_handling_policy = straggler_handling_policy() self._end_of_round_check_done = [False] * rounds_to_train self.stragglers = [] diff --git a/openfl/component/aggregator/straggler_handling.py b/openfl/component/aggregator/straggler_handling.py index df26080e90..1d5f6ef7de 100644 --- a/openfl/component/aggregator/straggler_handling.py +++ b/openfl/component/aggregator/straggler_handling.py @@ -12,8 +12,10 @@ import numpy as np +logger = getLogger(__name__) -class StragglerHandlingPolicy(ABC): + +class StragglerPolicy(ABC): """Federated Learning straggler handling interface.""" @abstractmethod @@ -24,23 +26,12 @@ def start_policy(self, **kwargs) -> None: Args: **kwargs - - Returns: - None """ raise NotImplementedError @abstractmethod def reset_policy_for_round(self) -> None: - """ - Reset policy variable for the next round. - - Args: - None - - Returns: - None - """ + """Reset policy for the next round.""" raise NotImplementedError @abstractmethod @@ -65,7 +56,7 @@ def straggler_cutoff_check( raise NotImplementedError -class CutoffPolicy(StragglerHandlingPolicy): +class CutoffPolicy(StragglerPolicy): """Cutoff time based Straggler Handling function.""" def __init__( @@ -90,17 +81,12 @@ def __init__( self.straggler_cutoff_time = straggler_cutoff_time self.minimum_reporting = minimum_reporting self.is_timer_started = False - self.logger = getLogger(__name__) if self.straggler_cutoff_time == np.inf: - self.logger.warning( - "CutoffPolicy is disabled as straggler_cutoff_time " "is set to np.inf." - ) + logger.warning("CutoffPolicy is disabled as straggler_cutoff_time " "is set to np.inf.") def reset_policy_for_round(self) -> None: - """ - Reset timer for the next round. - """ + """Reset timer for the next round.""" if hasattr(self, "timer"): self.timer.cancel() self.is_timer_started = False @@ -113,9 +99,6 @@ def start_policy(self, callback: Callable) -> None: Args: callback: Callable Callback function for when straggler_cutoff_time elapses - - Returns: - None """ # If straggler_cutoff_time is set to infinity # or if the timer is already running, @@ -159,13 +142,13 @@ def straggler_cutoff_check( # Time has expired # Check if minimum_reporting collaborators have reported results elif self.__minimum_collaborators_reported(num_collaborators_done): - self.logger.info( + logger.info( f"{num_collaborators_done} collaborators have reported results. " "Applying cutoff policy and proceeding with end of round." ) return True else: - self.logger.info( + logger.info( f"Waiting for minimum {self.minimum_reporting} collaborator(s) to report results." ) return False @@ -194,7 +177,7 @@ def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: return num_collaborators_done >= self.minimum_reporting -class PercentagePolicy(StragglerHandlingPolicy): +class PercentagePolicy(StragglerPolicy): """Percentage based Straggler Handling function.""" def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs): @@ -212,18 +195,13 @@ def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwar self.percent_collaborators_needed = percent_collaborators_needed self.minimum_reporting = minimum_reporting - self.logger = getLogger(__name__) def reset_policy_for_round(self) -> None: - """ - Not required in PercentagePolicy. - """ + """Not required in PercentagePolicy.""" pass def start_policy(self, **kwargs) -> None: - """ - Not required in PercentagePolicy. - """ + """Not required in PercentagePolicy.""" pass def straggler_cutoff_check( diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 9c326a746a..b2321d2011 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -4,6 +4,7 @@ """Plan module.""" +from functools import partial from hashlib import sha384 from importlib import import_module from logging import getLogger @@ -43,7 +44,7 @@ class Plan: server_ (AggregatorGRPCServer): gRPC server object. client_ (AggregatorGRPCClient): gRPC client object. pipe_ (CompressionPipeline): Compression pipeline object. - straggler_policy_ (StragglerHandlingPolicy): Straggler handling policy. + straggler_policy_ (StragglerPolicy): Straggler handling policy. hash_ (str): Hash of the instance. name_ (str): Name of the instance. serializer_ (SerializerPlugin): Serializer plugin. @@ -426,7 +427,10 @@ def get_straggler_handling_policy(self): defaults = self.config.get("straggler_handling_policy", {TEMPLATE: template, SETTINGS: {}}) if self.straggler_policy_ is None: - self.straggler_policy_ = Plan.build(**defaults) + # Prepare a partial function for the straggler policy + self.straggler_policy_ = partial( + Plan.import_(defaults["template"]), **defaults["settings"] + ) return self.straggler_policy_ From 7f16cdb77bd3a7fd1b9b38daf8791fde9f316ebf Mon Sep 17 00:00:00 2001 From: Ishant Thakare Date: Fri, 20 Dec 2024 14:55:41 +0530 Subject: [PATCH 4/6] Resolving merge conflicts Signed-off-by: Ishant Thakare --- docs/source/api/openfl_component.rst | 19 ------------------- docs/source/api/openfl_cryptography.rst | 16 ---------------- docs/source/api/openfl_databases.rst | 16 ---------------- docs/source/api/openfl_experimental.rst | 18 ------------------ docs/source/api/openfl_federated.rst | 18 ------------------ docs/source/api/openfl_interface.rst | 16 ---------------- docs/source/api/openfl_native.rst | 16 ---------------- docs/source/api/openfl_pipelines.rst | 16 ---------------- docs/source/api/openfl_plugins.rst | 16 ---------------- docs/source/api/openfl_protocols.rst | 16 ---------------- docs/source/api/openfl_transport.rst | 15 --------------- docs/source/api/openfl_utilities.rst | 16 ---------------- docs/source/openfl/director_workflow.svg | 1 - docs/source/openfl/static_diagram.svg | 1 - 14 files changed, 200 deletions(-) delete mode 100644 docs/source/api/openfl_component.rst delete mode 100644 docs/source/api/openfl_cryptography.rst delete mode 100644 docs/source/api/openfl_databases.rst delete mode 100644 docs/source/api/openfl_experimental.rst delete mode 100644 docs/source/api/openfl_federated.rst delete mode 100644 docs/source/api/openfl_interface.rst delete mode 100644 docs/source/api/openfl_native.rst delete mode 100644 docs/source/api/openfl_pipelines.rst delete mode 100644 docs/source/api/openfl_plugins.rst delete mode 100644 docs/source/api/openfl_protocols.rst delete mode 100644 docs/source/api/openfl_transport.rst delete mode 100644 docs/source/api/openfl_utilities.rst delete mode 100644 docs/source/openfl/director_workflow.svg delete mode 100644 docs/source/openfl/static_diagram.svg diff --git a/docs/source/api/openfl_component.rst b/docs/source/api/openfl_component.rst deleted file mode 100644 index 0af3099b49..0000000000 --- a/docs/source/api/openfl_component.rst +++ /dev/null @@ -1,19 +0,0 @@ -.. # Copyright (C) 2020-2024 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -Component Module -************************************************* - -Component modules reference: - -.. autosummary:: - :toctree: _autosummary - :template: custom-module-template.rst - :recursive: - - openfl.component.aggregator - openfl.component.assigner - openfl.component.collaborator - openfl.component.director - openfl.component.envoy diff --git a/docs/source/api/openfl_cryptography.rst b/docs/source/api/openfl_cryptography.rst deleted file mode 100644 index 475ebd1e9b..0000000000 --- a/docs/source/api/openfl_cryptography.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. # Copyright (C) 2020-2024 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -Cryptography Module -************************************************* - -Cryptography modules reference: - -.. autosummary:: - :toctree: _autosummary - :recursive: - - openfl.cryptography.ca - openfl.cryptography.io - openfl.cryptography.participant diff --git a/docs/source/api/openfl_databases.rst b/docs/source/api/openfl_databases.rst deleted file mode 100644 index 8014d42114..0000000000 --- a/docs/source/api/openfl_databases.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. # Copyright (C) 2020-2024 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -Databases Module -************************************************* - -Databases modules reference: - -.. autosummary:: - :toctree: _autosummary - :template: custom-module-template.rst - :recursive: - - openfl.databases - \ No newline at end of file diff --git a/docs/source/api/openfl_experimental.rst b/docs/source/api/openfl_experimental.rst deleted file mode 100644 index 907645686d..0000000000 --- a/docs/source/api/openfl_experimental.rst +++ /dev/null @@ -1,18 +0,0 @@ -.. # Copyright (C) 2020-2024 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -Experimental Module -************************************************* - -Experimental modules reference: - -.. autosummary:: - :toctree: _autosummary - :template: custom-module-template.rst - :recursive: - - openfl.experimental.workflow.interface - openfl.experimental.workflow.placement - openfl.experimental.workflow.runtime - openfl.experimental.workflow.utilities diff --git a/docs/source/api/openfl_federated.rst b/docs/source/api/openfl_federated.rst deleted file mode 100644 index 8c3d50b81e..0000000000 --- a/docs/source/api/openfl_federated.rst +++ /dev/null @@ -1,18 +0,0 @@ -.. # Copyright (C) 2020-2024 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -Federated Module -************************************************* - -Federated modules reference: - -.. autosummary:: - :toctree: _autosummary - :template: custom-module-template.rst - :recursive: - - openfl.federated.plan - openfl.federated.task - openfl.federated.data - \ No newline at end of file diff --git a/docs/source/api/openfl_interface.rst b/docs/source/api/openfl_interface.rst deleted file mode 100644 index 8685cce5f0..0000000000 --- a/docs/source/api/openfl_interface.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. # Copyright (C) 2020-2024 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -Interface Module -************************************************* - -Interface modules reference: - -.. autosummary:: - :toctree: _autosummary - :template: custom-module-template.rst - :recursive: - - openfl.interface - \ No newline at end of file diff --git a/docs/source/api/openfl_native.rst b/docs/source/api/openfl_native.rst deleted file mode 100644 index 5f3f513340..0000000000 --- a/docs/source/api/openfl_native.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. # Copyright (C) 2020-2024 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -Native Module (Deprecated) -************************************************* - -Native modules reference: - -.. autosummary:: - :toctree: _autosummary - :template: custom-module-template.rst - :recursive: - - openfl.native - \ No newline at end of file diff --git a/docs/source/api/openfl_pipelines.rst b/docs/source/api/openfl_pipelines.rst deleted file mode 100644 index 42ec1b33ad..0000000000 --- a/docs/source/api/openfl_pipelines.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. # Copyright (C) 2020-2024 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -Pipelines Module -************************************************* - -Pipelines modules reference: - -.. autosummary:: - :toctree: _autosummary - :template: custom-module-template.rst - :recursive: - - openfl.pipelines - \ No newline at end of file diff --git a/docs/source/api/openfl_plugins.rst b/docs/source/api/openfl_plugins.rst deleted file mode 100644 index de8df91f4f..0000000000 --- a/docs/source/api/openfl_plugins.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. # Copyright (C) 2020-2024 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -Plugins Module -************************************************* - -Plugins modules reference: - -.. autosummary:: - :toctree: _autosummary - :template: custom-module-template.rst - :recursive: - - openfl.plugins - \ No newline at end of file diff --git a/docs/source/api/openfl_protocols.rst b/docs/source/api/openfl_protocols.rst deleted file mode 100644 index e6e571ccc3..0000000000 --- a/docs/source/api/openfl_protocols.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. # Copyright (C) 2020-2024 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -Protocols Module -************************************************* - -Protocols modules reference: - -.. autosummary:: - :toctree: _autosummary - :template: custom-module-template.rst - :recursive: - - openfl.protocols - \ No newline at end of file diff --git a/docs/source/api/openfl_transport.rst b/docs/source/api/openfl_transport.rst deleted file mode 100644 index 19eb01d839..0000000000 --- a/docs/source/api/openfl_transport.rst +++ /dev/null @@ -1,15 +0,0 @@ -.. # Copyright (C) 2020-2024 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -Transport Module -************************************************* - -Transport modules reference: - -.. autosummary:: - :toctree: _autosummary - :template: custom-module-template.rst - :recursive: - - openfl.transport diff --git a/docs/source/api/openfl_utilities.rst b/docs/source/api/openfl_utilities.rst deleted file mode 100644 index b44e1f74d7..0000000000 --- a/docs/source/api/openfl_utilities.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. # Copyright (C) 2020-2024 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -Utilities Module -************************************************* - -Utilities modules reference: - -.. autosummary:: - :toctree: _autosummary - :template: custom-module-template.rst - :recursive: - - openfl.utilities - \ No newline at end of file diff --git a/docs/source/openfl/director_workflow.svg b/docs/source/openfl/director_workflow.svg deleted file mode 100644 index bdc57e11d0..0000000000 --- a/docs/source/openfl/director_workflow.svg +++ /dev/null @@ -1 +0,0 @@ -OpenFLPython API componentExperiment ManagerDirector ManagerCollaborator ManagerCentral nodeDirectorCollaborator nodeEnvoyAggregatorCollaboratorProvides FL plans, Tasks, Models, Data LoadersLaunches.Sets up global Federation settingsLaunches.Provides local dataset Shard DescriptorsCreates an instance to maintain an FL experimentCreates an instance to maintain an FL experimentCommunicates dataset info, Sends status updatesApproves, Sends FL experimentsSends locally tuned tensors and training metricsSends tasks and initial tensorsRegisters FL experimentsSends info about the Federation. Returns training artifacts \ No newline at end of file diff --git a/docs/source/openfl/static_diagram.svg b/docs/source/openfl/static_diagram.svg deleted file mode 100644 index 8f816d4e88..0000000000 --- a/docs/source/openfl/static_diagram.svg +++ /dev/null @@ -1 +0,0 @@ -Friday, 27 August 2021, 16:25 Moscow Standard TimeContainer diagram for OpenFLOpenFL[Software System]Central node-Collaborator node-Data scientist[Person]A person or group of peopleusing OpenFLEnvoy[Container]A long-living entity that can adapt alocal data set and spawncollaborators+Collaborator manager[Person]Data owner's representativecontrolling EnvoyDirector manager[Person]-Collaborator[Container]Actor executing tasks on local datainside one experiment+Python API component[Container]A set of tools to setup register FLExperiments+Director[Container]A long-living entity that can spawnaggregators-Aggregator[Container]Model server and collaboratororchestrator-Launches. Setsup globalFederationsettings--Remove link.Link options.Launches.Provides localdatasetShardDescriptors--Remove link.Link options.Sends locallytuned tensorsand trainingmetrics--Remove vertex.Remove link.Link options.Provides FL Plans,Tasks, Models,DataLoaders--Remove link.Link options.Sends tasks andinitial tensors--Remove vertex.Remove link.Link options.Approves, SendsFL experiments--Remove vertex.Remove link.Link options.Communicatesdataset info,Sends statusupdates--Remove vertex.Remove link.Link options.Creates aninstance tomaintain an FLexperiment--Remove link.Link options.Creates aninstance tomaintain an FLexperiment--Remove link.Link options.Sendsinformationabout theFederation.Returns trainingartifacts.--Remove vertex.Remove link.Link options.Registers FLexperiments--Remove vertex.Remove link.Link options. \ No newline at end of file From 8f2fc6be9cc241a1e3e1d3f46bd4a4bbc9ec6d86 Mon Sep 17 00:00:00 2001 From: Ishant Thakare Date: Fri, 17 Jan 2025 08:29:21 +0530 Subject: [PATCH 5/6] Fix code format Signed-off-by: Ishant Thakare --- openfl/component/aggregator/straggler_handling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfl/component/aggregator/straggler_handling.py b/openfl/component/aggregator/straggler_handling.py index 1d5f6ef7de..d1157ae3ff 100644 --- a/openfl/component/aggregator/straggler_handling.py +++ b/openfl/component/aggregator/straggler_handling.py @@ -83,7 +83,7 @@ def __init__( self.is_timer_started = False if self.straggler_cutoff_time == np.inf: - logger.warning("CutoffPolicy is disabled as straggler_cutoff_time " "is set to np.inf.") + logger.warning("CutoffPolicy is disabled as straggler_cutoff_time is set to np.inf.") def reset_policy_for_round(self) -> None: """Reset timer for the next round.""" From 3133cb68965f8edd07d5b1c8e3260f187fb2fe58 Mon Sep 17 00:00:00 2001 From: Ishant Thakare Date: Wed, 22 Jan 2025 08:36:11 +0530 Subject: [PATCH 6/6] Incorporated review comments Signed-off-by: Ishant Thakare --- .../straggler_handling_algorithms.rst | 6 +++--- openfl/component/__init__.py | 2 +- openfl/component/aggregator/__init__.py | 2 +- openfl/component/aggregator/aggregator.py | 4 ++-- openfl/component/aggregator/straggler_handling.py | 13 ++++++++----- openfl/federated/plan/plan.py | 2 +- 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst b/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst index 32489e9ef1..88579bde57 100644 --- a/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst +++ b/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst @@ -11,7 +11,7 @@ The Open Federated Learning (OpenFL) framework supports straggler handling inter The following are the straggler handling algorithms supported in OpenFL: -``CutoffPolicy`` +``CutoffTimePolicy`` Identifies stragglers based on the cutoff time specified in the settings. Arguments to the function are: - *Cutoff Time* (straggler_cutoff_time), specifies the cutoff time by which the aggregator should end the round early. - *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model. @@ -29,12 +29,12 @@ The following are the straggler handling algorithms supported in OpenFL: Demonstration of adding the straggler handling interface ========================================================= -The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentagePolicy``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffPolicy** function instead: +The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentagePolicy``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffTimePolicy** function instead: .. code-block:: yaml straggler_handling_policy : - template : openfl.component.aggregator.straggler_handling.CutoffPolicy + template : openfl.component.aggregator.straggler_handling.CutoffTimePolicy settings : straggler_cutoff_time : 20 minimum_reporting : 1 diff --git a/openfl/component/__init__.py b/openfl/component/__init__.py index 792cfd855d..499af56320 100644 --- a/openfl/component/__init__.py +++ b/openfl/component/__init__.py @@ -4,7 +4,7 @@ from openfl.component.aggregator.aggregator import Aggregator from openfl.component.aggregator.straggler_handling import ( - CutoffPolicy, + CutoffTimePolicy, PercentagePolicy, StragglerPolicy, ) diff --git a/openfl/component/aggregator/__init__.py b/openfl/component/aggregator/__init__.py index 1bbddaf4f2..71a7ce11ea 100644 --- a/openfl/component/aggregator/__init__.py +++ b/openfl/component/aggregator/__init__.py @@ -4,7 +4,7 @@ from openfl.component.aggregator.aggregator import Aggregator from openfl.component.aggregator.straggler_handling import ( - CutoffPolicy, + CutoffTimePolicy, PercentagePolicy, StragglerPolicy, ) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index e48039875e..5ede1c1da9 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -10,7 +10,7 @@ from typing import List, Optional import openfl.callbacks as callbacks_module -from openfl.component.aggregator.straggler_handling import CutoffPolicy, StragglerPolicy +from openfl.component.aggregator.straggler_handling import CutoffTimePolicy, StragglerPolicy from openfl.databases import PersistentTensorDB, TensorDB from openfl.interface.aggregation_functions import WeightedAverage from openfl.pipelines import NoCompressionPipeline, TensorCodec @@ -73,7 +73,7 @@ def __init__( last_state_path, assigner, use_delta_updates=True, - straggler_handling_policy: StragglerPolicy = CutoffPolicy, + straggler_handling_policy: StragglerPolicy = CutoffTimePolicy, rounds_to_train=256, single_col_cert_common_name=None, compression_pipeline=None, diff --git a/openfl/component/aggregator/straggler_handling.py b/openfl/component/aggregator/straggler_handling.py index d1157ae3ff..66207bddf1 100644 --- a/openfl/component/aggregator/straggler_handling.py +++ b/openfl/component/aggregator/straggler_handling.py @@ -22,7 +22,7 @@ class StragglerPolicy(ABC): def start_policy(self, **kwargs) -> None: """ Start straggler handling policy for collaborator for a particular round. - NOTE: Refer CutoffPolicy for reference. + NOTE: Refer CutoffTimePolicy for reference. Args: **kwargs @@ -56,14 +56,14 @@ def straggler_cutoff_check( raise NotImplementedError -class CutoffPolicy(StragglerPolicy): +class CutoffTimePolicy(StragglerPolicy): """Cutoff time based Straggler Handling function.""" def __init__( self, round_start_time=None, straggler_cutoff_time=np.inf, minimum_reporting=1, **kwargs ): """ - Initialize a CutoffPolicy object. + Initialize a CutoffTimePolicy object. Args: round_start_time (optional): The start time of the round. Defaults @@ -71,7 +71,8 @@ def __init__( straggler_cutoff_time (float, optional): The cutoff time for stragglers. Defaults to np.inf. minimum_reporting (int, optional): The minimum number of - collaborators that should report. Defaults to 1. + collaborators that should report before moving to the next round. + Defaults to 1. **kwargs: Variable length argument list. """ if minimum_reporting <= 0: @@ -83,7 +84,9 @@ def __init__( self.is_timer_started = False if self.straggler_cutoff_time == np.inf: - logger.warning("CutoffPolicy is disabled as straggler_cutoff_time is set to np.inf.") + logger.warning( + "CutoffTimePolicy is disabled as straggler_cutoff_time is set to np.inf." + ) def reset_policy_for_round(self) -> None: """Reset timer for the next round.""" diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 0fe1e87c41..e26bfa56d8 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -411,7 +411,7 @@ def get_tensor_pipe(self): def get_straggler_handling_policy(self): """Get straggler handling policy.""" - template = "openfl.component.aggregator.straggler_handling.CutoffPolicy" + template = "openfl.component.aggregator.straggler_handling.CutoffTimePolicy" defaults = self.config.get("straggler_handling_policy", {TEMPLATE: template, SETTINGS: {}}) if self.straggler_policy_ is None: