diff --git a/docs/_templates/custom-class-template.rst b/docs/_templates/custom-class-template.rst new file mode 100644 index 0000000000..d64b80d51c --- /dev/null +++ b/docs/_templates/custom-class-template.rst @@ -0,0 +1,34 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + :show-inheritance: + :inherited-members: + :special-members: __call__, __add__, __mul__ + + {% block methods %} + {% if methods %} + .. rubric:: {{ _('Methods') }} + + .. autosummary:: + :nosignatures: + {% for item in methods %} + {%- if not item.startswith('_') %} + ~{{ name }}.{{ item }} + {%- endif -%} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block attributes %} + {% if attributes %} + .. rubric:: {{ _('Attributes') }} + + .. autosummary:: + {% for item in attributes %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} \ No newline at end of file diff --git a/docs/_templates/custom-module-template.rst b/docs/_templates/custom-module-template.rst new file mode 100644 index 0000000000..ec6b7ab05d --- /dev/null +++ b/docs/_templates/custom-module-template.rst @@ -0,0 +1,66 @@ +{{ fullname | escape | underline}} + +.. automodule:: {{ fullname }} + :members: + {% block attributes %} + {% if attributes %} + .. rubric:: Module attributes + + .. autosummary:: + :toctree: + {% for item in attributes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block functions %} + {% if functions %} + .. rubric:: {{ _('Functions') }} + + .. autosummary:: + :toctree: + :nosignatures: + {% for item in functions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block classes %} + {% if classes %} + .. rubric:: {{ _('Classes') }} + + .. autosummary:: + :toctree: + :template: custom-class-template.rst + :nosignatures: + {% for item in classes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block exceptions %} + {% if exceptions %} + .. rubric:: {{ _('Exceptions') }} + + .. autosummary:: + :toctree: + {% for item in exceptions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + +{% block modules %} +{% if modules %} +.. autosummary:: + :toctree: + :template: custom-module-template.rst + :recursive: +{% for item in modules %} + {{ item }} +{%- endfor %} +{% endif %} +{% endblock %} \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 8e0983bcbd..5c8b93d1a9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -31,16 +31,19 @@ # import sphinxcontrib.napoleon # NOQA:E800 extensions = [ + 'sphinx.ext.napoleon', 'sphinx_rtd_theme', 'sphinx.ext.autosectionlabel', - 'sphinx.ext.napoleon', 'sphinx-prompt', 'sphinx_copybutton', 'sphinx_substitution_extensions', 'sphinx.ext.ifconfig', 'sphinxcontrib.mermaid', + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', 'recommonmark' ] +autosummary_generate = True # Turn on sphinx.ext.autosummary source_suffix = ['.rst', '.md'] @@ -68,6 +71,37 @@ napoleon_google_docstring = True +# Config the returns section to behave like the Args section +napoleon_custom_sections = [('Returns', 'params_style')] + +# This code extends Sphinx's GoogleDocstring class to support 'Keys', +# 'Attributes', and 'Class Attributes' sections in docstrings. Allows for more +# detailed and structured documentation of Python classes and their attributes. +from sphinx.ext.napoleon.docstring import GoogleDocstring + +# Define new sections and their corresponding parse methods +new_sections = { + 'keys': 'Keys', + 'attributes': 'Attributes', + 'class attributes': 'Class Attributes' +} + +# Add new sections to GoogleDocstring class +for section, title in new_sections.items(): + setattr(GoogleDocstring, f'_parse_{section}_section', + lambda self, section: self._format_fields(title, self._consume_fields())) + + +# Patch the parse method to include new sections +def patched_parse(self): + for section in new_sections: + self._sections[section] = getattr(self, f'_parse_{section}_section') + self._unpatched_parse() + +# Apply the patch +GoogleDocstring._unpatched_parse = GoogleDocstring._parse +GoogleDocstring._parse = patched_parse + # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -75,7 +109,8 @@ # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', 'README.md', 'structurizer_dsl/README.md', - '.DS_Store', 'tutorials/*', 'graveyard/*'] + '.DS_Store', 'tutorials/*', 'graveyard/*', '_templates'] + # add temporary unused files exclude_patterns.extend(['modules.rst', 'install.singularity.rst', diff --git a/docs/developer_ref/api_documentation.rst b/docs/developer_ref/api_documentation.rst index d42fc1251d..fa1126d4a9 100644 --- a/docs/developer_ref/api_documentation.rst +++ b/docs/developer_ref/api_documentation.rst @@ -1,10 +1,33 @@ -.. # Copyright (C) 2020-2023 Intel Corporation -.. # SPDX-License-Identifier: Apache-2.0 - -************************************************* -|productName| API -************************************************* - -Welcome to the |productName| API reference: - -TODO \ No newline at end of file +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +|productName| API +************************************************* + +Welcome to the |productName| API reference: + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + +- :doc:`../source/api/openfl_component` +- :doc:`../source/api/openfl_cryptography` +- :doc:`../source/api/openfl_databases` +- :doc:`../source/api/openfl_experimental` +- :doc:`../source/api/openfl_federated` +- :doc:`../source/api/openfl_interface` +- :doc:`../source/api/openfl_native` +- :doc:`../source/api/openfl_pipelines` +- :doc:`../source/api/openfl_plugins` +- :doc:`../source/api/openfl_protocols` +- :doc:`../source/api/openfl_transport` +- :doc:`../source/api/openfl_utilities` + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 8a31cecc36..1d51291cb3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -64,12 +64,4 @@ Looking for the Open Flash Library project also referred to as OpenFL? Find it ` :caption: CONTRIBUTING GUIDELINES :maxdepth: 2 - contributing_guidelines/contributing - - -.. Indices and tables -.. ================== - -.. * :ref:`genindex` -.. * :ref:`modindex` -.. * :ref:`search` \ No newline at end of file + contributing_guidelines/contributing diff --git a/docs/source/api/openfl_component.rst b/docs/source/api/openfl_component.rst new file mode 100644 index 0000000000..8deb6528ac --- /dev/null +++ b/docs/source/api/openfl_component.rst @@ -0,0 +1,20 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +Component Module +************************************************* + +Component modules reference: + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + + openfl.component.aggregator + openfl.component.assigner + openfl.component.collaborator + openfl.component.director + openfl.component.envoy + openfl.component.straggler_handling_functions diff --git a/docs/source/api/openfl_cryptography.rst b/docs/source/api/openfl_cryptography.rst new file mode 100644 index 0000000000..475ebd1e9b --- /dev/null +++ b/docs/source/api/openfl_cryptography.rst @@ -0,0 +1,16 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +Cryptography Module +************************************************* + +Cryptography modules reference: + +.. autosummary:: + :toctree: _autosummary + :recursive: + + openfl.cryptography.ca + openfl.cryptography.io + openfl.cryptography.participant diff --git a/docs/source/api/openfl_databases.rst b/docs/source/api/openfl_databases.rst new file mode 100644 index 0000000000..8014d42114 --- /dev/null +++ b/docs/source/api/openfl_databases.rst @@ -0,0 +1,16 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +Databases Module +************************************************* + +Databases modules reference: + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + + openfl.databases + \ No newline at end of file diff --git a/docs/source/api/openfl_experimental.rst b/docs/source/api/openfl_experimental.rst new file mode 100644 index 0000000000..01fbb0fcee --- /dev/null +++ b/docs/source/api/openfl_experimental.rst @@ -0,0 +1,18 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +Experimental Module +************************************************* + +Experimental modules reference: + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + + openfl.experimental.interface + openfl.experimental.placement + openfl.experimental.runtime + openfl.experimental.utilities diff --git a/docs/source/api/openfl_federated.rst b/docs/source/api/openfl_federated.rst new file mode 100644 index 0000000000..5072c00516 --- /dev/null +++ b/docs/source/api/openfl_federated.rst @@ -0,0 +1,16 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +Federated Module +************************************************* + +Federated modules reference: + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + + openfl.federated + \ No newline at end of file diff --git a/docs/source/api/openfl_interface.rst b/docs/source/api/openfl_interface.rst new file mode 100644 index 0000000000..8685cce5f0 --- /dev/null +++ b/docs/source/api/openfl_interface.rst @@ -0,0 +1,16 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +Interface Module +************************************************* + +Interface modules reference: + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + + openfl.interface + \ No newline at end of file diff --git a/docs/source/api/openfl_native.rst b/docs/source/api/openfl_native.rst new file mode 100644 index 0000000000..bd9eb608d3 --- /dev/null +++ b/docs/source/api/openfl_native.rst @@ -0,0 +1,16 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +Native Module +************************************************* + +Native modules reference: + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + + openfl.native + \ No newline at end of file diff --git a/docs/source/api/openfl_pipelines.rst b/docs/source/api/openfl_pipelines.rst new file mode 100644 index 0000000000..42ec1b33ad --- /dev/null +++ b/docs/source/api/openfl_pipelines.rst @@ -0,0 +1,16 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +Pipelines Module +************************************************* + +Pipelines modules reference: + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + + openfl.pipelines + \ No newline at end of file diff --git a/docs/source/api/openfl_plugins.rst b/docs/source/api/openfl_plugins.rst new file mode 100644 index 0000000000..de8df91f4f --- /dev/null +++ b/docs/source/api/openfl_plugins.rst @@ -0,0 +1,16 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +Plugins Module +************************************************* + +Plugins modules reference: + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + + openfl.plugins + \ No newline at end of file diff --git a/docs/source/api/openfl_protocols.rst b/docs/source/api/openfl_protocols.rst new file mode 100644 index 0000000000..e6e571ccc3 --- /dev/null +++ b/docs/source/api/openfl_protocols.rst @@ -0,0 +1,16 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +Protocols Module +************************************************* + +Protocols modules reference: + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + + openfl.protocols + \ No newline at end of file diff --git a/docs/source/api/openfl_transport.rst b/docs/source/api/openfl_transport.rst new file mode 100644 index 0000000000..19eb01d839 --- /dev/null +++ b/docs/source/api/openfl_transport.rst @@ -0,0 +1,15 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +Transport Module +************************************************* + +Transport modules reference: + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + + openfl.transport diff --git a/docs/source/api/openfl_utilities.rst b/docs/source/api/openfl_utilities.rst new file mode 100644 index 0000000000..b44e1f74d7 --- /dev/null +++ b/docs/source/api/openfl_utilities.rst @@ -0,0 +1,16 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +************************************************* +Utilities Module +************************************************* + +Utilities modules reference: + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + + openfl.utilities + \ No newline at end of file diff --git a/openfl/__init__.py b/openfl/__init__.py index bb887c6ebf..7fc3c15892 100644 --- a/openfl/__init__.py +++ b/openfl/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""openfl base package.""" -from .__version__ import __version__ -# flake8: noqa -#from .interface.model import get_model +"""Openfl base package.""" +from openfl.__version__ import __version__ +# flake8: noqa \ No newline at end of file diff --git a/openfl/__version__.py b/openfl/__version__.py index fc16e8408a..235bdfab60 100644 --- a/openfl/__version__.py +++ b/openfl/__version__.py @@ -1,4 +1,4 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""openfl version information.""" +"""Openfl version information.""" __version__ = '1.5' diff --git a/openfl/component/__init__.py b/openfl/component/__init__.py index f3aa66f7d1..c6ec0f7141 100644 --- a/openfl/component/__init__.py +++ b/openfl/component/__init__.py @@ -1,24 +1,10 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -"""openfl.component package.""" - -from .aggregator import Aggregator -from .assigner import Assigner -from .assigner import RandomGroupedAssigner -from .assigner import StaticGroupedAssigner -from .collaborator import Collaborator -from .straggler_handling_functions import StragglerHandlingFunction -from .straggler_handling_functions import CutoffTimeBasedStragglerHandling -from .straggler_handling_functions import PercentageBasedStragglerHandling - -__all__ = [ - 'Assigner', - 'RandomGroupedAssigner', - 'StaticGroupedAssigner', - 'Aggregator', - 'Collaborator', - 'StragglerHandlingFunction', - 'CutoffTimeBasedStragglerHandling', - 'PercentageBasedStragglerHandling' -] +from openfl.component.aggregator import Aggregator +from openfl.component.assigner import Assigner, RandomGroupedAssigner, StaticGroupedAssigner +from openfl.component.collaborator import Collaborator +from openfl.component.straggler_handling_functions import ( + CutoffTimeBasedStragglerHandling, + PercentageBasedStragglerHandling, + StragglerHandlingFunction, +) \ No newline at end of file diff --git a/openfl/component/aggregator/__init__.py b/openfl/component/aggregator/__init__.py index 735743adff..5c5f98265f 100644 --- a/openfl/component/aggregator/__init__.py +++ b/openfl/component/aggregator/__init__.py @@ -1,10 +1,3 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -"""Aggregator package.""" - -from .aggregator import Aggregator - -__all__ = [ - 'Aggregator', -] +from openfl.component.aggregator.aggregator import Aggregator diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 4100b195e4..05a8fbbb3d 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -2,49 +2,70 @@ # SPDX-License-Identifier: Apache-2.0 """Aggregator module.""" -import time import queue +import time from logging import getLogger -from openfl.interface.aggregation_functions import WeightedAverage -from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling +from openfl.component.straggler_handling_functions import ( + CutoffTimeBasedStragglerHandling, +) from openfl.databases import TensorDB -from openfl.pipelines import NoCompressionPipeline -from openfl.pipelines import TensorCodec -from openfl.protocols import base_pb2 -from openfl.protocols import utils -from openfl.utilities import change_tags -from openfl.utilities import TaskResultKey -from openfl.utilities import TensorKey +from openfl.interface.aggregation_functions import WeightedAverage +from openfl.pipelines import NoCompressionPipeline, TensorCodec +from openfl.protocols import base_pb2, utils +from openfl.utilities import TaskResultKey, TensorKey, change_tags from openfl.utilities.logs import write_metric class Aggregator: r"""An Aggregator is the central node in federated learning. - Args: - aggregator_uuid (str): Aggregation ID. - federation_uuid (str): Federation ID. - authorized_cols (list of str): The list of IDs of enrolled collaborators. - init_state_path* (str): The location of the initial weight file. - last_state_path* (str): The file location to store the latest weight. - best_state_path* (str): The file location to store the weight of the best model. + Attributes: + round_number (int): Current round number. + single_col_cert_common_name (str): Common name for single + collaborator certificate. + straggler_handling_policy: Policy for handling stragglers. + _end_of_round_check_done (list of bool): Indicates if end of round + check is done for each round. + stragglers (list): List of stragglers. + rounds_to_train (int): Number of rounds to train. + authorized_cols (list of str): IDs of enrolled collaborators. + uuid (int): Aggregator UUID. + federation_uuid (str): Federation UUID. + assigner: Object assigning tasks to collaborators. + quit_job_sent_to (list): Collaborators sent a quit job. + tensor_db (TensorDB): Object for tensor database. db_store_rounds* (int): Rounds to store in TensorDB. - - Note: + logger: Object for logging. + write_logs (bool): Flag to enable log writing. + log_metric_callback: Callback for logging metrics. + best_model_score (optional): Score of the best model. Defaults to + None. + metric_queue (queue.Queue): Queue for metrics. + compression_pipeline: Pipeline for compressing data. + tensor_codec (TensorCodec): Codec for tensor compression. + init_state_path* (str): Initial weight file location. + best_state_path* (str): Where to store the best model weight. + last_state_path* (str): Where to store the latest model weight. + best_tensor_dict (dict): Dict of the best tensors. + last_tensor_dict (dict): Dict of the last tensors. + collaborator_tensor_results (dict): Dict of collaborator tensor + results. + collaborator_tasks_results (dict): Dict of collaborator tasks + results. + collaborator_task_weight (dict): Dict of collaborator task weight. + + .. note:: \* - plan setting. """ def __init__(self, - aggregator_uuid, federation_uuid, authorized_cols, - init_state_path, best_state_path, last_state_path, - assigner, straggler_handling_policy=None, rounds_to_train=256, @@ -54,7 +75,35 @@ def __init__(self, write_logs=False, log_metric_callback=None, **kwargs): - """Initialize.""" + """Initializes the Aggregator. + + Args: + aggregator_uuid (int): Aggregation ID. + federation_uuid (str): Federation ID. + authorized_cols (list of str): The list of IDs of enrolled + collaborators. + init_state_path (str): The location of the initial weight file. + best_state_path (str): The file location to store the weight of + the best model. + last_state_path (str): The file location to store the latest + weight. + assigner: Assigner object. + straggler_handling_policy (optional): Straggler handling policy. + Defaults to CutoffTimeBasedStragglerHandling. + rounds_to_train (int, optional): Number of rounds to train. + Defaults to 256. + single_col_cert_common_name (str, optional): Common name for single + collaborator certificate. Defaults to None. + compression_pipeline (optional): Compression pipeline. Defaults to + NoCompressionPipeline. + db_store_rounds (int, optional): Rounds to store in TensorDB. + Defaults to 1. + write_logs (bool, optional): Whether to write logs. Defaults to + False. + log_metric_callback (optional): Callback for log metric. Defaults + to None. + **kwargs: Additional keyword arguments. + """ self.round_number = 0 self.single_col_cert_common_name = single_col_cert_common_name @@ -116,7 +165,8 @@ def __init__(self, round_number=0, tensor_pipe=self.compression_pipeline) else: - self.model: base_pb2.ModelProto = utils.load_proto(self.init_state_path) + self.model: base_pb2.ModelProto = utils.load_proto( + self.init_state_path) self._load_initial_tensors() # keys are TensorKeys self.collaborator_tensor_results = {} # {TensorKey: nparray}} @@ -127,8 +177,7 @@ def __init__(self, self.collaborator_task_weight = {} # {TaskResultKey: data_size} def _load_initial_tensors(self): - """ - Load all of the tensors required to begin federated learning. + """Load all of the tensors required to begin federated learning. Required tensors are: \ 1. Initial model. @@ -153,8 +202,7 @@ def _load_initial_tensors(self): self.logger.debug(f'This is the initial tensor_db: {self.tensor_db}') def _load_initial_tensors_from_dict(self, tensor_dict): - """ - Load all of the tensors required to begin federated learning. + """Load all of the tensors required to begin federated learning. Required tensors are: \ 1. Initial model. @@ -171,14 +219,11 @@ def _load_initial_tensors_from_dict(self, tensor_dict): self.logger.debug(f'This is the initial tensor_db: {self.tensor_db}') def _save_model(self, round_number, file_path): - """ - Save the best or latest model. + """Save the best or latest model. Args: - round_number: int - Model round to be saved - file_path: str - Either the best model or latest model file path + round_number (int): Model round to be saved. + file_path (str): Either the best model or latest model file path. Returns: None @@ -208,17 +253,16 @@ def _save_model(self, round_number, file_path): def valid_collaborator_cn_and_id(self, cert_common_name, collaborator_common_name): - """ - Determine if the collaborator certificate and ID are valid for this federation. + """Determine if the collaborator certificate and ID are valid for this + federation. Args: - cert_common_name: Common name for security certificate - collaborator_common_name: Common name for collaborator + cert_common_name (str): Common name for security certificate. + collaborator_common_name (str): Common name for collaborator. Returns: bool: True means the collaborator common name matches the name in - the security certificate. - + the security certificate. """ # if self.test_mode_whitelist is None, then the common_name must # match collaborator_common_name and be in authorized_cols @@ -234,45 +278,45 @@ def valid_collaborator_cn_and_id(self, cert_common_name, and collaborator_common_name in self.authorized_cols) def all_quit_jobs_sent(self): - """Assert all quit jobs are sent to collaborators.""" + """Assert all quit jobs are sent to collaborators. + + Returns: + bool: True if all quit jobs are sent, False otherwise. + """ return set(self.quit_job_sent_to) == set(self.authorized_cols) @staticmethod def _get_sleep_time(): - """ - Sleep 10 seconds. + """Sleep 10 seconds. Returns: - sleep_time: int + int: Sleep time. """ # Decrease sleep period for finer discretezation return 10 def _time_to_quit(self): - """ - If all rounds are complete, it's time to quit. + """If all rounds are complete, it's time to quit. Returns: - is_time_to_quit: bool + bool: True if it's time to quit, False otherwise. """ if self.round_number >= self.rounds_to_train: return True return False def get_tasks(self, collaborator_name): - """ - RPC called by a collaborator to determine which tasks to perform. + """RPC called by a collaborator to determine which tasks to perform. Args: - collaborator_name: str - Requested collaborator name + collaborator_name (str): Requested collaborator name. Returns: - tasks: list[str] - List of tasks to be performed by the requesting collaborator - for the current round. - sleep_time: int - time_to_quit: bool + tasks (list[str]): List of tasks to be performed by the requesting + collaborator for the current round. + round_number (int): Actual round number. + sleep_time (int): Sleep time. + time_to_quit (bool): Whether it's time to quit. """ self.logger.debug( f'Aggregator GetTasks function reached from collaborator {collaborator_name}...' @@ -291,7 +335,8 @@ def get_tasks(self, collaborator_name): time_to_quit = False # otherwise, get the tasks from our task assigner - tasks = self.assigner.get_tasks_for_collaborator(collaborator_name, self.round_number) + tasks = self.assigner.get_tasks_for_collaborator( + collaborator_name, self.round_number) # if no tasks, tell the collaborator to sleep if len(tasks) == 0: @@ -338,26 +383,28 @@ def get_tasks(self, collaborator_name): def get_aggregated_tensor(self, collaborator_name, tensor_name, round_number, report, tags, require_lossless): - """ - RPC called by collaborator. + """RPC called by collaborator. - Performs local lookup to determine if there is an aggregated tensor available \ - that matches the request. + Performs local lookup to determine if there is an aggregated tensor available + that matches the request. Args: - collaborator_name : str - Requested tensor key collaborator name - tensor_name: str - require_lossless: bool - round_number: int - report: bool - tags: tuple[str, ...] + collaborator_name (str): Requested tensor key collaborator name. + tensor_name (str): Name of the tensor. + round_number (int): Actual round number. + report (bool): Whether to report. + tags (tuple[str, ...]): Tags. + require_lossless (bool): Whether to require lossless. + Returns: - named_tensor : protobuf NamedTensor - the tensor requested by the collaborator + named_tensor (protobuf) : NamedTensor, the tensor requested by the collaborator. + + Raises: + ValueError: if Aggregator does not have an aggregated tensor for {tensor_key}. """ - self.logger.debug(f'Retrieving aggregated tensor {tensor_name},{round_number},{tags} ' - f'for collaborator {collaborator_name}') + self.logger.debug( + f'Retrieving aggregated tensor {tensor_name},{round_number},{tags} ' + f'for collaborator {collaborator_name}') if 'compressed' in tags or require_lossless: compress_lossless = True @@ -372,15 +419,13 @@ def get_aggregated_tensor(self, collaborator_name, tensor_name, if 'lossy_compressed' in tags: tags = change_tags(tags, remove_field='lossy_compressed') - tensor_key = TensorKey( - tensor_name, self.uuid, round_number, report, tags - ) + tensor_key = TensorKey(tensor_name, self.uuid, round_number, report, + tags) tensor_name, origin, round_number, report, tags = tensor_key if 'aggregated' in tags and 'delta' in tags and round_number != 0: - agg_tensor_key = TensorKey( - tensor_name, origin, round_number, report, ('aggregated',) - ) + agg_tensor_key = TensorKey(tensor_name, origin, round_number, + report, ('aggregated', )) else: agg_tensor_key = tensor_key @@ -395,7 +440,9 @@ def get_aggregated_tensor(self, collaborator_name, tensor_name, break if nparray is None: - raise ValueError(f'Aggregator does not have an aggregated tensor for {tensor_key}') + raise ValueError( + f'Aggregator does not have an aggregated tensor for {tensor_key}' + ) # quite a bit happens in here, including compression, delta handling, # etc... @@ -404,28 +451,35 @@ def get_aggregated_tensor(self, collaborator_name, tensor_name, agg_tensor_key, nparray, send_model_deltas=True, - compress_lossless=compress_lossless - ) + compress_lossless=compress_lossless) return named_tensor def _nparray_to_named_tensor(self, tensor_key, nparray, send_model_deltas, compress_lossless): - """ - Construct the NamedTensor Protobuf. + """Construct the NamedTensor Protobuf. + + Also includes logic to create delta, compress tensors with the + TensorCodec, etc. + + Args: + tensor_key (TensorKey): Tensor key. + nparray (np.array): Numpy array. + send_model_deltas (bool): Whether to send model deltas. + compress_lossless (bool): Whether to compress lossless. + + Returns: + tensor_key (TensorKey): Tensor key. + nparray (np.array): Numpy array. - Also includes logic to create delta, compress tensors with the TensorCodec, etc. """ tensor_name, origin, round_number, report, tags = tensor_key # if we have an aggregated tensor, we can make a delta if 'aggregated' in tags and send_model_deltas: # Should get the pretrained model to create the delta. If training # has happened, Model should already be stored in the TensorDB - model_tk = TensorKey(tensor_name, - origin, - round_number - 1, - report, - ('model',)) + model_tk = TensorKey(tensor_name, origin, round_number - 1, report, + ('model', )) model_nparray = self.tensor_db.get_tensor_from_cache(model_tk) @@ -433,105 +487,87 @@ def _nparray_to_named_tensor(self, tensor_key, nparray, send_model_deltas, 'The original model layer should be present if the latest ' 'aggregated model is present') delta_tensor_key, delta_nparray = self.tensor_codec.generate_delta( - tensor_key, - nparray, - model_nparray - ) + tensor_key, nparray, model_nparray) delta_comp_tensor_key, delta_comp_nparray, metadata = self.tensor_codec.compress( - delta_tensor_key, - delta_nparray, - lossless=compress_lossless - ) + delta_tensor_key, delta_nparray, lossless=compress_lossless) named_tensor = utils.construct_named_tensor( delta_comp_tensor_key, delta_comp_nparray, metadata, - lossless=compress_lossless - ) + lossless=compress_lossless) else: # Assume every other tensor requires lossless compression compressed_tensor_key, compressed_nparray, metadata = self.tensor_codec.compress( - tensor_key, - nparray, - require_lossless=True - ) + tensor_key, nparray, require_lossless=True) named_tensor = utils.construct_named_tensor( compressed_tensor_key, compressed_nparray, metadata, - lossless=compress_lossless - ) + lossless=compress_lossless) return named_tensor def _collaborator_task_completed(self, collaborator, task_name, round_num): - """ - Check if the collaborator has completed the task for the round. + """Check if the collaborator has completed the task for the round. - The aggregator doesn't actually know which tensors should be sent from the collaborator \ - so it must to rely specifically on the presence of previous results + The aggregator doesn't actually know which tensors should be sent from + the collaborator so it must to rely specifically on the presence of + previous results. - Args: - collaborator : str - collaborator to check if their task has been completed - task_name : str - The name of the task (TaskRunner function) - round_num : int + Args: + collaborator (str): Collaborator to check if their task has been + completed. + task_name (str): The name of the task (TaskRunner function). + round_num (int): Round number. Returns: - task_competed : bool - Whether or not the collaborator has completed the task for this - round + bool: Whether or not the collaborator has completed the task for + this round. """ task_key = TaskResultKey(task_name, collaborator, round_num) return task_key in self.collaborator_tasks_results - def send_local_task_results(self, collaborator_name, round_number, task_name, - data_size, named_tensors): - """ - RPC called by collaborator. + def send_local_task_results(self, collaborator_name, round_number, + task_name, data_size, named_tensors): + """RPC called by collaborator. Transmits collaborator's task results to the aggregator. Args: - collaborator_name: str - task_name: str - round_number: int - data_size: int - named_tensors: protobuf NamedTensor + collaborator_name (str): Collaborator name. + round_number (int): Round number. + task_name (str): Task name. + data_size (int): Data size. + named_tensors (protobuf NamedTensor): Named tensors. + Returns: - None + None """ if self._time_to_quit() or self._is_task_done(task_name): self.logger.warning( f'STRAGGLER: Collaborator {collaborator_name} is reporting results ' - 'after task {task_name} has finished.' - ) + 'after task {task_name} has finished.') return if self.round_number != round_number: self.logger.warning( f'Collaborator {collaborator_name} is reporting results' - f' for the wrong round: {round_number}. Ignoring...' - ) + f' for the wrong round: {round_number}. Ignoring...') return self.logger.info( f'Collaborator {collaborator_name} is sending task results ' - f'for {task_name}, round {round_number}' - ) + f'for {task_name}, round {round_number}') task_key = TaskResultKey(task_name, collaborator_name, round_number) # we mustn't have results already - if self._collaborator_task_completed( - collaborator_name, task_name, round_number - ): + if self._collaborator_task_completed(collaborator_name, task_name, + round_number): raise ValueError( f'Aggregator already has task results from collaborator {collaborator_name}' - f' for task {task_key}' - ) + f' for task {task_key}') # By giving task_key it's own weight, we can support different # training/validation weights @@ -568,120 +604,114 @@ def send_local_task_results(self, collaborator_name, round_number, task_name, self._end_of_task_check(task_name) def _process_named_tensor(self, named_tensor, collaborator_name): - """ - Extract the named tensor fields. + """Extract the named tensor fields. - Performs decompression, delta computation, and inserts results into TensorDB. + Performs decompression, delta computation, and inserts results into + TensorDB. Args: - named_tensor: NamedTensor (protobuf) + named_tensor (protobuf NamedTensor): Named tensor. protobuf that will be extracted from and processed - collaborator_name: str + collaborator_name (str): Collaborator name. Collaborator name is needed for proper tagging of resulting - tensorkeys + tensorkeys. Returns: - tensor_key : TensorKey (named_tuple) - The tensorkey extracted from the protobuf - nparray : np.array - The numpy array associated with the returned tensorkey + tensor_key (TensorKey): Tensor key. + The tensorkey extracted from the protobuf. + nparray (np.array): Numpy array. + The numpy array associated with the returned tensorkey. """ raw_bytes = named_tensor.data_bytes - metadata = [{'int_to_float': proto.int_to_float, - 'int_list': proto.int_list, - 'bool_list': proto.bool_list} - for proto in named_tensor.transformer_metadata] + metadata = [{ + 'int_to_float': proto.int_to_float, + 'int_list': proto.int_list, + 'bool_list': proto.bool_list + } for proto in named_tensor.transformer_metadata] # The tensor has already been transfered to aggregator, # so the newly constructed tensor should have the aggregator origin - tensor_key = TensorKey( - named_tensor.name, - self.uuid, - named_tensor.round_number, - named_tensor.report, - tuple(named_tensor.tags) - ) + tensor_key = TensorKey(named_tensor.name, self.uuid, + named_tensor.round_number, named_tensor.report, + tuple(named_tensor.tags)) tensor_name, origin, round_number, report, tags = tensor_key - assert ('compressed' in tags or 'lossy_compressed' in tags), ( - f'Named tensor {tensor_key} is not compressed' - ) + assert ('compressed' in tags or 'lossy_compressed' + in tags), (f'Named tensor {tensor_key} is not compressed') if 'compressed' in tags: dec_tk, decompressed_nparray = self.tensor_codec.decompress( tensor_key, data=raw_bytes, transformer_metadata=metadata, - require_lossless=True - ) + require_lossless=True) dec_name, dec_origin, dec_round_num, dec_report, dec_tags = dec_tk # Need to add the collaborator tag to the resulting tensor new_tags = change_tags(dec_tags, add_field=collaborator_name) # layer.agg.n.trained.delta.col_i - decompressed_tensor_key = TensorKey( - dec_name, dec_origin, dec_round_num, dec_report, new_tags - ) + decompressed_tensor_key = TensorKey(dec_name, dec_origin, + dec_round_num, dec_report, + new_tags) if 'lossy_compressed' in tags: dec_tk, decompressed_nparray = self.tensor_codec.decompress( tensor_key, data=raw_bytes, transformer_metadata=metadata, - require_lossless=False - ) + require_lossless=False) dec_name, dec_origin, dec_round_num, dec_report, dec_tags = dec_tk new_tags = change_tags(dec_tags, add_field=collaborator_name) # layer.agg.n.trained.delta.lossy_decompressed.col_i - decompressed_tensor_key = TensorKey( - dec_name, dec_origin, dec_round_num, dec_report, new_tags - ) + decompressed_tensor_key = TensorKey(dec_name, dec_origin, + dec_round_num, dec_report, + new_tags) if 'delta' in tags: - base_model_tensor_key = TensorKey( - tensor_name, origin, round_number, report, ('model',) - ) + base_model_tensor_key = TensorKey(tensor_name, origin, + round_number, report, + ('model', )) base_model_nparray = self.tensor_db.get_tensor_from_cache( - base_model_tensor_key - ) + base_model_tensor_key) if base_model_nparray is None: - raise ValueError(f'Base model {base_model_tensor_key} not present in TensorDB') + raise ValueError( + f'Base model {base_model_tensor_key} not present in TensorDB' + ) final_tensor_key, final_nparray = self.tensor_codec.apply_delta( - decompressed_tensor_key, - decompressed_nparray, base_model_nparray - ) + decompressed_tensor_key, decompressed_nparray, + base_model_nparray) else: final_tensor_key = decompressed_tensor_key final_nparray = decompressed_nparray - assert (final_nparray is not None), f'Could not create tensorkey {final_tensor_key}' + assert (final_nparray + is not None), f'Could not create tensorkey {final_tensor_key}' self.tensor_db.cache_tensor({final_tensor_key: final_nparray}) self.logger.debug(f'Created TensorKey: {final_tensor_key}') return final_tensor_key, final_nparray def _end_of_task_check(self, task_name): - """ - Check whether all collaborators who are supposed to perform the task complete. + """Check whether all collaborators who are supposed to perform the + task complete. Args: - task_name : str - The task name to check + task_name (str): Task name. + The task name to check. Returns: - complete : boolean - Is the task done + bool: Whether the task is done. """ if self._is_task_done(task_name): # now check for the end of the round self._end_of_round_check() - def _prepare_trained(self, tensor_name, origin, round_number, report, agg_results): - """ - Prepare aggregated tensorkey tags. + def _prepare_trained(self, tensor_name, origin, round_number, report, + agg_results): + """Prepare aggregated tensorkey tags. Args: - tensor_name : str - origin: - round_number: int - report: bool - agg_results: np.array + tensor_name (str): Tensor name. + origin: Origin. + round_number (int): Round number. + report (bool): Whether to report. + agg_results (np.array): Aggregated results. """ # The aggregated tensorkey tags should have the form of # 'trained' or 'trained.lossy_decompressed' @@ -691,30 +721,18 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result # First insert the aggregated model layer with the # correct tensorkey - agg_tag_tk = TensorKey( - tensor_name, - origin, - round_number + 1, - report, - ('aggregated',) - ) + agg_tag_tk = TensorKey(tensor_name, origin, round_number + 1, report, + ('aggregated', )) self.tensor_db.cache_tensor({agg_tag_tk: agg_results}) # Create delta and save it in TensorDB - base_model_tk = TensorKey( - tensor_name, - origin, - round_number, - report, - ('model',) - ) - base_model_nparray = self.tensor_db.get_tensor_from_cache(base_model_tk) + base_model_tk = TensorKey(tensor_name, origin, round_number, report, + ('model', )) + base_model_nparray = self.tensor_db.get_tensor_from_cache( + base_model_tk) if base_model_nparray is not None: delta_tk, delta_nparray = self.tensor_codec.generate_delta( - agg_tag_tk, - agg_results, - base_model_nparray - ) + agg_tag_tk, agg_results, base_model_nparray) else: # This condition is possible for base model # optimizer states (i.e. Adam/iter:0, SGD, etc.) @@ -724,8 +742,7 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result # Compress lossless/lossy compressed_delta_tk, compressed_delta_nparray, metadata = self.tensor_codec.compress( - delta_tk, delta_nparray - ) + delta_tk, delta_nparray) # TODO extend the TensorDB so that compressed data is # supported. Once that is in place @@ -734,21 +751,18 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result # Decompress lossless/lossy decompressed_delta_tk, decompressed_delta_nparray = self.tensor_codec.decompress( - compressed_delta_tk, - compressed_delta_nparray, - metadata - ) + compressed_delta_tk, compressed_delta_nparray, metadata) - self.tensor_db.cache_tensor({decompressed_delta_tk: decompressed_delta_nparray}) + self.tensor_db.cache_tensor( + {decompressed_delta_tk: decompressed_delta_nparray}) # Apply delta (unless delta couldn't be created) if base_model_nparray is not None: - self.logger.debug(f'Applying delta for layer {decompressed_delta_tk[0]}') + self.logger.debug( + f'Applying delta for layer {decompressed_delta_tk[0]}') new_model_tk, new_model_nparray = self.tensor_codec.apply_delta( - decompressed_delta_tk, - decompressed_delta_nparray, - base_model_nparray - ) + decompressed_delta_tk, decompressed_delta_nparray, + base_model_nparray) else: new_model_tk, new_model_nparray = decompressed_delta_tk, decompressed_delta_nparray @@ -757,36 +771,30 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result # Relabel the tags to 'model' (new_model_tensor_name, new_model_origin, new_model_round_number, new_model_report, new_model_tags) = new_model_tk - final_model_tk = TensorKey( - new_model_tensor_name, - new_model_origin, - new_model_round_number, - new_model_report, - ('model',) - ) + final_model_tk = TensorKey(new_model_tensor_name, new_model_origin, + new_model_round_number, new_model_report, + ('model', )) # Finally, cache the updated model tensor self.tensor_db.cache_tensor({final_model_tk: new_model_nparray}) def _compute_validation_related_task_metrics(self, task_name): - """ - Compute all validation related metrics. + """Compute all validation related metrics. Args: - task_name : str - The task name to compute + task_name (str): Task name. """ # By default, print out all of the metrics that the validation # task sent # This handles getting the subset of collaborators that may be # part of the validation task all_collaborators_for_task = self.assigner.get_collaborators_for_task( - task_name, self.round_number - ) + task_name, self.round_number) # leave out stragglers for the round collaborators_for_task = [] for c in all_collaborators_for_task: - if self._collaborator_task_completed(c, task_name, self.round_number): + if self._collaborator_task_completed(c, task_name, + self.round_number): collaborators_for_task.append(c) # The collaborator data sizes for that task @@ -805,8 +813,10 @@ def _compute_validation_related_task_metrics(self, task_name): # collaborator in our subset, and apply the correct # transformations to the tensorkey to resolve the aggregated # tensor for that round - task_agg_function = self.assigner.get_aggregation_type_for_task(task_name) - task_key = TaskResultKey(task_name, collaborators_for_task[0], self.round_number) + task_agg_function = self.assigner.get_aggregation_type_for_task( + task_name) + task_key = TaskResultKey(task_name, collaborators_for_task[0], + self.round_number) for tensor_key in self.collaborator_tasks_results[task_key]: tensor_name, origin, round_number, report, tags = tensor_key @@ -814,11 +824,16 @@ def _compute_validation_related_task_metrics(self, task_name): f'Tensor {tensor_key} in task {task_name} has not been processed correctly' ) # Strip the collaborator label, and lookup aggregated tensor - new_tags = change_tags(tags, remove_field=collaborators_for_task[0]) - agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, new_tags) - agg_function = WeightedAverage() if 'metric' in tags else task_agg_function + new_tags = change_tags(tags, + remove_field=collaborators_for_task[0]) + agg_tensor_key = TensorKey(tensor_name, origin, round_number, + report, new_tags) + agg_function = WeightedAverage( + ) if 'metric' in tags else task_agg_function agg_results = self.tensor_db.get_aggregated_tensor( - agg_tensor_key, collaborator_weight_dict, aggregation_function=agg_function) + agg_tensor_key, + collaborator_weight_dict, + aggregation_function=agg_function) if report: # Caution: This schema must be followed. It is also used in @@ -838,20 +853,21 @@ def _compute_validation_related_task_metrics(self, task_name): if 'validate_agg' in tags: # Compare the accuracy of the model, potentially save it if self.best_model_score is None or self.best_model_score < agg_results: - self.logger.metric(f'Round {round_number}: saved the best ' - f'model with score {agg_results:f}') + self.logger.metric( + f'Round {round_number}: saved the best ' + f'model with score {agg_results:f}') self.best_model_score = agg_results self._save_model(round_number, self.best_state_path) if 'trained' in tags: - self._prepare_trained(tensor_name, origin, round_number, report, agg_results) + self._prepare_trained(tensor_name, origin, round_number, + report, agg_results) def _end_of_round_check(self): - """ - Check if the round complete. + """Check if the round complete. If so, perform many end of round operations, such as model aggregation, metric reporting, delta generation (+ - associated tensorkey labeling), and save the model + associated tensorkey labeling), and save the model. Args: None @@ -859,7 +875,8 @@ def _end_of_round_check(self): Returns: None """ - if not self._is_round_done() or self._end_of_round_check_done[self.round_number]: + if not self._is_round_done() or self._end_of_round_check_done[ + self.round_number]: return # Compute all validation related metrics @@ -887,16 +904,21 @@ def _end_of_round_check(self): self.tensor_db.clean_up(self.db_store_rounds) def _is_task_done(self, task_name): - """Check that task is done.""" + """Check that task is done. + + Args: + task_name (str): Task name. + + Returns: + bool: Whether the task is done. + """ all_collaborators = self.assigner.get_collaborators_for_task( - task_name, self.round_number - ) + task_name, self.round_number) collaborators_done = [] for c in all_collaborators: - if self._collaborator_task_completed( - c, task_name, self.round_number - ): + if self._collaborator_task_completed(c, task_name, + self.round_number): collaborators_done.append(c) straggler_check = self.straggler_handling_policy.straggler_cutoff_check( @@ -906,19 +928,26 @@ def _is_task_done(self, task_name): for c in all_collaborators: if c not in collaborators_done: self.stragglers.append(c) - self.logger.info(f'\tEnding task {task_name} early due to straggler cutoff policy') + self.logger.info( + f'\tEnding task {task_name} early due to straggler cutoff policy' + ) self.logger.warning(f'\tIdentified stragglers: {self.stragglers}') # all are done or straggler policy calls for early round end. - return straggler_check or len(all_collaborators) == len(collaborators_done) + return straggler_check or len(all_collaborators) == len( + collaborators_done) def _is_round_done(self): - """Check that round is done.""" - tasks_for_round = self.assigner.get_all_tasks_for_round(self.round_number) + """Check that round is done. + + Returns: + bool: Whether the round is done. + """ + tasks_for_round = self.assigner.get_all_tasks_for_round( + self.round_number) return all( - self._is_task_done( - task_name) for task_name in tasks_for_round) + self._is_task_done(task_name) for task_name in tasks_for_round) def _log_big_warning(self): """Warn user about single collaborator cert mode.""" @@ -926,11 +955,17 @@ def _log_big_warning(self): f'\n{the_dragon}\nYOU ARE RUNNING IN SINGLE COLLABORATOR CERT MODE! THIS IS' f' NOT PROPER PKI AND ' f'SHOULD ONLY BE USED IN DEVELOPMENT SETTINGS!!!! YE HAVE BEEN' - f' WARNED!!!' - ) + f' WARNED!!!') def stop(self, failed_collaborator: str = None) -> None: - """Stop aggregator execution.""" + """Stop aggregator execution. + + Args: + failed_collaborator (str, optional): Failed collaborator. Defaults to None. + + Returns: + None + """ self.logger.info('Force stopping the aggregator execution.') # We imitate quit_job_sent_to the failed collaborator # So the experiment set to a finished state @@ -939,8 +974,11 @@ def stop(self, failed_collaborator: str = None) -> None: # This code does not actually send `quit` tasks to collaborators, # it just mimics it by filling arrays. - for collaborator_name in filter(lambda c: c != failed_collaborator, self.authorized_cols): - self.logger.info(f'Sending signal to collaborator {collaborator_name} to shutdown...') + for collaborator_name in filter(lambda c: c != failed_collaborator, + self.authorized_cols): + self.logger.info( + f'Sending signal to collaborator {collaborator_name} to shutdown...' + ) self.quit_job_sent_to.append(collaborator_name) diff --git a/openfl/component/assigner/__init__.py b/openfl/component/assigner/__init__.py index 5c3dbdc8c8..e6ea55b325 100644 --- a/openfl/component/assigner/__init__.py +++ b/openfl/component/assigner/__init__.py @@ -1,15 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -"""Assigner package.""" - -from .assigner import Assigner -from .random_grouped_assigner import RandomGroupedAssigner -from .static_grouped_assigner import StaticGroupedAssigner - - -__all__ = [ - 'Assigner', - 'RandomGroupedAssigner', - 'StaticGroupedAssigner', -] +from openfl.component.assigner.assigner import Assigner +from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner +from openfl.component.assigner.static_grouped_assigner import StaticGroupedAssigner \ No newline at end of file diff --git a/openfl/component/assigner/assigner.py b/openfl/component/assigner/assigner.py index 0c2a352b95..02235346a9 100644 --- a/openfl/component/assigner/assigner.py +++ b/openfl/component/assigner/assigner.py @@ -4,35 +4,43 @@ class Assigner: - r""" - The task assigner maintains a list of tasks. + r"""The task assigner maintains a list of tasks. Also it decides the policy for which collaborator should run those tasks. - There may be many types of policies implemented, but a natural place to start is with a: - - RandomGroupedTaskAssigner - Given a set of task groups, and a percentage, - assign that task group to that - percentage of collaborators in the federation. - After assigning the tasks to - collaborator, those tasks should be carried - out each round (no reassignment - between rounds) - GroupedTaskAssigner - Given task groups and a list of collaborators that - belong to that task group, - carry out tasks for each round of experiment - - Args: - tasks* (list of object): list of tasks to assign. - authorized_cols (list of str): collaborators. - rounds_to_train (int): number of training rounds. - - Note: + There may be many types of policies implemented, but a natural place to start + is with a: + + - RandomGroupedTaskAssigner : + Given a set of task groups, and a percentage, + assign that task group to that percentage of collaborators in the federation. + After assigning the tasks to collaborator, those tasks should be carried + out each round (no reassignment between rounds). + - GroupedTaskAssigner : + Given task groups and a list of collaborators that belong to that task group, + carry out tasks for each round of experiment. + + Attributes: + tasks* (list of object): List of tasks to assign. + authorized_cols (list of str): Collaborators. + rounds (int): Number of rounds to train. + all_tasks_in_groups (list): All tasks in groups. + task_group_collaborators (dict): Task group collaborators. + collaborators_for_task (dict): Collaborators for each task. + collaborator_tasks (dict): Tasks for each collaborator. + + .. note:: \* - ``tasks`` argument is taken from ``tasks`` section of FL plan YAML file. """ - def __init__(self, tasks, authorized_cols, - rounds_to_train, **kwargs): - """Initialize.""" + def __init__(self, tasks, authorized_cols, rounds_to_train, **kwargs): + """Initializes the Assigner. + + Args: + tasks (list of object): List of tasks to assign. + authorized_cols (list of str): Collaborators. + rounds_to_train (int): Number of training rounds. + **kwargs: Additional keyword arguments. + """ self.tasks = tasks self.authorized_cols = authorized_cols self.rounds = rounds_to_train @@ -57,16 +65,28 @@ def get_collaborators_for_task(self, task_name, round_number): raise NotImplementedError def get_all_tasks_for_round(self, round_number): - """ - Return tasks for the current round. + """Return tasks for the current round. Currently all tasks are performed on each round, But there may be a reason to change this. + + Args: + round_number (int): Round number. + + Returns: + list: List of tasks for the current round. """ return self.all_tasks_in_groups def get_aggregation_type_for_task(self, task_name): - """Extract aggregation type from self.tasks.""" + """Extract aggregation type from self.tasks. + + Args: + task_name (str): Name of the task. + + Returns: + str: Aggregation type for the task. + """ if 'aggregation_type' not in self.tasks[task_name]: return None return self.tasks[task_name]['aggregation_type'] diff --git a/openfl/component/assigner/custom_assigner.py b/openfl/component/assigner/custom_assigner.py index 134a16c257..cfa7f7e5ef 100644 --- a/openfl/component/assigner/custom_assigner.py +++ b/openfl/component/assigner/custom_assigner.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 """Custom Assigner module.""" - import logging from collections import defaultdict @@ -12,17 +11,37 @@ class Assigner: - """Custom assigner class.""" - - def __init__( - self, - *, - assigner_function, - aggregation_functions_by_task, - authorized_cols, - rounds_to_train - ): - """Initialize.""" + """Custom assigner class. + + Attributes: + agg_functions_by_task (dict): Dictionary mapping tasks to their + respective aggregation functions. + agg_functions_by_task_name (dict): Dictionary mapping task names to + their respective aggregation functions. + authorized_cols (list of str): List of authorized collaborators. + rounds_to_train (int): Number of rounds to train. + all_tasks_for_round (defaultdict): Dictionary mapping round numbers + to tasks. + collaborators_for_task (defaultdict): Dictionary mapping round numbers + to collaborators for each task. + collaborator_tasks (defaultdict): Dictionary mapping round numbers + to tasks for each collaborator. + assigner_function (function): Function to assign tasks to + collaborators. + """ + + def __init__(self, *, assigner_function, aggregation_functions_by_task, + authorized_cols, rounds_to_train): + """Initialize the Custom assigner object. + + Args: + assigner_function (function): Function to assign tasks to + collaborators. + aggregation_functions_by_task (dict): Dictionary mapping tasks to + their respective aggregation functions. + authorized_cols (list of str): List of authorized collaborators. + rounds_to_train (int): Number of rounds to train. + """ self.agg_functions_by_task = aggregation_functions_by_task self.agg_functions_by_task_name = {} self.authorized_cols = authorized_cols @@ -35,7 +54,20 @@ def __init__( self.define_task_assignments() def define_task_assignments(self): - """Abstract method.""" + """Define task assignments for each round and collaborator. + + This method uses the assigner function to assign tasks to + collaborators for each round. It also maps tasks to their respective + aggregation functions. + + Abstract method. + + Args: + None + + Returns: + None + """ for round_number in range(self.rounds_to_train): tasks_by_collaborator = self.assigner_function( self.authorized_cols, @@ -53,23 +85,59 @@ def define_task_assignments(self): ] = self.agg_functions_by_task.get(task.function_name, WeightedAverage()) def get_tasks_for_collaborator(self, collaborator_name, round_number): - """Abstract method.""" + """Get tasks for a specific collaborator in a specific round. + + Abstract method. + + Args: + collaborator_name (str): Name of the collaborator. + round_number (int): Round number. + + Returns: + list: List of tasks for the collaborator in the specified round. + """ return self.collaborator_tasks[round_number][collaborator_name] def get_collaborators_for_task(self, task_name, round_number): - """Abstract method.""" + """Get collaborators for a specific task in a specific round. + + Abstract method. + + Args: + task_name (str): Name of the task. + round_number (int): Round number. + + Returns: + list: List of collaborators for the task in the specified round. + """ return self.collaborators_for_task[round_number][task_name] def get_all_tasks_for_round(self, round_number): - """ - Return tasks for the current round. + """Get all tasks for a specific round. Currently all tasks are performed on each round, But there may be a reason to change this. + + Args: + round_number (int): Round number. + + Returns: + list: List of all tasks for the specified round. """ - return [task.name for task in self.all_tasks_for_round[round_number].values()] + return [ + task.name + for task in self.all_tasks_for_round[round_number].values() + ] def get_aggregation_type_for_task(self, task_name): - """Extract aggregation type from self.tasks.""" - agg_fn = self.agg_functions_by_task_name.get(task_name, WeightedAverage()) + """Get the aggregation type for a specific task (from self.tasks). + + Args: + task_name (str): Name of the task. + + Returns: + function: Aggregation function for the task. + """ + agg_fn = self.agg_functions_by_task_name.get(task_name, + WeightedAverage()) return agg_fn diff --git a/openfl/component/assigner/random_grouped_assigner.py b/openfl/component/assigner/random_grouped_assigner.py index 9fb8f62efc..7ecf07cb5e 100644 --- a/openfl/component/assigner/random_grouped_assigner.py +++ b/openfl/component/assigner/random_grouped_assigner.py @@ -1,57 +1,68 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Random grouped assigner module.""" - import numpy as np -from .assigner import Assigner +from openfl.component.assigner import Assigner class RandomGroupedAssigner(Assigner): - r""" - The task assigner maintains a list of tasks. + r"""The task assigner maintains a list of tasks. - Also it decides the policy for - which collaborator should run those tasks + Also it decides the policy for which collaborator should run those tasks There may be many types of policies implemented, but a natural place to start is with a: - RandomGroupedAssigner - Given a set of task groups, and a percentage, - assign that task group to that percentage - of collaborators in the federation. After - assigning the tasks to collaborator, those - tasks should be carried out each round (no - reassignment between rounds) - GroupedAssigner - Given task groups and a list of collaborators - that belong to that task group, - carry out tasks for each round of experiment + - RandomGroupedAssigner : + Given a set of task groups, and a percentage, + assign that task group to that percentage of collaborators in the + federation. + After assigning the tasks to collaborator, those tasks should be + carried out each round (no reassignment between rounds). + - GroupedAssigner : + Given task groups and a list of collaborators that belong to that + task group, carry out tasks for each round of experiment. - Args: - task_groups* (list of object): task groups to assign. + Attributes: + task_groups* (list of object): Task groups to assign. - Note: + .. note:: \* - Plan setting. """ def __init__(self, task_groups, **kwargs): - """Initialize.""" + """Initializes the RandomGroupedAssigner. + + Args: + task_groups (list of object): Task groups to assign. + **kwargs: Additional keyword arguments. + """ self.task_groups = task_groups super().__init__(**kwargs) def define_task_assignments(self): - """All of the logic to set up the map of tasks to collaborators is done here.""" + """Define task assignments for each round and collaborator. + + This method uses the assigner function to assign tasks to + collaborators for each round. + It also maps tasks to their respective aggregation functions. + + Args: + None + + Returns: + None + """ assert (np.abs(1.0 - np.sum([group['percentage'] for group in self.task_groups])) < 0.01), ( 'Task group percentages must sum to 100%') # Start by finding all of the tasks in all specified groups - self.all_tasks_in_groups = list({ - task - for group in self.task_groups - for task in group['tasks'] - }) + self.all_tasks_in_groups = list( + {task + for group in self.task_groups + for task in group['tasks']}) # Initialize the map of collaborators for a given task on a given round for task in self.all_tasks_in_groups: @@ -64,11 +75,9 @@ def define_task_assignments(self): col_list_size = len(self.authorized_cols) for round_num in range(self.rounds): - randomized_col_idx = np.random.choice( - len(self.authorized_cols), - len(self.authorized_cols), - replace=False - ) + randomized_col_idx = np.random.choice(len(self.authorized_cols), + len(self.authorized_cols), + replace=False) col_idx = 0 for group in self.task_groups: num_col_in_group = int(group['percentage'] * col_list_size) @@ -76,21 +85,40 @@ def define_task_assignments(self): self.authorized_cols[i] for i in randomized_col_idx[col_idx:col_idx + num_col_in_group] ] - self.task_group_collaborators[group['name']] = rand_col_group_list + self.task_group_collaborators[ + group['name']] = rand_col_group_list for col in rand_col_group_list: self.collaborator_tasks[col][round_num] = group['tasks'] # Now populate reverse lookup of tasks->group for task in group['tasks']: # This should append the list of collaborators performing # that task - self.collaborators_for_task[task][round_num] += rand_col_group_list + self.collaborators_for_task[task][ + round_num] += rand_col_group_list col_idx += num_col_in_group - assert (col_idx == col_list_size), 'Task groups were not divided properly' + assert (col_idx == col_list_size + ), 'Task groups were not divided properly' def get_tasks_for_collaborator(self, collaborator_name, round_number): - """Get tasks for the collaborator specified.""" + """Get tasks for a specific collaborator in a specific round. + + Args: + collaborator_name (str): Name of the collaborator. + round_number (int): Round number. + + Returns: + list: List of tasks for the collaborator in the specified round. + """ return self.collaborator_tasks[collaborator_name][round_number] def get_collaborators_for_task(self, task_name, round_number): - """Get collaborators for the task specified.""" + """Get collaborators for a specific task in a specific round. + + Args: + task_name (str): Name of the task. + round_number (int): Round number. + + Returns: + list: List of collaborators for the task in the specified round. + """ return self.collaborators_for_task[task_name][round_number] diff --git a/openfl/component/assigner/static_grouped_assigner.py b/openfl/component/assigner/static_grouped_assigner.py index 835e8a7541..e051baa587 100644 --- a/openfl/component/assigner/static_grouped_assigner.py +++ b/openfl/component/assigner/static_grouped_assigner.py @@ -1,47 +1,60 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Static grouped assigner module.""" -from .assigner import Assigner +from openfl.component.assigner import Assigner class StaticGroupedAssigner(Assigner): - r""" - The task assigner maintains a list of tasks. + r"""The task assigner maintains a list of tasks. - Also it decides the policy for - which collaborator should run those tasks + Also it decides the policy for which collaborator should run those tasks There may be many types of policies implemented, but a natural place to start is with a: - StaticGroupedAssigner - Given a set of task groups, and a list of - collaborators for that group, assign tasks for - of collaborators in the federation. After assigning - the tasks to collaborator, those tasks - should be carried out each round (no reassignment - between rounds) - GroupedAssigner - Given task groups and a list of collaborators that - belong to that task group, carry out tasks for - each round of experiment - - Args: - task_groups* (list of obj): task groups to assign. - - Note: + - StaticGroupedAssigner : + Given a set of task groups, and a list of + collaborators for that group, assign tasks for of collaborators in + the federation. + After assigning the tasks to collaborator, those tasks should be + carried out each round (no reassignment between rounds). + - GroupedAssigner : + Given task groups and a list of collaborators that + belong to that task group, carry out tasks for each round of + experiment. + + Attributes: + task_groups* (list of object): Task groups to assign. + + .. note:: \* - Plan setting. """ def __init__(self, task_groups, **kwargs): - """Initialize.""" + """Initializes the StaticGroupedAssigner. + + Args: + task_groups (list of object): Task groups to assign. + **kwargs: Additional keyword arguments. + """ self.task_groups = task_groups super().__init__(**kwargs) def define_task_assignments(self): - """All of the logic to set up the map of tasks to collaborators is done here.""" - cols_amount = sum([ - len(group['collaborators']) for group in self.task_groups - ]) + """Define task assignments for each round and collaborator. + + This method uses the assigner function to assign tasks to + collaborators for each round. + It also maps tasks to their respective aggregation functions. + + Args: + None + + Returns: + None + """ + cols_amount = sum( + [len(group['collaborators']) for group in self.task_groups]) authorized_cols_amount = len(self.authorized_cols) unique_cols = { @@ -51,22 +64,22 @@ def define_task_assignments(self): } unique_authorized_cols = set(self.authorized_cols) - assert (cols_amount == authorized_cols_amount and unique_cols == unique_authorized_cols), ( - f'Collaborators in each group must be distinct: ' - f'{unique_cols}, {unique_authorized_cols}' - ) + assert (cols_amount == authorized_cols_amount + and unique_cols == unique_authorized_cols), ( + f'Collaborators in each group must be distinct: ' + f'{unique_cols}, {unique_authorized_cols}') # Start by finding all of the tasks in all specified groups - self.all_tasks_in_groups = list({ - task - for group in self.task_groups - for task in group['tasks'] - }) + self.all_tasks_in_groups = list( + {task + for group in self.task_groups + for task in group['tasks']}) # Initialize the map of collaborators for a given task on a given round for task in self.all_tasks_in_groups: self.collaborators_for_task[task] = { - i: [] for i in range(self.rounds) + i: [] + for i in range(self.rounds) } for group in self.task_groups: @@ -86,9 +99,25 @@ def define_task_assignments(self): self.collaborators_for_task[task][round_] += group_col_list def get_tasks_for_collaborator(self, collaborator_name, round_number): - """Get tasks for the collaborator specified.""" + """Get tasks for a specific collaborator in a specific round. + + Args: + collaborator_name (str): Name of the collaborator. + round_number (int): Round number. + + Returns: + list: List of tasks for the collaborator in the specified round. + """ return self.collaborator_tasks[collaborator_name][round_number] def get_collaborators_for_task(self, task_name, round_number): - """Get collaborators for the task specified.""" + """Get collaborators for a specific task in a specific round. + + Args: + task_name (str): Name of the task. + round_number (int): Round number. + + Returns: + list: List of collaborators for the task in the specified round. + """ return self.collaborators_for_task[task_name][round_number] diff --git a/openfl/component/assigner/tasks.py b/openfl/component/assigner/tasks.py index 1ca7f07323..b249778cbd 100644 --- a/openfl/component/assigner/tasks.py +++ b/openfl/component/assigner/tasks.py @@ -2,31 +2,61 @@ # SPDX-License-Identifier: Apache-2.0 """Task module.""" - -from dataclasses import dataclass -from dataclasses import field +from dataclasses import dataclass, field @dataclass class Task: - """Task base dataclass.""" - + """Task base dataclass. + + Args: + name (str): Name of the task. + function_name (str): Name of the function to be executed for the task. + task_type (str): Type of the task. + apply_local (bool, optional): Whether to apply the task locally. + Defaults to False. + parameters (dict, optional): Parameters for the task. Defaults to an + empty dictionary. + """ name: str function_name: str task_type: str apply_local: bool = False - parameters: dict = field(default_factory=dict) # We can expend it in the future + parameters: dict = field( + default_factory=dict) # We can expend it in the future @dataclass class TrainTask(Task): - """TrainTask class.""" - + """TrainTask class. + + Args: + name (str): Name of the task. + function_name (str): Name of the function to be executed for the task. + apply_local (bool, optional): Whether to apply the task locally. + Defaults to False. + parameters (dict, optional): Parameters for the task. Defaults to an + empty dictionary. + + Attributes: + task_type (str): Type of the task. Set to 'train'. + """ task_type: str = 'train' @dataclass class ValidateTask(Task): - """Validation Task class.""" - + """Validation Task class. + + Args: + name (str): Name of the task. + function_name (str): Name of the function to be executed for the task. + apply_local (bool, optional): Whether to apply the task locally. + Defaults to False. + parameters (dict, optional): Parameters for the task. Defaults to an + empty dictionary. + + Attributes: + task_type (str): Type of the task. Set to 'validate'. + """ task_type: str = 'validate' diff --git a/openfl/component/collaborator/__init__.py b/openfl/component/collaborator/__init__.py index 3e0bbe1de6..57fc9f3620 100644 --- a/openfl/component/collaborator/__init__.py +++ b/openfl/component/collaborator/__init__.py @@ -1,10 +1,3 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -"""Collaborator package.""" - -from .collaborator import Collaborator - -__all__ = [ - 'Collaborator', -] +from openfl.component.collaborator.collaborator import Collaborator \ No newline at end of file diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index 9fb3d00660..7194770cef 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Collaborator module.""" from enum import Enum @@ -9,14 +8,18 @@ from typing import Tuple from openfl.databases import TensorDB -from openfl.pipelines import NoCompressionPipeline -from openfl.pipelines import TensorCodec +from openfl.pipelines import NoCompressionPipeline, TensorCodec from openfl.protocols import utils from openfl.utilities import TensorKey class DevicePolicy(Enum): - """Device assignment policy.""" + """Device assignment policy. + + Attributes: + CPU_ONLY (int): Assigns tasks to CPU only. + CUDA_PREFERRED (int): Prefers CUDA for task assignment if available. + """ CPU_ONLY = 1 @@ -26,14 +29,12 @@ class DevicePolicy(Enum): class OptTreatment(Enum): """Optimizer Methods. - - RESET tells each collaborator to reset the optimizer state at the beginning - of each round. - - - CONTINUE_LOCAL tells each collaborator to continue with the local optimizer - state from the previous round. - - - CONTINUE_GLOBAL tells each collaborator to continue with the federally - averaged optimizer state from the previous round. + Attributes: + RESET (int): Resets the optimizer state at the beginning of each round. + CONTINUE_LOCAL (int): Continues with the local optimizer state from + the previous round. + CONTINUE_GLOBAL (int): Continues with the federally averaged optimizer + state from the previous round. """ RESET = 1 @@ -44,26 +45,23 @@ class OptTreatment(Enum): class Collaborator: r"""The Collaborator object class. - Args: - collaborator_name (string): The common name for the collaborator - aggregator_uuid: The unique id for the client - federation_uuid: The unique id for the federation - model: The model - opt_treatment* (string): The optimizer state treatment (Defaults to - "CONTINUE_GLOBAL", which is aggreagated state from previous round.) - - compression_pipeline: The compression pipeline (Defaults to None) - - num_batches_per_round (int): Number of batches per round - (Defaults to None) - - delta_updates* (bool): True = Only model delta gets sent. - False = Whole model gets sent to collaborator. - Defaults to False. - - single_col_cert_common_name: (Defaults to None) - - Note: + Attributes: + collaborator_name (str): The common name for the collaborator. + aggregator_uuid (str): The unique id for the client. + federation_uuid (str): The unique id for the federation. + client (object): The client object. + task_runner (object): The task runner object. + task_config (dict): The task configuration. + opt_treatment (str)*: The optimizer state treatment. + device_assignment_policy (str): The device assignment policy. + delta_updates (bool)*: If True, only model delta gets sent. If False, + whole model gets sent to collaborator. + compression_pipeline (object): The compression pipeline. + db_store_rounds (int): The number of rounds to store in the database. + single_col_cert_common_name (str): The common name for the single + column certificate. + + .. note:: \* - Plan setting. """ @@ -80,7 +78,28 @@ def __init__(self, compression_pipeline=None, db_store_rounds=1, **kwargs): - """Initialize.""" + """Initialize the Collaborator object. + + Args: + collaborator_name (str): The common name for the collaborator. + aggregator_uuid (str): The unique id for the client. + federation_uuid (str): The unique id for the federation. + client (object): The client object. + task_runner (object): The task runner object. + task_config (dict): The task configuration. + opt_treatment (str, optional): The optimizer state treatment. + Defaults to 'RESET'. + device_assignment_policy (str, optional): The device assignment + policy. Defaults to 'CPU_ONLY'. + delta_updates (bool, optional): If True, only model delta gets + sent. If False, whole model gets sent to collaborator. + Defaults to False. + compression_pipeline (object, optional): The compression pipeline. + Defaults to None. + db_store_rounds (int, optional): The number of rounds to store in + the database. Defaults to 1. + **kwargs: Variable length argument list. + """ self.single_col_cert_common_name = None if self.single_col_cert_common_name is None: @@ -124,10 +143,11 @@ def __init__(self, self.task_runner.set_optimizer_treatment(self.opt_treatment.name) def set_available_devices(self, cuda: Tuple[str] = ()): - """ - Set available CUDA devices. + """Set available CUDA devices. - Cuda tuple contains string indeces, ('1', '3'). + Args: + cuda (Tuple[str]): Tuple containing string indices of available + CUDA devices, ('1', '3'). """ self.cuda_devices = cuda @@ -150,12 +170,10 @@ def run(self): self.logger.info('End of Federation reached. Exiting...') def run_simulation(self): - """ - Specific function for the simulation. + """Specific function for the simulation. - After the tasks have - been performed for a roundquit, and then the collaborator object will - be reinitialized after the next round + After the tasks have been performed for a roundquit, and then the + collaborator object will be reinitialized after the next round. """ while True: tasks, round_number, sleep_time, time_to_quit = self.get_tasks() @@ -168,12 +186,20 @@ def run_simulation(self): self.logger.info(f'Received the following tasks: {tasks}') for task in tasks: self.do_task(task, round_number) - self.logger.info(f'All tasks completed on {self.collaborator_name} ' - f'for round {round_number}...') + self.logger.info( + f'All tasks completed on {self.collaborator_name} ' + f'for round {round_number}...') break def get_tasks(self): - """Get tasks from the aggregator.""" + """Get tasks from the aggregator. + + Returns: + tasks (list_of_str): List of tasks. + round_number (int): Actual round number. + sleep_time (int): Sleep time. + time_to_quit (bool): bool value for quit. + """ # logging wait time to analyze training process self.logger.info('Waiting for tasks...') tasks, round_number, sleep_time, time_to_quit = self.client.get_tasks( @@ -182,7 +208,12 @@ def get_tasks(self): return tasks, round_number, sleep_time, time_to_quit def do_task(self, task, round_number): - """Do the specified task.""" + """Perform the specified task. + + Args: + task (list_of_str): List of tasks. + round_number (int): Actual round number. + """ # map this task to an actual function name and kwargs if hasattr(self.task_runner, 'TASK_REGISTRY'): func_name = task.function_name @@ -203,9 +234,7 @@ def do_task(self, task, round_number): # this would return a list of what tensors we require as TensorKeys required_tensorkeys_relative = self.task_runner.get_required_tensorkeys_for_function( - func_name, - **kwargs - ) + func_name, **kwargs) # models actually return "relative" tensorkeys of (name, LOCAL|GLOBAL, # round_offset) @@ -220,14 +249,12 @@ def do_task(self, task, round_number): # rnd_num is the relative round. So if rnd_num is -1, get the # tensor from the previous round required_tensorkeys.append( - TensorKey(tname, origin, rnd_num + round_number, report, tags) - ) + TensorKey(tname, origin, rnd_num + round_number, report, tags)) # print('Required tensorkeys = {}'.format( # [tk[0] for tk in required_tensorkeys])) input_tensor_dict = self.get_numpy_dict_for_tensorkeys( - required_tensorkeys - ) + required_tensorkeys) # now we have whatever the model needs to do the task if hasattr(self.task_runner, 'TASK_REGISTRY'): @@ -240,7 +267,8 @@ def do_task(self, task, round_number): # those are parameters that the eperiment owner registered for # the task. # There is another set of parameters that created on the - # collaborator side, for instance, local processing unit identifier:s + # collaborator side, for instance, local processing unit + # identifiers: if (self.device_assignment_policy is DevicePolicy.CUDA_PREFERRED and len(self.cuda_devices) > 0): kwargs['device'] = f'cuda:{self.cuda_devices[0]}' @@ -264,24 +292,36 @@ def do_task(self, task, round_number): # send the results for this tasks; delta and compression will occur in # this function - self.send_task_results(global_output_tensor_dict, round_number, task_name) + self.send_task_results(global_output_tensor_dict, round_number, + task_name) def get_numpy_dict_for_tensorkeys(self, tensor_keys): - """Get tensor dictionary for specified tensorkey set.""" - return {k.tensor_name: self.get_data_for_tensorkey(k) for k in tensor_keys} + """Get tensor dictionary for specified tensorkey set. - def get_data_for_tensorkey(self, tensor_key): + Args: + tensor_keys (namedtuple): Tensorkeys that will be resolved locally + or remotely. May be the product of other tensors. """ - Resolve the tensor corresponding to the requested tensorkey. + return { + k.tensor_name: self.get_data_for_tensorkey(k) + for k in tensor_keys + } + + def get_data_for_tensorkey(self, tensor_key): + """Resolve the tensor corresponding to the requested tensorkey. + + Args: + tensor_key (namedtuple): Tensorkey that will be resolved locally or + remotely. May be the product of other tensors. - Args - ---- - tensor_key: Tensorkey that will be resolved locally or - remotely. May be the product of other tensors + Returns: + nparray: The decompressed tensor associated with the requested + tensor key. """ # try to get from the store tensor_name, origin, round_number, report, tags = tensor_key - self.logger.debug(f'Attempting to retrieve tensor {tensor_key} from local store') + self.logger.debug( + f'Attempting to retrieve tensor {tensor_key} from local store') nparray = self.tensor_db.get_tensor_from_cache(tensor_key) # if None and origin is our client, request it from the client @@ -293,10 +333,12 @@ def get_data_for_tensorkey(self, tensor_key): prior_round = round_number - 1 while prior_round >= 0: nparray = self.tensor_db.get_tensor_from_cache( - TensorKey(tensor_name, origin, prior_round, report, tags)) + TensorKey(tensor_name, origin, prior_round, report, + tags)) if nparray is not None: - self.logger.debug(f'Found tensor {tensor_name} in local TensorDB ' - f'for round {prior_round}') + self.logger.debug( + f'Found tensor {tensor_name} in local TensorDB ' + f'for round {prior_round}') return nparray prior_round -= 1 self.logger.info( @@ -308,8 +350,7 @@ def get_data_for_tensorkey(self, tensor_key): # dependencies. # Typically, dependencies are only relevant to model layers tensor_dependencies = self.tensor_codec.find_dependencies( - tensor_key, self.delta_updates - ) + tensor_key, self.delta_updates) if len(tensor_dependencies) > 0: # Resolve dependencies # tensor_dependencies[0] corresponds to the prior version @@ -317,12 +358,10 @@ def get_data_for_tensorkey(self, tensor_key): # If it exists locally, should pull the remote delta because # this is the least costly path prior_model_layer = self.tensor_db.get_tensor_from_cache( - tensor_dependencies[0] - ) + tensor_dependencies[0]) if prior_model_layer is not None: uncompressed_delta = self.get_aggregated_tensor_from_aggregator( - tensor_dependencies[1] - ) + tensor_dependencies[1]) new_model_tk, nparray = self.tensor_codec.apply_delta( tensor_dependencies[1], uncompressed_delta, @@ -335,50 +374,45 @@ def get_data_for_tensorkey(self, tensor_key): 'Fetching latest layer from aggregator') # The original model tensor should be fetched from client nparray = self.get_aggregated_tensor_from_aggregator( - tensor_key, - require_lossless=True - ) + tensor_key, require_lossless=True) elif 'model' in tags: # Pulling the model for the first time nparray = self.get_aggregated_tensor_from_aggregator( - tensor_key, - require_lossless=True - ) + tensor_key, require_lossless=True) else: self.logger.debug(f'Found tensor {tensor_key} in local TensorDB') return nparray - def get_aggregated_tensor_from_aggregator(self, tensor_key, + def get_aggregated_tensor_from_aggregator(self, + tensor_key, require_lossless=False): - """ - Return the decompressed tensor associated with the requested tensor key. + """Return the decompressed tensor associated with the requested tensor + key. If the key requests a compressed tensor (in the tag), the tensor will - be decompressed before returning + be decompressed before returning. If the key specifies an uncompressed tensor (or just omits a compressed - tag), the decompression operation will be skipped - - Args - ---- - tensor_key : The requested tensor - require_lossless: Should compression of the tensor be allowed - in flight? - For the initial model, it may affect - convergence to apply lossy - compression. And metrics shouldn't be - compressed either - - Returns - ------- - nparray : The decompressed tensor associated with the requested - tensor key + tag), the decompression operation will be skipped. + + Args: + tensor_key (namedtuple): The requested tensor. + require_lossless (bool): Should compression of the tensor be + allowed in flight? For the initial model, it may affect + convergence to apply lossy compression. And metrics shouldn't + be compressed either. + + Returns: + nparray : The decompressed tensor associated with the requested + tensor key. """ tensor_name, origin, round_number, report, tags = tensor_key self.logger.debug(f'Requesting aggregated tensor {tensor_key}') - tensor = self.client.get_aggregated_tensor( - self.collaborator_name, tensor_name, round_number, report, tags, require_lossless) + tensor = self.client.get_aggregated_tensor(self.collaborator_name, + tensor_name, round_number, + report, tags, + require_lossless) # this translates to a numpy array and includes decompression, as # necessary @@ -390,7 +424,13 @@ def get_aggregated_tensor_from_aggregator(self, tensor_key, return nparray def send_task_results(self, tensor_dict, round_number, task_name): - """Send task results to the aggregator.""" + """Send task results to the aggregator. + + Args: + tensor_dict (dict): Tensor dictionary. + round_number (int): Actual round number. + task_name (string): Task name. + """ named_tensors = [ self.nparray_to_named_tensor(k, v) for k, v in tensor_dict.items() ] @@ -417,14 +457,24 @@ def send_task_results(self, tensor_dict, round_number, task_name): f'is sending metric for task {task_name}:' f' {tensor_name}\t{tensor_dict[tensor]:f}') - self.client.send_local_task_results( - self.collaborator_name, round_number, task_name, data_size, named_tensors) + self.client.send_local_task_results(self.collaborator_name, + round_number, task_name, data_size, + named_tensors) def nparray_to_named_tensor(self, tensor_key, nparray): - """ - Construct the NamedTensor Protobuf. + """Construct the NamedTensor Protobuf. - Includes logic to create delta, compress tensors with the TensorCodec, etc. + Includes logic to create delta, compress tensors with the TensorCodec, + etc. + + Args: + tensor_key (namedtuple): Tensorkey that will be resolved locally or + remotely. May be the product of other tensors. + nparray: The decompressed tensor associated with the requested + tensor key. + + Returns: + named_tensor (protobuf) : The tensor constructed from the nparray. """ # if we have an aggregated tensor, we can make a delta tensor_name, origin, round_number, report, tags = tensor_key @@ -433,83 +483,66 @@ def nparray_to_named_tensor(self, tensor_key, nparray): # has happened, # Model should already be stored in the TensorDB model_nparray = self.tensor_db.get_tensor_from_cache( - TensorKey( - tensor_name, - origin, - round_number, - report, - ('model',) - ) - ) + TensorKey(tensor_name, origin, round_number, report, + ('model', ))) # The original model will not be present for the optimizer on the # first round. if model_nparray is not None: delta_tensor_key, delta_nparray = self.tensor_codec.generate_delta( - tensor_key, - nparray, - model_nparray - ) + tensor_key, nparray, model_nparray) delta_comp_tensor_key, delta_comp_nparray, metadata = self.tensor_codec.compress( - delta_tensor_key, - delta_nparray - ) + delta_tensor_key, delta_nparray) named_tensor = utils.construct_named_tensor( delta_comp_tensor_key, delta_comp_nparray, metadata, - lossless=False - ) + lossless=False) return named_tensor # Assume every other tensor requires lossless compression compressed_tensor_key, compressed_nparray, metadata = self.tensor_codec.compress( - tensor_key, - nparray, - require_lossless=True - ) - named_tensor = utils.construct_named_tensor( - compressed_tensor_key, - compressed_nparray, - metadata, - lossless=True - ) + tensor_key, nparray, require_lossless=True) + named_tensor = utils.construct_named_tensor(compressed_tensor_key, + compressed_nparray, + metadata, + lossless=True) return named_tensor def named_tensor_to_nparray(self, named_tensor): - """Convert named tensor to a numpy array.""" + """Convert named tensor to a numpy array. + + Args: + named_tensor (protobuf): The tensor to convert to nparray. + + Returns: + decompressed_nparray (nparray): The nparray converted. + """ # do the stuff we do now for decompression and frombuffer and stuff # This should probably be moved back to protoutils raw_bytes = named_tensor.data_bytes - metadata = [{'int_to_float': proto.int_to_float, - 'int_list': proto.int_list, - 'bool_list': proto.bool_list - } for proto in named_tensor.transformer_metadata] + metadata = [{ + 'int_to_float': proto.int_to_float, + 'int_list': proto.int_list, + 'bool_list': proto.bool_list + } for proto in named_tensor.transformer_metadata] # The tensor has already been transfered to collaborator, so # the newly constructed tensor should have the collaborator origin - tensor_key = TensorKey( - named_tensor.name, - self.collaborator_name, - named_tensor.round_number, - named_tensor.report, - tuple(named_tensor.tags) - ) + tensor_key = TensorKey(named_tensor.name, self.collaborator_name, + named_tensor.round_number, named_tensor.report, + tuple(named_tensor.tags)) tensor_name, origin, round_number, report, tags = tensor_key if 'compressed' in tags: decompressed_tensor_key, decompressed_nparray = self.tensor_codec.decompress( tensor_key, data=raw_bytes, transformer_metadata=metadata, - require_lossless=True - ) + require_lossless=True) elif 'lossy_compressed' in tags: decompressed_tensor_key, decompressed_nparray = self.tensor_codec.decompress( - tensor_key, - data=raw_bytes, - transformer_metadata=metadata - ) + tensor_key, data=raw_bytes, transformer_metadata=metadata) else: # There could be a case where the compression pipeline is bypassed # entirely @@ -518,7 +551,6 @@ def named_tensor_to_nparray(self, named_tensor): decompressed_nparray = raw_bytes self.tensor_db.cache_tensor( - {decompressed_tensor_key: decompressed_nparray} - ) + {decompressed_tensor_key: decompressed_nparray}) return decompressed_nparray diff --git a/openfl/component/director/__init__.py b/openfl/component/director/__init__.py index bec467778d..a4c547ab00 100644 --- a/openfl/component/director/__init__.py +++ b/openfl/component/director/__init__.py @@ -1,11 +1,3 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -"""Director package.""" - -from .director import Director - - -__all__ = [ - 'Director', -] +from openfl.component.director.director import Director diff --git a/openfl/component/director/director.py b/openfl/component/director/director.py index c1061510b4..061d414985 100644 --- a/openfl/component/director/director.py +++ b/openfl/component/director/director.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Director module.""" import asyncio @@ -8,36 +7,76 @@ import time from collections import defaultdict from pathlib import Path -from typing import Callable -from typing import Iterable -from typing import List -from typing import Union +from typing import Callable, Iterable, List, Union from openfl.transport.grpc.exceptions import ShardNotFoundError -from .experiment import Experiment -from .experiment import ExperimentsRegistry -from .experiment import Status +from openfl.component.director.experiment import Experiment, ExperimentsRegistry, Status logger = logging.getLogger(__name__) class Director: - """Director class.""" - - def __init__( - self, *, - tls: bool = True, - root_certificate: Union[Path, str] = None, - private_key: Union[Path, str] = None, - certificate: Union[Path, str] = None, - sample_shape: list = None, - target_shape: list = None, - review_plan_callback: Union[None, Callable] = None, - envoy_health_check_period: int = 60, - install_requirements: bool = False - ) -> None: - """Initialize a director object.""" + """Director class. The Director is the central node of the federation + (Director-Based Workflow). + + Attributes: + tls (bool): A flag indicating if TLS should be used for connections. + root_certificate (Union[Path, str]): The path to the root certificate + for TLS. + private_key (Union[Path, str]): The path to the private key for TLS. + certificate (Union[Path, str]): The path to the certificate for TLS. + sample_shape (list): The shape of the sample data. + target_shape (list): The shape of the target data. + review_plan_callback (Union[None, Callable]): A callback function for + reviewing the plan. + envoy_health_check_period (int): The period for health check of envoys + in seconds. + install_requirements (bool): A flag indicating if the requirements + should be installed. + _shard_registry (dict): A dictionary to store the shard registry. + experiments_registry (ExperimentsRegistry): An object of + ExperimentsRegistry to store the experiments. + col_exp_queues (defaultdict): A defaultdict to store the experiment + queues for collaborators. + col_exp (dict): A dictionary to store the experiments for + collaborators. + logger (Logger): A logger for logging activities. + """ + + def __init__(self, + *, + tls: bool = True, + root_certificate: Union[Path, str] = None, + private_key: Union[Path, str] = None, + certificate: Union[Path, str] = None, + sample_shape: list = None, + target_shape: list = None, + review_plan_callback: Union[None, Callable] = None, + envoy_health_check_period: int = 60, + install_requirements: bool = False) -> None: + """Initialize the Director object. + + Args: + tls (bool, optional): A flag indicating if TLS should be used for + connections. Defaults to True. + root_certificate (Union[Path, str], optional): The path to the + root certificate for TLS. Defaults to None. + private_key (Union[Path, str], optional): The path to the private + key for TLS. Defaults to None. + certificate (Union[Path, str], optional): The path to the + certificate for TLS. Defaults to None. + sample_shape (list, optional): The shape of the sample data. + Defaults to None. + target_shape (list, optional): The shape of the target data. + Defaults to None. + review_plan_callback (Union[None, Callable], optional): A callback + function for reviewing the plan. Defaults to None. + envoy_health_check_period (int, optional): The period for health + check of envoys in seconds. Defaults to 60. + install_requirements (bool, optional): A flag indicating if the + requirements should be installed. Defaults to False. + """ self.sample_shape, self.target_shape = sample_shape, target_shape self._shard_registry = {} self.tls = tls @@ -52,13 +91,25 @@ def __init__( self.install_requirements = install_requirements def acknowledge_shard(self, shard_info: dict) -> bool: - """Save shard info to shard registry if accepted.""" + """Save shard info to shard registry if it's acceptable. + + Args: + shard_info (dict): The shard info dictionary should be able to + store registries. + + Returns: + is_accepted (bool): Bool value to accept o deny the addition of + the shard info. + """ is_accepted = False if (self.sample_shape != shard_info['sample_shape'] or self.target_shape != shard_info['target_shape']): - logger.info(f'Director did not accept shard for {shard_info["node_info"]["name"]}') + logger.info( + f'Director did not accept shard for {shard_info["node_info"]["name"]}' + ) return is_accepted - logger.info(f'Director accepted shard for {shard_info["node_info"]["name"]}') + logger.info( + f'Director accepted shard for {shard_info["node_info"]["name"]}') self._shard_registry[shard_info['node_info']['name']] = { 'shard_info': shard_info, 'is_online': True, @@ -70,14 +121,26 @@ def acknowledge_shard(self, shard_info: dict) -> bool: return is_accepted async def set_new_experiment( - self, *, - experiment_name: str, - sender_name: str, - tensor_dict: dict, - collaborator_names: Iterable[str], - experiment_archive_path: Path, + self, + *, + experiment_name: str, + sender_name: str, + tensor_dict: dict, + collaborator_names: Iterable[str], + experiment_archive_path: Path, ) -> bool: - """Set new experiment.""" + """Set new experiment. + + Args: + experiment_name (str): String id for experiment. + sender_name (str): The name of the sender. + tensor_dict (dict): Dictionary of tensors. + collaborator_names (Iterable[str]): Names of collaborators. + experiment_archive_path (Path): Path of the experiment. + + Returns: + bool : Boolean returned if the experiment register was successful. + """ experiment = Experiment( name=experiment_name, archive_path=experiment_archive_path, @@ -89,21 +152,45 @@ async def set_new_experiment( self.experiments_registry.add(experiment) return True - async def get_experiment_status( - self, - experiment_name: str, - caller: str): - """Get experiment status.""" - if (experiment_name not in self.experiments_registry - or caller not in self.experiments_registry[experiment_name].users): + async def get_experiment_status(self, experiment_name: str, caller: str): + """Get experiment status. + + Args: + experiment_name (str): String id for experiment. + caller (str): String id for experiment owner. + + Returns: + str: The status of the experiment can be one of the following: + - PENDING = 'pending' + - FINISHED = 'finished' + - IN_PROGRESS = 'in_progress' + - FAILED = 'failed' + - REJECTED = 'rejected' + """ + if (experiment_name not in self.experiments_registry or caller + not in self.experiments_registry[experiment_name].users): logger.error('No experiment data in the stash') return None return self.experiments_registry[experiment_name].status - def get_trained_model(self, experiment_name: str, caller: str, model_type: str): - """Get trained model.""" - if (experiment_name not in self.experiments_registry - or caller not in self.experiments_registry[experiment_name].users): + def get_trained_model(self, experiment_name: str, caller: str, + model_type: str): + """Get trained model. + + Args: + experiment_name (str): String id for experiment. + caller (str): String id for experiment owner. + model_type (str): The type of the model. + + Returns: + None: One of the following: [No experiment data in the stash] or + [Aggregator have no aggregated model to return] or [Unknown + model type required]. + dict: Dictionary of tensors from the aggregator when the model + type is 'best' or 'last'. + """ + if (experiment_name not in self.experiments_registry or caller + not in self.experiments_registry[experiment_name].users): logger.error('No experiment data in the stash') return None @@ -122,11 +209,25 @@ def get_trained_model(self, experiment_name: str, caller: str, model_type: str): return None def get_experiment_data(self, experiment_name: str) -> Path: - """Get experiment data.""" + """Get experiment data. + + Args: + experiment_name (str): String id for experiment. + + Returns: + str: Path of archive. + """ return self.experiments_registry[experiment_name].archive_path async def wait_experiment(self, envoy_name: str) -> str: - """Wait an experiment.""" + """Wait an experiment. + + Args: + envoy_name (str): The name of the envoy. + + Returns: + str: The name of the experiment on the queue. + """ experiment_name = self.col_exp.get(envoy_name) if experiment_name and experiment_name in self.experiments_registry: # Experiment already set, but the envoy hasn't received experiment @@ -148,33 +249,34 @@ def get_dataset_info(self): def get_registered_shards(self) -> list: # Why is it here? """Get registered shard infos.""" - return [shard_status['shard_info'] for shard_status in self._shard_registry.values()] + return [ + shard_status['shard_info'] + for shard_status in self._shard_registry.values() + ] async def stream_metrics(self, experiment_name: str, caller: str): - """ - Stream metrics from the aggregator. + """Stream metrics from the aggregator. This method takes next metric dictionary from the aggregator's queue and returns it to the caller. - Inputs: - experiment_name - string id for experiment - caller - string id for experiment owner + Args: + experiment_name (str): String id for experiment. + caller (str): String id for experiment owner. Returns: - metric_dict - {'metric_origin','task_name','metric_name','metric_value','round'} - if the queue is not empty - None - f queue is empty but the experiment is still running + metric_dict: {'metric_origin','task_name','metric_name','metric_value','round'} + if the queue is not empty. + None: queue is empty but the experiment is still running. Raises: - StopIteration - if the experiment is finished and there is no more metrics to report + StopIteration: if the experiment is finished and there is no more metrics to report. """ - if (experiment_name not in self.experiments_registry - or caller not in self.experiments_registry[experiment_name].users): + if (experiment_name not in self.experiments_registry or caller + not in self.experiments_registry[experiment_name].users): raise Exception( f'No experiment name "{experiment_name}" in experiments list, or caller "{caller}"' - f' does not have access to this experiment' - ) + f' does not have access to this experiment') while not self.experiments_registry[experiment_name].aggregator: await asyncio.sleep(1) @@ -185,26 +287,39 @@ async def stream_metrics(self, experiment_name: str, caller: str): yield aggregator.metric_queue.get() continue - if aggregator.all_quit_jobs_sent() and aggregator.metric_queue.empty(): + if aggregator.all_quit_jobs_sent( + ) and aggregator.metric_queue.empty(): return yield None def remove_experiment_data(self, experiment_name: str, caller: str): - """Remove experiment data from stash.""" - if (experiment_name in self.experiments_registry - and caller in self.experiments_registry[experiment_name].users): - self.experiments_registry.remove(experiment_name) + """Remove experiment data from stash. - def set_experiment_failed(self, *, experiment_name: str, collaborator_name: str): + Args: + experiment_name (str): String id for experiment. + caller (str): String id for experiment owner. """ - Envoys Set experiment failed RPC. + if (experiment_name in self.experiments_registry and caller + in self.experiments_registry[experiment_name].users): + self.experiments_registry.remove(experiment_name) + + def set_experiment_failed(self, *, experiment_name: str, + collaborator_name: str): + """Envoys Set experiment failed RPC. - This method shoud call `experiment.abort()` and all the code - should be pushed down to that method. + Args: + experiment_name (str): String id for experiment. + collaborator_name (str): String id for collaborator. - It would be also good to be able to interrupt aggregator async task with - the following code: + Return: + None + """ + """This method shoud call `experiment.abort()` and all the code should + be pushed down to that method. + + It would be also good to be able to interrupt aggregator async task + with the following code: ``` run_aggregator_atask = self.experiments_registry[experiment_name].run_aggregator_atask if asyncio.isfuture(run_aggregator_atask) and not run_aggregator_atask.done(): @@ -220,12 +335,27 @@ def set_experiment_failed(self, *, experiment_name: str, collaborator_name: str) self.experiments_registry[experiment_name].status = Status.FAILED def update_envoy_status( - self, *, - envoy_name: str, - is_experiment_running: bool, - cuda_devices_status: list = None, + self, + *, + envoy_name: str, + is_experiment_running: bool, + cuda_devices_status: list = None, ) -> int: - """Accept health check from envoy.""" + """Accept health check from envoy. + + Args: + envoy_name (str): String id for envoy. + is_experiment_running (bool): Boolean value for the status of the + experiment. + cuda_devices_status (list, optional): List of cuda devices and + status. Defaults to None. + + Raises: + ShardNotFoundError: When Unknown shard {envoy_name}. + + Returns: + int: Value of the envoy_health_check_period. + """ shard_info = self._shard_registry.get(envoy_name) if not shard_info: raise ShardNotFoundError(f'Unknown shard {envoy_name}') @@ -242,7 +372,11 @@ def update_envoy_status( return self.envoy_health_check_period def get_envoys(self) -> list: - """Get a status information about envoys.""" + """Get a status information about envoys. + + Returns: + list: List with the status information about envoys. + """ logger.debug(f'Shard registry: {self._shard_registry}') for envoy_info in self._shard_registry.values(): envoy_info['is_online'] = ( @@ -255,7 +389,13 @@ def get_envoys(self) -> list: return self._shard_registry.values() def get_experiments_list(self, caller: str) -> list: - """Get experiments list for specific user.""" + """Get experiments list for specific user. + + Args: + caller (str): String id for experiment owner. + Returns: + list: List with the info of the experiment for specific user. + """ experiments = self.experiments_registry.get_user_experiments(caller) result = [] for exp in experiments: @@ -278,7 +418,15 @@ def get_experiments_list(self, caller: str) -> list: return result def get_experiment_description(self, caller: str, name: str) -> dict: - """Get a experiment information by name for specific user.""" + """Get a experiment information by name for specific user. + + Args: + caller (str): String id for experiment owner. + name (str): String id for experiment name. + + Returns: + dict: Dictionary with the info from the experiment. + """ exp = self.experiments_registry.get(name) if not exp or caller not in exp.users: return {} @@ -308,24 +456,27 @@ async def start_experiment_execution_loop(self): """Run task to monitor and run experiments.""" loop = asyncio.get_event_loop() while True: - async with self.experiments_registry.get_next_experiment() as experiment: + async with self.experiments_registry.get_next_experiment( + ) as experiment: # Review experiment block starts. if self.review_plan_callback: - if not await experiment.review_experiment(self.review_plan_callback): + if not await experiment.review_experiment( + self.review_plan_callback): logger.info( f'"{experiment.name}" Plan was rejected by the Director manager.' ) continue # Review experiment block ends. - run_aggregator_future = loop.create_task(experiment.start( - root_certificate=self.root_certificate, - certificate=self.certificate, - private_key=self.private_key, - tls=self.tls, - install_requirements=self.install_requirements, - )) + run_aggregator_future = loop.create_task( + experiment.start( + root_certificate=self.root_certificate, + certificate=self.certificate, + private_key=self.private_key, + tls=self.tls, + install_requirements=self.install_requirements, + )) # Adding the experiment to collaborators queues for col_name in experiment.collaborators: queue = self.col_exp_queues[col_name] diff --git a/openfl/component/director/experiment.py b/openfl/component/director/experiment.py index e052410e80..89ba90f355 100644 --- a/openfl/component/director/experiment.py +++ b/openfl/component/director/experiment.py @@ -1,16 +1,12 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Experiment module.""" import asyncio import logging from contextlib import asynccontextmanager from pathlib import Path -from typing import Callable -from typing import Iterable -from typing import List -from typing import Union +from typing import Callable, Iterable, List, Union from openfl.federated import Plan from openfl.transport import AggregatorGRPCServer @@ -30,19 +26,48 @@ class Status: class Experiment: - """Experiment class.""" + """Experiment class. + + Attributes: + name (str): The name of the experiment. + archive_path (Union[Path, str]): The path to the experiment + archive. + collaborators (List[str]): The list of collaborators. + sender (str): The name of the sender. + init_tensor_dict (dict): The initial tensor dictionary. + plan_path (Union[Path, str]): The path to the plan. + users (Iterable[str]): The list of users. + status (str): The status of the experiment. + aggregator (object): The aggregator object. + run_aggregator_atask (object): The run aggregator async task + object. + """ def __init__( - self, *, - name: str, - archive_path: Union[Path, str], - collaborators: List[str], - sender: str, - init_tensor_dict: dict, - plan_path: Union[Path, str] = 'plan/plan.yaml', - users: Iterable[str] = None, + self, + *, + name: str, + archive_path: Union[Path, str], + collaborators: List[str], + sender: str, + init_tensor_dict: dict, + plan_path: Union[Path, str] = 'plan/plan.yaml', + users: Iterable[str] = None, ) -> None: - """Initialize an experiment object.""" + """Initialize an experiment object. + + Args: + name (str): The name of the experiment. + archive_path (Union[Path, str]): The path to the experiment + archive. + collaborators (List[str]): The list of collaborators. + sender (str): The name of the sender. + init_tensor_dict (dict): The initial tensor dictionary. + plan_path (Union[Path, str], optional): The path to the plan. + Defaults to 'plan/plan.yaml'. + users (Iterable[str], optional): The list of users. Defaults to + None. + """ self.name = name self.archive_path = Path(archive_path).absolute() self.collaborators = collaborators @@ -55,23 +80,37 @@ def __init__( self.run_aggregator_atask = None async def start( - self, *, - tls: bool = True, - root_certificate: Union[Path, str] = None, - private_key: Union[Path, str] = None, - certificate: Union[Path, str] = None, - install_requirements: bool = False, + self, + *, + tls: bool = True, + root_certificate: Union[Path, str] = None, + private_key: Union[Path, str] = None, + certificate: Union[Path, str] = None, + install_requirements: bool = False, ): - """Run experiment.""" + """Run experiment. + + Args: + tls (bool, optional): A flag indicating if TLS should be used for + connections. Defaults to True. + root_certificate (Union[Path, str], optional): The path to the + root certificate for TLS. Defaults to None. + private_key (Union[Path, str], optional): The path to the private + key for TLS. Defaults to None. + certificate (Union[Path, str], optional): The path to the + certificate for TLS. Defaults to None. + install_requirements (bool, optional): A flag indicating if the + requirements should be installed. Defaults to False. + """ self.status = Status.IN_PROGRESS try: logger.info(f'New experiment {self.name} for ' f'collaborators {self.collaborators}') with ExperimentWorkspace( - experiment_name=self.name, - data_file_path=self.archive_path, - install_requirements=install_requirements, + experiment_name=self.name, + data_file_path=self.archive_path, + install_requirements=install_requirements, ): aggregator_grpc_server = self._create_aggregator_grpc_server( tls=tls, @@ -83,33 +122,34 @@ async def start( self.run_aggregator_atask = asyncio.create_task( self._run_aggregator_grpc_server( - aggregator_grpc_server=aggregator_grpc_server, - ) - ) + aggregator_grpc_server=aggregator_grpc_server, )) await self.run_aggregator_atask self.status = Status.FINISHED logger.info(f'Experiment "{self.name}" was finished successfully.') except Exception as e: self.status = Status.FAILED - logger.exception(f'Experiment "{self.name}" failed with error: {e}.') + logger.exception( + f'Experiment "{self.name}" failed with error: {e}.') async def review_experiment(self, review_plan_callback: Callable) -> bool: - """Get plan approve in console.""" + """Get plan approve in console. + + Args: + review_plan_callback (Callable): A callback function for reviewing the plan. + + Returns: + bool: True if the plan was approved, False otherwise. + """ logger.debug("Experiment Review starts") # Extract the workspace for review (without installing requirements) - with ExperimentWorkspace( - self.name, - self.archive_path, - is_install_requirements=False, - remove_archive=False - ): + with ExperimentWorkspace(self.name, + self.archive_path, + is_install_requirements=False, + remove_archive=False): loop = asyncio.get_event_loop() # Call for a review in a separate thread (server is not blocked) review_approve = await loop.run_in_executor( - None, - review_plan_callback, - self.name, self.plan_path - ) + None, review_plan_callback, self.name, self.plan_path) if not review_approve: self.status = Status.REJECTED self.archive_path.unlink(missing_ok=True) @@ -119,16 +159,33 @@ async def review_experiment(self, review_plan_callback: Callable) -> bool: return True def _create_aggregator_grpc_server( - self, *, - tls: bool = True, - root_certificate: Union[Path, str] = None, - private_key: Union[Path, str] = None, - certificate: Union[Path, str] = None, + self, + *, + tls: bool = True, + root_certificate: Union[Path, str] = None, + private_key: Union[Path, str] = None, + certificate: Union[Path, str] = None, ) -> AggregatorGRPCServer: + """Create an aggregator gRPC server. + + Args: + tls (bool, optional): A flag indicating if TLS should be used for + connections. Defaults to True. + root_certificate (Union[Path, str], optional): The path to the + root certificate for TLS. Defaults to None. + private_key (Union[Path, str], optional): The path to the private + key for TLS. Defaults to None. + certificate (Union[Path, str], optional): The path to the + certificate for TLS. Defaults to None. + + Returns: + AggregatorGRPCServer: The created aggregator gRPC server. + """ plan = Plan.parse(plan_config_path=self.plan_path) plan.authorized_cols = list(self.collaborators) - logger.info(f'🧿 Created an Aggregator Server for {self.name} experiment.') + logger.info( + f'🧿 Created an Aggregator Server for {self.name} experiment.') aggregator_grpc_server = plan.interactive_api_get_server( tensor_dict=self.init_tensor_dict, root_certificate=root_certificate, @@ -139,8 +196,14 @@ def _create_aggregator_grpc_server( return aggregator_grpc_server @staticmethod - async def _run_aggregator_grpc_server(aggregator_grpc_server: AggregatorGRPCServer) -> None: - """Run aggregator.""" + async def _run_aggregator_grpc_server( + aggregator_grpc_server: AggregatorGRPCServer) -> None: + """Run aggregator. + + Args: + aggregator_grpc_server (AggregatorGRPCServer): The aggregator gRPC + server to run. + """ logger.info('🧿 Starting the Aggregator Service.') grpc_server = aggregator_grpc_server.get_server() grpc_server.start() @@ -150,7 +213,8 @@ async def _run_aggregator_grpc_server(aggregator_grpc_server: AggregatorGRPCServ while not aggregator_grpc_server.aggregator.all_quit_jobs_sent(): # Awaiting quit job sent to collaborators await asyncio.sleep(10) - logger.debug('Aggregator sent quit jobs calls to all collaborators') + logger.debug( + 'Aggregator sent quit jobs calls to all collaborators') except KeyboardInterrupt: pass finally: @@ -171,23 +235,40 @@ def __init__(self) -> None: @property def active_experiment(self) -> Union[Experiment, None]: - """Get active experiment.""" + """Get active experiment. + + Returns: + Union[Experiment, None]: The active experiment if exists, None + otherwise. + """ if self.__active_experiment_name is None: return None return self.__dict[self.__active_experiment_name] @property def pending_experiments(self) -> List[str]: - """Get queue of not started experiments.""" + """Get queue of not started experiments. + + Returns: + List[str]: The list of pending experiments. + """ return self.__pending_experiments def add(self, experiment: Experiment) -> None: - """Add experiment to queue of not started experiments.""" + """Add experiment to queue of not started experiments. + + Args: + experiment (Experiment): The experiment to add. + """ self.__dict[experiment.name] = experiment self.__pending_experiments.append(experiment.name) def remove(self, name: str) -> None: - """Remove experiment from everywhere.""" + """Remove experiment from everywhere. + + Args: + name (str): The name of the experiment to remove. + """ if self.__active_experiment_name == name: self.__active_experiment_name = None if name in self.__pending_experiments: @@ -198,23 +279,50 @@ def remove(self, name: str) -> None: del self.__dict[name] def __getitem__(self, key: str) -> Experiment: - """Get experiment by name.""" + """Get experiment by name. + + Args: + key (str): The name of the experiment. + + Returns: + Experiment: The experiment with the given name. + """ return self.__dict[key] def get(self, key: str, default=None) -> Experiment: - """Get experiment by name.""" + """Get experiment by name. + + Args: + key (str): The name of the experiment. + default (optional): The default value to return if the experiment + does not exist. + + Returns: + Experiment: The experiment with the given name, or the default + value if the experiment does not exist. + """ return self.__dict.get(key, default) def get_user_experiments(self, user: str) -> List[Experiment]: - """Get list of experiments for specific user.""" - return [ - exp - for exp in self.__dict.values() - if user in exp.users - ] + """Get list of experiments for specific user. + + Args: + user (str): The name of the user. + + Returns: + List[Experiment]: The list of experiments for the specific user. + """ + return [exp for exp in self.__dict.values() if user in exp.users] def __contains__(self, key: str) -> bool: - """Check if experiment exists.""" + """Check if experiment exists. + + Args: + key (str): The name of the experiment. + + Returns: + bool: True if the experiment exists, False otherwise. + """ return key in self.__dict def finish_active(self) -> None: @@ -226,8 +334,8 @@ def finish_active(self) -> None: async def get_next_experiment(self): """Context manager. - On enter get experiment from pending_experiments. - On exit put finished experiment to archive_experiments. + On enter get experiment from pending_experiments. On exit put finished + experiment to archive_experiments. """ while True: if self.active_experiment is None and self.pending_experiments: diff --git a/openfl/component/envoy/__init__.py b/openfl/component/envoy/__init__.py index a028d52d39..6b1ddbbd09 100644 --- a/openfl/component/envoy/__init__.py +++ b/openfl/component/envoy/__init__.py @@ -1,4 +1,3 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -"""Envoy package.""" +from openfl.component.envoy.envoy import Envoy \ No newline at end of file diff --git a/openfl/component/envoy/envoy.py b/openfl/component/envoy/envoy.py index a2f384ef88..0dee5e5b92 100644 --- a/openfl/component/envoy/envoy.py +++ b/openfl/component/envoy/envoy.py @@ -1,25 +1,23 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Envoy module.""" import logging +import sys import time import traceback import uuid -import sys from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Callable -from typing import Optional -from typing import Type -from typing import Union +from typing import Callable, Optional, Type, Union from openfl.federated import Plan from openfl.interface.interactive_api.shard_descriptor import ShardDescriptor -from openfl.plugins.processing_units_monitor.cuda_device_monitor import CUDADeviceMonitor -from openfl.transport.grpc.exceptions import ShardNotFoundError +from openfl.plugins.processing_units_monitor.cuda_device_monitor import ( + CUDADeviceMonitor, +) from openfl.transport.grpc.director_client import ShardDirectorClient +from openfl.transport.grpc.exceptions import ShardNotFoundError from openfl.utilities.workspace import ExperimentWorkspace logger = logging.getLogger(__name__) @@ -28,24 +26,72 @@ class Envoy: - """Envoy class.""" + """Envoy class. The Envoy is a long-lived entity that runs on collaborator + nodes connected to the Director. + + Attributes: + name (str): The name of the shard. + root_certificate (Union[Path, str]): The path to the root certificate + for TLS. + private_key (Union[Path, str]): The path to the private key for TLS. + certificate (Union[Path, str]): The path to the certificate for TLS. + director_client (ShardDirectorClient): The director client. + shard_descriptor (Type[ShardDescriptor]): The shard descriptor. + cuda_devices (tuple): The CUDA devices. + install_requirements (bool): A flag indicating if the requirements + should be installed. + review_plan_callback (Union[None, Callable]): A callback function for + reviewing the plan. + cuda_device_monitor (Optional[Type[CUDADeviceMonitor]]): The CUDA + device monitor. + executor (ThreadPoolExecutor): The executor for running tasks. + running_experiments (dict): A dictionary to store the running + experiments. + is_experiment_running (bool): A flag indicating if an experiment is + running. + _health_check_future (object): The future object for the health check. + """ def __init__( - self, *, - shard_name: str, - director_host: str, - director_port: int, - shard_descriptor: Type[ShardDescriptor], - root_certificate: Optional[Union[Path, str]] = None, - private_key: Optional[Union[Path, str]] = None, - certificate: Optional[Union[Path, str]] = None, - tls: bool = True, - install_requirements: bool = True, - cuda_devices: Union[tuple, list] = (), - cuda_device_monitor: Optional[Type[CUDADeviceMonitor]] = None, - review_plan_callback: Union[None, Callable] = None, + self, + *, + shard_name: str, + director_host: str, + director_port: int, + shard_descriptor: Type[ShardDescriptor], + root_certificate: Optional[Union[Path, str]] = None, + private_key: Optional[Union[Path, str]] = None, + certificate: Optional[Union[Path, str]] = None, + tls: bool = True, + install_requirements: bool = True, + cuda_devices: Union[tuple, list] = (), + cuda_device_monitor: Optional[Type[CUDADeviceMonitor]] = None, + review_plan_callback: Union[None, Callable] = None, ) -> None: - """Initialize a envoy object.""" + """Initialize a envoy object. + + Args: + shard_name (str): The name of the shard. + director_host (str): The host of the director. + director_port (int): The port of the director. + shard_descriptor (Type[ShardDescriptor]): The shard descriptor. + root_certificate (Optional[Union[Path, str]], optional): The path + to the root certificate for TLS. Defaults to None. + private_key (Optional[Union[Path, str]], optional): The path to + the private key for TLS. Defaults to None. + certificate (Optional[Union[Path, str]], optional): The path to + the certificate for TLS. Defaults to None. + tls (bool, optional): A flag indicating if TLS should be used for + connections. Defaults to True. + install_requirements (bool, optional): A flag indicating if the + requirements should be installed. Defaults to True. + cuda_devices (Union[tuple, list], optional): The CUDA devices. + Defaults to (). + cuda_device_monitor (Optional[Type[CUDADeviceMonitor]], optional): + The CUDA device monitor. Defaults to None. + review_plan_callback (Union[None, Callable], optional): A callback + function for reviewing the plan. Defaults to None. + """ self.name = shard_name self.root_certificate = Path( root_certificate).absolute() if root_certificate is not None else None @@ -58,8 +104,7 @@ def __init__( tls=tls, root_certificate=root_certificate, private_key=private_key, - certificate=certificate - ) + certificate=certificate) self.shard_descriptor = shard_descriptor self.cuda_devices = tuple(cuda_devices) @@ -81,7 +126,8 @@ def run(self): try: # Workspace import should not be done by gRPC client! experiment_name = self.director_client.wait_experiment() - data_stream = self.director_client.get_experiment_data(experiment_name) + data_stream = self.director_client.get_experiment_data( + experiment_name) except Exception as exc: logger.exception(f'Failed to get experiment: {exc}') time.sleep(DEFAULT_RETRY_TIMEOUT_IN_SECONDS) @@ -93,17 +139,16 @@ def run(self): with ExperimentWorkspace( experiment_name=f'{self.name}_{experiment_name}', data_file_path=data_file_path, - install_requirements=self.install_requirements - ): + install_requirements=self.install_requirements): # If the callback is passed if self.review_plan_callback: # envoy to review the experiment before starting - if not self.review_plan_callback('plan', 'plan/plan.yaml'): + if not self.review_plan_callback( + 'plan', 'plan/plan.yaml'): self.director_client.set_experiment_failed( experiment_name, error_description='Experiment is rejected' - f' by Envoy "{self.name}" manager.' - ) + f' by Envoy "{self.name}" manager.') continue logger.debug( f'Experiment "{experiment_name}" was accepted by Envoy manager' @@ -115,13 +160,20 @@ def run(self): self.director_client.set_experiment_failed( experiment_name, error_code=1, - error_description=traceback.format_exc() - ) + error_description=traceback.format_exc()) finally: self.is_experiment_running = False @staticmethod def _save_data_stream_to_file(data_stream): + """Save data stream to file. + + Args: + data_stream: The data stream to save. + + Returns: + Path: The path to the saved data file. + """ data_file_path = Path(str(uuid.uuid4())).absolute() with open(data_file_path, 'wb') as data_file: for response in data_stream: @@ -144,27 +196,37 @@ def send_health_check(self): cuda_devices_info=cuda_devices_info, ) except ShardNotFoundError: - logger.info('The director has lost information about current shard. Resending...') + logger.info( + 'The director has lost information about current shard. Resending...' + ) self.director_client.report_shard_info( shard_descriptor=self.shard_descriptor, - cuda_devices=self.cuda_devices - ) + cuda_devices=self.cuda_devices) time.sleep(timeout) def _get_cuda_device_info(self): + """Get CUDA device info. + + Returns: + list: A list of dictionaries containing info about each CUDA + device. + """ cuda_devices_info = None try: if self.cuda_device_monitor is not None: cuda_devices_info = [] - cuda_driver_version = self.cuda_device_monitor.get_driver_version() + cuda_driver_version = self.cuda_device_monitor.get_driver_version( + ) cuda_version = self.cuda_device_monitor.get_cuda_version() for device_id in self.cuda_devices: - memory_total = self.cuda_device_monitor.get_device_memory_total(device_id) + memory_total = self.cuda_device_monitor.get_device_memory_total( + device_id) memory_utilized = self.cuda_device_monitor.get_device_memory_utilized( - device_id - ) - device_utilization = self.cuda_device_monitor.get_device_utilization(device_id) - device_name = self.cuda_device_monitor.get_device_name(device_id) + device_id) + device_utilization = self.cuda_device_monitor.get_device_utilization( + device_id) + device_name = self.cuda_device_monitor.get_device_name( + device_id) cuda_devices_info.append({ 'index': device_id, 'memory_total': memory_total, @@ -180,15 +242,22 @@ def _get_cuda_device_info(self): return cuda_devices_info def _run_collaborator(self, plan='plan/plan.yaml'): - """Run the collaborator for the experiment running.""" + """Run the collaborator for the experiment running. + + Args: + plan (str, optional): The path to the plan. Defaults to 'plan/plan.yaml'. + """ plan = Plan.parse(plan_config_path=Path(plan)) # TODO: Need to restructure data loader config file loader logger.debug(f'Data = {plan.cols_data_paths}') logger.info('🧿 Starting the Collaborator Service.') - col = plan.get_collaborator(self.name, self.root_certificate, self.private_key, - self.certificate, shard_descriptor=self.shard_descriptor) + col = plan.get_collaborator(self.name, + self.root_certificate, + self.private_key, + self.certificate, + shard_descriptor=self.shard_descriptor) col.set_available_devices(cuda=self.cuda_devices) col.run() @@ -205,7 +274,8 @@ def start(self): if is_accepted: logger.info('Shard was accepted by director') # Shard accepted for participation in the federation - self._health_check_future = self.executor.submit(self.send_health_check) + self._health_check_future = self.executor.submit( + self.send_health_check) self.run() else: # Shut down diff --git a/openfl/component/straggler_handling_functions/__init__.py b/openfl/component/straggler_handling_functions/__init__.py index ab631cdd0b..c957dec87f 100644 --- a/openfl/component/straggler_handling_functions/__init__.py +++ b/openfl/component/straggler_handling_functions/__init__.py @@ -1,12 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -"""Straggler Handling functions package.""" - -from .straggler_handling_function import StragglerHandlingFunction -from .cutoff_time_based_straggler_handling import CutoffTimeBasedStragglerHandling -from .percentage_based_straggler_handling import PercentageBasedStragglerHandling - -__all__ = ['CutoffTimeBasedStragglerHandling', - 'PercentageBasedStragglerHandling', - 'StragglerHandlingFunction'] +from openfl.component.straggler_handling_functions.cutoff_time_based_straggler_handling import CutoffTimeBasedStragglerHandling +from openfl.component.straggler_handling_functions.percentage_based_straggler_handling import PercentageBasedStragglerHandling +from openfl.component.straggler_handling_functions.straggler_handling_function import StragglerHandlingFunction \ No newline at end of file diff --git a/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py b/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py index fba40150fb..5434d9ff16 100644 --- a/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py +++ b/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py @@ -1,33 +1,75 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Cutoff time based Straggler Handling function.""" -import numpy as np import time -from openfl.component.straggler_handling_functions import StragglerHandlingFunction +import numpy as np + +from openfl.component.straggler_handling_functions.straggler_handling_function import ( + StragglerHandlingFunction, +) class CutoffTimeBasedStragglerHandling(StragglerHandlingFunction): - def __init__( - self, - round_start_time=None, - straggler_cutoff_time=np.inf, - minimum_reporting=1, - **kwargs - ): + """Cutoff time based Straggler Handling function.""" + + def __init__(self, + round_start_time=None, + straggler_cutoff_time=np.inf, + minimum_reporting=1, + **kwargs): + """Initialize a CutoffTimeBasedStragglerHandling object. + + Args: + round_start_time (optional): The start time of the round. Defaults + to None. + straggler_cutoff_time (float, optional): The cutoff time for + stragglers. Defaults to np.inf. + minimum_reporting (int, optional): The minimum number of + collaborators that should report. Defaults to 1. + **kwargs: Variable length argument list. + """ self.round_start_time = round_start_time self.straggler_cutoff_time = straggler_cutoff_time self.minimum_reporting = minimum_reporting def straggler_time_expired(self): + """Check if the straggler time has expired. + + Returns: + bool: True if the straggler time has expired, False otherwise. + """ return self.round_start_time is not None and ( (time.time() - self.round_start_time) > self.straggler_cutoff_time) def minimum_collaborators_reported(self, num_collaborators_done): + """Check if the minimum number of collaborators have reported. + + Args: + num_collaborators_done (int): The number of collaborators that + have reported. + + Returns: + bool: True if the minimum number of collaborators have reported, + False otherwise. + """ return num_collaborators_done >= self.minimum_reporting - def straggler_cutoff_check(self, num_collaborators_done, all_collaborators=None): - cutoff = self.straggler_time_expired() and self.minimum_collaborators_reported( - num_collaborators_done) + def straggler_cutoff_check(self, + num_collaborators_done, + all_collaborators=None): + """Check if the straggler cutoff conditions are met. + + Args: + num_collaborators_done (int): The number of collaborators that + have reported. + all_collaborators (optional): All the collaborators. Defaults to + None. + + Returns: + bool: True if the straggler cutoff conditions are met, False + otherwise. + """ + cutoff = self.straggler_time_expired( + ) and self.minimum_collaborators_reported(num_collaborators_done) return cutoff diff --git a/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py b/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py index 8e01418f24..3648ffb82e 100644 --- a/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py +++ b/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py @@ -1,24 +1,58 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Percentage based Straggler Handling function.""" -from openfl.component.straggler_handling_functions import StragglerHandlingFunction +from openfl.component.straggler_handling_functions.straggler_handling_function import ( + StragglerHandlingFunction, +) class PercentageBasedStragglerHandling(StragglerHandlingFunction): - def __init__( - self, - percent_collaborators_needed=1.0, - minimum_reporting=1, - **kwargs - ): + """Percentage based Straggler Handling function.""" + + def __init__(self, + percent_collaborators_needed=1.0, + minimum_reporting=1, + **kwargs): + """Initialize a PercentageBasedStragglerHandling object. + + Args: + percent_collaborators_needed (float, optional): The percentage of + collaborators needed. Defaults to 1.0. + minimum_reporting (int, optional): The minimum number of + collaborators that should report. Defaults to 1. + **kwargs: Variable length argument list. + """ self.percent_collaborators_needed = percent_collaborators_needed self.minimum_reporting = minimum_reporting def minimum_collaborators_reported(self, num_collaborators_done): + """Check if the minimum number of collaborators have reported. + + Args: + num_collaborators_done (int): The number of collaborators that + have reported. + + Returns: + bool: True if the minimum number of collaborators have reported, + False otherwise. + """ return num_collaborators_done >= self.minimum_reporting - def straggler_cutoff_check(self, num_collaborators_done, all_collaborators): - cutoff = (num_collaborators_done >= self.percent_collaborators_needed * len( - all_collaborators)) and self.minimum_collaborators_reported(num_collaborators_done) + def straggler_cutoff_check(self, num_collaborators_done, + all_collaborators): + """Check if the straggler cutoff conditions are met. + + Args: + num_collaborators_done (int): The number of collaborators that + have reported. + all_collaborators (list): All the collaborators. + + Returns: + bool: True if the straggler cutoff conditions are met, False + otherwise. + """ + cutoff = ( + num_collaborators_done + >= self.percent_collaborators_needed * len(all_collaborators) + ) and self.minimum_collaborators_reported(num_collaborators_done) return cutoff diff --git a/openfl/component/straggler_handling_functions/straggler_handling_function.py b/openfl/component/straggler_handling_functions/straggler_handling_function.py index 53d1076932..1c9434541d 100644 --- a/openfl/component/straggler_handling_functions/straggler_handling_function.py +++ b/openfl/component/straggler_handling_functions/straggler_handling_function.py @@ -1,10 +1,8 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Straggler handling module.""" -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod class StragglerHandlingFunction(ABC): @@ -12,9 +10,15 @@ class StragglerHandlingFunction(ABC): @abstractmethod def straggler_cutoff_check(self, **kwargs): - """ - Determines whether it is time to end the round early. + """Determines whether it is time to end the round early. + + Args: + **kwargs: Variable length argument list. + Returns: - bool + bool: True if it is time to end the round early, False otherwise. + + Raises: + NotImplementedError: This method must be implemented by a subclass. """ raise NotImplementedError diff --git a/openfl/cryptography/ca.py b/openfl/cryptography/ca.py index ef8c28a32f..db02b8ca4d 100644 --- a/openfl/cryptography/ca.py +++ b/openfl/cryptography/ca.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Cryptography CA utilities.""" import datetime @@ -12,25 +11,30 @@ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey -from cryptography.x509.base import Certificate -from cryptography.x509.base import CertificateSigningRequest +from cryptography.x509.base import Certificate, CertificateSigningRequest from cryptography.x509.extensions import ExtensionNotFound from cryptography.x509.name import Name -from cryptography.x509.oid import ExtensionOID -from cryptography.x509.oid import NameOID +from cryptography.x509.oid import ExtensionOID, NameOID + +def generate_root_cert( + days_to_expiration: int = 365) -> Tuple[RSAPrivateKey, Certificate]: + """Generate a root certificate and its corresponding private key. + + Args: + days_to_expiration (int, optional): The number of days until the + certificate expires. Defaults to 365. -def generate_root_cert(days_to_expiration: int = 365) -> Tuple[RSAPrivateKey, Certificate]: - """Generate_root_certificate.""" + Returns: + Tuple[RSAPrivateKey, Certificate]: The private key and the certificate. + """ now = datetime.datetime.utcnow() expiration_delta = days_to_expiration * datetime.timedelta(1, 0, 0) # Generate private key - root_private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=3072, - backend=default_backend() - ) + root_private_key = rsa.generate_private_key(public_exponent=65537, + key_size=3072, + backend=default_backend()) # Generate public key root_public_key = root_private_key.public_key() @@ -40,7 +44,8 @@ def generate_root_cert(days_to_expiration: int = 365) -> Tuple[RSAPrivateKey, Ce x509.NameAttribute(NameOID.DOMAIN_COMPONENT, u'simple'), x509.NameAttribute(NameOID.COMMON_NAME, u'Simple Root CA'), x509.NameAttribute(NameOID.ORGANIZATION_NAME, u'Simple Inc'), - x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, u'Simple Root CA'), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, + u'Simple Root CA'), ]) issuer = subject builder = builder.subject_name(subject) @@ -51,26 +56,30 @@ def generate_root_cert(days_to_expiration: int = 365) -> Tuple[RSAPrivateKey, Ce builder = builder.serial_number(int(uuid.uuid4())) builder = builder.public_key(root_public_key) builder = builder.add_extension( - x509.BasicConstraints(ca=True, path_length=None), critical=True, + x509.BasicConstraints(ca=True, path_length=None), + critical=True, ) # Sign the CSR - certificate = builder.sign( - private_key=root_private_key, algorithm=hashes.SHA384(), - backend=default_backend() - ) + certificate = builder.sign(private_key=root_private_key, + algorithm=hashes.SHA384(), + backend=default_backend()) return root_private_key, certificate def generate_signing_csr() -> Tuple[RSAPrivateKey, CertificateSigningRequest]: - """Generate signing CSR.""" + """Generate a Certificate Signing Request (CSR) and its corresponding + private key. + + Returns: + Tuple[RSAPrivateKey, CertificateSigningRequest]: The private key and + the CSR. + """ # Generate private key - signing_private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=3072, - backend=default_backend() - ) + signing_private_key = rsa.generate_private_key(public_exponent=65537, + key_size=3072, + backend=default_backend()) builder = x509.CertificateSigningRequestBuilder() subject = x509.Name([ @@ -78,35 +87,43 @@ def generate_signing_csr() -> Tuple[RSAPrivateKey, CertificateSigningRequest]: x509.NameAttribute(NameOID.DOMAIN_COMPONENT, u'simple'), x509.NameAttribute(NameOID.COMMON_NAME, u'Simple Signing CA'), x509.NameAttribute(NameOID.ORGANIZATION_NAME, u'Simple Inc'), - x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, u'Simple Signing CA'), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, + u'Simple Signing CA'), ]) builder = builder.subject_name(subject) builder = builder.add_extension( - x509.BasicConstraints(ca=True, path_length=None), critical=True, + x509.BasicConstraints(ca=True, path_length=None), + critical=True, ) # Sign the CSR - csr = builder.sign( - private_key=signing_private_key, algorithm=hashes.SHA384(), - backend=default_backend() - ) + csr = builder.sign(private_key=signing_private_key, + algorithm=hashes.SHA384(), + backend=default_backend()) return signing_private_key, csr -def sign_certificate(csr: CertificateSigningRequest, issuer_private_key: RSAPrivateKey, - issuer_name: Name, days_to_expiration: int = 365, +def sign_certificate(csr: CertificateSigningRequest, + issuer_private_key: RSAPrivateKey, + issuer_name: Name, + days_to_expiration: int = 365, ca: bool = False) -> Certificate: - """ - Sign the incoming CSR request. + """Sign a incoming Certificate Signing Request (CSR) with the issuer's + private key. Args: - csr : Certificate Signing Request object - issuer_private_key : Root CA private key if the request is for the signing - CA; Signing CA private key otherwise - issuer_name : x509 Name - days_to_expiration : int (365 days by default) - ca : Is this a certificate authority + csr (CertificateSigningRequest): The CSR to be signed. + issuer_private_key (RSAPrivateKey): Root CA private key if the request + is for the signing CA; Signing CA private key otherwise. + issuer_name (Name): The name of the issuer. + days_to_expiration (int, optional): The number of days until the + certificate expires. Defaults to 365. + ca (bool, optional): Whether the certificate is for a certificate + authority (CA). Defaults to False. + + Returns: + Certificate: The signed certificate. """ now = datetime.datetime.utcnow() expiration_delta = days_to_expiration * datetime.timedelta(1, 0, 0) @@ -119,7 +136,8 @@ def sign_certificate(csr: CertificateSigningRequest, issuer_private_key: RSAPriv builder = builder.serial_number(int(uuid.uuid4())) builder = builder.public_key(csr.public_key()) builder = builder.add_extension( - x509.BasicConstraints(ca=ca, path_length=None), critical=True, + x509.BasicConstraints(ca=ca, path_length=None), + critical=True, ) try: builder = builder.add_extension( @@ -130,7 +148,7 @@ def sign_certificate(csr: CertificateSigningRequest, issuer_private_key: RSAPriv except ExtensionNotFound: pass # Might not have alternative name - signed_cert = builder.sign( - private_key=issuer_private_key, algorithm=hashes.SHA384(), backend=default_backend() - ) + signed_cert = builder.sign(private_key=issuer_private_key, + algorithm=hashes.SHA384(), + backend=default_backend()) return signed_cert diff --git a/openfl/cryptography/io.py b/openfl/cryptography/io.py index 52bfc5e95b..0c066a8111 100644 --- a/openfl/cryptography/io.py +++ b/openfl/cryptography/io.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Cryptography IO utilities.""" import os @@ -13,19 +12,17 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey from cryptography.hazmat.primitives.serialization import load_pem_private_key -from cryptography.x509.base import Certificate -from cryptography.x509.base import CertificateSigningRequest +from cryptography.x509.base import Certificate, CertificateSigningRequest def read_key(path: Path) -> RSAPrivateKey: - """ - Read private key. + """Reads a private key from a file. Args: - path : Path (pathlib) + path (Path): The path to the file containing the private key. Returns: - private_key + RSAPrivateKey: The private key. """ with open(path, 'rb') as f: pem_data = f.read() @@ -37,34 +34,32 @@ def read_key(path: Path) -> RSAPrivateKey: def write_key(key: RSAPrivateKey, path: Path) -> None: - """ - Write private key. + """Writes a private key to a file. Args: - key : RSA private key object - path : Path (pathlib) - + key (RSAPrivateKey): The private key to write. + path (Path): The path to the file to write the private key to. """ + def key_opener(path, flags): return os.open(path, flags, mode=0o600) with open(path, 'wb', opener=key_opener) as f: - f.write(key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() - )) + f.write( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption())) def read_crt(path: Path) -> Certificate: - """ - Read signed TLS certificate. + """Reads a signed TLS certificate from a file. Args: - path : Path (pathlib) + path (Path): The path to the file containing the certificate. Returns: - Cryptography TLS Certificate object + Certificate: The TLS certificate. """ with open(path, 'rb') as f: pem_data = f.read() @@ -76,31 +71,26 @@ def read_crt(path: Path) -> Certificate: def write_crt(certificate: Certificate, path: Path) -> None: - """ - Write cryptography certificate / csr. + """Writes a cryptography certificate / CSR to a file. Args: - certificate : cryptography csr / certificate object - path : Path (pathlib) - - Returns: - Cryptography TLS Certificate object + certificate (Certificate): cryptography csr / certificate object to + write. + path (Path): The path to the file to write the certificate to. """ with open(path, 'wb') as f: f.write(certificate.public_bytes( - encoding=serialization.Encoding.PEM, - )) + encoding=serialization.Encoding.PEM, )) def read_csr(path: Path) -> Tuple[CertificateSigningRequest, str]: - """ - Read certificate signing request. + """Reads a Certificate Signing Request (CSR) from a file. Args: - path : Path (pathlib) + path (Path): The path to the file containing the CSR. Returns: - Cryptography CSR object + Tuple[CertificateSigningRequest, str]: The CSR and its hash. """ with open(path, 'rb') as f: pem_data = f.read() @@ -112,18 +102,17 @@ def read_csr(path: Path) -> Tuple[CertificateSigningRequest, str]: def get_csr_hash(certificate: CertificateSigningRequest) -> str: - """ - Get hash of cryptography certificate. + """Computes the SHA-384 hash of a certificate. Args: - certificate : Cryptography CSR object + certificate (CertificateSigningRequest): The certificate to compute + the hash of. Returns: - Hash of cryptography certificate / csr + str: The SHA-384 hash of the certificate. """ hasher = sha384() encoded_bytes = certificate.public_bytes( - encoding=serialization.Encoding.PEM, - ) + encoding=serialization.Encoding.PEM, ) hasher.update(encoded_bytes) return hasher.hexdigest() diff --git a/openfl/cryptography/participant.py b/openfl/cryptography/participant.py index d6e94712b1..023db253e1 100644 --- a/openfl/cryptography/participant.py +++ b/openfl/cryptography/participant.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Cryptography participant utilities.""" from typing import Tuple @@ -13,15 +12,25 @@ from cryptography.x509.oid import NameOID -def generate_csr(common_name: str, - server: bool = False) -> Tuple[RSAPrivateKey, CertificateSigningRequest]: - """Issue certificate signing request for server and client.""" +def generate_csr( + common_name: str, + server: bool = False +) -> Tuple[RSAPrivateKey, CertificateSigningRequest]: + """Issue a Certificate Signing Request (CSR) for a server or a client. + + Args: + common_name (str): The common name for the certificate. + server (bool, optional): A flag to indicate if the CSR is for a server. + If False, the CSR is for a client. Defaults to False. + + Returns: + Tuple[RSAPrivateKey, CertificateSigningRequest]: The private key and + the CSR. + """ # Generate private key - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=3072, - backend=default_backend() - ) + private_key = rsa.generate_private_key(public_exponent=65537, + key_size=3072, + backend=default_backend()) builder = x509.CertificateSigningRequestBuilder() subject = x509.Name([ @@ -29,7 +38,8 @@ def generate_csr(common_name: str, ]) builder = builder.subject_name(subject) builder = builder.add_extension( - x509.BasicConstraints(ca=False, path_length=None), critical=True, + x509.BasicConstraints(ca=False, path_length=None), + critical=True, ) if server: builder = builder.add_extension( @@ -43,20 +53,16 @@ def generate_csr(common_name: str, critical=True ) - builder = builder.add_extension( - x509.KeyUsage( - digital_signature=True, - key_encipherment=True, - data_encipherment=False, - key_agreement=False, - content_commitment=False, - key_cert_sign=False, - crl_sign=False, - encipher_only=False, - decipher_only=False - ), - critical=True - ) + builder = builder.add_extension(x509.KeyUsage(digital_signature=True, + key_encipherment=True, + data_encipherment=False, + key_agreement=False, + content_commitment=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False), + critical=True) builder = builder.add_extension( x509.SubjectAlternativeName([x509.DNSName(common_name)]), @@ -64,9 +70,8 @@ def generate_csr(common_name: str, ) # Sign the CSR - csr = builder.sign( - private_key=private_key, algorithm=hashes.SHA384(), - backend=default_backend() - ) + csr = builder.sign(private_key=private_key, + algorithm=hashes.SHA384(), + backend=default_backend()) return private_key, csr diff --git a/openfl/databases/__init__.py b/openfl/databases/__init__.py index 3152025247..3a0fef4a3a 100644 --- a/openfl/databases/__init__.py +++ b/openfl/databases/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Databases package.""" from .tensor_db import TensorDB diff --git a/openfl/databases/tensor_db.py b/openfl/databases/tensor_db.py index 0045569d6a..e15ac2b459 100644 --- a/openfl/databases/tensor_db.py +++ b/openfl/databases/tensor_db.py @@ -1,35 +1,40 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """TensorDB Module.""" from threading import Lock -from typing import Dict -from typing import Iterator -from typing import Optional from types import MethodType +from typing import Dict, Iterator, Optional import numpy as np import pandas as pd +from openfl.databases.utilities import ( + ROUND_PLACEHOLDER, + _retrieve, + _search, + _store, +) from openfl.interface.aggregation_functions import AggregationFunction -from openfl.utilities import change_tags -from openfl.utilities import LocalTensor -from openfl.utilities import TensorKey -from openfl.databases.utilities import _search, _store, _retrieve, ROUND_PLACEHOLDER +from openfl.utilities import LocalTensor, TensorKey, change_tags class TensorDB: - """ - The TensorDB stores a tensor key and the data that it corresponds to. + """The TensorDB stores a tensor key and the data that it corresponds to. - It is built on top of a pandas dataframe - for it's easy insertion, retreival and aggregation capabilities. Each - collaborator and aggregator has its own TensorDB. + It is built on top of a pandas dataframe for it's easy insertion, retreival + and aggregation capabilities. Each collaborator and aggregator has its own + TensorDB. + + Attributes: + tensor_db: A pandas DataFrame that stores the tensor key and the data + that it corresponds to. + mutex: A threading Lock object used to ensure thread-safe operations + on the tensor_db Dataframe. """ def __init__(self) -> None: - """Initialize.""" + """Initializes a new instance of the TensorDB class.""" types_dict = { 'tensor_name': 'string', 'origin': 'string', @@ -38,16 +43,17 @@ def __init__(self) -> None: 'tags': 'object', 'nparray': 'object' } - self.tensor_db = pd.DataFrame( - {col: pd.Series(dtype=dtype) for col, dtype in types_dict.items()} - ) + self.tensor_db = pd.DataFrame({ + col: pd.Series(dtype=dtype) + for col, dtype in types_dict.items() + }) self._bind_convenience_methods() self.mutex = Lock() def _bind_convenience_methods(self): - # Bind convenience methods for TensorDB dataframe - # to make storage, retrieval, and search easier + """Bind convenience methods for the TensorDB dataframe to make storage, + retrieval, and search easier.""" if not hasattr(self.tensor_db, 'store'): self.tensor_db.store = MethodType(_store, self.tensor_db) if not hasattr(self.tensor_db, 'retrieve'): @@ -56,33 +62,51 @@ def _bind_convenience_methods(self): self.tensor_db.search = MethodType(_search, self.tensor_db) def __repr__(self) -> str: - """Representation of the object.""" + """Returns the string representation of the TensorDB object. + + Returns: + content (str): The string representation of the TensorDB object. + """ with pd.option_context('display.max_rows', None): - content = self.tensor_db[['tensor_name', 'origin', 'round', 'report', 'tags']] + content = self.tensor_db[[ + 'tensor_name', 'origin', 'round', 'report', 'tags' + ]] return f'TensorDB contents:\n{content}' def __str__(self) -> str: - """Printable string representation.""" + """Returns the string representation of the TensorDB object. + + Returns: + __repr__ (str): The string representation of the TensorDB object. + """ return self.__repr__() def clean_up(self, remove_older_than: int = 1) -> None: - """Remove old entries from database preventing the db from becoming too large and slow.""" + """Removes old entries from the database to prevent it from becoming + too large and slow. + + Args: + remove_older_than (int, optional): Entries older than this number + of rounds are removed. Defaults to 1. + """ if remove_older_than < 0: # Getting a negative argument calls off cleaning return current_round = self.tensor_db['round'].astype(int).max() if current_round == ROUND_PLACEHOLDER: - current_round = np.sort(self.tensor_db['round'].astype(int).unique())[-2] + current_round = np.sort( + self.tensor_db['round'].astype(int).unique())[-2] self.tensor_db = self.tensor_db[ (self.tensor_db['round'].astype(int) > current_round - remove_older_than) - | self.tensor_db['report'] - ].reset_index(drop=True) + | self.tensor_db['report']].reset_index(drop=True) - def cache_tensor(self, tensor_key_dict: Dict[TensorKey, np.ndarray]) -> None: - """Insert tensor into TensorDB (dataframe). + def cache_tensor(self, tensor_key_dict: Dict[TensorKey, + np.ndarray]) -> None: + """Insert a tensor into TensorDB (dataframe). Args: - tensor_key_dict: The Tensor Key + tensor_key_dict (Dict[TensorKey, np.ndarray]): A dictionary where + the key is a TensorKey and the value is a numpy array. Returns: None @@ -99,16 +123,19 @@ def cache_tensor(self, tensor_key_dict: Dict[TensorKey, np.ndarray]) -> None: ) ) - self.tensor_db = pd.concat( - [self.tensor_db, *entries_to_add], ignore_index=True - ) + self.tensor_db = pd.concat([self.tensor_db, *entries_to_add], + ignore_index=True) - def get_tensor_from_cache(self, tensor_key: TensorKey) -> Optional[np.ndarray]: - """ - Perform a lookup of the tensor_key in the TensorDB. + def get_tensor_from_cache(self, + tensor_key: TensorKey) -> Optional[np.ndarray]: + """Perform a lookup of the tensor_key in the TensorDB. + + Args: + tensor_key (TensorKey): The key of the tensor to look up. - Returns the nparray if it is available - Otherwise, it returns 'None' + Returns: + Optional[np.ndarray]: The numpy array if it is available. + Otherwise, returns None. """ tensor_name, origin, fl_round, report, tags = tensor_key @@ -123,27 +150,27 @@ def get_tensor_from_cache(self, tensor_key: TensorKey) -> Optional[np.ndarray]: return None return np.array(df['nparray'].iloc[0]) - def get_aggregated_tensor(self, tensor_key: TensorKey, collaborator_weight_dict: dict, - aggregation_function: AggregationFunction - ) -> Optional[np.ndarray]: - """ - Determine whether all of the collaborator tensors are present for a given tensor key. - - Returns their weighted average. + def get_aggregated_tensor( + self, tensor_key: TensorKey, collaborator_weight_dict: dict, + aggregation_function: AggregationFunction) -> Optional[np.ndarray]: + """Determine whether all of the collaborator tensors are present for a + given tensor key. Args: - tensor_key: The tensor key to be resolved. If origin 'agg_uuid' is - present, can be returned directly. Otherwise must - compute weighted average of all collaborators - collaborator_weight_dict: List of collaborator names in federation - and their respective weights - aggregation_function: Call the underlying numpy aggregation - function. Default is just the weighted - average. - Returns: - weighted_nparray if all collaborator values are present - None if not all values are present + tensor_key (TensorKey): The tensor key to be resolved. If origin + 'agg_uuid' is present, can be returned directly. Otherwise + must compute weighted average of all collaborators. + collaborator_weight_dict (dict): A dictionary where the keys are + collaborator names and the values are their respective weights. + aggregation_function (AggregationFunction): Call the underlying + numpy aggregation function to use to compute the weighted + average. Default is just the weighted average. + Returns: + agg_nparray Optional[np.ndarray]: weighted_nparray The weighted + average if all collaborator values are present. Otherwise, + returns None. + None: if not all values are present. """ if len(collaborator_weight_dict) != 0: assert np.abs(1.0 - sum(collaborator_weight_dict.values())) < 0.01, ( @@ -179,10 +206,12 @@ def get_aggregated_tensor(self, tensor_key: TensorKey, collaborator_weight_dict: else: agg_tensor_dict[col] = raw_df.iloc[0] - local_tensors = [LocalTensor(col_name=col_name, - tensor=agg_tensor_dict[col_name], - weight=collaborator_weight_dict[col_name]) - for col_name in collaborator_names] + local_tensors = [ + LocalTensor(col_name=col_name, + tensor=agg_tensor_dict[col_name], + weight=collaborator_weight_dict[col_name]) + for col_name in collaborator_names + ] if hasattr(aggregation_function, '_privileged'): if aggregation_function._privileged: @@ -190,25 +219,34 @@ def get_aggregated_tensor(self, tensor_key: TensorKey, collaborator_weight_dict: self._bind_convenience_methods() agg_nparray = aggregation_function(local_tensors, self.tensor_db, - tensor_name, - fl_round, + tensor_name, fl_round, tags) self.cache_tensor({tensor_key: agg_nparray}) return np.array(agg_nparray) db_iterator = self._iterate() - agg_nparray = aggregation_function(local_tensors, - db_iterator, - tensor_name, - fl_round, - tags) + agg_nparray = aggregation_function(local_tensors, db_iterator, + tensor_name, fl_round, tags) self.cache_tensor({tensor_key: agg_nparray}) return np.array(agg_nparray) def _iterate(self, order_by: str = 'round', ascending: bool = False) -> Iterator[pd.Series]: + """Returns an iterator over the rows of the TensorDB, sorted by a + specified column. + + Args: + order_by (str, optional): The column to sort by. Defaults to + 'round'. + ascending (bool, optional): Whether to sort in ascending order. + Defaults to False. + + Returns: + Iterator[pd.Series]: An iterator over the rows of the TensorDB. + """ columns = ['round', 'nparray', 'tensor_name', 'tags'] - rows = self.tensor_db[columns].sort_values(by=order_by, ascending=ascending).iterrows() + rows = self.tensor_db[columns].sort_values( + by=order_by, ascending=ascending).iterrows() for _, row in rows: yield row diff --git a/openfl/databases/utilities/__init__.py b/openfl/databases/utilities/__init__.py index b7f4779adf..5252fb7a65 100644 --- a/openfl/databases/utilities/__init__.py +++ b/openfl/databases/utilities/__init__.py @@ -1,13 +1,7 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Database Utilities.""" -from .dataframe import _search, _store, _retrieve, ROUND_PLACEHOLDER +from .dataframe import ROUND_PLACEHOLDER, _retrieve, _search, _store -__all__ = [ - '_search', - '_store', - '_retrieve', - 'ROUND_PLACEHOLDER' -] +__all__ = ['_search', '_store', '_retrieve', 'ROUND_PLACEHOLDER'] diff --git a/openfl/databases/utilities/dataframe.py b/openfl/databases/utilities/dataframe.py index 9038fa07d3..ed45081b70 100644 --- a/openfl/databases/utilities/dataframe.py +++ b/openfl/databases/utilities/dataframe.py @@ -1,38 +1,42 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Convenience Utilities for DataFrame.""" +from typing import Optional + import numpy as np import pandas as pd -from typing import Optional ROUND_PLACEHOLDER = 1000000 -def _search(self, tensor_name: str = None, origin: str = None, - fl_round: int = None, metric: bool = None, tags: tuple = None - ) -> pd.DataFrame: - """ - Search the tensor_db dataframe based on: - - tensor_name - - origin - - fl_round - - metric - -tags - - Returns a new dataframe that matched the query - - Args: - tensor_name: The name of the tensor (or metric) to be searched - origin: Origin of the tensor - fl_round: Round the tensor is associated with - metric: Is the tensor a metric? - tags: Tuple of unstructured tags associated with the tensor - - Returns: - pd.DataFrame : New dataframe that matches the search query from - the tensor_db dataframe +def _search(self, + tensor_name: str = None, + origin: str = None, + fl_round: int = None, + metric: bool = None, + tags: tuple = None) -> pd.DataFrame: + """Returns a new dataframe that matched the query. + + Search the tensor_db dataframe based on: + - tensor_name + - origin + - fl_round + - metric + -tags + + Args: + tensor_name (str, optional): The name of the tensor (or metric) to be + searched. + origin (str, optional): Origin of the tensor. + fl_round (int, optional): Round the tensor is associated with. + metric (bool, optional): Whether the tensor is a metric. + tags (tuple, optional): Tuple of unstructured tags associated with the + tensor. + + Returns: + pd.DataFrame: New dataframe that matches the search query from the + tensor_db dataframe. """ df = pd.DataFrame() query_string = [] @@ -60,26 +64,33 @@ def _search(self, tensor_name: str = None, origin: str = None, return self -def _store(self, tensor_name: str = '_', origin: str = '_', - fl_round: int = ROUND_PLACEHOLDER, metric: bool = False, - tags: tuple = ('_',), nparray: np.array = None, +def _store(self, + tensor_name: str = '_', + origin: str = '_', + fl_round: int = ROUND_PLACEHOLDER, + metric: bool = False, + tags: tuple = ('_', ), + nparray: np.array = None, overwrite: bool = True) -> None: - """ - Convenience method to store a new tensor in the dataframe. - - Args: - tensor_name [ optional ] : The name of the tensor (or metric) to be saved - origin [ optional ] : Origin of the tensor - fl_round [ optional ] : Round the tensor is associated with - metric [ optional ] : Is the tensor a metric? - tags [ optional ] : Tuple of unstructured tags associated with the tensor - nparray [ required ] : Value to store associated with the other - included information (i.e. TensorKey info) - overwrite [ optional ] : If the tensor is already present in the dataframe - should it be overwritten? - - Returns: - None + """Convenience method to store a new tensor in the dataframe. + + Args: + tensor_name (str, optional): The name of the tensor (or metric) to be + saved. Defaults to '_'. + origin (str, optional): Origin of the tensor. Defaults to '_'. + fl_round (int, optional): Round the tensor is associated with. + Defaults to ROUND_PLACEHOLDER. + metric (bool, optional): Whether the tensor is a metric. Defaults to + False. + tags (tuple, optional): Tuple of unstructured tags associated with the + tensor. Defaults to ('_',). + nparray (np.array, optional): Value to store associated with the other + included information (i.e. TensorKey info). + overwrite (bool, optional): If the tensor is already present in the + dataframe, should it be overwritten? Defaults to True. + + Returns: + None """ if nparray is None: @@ -95,25 +106,29 @@ def _store(self, tensor_name: str = '_', origin: str = '_', idx = idx[0] else: idx = self.shape[0] - self.loc[idx] = np.array([tensor_name, origin, fl_round, metric, tags, nparray], dtype=object) + self.loc[idx] = np.array( + [tensor_name, origin, fl_round, metric, tags, nparray], dtype=object) def _retrieve(self, tensor_name: str = '_', origin: str = '_', fl_round: int = ROUND_PLACEHOLDER, metric: bool = False, tags: tuple = ('_',)) -> Optional[np.array]: - """ - Convenience method to retrieve tensor from the dataframe. - - Args: - tensor_name [ optional ] : The name of the tensor (or metric) to retrieve - origin [ optional ] : Origin of the tensor - fl_round [ optional ] : Round the tensor is associated with - metric: [ optional ] : Is the tensor a metric? - tags: [ optional ] : Tuple of unstructured tags associated with the tensor - should it be overwritten? - - Returns: - Optional[ np.array ] : If there is a match, return the first row + """Convenience method to retrieve tensor from the dataframe. + + Args: + tensor_name (str, optional): The name of the tensor (or metric) to + retrieve. Defaults to '_'. + origin (str, optional): Origin of the tensor. Defaults to '_'. + fl_round (int, optional): Round the tensor is associated with. + Defaults to ROUND_PLACEHOLDER. + metric (bool, optional): Whether the tensor is a metric. Defaults to + False. + tags (tuple, optional): Tuple of unstructured tags associated with the + tensor. Defaults to ('_',). + + Returns: + Optional[np.array]: If there is a match, return the first row. + Otherwise, return None. """ df = self[(self['tensor_name'] == tensor_name) diff --git a/openfl/experimental/__init__.py b/openfl/experimental/__init__.py index a397960a9f..0e118989d6 100644 --- a/openfl/experimental/__init__.py +++ b/openfl/experimental/__init__.py @@ -1,3 +1,3 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""openfl experimental package.""" +"""Openfl experimental package.""" diff --git a/openfl/experimental/component/aggregator/aggregator.py b/openfl/experimental/component/aggregator/aggregator.py index af44cdd6d1..bfc4651711 100644 --- a/openfl/experimental/component/aggregator/aggregator.py +++ b/openfl/experimental/component/aggregator/aggregator.py @@ -20,14 +20,16 @@ class Aggregator: Args: aggregator_uuid (str): Aggregation ID. federation_uuid (str): Federation ID. - authorized_cols (list of str): The list of IDs of enrolled collaborators. - + authorized_cols (list of str): The list of IDs of enrolled + collaborators. flow (Any): Flow class. rounds_to_train (int): External loop rounds. checkpoint (bool): Whether to save checkpoint or noe (default=False). - private_attrs_callable (Callable): Function for Aggregator private attriubtes + private_attrs_callable (Callable): Function for Aggregator private + attriubtes (default=None). - private_attrs_kwargs (Dict): Arguments to call private_attrs_callable (default={}). + private_attrs_kwargs (Dict): Arguments to call private_attrs_callable + (default={}). Returns: None @@ -78,7 +80,8 @@ def __init__( self.collaborator_task_results = Event() # A queue for each task self.__collaborator_tasks_queue = { - collab: queue.Queue() for collab in self.authorized_cols + collab: queue.Queue() + for collab in self.authorized_cols } self.flow = flow @@ -86,8 +89,7 @@ def __init__( self.flow._foreach_methods = [] self.logger.info("MetaflowInterface creation.") self.flow._metaflow_interface = MetaflowInterface( - self.flow.__class__, "single_process" - ) + self.flow.__class__, "single_process") self.flow._run_id = self.flow._metaflow_interface.create_run() self.flow.runtime = FederatedRuntime() self.flow.runtime.aggregator = "aggregator" @@ -104,35 +106,28 @@ def __init__( self.__initialize_private_attributes(private_attributes_kwargs) def __initialize_private_attributes(self, kwargs: Dict) -> None: - """ - Call private_attrs_callable function set - attributes to self.__private_attrs. - """ + """Call private_attrs_callable function set attributes to + self.__private_attrs.""" self.__private_attrs = self.__private_attrs_callable(**kwargs) def __set_attributes_to_clone(self, clone: Any) -> None: - """ - Set private_attrs to clone as attributes. - """ + """Set private_attrs to clone as attributes.""" if len(self.__private_attrs) > 0: for name, attr in self.__private_attrs.items(): setattr(clone, name, attr) - def __delete_agg_attrs_from_clone( - self, clone: Any, replace_str: str = None - ) -> None: - """ - Remove aggregator private attributes from FLSpec clone before - transition from Aggregator step to collaborator steps. - """ + def __delete_agg_attrs_from_clone(self, + clone: Any, + replace_str: str = None) -> None: + """Remove aggregator private attributes from FLSpec clone before + transition from Aggregator step to collaborator steps.""" # Update aggregator private attributes by taking latest # parameters from clone, then delete attributes from clone. if len(self.__private_attrs) > 0: for attr_name in self.__private_attrs: if hasattr(clone, attr_name): self.__private_attrs.update( - {attr_name: getattr(clone, attr_name)} - ) + {attr_name: getattr(clone, attr_name)}) if replace_str: setattr(clone, attr_name, replace_str) else: @@ -144,13 +139,11 @@ def _log_big_warning(self) -> None: f"\n{the_dragon}\nYOU ARE RUNNING IN SINGLE COLLABORATOR CERT MODE! THIS IS" f" NOT PROPER PKI AND " f"SHOULD ONLY BE USED IN DEVELOPMENT SETTINGS!!!! YE HAVE BEEN" - f" WARNED!!!" - ) + f" WARNED!!!") @staticmethod def _get_sleep_time() -> int: - """ - Sleep 10 seconds. + """Sleep 10 seconds. Returns: sleep_time: int @@ -158,9 +151,7 @@ def _get_sleep_time() -> int: return 10 def run_flow(self) -> None: - """ - Start the execution and run flow until transition. - """ + """Start the execution and run flow until transition.""" # Start function will be the first step if any flow f_name = "start" @@ -188,44 +179,42 @@ def run_flow(self) -> None: self.logger.info( "Waiting for " + f"{len_connected_collabs}/{len_sel_collabs}" - + " collaborators to connect..." - ) + + " collaborators to connect...") elif self.tasks_sent_to_collaborators != len_sel_collabs: self.logger.info( "Waiting for " + f"{self.tasks_sent_to_collaborators}/{len_sel_collabs}" - + " to make requests for tasks..." - ) + + " to make requests for tasks...") else: # Waiting for selected collaborators to send the results. self.logger.info( "Waiting for " + f"{self.collaborators_counter}/{len_sel_collabs}" - + " collaborators to send results..." - ) + + " collaborators to send results...") time.sleep(Aggregator._get_sleep_time()) self.collaborator_task_results.clear() f_name = self.next_step if hasattr(self, "instance_snapshot"): self.flow.restore_instance_snapshot( - self.flow, list(self.instance_snapshot) - ) + self.flow, list(self.instance_snapshot)) delattr(self, "instance_snapshot") - def call_checkpoint( - self, ctx: Any, f: Callable, stream_buffer: bytes = None - ) -> None: - """ - Perform checkpoint task. + def call_checkpoint(self, + ctx: Any, + f: Callable, + stream_buffer: bytes = None) -> None: + """Perform checkpoint task. Args: - ctx (FLSpec / bytes): Collaborator FLSpec object for which checkpoint is to be - performed. - f (Callable / bytes): Collaborator Step (Function) which is to be checkpointed. - stream_buffer (bytes): Captured object for output and error (default=None). - reserved_attributes (List[str]): List of attribute names which is to be excluded - from checkpoint (default=[]). + ctx (FLSpec / bytes): Collaborator FLSpec object for which + checkpoint is to be performed. + f (Callable / bytes): Collaborator Step (Function) which is to be + checkpointed. + stream_buffer (bytes): Captured object for output and error + (default=None). + reserved_attributes (List[str]): List of attribute names which is + to be excluded from checkpoint (default=[]). Returns: None @@ -242,15 +231,13 @@ def call_checkpoint( f = pickle.loads(f) if isinstance(stream_buffer, bytes): # Set stream buffer as function parameter - setattr( - f.__func__, "_stream_buffer", pickle.loads(stream_buffer) - ) + setattr(f.__func__, "_stream_buffer", + pickle.loads(stream_buffer)) checkpoint(ctx, f) def get_tasks(self, collaborator_name: str) -> Tuple: - """ - RPC called by a collaborator to determine which tasks to perform. + """RPC called by a collaborator to determine which tasks to perform. Tasks will only be sent to selected collaborators. Args: @@ -260,8 +247,8 @@ def get_tasks(self, collaborator_name: str) -> Tuple: next_step (str): Next function to be executed by collaborator clone_bytes (bytes): Function execution context for collaborator """ - # If requesting collaborator is not registered as connected collaborator, - # then register it + # If requesting collaborator is not registered as connected + # collaborator, then register it if collaborator_name not in self.connected_collaborators: self.logger.info(f"Collaborator {collaborator_name} is connected.") self.connected_collaborators.append(collaborator_name) @@ -292,14 +279,12 @@ def get_tasks(self, collaborator_name: str) -> Tuple: # Get collaborator step, and clone for requesting collaborator next_step, clone = self.__collaborator_tasks_queue[ - collaborator_name - ].get() + collaborator_name].get() self.tasks_sent_to_collaborators += 1 self.logger.info( "Sending tasks to collaborator" - + f" {collaborator_name} for round {self.current_round}..." - ) + + f" {collaborator_name} for round {self.current_round}...") return ( self.current_round, next_step, @@ -309,8 +294,7 @@ def get_tasks(self, collaborator_name: str) -> Tuple: ) def do_task(self, f_name: str) -> Any: - """ - Execute aggregator steps until transition. + """Execute aggregator steps until transition. Args: f_name (str): Aggregator step @@ -322,7 +306,8 @@ def do_task(self, f_name: str) -> Any: self.__set_attributes_to_clone(self.flow) not_at_transition_point = True - # Run a loop to execute flow steps until not_at_transition_point is False + # Run a loop to execute flow steps until not_at_transition_point + # is False while not_at_transition_point: f = getattr(self.flow, f_name) # Get the list of parameters of function f @@ -332,8 +317,7 @@ def do_task(self, f_name: str) -> Any: f() # Take the checkpoint of "end" step self.__delete_agg_attrs_from_clone( - self.flow, "Private attributes: Not Available." - ) + self.flow, "Private attributes: Not Available.") self.call_checkpoint(self.flow, f) self.__set_attributes_to_clone(self.flow) # Check if all rounds of external loop is executed @@ -352,26 +336,26 @@ def do_task(self, f_name: str) -> Any: selected_clones = () # If function requires arguments then it is join step of the flow if len(args) > 0: - # Check if total number of collaborators and number of selected collaborators - # are the same + # Check if total number of collaborators and number of + # selected collaborators are the same if len(self.selected_collaborators) != len(self.clones_dict): # Create list of selected collaborator clones - selected_clones = ([],) + selected_clones = ([], ) for name, clone in self.clones_dict.items(): - # Check if collaboraotr is in the list of selected collaborators + # Check if collaboraotr is in the list of selected + # collaborators if name in self.selected_collaborators: selected_clones[0].append(clone) else: - # Number of selected collaborators, and number of total collaborators - # are same - selected_clones = (list(self.clones_dict.values()),) + # Number of selected collaborators, and number of total + # collaborators are same + selected_clones = (list(self.clones_dict.values()), ) # Call the join function with selected collaborators # clones are arguments f(*selected_clones) self.__delete_agg_attrs_from_clone( - self.flow, "Private attributes: Not Available." - ) + self.flow, "Private attributes: Not Available.") # Take the checkpoint of executed step self.call_checkpoint(self.flow, f) self.__set_attributes_to_clone(self.flow) @@ -391,8 +375,7 @@ def do_task(self, f_name: str) -> Any: self.clones_dict, self.instance_snapshot, self.kwargs = temp self.selected_collaborators = getattr( - self.flow, self.kwargs["foreach"] - ) + self.flow, self.kwargs["foreach"]) else: self.kwargs = self.flow.execute_task_args[3] @@ -411,13 +394,13 @@ def send_task_results( next_step: str, clone_bytes: bytes, ) -> None: - """ - After collaborator execution, collaborator will call this function via gRPc - to send next function. + """After collaborator execution, collaborator will call this function + via gRPc to send next function. Args: collab_name (str): Collaborator name which is sending results - round_number (int): Round number for which collaborator is sending results + round_number (int): Round number for which collaborator is sending + results next_step (str): Next aggregator step in the flow clone_bytes (bytes): Collaborator FLSpec object @@ -428,13 +411,10 @@ def send_task_results( if round_number is not self.current_round: self.logger.warning( f"Collaborator {collab_name} is reporting results" - f" for the wrong round: {round_number}. Ignoring..." - ) + f" for the wrong round: {round_number}. Ignoring...") else: - self.logger.info( - f"Collaborator {collab_name} sent task results" - f" for round {round_number}." - ) + self.logger.info(f"Collaborator {collab_name} sent task results" + f" for round {round_number}.") # Unpickle the clone (FLSpec object) clone = pickle.loads(clone_bytes) # Update the clone in clones_dict dictionary @@ -449,15 +429,13 @@ def send_task_results( self.collaborator_task_results.set() # Empty tasks_sent_to_collaborators list for next time. if self.tasks_sent_to_collaborators == len( - self.selected_collaborators - ): + self.selected_collaborators): self.tasks_sent_to_collaborators = 0 - def valid_collaborator_cn_and_id( - self, cert_common_name: str, collaborator_common_name: str - ) -> bool: - """ - Determine if the collaborator certificate and ID are valid for this federation. + def valid_collaborator_cn_and_id(self, cert_common_name: str, + collaborator_common_name: str) -> bool: + """Determine if the collaborator certificate and ID are valid for this + federation. Args: cert_common_name: Common name for security certificate @@ -472,17 +450,13 @@ def valid_collaborator_cn_and_id( # FIXME: "" instead of None is just for protobuf compatibility. # Cleaner solution? if self.single_col_cert_common_name == "": - return ( - cert_common_name == collaborator_common_name - and collaborator_common_name in self.authorized_cols - ) + return (cert_common_name == collaborator_common_name + and collaborator_common_name in self.authorized_cols) # otherwise, common_name must be in whitelist and # collaborator_common_name must be in authorized_cols else: - return ( - cert_common_name == self.single_col_cert_common_name - and collaborator_common_name in self.authorized_cols - ) + return (cert_common_name == self.single_col_cert_common_name + and collaborator_common_name in self.authorized_cols) def all_quit_jobs_sent(self) -> bool: """Assert all quit jobs are sent to collaborators.""" diff --git a/openfl/experimental/component/collaborator/collaborator.py b/openfl/experimental/component/collaborator/collaborator.py index be84ffe2e8..11a6e6959e 100644 --- a/openfl/experimental/component/collaborator/collaborator.py +++ b/openfl/experimental/component/collaborator/collaborator.py @@ -54,9 +54,8 @@ def __init__( self.__initialize_private_attributes(private_attributes_kwargs) def __initialize_private_attributes(self, kwargs: Dict) -> None: - """ - Call private_attrs_callable function set - attributes to self.__private_attrs + """Call private_attrs_callable function set attributes to + self.__private_attrs. Args: kwargs (Dict): Private attributes callable function arguments @@ -67,8 +66,7 @@ def __initialize_private_attributes(self, kwargs: Dict) -> None: self.__private_attrs = self.__private_attrs_callable(**kwargs) def __set_attributes_to_clone(self, clone: Any) -> None: - """ - Set private_attrs to clone as attributes. + """Set private_attrs to clone as attributes. Args: clone (FLSpec): Clone to which private attributes are to be @@ -81,12 +79,11 @@ def __set_attributes_to_clone(self, clone: Any) -> None: for name, attr in self.__private_attrs.items(): setattr(clone, name, attr) - def __delete_agg_attrs_from_clone( - self, clone: Any, replace_str: str = None - ) -> None: - """ - Remove aggregator private attributes from FLSpec clone before - transition from Aggregator step to collaborator steps + def __delete_agg_attrs_from_clone(self, + clone: Any, + replace_str: str = None) -> None: + """Remove aggregator private attributes from FLSpec clone before + transition from Aggregator step to collaborator steps. Args: clone (FLSpec): Clone from which private attributes are to be @@ -101,18 +98,15 @@ def __delete_agg_attrs_from_clone( for attr_name in self.__private_attrs: if hasattr(clone, attr_name): self.__private_attrs.update( - {attr_name: getattr(clone, attr_name)} - ) + {attr_name: getattr(clone, attr_name)}) if replace_str: setattr(clone, attr_name, replace_str) else: delattr(clone, attr_name) - def call_checkpoint( - self, ctx: Any, f: Callable, stream_buffer: Any - ) -> None: - """ - Call checkpoint gRPC. + def call_checkpoint(self, ctx: Any, f: Callable, + stream_buffer: Any) -> None: + """Call checkpoint gRPC. Args: ctx (FLSpec): FLSPec object. @@ -130,8 +124,7 @@ def call_checkpoint( ) def run(self) -> None: - """ - Run the collaborator. + """Run the collaborator. Args: None @@ -153,28 +146,24 @@ def run(self) -> None: self.logger.info("End of Federation reached. Exiting...") def send_task_results(self, next_step: str, clone: Any) -> None: - """ - After collaborator is executed, send next aggregator - step to Aggregator for continue execution. + """After collaborator is executed, send next aggregator step to + Aggregator for continue execution. Args: next_step (str): Send next function to aggregator - clone (FLSpec): Updated clone object (Private attributes atr not included) + clone (FLSpec): Updated clone object (Private attributes atr not + included) Returns: None """ - self.logger.info( - f"Round {self.round_number}," - f" collaborator {self.name} is sending results..." - ) - self.client.send_task_results( - self.name, self.round_number, next_step, pickle.dumps(clone) - ) + self.logger.info(f"Round {self.round_number}," + f" collaborator {self.name} is sending results...") + self.client.send_task_results(self.name, self.round_number, next_step, + pickle.dumps(clone)) def get_tasks(self) -> Tuple: - """ - Get tasks from the aggregator. + """Get tasks from the aggregator. Args: None @@ -188,14 +177,12 @@ def get_tasks(self) -> Tuple: self.logger.info("Waiting for tasks...") temp = self.client.get_tasks(self.name) self.round_number, next_step, clone_bytes, sleep_time, time_to_quit = ( - temp - ) + temp) return next_step, pickle.loads(clone_bytes), sleep_time, time_to_quit def do_task(self, f_name: str, ctx: Any) -> Tuple: - """ - Run collaborator steps until transition. + """Run collaborator steps until transition. Args: f_name (str): Function name which is to be executed. @@ -214,8 +201,7 @@ def do_task(self, f_name: str, ctx: Any) -> Tuple: f() # Checkpoint the function self.__delete_agg_attrs_from_clone( - ctx, "Private attributes: Not Available." - ) + ctx, "Private attributes: Not Available.") self.call_checkpoint(ctx, f, f._stream_buffer) self.__set_attributes_to_clone(ctx) diff --git a/openfl/experimental/federated/plan/plan.py b/openfl/experimental/federated/plan/plan.py index 3ba1a75649..01fed0164b 100644 --- a/openfl/experimental/federated/plan/plan.py +++ b/openfl/experimental/federated/plan/plan.py @@ -50,8 +50,7 @@ def ignore_aliases(self, data): plan = Plan() plan.config = config frozen_yaml_path = Path( - f"{yaml_path.parent}/{yaml_path.stem}_{plan.hash[:8]}.yaml" - ) + f"{yaml_path.parent}/{yaml_path.stem}_{plan.hash[:8]}.yaml") if frozen_yaml_path.exists(): Plan.logger.info(f"{yaml_path.name} is already frozen") return @@ -68,8 +67,7 @@ def parse( data_config_path: Path = None, resolve=True, ): - """ - Parse the Federated Learning plan. + """Parse the Federated Learning plan. Args: plan_config_path (string): The filepath to the federated learning @@ -84,7 +82,8 @@ def parse( """ try: plan = Plan() - plan.config = Plan.load(plan_config_path) # load plan configuration + plan.config = Plan.load( + plan_config_path) # load plan configuration plan.name = plan_config_path.name plan.files = [plan_config_path] # collect all the plan files @@ -114,8 +113,7 @@ def parse( if SETTINGS in defaults: # override defaults with section settings defaults[SETTINGS].update( - plan.config[section][SETTINGS] - ) + plan.config[section][SETTINGS]) plan.config[section][SETTINGS] = defaults[SETTINGS] defaults.update(plan.config[section]) @@ -123,8 +121,7 @@ def parse( plan.config[section] = defaults plan.authorized_cols = Plan.load(cols_config_path).get( - "collaborators", [] - ) + "collaborators", []) if resolve: plan.resolve() @@ -148,8 +145,7 @@ def parse( @staticmethod def accept_args(cls): - """ - Determines whether a class's constructor (__init__ method) accepts + """Determines whether a class's constructor (__init__ method) accepts variable positional arguments (*args). Returns: @@ -163,8 +159,8 @@ def accept_args(cls): @staticmethod def build(template, settings, **override): - """ - Create an instance of a openfl Component or Federated DataLoader/TaskRunner. + """Create an instance of a openfl Component or Federated + DataLoader/TaskRunner. Args: template: Fully qualified class template path @@ -193,8 +189,8 @@ def build(template, settings, **override): @staticmethod def import_(template): - """ - Import an instance of a openfl Component or Federated DataLoader/TaskRunner. + """Import an instance of a openfl Component or Federated + DataLoader/TaskRunner. Args: template: Fully qualified object path @@ -245,22 +241,23 @@ def resolve(self): self.aggregator_uuid = f"aggregator_{self.federation_uuid}" self.rounds_to_train = self.config["aggregator"][SETTINGS][ - "rounds_to_train" - ] + "rounds_to_train"] if self.config["network"][SETTINGS]["agg_addr"] == AUTO: self.config["network"][SETTINGS]["agg_addr"] = getfqdn_env() if self.config["network"][SETTINGS]["agg_port"] == AUTO: self.config["network"][SETTINGS]["agg_port"] = ( - int(self.hash[:8], 16) % (60999 - 49152) + 49152 - ) + int(self.hash[:8], 16) % (60999 - 49152) + 49152) def get_aggregator(self): """Get federation aggregator.""" defaults = self.config.get( "aggregator", - {TEMPLATE: "openfl.experimental.Aggregator", SETTINGS: {}}, + { + TEMPLATE: "openfl.experimental.Aggregator", + SETTINGS: {} + }, ) defaults[SETTINGS]["aggregator_uuid"] = self.aggregator_uuid @@ -268,11 +265,9 @@ def get_aggregator(self): defaults[SETTINGS]["authorized_cols"] = self.authorized_cols private_attrs_callable, private_attrs_kwargs, private_attributes = ( - self.get_private_attr( - "aggregator" - ) - ) - defaults[SETTINGS]["private_attributes_callable"] = private_attrs_callable + self.get_private_attr("aggregator")) + defaults[SETTINGS][ + "private_attributes_callable"] = private_attrs_callable defaults[SETTINGS]["private_attributes_kwargs"] = private_attrs_kwargs defaults[SETTINGS]["private_attributes"] = private_attributes @@ -289,8 +284,7 @@ def get_aggregator(self): elif not callable(log_metric_callback): raise TypeError( f"log_metric_callback should be callable object " - f"or be import from code part, get {log_metric_callback}" - ) + f"or be import from code part, get {log_metric_callback}") defaults[SETTINGS]["log_metric_callback"] = log_metric_callback if self.aggregator_ is None: @@ -309,7 +303,10 @@ def get_collaborator( """Get collaborator.""" defaults = self.config.get( "collaborator", - {TEMPLATE: "openfl.experimental.Collaborator", SETTINGS: {}}, + { + TEMPLATE: "openfl.experimental.Collaborator", + SETTINGS: {} + }, ) defaults[SETTINGS]["collaborator_name"] = collaborator_name @@ -317,11 +314,9 @@ def get_collaborator( defaults[SETTINGS]["federation_uuid"] = self.federation_uuid private_attrs_callable, private_attrs_kwargs, private_attributes = ( - self.get_private_attr( - collaborator_name - ) - ) - defaults[SETTINGS]["private_attributes_callable"] = private_attrs_callable + self.get_private_attr(collaborator_name)) + defaults[SETTINGS][ + "private_attributes_callable"] = private_attrs_callable defaults[SETTINGS]["private_attributes_kwargs"] = private_attrs_kwargs defaults[SETTINGS]["private_attributes"] = private_attributes @@ -406,10 +401,13 @@ def get_server( return self.server_ def get_flow(self): - """instantiates federated flow object""" + """Instantiates federated flow object.""" defaults = self.config.get( "federated_flow", - {TEMPLATE: self.config["federated_flow"]["template"], SETTINGS: {}}, + { + TEMPLATE: self.config["federated_flow"]["template"], + SETTINGS: {} + }, ) defaults = self.import_kwargs_modules(defaults) @@ -438,7 +436,8 @@ def import_nested_settings(settings): if not inspect.isclass(attr): settings[key] = attr else: - settings = Plan.build(**value_defaults_data) + settings = Plan.build( + **value_defaults_data) except ImportError: raise ImportError(f"Cannot import {value}.") return settings @@ -462,40 +461,33 @@ def get_private_attr(self, private_attr_name=None): d = Plan.load(Path(data_yaml).absolute()) if d.get(private_attr_name, None): - callable_func = d.get(private_attr_name, {}).get( - "callable_func" - ) - private_attributes = d.get(private_attr_name, {}).get( - "private_attributes" - ) + callable_func = d.get(private_attr_name, + {}).get("callable_func") + private_attributes = d.get(private_attr_name, + {}).get("private_attributes") if callable_func and private_attributes: logger = getLogger(__name__) logger.warning( f'Warning: {private_attr_name} private attributes ' 'will be initialized via callable and ' 'attributes directly specified ' - 'will be ignored' - ) + 'will be ignored') if callable_func is not None: private_attrs_callable = { - "template": d.get(private_attr_name)["callable_func"][ - "template" - ] + "template": + d.get(private_attr_name)["callable_func"]["template"] } private_attrs_kwargs = self.import_kwargs_modules( - d.get(private_attr_name)["callable_func"] - )["settings"] + d.get(private_attr_name)["callable_func"])["settings"] if isinstance(private_attrs_callable, dict): private_attrs_callable = Plan.import_( - **private_attrs_callable - ) + **private_attrs_callable) elif private_attributes: private_attributes = Plan.import_( - d.get(private_attr_name)["private_attributes"] - ) + d.get(private_attr_name)["private_attributes"]) elif not callable(private_attrs_callable): raise TypeError( f"private_attrs_callable should be callable object " diff --git a/openfl/experimental/interface/cli/aggregator.py b/openfl/experimental/interface/cli/aggregator.py index ec307e361a..d7295948e1 100644 --- a/openfl/experimental/interface/cli/aggregator.py +++ b/openfl/experimental/interface/cli/aggregator.py @@ -74,8 +74,7 @@ def start_(plan, authorized_cols, secure): if not os.path.exists("plan/data.yaml"): logger.warning( "Aggregator private attributes are set to None as plan/data.yaml not found" - + " in workspace." - ) + + " in workspace.") else: import yaml from yaml.loader import SafeLoader @@ -85,8 +84,7 @@ def start_(plan, authorized_cols, secure): if data.get("aggregator", None) is None: logger.warning( "Aggregator private attributes are set to None as no aggregator" - + " attributes found in plan/data.yaml." - ) + + " attributes found in plan/data.yaml.") logger.info("🧿 Starting the Aggregator Service.") @@ -127,20 +125,16 @@ def generate_cert_request(fqdn): subject_alternative_name = f"DNS:{common_name}" file_name = f"agg_{common_name}" - echo( - f"Creating AGGREGATOR certificate key pair with following settings: " - f'CN={style(common_name, fg="red")},' - f' SAN={style(subject_alternative_name, fg="red")}' - ) + echo(f"Creating AGGREGATOR certificate key pair with following settings: " + f'CN={style(common_name, fg="red")},' + f' SAN={style(subject_alternative_name, fg="red")}') server_private_key, server_csr = generate_csr(common_name, server=True) (CERT_DIR / "server").mkdir(parents=True, exist_ok=True) - echo( - " Writing AGGREGATOR certificate key pair to: " - + style(f"{CERT_DIR}/server", fg="green") - ) + echo(" Writing AGGREGATOR certificate key pair to: " + + style(f"{CERT_DIR}/server", fg="green")) # Print csr hash before writing csr to disk csr_hash = get_csr_hash(server_csr) @@ -187,23 +181,21 @@ def certify(fqdn, silent): csr_path_absolute_path = Path(CERT_DIR / f"{cert_name}.csr").absolute() if not csr_path_absolute_path.exists(): echo( - style("Aggregator certificate signing request not found.", fg="red") + style("Aggregator certificate signing request not found.", + fg="red") + " Please run `fx aggregator generate-cert-request`" - " to generate the certificate request." - ) + " to generate the certificate request.") csr, csr_hash = read_csr(csr_path_absolute_path) # Load private signing key - private_sign_key_absolute_path = Path( - CERT_DIR / signing_key_path - ).absolute() + private_sign_key_absolute_path = Path(CERT_DIR + / signing_key_path).absolute() if not private_sign_key_absolute_path.exists(): echo( style("Signing key not found.", fg="red") + " Please run `fx workspace certify`" - " to initialize the local certificate authority." - ) + " to initialize the local certificate authority.") signing_key = read_key(private_sign_key_absolute_path) @@ -213,17 +205,12 @@ def certify(fqdn, silent): echo( style("Signing certificate not found.", fg="red") + " Please run `fx workspace certify`" - " to initialize the local certificate authority." - ) + " to initialize the local certificate authority.") signing_crt = read_crt(signing_crt_absolute_path) - echo( - "The CSR Hash for file " - + style(f"{cert_name}.csr", fg="green") - + " = " - + style(f"{csr_hash}", fg="red") - ) + echo("The CSR Hash for file " + style(f"{cert_name}.csr", fg="green") + + " = " + style(f"{csr_hash}", fg="red")) crt_path_absolute_path = Path(CERT_DIR / f"{cert_name}.crt").absolute() @@ -232,9 +219,8 @@ def certify(fqdn, silent): " Warning: manual check of certificate hashes is bypassed in silent mode." ) echo(" Signing AGGREGATOR certificate") - signed_agg_cert = sign_certificate( - csr, signing_key, signing_crt.subject - ) + signed_agg_cert = sign_certificate(csr, signing_key, + signing_crt.subject) write_crt(signed_agg_cert, crt_path_absolute_path) else: @@ -242,14 +228,12 @@ def certify(fqdn, silent): if confirm("Do you want to sign this certificate?"): echo(" Signing AGGREGATOR certificate") - signed_agg_cert = sign_certificate( - csr, signing_key, signing_crt.subject - ) + signed_agg_cert = sign_certificate(csr, signing_key, + signing_crt.subject) write_crt(signed_agg_cert, crt_path_absolute_path) else: echo( style("Not signing certificate.", fg="red") + " Please check with this AGGREGATOR to get the correct" - " certificate for this federation." - ) + " certificate for this federation.") diff --git a/openfl/experimental/interface/cli/cli_helper.py b/openfl/experimental/interface/cli/cli_helper.py index d8ddb2bd48..3c3f896c8f 100644 --- a/openfl/experimental/interface/cli/cli_helper.py +++ b/openfl/experimental/interface/cli/cli_helper.py @@ -74,9 +74,9 @@ def inner(dir_path: Path, prefix: str = "", level=-1): yield prefix + pointer + path.name directories += 1 extension = branch if pointer == tee else space - yield from inner( - path, prefix=prefix + extension, level=level - 1 - ) + yield from inner(path, + prefix=prefix + extension, + level=level - 1) elif not limit_to_directories: yield prefix + pointer + path.name files += 1 @@ -87,7 +87,8 @@ def inner(dir_path: Path, prefix: str = "", level=-1): echo(line) if next(iterator, None): echo(f"... length_limit, {length_limit}, reached, counted:") - echo(f"\n{directories} directories" + (f", {files} files" if files else "")) + echo(f"\n{directories} directories" + + (f", {files} files" if files else "")) def copytree( @@ -116,9 +117,8 @@ def _copytree(): os.makedirs(dst, exist_ok=dirs_exist_ok) errors = [] - use_srcentry = ( - copy_function is shutil.copy2 or copy_function is shutil.copy - ) + use_srcentry = (copy_function is shutil.copy2 + or copy_function is shutil.copy) for srcentry in entries: if srcentry.name in ignored_names: @@ -136,14 +136,12 @@ def _copytree(): linkto = os.readlink(srcname) if symlinks: os.symlink(linkto, dstname) - shutil.copystat( - srcobj, dstname, follow_symlinks=not symlinks - ) + shutil.copystat(srcobj, + dstname, + follow_symlinks=not symlinks) else: - if ( - not os.path.exists(linkto) - and ignore_dangling_symlinks - ): + if (not os.path.exists(linkto) + and ignore_dangling_symlinks): continue if srcentry.is_dir(): copytree( @@ -199,7 +197,8 @@ def get_workspace_parameter(name): def check_varenv(env: str = "", args: dict = None): - """Update "args" (dictionary) with if env has a defined value in the host.""" + """Update "args" (dictionary) with if env has a defined + value in the host.""" if args is None: args = {} env_val = environ.get(env) diff --git a/openfl/experimental/interface/cli/collaborator.py b/openfl/experimental/interface/cli/collaborator.py index c5f8c924ee..51b78e0555 100644 --- a/openfl/experimental/interface/cli/collaborator.py +++ b/openfl/experimental/interface/cli/collaborator.py @@ -67,10 +67,8 @@ def start_(plan, collaborator_name, secure, data_config="plan/data.yaml"): ) if not os.path.exists(data_config): - logger.warning( - "Collaborator private attributes are set to None as" - f" {data_config} not found in workspace." - ) + logger.warning("Collaborator private attributes are set to None as" + f" {data_config} not found in workspace.") else: import yaml from yaml.loader import SafeLoader @@ -80,8 +78,7 @@ def start_(plan, collaborator_name, secure, data_config="plan/data.yaml"): if data.get(collaborator_name, None) is None: logger.warning( f"Collaborator private attributes are set to None as no attributes" - f" for {collaborator_name} found in {data_config}." - ) + f" for {collaborator_name} found in {data_config}.") logger.info("🧿 Starting the Collaborator Service.") @@ -108,8 +105,7 @@ def generate_cert_request_(collaborator_name, silent, skip_package): def generate_cert_request(collaborator_name, silent, skip_package): - """ - Create collaborator certificate key pair. + """Create collaborator certificate key pair. Then create a package with the CSR to send for signing. """ @@ -124,17 +120,14 @@ def generate_cert_request(collaborator_name, silent, skip_package): echo( f"Creating COLLABORATOR certificate key pair with following settings: " f'CN={style(common_name, fg="red")},' - f' SAN={style(subject_alternative_name, fg="red")}' - ) + f' SAN={style(subject_alternative_name, fg="red")}') client_private_key, client_csr = generate_csr(common_name, server=False) (CERT_DIR / "client").mkdir(parents=True, exist_ok=True) - echo( - " Moving COLLABORATOR certificate to: " - + style(f"{CERT_DIR}/{file_name}", fg="green") - ) + echo(" Moving COLLABORATOR certificate to: " + + style(f"{CERT_DIR}/{file_name}", fg="green")) # Print csr hash before writing csr to disk csr_hash = get_csr_hash(client_csr) @@ -172,14 +165,10 @@ def generate_cert_request(collaborator_name, silent, skip_package): make_archive(archive_name, archive_type, tmp_dir) rmtree(tmp_dir) - echo( - f"Archive {archive_file_name} with certificate signing" - f" request created" - ) - echo( - "This file should be sent to the certificate authority" - " (typically hosted by the aggregator) for signing" - ) + echo(f"Archive {archive_file_name} with certificate signing" + f" request created") + echo("This file should be sent to the certificate authority" + " (typically hosted by the aggregator) for signing") def find_certificate_name(file_name): @@ -193,7 +182,6 @@ def register_collaborator(file_name): Args: file_name (str): The name of the collaborator in this federation - """ from os.path import isfile from pathlib import Path @@ -217,24 +205,16 @@ def register_collaborator(file_name): doc["collaborators"] = [] # Create empty list if col_name in doc["collaborators"]: - echo( - "\nCollaborator " - + style(f"{col_name}", fg="green") - + " is already in the " - + style(f"{cols_file}", fg="green") - ) + echo("\nCollaborator " + style(f"{col_name}", fg="green") + + " is already in the " + style(f"{cols_file}", fg="green")) else: doc["collaborators"].append(col_name) with open(cols_file, "w", encoding="utf-8") as f: dump(doc, f) - echo( - "\nRegistering " - + style(f"{col_name}", fg="green") - + " in " - + style(f"{cols_file}", fg="green") - ) + echo("\nRegistering " + style(f"{col_name}", fg="green") + " in " + + style(f"{cols_file}", fg="green")) @collaborator.command(name="certify") @@ -290,13 +270,11 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False): csr = glob(f"{CERT_DIR}/client/*.csr")[0] else: if collaborator_name is None: - echo( - "collaborator_name can only be omitted if signing\n" - "a zipped request package.\n" - "\n" - "Example: fx collaborator certify --request-pkg " - "col_one_to_agg_cert_request.zip" - ) + echo("collaborator_name can only be omitted if signing\n" + "a zipped request package.\n" + "\n" + "Example: fx collaborator certify --request-pkg " + "col_one_to_agg_cert_request.zip") return csr = glob(f"{CERT_DIR}/client/col_{common_name}.csr")[0] copy(csr, CERT_DIR) @@ -311,10 +289,8 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False): style( "Collaborator certificate signing request not found.", fg="red", - ) - + " Please run `fx collaborator generate-cert-request`" - " to generate the certificate request." - ) + ) + " Please run `fx collaborator generate-cert-request`" + " to generate the certificate request.") csr, csr_hash = read_csr(f"{cert_name}.csr") @@ -323,8 +299,7 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False): echo( style("Signing key not found.", fg="red") + " Please run `fx workspace certify`" - " to initialize the local certificate authority." - ) + " to initialize the local certificate authority.") signing_key = read_key(CERT_DIR / signing_key_path) @@ -333,26 +308,20 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False): echo( style("Signing certificate not found.", fg="red") + " Please run `fx workspace certify`" - " to initialize the local certificate authority." - ) + " to initialize the local certificate authority.") signing_crt = read_crt(CERT_DIR / signing_crt_path) - echo( - "The CSR Hash for file " - + style(f"{file_name}.csr", fg="green") - + " = " - + style(f"{csr_hash}", fg="red") - ) + echo("The CSR Hash for file " + style(f"{file_name}.csr", fg="green") + + " = " + style(f"{csr_hash}", fg="red")) if silent: echo(" Signing COLLABORATOR certificate") echo( " Warning: manual check of certificate hashes is bypassed in silent mode." ) - signed_col_cert = sign_certificate( - csr, signing_key, signing_crt.subject - ) + signed_col_cert = sign_certificate(csr, signing_key, + signing_crt.subject) write_crt(signed_col_cert, f"{cert_name}.crt") register_collaborator(CERT_DIR / "client" / f"{file_name}.crt") @@ -360,9 +329,8 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False): echo("Make sure the two hashes above are the same.") if confirm("Do you want to sign this certificate?"): echo(" Signing COLLABORATOR certificate") - signed_col_cert = sign_certificate( - csr, signing_key, signing_crt.subject - ) + signed_col_cert = sign_certificate(csr, signing_key, + signing_crt.subject) write_crt(signed_col_cert, f"{cert_name}.crt") register_collaborator(CERT_DIR / "client" / f"{file_name}.crt") @@ -370,8 +338,7 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False): echo( style("Not signing certificate.", fg="red") + " Please check with this collaborator to get the" - " correct certificate for this federation." - ) + " correct certificate for this federation.") return if len(common_name) == 0: diff --git a/openfl/experimental/interface/cli/experimental.py b/openfl/experimental/interface/cli/experimental.py index f6ed41e4d3..e985f5b820 100644 --- a/openfl/experimental/interface/cli/experimental.py +++ b/openfl/experimental/interface/cli/experimental.py @@ -18,8 +18,7 @@ def experimental(context): @experimental.command(name="deactivate") def deactivate(): """Deactivate experimental environment.""" - settings = ( - Path("~").expanduser().joinpath(".openfl", "experimental").resolve() - ) + settings = (Path("~").expanduser().joinpath(".openfl", + "experimental").resolve()) os.remove(settings) diff --git a/openfl/experimental/interface/cli/plan.py b/openfl/experimental/interface/cli/plan.py index f2ae1ede2c..7d18170cbd 100644 --- a/openfl/experimental/interface/cli/plan.py +++ b/openfl/experimental/interface/cli/plan.py @@ -51,14 +51,12 @@ def plan(context): required=False, help="The FQDN of the federation agregator", ) -def initialize( - context, plan_config, cols_config, data_config, aggregator_address -): - """ - Initialize Data Science plan. +def initialize(context, plan_config, cols_config, data_config, + aggregator_address): + """Initialize Data Science plan. - Create a protocol buffer file of the initial model weights for - the federation. + Create a protocol buffer file of the initial model weights for the + federation. """ from pathlib import Path @@ -82,18 +80,13 @@ def initialize( plan_origin = Plan.parse(plan_config, resolve=False).config - if ( - plan_origin["network"]["settings"]["agg_addr"] == "auto" - or aggregator_address - ): - plan_origin["network"]["settings"]["agg_addr"] = ( - aggregator_address or getfqdn_env() - ) - - logger.warn( - f"Patching Aggregator Addr in Plan" - f" 🠆 {plan_origin['network']['settings']['agg_addr']}" - ) + if (plan_origin["network"]["settings"]["agg_addr"] == "auto" + or aggregator_address): + plan_origin["network"]["settings"]["agg_addr"] = (aggregator_address + or getfqdn_env()) + + logger.warn(f"Patching Aggregator Addr in Plan" + f" 🠆 {plan_origin['network']['settings']['agg_addr']}") Plan.dump(plan_config, plan_origin) diff --git a/openfl/experimental/interface/cli/workspace.py b/openfl/experimental/interface/cli/workspace.py index 2aff2498bb..a37306a640 100644 --- a/openfl/experimental/interface/cli/workspace.py +++ b/openfl/experimental/interface/cli/workspace.py @@ -37,8 +37,7 @@ def create_dirs(prefix): (prefix / "data").mkdir(parents=True, exist_ok=True) # training data (prefix / "logs").mkdir(parents=True, exist_ok=True) # training logs (prefix / "save").mkdir( - parents=True, exist_ok=True - ) # model weight saves / initialization + parents=True, exist_ok=True) # model weight saves / initialization (prefix / "src").mkdir(parents=True, exist_ok=True) # model code copyfile(WORKSPACE / "workspace" / ".workspace", prefix / ".workspace") @@ -68,16 +67,16 @@ def get_templates(): from openfl.experimental.interface.cli.cli_helper import WORKSPACE return [ - d.name - for d in WORKSPACE.glob("*") + d.name for d in WORKSPACE.glob("*") if d.is_dir() and d.name not in ["__pycache__", "workspace"] ] @workspace.command(name="create") -@option( - "--prefix", required=True, help="Workspace name or path", type=ClickPath() -) +@option("--prefix", + required=True, + help="Workspace name or path", + type=ClickPath()) @option( "--custom_template", required=False, @@ -106,30 +105,21 @@ def create_(prefix, custom_template, template, notebook, template_output_dir): if custom_template and template and notebook: raise ValueError( "Please provide either `template`, `custom_template` or " - + "`notebook`. Not all are necessary" - ) - elif ( - (custom_template and template) - or (template and notebook) - or (custom_template and notebook) - ): - raise ValueError( - "Please provide only one of the following options: " - + "`template`, `custom_template`, or `notebook`." - ) + + "`notebook`. Not all are necessary") + elif ((custom_template and template) or (template and notebook) + or (custom_template and notebook)): + raise ValueError("Please provide only one of the following options: " + + "`template`, `custom_template`, or `notebook`.") if not (custom_template or template or notebook): - raise ValueError( - "Please provide one of the following options: " - + "`template`, `custom_template`, or `notebook`." - ) + raise ValueError("Please provide one of the following options: " + + "`template`, `custom_template`, or `notebook`.") if notebook: if not template_output_dir: raise ValueError( "Please provide output_workspace which is Destination directory to " - + "save your Jupyter Notebook workspace." - ) + + "save your Jupyter Notebook workspace.") from openfl.experimental.workspace_export import WorkspaceExport @@ -142,12 +132,10 @@ def create_(prefix, custom_template, template, notebook, template_output_dir): logger.warning( "The user should review the generated workspace for completeness " - + "before proceeding" - ) + + "before proceeding") else: - template = ( - Path(custom_template).resolve() if custom_template else template - ) + template = (Path(custom_template).resolve() + if custom_template else template) create(prefix, template) @@ -178,8 +166,7 @@ def create(prefix, template): "Participant private attributes shall be set to None as plan/data.yaml" + " was not found in the workspace.", fg="yellow", - ) - ) + )) if isfile(f"{str(prefix)}/{requirements_filename}"): check_call( @@ -193,14 +180,15 @@ def create(prefix, template): ], shell=False, ) - echo(f"Successfully installed packages from {prefix}/requirements.txt.") + echo( + f"Successfully installed packages from {prefix}/requirements.txt.") else: echo("No additional requirements for workspace defined. Skipping...") prefix_hash = _get_dir_hash(str(prefix.absolute())) with open( - OPENFL_USERDIR / f"requirements.{prefix_hash}.txt", - "w", - encoding="utf-8", + OPENFL_USERDIR / f"requirements.{prefix_hash}.txt", + "w", + encoding="utf-8", ) as f: check_call([executable, "-m", "pip", "freeze"], shell=False, stdout=f) @@ -237,8 +225,7 @@ def export_(pip_install_options: Tuple[str]): + " should review that these does not contain any information which is private and" + " not to be shared.", fg="yellow", - ) - ) + )) plan_file = Path("plan/plan.yaml").absolute() try: @@ -247,9 +234,8 @@ def export_(pip_install_options: Tuple[str]): echo(f'Plan file "{plan_file}" not found. No freeze performed.') # Dump requirements.txt - dump_requirements_file( - prefixes=pip_install_options, keep_original_prefixes=True - ) + dump_requirements_file(prefixes=pip_install_options, + keep_original_prefixes=True) archive_type = "zip" archive_name = basename(getcwd()) @@ -258,9 +244,8 @@ def export_(pip_install_options: Tuple[str]): # Aggregator workspace tmp_dir = join(mkdtemp(), "openfl", archive_name) - ignore = ignore_patterns( - "__pycache__", "*.crt", "*.key", "*.csr", "*.srl", "*.pem", "*.pbuf" - ) + ignore = ignore_patterns("__pycache__", "*.crt", "*.key", "*.csr", "*.srl", + "*.pem", "*.pbuf") # We only export the minimum required files to set up a collaborator makedirs(f"{tmp_dir}/save", exist_ok=True) @@ -277,10 +262,8 @@ def export_(pip_install_options: Tuple[str]): if confirm("Create a default '.workspace' file?"): copy2(WORKSPACE / "workspace" / ".workspace", tmp_dir) else: - echo( - "To proceed, you must have a '.workspace' " - "file in the current directory." - ) + echo("To proceed, you must have a '.workspace' " + "file in the current directory.") raise # Create Zip archive of directory @@ -351,29 +334,28 @@ def certify(): echo("1. Create Root CA") echo("1.1 Create Directories") - (CERT_DIR / "ca/root-ca/private").mkdir( - parents=True, exist_ok=True, mode=0o700 - ) + (CERT_DIR / "ca/root-ca/private").mkdir(parents=True, + exist_ok=True, + mode=0o700) (CERT_DIR / "ca/root-ca/db").mkdir(parents=True, exist_ok=True) echo("1.2 Create Database") - with open( - CERT_DIR / "ca/root-ca/db/root-ca.db", "w", encoding="utf-8" - ) as f: + with open(CERT_DIR / "ca/root-ca/db/root-ca.db", "w", + encoding="utf-8") as f: pass # write empty file - with open( - CERT_DIR / "ca/root-ca/db/root-ca.db.attr", "w", encoding="utf-8" - ) as f: + with open(CERT_DIR / "ca/root-ca/db/root-ca.db.attr", + "w", + encoding="utf-8") as f: pass # write empty file - with open( - CERT_DIR / "ca/root-ca/db/root-ca.crt.srl", "w", encoding="utf-8" - ) as f: + with open(CERT_DIR / "ca/root-ca/db/root-ca.crt.srl", + "w", + encoding="utf-8") as f: f.write("01") # write file with '01' - with open( - CERT_DIR / "ca/root-ca/db/root-ca.crl.srl", "w", encoding="utf-8" - ) as f: + with open(CERT_DIR / "ca/root-ca/db/root-ca.crl.srl", + "w", + encoding="utf-8") as f: f.write("01") # write file with '01' echo("1.3 Create CA Request and Certificate") @@ -385,11 +367,7 @@ def certify(): # Write root CA certificate to disk with open(CERT_DIR / root_crt_path, "wb") as f: - f.write( - root_cert.public_bytes( - encoding=serialization.Encoding.PEM, - ) - ) + f.write(root_cert.public_bytes(encoding=serialization.Encoding.PEM, )) with open(CERT_DIR / root_key_path, "wb") as f: f.write( @@ -397,35 +375,34 @@ def certify(): encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), - ) - ) + )) echo("2. Create Signing Certificate") echo("2.1 Create Directories") - (CERT_DIR / "ca/signing-ca/private").mkdir( - parents=True, exist_ok=True, mode=0o700 - ) + (CERT_DIR / "ca/signing-ca/private").mkdir(parents=True, + exist_ok=True, + mode=0o700) (CERT_DIR / "ca/signing-ca/db").mkdir(parents=True, exist_ok=True) echo("2.2 Create Database") - with open( - CERT_DIR / "ca/signing-ca/db/signing-ca.db", "w", encoding="utf-8" - ) as f: + with open(CERT_DIR / "ca/signing-ca/db/signing-ca.db", + "w", + encoding="utf-8") as f: pass # write empty file - with open( - CERT_DIR / "ca/signing-ca/db/signing-ca.db.attr", "w", encoding="utf-8" - ) as f: + with open(CERT_DIR / "ca/signing-ca/db/signing-ca.db.attr", + "w", + encoding="utf-8") as f: pass # write empty file - with open( - CERT_DIR / "ca/signing-ca/db/signing-ca.crt.srl", "w", encoding="utf-8" - ) as f: + with open(CERT_DIR / "ca/signing-ca/db/signing-ca.crt.srl", + "w", + encoding="utf-8") as f: f.write("01") # write file with '01' - with open( - CERT_DIR / "ca/signing-ca/db/signing-ca.crl.srl", "w", encoding="utf-8" - ) as f: + with open(CERT_DIR / "ca/signing-ca/db/signing-ca.crl.srl", + "w", + encoding="utf-8") as f: f.write("01") # write file with '01' echo("2.3 Create Signing Certificate CSR") @@ -438,11 +415,8 @@ def certify(): # Write Signing CA CSR to disk with open(CERT_DIR / signing_csr_path, "wb") as f: - f.write( - signing_csr.public_bytes( - encoding=serialization.Encoding.PEM, - ) - ) + f.write(signing_csr.public_bytes( + encoding=serialization.Encoding.PEM, )) with open(CERT_DIR / signing_key_path, "wb") as f: f.write( @@ -450,21 +424,18 @@ def certify(): encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), - ) - ) + )) echo("2.4 Sign Signing Certificate CSR") - signing_cert = sign_certificate( - signing_csr, root_private_key, root_cert.subject, ca=True - ) + signing_cert = sign_certificate(signing_csr, + root_private_key, + root_cert.subject, + ca=True) with open(CERT_DIR / signing_crt_path, "wb") as f: f.write( - signing_cert.public_bytes( - encoding=serialization.Encoding.PEM, - ) - ) + signing_cert.public_bytes(encoding=serialization.Encoding.PEM, )) echo("3 Create Certificate Chain") @@ -506,13 +477,14 @@ def _get_dir_hash(path): def apply_template_plan(prefix, template): """Copy plan file from template folder. - This function unfolds default values from template plan configuration - and writes the configuration to the current workspace. + This function unfolds default values from template plan configuration and + writes the configuration to the current workspace. """ from openfl.experimental.federated.plan import Plan from openfl.experimental.interface.cli.cli_helper import WORKSPACE - # Use the specified template if it's a Path, otherwise use WORKSPACE/template + # Use the specified template if it's a Path, otherwise use + # WORKSPACE/template source = template if isinstance(template, Path) else WORKSPACE / template template_plan = Plan.parse(source / "plan" / "plan.yaml") diff --git a/openfl/experimental/interface/fl_spec.py b/openfl/experimental/interface/fl_spec.py index 74ea7415af..ee7be0c986 100644 --- a/openfl/experimental/interface/fl_spec.py +++ b/openfl/experimental/interface/fl_spec.py @@ -26,31 +26,48 @@ class FLSpec: _initial_state = None def __init__(self, checkpoint: bool = False): + """Initializes the FLSpec object. + + Args: + checkpoint (bool, optional): Determines whether to checkpoint or + not. Defaults to False. + """ self._foreach_methods = [] self._checkpoint = checkpoint @classmethod def _create_clones(cls, instance: Type[FLSpec], names: List[str]) -> None: - """Creates clones for instance for each collaborator in names""" + """Creates clones for instance for each collaborator in names. + + Args: + instance (Type[FLSpec]): The instance to be cloned. + names (List[str]): The list of names for the clones. + """ cls._clones = {name: deepcopy(instance) for name in names} @classmethod def _reset_clones(cls): - """Reset clones""" - cls._clones = {} + """Resets the clones of the class.""" + + cls._clones = [] @classmethod def save_initial_state(cls, instance: Type[FLSpec]) -> None: - """Save initial state of instance before executing the flow""" + """Saves the initial state of an instance before executing the flow. + + Args: + instance (Type[FLSpec]): The instance whose initial state is to be + saved. + """ cls._initial_state = deepcopy(instance) def run(self) -> None: - """Starts the execution of the flow""" + """Starts the execution of the flow.""" + # Submit flow to Runtime if str(self._runtime) == "LocalRuntime": self._metaflow_interface = MetaflowInterface( - self.__class__, self.runtime.backend - ) + self.__class__, self.runtime.backend) self._run_id = self._metaflow_interface.create_run() # Initialize aggregator private attributes self.runtime.initialize_aggregator() @@ -64,7 +81,8 @@ def run(self) -> None: try: # Execute all Participant (Aggregator & Collaborator) tasks and # retrieve the final attributes - # start step is the first task & invoked on aggregator through runtime.execute_task + # start step is the first task & invoked on aggregator through + # runtime.execute_task final_attributes = self.runtime.execute_task( self, self.start, @@ -79,8 +97,7 @@ def run(self) -> None: "\n or for more information about the original error," "\nPlease see the official Ray documentation" "\nhttps://docs.ray.io/en/releases-2.2.0/ray-core/\ - objects/serialization.html" - ) + objects/serialization.html") raise SerializationError(str(e) + msg) else: raise e @@ -93,25 +110,38 @@ def run(self) -> None: @property def runtime(self) -> Type[Runtime]: - """Returns flow runtime""" + """Returns flow runtime. + + Returns: + Type[Runtime]: The runtime of the flow. + """ return self._runtime @runtime.setter def runtime(self, runtime: Type[Runtime]) -> None: - """Sets flow runtime""" + """Sets flow runtime. + + Args: + runtime (Type[Runtime]): The runtime to be set. + + Raises: + TypeError: If the provided runtime is not a valid OpenFL Runtime. + """ if isinstance(runtime, Runtime): self._runtime = runtime else: raise TypeError(f"{runtime} is not a valid OpenFL Runtime") def _capture_instance_snapshot(self, kwargs): - """ - Takes backup of self before exclude or include filtering + """Takes backup of self before exclude or include filtering. Args: kwargs: Key word arguments originally passed to the next function. If include or exclude are in the kwargs, the state of the - aggregator needs to be retained + aggregator needs to be retained. + + Returns: + return_objs (list): A list of return objects. """ return_objs = [] if "exclude" in kwargs or "include" in kwargs: @@ -119,15 +149,17 @@ def _capture_instance_snapshot(self, kwargs): return_objs.append(backup) return return_objs - def _is_at_transition_point( - self, f: Callable, parent_func: Callable - ) -> bool: - """ - Has the collaborator finished its current sequence? + def _is_at_transition_point(self, f: Callable, + parent_func: Callable) -> bool: + """Determines if the collaborator has finished its current sequence. Args: - f: The next function to be executed - parent_func: The previous function executed + f (Callable): The next function to be executed. + parent_func (Callable): The previous function executed. + + Returns: + bool: True if the collaborator has finished its current sequence, + False otherwise. """ if parent_func.__name__ in self._foreach_methods: self._foreach_methods.append(f.__name__) @@ -139,12 +171,14 @@ def _is_at_transition_point( return True return False - def _display_transition_logs( - self, f: Callable, parent_func: Callable - ) -> None: - """ - Prints aggregator to collaborators or - collaborators to aggregator state transition logs + def _display_transition_logs(self, f: Callable, + parent_func: Callable) -> None: + """Prints aggregator to collaborators or collaborators to aggregator + state transition logs. + + Args: + f (Callable): The next function to be executed. + parent_func (Callable): The previous function executed. """ if aggregator_to_collaborator(f, parent_func): print("Sending state from aggregator to collaborators") @@ -153,31 +187,44 @@ def _display_transition_logs( print("Sending state from collaborator to aggregator") def filter_exclude_include(self, f, **kwargs): - """ - This function filters exclude/include attributes + """Filters exclude/include attributes for a given task within the flow. Args: - flspec_obj : Reference to the FLSpec (flow) object - f : The task to be executed within the flow + f (Callable): The task to be executed within the flow. + **kwargs (dict): Additional keyword arguments. These should + include: + - "foreach" (str): The attribute name that contains the list + of selected collaborators. + - "exclude" (list, optional): List of attribute names to + exclude. If an attribute name is present in this list and the + clone has this attribute, it will be filtered out. + - "include" (list, optional): List of attribute names to + include. If an attribute name is present in this list and the + clone has this attribute, it will be included. """ selected_collaborators = getattr(self, kwargs["foreach"]) for col in selected_collaborators: clone = FLSpec._clones[col] clone.input = col - if ( - "exclude" in kwargs and hasattr(clone, kwargs["exclude"][0]) - ) or ("include" in kwargs and hasattr(clone, kwargs["include"][0])): + if ("exclude" in kwargs and hasattr(clone, kwargs["exclude"][0]) + ) or ("include" in kwargs + and hasattr(clone, kwargs["include"][0])): filter_attributes(clone, f, **kwargs) artifacts_iter, _ = generate_artifacts(ctx=self) for name, attr in artifacts_iter(): setattr(clone, name, deepcopy(attr)) clone._foreach_methods = self._foreach_methods - def restore_instance_snapshot( - self, ctx: FLSpec, instance_snapshot: List[FLSpec] - ): - """Restores attributes from backup (in instance snapshot) to ctx""" + def restore_instance_snapshot(self, ctx: FLSpec, + instance_snapshot: List[FLSpec]): + """Restores attributes from backup (in instance snapshot) to ctx. + + Args: + ctx (FLSpec): The context to restore the attributes to. + instance_snapshot (List[FLSpec]): The list of FLSpec instances + that serve as the backup. + """ for backup in instance_snapshot: artifacts_iter, _ = generate_artifacts(ctx=backup) for name, attr in artifacts_iter(): @@ -185,9 +232,7 @@ def restore_instance_snapshot( setattr(ctx, name, attr) def get_clones(self, kwargs): - """ - Create, and prepare clones - """ + """Create, and prepare clones.""" FLSpec._reset_clones() FLSpec._create_clones(self, self.runtime.collaborators) selected_collaborators = self.__getattribute__(kwargs["foreach"]) @@ -203,8 +248,11 @@ def get_clones(self, kwargs): clone._metaflow_interface = self._metaflow_interface def next(self, f, **kwargs): - """ - Next task in the flow to execute + """Specifies the next task in the flow to execute. + + Args: + f (Callable): The next task that will be executed in the flow. + **kwargs: Additional keyword arguments. """ # Get the name and reference to the calling function parent = inspect.stack()[1][3] diff --git a/openfl/experimental/interface/participants.py b/openfl/experimental/interface/participants.py index d3c5210725..9d2d5d55fa 100644 --- a/openfl/experimental/interface/participants.py +++ b/openfl/experimental/interface/participants.py @@ -6,71 +6,101 @@ class Participant: + """Class for a participant. + + Attributes: + private_attributes (dict): The private attributes of the participant. + _name (str): The name of the participant. + """ def __init__(self, name: str = ""): + """Initializes the Participant object with an optional name. + + Args: + name (str, optional): The name of the participant. Defaults to "". + """ self.private_attributes = {} self._name = name @property def name(self): + """Returns the name of the participant. + + Returns: + str: The name of the participant. + """ return self._name @name.setter def name(self, name: str): + """Sets the name of the participant. + + Args: + name (str): The name to be set. + """ self._name = name def private_attributes(self, attrs: Dict[str, Any]) -> None: - """ - Set the private attributes of the participant. These attributes will + """Set the private attributes of the participant. These attributes will only be available within the tasks performed by the participants and will be filtered out prior to the task's state being transfered. Args: - attrs: dictionary of ATTRIBUTE_NAME (str) -> object that will be accessible - within the participant's task. + attrs (Dict[str, Any]): dictionary of ATTRIBUTE_NAME (str) -> + object that will be accessible within the participant's task. - Example: - {'train_loader' : torch.utils.data.DataLoader(...)} + Example: + {'train_loader' : torch.utils.data.DataLoader(...)} - In any task performed by this participant performed within the flow, - this attribute could be referenced with self.train_loader + In any task performed by this participant performed within the + flow, this attribute could be referenced with self.train_loader """ self.private_attributes = attrs class Collaborator(Participant): + """Class for a collaborator participant, derived from the Participant + class. + + Attributes: + name (str): Name of the collaborator. + private_attributes_callable (Callable): A function which returns + collaborator private attributes for each collaborator. + num_cpus (int): Specifies how many cores to use for the collaborator + step execution. + num_gpus (float): Specifies how many GPUs to use to accelerate the + collaborator step execution. + kwargs (dict): Parameters required to call private_attributes_callable + function. """ - Defines a collaborator participant - """ - - def __init__( - self, - name: str = "", - private_attributes_callable: Callable = None, - num_cpus: int = 0, - num_gpus: int = 0.0, - **kwargs - ): - """ - Create collaborator object with custom resources and a callable - function to assign private attributes - Parameters: - name (str): Name of the collaborator. default="" + def __init__(self, + name: str = "", + private_attributes_callable: Callable = None, + num_cpus: int = 0, + num_gpus: int = 0.0, + **kwargs): + """Initializes the Collaborator object. - private_attributes_callable (Callable): A function which returns collaborator - private attributes for each collaborator. In case private_attributes are not - required this can be omitted. default=None - - num_cpus (int): Specifies how many cores to use for the collaborator step exection. - This will only be used if backend is set to ray. default=0 - - num_gpus (float): Specifies how many GPUs to use to accerlerate the collaborator - step exection. This will only be used if backend is set to ray. default=0 + Create collaborator object with custom resources and a callable + function to assign private attributes. - kwargs (dict): Parameters required to call private_attributes_callable function. - The key of the dictionary must match the arguments to the private_attributes_callable. - default={} + Args: + name (str, optional): Name of the collaborator. Defaults to "". + private_attributes_callable (Callable, optional): A function which + returns collaborator private attributes for each collaborator. + In case private_attributes are not required this can be + omitted. Defaults to None. + num_cpus (int, optional): Specifies how many cores to use for the + collaborator step execution. This will only be used if backend + is set to ray. Defaults to 0. + num_gpus (float, optional): Specifies how many GPUs to use to + accelerate the collaborator step execution. This will only be + used if backend is set to ray. Defaults to 0.0. + **kwargs (dict): Parameters required to call + private_attributes_callable function. The key of the + dictionary must match the arguments to the + private_attributes_callable. Defaults to {}. """ super().__init__(name=name) self.num_cpus = num_cpus @@ -88,25 +118,28 @@ def __init__( self.private_attributes_callable = private_attributes_callable def get_name(self) -> str: - """Get collaborator name""" - return self._name + """Gets the name of the collaborator. - def initialize_private_attributes(self, private_attrs: Dict[Any, Any] = None) -> None: - """ - initialize private attributes of Collaborator object by invoking - the callable or by passing private_attrs argument + Returns: + str: The name of the collaborator. """ + return self._name + + def initialize_private_attributes(self) -> None: + """Initialize private attributes of Collaborator object by invoking the + callable specified by user.""" if self.private_attributes_callable is not None: self.private_attributes = self.private_attributes_callable( - **self.kwargs - ) + **self.kwargs) elif private_attrs: self.private_attributes = private_attrs def __set_collaborator_attrs_to_clone(self, clone: Any) -> None: - """ - Set collaborator private attributes to FLSpec clone before transitioning - from Aggregator step to collaborator steps + """Set collaborator private attributes to FLSpec clone before + transitioning from Aggregator step to collaborator steps. + + Args: + clone (Any): The clone to set attributes to. """ # set collaborator private attributes as # clone attributes @@ -114,22 +147,30 @@ def __set_collaborator_attrs_to_clone(self, clone: Any) -> None: setattr(clone, name, attr) def __delete_collab_attrs_from_clone(self, clone: Any) -> None: - """ - Remove collaborator private attributes from FLSpec clone before - transitioning from Collaborator step to Aggregator step + """Remove collaborator private attributes from FLSpec clone before + transitioning from Collaborator step to Aggregator step. + + Args: + clone (Any): The clone to remove attributes from. """ # Update collaborator private attributes by taking latest # parameters from clone, then delete attributes from clone. for attr_name in self.private_attributes: if hasattr(clone, attr_name): self.private_attributes.update( - {attr_name: getattr(clone, attr_name)} - ) + {attr_name: getattr(clone, attr_name)}) delattr(clone, attr_name) def execute_func(self, ctx: Any, f_name: str, callback: Callable) -> Any: - """ - Execute remote function f + """Execute remote function f. + + Args: + ctx (Any): The context to execute the function in. + f_name (str): The name of the function to execute. + callback (Callable): The callback to execute after the function. + + Returns: + Any: The result of the function execution. """ self.__set_collaborator_attrs_to_clone(ctx) @@ -141,38 +182,35 @@ def execute_func(self, ctx: Any, f_name: str, callback: Callable) -> Any: class Aggregator(Participant): - """ - Defines an aggregator participant - """ + """Class for an aggregator participant, derived from the Participant + class.""" - def __init__( - self, - name: str = "", - private_attributes_callable: Callable = None, - num_cpus: int = 0, - num_gpus: int = 0.0, - **kwargs - ): - """ - Create aggregator object with custom resources and a callable - function to assign private attributes - - Parameters: - name (str): Name of the aggregator. default="" - - private_attributes_callable (Callable): A function which returns aggregator - private attributes. In case private_attributes are not required this can be omitted. - default=None - - num_cpus (int): Specifies how many cores to use for the aggregator step exection. - This will only be used if backend is set to ray. default=0 + def __init__(self, + name: str = "", + private_attributes_callable: Callable = None, + num_cpus: int = 0, + num_gpus: int = 0.0, + **kwargs): + """Initializes the Aggregator object. - num_gpus (float): Specifies how many GPUs to use to accerlerate the aggregator - step exection. This will only be used if backend is set to ray. default=0 + Create aggregator object with custom resources and a callable + function to assign private attributes. - kwargs (dict): Parameters required to call private_attributes_callable function. - The key of the dictionary must match the arguments to the private_attributes_callable. - default={} + Args: + name (str, optional): Name of the aggregator. Defaults to "". + private_attributes_callable (Callable, optional): A function which + returns aggregator private attributes. In case + private_attributes are not required this can be omitted. + Defaults to None. + num_cpus (int, optional): Specifies how many cores to use for the + aggregator step execution. This will only be used if backend + is set to ray. Defaults to 0. + num_gpus (float, optional): Specifies how many GPUs to use to + accelerate the aggregator step execution. This will only be + used if backend is set to ray. Defaults to 0.0. + **kwargs: Parameters required to call private_attributes_callable + function. The key of the dictionary must match the arguments + to the private_attributes_callable. Defaults to {}. """ super().__init__(name=name) self.num_cpus = num_cpus @@ -184,31 +222,33 @@ def __init__( else: if not callable(private_attributes_callable): raise Exception( - "private_attributes_callable parameter must be a callable" - ) + "private_attributes_callable parameter must be a callable") else: self.private_attributes_callable = private_attributes_callable def get_name(self) -> str: - """Get aggregator name""" - return self.name + """Gets the name of the aggregator. - def initialize_private_attributes(self, private_attrs: Dict[Any, Any] = None) -> None: - """ - initialize private attributes of Aggregator object by invoking - the callable or by passing private_attrs argument + Returns: + str: The name of the aggregator. """ + return self.name + + def initialize_private_attributes(self) -> None: + """Initialize private attributes of Aggregator object by invoking the + callable specified by user.""" if self.private_attributes_callable is not None: self.private_attributes = self.private_attributes_callable( - **self.kwargs - ) + **self.kwargs) elif private_attrs: self.private_attributes = private_attrs def __set_agg_attrs_to_clone(self, clone: Any) -> None: - """ - Set aggregator private attributes to FLSpec clone before transition - from Aggregator step to collaborator steps + """Set aggregator private attributes to FLSpec clone before transition + from Aggregator step to collaborator steps. + + Args: + clone (Any): The clone to set attributes to. """ # set aggregator private attributes as # clone attributes @@ -216,28 +256,36 @@ def __set_agg_attrs_to_clone(self, clone: Any) -> None: setattr(clone, name, attr) def __delete_agg_attrs_from_clone(self, clone: Any) -> None: - """ - Remove aggregator private attributes from FLSpec clone before - transition from Aggregator step to collaborator steps + """Remove aggregator private attributes from FLSpec clone before + transition from Aggregator step to collaborator steps. + + Args: + clone (Any): The clone to remove attributes from. """ # Update aggregator private attributes by taking latest # parameters from clone, then delete attributes from clone. for attr_name in self.private_attributes: if hasattr(clone, attr_name): self.private_attributes.update( - {attr_name: getattr(clone, attr_name)} - ) + {attr_name: getattr(clone, attr_name)}) delattr(clone, attr_name) - def execute_func( - self, - ctx: Any, - f_name: str, - callback: Callable, - clones: Optional[Any] = None, - ) -> Any: - """ - Execute remote function f + def execute_func(self, + ctx: Any, + f_name: str, + callback: Callable, + clones: Optional[Any] = None) -> Any: + """Executes remote function f. + + Args: + ctx (Any): The context to execute the function in. + f_name (str): The name of the function to execute. + callback (Callable): The callback to execute after the function. + clones (Optional[Any], optional): The clones to use in the + function. Defaults to None. + + Returns: + Any: The result of the function execution. """ self.__set_agg_attrs_to_clone(ctx) diff --git a/openfl/experimental/placement/placement.py b/openfl/experimental/placement/placement.py index f7ba1f16e2..1e9ec24397 100644 --- a/openfl/experimental/placement/placement.py +++ b/openfl/experimental/placement/placement.py @@ -8,9 +8,8 @@ def aggregator(f: Callable = None) -> Callable: - """ - Placement decorator that designates that the task will - run at the aggregator node + """Placement decorator that designates that the task will run at the + aggregator node. Usage: class MyFlow(FLSpec): @@ -20,6 +19,11 @@ def agg_task(self): ... ... + Args: + f (Callable, optional): The function to be decorated. + + Returns: + Callable: The decorated function. """ print(f'Aggregator step "{f.__name__}" registered') f.is_step = True @@ -43,9 +47,8 @@ def wrapper(*args, **kwargs): def collaborator(f: Callable = None) -> Callable: - """ - Placement decorator that designates that the task will - run at the collaborator node + """Placement decorator that designates that the task will run at the + collaborator node. Usage: class MyFlow(FLSpec): @@ -60,11 +63,15 @@ def collaborator_task(self): ... Args: - num_gpus: [Applicable for Ray backend only] + f (Callable, optional): The function to be decorated. + num_gpus (float, optional): [Applicable for Ray backend only] Defines how many GPUs will be made available to the task (Default = 0). Selecting a value < 1 (0.0-1.0] will result in sharing of GPUs between tasks. 1 >= results in exclusive GPU access for the task. + + Returns: + Callable: The decorated function. """ if f is None: return functools.partial(collaborator) diff --git a/openfl/experimental/protocols/interceptors.py b/openfl/experimental/protocols/interceptors.py index 02f9c1b6d1..5431465658 100644 --- a/openfl/experimental/protocols/interceptors.py +++ b/openfl/experimental/protocols/interceptors.py @@ -7,46 +7,40 @@ class _GenericClientInterceptor( - grpc.UnaryUnaryClientInterceptor, - grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, - grpc.StreamStreamClientInterceptor, + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, ): def __init__(self, interceptor_function): self._fn = interceptor_function - def intercept_unary_unary(self, continuation, client_call_details, request): + def intercept_unary_unary(self, continuation, client_call_details, + request): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, False - ) + client_call_details, iter((request, )), False, False) response = continuation(new_details, next(new_request_iterator)) return postprocess(response) if postprocess else response - def intercept_unary_stream( - self, continuation, client_call_details, request - ): + def intercept_unary_stream(self, continuation, client_call_details, + request): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, True - ) + client_call_details, iter((request, )), False, True) response_it = continuation(new_details, next(new_request_iterator)) return postprocess(response_it) if postprocess else response_it - def intercept_stream_unary( - self, continuation, client_call_details, request_iterator - ): + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, request_iterator, True, False - ) + client_call_details, request_iterator, True, False) response = continuation(new_details, new_request_iterator) return postprocess(response) if postprocess else response - def intercept_stream_stream( - self, continuation, client_call_details, request_iterator - ): + def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, request_iterator, True, True - ) + client_call_details, request_iterator, True, True) response_it = continuation(new_details, new_request_iterator) return postprocess(response_it) if postprocess else response_it @@ -56,10 +50,10 @@ def _create_generic_interceptor(intercept_call): class _ClientCallDetails( - collections.namedtuple( - "_ClientCallDetails", ("method", "timeout", "metadata", "credentials") - ), - grpc.ClientCallDetails, + collections.namedtuple( + "_ClientCallDetails", + ("method", "timeout", "metadata", "credentials")), + grpc.ClientCallDetails, ): pass @@ -77,12 +71,10 @@ def intercept_call( if client_call_details.metadata is not None: metadata = list(client_call_details.metadata) for header, value in headers.items(): - metadata.append( - ( - header, - value, - ) - ) + metadata.append(( + header, + value, + )) client_call_details = _ClientCallDetails( client_call_details.method, client_call_details.timeout, diff --git a/openfl/experimental/protocols/utils.py b/openfl/experimental/protocols/utils.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfl/experimental/runtime/federated_runtime.py b/openfl/experimental/runtime/federated_runtime.py index bae51c8fe3..daf4756e2f 100644 --- a/openfl/experimental/runtime/federated_runtime.py +++ b/openfl/experimental/runtime/federated_runtime.py @@ -1,6 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -""" openfl.experimental.runtime package LocalRuntime class.""" +"""openfl.experimental.runtime package LocalRuntime class.""" from __future__ import annotations @@ -16,6 +16,13 @@ class FederatedRuntime(Runtime): + """Class for a federated runtime, derived from the Runtime class. + + Attributes: + aggregator (Type[Aggregator]): The aggregator participant. + collaborators (List[Type[Collaborator]]): The list of collaborator + participants. + """ def __init__( self, @@ -23,15 +30,16 @@ def __init__( collaborators: List[str] = None, **kwargs, ) -> None: - """ - Use single node to run the flow + """Initializes the FederatedRuntime object. - Args: - aggregator: Name of the aggregator. - collaborators: List of collaborator names. + Use single node to run the flow. - Returns: - None + Args: + aggregator (str, optional): Name of the aggregator. Defaults to + None. + collaborators (List[str], optional): List of collaborator names. + Defaults to None. + **kwargs: Additional keyword arguments. """ super().__init__() if aggregator is not None: @@ -42,24 +50,38 @@ def __init__( @property def aggregator(self) -> str: - """Returns name of _aggregator""" + """Returns name of _aggregator.""" return self._aggregator @aggregator.setter def aggregator(self, aggregator_name: Type[Aggregator]): - """Set LocalRuntime _aggregator""" + """Set LocalRuntime _aggregator. + + Args: + aggregator_name (Type[Aggregator]): The name of the aggregator to + set. + """ self._aggregator = aggregator_name @property def collaborators(self) -> List[str]: - """ - Return names of collaborators. Don't give direct access to private attributes + """Return names of collaborators. + + Don't give direct access to private attributes. + + Returns: + List[str]: The names of the collaborators. """ return self.__collaborators @collaborators.setter def collaborators(self, collaborators: List[Type[Collaborator]]): - """Set LocalRuntime collaborators""" + """Set LocalRuntime collaborators. + + Args: + collaborators (List[Type[Collaborator]]): The list of + collaborators to set. + """ self.__collaborators = collaborators def __repr__(self): diff --git a/openfl/experimental/runtime/local_runtime.py b/openfl/experimental/runtime/local_runtime.py index 70f9956404..208fa84aff 100644 --- a/openfl/experimental/runtime/local_runtime.py +++ b/openfl/experimental/runtime/local_runtime.py @@ -1,6 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -""" openfl.experimental.runtime package LocalRuntime class.""" +"""openfl.experimental.runtime package LocalRuntime class.""" from __future__ import annotations @@ -33,9 +33,10 @@ class RayExecutor: + """Class for executing tasks using the Ray framework.""" def __init__(self): - """Create RayExecutor object""" + """Initializes the RayExecutor object.""" self.__remote_contexts = [] def ray_call_put( @@ -46,23 +47,31 @@ def ray_call_put( callback: Callable, clones: Optional[Any] = None, ) -> None: - """ - Execute f_name from inside participant (Aggregator or Collaborator) class with the context - of clone (ctx) + """Execute f_name from inside participant (Aggregator or Collaborator) + class with the context of clone (ctx). + + Args: + participant (Any): The participant (Aggregator or Collaborator) to + execute the function in. + ctx (Any): The context to execute the function in. + f_name (str): The name of the function to execute. + callback (Callable): The callback to execute after the function. + clones (Optional[Any], optional): The clones to use in the + function. Defaults to None. """ if clones is not None: self.__remote_contexts.append( - participant.execute_func.remote(ctx, f_name, callback, clones) - ) + participant.execute_func.remote(ctx, f_name, callback, clones)) else: self.__remote_contexts.append( - participant.execute_func.remote(ctx, f_name, callback) - ) + participant.execute_func.remote(ctx, f_name, callback)) def ray_call_get(self) -> List[Any]: - """ - Get remote clones and delete ray references of clone (ctx) and, - reclaim memory + """Get remote clones and delete ray references of clone (ctx) and, + reclaim memory. + + Returns: + List[Any]: The list of remote clones. """ clones = ray.get(self.__remote_contexts) del self.__remote_contexts @@ -72,29 +81,27 @@ def ray_call_get(self) -> List[Any]: def ray_group_assign(collaborators, num_actors=1): - """ - Assigns collaborators to resource groups which share a CUDA context. + """Assigns collaborators to resource groups which share a CUDA context. Args: collaborators (list): The list of collaborators. - num_actors (int, optional): Number of actors to distribute collaborators to. - Defaults to 3. + num_actors (int, optional): Number of actors to distribute + collaborators to. Defaults to 1. Returns: list: A list of GroupMember instances. """ class GroupMember: - """ - A utility class that manages the collaborator and its group. + """A utility class that manages the collaborator and its group. - This class maintains compatibility with runtime execution by assigning attributes for each - function in the Collaborator interface in conjunction with RemoteHelper. + This class maintains compatibility with runtime execution by assigning + attributes for each function in the Collaborator interface in + conjunction with RemoteHelper. """ def __init__(self, collaborator_actor, collaborator): - """ - Initializes a new instance of the GroupMember class. + """Initializes a new instance of the GroupMember class. Args: collaborator_actor: The collaborator actor. @@ -103,8 +110,7 @@ def __init__(self, collaborator_actor, collaborator): from openfl.experimental.interface import Collaborator all_methods = [ - method - for method in dir(Collaborator) + method for method in dir(Collaborator) if callable(getattr(Collaborator, method)) ] external_methods = [ @@ -116,27 +122,26 @@ def __init__(self, collaborator_actor, collaborator): setattr( self, method, - RemoteHelper( - self.collaborator_actor, self.collaborator, method - ), + RemoteHelper(self.collaborator_actor, self.collaborator, + method), ) class RemoteHelper: - """ - A utility class to maintain compatibility with RayExecutor. + """A utility class to maintain compatibility with RayExecutor. - This class returns a lambda function that uses collaborator_actor.execute_from_col to run - a given function from the given collaborator. + This class returns a lambda function that uses + collaborator_actor.execute_from_col to run a given function from the + given collaborator. """ - # once ray_grouped replaces the current ray runtime this class can be replaced with a - # funtion that returns the lambda funtion, using a funtion is necesary because this is used - # in setting multiple funtions in a loop and lambda takes the reference to self.f_name and - # not the value so we need to change scope to avoid self.f_name from changing as the loop - # progresses + # once ray_grouped replaces the current ray runtime this class can be + # replaced with a funtion that returns the lambda funtion, using a + # funtion is necesary because this is used in setting multiple + # funtions in a loop and lambda takes the reference to self.f_name and + # not the value so we need to change scope to avoid self.f_name from + # changing as the loop progresses def __init__(self, collaborator_actor, collaborator, f_name) -> None: - """ - Initializes a new instance of the RemoteHelper class. + """Initializes a new instance of the RemoteHelper class. Args: collaborator_actor: The collaborator actor. @@ -147,12 +152,11 @@ def __init__(self, collaborator_actor, collaborator, f_name) -> None: self.collaborator_actor = collaborator_actor self.collaborator = collaborator self.f = lambda *args, **kwargs: self.collaborator_actor.execute_from_col.remote( - self.collaborator, self.f_name, *args, **kwargs - ) + self.collaborator, self.f_name, *args, **kwargs) def remote(self, *args, **kwargs): - """ - Executes the function with the given arguments and keyword arguments. + """Executes the function with the given arguments and keyword + arguments. Args: *args: The arguments to pass to the function. @@ -166,8 +170,8 @@ def remote(self, *args, **kwargs): collaborator_ray_refs = [] collaborators_per_group = math.ceil(len(collaborators) / num_actors) times_called = 0 - # logic to sort collaborators by gpus, if collaborators have the same number of gpu then they - # are sorted by cpu + # logic to sort collaborators by gpus, if collaborators have the same + # number of gpu then they are sorted by cpu cpu_magnitude = len(str(abs(max([i.num_cpus for i in collaborators])))) min_gpu = min([i.num_gpus for i in collaborators]) min_gpu = max(min_gpu, 0.0001) @@ -180,43 +184,30 @@ def remote(self, *args, **kwargs): for collaborator in collaborators_sorted_by_gpucpu: # initialize actor group if times_called % collaborators_per_group == 0: - max_num_cpus = max( - [ - i.num_cpus - for i in collaborators_sorted_by_gpucpu[ - times_called : times_called + collaborators_per_group - ] - ] - ) - max_num_gpus = max( - [ - i.num_gpus - for i in collaborators_sorted_by_gpucpu[ - times_called : times_called + collaborators_per_group - ] - ] - ) + max_num_cpus = max([ + i.num_cpus for i in + collaborators_sorted_by_gpucpu[times_called:times_called + + collaborators_per_group] + ]) + max_num_gpus = max([ + i.num_gpus for i in + collaborators_sorted_by_gpucpu[times_called:times_called + + collaborators_per_group] + ]) print(f"creating actor with {max_num_cpus}, {max_num_gpus}") collaborator_actor = ( - ray.remote(RayGroup) - .options( - num_cpus=max_num_cpus, num_gpus=max_num_gpus - ) # max_concurrency=max_concurrency) - .remote() - ) + ray.remote(RayGroup).options( + num_cpus=max_num_cpus, + num_gpus=max_num_gpus) # max_concurrency=max_concurrency) + .remote()) # add collaborator to actor group - initializations.append( - collaborator_actor.append.remote( - collaborator - ) - ) + initializations.append(collaborator_actor.append.remote(collaborator)) times_called += 1 # append GroupMember to output list collaborator_ray_refs.append( - GroupMember(collaborator_actor, collaborator.get_name()) - ) + GroupMember(collaborator_actor, collaborator.get_name())) # Wait for all collaborators to be created on actors ray.get(initializations) @@ -224,31 +215,27 @@ def remote(self, *args, **kwargs): class RayGroup: - """ - A Ray actor that manages a group of collaborators. + """A Ray actor that manages a group of collaborators. - This class allows for the execution of functions from a specified collaborator - using the execute_from_col method. The collaborators are stored in a dictionary - where the key is the collaborator's name. + This class allows for the execution of functions from a specified + collaborator using the execute_from_col method. The collaborators are + stored in a dictionary where the key is the collaborator's name. """ def __init__(self): - """ - Initializes a new instance of the RayGroup class. - """ + """Initializes a new instance of the RayGroup class.""" self.collaborators = {} def append( self, collaborator: Collaborator, ): - """ - Appends a new collaborator to the group. + """Appends a new collaborator to the group. Args: name (str): The name of the collaborator. - private_attributes_callable (Callable): A callable that sets the private attributes of - the collaborator. + private_attributes_callable (Callable): A callable that sets the + private attributes of the collaborator. **kwargs: Additional keyword arguments. """ from openfl.experimental.interface import Collaborator @@ -256,7 +243,8 @@ def append( if collaborator.private_attributes_callable is not None: self.collaborators[collaborator.name] = Collaborator( name=collaborator.name, - private_attributes_callable=collaborator.private_attributes_callable, + private_attributes_callable=collaborator. + private_attributes_callable, **collaborator.kwargs, ) elif collaborator.private_attributes is not None: @@ -264,13 +252,12 @@ def append( name=collaborator.name, **collaborator.kwargs, ) - self.collaborators[collaborator.name].initialize_private_attributes( - collaborator.private_attributes - ) + self.collaborators[ + collaborator.name].initialize_private_attributes( + collaborator.private_attributes) def execute_from_col(self, name, internal_f_name, *args, **kwargs): - """ - Executes a function from a specified collaborator. + """Executes a function from a specified collaborator. Args: name (str): The name of the collaborator. @@ -285,8 +272,7 @@ def execute_from_col(self, name, internal_f_name, *args, **kwargs): return f(*args, **kwargs) def get_collaborator(self, name): - """ - Retrieves a collaborator from the group by name. + """Retrieves a collaborator from the group by name. Args: name (str): The name of the collaborator. @@ -298,6 +284,14 @@ def get_collaborator(self, name): class LocalRuntime(Runtime): + """Class for a local runtime, derived from the Runtime class. + + Attributes: + aggregator (Type[Aggregator]): The aggregator participant. + __collaborators (dict): The collaborators, stored as a dictionary of + names to participants. + backend (str): The backend that will execute the tasks. + """ def __init__( self, @@ -306,47 +300,53 @@ def __init__( backend: str = "single_process", **kwargs, ) -> None: - """ - Use single node to run the flow + """Initializes the LocalRuntime object to run the flow on a single + node, with an optional aggregator, an optional list of collaborators, + an optional backend, and additional keyword arguments. Args: - aggregator: The aggregator instance that holds private attributes - collaborators: A list of collaborators; each with their own private attributes - backend: The backend that will execute the tasks. Available options are: - - 'single_process': (default) Executes every task within the same process - - 'ray': Executes tasks using the Ray library. We use ray - actors called RayGroups to runs tasks in their own - isolated process. Each participant is distributed - into a ray group. The RayGroups run concurrently - while participants in the group run serially. - The default is 1 RayGroup and can be changed by using - the num_actors=1 kwarg. By using more RayGroups more - concurency is allowed with the trade off being that - each RayGroup has extra memory overhead in the form - of extra CUDA CONTEXTS. - - Also the ray runtime supports GPU isolation using - Ray's 'num_gpus' argument, which can be passed in - through the collaborator placement decorator. - - Example: - @collaborator(num_gpus=1) - def some_collaborator_task(self): - ... - - - By selecting num_gpus=1, the task is guaranteed - exclusive GPU access. If the system has one GPU, - collaborator tasks will run sequentially. + aggregator (Type[Aggregator], optional): The aggregator instance + that holds private attributes. + collaborators (List[Type[Collaborator]], optional): A list of + collaborators; each with their own private attributes. + backend (str, optional): The backend that will execute the tasks. + Defaults to "single_process". + Available options are: + - 'single_process': (default) Executes every task within the + same process. + - 'ray': Executes tasks using the Ray library. We use ray + actors called RayGroups to runs tasks in their own isolated + process. Each participant is distributed into a ray group. + The RayGroups run concurrently while participants in the + group run serially. + The default is 1 RayGroup and can be changed by using the + num_actors=1 kwarg. By using more RayGroups more concurency + is allowed with the trade off being that each RayGroup has + extra memory overhead in the form of extra CUDA CONTEXTS. + + Also the ray runtime supports GPU isolation using Ray's + 'num_gpus' argument, which can be passed in through the + collaborator placement decorator. + + Raises: + ValueError: If the provided backend value is not 'ray' or + 'single_process'. + + Example: + @collaborator(num_gpus=1) + def some_collaborator_task(self): + # Task implementation + ... + + By selecting num_gpus=1, the task is guaranteed exclusive GPU + access. If the system has one GPU, collaborator tasks will run + sequentially. """ super().__init__() if backend not in ["ray", "single_process"]: raise ValueError( f"Invalid 'backend' value '{backend}', accepted values are " - + "'ray', or 'single_process'" - ) + + "'ray', or 'single_process'") if backend == "ray": if not ray.is_initialized(): dh = kwargs.get("dashboard_host", "127.0.0.1") @@ -363,15 +363,30 @@ def some_collaborator_task(self): self.collaborators = self.__get_collaborator_object(collaborators) def __get_aggregator_object(self, aggregator: Type[Aggregator]) -> Any: - """Get aggregator object based on localruntime backend""" + """Get aggregator object based on localruntime backend. + + If the backend is 'single_process', it returns the aggregator directly. + If the backend is 'ray', it creates a Ray actor for the aggregator + with the specified resources. + + Args: + aggregator (Type[Aggregator]): The aggregator class to instantiate. + + Returns: + Any: The aggregator object or a reference to the Ray actor + representing the aggregator. + + Raises: + ResourcesNotAvailableError: If the requested resources exceed the + available resources. + """ if aggregator.private_attributes and aggregator.private_attributes_callable: self.logger.warning( 'Warning: Aggregator private attributes ' + 'will be initialized via callable and ' + 'attributes via aggregator.private_attributes ' - + 'will be ignored' - ) + + 'will be ignored') if self.backend == "single_process": return aggregator @@ -391,27 +406,24 @@ def __get_aggregator_object(self, aggregator: Type[Aggregator]) -> Any: if total_available_gpus < agg_gpus: raise ResourcesNotAvailableError( f"cannot assign more than available GPUs \ - ({agg_gpus} < {total_available_gpus})." - ) + ({agg_gpus} < {total_available_gpus}).") if total_available_cpus < agg_cpus: raise ResourcesNotAvailableError( f"cannot assign more than available CPUs \ - ({agg_cpus} < {total_available_cpus})." - ) + ({agg_cpus} < {total_available_cpus}).") interface_module = importlib.import_module( - "openfl.experimental.interface" - ) + "openfl.experimental.interface") aggregator_class = getattr(interface_module, "Aggregator") aggregator_actor = ray.remote(aggregator_class).options( - num_cpus=agg_cpus, num_gpus=agg_gpus - ) + num_cpus=agg_cpus, num_gpus=agg_gpus) if aggregator.private_attributes_callable is not None: aggregator_actor_ref = aggregator_actor.remote( name=aggregator.get_name(), - private_attributes_callable=aggregator.private_attributes_callable, + private_attributes_callable=aggregator. + private_attributes_callable, **aggregator.kwargs, ) elif aggregator.private_attributes is not None: @@ -420,14 +432,30 @@ def __get_aggregator_object(self, aggregator: Type[Aggregator]) -> Any: **aggregator.kwargs, ) aggregator_actor_ref.initialize_private_attributes.remote( - aggregator.private_attributes - ) + aggregator.private_attributes) return aggregator_actor_ref def __get_collaborator_object(self, collaborators: List) -> Any: - """Get collaborator object based on localruntime backend""" + """Get collaborator object based on localruntime backend. + + If the backend is 'single_process', it returns the list of + collaborators directly. + If the backend is 'ray', it assigns collaborators to Ray actors using + the ray_group_assign function. + Args: + collaborators (List[Type[Collaborator]]): The list of collaborator + classes to instantiate. + + Returns: + Any: The list of collaborator objects or a list of references to + the Ray actors representing the collaborators. + + Raises: + ResourcesNotAvailableError: If the requested resources exceed the + available resources. + """ for collab in collaborators: if collab.private_attributes and collab.private_attributes_callable: self.logger.warning( @@ -442,40 +470,53 @@ def __get_collaborator_object(self, collaborators: List) -> Any: total_available_cpus = os.cpu_count() total_required_cpus = sum( - [collaborator.num_cpus for collaborator in collaborators] - ) + [collaborator.num_cpus for collaborator in collaborators]) if total_available_cpus < total_required_cpus: raise ResourcesNotAvailableError( f"cannot assign more than available CPUs \ - ({total_required_cpus} < {total_available_cpus})." - ) + ({total_required_cpus} < {total_available_cpus}).") if self.backend == "ray": collaborator_ray_refs = ray_group_assign( - collaborators, num_actors=self.num_actors - ) + collaborators, num_actors=self.num_actors) return collaborator_ray_refs @property def aggregator(self) -> str: - """Returns name of _aggregator""" + """Gets the name of the aggregator. + + Returns: + str: The name of the aggregator. + """ return self._aggregator.name @aggregator.setter def aggregator(self, aggregator: Type[Aggregator]): - """Set LocalRuntime _aggregator""" + """Set LocalRuntime _aggregator. + + Args: + aggregator (Type[Aggregator]): The aggregator to be set. + """ self._aggregator = aggregator @property def collaborators(self) -> List[str]: - """ - Return names of collaborators. Don't give direct access to private attributes + """Return names of collaborators. Don't give direct access to private + attributes. + + Returns: + List[str]: The names of the collaborators. """ return list(self.__collaborators.keys()) @collaborators.setter def collaborators(self, collaborators: List[Type[Collaborator]]): - """Set LocalRuntime collaborators""" + """Set LocalRuntime collaborators. + + Args: + collaborators (List[Type[Collaborator]]): The collaborators to be + set. + """ if self.backend == "single_process": def get_collab_name(collab): @@ -492,11 +533,11 @@ def get_collab_name(collab): } def get_collaborator_kwargs(self, collaborator_name: str): - """ - Returns kwargs of collaborator + """Returns kwargs of collaborator. Args: - collaborator_name: Collaborator name for which kwargs is to be returned + collaborator_name: Collaborator name for which kwargs is to be + returned Returns: kwargs: Collaborator private_attributes_callable function name, and @@ -508,20 +549,19 @@ def get_collaborator_kwargs(self, collaborator_name: str): if collab.private_attributes_callable is not None: kwargs.update(collab.kwargs) kwargs["private_attributes_callable"] = ( - collab.private_attributes_callable.__name__ - ) + collab.private_attributes_callable.__name__) return kwargs def initialize_aggregator(self): - """initialize aggregator private attributes""" + """Initialize aggregator private attributes.""" if self.backend == "single_process": self._aggregator.initialize_private_attributes() else: ray.get(self._aggregator.initialize_private_attributes.remote()) def initialize_collaborators(self): - """initialize collaborator private attributes""" + """Initialize collaborator private attributes.""" if self.backend == "single_process": def init_private_attrs(collab): @@ -535,21 +575,32 @@ def init_private_attrs(collab): for collaborator in self.__collaborators.values(): init_private_attrs(collaborator) - def restore_instance_snapshot( - self, ctx: Type[FLSpec], instance_snapshot: List[Type[FLSpec]] - ): - """Restores attributes from backup (in instance snapshot) to ctx""" + def restore_instance_snapshot(self, ctx: Type[FLSpec], + instance_snapshot: List[Type[FLSpec]]): + """Restores attributes from backup (in instance snapshot) to context + (ctx). + + Args: + ctx (Type[FLSpec]): The context to restore the snapshot to. + instance_snapshot (List[Type[FLSpec]]): The snapshot of the + instance to be restored. + """ for backup in instance_snapshot: artifacts_iter, _ = generate_artifacts(ctx=backup) for name, attr in artifacts_iter(): if not hasattr(ctx, name): setattr(ctx, name, attr) - def execute_agg_steps( - self, ctx: Any, f_name: str, clones: Optional[Any] = None - ): - """ - Execute aggregator steps until at transition point + def execute_agg_steps(self, + ctx: Any, + f_name: str, + clones: Optional[Any] = None): + """Execute aggregator steps until at transition point. + + Args: + ctx (Any): The context in which the function is executed. + f_name (str): The name of the function to be executed. + clones (Optional[Any], optional): Clones if any. Defaults to None. """ if clones is not None: f = getattr(ctx, f_name) @@ -561,17 +612,18 @@ def execute_agg_steps( f() f, parent_func = ctx.execute_task_args[:2] - if ( - aggregator_to_collaborator(f, parent_func) - or f.__name__ == "end" - ): + if (aggregator_to_collaborator(f, parent_func) + or f.__name__ == "end"): not_at_transition_point = False f_name = f.__name__ def execute_collab_steps(self, ctx: Any, f_name: str): - """ - Execute collaborator steps until at transition point + """Execute collaborator steps until at transition point. + + Args: + ctx (Any): The context in which the function is executed. + f_name (str): The name of the function to be executed. """ not_at_transition_point = True while not_at_transition_point: @@ -585,14 +637,14 @@ def execute_collab_steps(self, ctx: Any, f_name: str): f_name = f.__name__ def execute_task(self, flspec_obj: Type[FLSpec], f: Callable, **kwargs): - """ - Defines which function to be executed based on name and kwargs - Updates the arguments and executes until end is not reached + """Defines which function to be executed based on name and kwargs. + + Updates the arguments and executes until end is not reached. Args: - flspec_obj: Reference to the FLSpec (flow) object. Contains information - about task sequence, flow attributes. - f: The next task to be executed within the flow + flspec_obj: Reference to the FLSpec (flow) object. Contains + information about task sequence, flow attributes. + f: The next task to be executed within the flow. Returns: artifacts_iter: Iterator with updated sequence of values @@ -603,14 +655,14 @@ def execute_task(self, flspec_obj: Type[FLSpec], f: Callable, **kwargs): while f.__name__ != "end": if "foreach" in kwargs: - flspec_obj = self.execute_collab_task( - flspec_obj, f, parent_func, instance_snapshot, **kwargs - ) + flspec_obj = self.execute_collab_task(flspec_obj, f, + parent_func, + instance_snapshot, + **kwargs) else: flspec_obj = self.execute_agg_task(flspec_obj, f) f, parent_func, instance_snapshot, kwargs = ( - flspec_obj.execute_task_args - ) + flspec_obj.execute_task_args) else: flspec_obj = self.execute_agg_task(flspec_obj, f) f = flspec_obj.execute_task_args[0] @@ -620,14 +672,14 @@ def execute_task(self, flspec_obj: Type[FLSpec], f: Callable, **kwargs): return artifacts_iter() def execute_agg_task(self, flspec_obj, f): - """ - Performs execution of aggregator task + """Performs execution of aggregator task. + Args: - flspec_obj : Reference to the FLSpec (flow) object - f : The task to be executed within the flow + flspec_obj: Reference to the FLSpec (flow) object. + f: The task to be executed within the flow. Returns: - flspec_obj: updated FLSpec (flow) object + flspec_obj: updated FLSpec (flow) object. """ from openfl.experimental.interface import FLSpec @@ -652,17 +704,16 @@ def execute_agg_task(self, flspec_obj, f): flspec_obj = ray_executor.ray_call_get()[0] del ray_executor else: - aggregator.execute_func( - flspec_obj, f.__name__, self.execute_agg_steps, clones - ) + aggregator.execute_func(flspec_obj, f.__name__, + self.execute_agg_steps, clones) gc.collect() return flspec_obj - def execute_collab_task( - self, flspec_obj, f, parent_func, instance_snapshot, **kwargs - ): - """ + def execute_collab_task(self, flspec_obj, f, parent_func, + instance_snapshot, **kwargs): + """Performs execution of collaborator task. + Performs 1. Filter include/exclude 2. Set runtime, collab private attributes , metaflow_interface @@ -671,10 +722,10 @@ def execute_collab_task( 5. Execute the next function after transition Args: - flspec_obj : Reference to the FLSpec (flow) object - f : The task to be executed within the flow - parent_func : The prior task executed in the flow - instance_snapshot : A prior FLSpec state that needs to be restored + flspec_obj: Reference to the FLSpec (flow) object. + f: The task to be executed within the flow. + parent_func: The prior task executed in the flow. + instance_snapshot: A prior FLSpec state that needs to be restored. Returns: flspec_obj: updated FLSpec (flow) object @@ -687,9 +738,8 @@ def execute_collab_task( self.selected_collaborators = selected_collaborators # filter exclude/include attributes for clone - self.filter_exclude_include( - flspec_obj, f, selected_collaborators, **kwargs - ) + self.filter_exclude_include(flspec_obj, f, selected_collaborators, + **kwargs) if self.backend == "ray": ray_executor = RayExecutor() @@ -710,13 +760,11 @@ def execute_collab_task( collaborator = self.__collaborators[collab_name] if self.backend == "ray": - ray_executor.ray_call_put( - collaborator, clone, f.__name__, self.execute_collab_steps - ) + ray_executor.ray_call_put(collaborator, clone, f.__name__, + self.execute_collab_steps) else: - collaborator.execute_func( - clone, f.__name__, self.execute_collab_steps - ) + collaborator.execute_func(clone, f.__name__, + self.execute_collab_steps) if self.backend == "ray": clones = ray_executor.ray_call_get() @@ -735,15 +783,14 @@ def execute_collab_task( self.join_step = True return flspec_obj - def filter_exclude_include( - self, flspec_obj, f, selected_collaborators, **kwargs - ): - """ - This function filters exclude/include attributes + def filter_exclude_include(self, flspec_obj, f, selected_collaborators, + **kwargs): + """This function filters exclude/include attributes. + Args: - flspec_obj : Reference to the FLSpec (flow) object - f : The task to be executed within the flow - selected_collaborators : all collaborators + flspec_obj: Reference to the FLSpec (flow) object. + f: The task to be executed within the flow. + selected_collaborators: all collaborators. """ from openfl.experimental.interface import FLSpec @@ -751,9 +798,9 @@ def filter_exclude_include( for col in selected_collaborators: clone = FLSpec._clones[col] clone.input = col - if ( - "exclude" in kwargs and hasattr(clone, kwargs["exclude"][0]) - ) or ("include" in kwargs and hasattr(clone, kwargs["include"][0])): + if ("exclude" in kwargs and hasattr(clone, kwargs["exclude"][0]) + ) or ("include" in kwargs + and hasattr(clone, kwargs["include"][0])): filter_attributes(clone, f, **kwargs) artifacts_iter, _ = generate_artifacts(ctx=flspec_obj) for name, attr in artifacts_iter(): @@ -761,4 +808,9 @@ def filter_exclude_include( clone._foreach_methods = flspec_obj._foreach_methods def __repr__(self): + """Returns the string representation of the LocalRuntime object. + + Returns: + str: The string representation of the LocalRuntime object. + """ return "LocalRuntime" diff --git a/openfl/experimental/runtime/runtime.py b/openfl/experimental/runtime/runtime.py index a9e5a5d9e3..14b42a7bfd 100644 --- a/openfl/experimental/runtime/runtime.py +++ b/openfl/experimental/runtime/runtime.py @@ -1,6 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -""" openfl.experimental.runtime module Runtime class.""" +"""openfl.experimental.runtime module Runtime class.""" from __future__ import annotations from typing import TYPE_CHECKING @@ -14,32 +14,57 @@ class Runtime: def __init__(self): - """ - Base interface for runtimes that can run FLSpec flows + """Initializes the Runtime object. + This serves as a base interface for runtimes that can run FLSpec flows. """ pass @property def aggregator(self): - """Returns name of aggregator""" + """Returns the name of the aggregator. + + Raises: + NotImplementedError: If the method is not implemented in a + subclass. + """ raise NotImplementedError @aggregator.setter def aggregator(self, aggregator: Aggregator): - """Set Runtime aggregator""" + """Sets the aggregator of the Runtime. + + Args: + aggregator (Aggregator): The aggregator to be set. + + Raises: + NotImplementedError: If the method is not implemented in a + subclass. + """ raise NotImplementedError @property def collaborators(self): - """ - Return names of collaborators. Don't give direct access to private attributes + """Return the names of the collaborators. Don't give direct access to + private attributes. + + Raises: + NotImplementedError: If the method is not implemented in a + subclass. """ raise NotImplementedError @collaborators.setter def collaborators(self, collaborators: List[Collaborator]): - """Set Runtime collaborators""" + """Sets the collaborators of the Runtime. + + Args: + collaborators (List[Collaborator]): The collaborators to be set. + + Raises: + NotImplementedError: If the method is not implemented in a + subclass. + """ raise NotImplementedError def execute_task( @@ -50,18 +75,22 @@ def execute_task( instance_snapshot: List[FLSpec] = [], **kwargs, ): - """ - Performs the execution of a task as defined by the - implementation and underlying backend (single_process, ray, etc) + """Performs the execution of a task as defined by the implementation + and underlying backend (single_process, ray, etc). Args: - flspec_obj: Reference to the FLSpec (flow) object. Contains information - about task sequence, flow attributes, that are needed to - execute a future task - f: The next task to be executed within the flow - parent_func: The prior task executed in the flow - instance_snapshot: A prior FLSpec state that needs to be restored from - (i.e. restoring aggregator state after collaborator - execution) + flspec_obj (FLSpec): Reference to the FLSpec (flow) object. + Contains information about task sequence, flow attributes, + that are needed to execute a future task. + f (Callable): The next task to be executed within the flow. + parent_func (Callable): The prior task executed in the flow. + instance_snapshot (List[FLSpec], optional): A prior FLSpec state + that needs to be restored from (i.e. restoring aggregator + state after collaborator execution). + **kwargs: Additional keyword arguments. + + Raises: + NotImplementedError: If the method is not implemented in a + subclass. """ raise NotImplementedError diff --git a/openfl/experimental/transport/grpc/aggregator_client.py b/openfl/experimental/transport/grpc/aggregator_client.py index ba04b7d629..6546a17a2e 100644 --- a/openfl/experimental/transport/grpc/aggregator_client.py +++ b/openfl/experimental/transport/grpc/aggregator_client.py @@ -30,9 +30,8 @@ def sleep(self): time.sleep(self.reconnect_interval) -class RetryOnRpcErrorClientInterceptor( - grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor -): +class RetryOnRpcErrorClientInterceptor(grpc.UnaryUnaryClientInterceptor, + grpc.StreamUnaryClientInterceptor): """Retry gRPC connection on failure.""" def __init__( @@ -44,9 +43,8 @@ def __init__( self.sleeping_policy = sleeping_policy self.status_for_retry = status_for_retry - def _intercept_call( - self, continuation, client_call_details, request_or_iterator - ): + def _intercept_call(self, continuation, client_call_details, + request_or_iterator): """Intercept the call to the gRPC server.""" while True: response = continuation(client_call_details, request_or_iterator) @@ -55,29 +53,25 @@ def _intercept_call( # If status code is not in retryable status codes self.sleeping_policy.logger.info( - f"Response code: {response.code()}" - ) - if ( - self.status_for_retry - and response.code() not in self.status_for_retry - ): + f"Response code: {response.code()}") + if (self.status_for_retry + and response.code() not in self.status_for_retry): return response self.sleeping_policy.sleep() else: return response - def intercept_unary_unary(self, continuation, client_call_details, request): + def intercept_unary_unary(self, continuation, client_call_details, + request): """Wrap intercept call for unary->unary RPC.""" return self._intercept_call(continuation, client_call_details, request) - def intercept_stream_unary( - self, continuation, client_call_details, request_iterator - ): + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): """Wrap intercept call for stream->unary RPC.""" - return self._intercept_call( - continuation, client_call_details, request_iterator - ) + return self._intercept_call(continuation, client_call_details, + request_iterator) def _atomic_connection(func): @@ -140,8 +134,7 @@ def __init__( if not self.tls: self.logger.warn( - "gRPC is running on insecure channel with TLS disabled." - ) + "gRPC is running on insecure channel with TLS disabled.") self.channel = self.create_insecure_channel(self.uri) else: self.channel = self.create_tls_channel( @@ -158,25 +151,20 @@ def __init__( self.single_col_cert_common_name = single_col_cert_common_name # Adding an interceptor for RPC Errors - self.interceptors = ( - RetryOnRpcErrorClientInterceptor( - sleeping_policy=ConstantBackoff( - logger=self.logger, - reconnect_interval=int( - kwargs.get("client_reconnect_interval", 1) - ), - uri=self.uri, - ), - status_for_retry=(grpc.StatusCode.UNAVAILABLE,), + self.interceptors = (RetryOnRpcErrorClientInterceptor( + sleeping_policy=ConstantBackoff( + logger=self.logger, + reconnect_interval=int( + kwargs.get("client_reconnect_interval", 1)), + uri=self.uri, ), - ) + status_for_retry=(grpc.StatusCode.UNAVAILABLE, ), + ), ) self.stub = aggregator_pb2_grpc.AggregatorStub( - grpc.intercept_channel(self.channel, *self.interceptors) - ) + grpc.intercept_channel(self.channel, *self.interceptors)) def create_insecure_channel(self, uri): - """ - Set an insecure gRPC channel (i.e. no TLS) if desired. + """Set an insecure gRPC channel (i.e. no TLS) if desired. Warns user that this is not recommended. @@ -185,7 +173,6 @@ def create_insecure_channel(self, uri): Returns: An insecure gRPC channel object - """ return grpc.insecure_channel(uri, options=channel_options) @@ -197,8 +184,7 @@ def create_tls_channel( certificate, private_key, ): - """ - Set an secure gRPC channel (i.e. TLS). + """Set an secure gRPC channel (i.e. TLS). Args: uri: The uniform resource identifier fo the insecure channel @@ -247,9 +233,8 @@ def validate_response(self, reply, collaborator_name): check_equal(reply.header.sender, self.aggregator_uuid, self.logger) # check that federation id matches - check_equal( - reply.header.federation_uuid, self.federation_uuid, self.logger - ) + check_equal(reply.header.federation_uuid, self.federation_uuid, + self.logger) # check that there is aggrement on the single_col_cert_common_name check_equal( @@ -265,7 +250,8 @@ def disconnect(self): def reconnect(self): """Create a new channel with the gRPC server.""" - # channel.close() is idempotent. Call again here in case it wasn't issued previously + # channel.close() is idempotent. Call again here in case it wasn't + # issued previously self.disconnect() if not self.tls: @@ -282,14 +268,12 @@ def reconnect(self): self.logger.debug(f"Connecting to gRPC at {self.uri}") self.stub = aggregator_pb2_grpc.AggregatorStub( - grpc.intercept_channel(self.channel, *self.interceptors) - ) + grpc.intercept_channel(self.channel, *self.interceptors)) @_atomic_connection @_resend_data_on_reconnection - def send_task_results( - self, collaborator_name, round_number, next_step, clone_bytes - ): + def send_task_results(self, collaborator_name, round_number, next_step, + clone_bytes): """Send next function name to aggregator.""" self._set_header(collaborator_name) request = aggregator_pb2.TaskResultsRequest( @@ -325,9 +309,8 @@ def get_tasks(self, collaborator_name): @_atomic_connection @_resend_data_on_reconnection - def call_checkpoint( - self, collaborator_name, clone_bytes, function, stream_buffer - ): + def call_checkpoint(self, collaborator_name, clone_bytes, function, + stream_buffer): """Perform checkpoint for collaborator task.""" self._set_header(collaborator_name) diff --git a/openfl/experimental/transport/grpc/aggregator_server.py b/openfl/experimental/transport/grpc/aggregator_server.py index e85ed17e87..433b3e7187 100644 --- a/openfl/experimental/transport/grpc/aggregator_server.py +++ b/openfl/experimental/transport/grpc/aggregator_server.py @@ -20,7 +20,7 @@ class AggregatorGRPCServer(aggregator_pb2_grpc.AggregatorServicer): - """gRPC server class for the Aggregator.""" + """GRPC server class for the Aggregator.""" def __init__( self, @@ -33,8 +33,7 @@ def __init__( private_key=None, **kwargs, ): - """ - Class initializer. + """Class initializer. Args: aggregator: The aggregator @@ -61,8 +60,7 @@ def __init__( self.logger = logging.getLogger(__name__) def validate_collaborator(self, request, context): - """ - Validate the collaborator. + """Validate the collaborator. Args: request: The gRPC message request @@ -70,17 +68,14 @@ def validate_collaborator(self, request, context): Raises: ValueError: If the collaborator or collaborator certificate is not - valid then raises error. - + valid then raises error. """ if self.tls: common_name = context.auth_context()["x509_common_name"][0].decode( - "utf-8" - ) + "utf-8") collaborator_common_name = request.header.sender if not self.aggregator.valid_collaborator_cn_and_id( - common_name, collaborator_common_name - ): + common_name, collaborator_common_name): # Random delay in authentication failures sleep(5 * random()) context.abort( @@ -90,8 +85,7 @@ def validate_collaborator(self, request, context): ) def get_header(self, collaborator_name): - """ - Compose and return MessageHeader. + """Compose and return MessageHeader. Args: collaborator_name : str @@ -101,21 +95,20 @@ def get_header(self, collaborator_name): sender=self.aggregator.uuid, receiver=collaborator_name, federation_uuid=self.aggregator.federation_uuid, - single_col_cert_common_name=self.aggregator.single_col_cert_common_name, + single_col_cert_common_name=self.aggregator. + single_col_cert_common_name, ) def check_request(self, request): - """ - Validate request header matches expected values. + """Validate request header matches expected values. Args: request : protobuf Request sent from a collaborator that requires validation """ # TODO improve this check. the sender name could be spoofed - check_is_in( - request.header.sender, self.aggregator.authorized_cols, self.logger - ) + check_is_in(request.header.sender, self.aggregator.authorized_cols, + self.logger) # check that the message is for me check_equal(request.header.receiver, self.aggregator.uuid, self.logger) @@ -135,32 +128,28 @@ def check_request(self, request): ) def SendTaskResults(self, request, context): # NOQA:N802 - """ - . + """. Args: request: The gRPC message request context: The gRPC context - """ self.validate_collaborator(request, context) self.check_request(request) collaborator_name = request.header.sender - round_number = (request.round_number,) - next_step = (request.next_step,) + round_number = (request.round_number, ) + next_step = (request.next_step, ) execution_environment = request.execution_environment - _ = self.aggregator.send_task_results( - collaborator_name, round_number[0], next_step, execution_environment - ) + _ = self.aggregator.send_task_results(collaborator_name, + round_number[0], next_step, + execution_environment) return aggregator_pb2.TaskResultsResponse( - header=self.get_header(collaborator_name) - ) + header=self.get_header(collaborator_name)) def GetTasks(self, request, context): # NOQA:N802 - """ - Request a job from aggregator. + """Request a job from aggregator. Args: request: The gRPC message request @@ -182,9 +171,7 @@ def GetTasks(self, request, context): # NOQA:N802 ) def CallCheckpoint(self, request, context): # NOQA:N802 - """ - Request aggregator to perform a checkpoint - for a given function. + """Request aggregator to perform a checkpoint for a given function. Args: request: The gRPC message request @@ -197,27 +184,23 @@ def CallCheckpoint(self, request, context): # NOQA:N802 function = request.function stream_buffer = request.stream_buffer - self.aggregator.call_checkpoint( - execution_environment, function, stream_buffer - ) + self.aggregator.call_checkpoint(execution_environment, function, + stream_buffer) return aggregator_pb2.CheckpointResponse( - header=self.get_header(collaborator_name) - ) + header=self.get_header(collaborator_name)) def get_server(self): """Return gRPC server.""" - self.server = server( - ThreadPoolExecutor(max_workers=cpu_count()), options=channel_options - ) + self.server = server(ThreadPoolExecutor(max_workers=cpu_count()), + options=channel_options) aggregator_pb2_grpc.add_AggregatorServicer_to_server(self, self.server) if not self.tls: self.logger.warn( - "gRPC is running on insecure channel with TLS disabled." - ) + "gRPC is running on insecure channel with TLS disabled.") port = self.server.add_insecure_port(self.uri) self.logger.info(f"Insecure port: {port}") @@ -234,7 +217,7 @@ def get_server(self): self.logger.warn("Client-side authentication is disabled.") self.server_credentials = ssl_server_credentials( - ((private_key_b, certificate_b),), + ((private_key_b, certificate_b), ), root_certificates=root_certificate_b, require_client_auth=not self.disable_client_auth, ) diff --git a/openfl/experimental/utilities/exceptions.py b/openfl/experimental/utilities/exceptions.py index caabaded18..e925792eed 100644 --- a/openfl/experimental/utilities/exceptions.py +++ b/openfl/experimental/utilities/exceptions.py @@ -3,21 +3,42 @@ class SerializationError(Exception): + """Raised when there is an error in serialization process.""" def __init__(self, *args: object) -> None: + """Initializes the SerializationError with the provided arguments. + + Args: + *args (object): Variable length argument list. + """ super().__init__(*args) pass class ResourcesNotAvailableError(Exception): + """Exception raised when the required resources are not available.""" def __init__(self, *args: object) -> None: + """Initializes the ResourcesNotAvailableError with the provided + arguments. + + Args: + *args (object): Variable length argument list. + """ super().__init__(*args) pass class ResourcesAllocationError(Exception): + """Exception raised when there is an error in the resources allocation + process.""" def __init__(self, *args: object) -> None: + """Initializes the ResourcesAllocationError with the provided + arguments. + + Args: + *args (object): Variable length argument list. + """ super().__init__(*args) pass diff --git a/openfl/experimental/utilities/metaflow_utils.py b/openfl/experimental/utilities/metaflow_utils.py index 36dd72a2f6..ec7604381f 100644 --- a/openfl/experimental/utilities/metaflow_utils.py +++ b/openfl/experimental/utilities/metaflow_utils.py @@ -58,14 +58,21 @@ class SystemMutex: + """Provides a system-wide mutex that locks a file until the lock is + released.""" def __init__(self, name): + """Initializes the SystemMutex with the provided name. + + Args: + name (str): The name of the mutex. + """ self.name = name def __enter__(self): - lock_id = hashlib.new( - "md5", self.name.encode("utf8"), usedforsecurity=False - ).hexdigest() # nosec + lock_id = hashlib.new("md5", + self.name.encode("utf8"), + usedforsecurity=False).hexdigest() # nosec # MD5sum used for concurrency purposes, not security self.fp = open(f"/tmp/.lock-{lock_id}.lck", "wb") fcntl.flock(self.fp.fileno(), fcntl.LOCK_EX) @@ -76,36 +83,69 @@ def __exit__(self, _type, value, tb): class Flow: + """A mock class representing a flow for Metaflow's internal use.""" def __init__(self, name): - """Mock flow for metaflow internals""" + """Mock flow for metaflow internals. + + Args: + name (str): The name of the flow. + """ self.name = name @ray.remote class Counter(object): + """A remote class that maintains a counter.""" def __init__(self): + """Initializes the Counter with value set to 0.""" self.value = 0 def increment(self): + """Increments the counter by 1. + + Returns: + int: The incremented value of the counter. + """ self.value += 1 return self.value def get_counter(self): + """Retrieves the current value of the counter. + + Returns: + int: The current value of the counter. + """ return self.value class DAGnode(DAGNode): + """A custom DAGNode class for the Metaflow graph. + + Attributes: + name (str): The name of the DAGNode. + func_lineno (int): The line number of the function in the source code. + decorators (list): The decorators applied to the function. + doc (str): The docstring of the function. + parallel_step (bool): A flag indicating if the step is parallelized. + """ def __init__(self, func_ast, decos, doc): + """Initializes the DAGNode with the provided function AST, decorators, + and docstring. + + Args: + func_ast (ast.FunctionDef): The function's abstract syntax tree. + decos (list): The decorators applied to the function. + doc (str): The docstring of the function. + """ self.name = func_ast.name self.func_lineno = func_ast.lineno self.decorators = decos self.doc = deindent_docstring(doc) self.parallel_step = any( - getattr(deco, "IS_PARALLEL", False) for deco in decos - ) + getattr(deco, "IS_PARALLEL", False) for deco in decos) # these attributes are populated by _parse self.tail_next_lineno = 0 @@ -150,7 +190,8 @@ def _parse(self, func_ast): self.out_funcs = [e.attr for e in tail.value.args] keywords = { - k.arg: getattr(k.value, "s", None) for k in tail.value.keywords + k.arg: getattr(k.value, "s", None) + for k in tail.value.keywords } # Second condition in the folliwing line added, # To add the support for up to 2 keyword arguments in Flowgraph @@ -190,19 +231,40 @@ def _parse(self, func_ast): class StepVisitor(StepVisitor): + """A custom StepVisitor class for visiting the steps in a Metaflow + graph.""" def __init__(self, nodes, flow): + """Initializes the StepVisitor with the provided nodes and flow. + + Args: + nodes (dict): The nodes in the graph. + flow (Flow): The flow object. + """ super().__init__(nodes, flow) def visit_FunctionDef(self, node): # NOQA: N802 + """Visits a FunctionDef node in the flow and adds it to the nodes + dictionary if it's a step. + + Args: + node (ast.FunctionDef): The function definition node to visit. + """ func = getattr(self.flow, node.name) if hasattr(func, "is_step"): - self.nodes[node.name] = DAGnode(node, func.decorators, func.__doc__) + self.nodes[node.name] = DAGnode(node, func.decorators, + func.__doc__) class FlowGraph(FlowGraph): + """A custom FlowGraph class for representing a Metaflow graph.""" def __init__(self, flow): + """Initializes the FlowGraph with the provided flow. + + Args: + flow (Flow): The flow object. + """ self.name = flow.__name__ self.nodes = self._create_nodes(flow) self.doc = deindent_docstring(flow.__doc__) @@ -210,11 +272,19 @@ def __init__(self, flow): self._postprocess() def _create_nodes(self, flow): + """Creates nodes for the flow graph by parsing the source code of the + flow's module. + + Args: + flow (Flow): The flow object. + + Returns: + nodes (dict): A dictionary of nodes in the graph. + """ module = __import__(flow.__module__) tree = ast.parse(getsource(module)).body root = [ - n - for n in tree + n for n in tree if isinstance(n, ast.ClassDef) and n.name == self.name ][0] nodes = {} @@ -223,6 +293,7 @@ def _create_nodes(self, flow): class TaskDataStore(TaskDataStore): + """A custom TaskDataStore class for storing task data in Metaflow.""" def __init__( self, @@ -235,6 +306,21 @@ def __init__( mode="r", allow_not_done=False, ): + """Initializes the TaskDataStore with the provided parameters. + + Args: + flow_datastore (FlowDataStore): The flow datastore. + run_id (str): The run id. + step_name (str): The step name. + task_id (str): The task id. + attempt (int, optional): The attempt number. Defaults to None. + data_metadata (DataMetadata, optional): The data metadata. + Defaults to None. + mode (str, optional): The mode (read 'r' or write 'w'). Defaults + to 'r'. + allow_not_done (bool, optional): A flag indicating whether to + allow tasks that are not done. Defaults to False. + """ super().__init__( flow_datastore, run_id, @@ -249,9 +335,8 @@ def __init__( @only_if_not_done @require_mode("w") def save_artifacts(self, artifacts_iter, force_v4=False, len_hint=0): - """ - Saves Metaflow Artifacts (Python objects) to the datastore and stores - any relevant metadata needed to retrieve them. + """Saves Metaflow Artifacts (Python objects) to the datastore and + stores any relevant metadata needed to retrieve them. Typically, objects are pickled but the datastore may perform any operation that it deems necessary. You should only access artifacts @@ -259,35 +344,30 @@ def save_artifacts(self, artifacts_iter, force_v4=False, len_hint=0): This method requires mode 'w'. - Parameters - ---------- - artifacts : Iterator[(string, object)] - Iterator over the human-readable name of the object to save - and the object itself - force_v4 : boolean or Dict[string -> boolean] - Indicates whether the artifact should be pickled using the v4 - version of pickle. If a single boolean, applies to all artifacts. - If a dictionary, applies to the object named only. Defaults to False - if not present or not specified - len_hint: integer - Estimated number of items in artifacts_iter + Args: + artifacts_iter (Iterator[(string, object)]): Iterator over the + human-readable name of the object to save and the object + itself. + force_v4 (Union[bool, Dict[string -> boolean]], optional): Indicates + whether the artifact should be pickled using the v4 version of + pickle. If a single boolean, applies to all artifacts. If a + dictionary, applies to the object named only. Defaults to False if + not present or not specified. + len_hint (int, optional): Estimated number of items in artifacts_iter. + Defaults to 0. """ artifact_names = [] def pickle_iter(): for name, obj in artifacts_iter: - do_v4 = ( - force_v4 and force_v4 - if isinstance(force_v4, bool) - else force_v4.get(name, False) - ) + do_v4 = (force_v4 and force_v4 if isinstance(force_v4, bool) + else force_v4.get(name, False)) if do_v4: encode_type = "gzip+pickle-v4" if encode_type not in self._encodings: raise DataException( f"Artifact {name} requires a serialization encoding that " - + "requires Python 3.4 or newer." - ) + + "requires Python 3.4 or newer.") try: blob = pickle.dumps(obj, protocol=4) except TypeError: @@ -320,14 +400,14 @@ def pickle_iter(): yield blob # Use the content-addressed store to store all artifacts - save_result = self._ca_store.save_blobs( - pickle_iter(), len_hint=len_hint - ) + save_result = self._ca_store.save_blobs(pickle_iter(), + len_hint=len_hint) for name, result in zip(artifact_names, save_result): self._objects[name] = result.key class FlowDataStore(FlowDataStore): + """A custom FlowDataStore class for storing flow data in Metaflow.""" def __init__( self, @@ -339,6 +419,22 @@ def __init__( storage_impl=None, ds_root=None, ): + """Initializes the FlowDataStore with the provided parameters. + + Args: + flow_name (str): The name of the flow. + environment (MetaflowEnvironment): The Metaflow environment. + metadata (MetadataProvider, optional): The metadata provider. + Defaults to None. + event_logger (EventLogger, optional): The event logger. Defaults + to None. + monitor (Monitor, optional): The monitor. Defaults to None. + storage_impl (DataStore, optional): The storage implementation. + Defaults to None. + ds_root (str, optional): The root of the datastore. Defaults to + None. + """ + super().__init__( flow_name, environment, @@ -359,7 +455,23 @@ def get_task_datastore( mode="r", allow_not_done=False, ): + """Returns a TaskDataStore for the specified task. + + Args: + run_id (str): The run id. + step_name (str): The step name. + task_id (str): The task id. + attempt (int, optional): The attempt number. Defaults to None. + data_metadata (DataMetadata, optional): The data metadata. + Defaults to None. + mode (str, optional): The mode (read 'r' or write 'w'). Defaults + to 'r'. + allow_not_done (bool, optional): A flag indicating whether to + allow tasks that are not done. Defaults to False. + Returns: + TaskDataStore: A TaskDataStore for the specified task. + """ return TaskDataStore( self, run_id, @@ -373,18 +485,20 @@ def get_task_datastore( class MetaflowInterface: + """A wrapper class for Metaflow's tooling, modified to work with the + workflow interface.""" def __init__(self, flow: Type[FLSpec], backend: str = "ray"): - """ - Wrapper class for the metaflow tooling modified to work with the - workflow interface. Keeps track of the current flow run, tasks, - and data artifacts. + """Wrapper class for the metaflow tooling modified to work with the + workflow interface. Keeps track of the current flow run, tasks, and + data artifacts. Args: - flow: the current flow that will be serialized / tracked using - metaflow tooling - backend: Which backend is selected by the runtime. Permitted selections - are 'ray' and 'single_process' + flow (Type[FLSpec]): The current flow that will be serialized / + tracked using metaflow tooling. + backend (str, optional): The backend selected by the runtime. + Permitted selections are 'ray' and 'single_process'. Defaults + to 'ray'. """ self.backend = backend self.flow_name = flow.__name__ @@ -394,9 +508,8 @@ def __init__(self, flow: Type[FLSpec], backend: str = "ray"): self.counter = 0 def create_run(self) -> int: - """ - Creates a run for the current flow using metaflow - internal functions + """Creates a run for the current flow using metaflow internal + functions. Args: None @@ -419,13 +532,11 @@ def create_run(self) -> int: return self.run_id def create_task(self, task_name: str) -> int: - """ - Creates a task for the current run. The generated - task_id is unique for each task and can be recalled - later with the metaflow client + """Creates a task for the current run. The generated task_id is unique + for each task and can be recalled later with the metaflow client. Args: - task_name: The name of the new task + task_name (str): The name of the new task. Returns: task_id [int] @@ -453,23 +564,24 @@ def save_artifacts( buffer_out: Type[StringIO], buffer_err: Type[StringIO], ) -> None: - """ - Use metaflow task datastore to save flow attributes, stdout, and stderr - for a specific task (identified by the task_name + task_id) + """Use metaflow task datastore to save flow attributes, stdout, and + stderr for a specific task (identified by the task_name + task_id). Args: - data_pairs: Generator that returns the name of the attribute, - and it's corresponding object - task_name: The task that an artifact is being saved for - task_id: A unique id (within the flow) that will be used to recover - these data artifacts by the metaflow client - buffer_out: StringIO buffer containing stdout - buffer_err: StringIO buffer containing stderr - + data_pairs (Generator[str, Any]): Generator that returns the name + of the attribute, and it's corresponding object. + task_name (str): The name of the task for which an artifact is + being saved. + task_id (int): A unique id (within the flow) that will be used to + recover these data artifacts by the metaflow client. + buffer_out (StringIO): StringIO buffer containing stdout. + buffer_err (StringIO): StringIO buffer containing stderr. """ - task_datastore = self.flow_datastore.get_task_datastore( - self.run_id, task_name, str(task_id), attempt=0, mode="w" - ) + task_datastore = self.flow_datastore.get_task_datastore(self.run_id, + task_name, + str(task_id), + attempt=0, + mode="w") task_datastore.init_task() task_datastore.save_artifacts(data_pairs) @@ -513,10 +625,21 @@ def save_artifacts( task_datastore.done() def load_artifacts(self, artifact_names, task_name, task_id): - """Use metaflow task datastore to load flow attributes""" - task_datastore = self.flow_datastore.get_task_datastore( - self.run_id, task_name, str(task_id), attempt=0, mode="r" - ) + """Loads flow attributes from Metaflow's task datastore. + + Args: + artifact_names (list): The names of the artifacts to load. + task_name (str): The name of the task from which to load artifacts. + task_id (int): The id of the task from which to load artifacts. + + Returns: + dict: A dictionary of loaded artifacts. + """ + task_datastore = self.flow_datastore.get_task_datastore(self.run_id, + task_name, + str(task_id), + attempt=0, + mode="r") return task_datastore.load_artifacts(artifact_names) def emit_log( @@ -526,27 +649,27 @@ def emit_log( task_datastore: Type[TaskDataStore], system_msg: bool = False, ) -> None: - """ - This function writes the stdout and stderr to Metaflow TaskDatastore + """Writes stdout and stderr to Metaflow's TaskDatastore. + Args: - msgbuffer_out: StringIO buffer containing stdout - msgbuffer_err: StringIO buffer containing stderr - task_datastore: Metaflow TaskDataStore instance + msgbuffer_out (StringIO): A StringIO buffer containing stdout. + msgbuffer_err (StringIO): A StringIO buffer containing stderr. + task_datastore (TaskDataStore): A Metaflow TaskDataStore instance. + system_msg (bool, optional): A flag indicating whether the message + is a system message. Defaults to False. """ stdout_buffer = TruncatedBuffer("stdout", MAX_LOG_SIZE) stderr_buffer = TruncatedBuffer("stderr", MAX_LOG_SIZE) for std_output in msgbuffer_out.readlines(): timestamp = datetime.utcnow() - stdout_buffer.write( - mflog_msg(std_output, now=timestamp), system_msg=system_msg - ) + stdout_buffer.write(mflog_msg(std_output, now=timestamp), + system_msg=system_msg) for std_error in msgbuffer_err.readlines(): timestamp = datetime.utcnow() - stderr_buffer.write( - mflog_msg(std_error, now=timestamp), system_msg=system_msg - ) + stderr_buffer.write(mflog_msg(std_error, now=timestamp), + system_msg=system_msg) task_datastore.save_logs( RUNTIME_LOG_SOURCE, @@ -558,12 +681,28 @@ def emit_log( class DefaultCard(DefaultCard): + """A custom DefaultCard class for Metaflow. + Attributes: + ALLOW_USER_COMPONENTS (bool): A flag indicating whether user + components are allowed. Defaults to True. + type (str): The type of the card. Defaults to "default". + """ ALLOW_USER_COMPONENTS = True type = "default" def __init__(self, options={"only_repr": True}, components=[], graph=None): + """Initializes the DefaultCard with the provided options, components, + and graph. + + Args: + options (dict, optional): A dictionary of options. Defaults to + {"only_repr": True}. + components (list, optional): A list of components. Defaults to an + empty list. + graph (any, optional): The graph to use. Defaults to None. + """ self._only_repr = True self._graph = None if graph is None else transform_flow_graph(graph) if "only_repr" in options: @@ -572,7 +711,16 @@ def __init__(self, options={"only_repr": True}, components=[], graph=None): # modified Defaultcard render function def render(self, task): - # :param: task instead of metaflow.client.Task object task.pathspec (string) is provided + """Renders the card with the provided task. + + Args: + task (any): The task to render the card with. + + Returns: + any: The rendered card. + """ + # :param: task instead of metaflow.client.Task object task.pathspec + # (string) is provided RENDER_TEMPLATE = read_file(RENDER_TEMPLATE_PATH) # NOQA: N806 JS_DATA = read_file(JS_PATH) # NOQA: N806 CSS_DATA = read_file(CSS_PATH) # NOQA: N806 @@ -585,22 +733,29 @@ def render(self, task): ).render() pt = self._get_mustache() data_dict = { - "task_data": base64.b64encode( - json.dumps(final_component_dict).encode("utf-8") - ).decode("utf-8"), - "javascript": JS_DATA, - "title": task, - "css": CSS_DATA, - "card_data_id": uuid.uuid4(), + "task_data": + base64.b64encode(json.dumps(final_component_dict).encode( + "utf-8")).decode("utf-8"), + "javascript": + JS_DATA, + "title": + task, + "css": + CSS_DATA, + "card_data_id": + uuid.uuid4(), } return pt.render(RENDER_TEMPLATE, data_dict) class TaskInfoComponent(TaskInfoComponent): - """ - Properties - page_content : a list of MetaflowCardComponents going as task info - final_component: the dictionary returned by the `render` function of this class. + """A custom TaskInfoComponent class for Metaflow. + + Properties: + page_content (list): A list of MetaflowCardComponents going as task + info. + final_component (dict): The dictionary returned by the `render` + function of this class. """ def __init__( @@ -611,6 +766,19 @@ def __init__( graph=None, components=[], ): + """Initializes the TaskInfoComponent with the provided task, page + title, representation flag, graph, and components. + + Args: + task (any): The task to use. + page_title (str, optional): The title of the page. Defaults to + "Task Info". + only_repr (bool, optional): A flag indicating whether to only use + the representation. Defaults to True. + graph (any, optional): The graph to use. Defaults to None. + components (list, optional): A list of components. Defaults to an + empty list. + """ self._task = task self._only_repr = only_repr self._graph = graph @@ -621,11 +789,12 @@ def __init__( # modified TaskInfoComponent render function def render(self): - """ + """Renders the component and returns a dictionary of metadata and + components. Returns: - a dictionary of form: - dict(metadata = {},components= []) + final_component_dict (dict): A dictionary of the form: + dict(metadata={}, components=[]). """ final_component_dict = { "metadata": { @@ -637,8 +806,8 @@ def render(self): } dag_component = SectionComponent( - title="DAG", contents=[DagComponent(data=self._graph).render()] - ).render() + title="DAG", + contents=[DagComponent(data=self._graph).render()]).render() page_contents = [] page_contents.append(dag_component) diff --git a/openfl/experimental/utilities/resources.py b/openfl/experimental/utilities/resources.py index 6c0ed54c3c..bf775aef49 100644 --- a/openfl/experimental/utilities/resources.py +++ b/openfl/experimental/utilities/resources.py @@ -9,23 +9,27 @@ def get_number_of_gpus() -> int: - """ - Returns number of NVIDIA GPUs attached to the machine. + """Returns number of NVIDIA GPUs attached to the machine. + + This function executes the `nvidia-smi --list-gpus` command to get the + list of GPUs. + If the command fails (e.g., NVIDIA drivers are not installed), it logs a + warning and returns 0. - Args: - None Returns: - int: Number of NVIDIA GPUs + int: The number of NVIDIA GPUs attached to the machine. """ # Execute the nvidia-smi command. command = "nvidia-smi --list-gpus" try: - op = run(command.strip().split(), shell=False, stdout=PIPE, stderr=PIPE) + op = run(command.strip().split(), + shell=False, + stdout=PIPE, + stderr=PIPE) stdout = op.stdout.decode().strip() return len(stdout.split("\n")) except FileNotFoundError: logger.warning( f'No GPUs found! If this is a mistake please try running "{command}" ' - + "manually." - ) + + "manually.") return 0 diff --git a/openfl/experimental/utilities/runtime_utils.py b/openfl/experimental/utilities/runtime_utils.py index 3421c5d211..1aa7af2055 100644 --- a/openfl/experimental/utilities/runtime_utils.py +++ b/openfl/experimental/utilities/runtime_utils.py @@ -12,18 +12,27 @@ def parse_attrs(ctx, exclude=[], reserved_words=["next", "runtime", "input"]): - """Returns ctx attributes and artifacts""" + """Parses the context to get its attributes and artifacts, excluding those + specified. + + Args: + ctx (any): The context to parse. + exclude (list, optional): A list of attribute names to exclude. + Defaults to an empty list. + reserved_words (list, optional): A list of reserved words to exclude. + Defaults to ["next", "runtime", "input"]. + + Returns: + tuple: A tuple containing a list of attribute names and a list of + valid artifacts (pairs of attribute names and values). + """ # TODO Persist attributes to local disk, database, object store, etc. here cls_attrs = [] valid_artifacts = [] for i in inspect.getmembers(ctx): - if ( - not hasattr(i[1], "task") - and not i[0].startswith("_") - and i[0] not in reserved_words - and i[0] not in exclude - and i not in inspect.getmembers(type(ctx)) - ): + if (not hasattr(i[1], "task") and not i[0].startswith("_") + and i[0] not in reserved_words and i[0] not in exclude + and i not in inspect.getmembers(type(ctx))): if not isinstance(i[1], MethodType): cls_attrs.append(i[0]) valid_artifacts.append((i[0], i[1])) @@ -31,8 +40,20 @@ def parse_attrs(ctx, exclude=[], reserved_words=["next", "runtime", "input"]): def generate_artifacts(ctx, reserved_words=["next", "runtime", "input"]): - """Returns ctx artifacts, and artifacts_iter method""" - cls_attrs, valid_artifacts = parse_attrs(ctx, reserved_words=reserved_words) + """Generates artifacts from the given context, excluding specified reserved + words. + + Args: + ctx (any): The context to generate artifacts from. + reserved_words (list, optional): A list of reserved words to exclude. + Defaults to ["next", "runtime", "input"]. + + Returns: + tuple: A tuple containing a generator of artifacts and a list of + attribute names. + """ + cls_attrs, valid_artifacts = parse_attrs(ctx, + reserved_words=reserved_words) def artifacts_iter(): # Helper function from metaflow source @@ -44,14 +65,25 @@ def artifacts_iter(): def filter_attributes(ctx, f, **kwargs): - """ - Filter out explicitly included / excluded attributes from the next task - in the flow. + """Filters out attributes from the next task in the flow based on inclusion + or exclusion. + + Args: + ctx (any): The context to filter attributes from. + f (function): The next task function in the flow. + **kwargs: Optional arguments that specify the 'include' or 'exclude' + lists. + + Raises: + RuntimeError: If both 'include' and 'exclude' are present, or if an + attribute in 'include' or 'exclude' is not found in the context's + attributes. """ _, cls_attrs = generate_artifacts(ctx=ctx) if "include" in kwargs and "exclude" in kwargs: - raise RuntimeError("'include' and 'exclude' should not both be present") + raise RuntimeError( + "'include' and 'exclude' should not both be present") elif "include" in kwargs: assert isinstance(kwargs["include"], list) for in_attr in kwargs["include"]: @@ -75,20 +107,25 @@ def filter_attributes(ctx, f, **kwargs): def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]): - """ - [Optionally] save current state for the task just executed task + """Optionally saves the current state for the task just executed. + + Args: + ctx (any): The context to checkpoint. + parent_func (function): The function that was just executed. + chkpnt_reserved_words (list, optional): A list of reserved words to + exclude from checkpointing. Defaults to ["next", "runtime"]. """ # Extract the stdout & stderr from the buffer - # NOTE: Any prints in this method before this line will be recorded as step output/error + # NOTE: Any prints in this method before this line will be recorded as + # step output/error step_stdout, step_stderr = parent_func._stream_buffer.get_stdstream() if ctx._checkpoint: # all objects will be serialized using Metaflow interface print(f"Saving data artifacts for {parent_func.__name__}") artifacts_iter, _ = generate_artifacts( - ctx=ctx, reserved_words=chkpnt_reserved_words - ) + ctx=ctx, reserved_words=chkpnt_reserved_words) task_id = ctx._metaflow_interface.create_task(parent_func.__name__) ctx._metaflow_interface.save_artifacts( artifacts_iter(), @@ -102,28 +139,25 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]): def old_check_resource_allocation(num_gpus, each_participant_gpu_usage): remaining_gpu_memory = {} - # TODO for each GPU the funtion tries see if all participant usages fit into a GPU, it it - # doesn't it removes that - # participant from the participant list, and adds it to the remaining_gpu_memory dict. So any - # sum of GPU requirements above 1 - # triggers this. - # But at this point the funtion will raise an error because remaining_gpu_memory is never - # cleared. - # The participant list should remove the participant if it fits in the gpu and save the - # partipant if it doesn't and continue - # to the next GPU to see if it fits in that one, only if we run out of GPUs should this - # funtion raise an error. + # TODO for each GPU the funtion tries see if all participant usages fit + # into a GPU, it it doesn't it removes that participant from the + # participant list, and adds it to the remaining_gpu_memory dict. So any + # sum of GPU requirements above 1 triggers this. + # But at this point the funtion will raise an error because + # remaining_gpu_memory is never cleared. + # The participant list should remove the participant if it fits in the gpu + # and save the partipant if it doesn't and continue to the next GPU to see + # if it fits in that one, only if we run out of GPUs should this funtion + # raise an error. for gpu in np.ones(num_gpus, dtype=int): for i, (participant_name, participant_gpu_usage) in enumerate( - each_participant_gpu_usage.items() - ): + each_participant_gpu_usage.items()): if gpu == 0: break if gpu < participant_gpu_usage: remaining_gpu_memory.update({participant_name: gpu}) each_participant_gpu_usage = dict( - itertools.islice(each_participant_gpu_usage.items(), i) - ) + itertools.islice(each_participant_gpu_usage.items(), i)) else: gpu -= participant_gpu_usage if len(remaining_gpu_memory) > 0: @@ -138,11 +172,11 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage): need_assigned = each_participant_gpu_usage.copy() # cycle through all available GPU availability for gpu in np.ones(num_gpus, dtype=int): - # buffer to cycle though since need_assigned will change sizes as we assign participants + # buffer to cycle though since need_assigned will change sizes as we + # assign participants current_dict = need_assigned.copy() - for i, (participant_name, participant_gpu_usage) in enumerate( - current_dict.items() - ): + for i, (participant_name, + participant_gpu_usage) in enumerate(current_dict.items()): if gpu == 0: break if gpu < participant_gpu_usage: @@ -153,8 +187,8 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage): need_assigned.pop(participant_name) gpu -= participant_gpu_usage - # raise error if after going though all gpus there are still participants that needed to be - # assigned + # raise error if after going though all gpus there are still participants + # that needed to be assigned if len(need_assigned) > 0: raise ResourcesAllocationError( f"Failed to allocate Participant {list(need_assigned.keys())} " diff --git a/openfl/experimental/utilities/stream_redirect.py b/openfl/experimental/utilities/stream_redirect.py index 5f7a25fd3d..4a1b599c13 100644 --- a/openfl/experimental/utilities/stream_redirect.py +++ b/openfl/experimental/utilities/stream_redirect.py @@ -8,17 +8,25 @@ class RedirectStdStreamBuffer: - """ - Buffer object used to store stdout & stderr + """Buffer object used to store stdout and stderr. + + Attributes: + _stdoutbuff (io.StringIO): Buffer for stdout. + _stderrbuff (io.StringIO): Buffer for stderr. """ def __init__(self): + """Initializes the RedirectStdStreamBuffer with empty stdout and stderr + buffers.""" self._stdoutbuff = io.StringIO() self._stderrbuff = io.StringIO() def get_stdstream(self): - """ - Return the contents of stdout and stderr buffers + """Returns the contents of stdout and stderr buffers. + + Returns: + tuple: A tuple containing the contents of stdout and stderr + buffers. """ self._stdoutbuff.seek(0) self._stderrbuff.seek(0) @@ -33,16 +41,30 @@ def get_stdstream(self): class RedirectStdStream(object): - """ - This class used to intercept stdout and stderr, so that - stdout and stderr is written to buffer as well as terminal + """Class used to intercept stdout and stderr, so that stdout and stderr is + written to buffer as well as terminal. + + Attributes: + __stdDestination (io.TextIOWrapper): Destination for standard outputs. + __stdBuffer (RedirectStdStreamBuffer): Buffer for standard outputs. """ def __init__(self, buffer, destination): + """Initializes the RedirectStdStream with a buffer and a destination. + + Args: + buffer (RedirectStdStreamBuffer): Buffer for standard outputs. + destination (io.TextIOWrapper): Destination for standard outputs. + """ self.__stdDestination = destination self.__stdBuffer = buffer def write(self, message): + """Writes the message to the standard destination and buffer. + + Args: + message (str): The message to write. + """ message = f"\33[94m{message}\33[0m" self.__stdDestination.write(message) self.__stdBuffer.write(message) @@ -52,31 +74,39 @@ def flush(self): class RedirectStdStreamContext: - """ - Context Manager that enables redirection of stdout & stderr + """Context Manager that enables redirection of stdout and stderr. + + Attributes: + stdstreambuffer (RedirectStdStreamBuffer): Buffer for standard outputs. """ def __init__(self): + """Initializes the RedirectStdStreamContext with a + RedirectStdStreamBuffer.""" self.stdstreambuffer = RedirectStdStreamBuffer() def __enter__(self): - """ - Create context to redirect stdout & stderr + """Creates a context to redirect stdout and stderr. + + Returns: + RedirectStdStreamBuffer: The buffer for standard outputs. """ self.__old_stdout = sys.stdout self.__old_stderr = sys.stderr - sys.stdout = RedirectStdStream( - self.stdstreambuffer._stdoutbuff, sys.stdout - ) - sys.stderr = RedirectStdStream( - self.stdstreambuffer._stderrbuff, sys.stderr - ) + sys.stdout = RedirectStdStream(self.stdstreambuffer._stdoutbuff, + sys.stdout) + sys.stderr = RedirectStdStream(self.stdstreambuffer._stderrbuff, + sys.stderr) return self.stdstreambuffer def __exit__(self, et, ev, tb): - """ - Exit the context and restore the stdout & stderr + """Exits the context and restores the stdout and stderr. + + Args: + et (type): The type of exception. + ev (BaseException): The instance of exception. + tb (traceback): A traceback object encapsulating the call stack. """ sys.stdout = self.__old_stdout sys.stderr = self.__old_stderr diff --git a/openfl/experimental/utilities/transitions.py b/openfl/experimental/utilities/transitions.py index b134a73690..4a31a08403 100644 --- a/openfl/experimental/utilities/transitions.py +++ b/openfl/experimental/utilities/transitions.py @@ -4,6 +4,15 @@ def should_transfer(func, parent_func): + """Determines if a transfer should occur from collaborator to aggregator. + + Args: + func (function): The current function. + parent_func (function): The parent function. + + Returns: + bool: True if a transfer should occur, False otherwise. + """ if collaborator_to_aggregator(func, parent_func): return True else: @@ -11,6 +20,15 @@ def should_transfer(func, parent_func): def aggregator_to_collaborator(func, parent_func): + """Checks if a transition from aggregator to collaborator is possible. + + Args: + func (function): The current function. + parent_func (function): The parent function. + + Returns: + bool: True if the transition is possible, False otherwise. + """ if parent_func.aggregator_step and func.collaborator_step: return True else: @@ -18,6 +36,15 @@ def aggregator_to_collaborator(func, parent_func): def collaborator_to_aggregator(func, parent_func): + """Checks if a transition from collaborator to aggregator is possible. + + Args: + func (function): The current function. + parent_func (function): The parent function. + + Returns: + bool: True if the transition is possible, False otherwise. + """ if parent_func.collaborator_step and func.aggregator_step: return True else: diff --git a/openfl/experimental/utilities/ui.py b/openfl/experimental/utilities/ui.py index ae10910ffd..8749c5e3a5 100644 --- a/openfl/experimental/utilities/ui.py +++ b/openfl/experimental/utilities/ui.py @@ -9,6 +9,17 @@ class InspectFlow: + """Class for inspecting a flow. + + Attributes: + ds_root (str): The root directory for the data store. Defaults to + "~/.metaflow". + show_html (bool): Whether to show the UI in a web browser. Defaults to + False. + run_id (str): The run ID of the flow. + flow_name (str): The name of the flow. + graph_dict (dict): The graph of the flow. + """ def __init__( self, @@ -17,6 +28,18 @@ def __init__( show_html=False, ds_root=f"{Path.home()}/.metaflow", ): + """Initializes the InspectFlow with a flow object, run ID, an optional + flag to show the UI in a web browser, and an optional root directory + for the data store. + + Args: + flow_obj (Flow): The flow object to inspect. + run_id (str): The run ID of the flow. + show_html (bool, optional): Whether to show the UI in a web + browser. Defaults to False. + ds_root (str, optional): The root directory for the data store. + Defaults to "~/.metaflow". + """ self.ds_root = ds_root self.show_html = show_html self.run_id = run_id @@ -28,13 +51,26 @@ def __init__( self.show_ui() def get_pathspec(self): + """Gets the path specification of the flow. + + Returns: + str: The path specification of the flow. + """ return f"{self.ds_root}/{self.flow_name}/{self.run_id}" def open_in_browser(self, card_path): + """Opens the specified path in a web browser. + + Args: + card_path (str): The path to open. + """ url = "file://" + os.path.abspath(card_path) webbrowser.open(url) def show_ui(self): + """Shows the UI of the flow in a web browser if show_html is True, and + saves the UI as an HTML file.""" + default_card = DefaultCard(graph=self.graph_dict) pathspec = self.get_pathspec() diff --git a/openfl/experimental/workspace_export/export.py b/openfl/experimental/workspace_export/export.py index 4370cfe21e..cf52295ae3 100644 --- a/openfl/experimental/workspace_export/export.py +++ b/openfl/experimental/workspace_export/export.py @@ -19,13 +19,15 @@ class WorkspaceExport: - """ - Convert a LocalRuntime Jupyter Notebook to Aggregator based FederatedRuntime Workflow. + """Convert a LocalRuntime Jupyter Notebook to Aggregator based + FederatedRuntime Workflow. Args: notebook_path: Absolute path of jupyter notebook. - template_workspace_path: Path to template workspace provided with OpenFL. - output_dir: Output directory for new generated workspace (default="/tmp"). + template_workspace_path: Path to template workspace provided with + OpenFL. + output_dir: Output directory for new generated workspace + (default="/tmp"). Returns: None @@ -39,42 +41,37 @@ def __init__(self, notebook_path: str, output_workspace: str) -> None: self.output_workspace_path.parent.mkdir(parents=True, exist_ok=True) self.template_workspace_path = ( - Path(f"{__file__}") - .parent.parent.parent.parent.joinpath( - "openfl-workspace", "experimental", "template_workspace" - ) - .resolve(strict=True) - ) + Path(f"{__file__}").parent.parent.parent.parent.joinpath( + "openfl-workspace", "experimental", + "template_workspace").resolve(strict=True)) # Copy template workspace to output directory self.created_workspace_path = Path( - copytree(self.template_workspace_path, self.output_workspace_path) - ) + copytree(self.template_workspace_path, self.output_workspace_path)) self.logger.info( - f"Copied template workspace to {self.created_workspace_path}" - ) + f"Copied template workspace to {self.created_workspace_path}") self.logger.info("Converting jupter notebook to python script...") export_filename = self.__get_exp_name() if export_filename is None: raise NameError( "Please include `#| default_exp ` in " - "the first cell of the notebook." - ) + "the first cell of the notebook.") self.script_path = Path( self.__convert_to_python( self.notebook_path, self.created_workspace_path.joinpath("src"), f"{export_filename}.py", - ) - ).resolve() + )).resolve() print_tree(self.created_workspace_path, level=2) # Generated python script name without .py extension self.script_name = self.script_path.name.split(".")[0].strip() - # Comment flow.run() so when script is imported flow does not start executing + # Comment flow.run() so when script is imported flow does not start + # executing self.__comment_flow_execution() - # This is required as Ray created actors too many actors when backend="ray" + # This is required as Ray created actors too many actors when + # backend="ray" self.__change_runtime() def __get_exp_name(self): @@ -88,22 +85,18 @@ def __get_exp_name(self): match = re.search(r"#\s*\|\s*default_exp\s+(\w+)", code) if match: self.logger.info( - f"Retrieved {match.group(1)} from default_exp" - ) + f"Retrieved {match.group(1)} from default_exp") return match.group(1) return None - def __convert_to_python( - self, notebook_path: Path, output_path: Path, export_filename - ): + def __convert_to_python(self, notebook_path: Path, output_path: Path, + export_filename): nb_export(notebook_path, output_path) return Path(output_path).joinpath(export_filename).resolve() def __comment_flow_execution(self): - """ - In the python script search for ".run()" and comment it - """ + """In the python script search for ".run()" and comment it.""" with open(self.script_path, "r") as f: data = f.readlines() for idx, line in enumerate(data): @@ -113,26 +106,21 @@ def __comment_flow_execution(self): f.writelines(data) def __change_runtime(self): - """ - Change the LocalRuntime backend from ray to single_process - """ + """Change the LocalRuntime backend from ray to single_process.""" with open(self.script_path, "r") as f: data = f.read() if "backend='ray'" in data or 'backend="ray"' in data: - data = data.replace( - "backend='ray'", "backend='single_process'" - ).replace( - 'backend="ray"', 'backend="single_process"' - ) + data = data.replace("backend='ray'", + "backend='single_process'").replace( + 'backend="ray"', + 'backend="single_process"') with open(self.script_path, "w") as f: f.write(data) def __get_class_arguments(self, class_name): - """ - Given the class name returns expected class arguments - """ + """Given the class name returns expected class arguments.""" # Import python script if not already if not hasattr(self, "exported_script_module"): self.__import_exported_script() @@ -153,10 +141,10 @@ def __get_class_arguments(self, class_name): # Check if the class has an __init__ method if "__init__" in cls.__dict__: init_signature = inspect.signature(cls.__init__) - # Extract the parameter names (excluding 'self', 'args', and 'kwargs') + # Extract the parameter names (excluding 'self', 'args', and + # 'kwargs') arg_names = [ - param - for param in init_signature.parameters + param for param in init_signature.parameters if param not in ("self", "args", "kwargs") ] return arg_names @@ -164,9 +152,8 @@ def __get_class_arguments(self, class_name): self.logger.error(f"{cls} is not a class") def __get_class_name_and_sourcecode_from_parent_class(self, parent_class): - """ - Provided the parent_class name returns derived class source code and name. - """ + """Provided the parent_class name returns derived class source code and + name.""" # Import python script if not already if not hasattr(self, "exported_script_module"): self.__import_exported_script() @@ -174,19 +161,15 @@ def __get_class_name_and_sourcecode_from_parent_class(self, parent_class): # Going though all attributes in imported python script for attr in self.available_modules_in_exported_script: t = getattr(self.exported_script_module, attr) - if ( - inspect.isclass(t) - and t != parent_class - and issubclass(t, parent_class) - ): + if (inspect.isclass(t) and t != parent_class + and issubclass(t, parent_class)): return inspect.getsource(t), attr return None, None def __extract_class_initializing_args(self, class_name): - """ - Provided name of the class returns expected arguments and it's values in form of dictionary - """ + """Provided name of the class returns expected arguments and it's + values in form of dictionary.""" instantiation_args = {"args": {}, "kwargs": {}} with open(self.script_path, "r") as s: @@ -194,8 +177,7 @@ def __extract_class_initializing_args(self, class_name): for node in ast.walk(tree): if isinstance(node, ast.Call) and isinstance( - node.func, ast.Name - ): + node.func, ast.Name): if node.func.id == class_name: # We found an instantiation of the class for arg in node.args: @@ -205,46 +187,44 @@ def __extract_class_initializing_args(self, class_name): instantiation_args["args"][arg.id] = arg.id elif isinstance(arg, ast.Constant): instantiation_args["args"][arg.s] = ( - astor.to_source(arg) - ) + astor.to_source(arg)) else: instantiation_args["args"][arg.arg] = ( - astor.to_source(arg).strip() - ) + astor.to_source(arg).strip()) for kwarg in node.keywords: # Iterate through keyword arguments value = astor.to_source(kwarg.value).strip() - # If paranthese or brackets around the value is found - # and it's not tuple or list remove paranthese or brackets + # If paranthese or brackets around the value is + # found and it's not tuple or list remove + # paranthese or brackets if value.startswith("(") and "," not in value: value = value.lstrip("(").rstrip(")") if value.startswith("[") and "," not in value: value = value.lstrip("[").rstrip("]") try: - # Evaluate the value to convert it from a string - # representation into its corresponding python object. + # Evaluate the value to convert it from a + # string representation into its corresponding + # python object. value = ast.literal_eval(value) except ValueError: - # ValueError is ignored because we want the value as a string + # ValueError is ignored because we want the + # value as a string pass instantiation_args["kwargs"][kwarg.arg] = value return instantiation_args def __import_exported_script(self): - """ - Imports generated python script with help of importlib - """ + """Imports generated python script with help of importlib.""" import importlib import sys sys.path.append(str(self.script_path.parent)) self.exported_script_module = importlib.import_module(self.script_name) self.available_modules_in_exported_script = dir( - self.exported_script_module - ) + self.exported_script_module) def __read_yaml(self, path): with open(path, "r") as y: @@ -256,14 +236,13 @@ def __write_yaml(self, path, data): @classmethod def export(cls, notebook_path: str, output_workspace: str) -> None: - """ - Exports workspace to `output_dir`. + """Exports workspace to `output_dir`. Args: notebook_path: Jupyter notebook path. output_dir: Path for generated workspace directory. - template_workspace_path: Path to template workspace provided with OpenFL - (default="/tmp"). + template_workspace_path: Path to template workspace provided with + OpenFL (default="/tmp"). Returns: None @@ -276,10 +255,8 @@ def export(cls, notebook_path: str, output_workspace: str) -> None: # Have to do generate_requirements before anything else # because these !pip commands needs to be removed from python script def generate_requirements(self): - """ - Finds pip libraries mentioned in exported python script and append in - workspace/requirements.txt - """ + """Finds pip libraries mentioned in exported python script and append + in workspace/requirements.txt.""" data = None with open(self.script_path, "r") as f: requirements = [] @@ -289,18 +266,14 @@ def generate_requirements(self): line = line.strip() if "pip install" in line: line_nos.append(i) - # Avoid commented lines, libraries from *.txt file, or openfl.git - # installation - if ( - not line.startswith("#") - and "-r" not in line - and "openfl.git" not in line - ): + # Avoid commented lines, libraries from *.txt file, or + # openfl.git installation + if (not line.startswith("#") and "-r" not in line + and "openfl.git" not in line): requirements.append(f"{line.split(' ')[-1].strip()}\n") requirements_filepath = str( - self.created_workspace_path.joinpath("requirements.txt").resolve() - ) + self.created_workspace_path.joinpath("requirements.txt").resolve()) # Write libraries found in requirements.txt with open(requirements_filepath, "a") as f: @@ -314,35 +287,27 @@ def generate_requirements(self): f.write(line) def generate_plan_yaml(self): - """ - Generates plan.yaml - """ + """Generates plan.yaml.""" flspec = getattr( - importlib.import_module("openfl.experimental.interface"), "FLSpec" - ) + importlib.import_module("openfl.experimental.interface"), "FLSpec") # Get flow classname _, self.flow_class_name = ( - self.__get_class_name_and_sourcecode_from_parent_class(flspec) - ) + self.__get_class_name_and_sourcecode_from_parent_class(flspec)) # Get expected arguments of flow class self.flow_class_expected_arguments = self.__get_class_arguments( - self.flow_class_name - ) + self.flow_class_name) # Get provided arguments to flow class self.arguments_passed_to_initialize = ( - self.__extract_class_initializing_args(self.flow_class_name) - ) + self.__extract_class_initializing_args(self.flow_class_name)) - plan = self.created_workspace_path.joinpath( - "plan", "plan.yaml" - ).resolve() + plan = self.created_workspace_path.joinpath("plan", + "plan.yaml").resolve() data = self.__read_yaml(plan) if data is None: data["federated_flow"] = {"settings": {}, "template": ""} data["federated_flow"][ - "template" - ] = f"src.{self.script_name}.{self.flow_class_name}" + "template"] = f"src.{self.script_name}.{self.flow_class_name}" def update_dictionary(args: dict, data: dict, dtype: str = "args"): for idx, (k, v) in enumerate(args.items()): @@ -366,9 +331,7 @@ def update_dictionary(args: dict, data: dict, dtype: str = "args"): self.__write_yaml(plan, data) def generate_data_yaml(self): - """ - Generates data.yaml - """ + """Generates data.yaml.""" # Import python script if not already if not hasattr(self, "exported_script_module"): self.__import_exported_script() @@ -380,14 +343,13 @@ def generate_data_yaml(self): "FLSpec", ) _, self.flow_class_name = ( - self.__get_class_name_and_sourcecode_from_parent_class(flspec) - ) + self.__get_class_name_and_sourcecode_from_parent_class(flspec)) # Import flow class - federated_flow_class = getattr( - self.exported_script_module, self.flow_class_name - ) - # Find federated_flow._runtime and federated_flow._runtime.collaborators + federated_flow_class = getattr(self.exported_script_module, + self.flow_class_name) + # Find federated_flow._runtime and + # federated_flow._runtime.collaborators for t in self.available_modules_in_exported_script: tempstring = t t = getattr(self.exported_script_module, t) @@ -395,18 +357,15 @@ def generate_data_yaml(self): flow_name = tempstring if not hasattr(t, "_runtime"): raise AttributeError( - "Unable to locate LocalRuntime instantiation" - ) + "Unable to locate LocalRuntime instantiation") runtime = t._runtime if not hasattr(runtime, "collaborators"): raise AttributeError( - "LocalRuntime instance does not have collaborators" - ) + "LocalRuntime instance does not have collaborators") break data_yaml = self.created_workspace_path.joinpath( - "plan", "data.yaml" - ).resolve() + "plan", "data.yaml").resolve() data = self.__read_yaml(data_yaml) if data is None: data = {} @@ -422,39 +381,39 @@ def generate_data_yaml(self): data["aggregator"] = { "callable_func": { "settings": {}, - "template": f"src.{self.script_name}.{private_attrs_callable.__name__}", + "template": + f"src.{self.script_name}.{private_attrs_callable.__name__}", } } # Find arguments expected by Aggregator arguments_passed_to_initialize = ( - self.__extract_class_initializing_args("Aggregator")["kwargs"] - ) + self.__extract_class_initializing_args("Aggregator")["kwargs"]) agg_kwargs = aggregator.kwargs for key, value in agg_kwargs.items(): if isinstance(value, (int, str, bool)): - data["aggregator"]["callable_func"]["settings"][key] = value + data["aggregator"]["callable_func"]["settings"][ + key] = value else: arg = arguments_passed_to_initialize[key] value = f"src.{self.script_name}.{arg}" - data["aggregator"]["callable_func"]["settings"][key] = value + data["aggregator"]["callable_func"]["settings"][ + key] = value elif aggregator_private_attributes: runtime_created = True with open(self.script_path, 'a') as f: f.write(f"\n{runtime_name} = {flow_name}._runtime\n") - f.write( - f"\naggregator_private_attributes = " - f"{runtime_name}._aggregator.private_attributes\n" - ) + f.write(f"\naggregator_private_attributes = " + f"{runtime_name}._aggregator.private_attributes\n") data["aggregator"] = { - "private_attributes": f"src.{self.script_name}.aggregator_private_attributes" + "private_attributes": + f"src.{self.script_name}.aggregator_private_attributes" } # Get runtime collaborators collaborators = runtime._LocalRuntime__collaborators # Find arguments expected by Collaborator arguments_passed_to_initialize = self.__extract_class_initializing_args( - "Collaborator" - )["kwargs"] + "Collaborator")["kwargs"] runtime_collab_created = False for collab in collaborators.values(): collab_name = collab.get_name() @@ -464,7 +423,10 @@ def generate_data_yaml(self): if callable_func: if collab_name not in data: data[collab_name] = { - "callable_func": {"settings": {}, "template": None} + "callable_func": { + "settings": {}, + "template": None + } } # Find collaborator private_attributes callable details kw_args = runtime.get_collaborator_kwargs(collab_name) @@ -473,28 +435,29 @@ def generate_data_yaml(self): value = f"src.{self.script_name}.{value}" data[collab_name]["callable_func"]["template"] = value elif isinstance(value, (int, str, bool)): - data[collab_name]["callable_func"]["settings"][key] = value + data[collab_name]["callable_func"]["settings"][ + key] = value else: arg = arguments_passed_to_initialize[key] value = f"src.{self.script_name}.{arg}" - data[collab_name]["callable_func"]["settings"][key] = value + data[collab_name]["callable_func"]["settings"][ + key] = value elif private_attributes: with open(self.script_path, 'a') as f: if not runtime_created: f.write(f"\n{runtime_name} = {flow_name}._runtime\n") runtime_created = True if not runtime_collab_created: - f.write( - f"\nruntime_collaborators = " - f"{runtime_name}._LocalRuntime__collaborators" - ) + f.write(f"\nruntime_collaborators = " + f"{runtime_name}._LocalRuntime__collaborators") runtime_collab_created = True f.write( f"\n{collab_name}_private_attributes = " f"runtime_collaborators['{collab_name}'].private_attributes" ) data[collab_name] = { - "private_attributes": f"src." + "private_attributes": + f"src." f"{self.script_name}.{collab_name}_private_attributes" } diff --git a/openfl/federated/__init__.py b/openfl/federated/__init__.py index b2b4f4fd1f..16e474832e 100644 --- a/openfl/federated/__init__.py +++ b/openfl/federated/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.federated package.""" import pkgutil diff --git a/openfl/federated/data/__init__.py b/openfl/federated/data/__init__.py index 7cab3710c1..14956b798a 100644 --- a/openfl/federated/data/__init__.py +++ b/openfl/federated/data/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Data package.""" import pkgutil diff --git a/openfl/federated/data/federated_data.py b/openfl/federated/data/federated_data.py index d1fbcaac0c..ef60694925 100644 --- a/openfl/federated/data/federated_data.py +++ b/openfl/federated/data/federated_data.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """FederatedDataset module.""" import numpy as np @@ -11,53 +10,43 @@ class FederatedDataSet(PyTorchDataLoader): - """ - Data Loader for in memory Numpy data. - - Args: - X_train: np.array - Training Features - y_train: np.array - Training labels - X_val: np.array - Validation features - y_val: np.array - Validation labels - batch_size : int - The batch size for the data loader - num_classes : int - The number of classes the model will be trained on - **kwargs: Additional arguments to pass to the function - + """A Data Loader class used to represent a federated dataset for in-memory + Numpy data. + + Attributes: + train_splitter (NumPyDataSplitter): An object that splits the training + data. + valid_splitter (NumPyDataSplitter): An object that splits the + validation data. """ train_splitter: NumPyDataSplitter valid_splitter: NumPyDataSplitter - def __init__(self, X_train, y_train, X_valid, y_valid, - batch_size=1, num_classes=None, train_splitter=None, valid_splitter=None): - """ - Initialize. + def __init__(self, + X_train, + y_train, + X_valid, + y_valid, + batch_size=1, + num_classes=None, + train_splitter=None, + valid_splitter=None): + """Initializes the FederatedDataSet object. Args: - X_train: np.array - Training Features - y_train: np.array - Training labels - X_val: np.array - Validation features - y_val: np.array - Validation labels - batch_size : int - The batch size for the data loader - num_classes : int - The number of classes the model will be trained on - train_splitter: NumPyDataSplitter - Data splitter for train dataset. - valid_splitter: NumPyDataSplitter - Data splitter for validation dataset. - **kwargs: Additional arguments to pass to the function - + X_train (np.array): The training features. + y_train (np.array): The training labels. + X_valid (np.array): The validation features. + y_valid (np.array): The validation labels. + batch_size (int, optional): The batch size for the data loader. + Defaults to 1. + num_classes (int, optional): The number of classes the model will + be trained on. Defaults to None. + train_splitter (NumPyDataSplitter, optional): The object that + splits the training data. Defaults to None. + valid_splitter (NumPyDataSplitter, optional): The object that + splits the validation data. Defaults to None. """ super().__init__(batch_size) @@ -68,47 +57,55 @@ def __init__(self, X_train, y_train, X_valid, y_valid, if num_classes is None: num_classes = np.unique(self.y_train).shape[0] - print(f'Inferred {num_classes} classes from the provided labels...') + print( + f'Inferred {num_classes} classes from the provided labels...') self.num_classes = num_classes self.train_splitter = self._get_splitter_or_default(train_splitter) self.valid_splitter = self._get_splitter_or_default(valid_splitter) @staticmethod def _get_splitter_or_default(value): + """Returns the provided splitter if it's a NumPyDataSplitter, otherwise + returns a default EqualNumPyDataSplitter. + + Args: + value (NumPyDataSplitter): The provided data splitter. + + Raises: + NotImplementedError: If the provided data splitter is not a + NumPyDataSplitter. + """ if value is None: return EqualNumPyDataSplitter() if isinstance(value, NumPyDataSplitter): return value else: - raise NotImplementedError(f'Data splitter {value} is not supported') + raise NotImplementedError( + f'Data splitter {value} is not supported') def split(self, num_collaborators): - """Create a Federated Dataset for each of the collaborators. + """Splits the dataset into equal parts for each collaborator and + returns a list of FederatedDataSet objects. Args: - num_collaborators: int - Collaborators to split the dataset between - shuffle: boolean - Should the dataset be randomized? - equally: boolean - Should each collaborator get the same amount of data? + num_collaborators (int): The number of collaborators to split the + dataset between. Returns: - list[FederatedDataSets] - A dataset slice for each collaborator + FederatedDataSets (list): A list of FederatedDataSet objects, each + representing a slice of the dataset for a collaborator. """ train_idx = self.train_splitter.split(self.y_train, num_collaborators) valid_idx = self.valid_splitter.split(self.y_valid, num_collaborators) return [ - FederatedDataSet( - self.X_train[train_idx[i]], - self.y_train[train_idx[i]], - self.X_valid[valid_idx[i]], - self.y_valid[valid_idx[i]], - batch_size=self.batch_size, - num_classes=self.num_classes, - train_splitter=self.train_splitter, - valid_splitter=self.valid_splitter - ) for i in range(num_collaborators) + FederatedDataSet(self.X_train[train_idx[i]], + self.y_train[train_idx[i]], + self.X_valid[valid_idx[i]], + self.y_valid[valid_idx[i]], + batch_size=self.batch_size, + num_classes=self.num_classes, + train_splitter=self.train_splitter, + valid_splitter=self.valid_splitter) + for i in range(num_collaborators) ] diff --git a/openfl/federated/data/loader.py b/openfl/federated/data/loader.py index 5f50705b72..8e4aa8f8f4 100644 --- a/openfl/federated/data/loader.py +++ b/openfl/federated/data/loader.py @@ -1,72 +1,79 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """DataLoader module.""" class DataLoader: - """Federated Learning Data Loader Class.""" + """A base class used to represent a Federated Learning Data Loader. + + This class should be inherited by any data loader class specific to a + machine learning framework. + + Attributes: + None + """ def __init__(self, **kwargs): - """ - Instantiate the data object. + """Initializes the DataLoader object. - Returns: - None + Args: + kwargs: Additional arguments to pass to the function. """ pass def get_feature_shape(self): - """ - Get the shape of an example feature array. + """Returns the shape of an example feature array. - Returns: - tuple: shape of an example feature array + Raises: + NotImplementedError: This method must be implemented by a child + class. """ raise NotImplementedError def get_train_loader(self, **kwargs): - """ - Get training data loader. + """Returns the data loader for the training data. - Returns: - loader object (class defined by inheritor) + Args: + kwargs: Additional arguments to pass to the function. + + Raises: + NotImplementedError: This method must be implemented by a child + class. """ raise NotImplementedError def get_valid_loader(self): - """ - Get validation data loader. + """Returns the data loader for the validation data. - Returns: - loader object (class defined by inheritor) + Raises: + NotImplementedError: This method must be implemented by a child + class. """ raise NotImplementedError def get_infer_loader(self): - """ - Get inferencing data loader. + """Returns the data loader for inferencing data. - Returns - ------- - loader object (class defined by inheritor) + Raises: + NotImplementedError: This method must be implemented by a child + class. """ return NotImplementedError def get_train_data_size(self): - """ - Get total number of training samples. + """Returns the total number of training samples. - Returns: - int: number of training samples + Raises: + NotImplementedError: This method must be implemented by a child + class. """ raise NotImplementedError def get_valid_data_size(self): - """ - Get total number of validation samples. + """Returns the total number of validation samples. - Returns: - int: number of validation samples + Raises: + NotImplementedError: This method must be implemented by a child + class. """ raise NotImplementedError diff --git a/openfl/federated/data/loader_gandlf.py b/openfl/federated/data/loader_gandlf.py index b29533e307..bd1e8a3ba1 100644 --- a/openfl/federated/data/loader_gandlf.py +++ b/openfl/federated/data/loader_gandlf.py @@ -1,14 +1,29 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """PyTorchDataLoader module.""" + from .loader import DataLoader class GaNDLFDataLoaderWrapper(DataLoader): - """Data Loader for the Generally Nuanced Deep Learning Framework (GaNDLF).""" + """A class used to represent a data loader for the Generally Nuanced Deep + Learning Framework (GaNDLF). + + Attributes: + train_csv (str): Path to the training CSV file. + val_csv (str): Path to the validation CSV file. + train_dataloader (DataLoader): DataLoader object for the training data. + val_dataloader (DataLoader): DataLoader object for the validation data. + feature_shape (tuple): Shape of an example feature array. + """ def __init__(self, data_path, feature_shape): + """Initializes the GaNDLFDataLoaderWrapper object. + + Args: + data_path (str): The path to the directory containing the data. + feature_shape (tuple): The shape of an example feature array. + """ self.train_csv = data_path + '/train.csv' self.val_csv = data_path + '/valid.csv' self.train_dataloader = None @@ -16,50 +31,63 @@ def __init__(self, data_path, feature_shape): self.feature_shape = feature_shape def set_dataloaders(self, train_dataloader, val_dataloader): + """Sets the data loaders for the training and validation data. + + Args: + train_dataloader (DataLoader): The DataLoader object for the + training data. + val_dataloader (DataLoader): The DataLoader object for the + validation data. + """ self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader def get_feature_shape(self): - """Get the shape of an example feature array. + """Returns the shape of an example feature array. Returns: - tuple: shape of an example feature array + tuple: The shape of an example feature array. """ return self.feature_shape def get_train_loader(self, batch_size=None, num_batches=None): - """ - Get training data loader. + """Returns the data loader for the training data. + + Args: + batch_size (int, optional): The batch size for the data loader + (default is None). + num_batches (int, optional): The number of batches for the data + loader (default is None). - Returns - ------- - loader object + Returns: + DataLoader: The DataLoader object for the training data. """ return self.train_dataloader def get_valid_loader(self, batch_size=None): - """ - Get validation data loader. + """Returns the data loader for the validation data. + + Args: + batch_size (int, optional): The batch size for the data loader + (default is None). Returns: - loader object + DataLoader: The DataLoader object for the validation data. """ return self.val_dataloader def get_train_data_size(self): - """ - Get total number of training samples. + """Returns the total number of training samples. Returns: - int: number of training samples + int: The total number of training samples. """ return len(self.train_dataloader.dataset) def get_valid_data_size(self): - """ - Get total number of validation samples. + """Returns the total number of validation samples. Returns: - int: number of validation samples + int: The total number of validation samples. """ return len(self.val_dataloader.dataset) diff --git a/openfl/federated/data/loader_keras.py b/openfl/federated/data/loader_keras.py index 9a3834cafe..546605b965 100644 --- a/openfl/federated/data/loader_keras.py +++ b/openfl/federated/data/loader_keras.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """KerasDataLoader module.""" import numpy as np @@ -9,18 +8,22 @@ class KerasDataLoader(DataLoader): - """Federation Data Loader for TensorFlow Models.""" + """A class used to represent a Federation Data Loader for Keras models. + + Attributes: + batch_size (int): Size of batches used for all data loaders. + X_train (np.array): Training features. + y_train (np.array): Training labels. + X_valid (np.array): Validation features. + y_valid (np.array): Validation labels. + """ def __init__(self, batch_size, **kwargs): - """ - Instantiate the data object. + """Initializes the KerasDataLoader object. Args: - batch_size: Size of batches used for all data loaders - kwargs: consumes all un-used kwargs - - Returns: - None + batch_size (int): The size of batches used for all data loaders. + kwargs: Additional arguments to pass to the function. """ self.batch_size = batch_size self.X_train = None @@ -33,66 +36,73 @@ def __init__(self, batch_size, **kwargs): # define self.X_train, self.y_train, self.X_valid, and self.y_valid def get_feature_shape(self): - """Get the shape of an example feature array. + """Returns the shape of an example feature array. Returns: - tuple: shape of an example feature array + tuple: The shape of an example feature array. """ return self.X_train[0].shape def get_train_loader(self, batch_size=None, num_batches=None): - """ - Get training data loader. + """Returns the data loader for the training data. + + Args: + batch_size (int, optional): The batch size for the data loader + (default is None). + num_batches (int, optional): The number of batches for the data + loader (default is None). - Returns - ------- - loader object + Returns: + DataLoader: The DataLoader object for the training data. """ - return self._get_batch_generator(X=self.X_train, y=self.y_train, batch_size=batch_size, + return self._get_batch_generator(X=self.X_train, + y=self.y_train, + batch_size=batch_size, num_batches=num_batches) def get_valid_loader(self, batch_size=None): - """ - Get validation data loader. + """Returns the data loader for the validation data. + + Args: + batch_size (int, optional): The batch size for the data loader + (default is None). Returns: - loader object + DataLoader: The DataLoader object for the validation data. """ - return self._get_batch_generator(X=self.X_valid, y=self.y_valid, batch_size=batch_size) + return self._get_batch_generator(X=self.X_valid, + y=self.y_valid, + batch_size=batch_size) def get_train_data_size(self): - """ - Get total number of training samples. + """Returns the total number of training samples. Returns: - int: number of training samples + int: The total number of training samples. """ return self.X_train.shape[0] def get_valid_data_size(self): - """ - Get total number of validation samples. + """Returns the total number of validation samples. Returns: - int: number of validation samples + int: The total number of validation samples. """ return self.X_valid.shape[0] @staticmethod def _batch_generator(X, y, idxs, batch_size, num_batches): - """ - Generate batch of data. + """Generates batches of data. Args: - X: input data - y: label data - idxs: The index of the dataset - batch_size: The batch size for the data loader - num_batches: The number of batches + X (np.array): The input data. + y (np.array): The label data. + idxs (np.array): The index of the dataset. + batch_size (int): The batch size for the data loader. + num_batches (int): The number of batches. Yields: - tuple: input data, label data - + tuple: The input data and label data for each batch. """ for i in range(num_batches): a = i * batch_size @@ -100,14 +110,17 @@ def _batch_generator(X, y, idxs, batch_size, num_batches): yield X[idxs[a:b]], y[idxs[a:b]] def _get_batch_generator(self, X, y, batch_size, num_batches=None): - """ - Return the dataset generator. + """Returns the dataset generator. Args: - X: input data - y: label data - batch_size: The batch size for the data loader + X (np.array): The input data. + y (np.array): The label data. + batch_size (int): The batch size for the data loader. + num_batches (int, optional): The number of batches (default is + None). + Returns: + generator: The dataset generator. """ if batch_size is None: batch_size = self.batch_size diff --git a/openfl/federated/data/loader_pt.py b/openfl/federated/data/loader_pt.py index af005e745c..457abd6c74 100644 --- a/openfl/federated/data/loader_pt.py +++ b/openfl/federated/data/loader_pt.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """PyTorchDataLoader module.""" from math import ceil @@ -11,18 +10,25 @@ class PyTorchDataLoader(DataLoader): - """Federation Data Loader for TensorFlow Models.""" + """A class used to represent a Federation Data Loader for PyTorch models. + + Attributes: + batch_size (int): Size of batches used for all data loaders. + X_train (np.array): Training features. + y_train (np.array): Training labels. + X_valid (np.array): Validation features. + y_valid (np.array): Validation labels. + random_seed (int, optional): Random seed for data shuffling. + """ def __init__(self, batch_size, random_seed=None, **kwargs): - """ - Instantiate the data object. + """Initializes the PyTorchDataLoader object with the batch size, random + seed, and any additional arguments. Args: - batch_size: Size of batches used for all data loaders - kwargs: consumes all un-used kwargs - - Returns: - None + batch_size (int): The size of batches used for all data loaders. + random_seed (int, optional): Random seed for data shuffling. + kwargs: Additional arguments to pass to the function. """ self.batch_size = batch_size self.X_train = None @@ -36,66 +42,73 @@ def __init__(self, batch_size, random_seed=None, **kwargs): # define self.X_train, self.y_train, self.X_valid, and self.y_valid def get_feature_shape(self): - """Get the shape of an example feature array. + """Returns the shape of an example feature array. Returns: - tuple: shape of an example feature array + tuple: The shape of an example feature array. """ return self.X_train[0].shape def get_train_loader(self, batch_size=None, num_batches=None): - """ - Get training data loader. + """Returns the data loader for the training data. + + Args: + batch_size (int, optional): The batch size for the data loader + (default is None). + num_batches (int, optional): The number of batches for the data + loader (default is None). - Returns - ------- - loader object + Returns: + DataLoader: The DataLoader object for the training data. """ - return self._get_batch_generator( - X=self.X_train, y=self.y_train, batch_size=batch_size, num_batches=num_batches) + return self._get_batch_generator(X=self.X_train, + y=self.y_train, + batch_size=batch_size, + num_batches=num_batches) def get_valid_loader(self, batch_size=None): - """ - Get validation data loader. + """Returns the data loader for the validation data. + + Args: + batch_size (int, optional): The batch size for the data loader + (default is None). Returns: - loader object + DataLoader: The DataLoader object for the validation data. """ - return self._get_batch_generator(X=self.X_valid, y=self.y_valid, batch_size=batch_size) + return self._get_batch_generator(X=self.X_valid, + y=self.y_valid, + batch_size=batch_size) def get_train_data_size(self): - """ - Get total number of training samples. + """Returns the total number of training samples. Returns: - int: number of training samples + int: The total number of training samples. """ return self.X_train.shape[0] def get_valid_data_size(self): - """ - Get total number of validation samples. + """Returns the total number of validation samples. Returns: - int: number of validation samples + int: The total number of validation samples. """ return self.X_valid.shape[0] @staticmethod def _batch_generator(X, y, idxs, batch_size, num_batches): - """ - Generate batch of data. + """Generates batches of data. Args: - X: input data - y: label data - idxs: The index of the dataset - batch_size: The batch size for the data loader - num_batches: The number of batches + X (np.array): The input data. + y (np.array): The label data. + idxs (np.array): The index of the dataset. + batch_size (int): The batch size for the data loader. + num_batches (int): The number of batches. Yields: - tuple: input data, label data - + tuple: The input data and label data for each batch. """ for i in range(num_batches): a = i * batch_size @@ -103,14 +116,17 @@ def _batch_generator(X, y, idxs, batch_size, num_batches): yield X[idxs[a:b]], y[idxs[a:b]] def _get_batch_generator(self, X, y, batch_size, num_batches=None): - """ - Return the dataset generator. + """Returns the dataset generator. Args: - X: input data - y: label data - batch_size: The batch size for the data loader + X (np.array): The input data. + y (np.array): The label data. + batch_size (int): The batch size for the data loader. + num_batches (int, optional): The number of batches (default is + None). + Returns: + generator: The dataset generator. """ if batch_size is None: batch_size = self.batch_size diff --git a/openfl/federated/data/loader_tf.py b/openfl/federated/data/loader_tf.py index e678ec8387..aa29fe1c97 100644 --- a/openfl/federated/data/loader_tf.py +++ b/openfl/federated/data/loader_tf.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """TensorflowDataLoader module.""" import numpy as np @@ -9,18 +8,24 @@ class TensorFlowDataLoader(DataLoader): - """Federation Data Loader for TensorFlow Models.""" + """A class used to represent a Federation Data Loader for TensorFlow + models. + + Attributes: + batch_size (int): Size of batches used for all data loaders. + X_train (np.array): Training features. + y_train (np.array): Training labels. + X_valid (np.array): Validation features. + y_valid (np.array): Validation labels. + """ def __init__(self, batch_size, **kwargs): - """ - Instantiate the data object. + """Initializes the TensorFlowDataLoader object with the batch size and + any additional arguments. Args: - batch_size: Size of batches used for all data loaders - kwargs: consumes all un-used kwargs - - Returns: - None + batch_size (int): The size of batches used for all data loaders. + kwargs: Additional arguments to pass to the function. """ self.batch_size = batch_size self.X_train = None @@ -33,66 +38,70 @@ def __init__(self, batch_size, **kwargs): # define self.X_train, self.y_train, self.X_valid, and self.y_valid def get_feature_shape(self): - """ - Get the shape of an example feature array. + """Returns the shape of an example feature array. Returns: - tuple: shape of an example feature array + tuple: The shape of an example feature array. """ return self.X_train[0].shape def get_train_loader(self, batch_size=None): - """ - Get training data loader. + """Returns the data loader for the training data. - Returns - ------- - loader object + Args: + batch_size (int, optional): The batch size for the data loader + (default is None). + + Returns: + DataLoader: The DataLoader object for the training data. """ - return self._get_batch_generator(X=self.X_train, y=self.y_train, batch_size=batch_size) + return self._get_batch_generator(X=self.X_train, + y=self.y_train, + batch_size=batch_size) def get_valid_loader(self, batch_size=None): - """ - Get validation data loader. + """Returns the data loader for the validation data. + + Args: + batch_size (int, optional): The batch size for the data loader + (default is None). Returns: - loader object + DataLoader: The DataLoader object for the validation data. """ - return self._get_batch_generator(X=self.X_valid, y=self.y_valid, batch_size=batch_size) + return self._get_batch_generator(X=self.X_valid, + y=self.y_valid, + batch_size=batch_size) def get_train_data_size(self): - """ - Get total number of training samples. + """Returns the total number of training samples. Returns: - int: number of training samples + int: The total number of training samples. """ return self.X_train.shape[0] def get_valid_data_size(self): - """ - Get total number of validation samples. + """Returns the total number of validation samples. Returns: - int: number of validation samples + int: The total number of validation samples. """ return self.X_valid.shape[0] @staticmethod def _batch_generator(X, y, idxs, batch_size, num_batches): - """ - Generate batch of data. + """Generates batches of data. Args: - X: input data - y: label data - idxs: The index of the dataset - batch_size: The batch size for the data loader - num_batches: The number of batches + X (np.array): The input data. + y (np.array): The label data. + idxs (np.array): The index of the dataset. + batch_size (int): The batch size for the data loader. + num_batches (int): The number of batches. Yields: - tuple: input data, label data - + tuple: The input data and label data for each batch. """ for i in range(num_batches): a = i * batch_size @@ -100,14 +109,15 @@ def _batch_generator(X, y, idxs, batch_size, num_batches): yield X[idxs[a:b]], y[idxs[a:b]] def _get_batch_generator(self, X, y, batch_size): - """ - Return the dataset generator. + """Returns the dataset generator. Args: - X: input data - y: label data - batch_size: The batch size for the data loader + X (np.array): The input data. + y (np.array): The label data. + batch_size (int): The batch size for the data loader. + Returns: + generator: The dataset generator. """ if batch_size is None: batch_size = self.batch_size diff --git a/openfl/federated/plan/__init__.py b/openfl/federated/plan/__init__.py index 0733ba5d90..9c172a9d7b 100644 --- a/openfl/federated/plan/__init__.py +++ b/openfl/federated/plan/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Plan package.""" from .plan import Plan diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 33c1723475..96ae982eba 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Plan module.""" from hashlib import sha384 from importlib import import_module @@ -27,13 +26,44 @@ class Plan: - """Federated Learning plan.""" + """A class used to represent a Federated Learning plan. + + This class provides methods to manage and manipulate federated learning + plans. + + Attributes: + logger (Logger): Logger instance for the class. + config (dict): Dictionary containing patched plan definition. + authorized_cols (list): Authorized collaborator list. + cols_data_paths (dict): Collaborator data paths dictionary. + collaborator_ (Collaborator): Collaborator object. + aggregator_ (Aggregator): Aggregator object. + assigner_ (Assigner): Assigner object. + loader_ (DataLoader): Data loader object. + runner_ (TaskRunner): Task runner object. + server_ (AggregatorGRPCServer): gRPC server object. + client_ (AggregatorGRPCClient): gRPC client object. + pipe_ (CompressionPipeline): Compression pipeline object. + straggler_policy_ (StragglerHandlingPolicy): Straggler handling policy. + hash_ (str): Hash of the instance. + name_ (str): Name of the instance. + serializer_ (SerializerPlugin): Serializer plugin. + """ logger = getLogger(__name__) @staticmethod def load(yaml_path: Path, default: dict = None): - """Load the plan from YAML file.""" + """Load the plan from YAML file. + + Args: + yaml_path (Path): Path to the YAML file. + default (dict, optional): Default plan configuration. + Defaults to {}. + + Returns: + dict: Plan configuration loaded from the YAML file. + """ if default is None: default = {} if yaml_path and yaml_path.exists(): @@ -42,7 +72,14 @@ def load(yaml_path: Path, default: dict = None): @staticmethod def dump(yaml_path, config, freeze=False): - """Dump the plan config to YAML file.""" + """Dump the plan config to YAML file. + + Args: + yaml_path (Path): Path to the YAML file. + config (dict): Plan configuration to be dumped. + freeze (bool, optional): Flag to freeze the plan. Defaults to + False. + """ class NoAliasDumper(SafeDumper): @@ -64,30 +101,33 @@ def ignore_aliases(self, data): yaml_path.write_text(dump(config)) @staticmethod - def parse(plan_config_path: Path, cols_config_path: Path = None, - data_config_path: Path = None, gandlf_config_path=None, + def parse(plan_config_path: Path, + cols_config_path: Path = None, + data_config_path: Path = None, + gandlf_config_path=None, resolve=True): - """ - Parse the Federated Learning plan. + """Parse the Federated Learning plan. Args: - plan_config_path (string): The filepath to the federated learning - plan - cols_config_path (string): The filepath to the federation - collaborator list [optional] - data_config_path (string): The filepath to the federation - collaborator data configuration - [optional] - override_config_path (string): The filepath to a yaml file - that overrides the configuration - [optional] + plan_config_path (Path): The filepath to the Federated Learning + plan. + cols_config_path (Path, optional): The filepath to the Federation + collaborator list. Defaults to None. + data_config_path (Path, optional): The filepath to the Federation + collaborator data configuration. Defaults to None. + gandlf_config_path (Path, optional): The filepath to a yaml file + that overrides the configuration. Defaults to None. + resolve (bool, optional): Flag to resolve the plan settings. + Defaults to True. + Returns: - A federated learning plan object + Plan: A Federated Learning plan object. """ try: plan = Plan() - plan.config = Plan.load(plan_config_path) # load plan configuration + plan.config = Plan.load( + plan_config_path) # load plan configuration plan.name = plan_config_path.name plan.files = [plan_config_path] # collect all the plan files @@ -132,12 +172,13 @@ def parse(plan_config_path: Path, cols_config_path: Path = None, gandlf_config = Plan.load(Path(gandlf_config_path)) # check for some defaults - gandlf_config['output_dir'] = gandlf_config.get('output_dir', '.') - plan.config['task_runner']['settings']['gandlf_config'] = gandlf_config + gandlf_config['output_dir'] = gandlf_config.get( + 'output_dir', '.') + plan.config['task_runner']['settings'][ + 'gandlf_config'] = gandlf_config plan.authorized_cols = Plan.load(cols_config_path).get( - 'collaborators', [] - ) + 'collaborators', []) # TODO: Does this need to be a YAML file? Probably want to use key # value as the plan hash @@ -163,22 +204,24 @@ def parse(plan_config_path: Path, cols_config_path: Path = None, return plan except Exception: - Plan.logger.exception(f'Parsing Federated Learning Plan : ' - f'[red]FAILURE[/] : [blue]{plan_config_path}[/].', - extra={'markup': True}) + Plan.logger.exception( + f'Parsing Federated Learning Plan : ' + f'[red]FAILURE[/] : [blue]{plan_config_path}[/].', + extra={'markup': True}) raise @staticmethod def build(template, settings, **override): - """ - Create an instance of a openfl Component or Federated DataLoader/TaskRunner. + """Create an instance of a openfl Component or Federated + DataLoader/TaskRunner. Args: - template: Fully qualified class template path - settings: Keyword arguments to class constructor + template (str): Fully qualified class template path. + settings (dict): Keyword arguments to class constructor. + override (dict): Additional settings to override the default ones. Returns: - A Python object + object: A Python object. """ class_name = splitext(template)[1].strip('.') module_path = splitext(template)[0] @@ -196,27 +239,28 @@ def build(template, settings, **override): @staticmethod def import_(template): - """ - Import an instance of a openfl Component or Federated DataLoader/TaskRunner. + """Import an instance of a openfl Component or Federated + DataLoader/TaskRunner. Args: - template: Fully qualified object path + template (str): Fully qualified object path. Returns: - A Python object + object: A Python object. """ class_name = splitext(template)[1].strip('.') module_path = splitext(template)[0] - Plan.logger.info(f'Importing [red]🡆[/] Object [red]{class_name}[/] ' - f'from [red]{module_path}[/] Module.', - extra={'markup': True}) + Plan.logger.info( + f'Importing [red]🡆[/] Object [red]{class_name}[/] ' + f'from [red]{module_path}[/] Module.', + extra={'markup': True}) module = import_module(module_path) instance = getattr(module, class_name) return instance def __init__(self): - """Initialize.""" + """Initializes the Plan object.""" self.config = {} # dictionary containing patched plan definition self.authorized_cols = [] # authorized collaborator list self.cols_data_paths = {} # collaborator data paths dict @@ -261,20 +305,19 @@ def resolve(self): if self.config['network'][SETTINGS]['agg_port'] == AUTO: self.config['network'][SETTINGS]['agg_port'] = int( - self.hash[:8], 16 - ) % (60999 - 49152) + 49152 + self.hash[:8], 16) % (60999 - 49152) + 49152 def get_assigner(self): """Get the plan task assigner.""" aggregation_functions_by_task = None assigner_function = None try: - aggregation_functions_by_task = self.restore_object('aggregation_function_obj.pkl') + aggregation_functions_by_task = self.restore_object( + 'aggregation_function_obj.pkl') assigner_function = self.restore_object('task_assigner_obj.pkl') except Exception as exc: self.logger.error( - f'Failed to load aggregation and assigner functions: {exc}' - ) + f'Failed to load aggregation and assigner functions: {exc}') self.logger.info('Using Task Runner API workflow') if assigner_function: self.assigner_ = Assigner( @@ -285,13 +328,10 @@ def get_assigner(self): ) else: # Backward compatibility - defaults = self.config.get( - 'assigner', - { - TEMPLATE: 'openfl.component.Assigner', - SETTINGS: {} - } - ) + defaults = self.config.get('assigner', { + TEMPLATE: 'openfl.component.Assigner', + SETTINGS: {} + }) defaults[SETTINGS]['authorized_cols'] = self.authorized_cols defaults[SETTINGS]['rounds_to_train'] = self.rounds_to_train @@ -324,43 +364,58 @@ def get_tasks(self): return tasks def get_aggregator(self, tensor_dict=None): - """Get federation aggregator.""" - defaults = self.config.get('aggregator', - { - TEMPLATE: 'openfl.component.Aggregator', - SETTINGS: {} - }) + """Get federation aggregator. + + This method retrieves the federation aggregator. If the aggregator + does not exist, it is built using the configuration settings and the + provided tensor dictionary. + + Args: + tensor_dict (dict, optional): The initial tensor dictionary to use + when building the aggregator. Defaults to None. + + Returns: + self.aggregator_ (Aggregator): The federation aggregator. + + Raises: + TypeError: If the log_metric_callback is not a callable object or + cannot be imported from code. + """ + defaults = self.config.get('aggregator', { + TEMPLATE: 'openfl.component.Aggregator', + SETTINGS: {} + }) defaults[SETTINGS]['aggregator_uuid'] = self.aggregator_uuid defaults[SETTINGS]['federation_uuid'] = self.federation_uuid defaults[SETTINGS]['authorized_cols'] = self.authorized_cols defaults[SETTINGS]['assigner'] = self.get_assigner() defaults[SETTINGS]['compression_pipeline'] = self.get_tensor_pipe() - defaults[SETTINGS]['straggler_handling_policy'] = self.get_straggler_handling_policy() + defaults[SETTINGS][ + 'straggler_handling_policy'] = self.get_straggler_handling_policy() log_metric_callback = defaults[SETTINGS].get('log_metric_callback') if log_metric_callback: if isinstance(log_metric_callback, dict): log_metric_callback = Plan.import_(**log_metric_callback) elif not callable(log_metric_callback): - raise TypeError(f'log_metric_callback should be callable object ' - f'or be import from code part, get {log_metric_callback}') + raise TypeError( + f'log_metric_callback should be callable object ' + f'or be import from code part, get {log_metric_callback}') defaults[SETTINGS]['log_metric_callback'] = log_metric_callback if self.aggregator_ is None: - self.aggregator_ = Plan.build(**defaults, initial_tensor_dict=tensor_dict) + self.aggregator_ = Plan.build(**defaults, + initial_tensor_dict=tensor_dict) return self.aggregator_ def get_tensor_pipe(self): """Get data tensor pipeline.""" - defaults = self.config.get( - 'compression_pipeline', - { - TEMPLATE: 'openfl.pipelines.NoCompressionPipeline', - SETTINGS: {} - } - ) + defaults = self.config.get('compression_pipeline', { + TEMPLATE: 'openfl.pipelines.NoCompressionPipeline', + SETTINGS: {} + }) if self.pipe_ is None: self.pipe_ = Plan.build(**defaults) @@ -370,13 +425,10 @@ def get_tensor_pipe(self): def get_straggler_handling_policy(self): """Get straggler handling policy.""" template = 'openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling' - defaults = self.config.get( - 'straggler_handling_policy', - { - TEMPLATE: template, - SETTINGS: {} - } - ) + defaults = self.config.get('straggler_handling_policy', { + TEMPLATE: template, + SETTINGS: {} + }) if self.straggler_policy_ is None: self.straggler_policy_ = Plan.build(**defaults) @@ -385,16 +437,21 @@ def get_straggler_handling_policy(self): # legacy api (TaskRunner subclassing) def get_data_loader(self, collaborator_name): - """Get data loader.""" - defaults = self.config.get('data_loader', - { - TEMPLATE: 'openfl.federation.DataLoader', - SETTINGS: {} - }) + """Get data loader for a specific collaborator. + + Args: + collaborator_name (str): Name of the collaborator. + + Returns: + DataLoader: Data loader for the specified collaborator. + """ + defaults = self.config.get('data_loader', { + TEMPLATE: 'openfl.federation.DataLoader', + SETTINGS: {} + }) defaults[SETTINGS]['data_path'] = self.cols_data_paths[ - collaborator_name - ] + collaborator_name] if self.loader_ is None: self.loader_ = Plan.build(**defaults) @@ -403,18 +460,32 @@ def get_data_loader(self, collaborator_name): # Python interactive api def initialize_data_loader(self, data_loader, shard_descriptor): - """Get data loader.""" + """Initialize data loader. + + Args: + data_loader (DataLoader): Data loader to initialize. + shard_descriptor (ShardDescriptor): Descriptor of the data shard. + + Returns: + DataLoader: Initialized data loader. + """ data_loader.shard_descriptor = shard_descriptor return data_loader # legacy api (TaskRunner subclassing) def get_task_runner(self, data_loader): - """Get task runner.""" - defaults = self.config.get('task_runner', - { - TEMPLATE: 'openfl.federation.TaskRunner', - SETTINGS: {} - }) + """Get task runner. + + Args: + data_loader (DataLoader): Data loader for the tasks. + + Returns: + TaskRunner: Task runner for the tasks. + """ + defaults = self.config.get('task_runner', { + TEMPLATE: 'openfl.federation.TaskRunner', + SETTINGS: {} + }) defaults[SETTINGS]['data_loader'] = data_loader @@ -427,13 +498,25 @@ def get_task_runner(self, data_loader): return self.runner_ # Python interactive api - def get_core_task_runner(self, data_loader=None, + def get_core_task_runner(self, + data_loader=None, model_provider=None, task_keeper=None): - """Get task runner.""" + """Get core task runner. + + Args: + data_loader (DataLoader, optional): Data loader for the tasks. + Defaults to None. + model_provider (ModelProvider, optional): Provider for the model. + Defaults to None. + task_keeper (TaskKeeper, optional): Keeper for the tasks. Defaults + to None. + + Returns: + CoreTaskRunner: Core task runner for the tasks. + """ defaults = self.config.get( - 'task_runner', - { + 'task_runner', { TEMPLATE: 'openfl.federated.task.task_runner.CoreTaskRunner', SETTINGS: {} }) @@ -448,7 +531,8 @@ def get_core_task_runner(self, data_loader=None, self.runner_.set_task_provider(task_keeper) framework_adapter = Plan.build( - self.config['task_runner']['required_plugin_components']['framework_adapters'], {}) + self.config['task_runner']['required_plugin_components'] + ['framework_adapters'], {}) # This step initializes tensorkeys # Which have no sens if task provider is not set up @@ -456,16 +540,42 @@ def get_core_task_runner(self, data_loader=None, return self.runner_ - def get_collaborator(self, collaborator_name, root_certificate=None, private_key=None, - certificate=None, task_runner=None, client=None, shard_descriptor=None): - """Get collaborator.""" - defaults = self.config.get( - 'collaborator', - { - TEMPLATE: 'openfl.component.Collaborator', - SETTINGS: {} - } - ) + def get_collaborator(self, + collaborator_name, + root_certificate=None, + private_key=None, + certificate=None, + task_runner=None, + client=None, + shard_descriptor=None): + """Get collaborator. + + This method retrieves a collaborator. If the collaborator does not + exist, it is built using the configuration settings and the provided + parameters. + + Args: + collaborator_name (str): Name of the collaborator. + root_certificate (str, optional): Root certificate for the + collaborator. Defaults to None. + private_key (str, optional): Private key for the collaborator. + Defaults to None. + certificate (str, optional): Certificate for the collaborator. + Defaults to None. + task_runner (TaskRunner, optional): Task runner for the + collaborator. Defaults to None. + client (Client, optional): Client for the collaborator. Defaults + to None. + shard_descriptor (ShardDescriptor, optional): Descriptor of the + data shard. Defaults to None. + + Returns: + self.collaborator_ (Collaborator): The collaborator instance. + """ + defaults = self.config.get('collaborator', { + TEMPLATE: 'openfl.component.Collaborator', + SETTINGS: {} + }) defaults[SETTINGS]['collaborator_name'] = collaborator_name defaults[SETTINGS]['aggregator_uuid'] = self.aggregator_uuid @@ -474,15 +584,21 @@ def get_collaborator(self, collaborator_name, root_certificate=None, private_key if task_runner is not None: defaults[SETTINGS]['task_runner'] = task_runner else: - # Here we support new interactive api as well as old task_runner subclassing interface - # If Task Runner class is placed incide openfl `task-runner` subpackage it is - # a part of the New API and it is a part of OpenFL kernel. - # If Task Runner is placed elsewhere, somewhere in user workspace, than it is - # a part of the old interface and we follow legacy initialization procedure. - if 'openfl.federated.task.task_runner' in self.config['task_runner']['template']: + # Here we support new interactive api as well as old task_runner + # subclassing interface. + # If Task Runner class is placed incide openfl `task-runner` + # subpackage it is a part of the New API and it is a part of + # OpenFL kernel. + # If Task Runner is placed elsewhere, somewhere in user workspace, + # than it is a part of the old interface and we follow legacy + # initialization procedure. + if 'openfl.federated.task.task_runner' in self.config[ + 'task_runner']['template']: # Interactive API - model_provider, task_keeper, data_loader = self.deserialize_interface_objects() - data_loader = self.initialize_data_loader(data_loader, shard_descriptor) + model_provider, task_keeper, data_loader = self.deserialize_interface_objects( + ) + data_loader = self.initialize_data_loader( + data_loader, shard_descriptor) defaults[SETTINGS]['task_runner'] = self.get_core_task_runner( data_loader=data_loader, model_provider=model_provider, @@ -490,7 +606,8 @@ def get_collaborator(self, collaborator_name, root_certificate=None, private_key else: # TaskRunner subclassing API data_loader = self.get_data_loader(collaborator_name) - defaults[SETTINGS]['task_runner'] = self.get_task_runner(data_loader) + defaults[SETTINGS]['task_runner'] = self.get_task_runner( + data_loader) defaults[SETTINGS]['compression_pipeline'] = self.get_tensor_pipe() defaults[SETTINGS]['task_config'] = self.config.get('tasks', {}) @@ -498,22 +615,37 @@ def get_collaborator(self, collaborator_name, root_certificate=None, private_key defaults[SETTINGS]['client'] = client else: defaults[SETTINGS]['client'] = self.get_client( - collaborator_name, - self.aggregator_uuid, - self.federation_uuid, - root_certificate, - private_key, - certificate - ) + collaborator_name, self.aggregator_uuid, self.federation_uuid, + root_certificate, private_key, certificate) if self.collaborator_ is None: self.collaborator_ = Plan.build(**defaults) return self.collaborator_ - def get_client(self, collaborator_name, aggregator_uuid, federation_uuid, - root_certificate=None, private_key=None, certificate=None): - """Get gRPC client for the specified collaborator.""" + def get_client(self, + collaborator_name, + aggregator_uuid, + federation_uuid, + root_certificate=None, + private_key=None, + certificate=None): + """Get gRPC client for the specified collaborator. + + Args: + collaborator_name (str): Name of the collaborator. + aggregator_uuid (str): UUID of the aggregator. + federation_uuid (str): UUID of the federation. + root_certificate (str, optional): Root certificate for the + collaborator. Defaults to None. + private_key (str, optional): Private key for the collaborator. + Defaults to None. + certificate (str, optional): Certificate for the collaborator. + Defaults to None. + + Returns: + AggregatorGRPCClient: gRPC client for the specified collaborator. + """ common_name = collaborator_name if not root_certificate or not private_key or not certificate: root_certificate = 'cert/cert_chain.crt' @@ -536,8 +668,25 @@ def get_client(self, collaborator_name, aggregator_uuid, federation_uuid, return self.client_ - def get_server(self, root_certificate=None, private_key=None, certificate=None, **kwargs): - """Get gRPC server of the aggregator instance.""" + def get_server(self, + root_certificate=None, + private_key=None, + certificate=None, + **kwargs): + """Get gRPC server of the aggregator instance. + + Args: + root_certificate (str, optional): Root certificate for the server. + Defaults to None. + private_key (str, optional): Private key for the server. Defaults + to None. + certificate (str, optional): Certificate for the server. Defaults + to None. + **kwargs: Additional keyword arguments. + + Returns: + AggregatorGRPCServer: gRPC server of the aggregator instance. + """ common_name = self.config['network'][SETTINGS]['agg_addr'].lower() if not root_certificate or not private_key or not certificate: @@ -561,9 +710,20 @@ def get_server(self, root_certificate=None, private_key=None, certificate=None, return self.server_ - def interactive_api_get_server(self, *, tensor_dict, root_certificate, certificate, - private_key, tls): - """Get gRPC server of the aggregator instance.""" + def interactive_api_get_server(self, *, tensor_dict, root_certificate, + certificate, private_key, tls): + """Get gRPC server of the aggregator instance for interactive API. + + Args: + tensor_dict (dict): Dictionary of tensors. + root_certificate (str): Root certificate for the server. + certificate (str): Certificate for the server. + private_key (str): Private key for the server. + tls (bool): Whether to use Transport Layer Security. + + Returns: + AggregatorGRPCServer: gRPC server of the aggregator instance. + """ server_args = self.config['network'][SETTINGS] # patch certificates @@ -580,30 +740,49 @@ def interactive_api_get_server(self, *, tensor_dict, root_certificate, certifica return self.server_ def deserialize_interface_objects(self): - """Deserialize objects for TaskRunner.""" + """Deserialize objects for TaskRunner. + + Returns: + tuple: Tuple containing the deserialized objects. + """ api_layer = self.config['api_layer'] filenames = [ - 'model_interface_file', - 'tasks_interface_file', + 'model_interface_file', 'tasks_interface_file', 'dataloader_interface_file' ] - return (self.restore_object(api_layer['settings'][filename]) for filename in filenames) + return (self.restore_object(api_layer['settings'][filename]) + for filename in filenames) def get_serializer_plugin(self, **kwargs): """Get serializer plugin. - This plugin is used for serialization of interfaces in new interactive API + This plugin is used for serialization of interfaces in new interactive + API. + + Args: + **kwargs: Additional keyword arguments. + + Returns: + SerializerPlugin: Serializer plugin. """ if self.serializer_ is None: if 'api_layer' not in self.config: # legacy API return None - required_plugin_components = self.config['api_layer']['required_plugin_components'] + required_plugin_components = self.config['api_layer'][ + 'required_plugin_components'] serializer_plugin = required_plugin_components['serializer_plugin'] self.serializer_ = Plan.build(serializer_plugin, kwargs) return self.serializer_ def restore_object(self, filename): - """Deserialize an object.""" + """Deserialize an object. + + Args: + filename (str): Name of the file. + + Returns: + object: Deserialized object. + """ serializer_plugin = self.get_serializer_plugin() if serializer_plugin is None: return None diff --git a/openfl/federated/task/__init__.py b/openfl/federated/task/__init__.py index b5efcdcd50..6567f03a72 100644 --- a/openfl/federated/task/__init__.py +++ b/openfl/federated/task/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Task package.""" import pkgutil @@ -15,7 +14,6 @@ from .runner import TaskRunner # NOQA - if pkgutil.find_loader('tensorflow'): from .runner_tf import TensorFlowTaskRunner # NOQA from .runner_keras import KerasTaskRunner # NOQA diff --git a/openfl/federated/task/fl_model.py b/openfl/federated/task/fl_model.py index 45a5181a42..4728d37da8 100644 --- a/openfl/federated/task/fl_model.py +++ b/openfl/federated/task/fl_model.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """FederatedModel module.""" import inspect @@ -9,29 +8,46 @@ class FederatedModel(TaskRunner): - """ - A wrapper that adapts to Tensorflow and Pytorch models to a federated context. + """A wrapper that adapts to Tensorflow and Pytorch models to a federated + context. + + This class provides methods to manage and manipulate federated models. - Args: - model : tensorflow/keras (function) , pytorch (class) + Attributes: + build_model (function or class): tensorflow/keras (function) , pytorch (class). For keras/tensorflow model, expects a function that returns the - model definition + model definition. For pytorch models, expects a class (not an instance) containing - the model definition and forward function - optimizer : lambda function (only required for pytorch) + the model definition and forward function. + lambda_opt (function): Lambda function for the optimizer (only + required for pytorch). The optimizer should be definied within a lambda function. This allows the optimizer to be attached to the federated models spawned for each collaborator. - loss_fn : pytorch loss_fun (only required for pytorch) + model (Model): The built model. + optimizer (Optimizer): Optimizer for the model. + runner (TaskRunner): Task runner for the model. + loss_fn (Loss): PyTorch Loss function for the model (only required for + pytorch). + tensor_dict_split_fn_kwargs (dict): Keyword arguments for the tensor + dict split function. """ def __init__(self, build_model, optimizer=None, loss_fn=None, **kwargs): - """Initialize. + """Initializes the FederatedModel object. - Args: - model: build_model function - **kwargs: Additional parameters to pass to the function + Sets up the initial state of the FederatedModel object, initializing + various components needed for the federated model. + Args: + build_model (function or class): Function that returns the model + definition or Class containing the model definition and + forward function. + optimizer (function, optional): Lambda function defining the + optimizer. Defaults to None. + loss_fn (function, optional): PyTorch loss function. Defaults to + None. + **kwargs: Additional parameters to pass to the function. """ super().__init__(**kwargs) @@ -47,8 +63,8 @@ def __init__(self, build_model, optimizer=None, loss_fn=None, **kwargs): if hasattr(self.model, 'forward'): self.runner.forward = self.model.forward else: - self.model = self.build_model( - self.feature_shape, self.data_loader.num_classes) + self.model = self.build_model(self.feature_shape, + self.data_loader.num_classes) from .runner_keras import KerasTaskRunner self.runner = KerasTaskRunner(**kwargs) self.optimizer = self.model.optimizer @@ -67,33 +83,38 @@ def __init__(self, build_model, optimizer=None, loss_fn=None, **kwargs): self.initialize_tensorkeys_for_functions() def __getattribute__(self, attr): - """Direct call into self.runner methods if necessary.""" - if attr in ['reset_opt_vars', 'initialize_globals', - 'set_tensor_dict', 'get_tensor_dict', - 'get_required_tensorkeys_for_function', - 'initialize_tensorkeys_for_functions', - 'save_native', 'load_native', 'rebuild_model', - 'set_optimizer_treatment', - 'train', 'train_batches', 'validate']: + """Direct call into self.runner methods if necessary. + Args: + attr (str): Attribute name. + + Returns: + attribute: Requested attribute from the runner or the class itself. + """ + if attr in [ + 'reset_opt_vars', 'initialize_globals', 'set_tensor_dict', + 'get_tensor_dict', 'get_required_tensorkeys_for_function', + 'initialize_tensorkeys_for_functions', 'save_native', + 'load_native', 'rebuild_model', 'set_optimizer_treatment', + 'train', 'train_batches', 'validate' + ]: return self.runner.__getattribute__(attr) return super(FederatedModel, self).__getattribute__(attr) def setup(self, num_collaborators, **kwargs): - """ - Create new models for all of the collaborators in the experiment. + """Create new models for all of the collaborators in the experiment. Args: - num_collaborators: Number of experiment collaborators + num_collaborators (int): Number of experiment collaborators. + **kwargs: Additional parameters to pass to the function. Returns: - List of models + List[FederatedModel]: List of models for each collaborator. """ return [ - FederatedModel( - self.build_model, - optimizer=self.lambda_opt, - loss_fn=self.loss_fn, - data_loader=data_slice, - **kwargs - ) - for data_slice in self.data_loader.split(num_collaborators)] + FederatedModel(self.build_model, + optimizer=self.lambda_opt, + loss_fn=self.loss_fn, + data_loader=data_slice, + **kwargs) + for data_slice in self.data_loader.split(num_collaborators) + ] diff --git a/openfl/federated/task/runner.py b/openfl/federated/task/runner.py index 8d8c8a885c..e8835a889e 100644 --- a/openfl/federated/task/runner.py +++ b/openfl/federated/task/runner.py @@ -1,8 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -""" -Mixin class for FL models. No default implementation. +"""Mixin class for FL models. No default implementation. Each framework will likely have its own baseclass implementation (e.g. TensorflowTaskRunner) that uses this mixin. @@ -15,16 +13,28 @@ class TaskRunner: - """Federated Learning Task Runner Class.""" - - def __init__(self, data_loader, tensor_dict_split_fn_kwargs: dict = None, **kwargs): - """ - Intialize. + """Federated Learning Task Runner Class. + + Attributes: + data_loader: The data_loader object. + tensor_dict_split_fn_kwargs (dict): Key word arguments for determining + which parameters to hold out from aggregation. + logger (logging.Logger): Logger object for logging events. + opt_treatment (str): Treatment of current instance optimizer. + """ + + def __init__(self, + data_loader, + tensor_dict_split_fn_kwargs: dict = None, + **kwargs): + """Intializes the TaskRunner object. Args: data_loader: The data_loader object - tensor_dict_split_fn_kwargs: (Default=None) - **kwargs: Additional parameters to pass to the function + tensor_dict_split_fn_kwargs (dict, optional): Key word arguments + for determining which parameters to hold out from aggregation. + Default is None. + **kwargs: Additional parameters to pass to the function. """ self.data_loader = data_loader self.feature_shape = self.data_loader.get_feature_shape() @@ -34,7 +44,8 @@ def __init__(self, data_loader, tensor_dict_split_fn_kwargs: dict = None, **kwar # If set to none, an empty dict will be passed, currently resulting in # the defaults: # be held out - # holdout_tensor_names=[] # params with these names will be held out # NOQA:E800 + # holdout_tensor_names=[] # NOQA:E800 + # params with these names will be held out # NOQA:E800 # TODO: params are restored from protobufs as float32 numpy arrays, so # non-floats arrays and non-arrays are not currently supported for # passing to and from protobuf (and as a result for aggregation) - for @@ -46,16 +57,26 @@ def __init__(self, data_loader, tensor_dict_split_fn_kwargs: dict = None, **kwar self.set_logger() def set_logger(self): - """Set up the log object.""" + """Set up the log object. + + Returns: + None + """ self.logger = getLogger(__name__) def set_optimizer_treatment(self, opt_treatment): - """Change the treatment of current instance optimizer.""" + """Change the treatment of current instance optimizer. + + Args: + opt_treatment (str): The optimizer treatment. + + Returns: + None + """ self.opt_treatment = opt_treatment def get_data_loader(self): - """ - Get the data_loader object. + """Get the data_loader object. Serves up batches and provides info regarding data_loader. @@ -68,19 +89,20 @@ def set_data_loader(self, data_loader): """Set data_loader object. Args: - data_loader: data_loader object to set + data_loader: data_loader object to set. + Returns: None """ - if data_loader.get_feature_shape() != self.data_loader.get_feature_shape(): + if data_loader.get_feature_shape( + ) != self.data_loader.get_feature_shape(): raise ValueError( 'The data_loader feature shape is not compatible with model.') self.data_loader = data_loader def get_train_data_size(self): - """ - Get the number of training examples. + """Get the number of training examples. It will be used for weighted averaging in aggregation. @@ -90,8 +112,7 @@ def get_train_data_size(self): return self.data_loader.get_train_data_size() def get_valid_data_size(self): - """ - Get the number of examples. + """Get the number of examples. It will be used for weighted averaging in aggregation. @@ -101,63 +122,65 @@ def get_valid_data_size(self): return self.data_loader.get_valid_data_size() def train_batches(self, num_batches=None, use_tqdm=False): - """ - Perform the training for a specified number of batches. + """Perform the training for a specified number of batches. Is expected to perform draws randomly, without replacement until data is exausted. Then data is replaced and shuffled and draws continue. Args: - num_batches: Number of batches to train - use_tdqm (bool): True = use tqdm progress bar (Default=False) + num_batches (int, optional): Number of batches to train. Default + is None. + use_tqdm (bool, optional): If True, use tqdm to print a progress + bar. Default is False. Returns: - dict: {: } + dict: {: }. """ raise NotImplementedError def validate(self): - """ - Run validation. + """Run validation. Returns: - dict: {: } + dict: {: }. """ raise NotImplementedError def get_required_tensorkeys_for_function(self, func_name, **kwargs): - """ - When running a task, a map of named tensorkeys \ - must be provided to the function as dependencies. + """When running a task, a map of named tensorkeys must be provided to + the function as dependencies. + + Args: + func_name (str): The function name. + **kwargs: Additional parameters to pass to the function. Returns: - list: (TensorKey(tensor_name, origin, round_number)) + list: List of required TensorKey. (TensorKey(tensor_name, origin, + round_number)) """ raise NotImplementedError def get_tensor_dict(self, with_opt_vars): - """ - Get the weights. + """Get the weights. Args: with_opt_vars (bool): Specify if we also want to get the variables - of the optimizer. + of the optimizer. Returns: - dict: The weight dictionary {: } + dict: The weight dictionary {: }. """ raise NotImplementedError def set_tensor_dict(self, tensor_dict, with_opt_vars): - """ - Set the model weights with a tensor dictionary:\ + """Set the model weights with a tensor dictionary: {: }. Args: tensor_dict (dict): The model weights dictionary. with_opt_vars (bool): Specify if we also want to set the variables - of the optimizer. + of the optimizer. Returns: None @@ -165,12 +188,15 @@ def set_tensor_dict(self, tensor_dict, with_opt_vars): raise NotImplementedError def reset_opt_vars(self): - """Reinitialize the optimizer variables.""" + """Reinitialize the optimizer variables. + + Returns: + None + """ raise NotImplementedError def initialize_globals(self): - """ - Initialize all global variables. + """Initialize all global variables. Returns: None @@ -178,18 +204,18 @@ def initialize_globals(self): raise NotImplementedError def load_native(self, filepath, **kwargs): - """ - Load model state from a filepath in ML-framework "native" format, \ - e.g. PyTorch pickled models. + """Load model state from a filepath in ML-framework "native" format, + e.g. PyTorch pickled models. May load from multiple files. Other filepaths may be derived from the passed filepath, or they may be in the kwargs. Args: - filepath (string): Path to frame-work specific file to load. For - frameworks that use multiple files, this string must be used to - derive the other filepaths. - kwargs : For future-proofing + filepath (str): Path to frame-work specific file to load. + For frameworks that use multiple files, this string must be + used to derive the other filepaths. + **kwargs: Additional parameters to pass to the function. For + future-proofing. Returns: None @@ -197,17 +223,17 @@ def load_native(self, filepath, **kwargs): raise NotImplementedError def save_native(self, filepath, **kwargs): - """ - Save model state in ML-framework "native" format, e.g. PyTorch pickled models. + """Save model state in ML-framework "native" format, e.g. PyTorch + pickled models. May save one file or multiple files, depending on the framework. Args: - filepath (string): If framework stores a single file, this should - be a single file path. - Frameworks that store multiple files may need to derive the other - paths from this path. - kwargs : For future-proofing + filepath (str): If framework stores a single file, this should be + a single file path. Frameworks that store multiple files may + need to derive the other paths from this path. + **kwargs: Additional parameters to pass to the function. For + future-proofing. Returns: None diff --git a/openfl/federated/task/runner_gandlf.py b/openfl/federated/task/runner_gandlf.py index 5157ceb7a3..a758e0922f 100644 --- a/openfl/federated/task/runner_gandlf.py +++ b/openfl/federated/task/runner_gandlf.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """GaNDLFTaskRunner module.""" from copy import deepcopy @@ -23,18 +22,42 @@ class GaNDLFTaskRunner(TaskRunner): - """GaNDLF Model class for Federated Learning.""" - - def __init__( - self, - gandlf_config: Union[str, dict] = None, - device: str = None, - **kwargs - ): - """Initialize. + """GaNDLF Model class for Federated Learning. + + This class provides methods to manage and manipulate GaNDLF models in a + federated learning context. + + Attributes: + build_model (function or class): Function or Class to build the model. + lambda_opt (function): Lambda function for the optimizer. + model (Model): The built model. + optimizer (Optimizer): Optimizer for the model. + scheduler (Scheduler): Scheduler for the model. + params (Parameters): Parameters for the model. + device (str): Device for the model. + training_round_completed (bool): Whether the training round has been + completed. + required_tensorkeys_for_function (dict): Required tensorkeys for + function. + tensor_dict_split_fn_kwargs (dict): Keyword arguments for the tensor + dict split function. + """ + + def __init__(self, + gandlf_config: Union[str, dict] = None, + device: str = None, + **kwargs): + """Initializes the GaNDLFTaskRunner object. + + Sets up the initial state of the GaNDLFTaskRunner object, initializing + various components needed for the federated model. + Args: - device (string): Compute device (default="cpu") - **kwargs: Additional parameters to pass to the functions + gandlf_config (Union[str, dict], optional): GaNDLF configuration. + Can be a string (file path) or a dictionary. Defaults to None. + device (str, optional): Compute device. Defaults to None + (default="cpu"). + **kwargs: Additional parameters to pass to the function. """ super().__init__(**kwargs) @@ -57,9 +80,10 @@ def __init__( val_loader, scheduler, params, - ) = create_pytorch_objects( - gandlf_config, train_csv=train_csv, val_csv=val_csv, device=device - ) + ) = create_pytorch_objects(gandlf_config, + train_csv=train_csv, + val_csv=val_csv, + device=device) self.model = model self.optimizer = optimizer self.scheduler = scheduler @@ -79,13 +103,20 @@ def __init__( # overwrite attribute to account for one optimizer param (in every # child model that does not overwrite get and set tensordict) that is # not a numpy array - self.tensor_dict_split_fn_kwargs.update({ - 'holdout_tensor_names': ['__opt_state_needed'] - }) + self.tensor_dict_split_fn_kwargs.update( + {'holdout_tensor_names': ['__opt_state_needed']}) def rebuild_model(self, round_num, input_tensor_dict, validation=False): - """ - Parse tensor names and update weights of model. Handles the optimizer treatment. + """Parse tensor names and update weights of model. Handles the + optimizer treatment. + + Args: + round_num: The current round number. + input_tensor_dict (dict): The input tensor dictionary used to + update the weights of the model. + validation (bool, optional): A flag indicating whether the model + is in validation. Defaults to False. + Returns: None """ @@ -99,29 +130,36 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False): else: self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - def validate(self, col_name, round_num, input_tensor_dict, - use_tqdm=False, **kwargs): - """Validate. - Run validation of the model on the local data. + def validate(self, + col_name, + round_num, + input_tensor_dict, + use_tqdm=False, + **kwargs): + """Run validation of the model on the local data. + Args: - col_name: Name of the collaborator - round_num: What round is it - input_tensor_dict: Required input tensors (for model) - use_tqdm (bool): Use tqdm to print a progress bar (Default=True) - kwargs: Key word arguments passed to GaNDLF main_run + col_name (str): Name of the collaborator. + round_num (int): Current round number. + input_tensor_dict (dict): Required input tensors (for model). + use_tqdm (bool, optional): Use tqdm to print a progress bar. + Defaults to False. + **kwargs: Key word arguments passed to GaNDLF main_run. + Returns: - global_output_dict: Tensors to send back to the aggregator - local_output_dict: Tensors to maintain in the local TensorDB + output_tensor_dict (dict): Tensors to send back to the aggregator. + {} (dict): Tensors to maintain in the local TensorDB. """ self.rebuild_model(round_num, input_tensor_dict, validation=True) self.model.eval() - epoch_valid_loss, epoch_valid_metric = validate_network(self.model, - self.data_loader.val_dataloader, - self.scheduler, - self.params, - round_num, - mode="validation") + epoch_valid_loss, epoch_valid_metric = validate_network( + self.model, + self.data_loader.val_dataloader, + self.scheduler, + self.params, + round_num, + mode="validation") self.logger.info(epoch_valid_loss) self.logger.info(epoch_valid_metric) @@ -135,7 +173,8 @@ def validate(self, col_name, round_num, input_tensor_dict, tags = ('metric', suffix) output_tensor_dict = {} - valid_loss_tensor_key = TensorKey('valid_loss', origin, round_num, True, tags) + valid_loss_tensor_key = TensorKey('valid_loss', origin, round_num, + True, tags) output_tensor_dict[valid_loss_tensor_key] = np.array(epoch_valid_loss) for k, v in epoch_valid_metric.items(): tensor_key = TensorKey(f'valid_{k}', origin, round_num, True, tags) @@ -144,30 +183,28 @@ def validate(self, col_name, round_num, input_tensor_dict, # Empty list represents metrics that should only be stored locally return output_tensor_dict, {} - def train(self, col_name, round_num, input_tensor_dict, use_tqdm=False, epochs=1, **kwargs): - """Train batches. - Train the model on the requested number of batches. + def train(self, + col_name, + round_num, + input_tensor_dict, + use_tqdm=False, + epochs=1, + **kwargs): + """Train the model on the requested number of batches. + Args: - col_name : Name of the collaborator - round_num : What round is it - input_tensor_dict : Required input tensors (for model) - use_tqdm (bool) : Use tqdm to print a progress bar (Default=True) - epochs : The number of epochs to train - crossfold_test : Whether or not to use cross fold trainval/test - to evaluate the quality of the model under fine tuning - (this uses a separate prameter to pass in the data and - config used) - crossfold_test_data_csv : Data csv used to define data used in crossfold test. - This csv does not itself define the folds, just - defines the total data to be used. - crossfold_val_n : number of folds to use for the train,val level - of the nested crossfold. - corssfold_test_n : number of folds to use for the trainval,test level - of the nested crossfold. - kwargs : Key word arguments passed to GaNDLF main_run + col_name (str): Name of the collaborator. + round_num (int): Current round number. + input_tensor_dict (dict): Required input tensors (for model). + use_tqdm (bool, optional): Use tqdm to print a progress bar. + Defaults to False. + epochs (int, optional): The number of epochs to train. Defaults to 1. + **kwargs: Key word arguments passed to GaNDLF main_run. + Returns: - global_output_dict : Tensors to send back to the aggregator - local_output_dict : Tensors to maintain in the local TensorDB + global_tensor_dict (dict): Tensors to send back to the aggregator. + local_tensor_dict (dict): Tensors to maintain in the local + TensorDB. """ self.rebuild_model(round_num, input_tensor_dict) # set to "training" mode @@ -176,10 +213,9 @@ def train(self, col_name, round_num, input_tensor_dict, use_tqdm=False, epochs=1 self.logger.info(f'Run {epoch} epoch of {round_num} round') # FIXME: do we want to capture these in an array # rather than simply taking the last value? - epoch_train_loss, epoch_train_metric = train_network(self.model, - self.data_loader.train_dataloader, - self.optimizer, - self.params) + epoch_train_loss, epoch_train_metric = train_network( + self.model, self.data_loader.train_dataloader, self.optimizer, + self.params) # output model tensors (Doesn't include TensorKey) tensor_dict = self.get_tensor_dict(with_opt_vars=True) @@ -221,11 +257,13 @@ def train(self, col_name, round_num, input_tensor_dict, use_tqdm=False, epochs=1 def get_tensor_dict(self, with_opt_vars=False): """Return the tensor dictionary. + Args: with_opt_vars (bool): Return the tensor dictionary including the - optimizer tensors (Default=False) + optimizer tensors (Default=False). + Returns: - dict: Tensor dictionary {**dict, **optimizer_dict} + state (dict): Tensor dictionary {**dict, **optimizer_dict} """ # Gets information regarding tensor model layers and optimizer state. # FIXME: self.parameters() instead? Unclear if load_state_dict() or @@ -242,6 +280,15 @@ def get_tensor_dict(self, with_opt_vars=False): return state def _get_weights_names(self, with_opt_vars=False): + """Get the names of the weights. + + Args: + with_opt_vars (bool, optional): Include the optimizer variables. + Defaults to False. + + Returns: + list: List of weight names. + """ # Gets information regarding tensor model layers and optimizer state. # FIXME: self.parameters() instead? Unclear if load_state_dict() or # simple assignment is better @@ -258,41 +305,53 @@ def _get_weights_names(self, with_opt_vars=False): def set_tensor_dict(self, tensor_dict, with_opt_vars=False): """Set the tensor dictionary. + Args: - tensor_dict: The tensor dictionary - with_opt_vars (bool): Return the tensor dictionary including the - optimizer tensors (Default=False) + tensor_dict (dict): The tensor dictionary. + with_opt_vars (bool, optional): Include the optimizer tensors. + Defaults to False. """ - set_pt_model_from_tensor_dict(self.model, tensor_dict, self.device, with_opt_vars) + set_pt_model_from_tensor_dict(self.model, tensor_dict, self.device, + with_opt_vars) def get_optimizer(self): - """Get the optimizer of this instance.""" + """Get the optimizer of this instance. + + Returns: + Optimizer: The optimizer of this instance. + """ return self.optimizer def get_required_tensorkeys_for_function(self, func_name, **kwargs): - """ - Get the required tensors for specified function that could be called \ - as part of a task. By default, this is just all of the layers and \ - optimizer of the model. + """Get the required tensors for specified function that could be called + as part of a task. + + By default, this is just all of the layers and optimizer of the model. + Args: - func_name + func_name (str): Function name. + **kwargs: Additional keyword arguments. + Returns: - list : [TensorKey] + required_tensorkeys_for_function (list): List of required + TensorKey. """ if func_name == 'validate': local_model = 'apply=' + str(kwargs['apply']) - return self.required_tensorkeys_for_function[func_name][local_model] + return self.required_tensorkeys_for_function[func_name][ + local_model] else: return self.required_tensorkeys_for_function[func_name] def initialize_tensorkeys_for_functions(self, with_opt_vars=False): """Set the required tensors for all publicly accessible task methods. + By default, this is just all of the layers and optimizer of the model. Custom tensors should be added to this function. + Args: - None - Returns: - None + with_opt_vars (bool, optional): Include the optimizer tensors. + Defaults to False. """ # TODO there should be a way to programmatically iterate through # all of the methods in the class and declare the tensors. @@ -300,29 +359,23 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): output_model_dict = self.get_tensor_dict(with_opt_vars=with_opt_vars) global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( - self.logger, output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) + self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs) if not with_opt_vars: global_model_dict_val = global_model_dict local_model_dict_val = local_model_dict else: output_model_dict = self.get_tensor_dict(with_opt_vars=False) global_model_dict_val, local_model_dict_val = split_tensor_dict_for_holdouts( - self.logger, - output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) + self.logger, output_model_dict, + **self.tensor_dict_split_fn_kwargs) self.required_tensorkeys_for_function['train'] = [ - TensorKey( - tensor_name, 'GLOBAL', 0, False, ('model',) - ) for tensor_name in global_model_dict + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model', )) + for tensor_name in global_model_dict ] self.required_tensorkeys_for_function['train'] += [ - TensorKey( - tensor_name, 'LOCAL', 0, False, ('model',) - ) for tensor_name in local_model_dict + TensorKey(tensor_name, 'LOCAL', 0, False, ('model', )) + for tensor_name in local_model_dict ] # Validation may be performed on local or aggregated (global) model, @@ -330,57 +383,58 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): self.required_tensorkeys_for_function['validate'] = {} # TODO This is not stateless. The optimizer will not be self.required_tensorkeys_for_function['validate']['apply=local'] = [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('trained',)) + TensorKey(tensor_name, 'LOCAL', 0, False, ('trained', )) for tensor_name in { **global_model_dict_val, **local_model_dict_val - }] + } + ] self.required_tensorkeys_for_function['validate']['apply=global'] = [ - TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model', )) for tensor_name in global_model_dict_val ] self.required_tensorkeys_for_function['validate']['apply=global'] += [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) + TensorKey(tensor_name, 'LOCAL', 0, False, ('model', )) for tensor_name in local_model_dict_val ] - def load_native(self, filepath, model_state_dict_key='model_state_dict', - optimizer_state_dict_key='optimizer_state_dict', **kwargs): - """ - Load model and optimizer states from a pickled file specified by \ - filepath. model_/optimizer_state_dict args can be specified if needed. \ + def load_native(self, + filepath, + model_state_dict_key='model_state_dict', + optimizer_state_dict_key='optimizer_state_dict', + **kwargs): + """Load model and optimizer states from a pickled file specified by + filepath. model_/optimizer_state_dict args can be specified if needed. Uses pt.load(). + Args: - filepath (string) : Path to pickle file created - by pt.save(). - model_state_dict_key (string) : key for model state dict - in pickled file. - optimizer_state_dict_key (string) : key for optimizer state dict - in picked file. - kwargs : unused - Returns: - None + filepath (str): Path to pickle file created by pt.save(). + model_state_dict_key (str, optional): Key for model state dict in + pickled file. Defaults to 'model_state_dict'. + optimizer_state_dict_key (str, optional): Key for optimizer state + dict in picked file. Defaults to 'optimizer_state_dict'. + **kwargs: Additional keyword arguments. """ pickle_dict = pt.load(filepath) self.model.load_state_dict(pickle_dict[model_state_dict_key]) self.optimizer.load_state_dict(pickle_dict[optimizer_state_dict_key]) - def save_native(self, filepath, model_state_dict_key='model_state_dict', - optimizer_state_dict_key='optimizer_state_dict', **kwargs): - """ - Save model and optimizer states in a picked file specified by the \ - filepath. model_/optimizer_state_dicts are stored in the keys provided. \ + def save_native(self, + filepath, + model_state_dict_key='model_state_dict', + optimizer_state_dict_key='optimizer_state_dict', + **kwargs): + """Save model and optimizer states in a picked file specified by the + filepath. model_/optimizer_state_dicts are stored in the keys provided. Uses pt.save(). + Args: - filepath (string) : Path to pickle file to be - created by pt.save(). - model_state_dict_key (string) : key for model state dict - in pickled file. - optimizer_state_dict_key (string) : key for optimizer state - dict in picked file. - kwargs : unused - Returns: - None + filepath (str): Path to pickle file to be created by pt.save(). + model_state_dict_key (str, optional): Key for model state dict in + pickled file. Defaults to 'model_state_dict'. + optimizer_state_dict_key (str, optional): Key for optimizer state + dict in picked file. Defaults to 'optimizer_state_dict'. + **kwargs: Additional keyword arguments. """ pickle_dict = { model_state_dict_key: self.model.state_dict(), @@ -389,51 +443,57 @@ def save_native(self, filepath, model_state_dict_key='model_state_dict', pt.save(pickle_dict, filepath) def reset_opt_vars(self): - """ - Reset optimizer variables. - Resets the optimizer variables - """ + """Reset optimizer variables.""" pass -def create_tensorkey_dicts(tensor_dict, - metric_dict, - col_name, - round_num, - logger, - tensor_dict_split_fn_kwargs): +def create_tensorkey_dicts(tensor_dict, metric_dict, col_name, round_num, + logger, tensor_dict_split_fn_kwargs): + """Create dictionaries of TensorKeys for global and local tensors. + + Args: + tensor_dict (dict): Dictionary of tensors. + metric_dict (dict): Dictionary of metrics. + col_name (str): Name of the collaborator. + round_num (int): Current round number. + logger (Logger): Logger instance. + tensor_dict_split_fn_kwargs (dict): Keyword arguments for the tensor + dict split function. + + Returns: + global_tensor_dict (dict): Dictionary of global TensorKeys. + local_tensor_dict (dict): Dictionary of local TensorKeys. + """ origin = col_name - tags = ('trained',) + tags = ('trained', ) output_metric_dict = {} for k, v in metric_dict.items(): - tk = TensorKey(k, origin, round_num, True, ('metric',)) + tk = TensorKey(k, origin, round_num, True, ('metric', )) output_metric_dict[tk] = np.array(v) global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( - logger, tensor_dict, **tensor_dict_split_fn_kwargs - ) + logger, tensor_dict, **tensor_dict_split_fn_kwargs) # Create global tensorkeys global_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num, False, tags): - nparray for tensor_name, nparray in global_model_dict.items() + TensorKey(tensor_name, origin, round_num, False, tags): nparray + for tensor_name, nparray in global_model_dict.items() } # Create tensorkeys that should stay local local_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num, False, tags): - nparray for tensor_name, nparray in local_model_dict.items() + TensorKey(tensor_name, origin, round_num, False, tags): nparray + for tensor_name, nparray in local_model_dict.items() } # The train/validate aggregated function of the next round will look # for the updated model parameters. # This ensures they will be resolved locally next_local_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num + 1, False, ('model',)): nparray - for tensor_name, nparray in local_model_dict.items()} - - global_tensor_dict = { - **output_metric_dict, - **global_tensorkey_model_dict + TensorKey(tensor_name, origin, round_num + 1, False, ('model', )): + nparray + for tensor_name, nparray in local_model_dict.items() } + + global_tensor_dict = {**output_metric_dict, **global_tensorkey_model_dict} local_tensor_dict = { **local_tensorkey_model_dict, **next_local_tensorkey_model_dict @@ -442,14 +502,18 @@ def create_tensorkey_dicts(tensor_dict, return global_tensor_dict, local_tensor_dict -def set_pt_model_from_tensor_dict(model, tensor_dict, device, with_opt_vars=False): - """Set the tensor dictionary. +def set_pt_model_from_tensor_dict(model, + tensor_dict, + device, + with_opt_vars=False): + """Set the tensor dictionary for the PyTorch model. + Args: - model: the pytorch nn.module object - tensor_dict: The tensor dictionary - device: the device where the tensor values need to be sent - with_opt_vars (bool): Return the tensor dictionary including the - optimizer tensors (Default=False) + model (Model): The PyTorch model. + tensor_dict (dict): Tensor dictionary. + device (str): Device for the model. + with_opt_vars (bool, optional): Include the optimizer tensors. + Defaults to False. """ # Sets tensors for model layers and optimizer state. # FIXME: model.parameters() instead? Unclear if load_state_dict() or @@ -477,12 +541,17 @@ def set_pt_model_from_tensor_dict(model, tensor_dict, device, with_opt_vars=Fals def _derive_opt_state_dict(opt_state_dict): """Separate optimizer tensors from the tensor dictionary. + Flattens the optimizer state dict so as to have key, value pairs with values as numpy arrays. The keys have sufficient info to restore opt_state_dict using expand_derived_opt_state_dict. + Args: - opt_state_dict: The optimizer state dictionary + opt_state_dict (dict): Optimizer state dictionary. + + Returns: + derived_opt_state_dict (dict): Optimizer state dictionary. """ derived_opt_state_dict = {} @@ -497,8 +566,7 @@ def _derive_opt_state_dict(opt_state_dict): # dictionary value. example_state_key = opt_state_dict['param_groups'][0]['params'][0] example_state_subkeys = set( - opt_state_dict['state'][example_state_key].keys() - ) + opt_state_dict['state'][example_state_key].keys()) # We assume that the state collected for all params in all param groups is # the same. @@ -507,12 +575,12 @@ def _derive_opt_state_dict(opt_state_dict): # Using assert statements to break the routine if these assumptions are # incorrect. for state_key in opt_state_dict['state'].keys(): - assert example_state_subkeys == set(opt_state_dict['state'][state_key].keys()) + assert example_state_subkeys == set( + opt_state_dict['state'][state_key].keys()) for state_subkey in example_state_subkeys: assert (isinstance( opt_state_dict['state'][example_state_key][state_subkey], - pt.Tensor) - == isinstance( + pt.Tensor) == isinstance( opt_state_dict['state'][state_key][state_subkey], pt.Tensor)) @@ -522,10 +590,8 @@ def _derive_opt_state_dict(opt_state_dict): # tensor or not. state_subkey_tags = [] for state_subkey in state_subkeys: - if isinstance( - opt_state_dict['state'][example_state_key][state_subkey], - pt.Tensor - ): + if isinstance(opt_state_dict['state'][example_state_key][state_subkey], + pt.Tensor): state_subkey_tags.append('istensor') else: state_subkey_tags.append('') @@ -539,33 +605,36 @@ def _derive_opt_state_dict(opt_state_dict): for idx, param_id in enumerate(group['params']): for subkey, tag in state_subkeys_and_tags: if tag == 'istensor': - new_v = opt_state_dict['state'][param_id][ - subkey].cpu().numpy() + new_v = opt_state_dict['state'][param_id][subkey].cpu( + ).numpy() else: new_v = np.array( - [opt_state_dict['state'][param_id][subkey]] - ) - derived_opt_state_dict[f'__opt_state_{group_idx}_{idx}_{tag}_{subkey}'] = new_v + [opt_state_dict['state'][param_id][subkey]]) + derived_opt_state_dict[ + f'__opt_state_{group_idx}_{idx}_{tag}_{subkey}'] = new_v nb_params_per_group.append(idx + 1) # group lengths are also helpful for reconstructing # original opt_state_dict structure derived_opt_state_dict['__opt_group_lengths'] = np.array( - nb_params_per_group - ) + nb_params_per_group) return derived_opt_state_dict def expand_derived_opt_state_dict(derived_opt_state_dict, device): """Expand the optimizer state dictionary. + Takes a derived opt_state_dict and creates an opt_state_dict suitable as input for load_state_dict for restoring optimizer state. - Reconstructing state_subkeys_and_tags using the example key - prefix, "__opt_state_0_0_", certain to be present. + Reconstructing state_subkeys_and_tags using the example key prefix, + "__opt_state_0_0_", certain to be present. + Args: - derived_opt_state_dict: Optimizer state dictionary + derived_opt_state_dict (dict): Derived optimizer state dictionary. + device (str): Device for the model. + Returns: - dict: Optimizer state dictionary + opt_state_dict (dict): Expanded optimizer state dictionary. """ state_subkeys_and_tags = [] for key in derived_opt_state_dict: @@ -581,8 +650,7 @@ def expand_derived_opt_state_dict(derived_opt_state_dict, device): opt_state_dict = {'param_groups': [], 'state': {}} nb_params_per_group = list( - derived_opt_state_dict.pop('__opt_group_lengths').astype(np.int32) - ) + derived_opt_state_dict.pop('__opt_group_lengths').astype(np.int32)) # Construct the expanded dict. for group_idx, nb_params in enumerate(nb_params_per_group): @@ -609,9 +677,13 @@ def expand_derived_opt_state_dict(derived_opt_state_dict, device): def _get_optimizer_state(optimizer): - """Return the optimizer state. + """Get the state of the optimizer. + Args: - optimizer + optimizer (Optimizer): Optimizer. + + Returns: + derived_opt_state_dict (dict): State of the optimizer. """ opt_state_dict = deepcopy(optimizer.state_dict()) @@ -629,14 +701,15 @@ def _get_optimizer_state(optimizer): def _set_optimizer_state(optimizer, device, derived_opt_state_dict): - """Set the optimizer state. + """Set the state of the optimizer. + Args: - optimizer: - device: - derived_opt_state_dict: + optimizer (Optimizer): Optimizer. + device (str): Device for the model. + derived_opt_state_dict (dict): Derived optimizer state dictionary. """ - temp_state_dict = expand_derived_opt_state_dict( - derived_opt_state_dict, device) + temp_state_dict = expand_derived_opt_state_dict(derived_opt_state_dict, + device) # FIXME: Figure out whether or not this breaks learning rate # scheduling and the like. @@ -651,9 +724,13 @@ def _set_optimizer_state(optimizer, device, derived_opt_state_dict): def to_cpu_numpy(state): - """Send data to CPU as Numpy array. + """Convert state to CPU as Numpy array. + Args: - state + state (State): State to be converted. + + Returns: + state (dict): State as Numpy array. """ # deep copy so as to decouple from active model state = deepcopy(state) diff --git a/openfl/federated/task/runner_keras.py b/openfl/federated/task/runner_keras.py index c7daaa3d33..a3556ca460 100644 --- a/openfl/federated/task/runner_keras.py +++ b/openfl/federated/task/runner_keras.py @@ -1,8 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -""" -Base classes for developing a ke.Model() Federated Learning model. +"""Base classes for developing a ke.Model() Federated Learning model. You may copy this file as the starting point of your own keras model. """ @@ -24,11 +22,17 @@ class KerasTaskRunner(TaskRunner): - """The base model for Keras models in the federation.""" + """The base model for Keras models in the federation. + + Attributes: + model (ke.Model): The Keras model. + model_tensor_names (list): List of model tensor names. + required_tensorkeys_for_function (dict): A map of all of the required + tensors for each of the public functions in KerasTaskRunner. + """ def __init__(self, **kwargs): - """ - Initialize. + """Initializes the KerasTaskRunner instance. Args: **kwargs: Additional parameters to pass to the function @@ -45,12 +49,14 @@ def __init__(self, **kwargs): ke.backend.clear_session() def rebuild_model(self, round_num, input_tensor_dict, validation=False): - """ - Parse tensor names and update weights of model. Handles the optimizer treatment. + """Parse tensor names and update weights of model. Handles the + optimizer treatment. - Returns - ------- - None + Args: + round_num (int): The round number. + input_tensor_dict (dict): The input tensor dictionary. + validation (bool, optional): If True, validate the model. Defaults + to False. """ if self.opt_treatment == 'RESET': self.reset_opt_vars() @@ -61,18 +67,30 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False): else: self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - def train(self, col_name, round_num, input_tensor_dict, - metrics, epochs=1, batch_size=1, **kwargs): - """ - Perform the training. + def train(self, + col_name, + round_num, + input_tensor_dict, + metrics, + epochs=1, + batch_size=1, + **kwargs): + """Perform the training. Is expected to perform draws randomly, without + replacement until data is exausted. Then data is replaced and shuffled + and draws continue. - Is expected to perform draws randomly, without replacement until data is exausted. - Then data is replaced and shuffled and draws continue. + Args: + col_name (str): The collaborator name. + round_num (int): The round number. + input_tensor_dict (dict): The input tensor dictionary. + metrics (list): List of metrics. + epochs (int, optional): Number of epochs to train. Defaults to 1. + batch_size (int, optional): Batch size. Defaults to 1. + **kwargs: Additional parameters. - Returns - ------- - dict - 'TensorKey: nparray' + Returns: + global_tensor_dict (dict): Dictionary of 'TensorKey: nparray'. + local_tensor_dict (dict): Dictionary of 'TensorKey: nparray'. """ if metrics is None: raise KeyError('metrics must be defined') @@ -81,44 +99,42 @@ def train(self, col_name, round_num, input_tensor_dict, self.rebuild_model(round_num, input_tensor_dict) for epoch in range(epochs): self.logger.info(f'Run {epoch} epoch of {round_num} round') - results = self.train_iteration(self.data_loader.get_train_loader(batch_size), - metrics=metrics, - **kwargs) + results = self.train_iteration( + self.data_loader.get_train_loader(batch_size), + metrics=metrics, + **kwargs) # output metric tensors (scalar) origin = col_name - tags = ('trained',) + tags = ('trained', ) output_metric_dict = { - TensorKey( - metric_name, origin, round_num, True, ('metric',) - ): metric_value + TensorKey(metric_name, origin, round_num, True, ('metric', )): + metric_value for (metric_name, metric_value) in results } # output model tensors (Doesn't include TensorKey) output_model_dict = self.get_tensor_dict(with_opt_vars=True) global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( - self.logger, output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) + self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs) # create global tensorkeys global_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num, False, tags): - nparray for tensor_name, nparray in global_model_dict.items() + TensorKey(tensor_name, origin, round_num, False, tags): nparray + for tensor_name, nparray in global_model_dict.items() } # create tensorkeys that should stay local local_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num, False, tags): - nparray for tensor_name, nparray in local_model_dict.items() + TensorKey(tensor_name, origin, round_num, False, tags): nparray + for tensor_name, nparray in local_model_dict.items() } # the train/validate aggregated function of the next round will look # for the updated model parameters. # this ensures they will be resolved locally next_local_tensorkey_model_dict = { - TensorKey( - tensor_name, origin, round_num + 1, False, ('model',) - ): nparray for tensor_name, nparray in local_model_dict.items() + TensorKey(tensor_name, origin, round_num + 1, False, ('model', )): + nparray + for tensor_name, nparray in local_model_dict.items() } global_tensor_dict = { @@ -146,17 +162,16 @@ def train(self, col_name, round_num, input_tensor_dict, return global_tensor_dict, local_tensor_dict def train_iteration(self, batch_generator, metrics: list = None, **kwargs): - """Train single epoch. - - Override this function for custom training. + """Train single epoch. Override this function for custom training. Args: - batch_generator: Generator of training batches. - Each batch is a tuple of N train images and N train labels - where N is the batch size of the DataLoader of the current TaskRunner instance. + batch_generator (generator): Generator of training batches. + metrics (list, optional): Names of metrics to save. Defaults to + None. + **kwargs: Additional parameters. - epochs: Number of epochs to train. - metrics: Names of metrics to save. + Returns: + results (list): List of Metric objects. """ if metrics is None: metrics = [] @@ -176,9 +191,7 @@ def train_iteration(self, batch_generator, metrics: list = None, **kwargs): f'Param_metrics = {metrics}, model_metrics_names = {model_metrics_names}' ) - history = self.model.fit(batch_generator, - verbose=1, - **kwargs) + history = self.model.fit(batch_generator, verbose=1, **kwargs) results = [] for metric in metrics: value = np.mean([history.history[metric]]) @@ -186,17 +199,19 @@ def train_iteration(self, batch_generator, metrics: list = None, **kwargs): return results def validate(self, col_name, round_num, input_tensor_dict, **kwargs): - """ - Run the trained model on validation data; report results. + """Run the trained model on validation data; report results. - Parameters - ---------- - input_tensor_dict : either the last aggregated or locally trained model + Args: + col_name (str): The collaborator name. + round_num (int): The round number. + input_tensor_dict (dict): The input tensor dictionary. Either the + last aggregated or locally trained model + **kwargs: Additional parameters. - Returns - ------- - output_tensor_dict : {TensorKey: nparray} (these correspond to acc, - precision, f1_score, etc.) + Returns: + output_tensor_dict (dict): Dictionary of 'TensorKey: nparray'. + These correspond to acc, precision, f1_score, etc. + dict: Empty dictionary. """ if 'batch_size' in kwargs: batch_size = kwargs['batch_size'] @@ -207,9 +222,7 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): param_metrics = kwargs['metrics'] vals = self.model.evaluate( - self.data_loader.get_valid_loader(batch_size), - verbose=1 - ) + self.data_loader.get_valid_loader(batch_size), verbose=1) model_metrics_names = self.model.metrics_names if type(vals) is not list: vals = [vals] @@ -231,55 +244,57 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): suffix += '_local' else: suffix += '_agg' - tags = ('metric',) + tags = ('metric', ) tags = change_tags(tags, add_field=suffix) output_tensor_dict = { TensorKey(metric, origin, round_num, True, tags): - np.array(ret_dict[metric]) - for metric in param_metrics} + np.array(ret_dict[metric]) + for metric in param_metrics + } return output_tensor_dict, {} def save_native(self, filepath): - """Save model.""" + """Save model. + + Args: + filepath (str): The file path to save the model. + """ self.model.save(filepath) def load_native(self, filepath): - """Load model.""" + """Load model. + + Args: + filepath (str): The file path to load the model. + """ self.model = ke.models.load_model(filepath) @staticmethod def _get_weights_names(obj): - """ - Get the list of weight names. + """Get the list of weight names. - Parameters - ---------- - obj : Model or Optimizer - The target object that we want to get the weights. + Args: + obj (Model or Optimizer): The target object that we want to get + the weights. - Returns - ------- - dict - The weight name list + Returns: + weight_names (list): The weight name list. """ weight_names = [weight.name for weight in obj.weights] return weight_names @staticmethod def _get_weights_dict(obj, suffix=''): - """ - Get the dictionary of weights. + """Get the dictionary of weights. - Parameters - ---------- - obj : Model or Optimizer - The target object that we want to get the weights. + Args: + obj (Model or Optimizer): The target object that we want to get + the weights. + suffix (str, optional): Suffix for weight names. Defaults to ''. - Returns - ------- - dict - The weight dictionary. + Returns: + weights_dict (dict): The weight dictionary. """ weights_dict = {} weight_names = [weight.name for weight in obj.weights] @@ -292,33 +307,24 @@ def _get_weights_dict(obj, suffix=''): def _set_weights_dict(obj, weights_dict): """Set the object weights with a dictionary. - The obj can be a model or an optimizer. - Args: obj (Model or Optimizer): The target object that we want to set - the weights. + the weights. weights_dict (dict): The weight dictionary. - - Returns: - None """ weight_names = [weight.name for weight in obj.weights] weight_values = [weights_dict[name] for name in weight_names] obj.set_weights(weight_values) def get_tensor_dict(self, with_opt_vars, suffix=''): - """ - Get the model weights as a tensor dictionary. + """Get the model weights as a tensor dictionary. - Parameters - ---------- - with_opt_vars : bool - If we should include the optimizer's status. - suffix : string - Universally + Args: + with_opt_vars (bool): If we should include the optimizer's status. + suffix (str): Universally. Returns: - dict: The tensor dictionary. + model_weights (dict): The tensor dictionary. """ model_weights = self._get_weights_dict(self.model, suffix) @@ -332,65 +338,55 @@ def get_tensor_dict(self, with_opt_vars, suffix=''): return model_weights def set_tensor_dict(self, tensor_dict, with_opt_vars): - """ - Set the model weights with a tensor dictionary. + """Set the model weights with a tensor dictionary. Args: - tensor_dict: the tensor dictionary + tensor_dict (dict): The tensor dictionary. with_opt_vars (bool): True = include the optimizer's status. """ if with_opt_vars is False: - # It is possible to pass in opt variables from the input tensor dict - # This will make sure that the correct layers are updated + # It is possible to pass in opt variables from the input tensor + # dict. This will make sure that the correct layers are updated model_weight_names = [weight.name for weight in self.model.weights] model_weights_dict = { - name: tensor_dict[name] for name in model_weight_names + name: tensor_dict[name] + for name in model_weight_names } self._set_weights_dict(self.model, model_weights_dict) else: - model_weight_names = [ - weight.name for weight in self.model.weights - ] + model_weight_names = [weight.name for weight in self.model.weights] model_weights_dict = { - name: tensor_dict[name] for name in model_weight_names + name: tensor_dict[name] + for name in model_weight_names } opt_weight_names = [ weight.name for weight in self.model.optimizer.weights ] opt_weights_dict = { - name: tensor_dict[name] for name in opt_weight_names + name: tensor_dict[name] + for name in opt_weight_names } self._set_weights_dict(self.model, model_weights_dict) self._set_weights_dict(self.model.optimizer, opt_weights_dict) def reset_opt_vars(self): - """ - Reset optimizer variables. - - Resets the optimizer variables - - """ + """Resets the optimizer variables.""" for var in self.model.optimizer.variables(): var.assign(tf.zeros_like(var)) self.logger.debug('Optimizer variables reset') - def set_required_tensorkeys_for_function(self, func_name, - tensor_key, **kwargs): - """ - Set the required tensors for specified function that could be called as part of a task. + def set_required_tensorkeys_for_function(self, func_name, tensor_key, + **kwargs): + """Set the required tensors for specified function that could be called + as part of a task. By default, this is just all of the layers and optimizer of the model. - Custom tensors should be added to this function - - Parameters - ---------- - func_name: string - tensor_key: TensorKey (namedtuple) - **kwargs: Any function arguments {} + Custom tensors should be added to this function. - Returns - ------- - None + Args: + func_name (str): The function name. + tensor_key (TensorKey): The tensor key. + **kwargs: Any function arguments. """ # TODO there should be a way to programmatically iterate through all # of the methods in the class and declare the tensors. @@ -405,41 +401,31 @@ def set_required_tensorkeys_for_function(self, func_name, self.required_tensorkeys_for_function[func_name].append(tensor_key) def get_required_tensorkeys_for_function(self, func_name, **kwargs): - """ - Get the required tensors for specified function that could be called as part of a task. + """Get the required tensors for specified function that could be called + as part of a task. By default, this is just all of the layers and optimizer of the model. - Parameters - ---------- - None + Args: + func_name (str): The function name. + **kwargs: Any function arguments. - Returns - ------- - List - [TensorKey] + Returns: + list: List of TensorKey objects. """ if func_name == 'validate': local_model = 'apply=' + str(kwargs['apply']) - return self.required_tensorkeys_for_function[func_name][local_model] + return self.required_tensorkeys_for_function[func_name][ + local_model] else: return self.required_tensorkeys_for_function[func_name] def update_tensorkeys_for_functions(self): - """ - Update the required tensors for all publicly accessible methods \ - that could be called as part of a task. + """Update the required tensors for all publicly accessible methods that + could be called as part of a task. By default, this is just all of the layers and optimizer of the model. Custom tensors should be added to this function - - Parameters - ---------- - None - - Returns - ------- - None """ # TODO complete this function. It is only needed for opt_treatment, # and making the model stateless @@ -450,7 +436,7 @@ def update_tensorkeys_for_functions(self): tensor_names = model_layer_names + opt_names self.logger.debug(f'Updating model tensor names: {tensor_names}') self.required_tensorkeys_for_function['train'] = [ - TensorKey(tensor_name, 'GLOBAL', 0, ('model',)) + TensorKey(tensor_name, 'GLOBAL', 0, ('model', )) for tensor_name in tensor_names ] @@ -467,20 +453,15 @@ def update_tensorkeys_for_functions(self): ] def initialize_tensorkeys_for_functions(self, with_opt_vars=False): - """ - Set the required tensors for all publicly accessible methods \ - that could be called as part of a task. + """Set the required tensors for all publicly accessible methods that + could be called as part of a task. By default, this is just all of the layers and optimizer of the model. Custom tensors should be added to this function - Parameters - ---------- - None - - Returns - ------- - None + Args: + with_opt_vars (bool, optional): If True, include the optimizer's + status. Defaults to False. """ # TODO there should be a way to programmatically iterate through all # of the methods in the class and declare the tensors. @@ -488,26 +469,22 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): output_model_dict = self.get_tensor_dict(with_opt_vars=with_opt_vars) global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( - self.logger, output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) + self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs) if not with_opt_vars: global_model_dict_val = global_model_dict local_model_dict_val = local_model_dict else: output_model_dict = self.get_tensor_dict(with_opt_vars=False) global_model_dict_val, local_model_dict_val = split_tensor_dict_for_holdouts( - self.logger, - output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) + self.logger, output_model_dict, + **self.tensor_dict_split_fn_kwargs) self.required_tensorkeys_for_function['train'] = [ - TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model', )) for tensor_name in global_model_dict ] self.required_tensorkeys_for_function['train'] += [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) + TensorKey(tensor_name, 'LOCAL', 0, False, ('model', )) for tensor_name in local_model_dict ] @@ -516,17 +493,17 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): self.required_tensorkeys_for_function['validate'] = {} # TODO This is not stateless. The optimizer will not be self.required_tensorkeys_for_function['validate']['apply=local'] = [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('trained',)) + TensorKey(tensor_name, 'LOCAL', 0, False, ('trained', )) for tensor_name in { **global_model_dict_val, **local_model_dict_val } ] self.required_tensorkeys_for_function['validate']['apply=global'] = [ - TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model', )) for tensor_name in global_model_dict_val ] self.required_tensorkeys_for_function['validate']['apply=global'] += [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) + TensorKey(tensor_name, 'LOCAL', 0, False, ('model', )) for tensor_name in local_model_dict_val ] diff --git a/openfl/federated/task/runner_pt.py b/openfl/federated/task/runner_pt.py index eee4410c78..bbb10bf638 100644 --- a/openfl/federated/task/runner_pt.py +++ b/openfl/federated/task/runner_pt.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """PyTorchTaskRunner module.""" from copy import deepcopy @@ -20,21 +19,40 @@ class PyTorchTaskRunner(nn.Module, TaskRunner): - """PyTorch Model class for Federated Learning.""" + """PyTorch Model class for Federated Learning. + + Attributes: + device (str): Compute device (default="cpu") + required_tensorkeys_for_function (dict): A map of all the required + tensors for each of the public functions in PyTorchTaskRunner. + optimizer (Optimizer): The optimizer for the model. + loss_fn (function): The loss function for the model. + training_round_completed (bool): Flag to check if training round is + completed. + tensor_dict_split_fn_kwargs (dict): Arguments for the tensor + dictionary split function. + """ - def __init__(self, device: str = None, loss_fn=None, optimizer=None, **kwargs): - """Initialize. + def __init__(self, + device: str = None, + loss_fn=None, + optimizer=None, + **kwargs): + """Initializes the PyTorchTaskRunner object. Args: - device (string): Compute device (default="cpu") - **kwargs: Additional parameters to pass to the functions + device (str): Compute device (default="cpu"). + loss_fn (function): The loss function for the model. + optimizer (Optimizer): The optimizer for the model. + **kwargs: Additional parameters to pass to the functions. """ super().__init__() TaskRunner.__init__(self, **kwargs) if device: self.device = device else: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") # This is a map of all the required tensors for each of the public # functions in PyTorchTaskRunner @@ -48,12 +66,16 @@ def __init__(self, device: str = None, loss_fn=None, optimizer=None, **kwargs): # child model that does not overwrite get and set tensordict) that is # not a numpy array self.tensor_dict_split_fn_kwargs.update( - {"holdout_tensor_names": ["__opt_state_needed"]} - ) + {"holdout_tensor_names": ["__opt_state_needed"]}) def rebuild_model(self, round_num, input_tensor_dict, validation=False): - """ - Parse tensor names and update weights of model. Handles the optimizer treatment. + """Parse tensor names and update weights of model. Handles the + optimizer treatment. + + Args: + round_num (int): The current round number + input_tensor_dict (dict): The input tensor dictionary + validation (bool): Flag to check if it's validation Returns: None @@ -61,32 +83,33 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False): if self.opt_treatment == "RESET": self.reset_opt_vars() self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - elif ( - self.training_round_completed - and self.opt_treatment == "CONTINUE_GLOBAL" - and not validation - ): + elif (self.training_round_completed + and self.opt_treatment == "CONTINUE_GLOBAL" and not validation): self.set_tensor_dict(input_tensor_dict, with_opt_vars=True) else: self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - def validate_task( - self, col_name, round_num, input_tensor_dict, use_tqdm=False, **kwargs - ): + def validate_task(self, + col_name, + round_num, + input_tensor_dict, + use_tqdm=False, + **kwargs): """Validate Task. Run validation of the model on the local data. Args: - col_name: Name of the collaborator - round_num: What round is it - input_tensor_dict: Required input tensors (for model) - use_tqdm (bool): Use tqdm to print a progress bar (Default=True) + col_name (str): Name of the collaborator. + round_num (int): What round is it. + input_tensor_dict (dict): Required input tensors (for model). + use_tqdm (bool): Use tqdm to print a progress bar (Default=True). + **kwargs: Additional parameters. Returns: - global_output_dict: Tensors to send back to the aggregator - local_output_dict: Tensors to maintain in the local TensorDB - + global_output_dict (dict): Tensors to send back to the aggregator. + local_output_dict (dict): Tensors to maintain in the local + TensorDB. """ self.rebuild_model(round_num, input_tensor_dict, validation=True) self.eval() @@ -104,7 +127,7 @@ def validate_task( suffix += "_local" else: suffix += "_agg" - tags = ("metric",) + tags = ("metric", ) tags = change_tags(tags, add_field=suffix) # TODO figure out a better way to pass in metric for this pytorch # validate function @@ -115,23 +138,29 @@ def validate_task( # Empty list represents metrics that should only be stored locally return output_tensor_dict, {} - def train_task( - self, col_name, round_num, input_tensor_dict, use_tqdm=False, epochs=1, **kwargs - ): + def train_task(self, + col_name, + round_num, + input_tensor_dict, + use_tqdm=False, + epochs=1, + **kwargs): """Train batches task. Train the model on the requested number of batches. Args: - col_name: Name of the collaborator - round_num: What round is it - input_tensor_dict: Required input tensors (for model) - use_tqdm (bool): Use tqdm to print a progress bar (Default=True) - epochs: The number of epochs to train + col_name (str): Name of the collaborator. + round_num (int): What round is it. + input_tensor_dict (dict): Required input tensors (for model). + use_tqdm (bool): Use tqdm to print a progress bar (Default=True). + epochs (int): The number of epochs to train. + **kwargs: Additional parameters. Returns: - global_output_dict: Tensors to send back to the aggregator - local_output_dict: Tensors to maintain in the local TensorDB + global_output_dict (dict): Tensors to send back to the aggregator. + local_output_dict (dict): Tensors to maintain in the local + TensorDB. """ self.rebuild_model(round_num, input_tensor_dict) # set to "training" mode @@ -145,16 +174,16 @@ def train_task( metric = self.train_(loader) # Output metric tensors (scalar) origin = col_name - tags = ("trained",) + tags = ("trained", ) output_metric_dict = { - TensorKey(metric.name, origin, round_num, True, ("metric",)): metric.value + TensorKey(metric.name, origin, round_num, True, ("metric", )): + metric.value } # output model tensors (Doesn't include TensorKey) output_model_dict = self.get_tensor_dict(with_opt_vars=True) global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( - self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs - ) + self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs) # Create global tensorkeys global_tensorkey_model_dict = { @@ -170,11 +199,15 @@ def train_task( # for the updated model parameters. # This ensures they will be resolved locally next_local_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num + 1, False, ("model",)): nparray + TensorKey(tensor_name, origin, round_num + 1, False, ("model", )): + nparray for tensor_name, nparray in local_model_dict.items() } - global_tensor_dict = {**output_metric_dict, **global_tensorkey_model_dict} + global_tensor_dict = { + **output_metric_dict, + **global_tensorkey_model_dict + } local_tensor_dict = { **local_tensorkey_model_dict, **next_local_tensorkey_model_dict, @@ -205,11 +238,10 @@ def get_tensor_dict(self, with_opt_vars=False): Args: with_opt_vars (bool): Return the tensor dictionary including the - optimizer tensors (Default=False) + optimizer tensors (Default=False) Returns: - dict: Tensor dictionary {**dict, **optimizer_dict} - + state (dict): Tensor dictionary {**dict, **optimizer_dict} """ # Gets information regarding tensor model layers and optimizer state. # FIXME: self.parameters() instead? Unclear if load_state_dict() or @@ -226,7 +258,15 @@ def get_tensor_dict(self, with_opt_vars=False): return state def _get_weights_names(self, with_opt_vars=False): - # Gets information regarding tensor model layers and optimizer state. + """Get information regarding tensor model layers and optimizer state. + + Args: + with_opt_vars (bool): Flag to check if optimizer variables are + included (Default=False). + + Returns: + state (list): List of state keys. + """ # FIXME: self.parameters() instead? Unclear if load_state_dict() or # simple assignment is better # for now, state dict gives us names which is good @@ -244,10 +284,9 @@ def set_tensor_dict(self, tensor_dict, with_opt_vars=False): """Set the tensor dictionary. Args: - tensor_dict: The tensor dictionary + tensor_dict (dict): The tensor dictionary. with_opt_vars (bool): Return the tensor dictionary including the - optimizer tensors (Default=False) - + optimizer tensors (Default=False). """ # Sets tensors for model layers and optimizer state. # FIXME: self.parameters() instead? Unclear if load_state_dict() or @@ -276,24 +315,28 @@ def set_tensor_dict(self, tensor_dict, with_opt_vars=False): assert len(tensor_dict) == 0 def get_optimizer(self): - """Get the optimizer of this instance.""" + """Get the optimizer of this instance. + + Returns: + Optimizer: The optimizer of the instance. + """ return self.optimizer def get_required_tensorkeys_for_function(self, func_name, **kwargs): - """ - Get the required tensors for specified function that could be called \ - as part of a task. By default, this is just all of the layers and \ + """Get the required tensors for specified function that could be called + as part of a task. By default, this is just all of the layers and optimizer of the model. Args: - func_name + func_name (str): The function name. Returns: - list : [TensorKey] + list : [TensorKey]. """ if func_name == "validate_task": local_model = "apply=" + str(kwargs["apply"]) - return self.required_tensorkeys_for_function[func_name][local_model] + return self.required_tensorkeys_for_function[func_name][ + local_model] else: return self.required_tensorkeys_for_function[func_name] @@ -304,7 +347,8 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): Custom tensors should be added to this function. Args: - None + with_opt_vars (bool): Flag to check if optimizer variables are + included. Defaults to False. Returns: None @@ -315,8 +359,7 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): output_model_dict = self.get_tensor_dict(with_opt_vars=with_opt_vars) global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( - self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs - ) + self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs) if not with_opt_vars: global_model_dict_val = global_model_dict local_model_dict_val = local_model_dict @@ -324,25 +367,24 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): output_model_dict = self.get_tensor_dict(with_opt_vars=False) global_model_dict_val, local_model_dict_val = ( split_tensor_dict_for_holdouts( - self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs - ) - ) + self.logger, output_model_dict, + **self.tensor_dict_split_fn_kwargs)) self.required_tensorkeys_for_function["train_task"] = [ - TensorKey(tensor_name, "GLOBAL", 0, False, ("model",)) + TensorKey(tensor_name, "GLOBAL", 0, False, ("model", )) for tensor_name in global_model_dict ] self.required_tensorkeys_for_function["train_task"] += [ - TensorKey(tensor_name, "LOCAL", 0, False, ("model",)) + TensorKey(tensor_name, "LOCAL", 0, False, ("model", )) for tensor_name in local_model_dict ] self.required_tensorkeys_for_function["train_task"] = [ - TensorKey(tensor_name, "GLOBAL", 0, False, ("model",)) + TensorKey(tensor_name, "GLOBAL", 0, False, ("model", )) for tensor_name in global_model_dict ] self.required_tensorkeys_for_function["train_task"] += [ - TensorKey(tensor_name, "LOCAL", 0, False, ("model",)) + TensorKey(tensor_name, "LOCAL", 0, False, ("model", )) for tensor_name in local_model_dict ] @@ -363,26 +405,22 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): for tensor_name in local_model_dict_val ] - def load_native( - self, - filepath, - model_state_dict_key="model_state_dict", - optimizer_state_dict_key="optimizer_state_dict", - **kwargs, - ): - """ - Load model and optimizer states from a pickled file specified by \ - filepath. model_/optimizer_state_dict args can be specified if needed. \ + def load_native(self, + filepath, + model_state_dict_key='model_state_dict', + optimizer_state_dict_key='optimizer_state_dict', + **kwargs): + """Load model and optimizer states from a pickled file specified by + filepath. model_/optimizer_state_dict args can be specified if needed. Uses pt.load(). Args: - filepath (string) : Path to pickle file created - by torch.save(). - model_state_dict_key (string) : key for model state dict - in pickled file. - optimizer_state_dict_key (string) : key for optimizer state dict - in picked file. - kwargs : unused + filepath (str): Path to pickle file created by pt.save(). + model_state_dict_key (str): key for model state dict in pickled + file. + optimizer_state_dict_key (str): key for optimizer state dict in + picked file. + **kwargs: Additional parameters. Returns: None @@ -391,26 +429,22 @@ def load_native( self.load_state_dict(pickle_dict[model_state_dict_key]) self.optimizer.load_state_dict(pickle_dict[optimizer_state_dict_key]) - def save_native( - self, - filepath, - model_state_dict_key="model_state_dict", - optimizer_state_dict_key="optimizer_state_dict", - **kwargs, - ): - """ - Save model and optimizer states in a picked file specified by the \ - filepath. model_/optimizer_state_dicts are stored in the keys provided. \ - Uses torch.save(). + def save_native(self, + filepath, + model_state_dict_key='model_state_dict', + optimizer_state_dict_key='optimizer_state_dict', + **kwargs): + """Save model and optimizer states in a picked file specified by the + filepath. model_/optimizer_state_dicts are stored in the keys provided. + Uses pt.save(). Args: - filepath (string) : Path to pickle file to be - created by torch.save(). - model_state_dict_key (string) : key for model state dict - in pickled file. - optimizer_state_dict_key (string) : key for optimizer state - dict in picked file. - kwargs : unused + filepath (str): Path to pickle file to be created by pt.save(). + model_state_dict_key (str): key for model state dict in pickled + file. + optimizer_state_dict_key (str): key for optimizer state dict in + picked file. + **kwargs: Additional parameters. Returns: None @@ -426,27 +460,30 @@ def reset_opt_vars(self): Resets the optimizer state variables. + Returns: + None """ pass def train_( - self, train_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]] - ) -> Metric: + self, train_dataloader: Iterator[Tuple[np.ndarray, + np.ndarray]]) -> Metric: """Train single epoch. Override this function in order to use custom training. Args: - batch_generator: Train dataset batch generator. Yields (samples, targets) tuples of - size = `self.data_loader.batch_size`. + batch_generator (Iterator): Train dataset batch generator. Yields + (samples, targets) tuples of + size = `self.data_loader.batch_size`. + Returns: Metric: An object containing name and np.ndarray value. """ losses = [] for data, target in train_dataloader: - data, target = torch.tensor(data).to(self.device), torch.tensor(target).to( - self.device - ) + data, target = torch.tensor(data).to( + self.device), torch.tensor(target).to(self.device) self.optimizer.zero_grad() output = self(data) loss = self.loss_fn(output=output, target=target) @@ -457,10 +494,9 @@ def train_( return Metric(name=self.loss_fn.__name__, value=np.array(loss)) def validate_( - self, validation_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]] - ) -> Metric: - """ - Perform validation on PyTorch Model + self, validation_dataloader: Iterator[Tuple[np.ndarray, + np.ndarray]]) -> Metric: + """Perform validation on PyTorch Model. Override this function for your own custom validation function @@ -477,9 +513,9 @@ def validate_( for data, target in validation_dataloader: samples = target.shape[0] total_samples += samples - data, target = torch.tensor(data).to(self.device), torch.tensor( - target - ).to(self.device, dtype=torch.int64) + data, target = torch.tensor(data).to( + self.device), torch.tensor(target).to(self.device, + dtype=torch.int64) output = self(data) # get the index of the max log-probability pred = output.argmax(dim=1) @@ -498,8 +534,10 @@ def _derive_opt_state_dict(opt_state_dict): expand_derived_opt_state_dict. Args: - opt_state_dict: The optimizer state dictionary + opt_state_dict (dict): The optimizer state dictionary. + Returns: + derived_opt_state_dict (dict): Derived optimizer state dictionary. """ derived_opt_state_dict = {} @@ -513,7 +551,8 @@ def _derive_opt_state_dict(opt_state_dict): # Using one example state key, we collect keys for the corresponding # dictionary value. example_state_key = opt_state_dict["param_groups"][0]["params"][0] - example_state_subkeys = set(opt_state_dict["state"][example_state_key].keys()) + example_state_subkeys = set( + opt_state_dict["state"][example_state_key].keys()) # We assume that the state collected for all params in all param groups is # the same. @@ -522,13 +561,14 @@ def _derive_opt_state_dict(opt_state_dict): # Using assert statements to break the routine if these assumptions are # incorrect. for state_key in opt_state_dict["state"].keys(): - assert example_state_subkeys == set(opt_state_dict["state"][state_key].keys()) + assert example_state_subkeys == set( + opt_state_dict["state"][state_key].keys()) for state_subkey in example_state_subkeys: assert isinstance( - opt_state_dict["state"][example_state_key][state_subkey], torch.Tensor - ) == isinstance( - opt_state_dict["state"][state_key][state_subkey], torch.Tensor - ) + opt_state_dict["state"][example_state_key][state_subkey], + torch.Tensor) == isinstance( + opt_state_dict["state"][state_key][state_subkey], + torch.Tensor) state_subkeys = list(opt_state_dict["state"][example_state_key].keys()) @@ -536,9 +576,8 @@ def _derive_opt_state_dict(opt_state_dict): # tensor or not. state_subkey_tags = [] for state_subkey in state_subkeys: - if isinstance( - opt_state_dict["state"][example_state_key][state_subkey], torch.Tensor - ): + if isinstance(opt_state_dict["state"][example_state_key][state_subkey], + torch.Tensor): state_subkey_tags.append("istensor") else: state_subkey_tags.append("") @@ -552,16 +591,18 @@ def _derive_opt_state_dict(opt_state_dict): for idx, param_id in enumerate(group["params"]): for subkey, tag in state_subkeys_and_tags: if tag == "istensor": - new_v = opt_state_dict["state"][param_id][subkey].cpu().numpy() + new_v = opt_state_dict["state"][param_id][subkey].cpu( + ).numpy() else: - new_v = np.array([opt_state_dict["state"][param_id][subkey]]) + new_v = np.array( + [opt_state_dict["state"][param_id][subkey]]) derived_opt_state_dict[ - f"__opt_state_{group_idx}_{idx}_{tag}_{subkey}" - ] = new_v + f"__opt_state_{group_idx}_{idx}_{tag}_{subkey}"] = new_v nb_params_per_group.append(idx + 1) # group lengths are also helpful for reconstructing # original opt_state_dict structure - derived_opt_state_dict["__opt_group_lengths"] = np.array(nb_params_per_group) + derived_opt_state_dict["__opt_group_lengths"] = np.array( + nb_params_per_group) return derived_opt_state_dict @@ -576,10 +617,11 @@ def expand_derived_opt_state_dict(derived_opt_state_dict, device): prefix, "__opt_state_0_0_", certain to be present. Args: - derived_opt_state_dict: Optimizer state dictionary + derived_opt_state_dict (dict): Optimizer state dictionary. + device (str): The device to be used. Returns: - dict: Optimizer state dictionary + opt_state_dict (dict): Optimizer state dictionary. """ state_subkeys_and_tags = [] for key in derived_opt_state_dict: @@ -595,8 +637,7 @@ def expand_derived_opt_state_dict(derived_opt_state_dict, device): opt_state_dict = {"param_groups": [], "state": {}} nb_params_per_group = list( - derived_opt_state_dict.pop("__opt_group_lengths").astype(np.int32) - ) + derived_opt_state_dict.pop("__opt_group_lengths").astype(np.int32)) # Construct the expanded dict. for group_idx, nb_params in enumerate(nb_params_per_group): @@ -626,7 +667,10 @@ def _get_optimizer_state(optimizer): """Return the optimizer state. Args: - optimizer + optimizer (Optimizer): The optimizer for the model. + + Returns: + derived_opt_state_dict (dict): Derived optimizer state dictionary. """ opt_state_dict = deepcopy(optimizer.state_dict()) @@ -647,12 +691,15 @@ def _set_optimizer_state(optimizer, device, derived_opt_state_dict): """Set the optimizer state. Args: - optimizer: - device: - derived_opt_state_dict: + optimizer (Optimizer): The optimizer for the model. + device (str): The device to be used. + derived_opt_state_dict (dict): Derived optimizer state dictionary. + Returns: + None """ - temp_state_dict = expand_derived_opt_state_dict(derived_opt_state_dict, device) + temp_state_dict = expand_derived_opt_state_dict(derived_opt_state_dict, + device) # FIXME: Figure out whether or not this breaks learning rate # scheduling and the like. @@ -670,8 +717,10 @@ def to_cpu_numpy(state): """Send data to CPU as Numpy array. Args: - state + state (dict): The state dictionary. + Returns: + state (dict): State dictionary with values as numpy arrays. """ # deep copy so as to decouple from active model state = deepcopy(state) @@ -679,10 +728,8 @@ def to_cpu_numpy(state): for k, v in state.items(): # When restoring, we currently assume all values are tensors. if not torch.is_tensor(v): - raise ValueError( - "We do not currently support non-tensors " - "coming from model.state_dict()" - ) + raise ValueError("We do not currently support non-tensors " + "coming from model.state_dict()") # get as a numpy array, making sure is on cpu state[k] = v.cpu().numpy() return state diff --git a/openfl/federated/task/runner_tf.py b/openfl/federated/task/runner_tf.py index f63ffce3f8..9380e3429c 100644 --- a/openfl/federated/task/runner_tf.py +++ b/openfl/federated/task/runner_tf.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """TensorFlowTaskRunner module.""" import numpy as np @@ -13,19 +12,41 @@ class TensorFlowTaskRunner(TaskRunner): - """ - Base class for TensorFlow models in the Federated Learning solution. - - child classes should have __init__ function signature (self, data, kwargs), - and should overwrite at least the following while defining the model + """Base class for TensorFlow models in the Federated Learning solution. + + Attributes: + assign_ops (tf.Operation): TensorFlow operations for assignment. + placeholders (tf.Tensor): TensorFlow placeholders for tensors. + tvar_assign_ops (tf.Operation): TensorFlow operations for assignment + of trainable variables. + tvar_placeholders (tf.Tensor): TensorFlow placeholders for trainable + variables. + input_shape (tuple): Shape of the input features. + required_tensorkeys_for_function (dict): Required tensorkeys for all + public functions in TensorFlowTaskRunner. + sess (tf.Session): TensorFlow session. + X (tf.Tensor): Input features to the model. + y (tf.Tensor): Input labels to the model. + train_step (tf.Operation): Optimizer train step operation. + loss (tf.Tensor): Model loss function. + output (tf.Tensor): Model output tensor. + validation_metric (tf.Tensor): Function used to validate the model + outputs against labels. + tvars (list): TensorFlow trainable variables. + opt_vars (list): Optimizer variables. + fl_vars (list): Trainable variables and optimizer variables. + + .. note:: + Child classes should have __init__ function signature (self, data, + kwargs), + and should overwrite at least the following while defining the model. """ def __init__(self, **kwargs): - """ - Initialize. + """Initializes the TensorFlowTaskRunner object. Args: - **kwargs: Additional parameters to pass to the function + **kwargs: Additional parameters to pass to the function. """ tf.disable_v2_behavior() @@ -38,10 +59,7 @@ def __init__(self, **kwargs): self.tvar_placeholders = None # construct the shape needed for the input features - self.input_shape = (None,) + self.data_loader.get_feature_shape() - - # Required tensorkeys for all public functions in TensorFlowTaskRunner - self.required_tensorkeys_for_function = {} + self.input_shape = (None, ) + self.data_loader.get_feature_shape() # Required tensorkeys for all public functions in TensorFlowTaskRunner self.required_tensorkeys_for_function = {} @@ -68,8 +86,13 @@ def __init__(self, **kwargs): self.fl_vars = None def rebuild_model(self, round_num, input_tensor_dict, validation=False): - """ - Parse tensor names and update weights of model. Handles the optimizer treatment. + """Parse tensor names and update weights of model. Handles the + optimizer treatment. + + Args: + round_num (int): The round number. + input_tensor_dict (dict): The input tensor dictionary. + validation (bool): If True, perform validation. Default is False. Returns: None @@ -83,20 +106,29 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False): else: self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - def train_batches(self, col_name, round_num, input_tensor_dict, - epochs=1, use_tqdm=False, **kwargs): - """ - Perform the training. + def train_batches(self, + col_name, + round_num, + input_tensor_dict, + epochs=1, + use_tqdm=False, + **kwargs): + """Perform the training. - Is expected to perform draws randomly, without replacement until data is exausted. Then - data is replaced and shuffled and draws continue. + Is expected to perform draws randomly, without replacement until data + is exausted. Then data is replaced and shuffled and draws continue. Args: - use_tqdm (bool): True = use tqdm to print a progress - bar (Default=False) - epochs (int): Number of epochs to train + col_name (str): The column name. + round_num (int): The round number. + input_tensor_dict (dict): The input tensor dictionary. + epochs (int): Number of epochs to train. Default is 1. + use_tqdm (bool): If True, use tqdm to print a progress bar. + Default is False. + **kwargs: Additional parameters to pass to the function. + Returns: - float: loss metric + float: loss metric. """ batch_size = self.data_loader.batch_size @@ -121,37 +153,35 @@ def train_batches(self, col_name, round_num, input_tensor_dict, # Output metric tensors (scalar) origin = col_name - tags = ('trained',) + tags = ('trained', ) output_metric_dict = { - TensorKey( - self.loss_name, origin, round_num, True, ('metric',) - ): np.array(np.mean(losses)) + TensorKey(self.loss_name, origin, round_num, True, ('metric', )): + np.array(np.mean(losses)) } # output model tensors (Doesn't include TensorKey) output_model_dict = self.get_tensor_dict(with_opt_vars=True) global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( - self.logger, output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) + self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs) # Create global tensorkeys global_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num, False, tags): - nparray for tensor_name, nparray in global_model_dict.items() + TensorKey(tensor_name, origin, round_num, False, tags): nparray + for tensor_name, nparray in global_model_dict.items() } # Create tensorkeys that should stay local local_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num, False, tags): - nparray for tensor_name, nparray in local_model_dict.items() + TensorKey(tensor_name, origin, round_num, False, tags): nparray + for tensor_name, nparray in local_model_dict.items() } # The train/validate aggregated function of the next round will # look for the updated model parameters. # This ensures they will be resolved locally next_local_tensorkey_model_dict = { - TensorKey( - tensor_name, origin, round_num + 1, False, ('model',) - ): nparray for tensor_name, nparray in local_model_dict.items()} + TensorKey(tensor_name, origin, round_num + 1, False, ('model', )): + nparray + for tensor_name, nparray in local_model_dict.items() + } global_tensor_dict = { **output_metric_dict, @@ -178,30 +208,41 @@ def train_batches(self, col_name, round_num, input_tensor_dict, return global_tensor_dict, local_tensor_dict def train_batch(self, X, y): - """ - Train the model on a single batch. + """Train the model on a single batch. Args: - X: Input to the model - y: Ground truth label to the model + X (tf.Tensor): Input to the model. + y (tf.Tensor): Ground truth label to the model. Returns: - float: loss metric + loss (float): loss metric. """ feed_dict = {self.X: X, self.y: y} # run the train step and return the loss - _, loss = self.sess.run([self.train_step, self.loss], feed_dict=feed_dict) + _, loss = self.sess.run([self.train_step, self.loss], + feed_dict=feed_dict) return loss - def validate(self, col_name, round_num, - input_tensor_dict, use_tqdm=False, **kwargs): - """ - Run validation. + def validate(self, + col_name, + round_num, + input_tensor_dict, + use_tqdm=False, + **kwargs): + """Run validation. + + Args: + col_name (str): The column name. + round_num (int): The round number. + input_tensor_dict (dict): The input tensor dictionary. + use_tqdm (bool): If True, use tqdm to print a progress bar. + Default is False. + **kwargs: Additional parameters to pass to the function. Returns: - dict: {: } + output_tensor_dict (dict): {: }. """ batch_size = self.data_loader.batch_size @@ -231,9 +272,9 @@ def validate(self, col_name, round_num, suffix += '_agg' tags = ('metric', suffix) output_tensor_dict = { - TensorKey( - self.validation_metric_name, origin, round_num, True, tags - ): np.array(score)} + TensorKey(self.validation_metric_name, origin, round_num, True, tags): + np.array(score) + } # return empty dict for local metrics return output_tensor_dict, {} @@ -242,30 +283,28 @@ def validate_batch(self, X, y): """Validate the model on a single local batch. Args: - X: Input to the model - y: Ground truth label to the model + X (tf.Tensor): Input to the model. + y (tf.Tensor): Ground truth label to the model. Returns: - float: loss metric - + float: loss metric. """ feed_dict = {self.X: X, self.y: y} - return self.sess.run( - [self.output, self.validation_metric], feed_dict=feed_dict) + return self.sess.run([self.output, self.validation_metric], + feed_dict=feed_dict) def get_tensor_dict(self, with_opt_vars=True): """Get the dictionary weights. - Get the weights from the tensor + Get the weights from the tensor. Args: with_opt_vars (bool): Specify if we also want to get the variables - of the optimizer + of the optimizer. Default is True. Returns: - dict: The weight dictionary {: } - + dict: The weight dictionary {: }. """ if with_opt_vars is True: variables = self.fl_vars @@ -273,39 +312,40 @@ def get_tensor_dict(self, with_opt_vars=True): variables = self.tvars # FIXME: do this in one call? - return {var.name: val for var, val in zip( - variables, self.sess.run(variables))} + return { + var.name: val + for var, val in zip(variables, self.sess.run(variables)) + } def set_tensor_dict(self, tensor_dict, with_opt_vars): """Set the tensor dictionary. - Set the model weights with a tensor - dictionary: {: }. + Set the model weights with a tensor dictionary: + {: }. Args: - tensor_dict (dict): The model weights dictionary + tensor_dict (dict): The model weights dictionary. with_opt_vars (bool): Specify if we also want to set the variables - of the optimizer + of the optimizer. Returns: None """ if with_opt_vars: self.assign_ops, self.placeholders = tf_set_tensor_dict( - tensor_dict, self.sess, self.fl_vars, - self.assign_ops, self.placeholders - ) + tensor_dict, self.sess, self.fl_vars, self.assign_ops, + self.placeholders) else: self.tvar_assign_ops, self.tvar_placeholders = tf_set_tensor_dict( - tensor_dict, - self.sess, - self.tvars, - self.tvar_assign_ops, - self.tvar_placeholders - ) + tensor_dict, self.sess, self.tvars, self.tvar_assign_ops, + self.tvar_placeholders) def reset_opt_vars(self): - """Reinitialize the optimizer variables.""" + """Reinitialize the optimizer variables. + + Returns: + None + """ for v in self.opt_vars: v.initializer.run(session=self.sess) @@ -324,10 +364,10 @@ def _get_weights_names(self, with_opt_vars=True): Args: with_opt_vars (bool): Specify if we also want to get the variables - of the optimizer. + of the optimizer. Default is True. Returns: - list : The weight names list + list: The weight names list. """ if with_opt_vars is True: variables = self.fl_vars @@ -337,28 +377,39 @@ def _get_weights_names(self, with_opt_vars=True): return [var.name for var in variables] def get_required_tensorkeys_for_function(self, func_name, **kwargs): - """ - Get the required tensors for specified function that could be called as part of a task. + """Get the required tensors for specified function that could be called + as part of a task. By default, this is just all of the layers and optimizer of the model. + Args: + func_name (str): The function name. + **kwargs: Additional parameters to pass to the function. + Returns: - list : [TensorKey] + required_tensorkeys_for_function (list): List of required + TensorKey. [TensorKey]. """ if func_name == 'validate': local_model = 'apply=' + str(kwargs['apply']) - return self.required_tensorkeys_for_function[func_name][local_model] + return self.required_tensorkeys_for_function[func_name][ + local_model] else: return self.required_tensorkeys_for_function[func_name] def initialize_tensorkeys_for_functions(self, with_opt_vars=False): - """ - Set the required tensors for all publicly accessible methods \ - that could be called as part of a task. + """Set the required tensors for all publicly accessible methods that + could be called as part of a task. By default, this is just all of the layers and optimizer of the model. Custom tensors should be added to this function + Args: + with_opt_vars (bool): Specify if we also want to set the variables + of the optimizer. Default is False. + + Returns: + None """ # TODO there should be a way to programmatically iterate through # all of the methods in the class and declare the tensors. @@ -366,44 +417,42 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): output_model_dict = self.get_tensor_dict(with_opt_vars=with_opt_vars) global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( - self.logger, output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) + self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs) if not with_opt_vars: global_model_dict_val = global_model_dict local_model_dict_val = local_model_dict else: output_model_dict = self.get_tensor_dict(with_opt_vars=False) global_model_dict_val, local_model_dict_val = split_tensor_dict_for_holdouts( - self.logger, - output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) + self.logger, output_model_dict, + **self.tensor_dict_split_fn_kwargs) self.required_tensorkeys_for_function['train_batches'] = [ - TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) - for tensor_name in global_model_dict] + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model', )) + for tensor_name in global_model_dict + ] self.required_tensorkeys_for_function['train_batches'] += [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) - for tensor_name in local_model_dict] + TensorKey(tensor_name, 'LOCAL', 0, False, ('model', )) + for tensor_name in local_model_dict + ] # Validation may be performed on local or aggregated (global) # model, so there is an extra lookup dimension for kwargs self.required_tensorkeys_for_function['validate'] = {} # TODO This is not stateless. The optimizer will not be self.required_tensorkeys_for_function['validate']['apply=local'] = [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('trained',)) + TensorKey(tensor_name, 'LOCAL', 0, False, ('trained', )) for tensor_name in { **global_model_dict_val, **local_model_dict_val } ] self.required_tensorkeys_for_function['validate']['apply=global'] = [ - TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model', )) for tensor_name in global_model_dict_val ] self.required_tensorkeys_for_function['validate']['apply=global'] += [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) + TensorKey(tensor_name, 'LOCAL', 0, False, ('model', )) for tensor_name in local_model_dict_val ] @@ -415,28 +464,35 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): # to avoid inflating the graph, caller should keep these and pass them back # What if we want to set a different group of vars in the middle? # It is good if it is the subset of the original variables. -def tf_set_tensor_dict(tensor_dict, session, variables, - assign_ops=None, placeholders=None): - """Tensorflow set tensor dictionary. +def tf_set_tensor_dict(tensor_dict, + session, + variables, + assign_ops=None, + placeholders=None): + """Set the tensor dictionary in TensorFlow. Args: - tensor_dict: Dictionary of tensors - session: TensorFlow session - variables: TensorFlow variables - assign_ops: TensorFlow operations (Default=None) - placeholders: TensorFlow placeholders (Default=None) + tensor_dict (dict): Dictionary of tensors. + session (tf.Session): TensorFlow session. + variables (list): List of TensorFlow variables. + assign_ops (tf.Operation, optional): TensorFlow operations for + assignment. Default is None. + placeholders (tf.Tensor, optional): TensorFlow placeholders for + tensors. Default is None. Returns: - assign_ops, placeholders - + assign_ops (tf.Operation): TensorFlow operations for assignment. + placeholders (tf.Tensor): TensorFlow placeholders for tensors. """ if placeholders is None: placeholders = { - v.name: tf.placeholder(v.dtype, shape=v.shape) for v in variables + v.name: tf.placeholder(v.dtype, shape=v.shape) + for v in variables } if assign_ops is None: assign_ops = { - v.name: tf.assign(v, placeholders[v.name]) for v in variables + v.name: tf.assign(v, placeholders[v.name]) + for v in variables } for k, v in tensor_dict.items(): diff --git a/openfl/federated/task/task_runner.py b/openfl/federated/task/task_runner.py index 7bf6340cad..167e5e4153 100644 --- a/openfl/federated/task/task_runner.py +++ b/openfl/federated/task/task_runner.py @@ -12,45 +12,72 @@ class CoreTaskRunner: - """Federated Learning Task Runner Class.""" + """Federated Learning Task Runner Class. + + Attributes: + kwargs (dict): Additional parameters passed to the function. + TASK_REGISTRY (dict): Registry of tasks. + training_round_completed (bool): Flag indicating if a training round + has been completed. + tensor_dict_split_fn_kwargs (dict): Key word arguments for determining + which parameters to hold out from aggregation. + required_tensorkeys_for_function (dict): Required tensorkeys for all + public functions in CoreTaskRunner. + logger (logging.Logger): Logger object for logging events. + opt_treatment (str): Treatment of current instance optimizer. + """ + + def _prepare_tensorkeys_for_agggregation(self, metric_dict, + validation_flag, col_name, + round_num): + """Prepare tensorkeys for aggregation. - def _prepare_tensorkeys_for_agggregation(self, metric_dict, validation_flag, - col_name, round_num): - """ - Prepare tensorkeys for aggregation. + Args: + metric_dict (dict): Dictionary of metrics. + validation_flag (bool): Flag indicating if validation is to be + performed. + col_name (str): The column name. + round_num (int): The round number. - Returns (global_tensor_dict, local_tensor_dict) + Returns: + tuple: Tuple containing global_tensor_dict and local_tensor_dict. """ global_tensor_dict, local_tensor_dict = {}, {} origin = col_name if not validation_flag: # Output metric tensors (scalar) - tags = ('trained',) + tags = ('trained', ) # output model tensors (Doesn't include TensorKey) output_model_dict = self.get_tensor_dict(with_opt_vars=True) global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( self.logger, output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) + **self.tensor_dict_split_fn_kwargs) # Create global tensorkeys global_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num, False, tags): - nparray for tensor_name, nparray in global_model_dict.items()} + TensorKey(tensor_name, origin, round_num, False, tags): nparray + for tensor_name, nparray in global_model_dict.items() + } # Create tensorkeys that should stay local local_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num, False, tags): - nparray for tensor_name, nparray in local_model_dict.items()} + TensorKey(tensor_name, origin, round_num, False, tags): nparray + for tensor_name, nparray in local_model_dict.items() + } # The train/validate aggregated function of the next # round will look for the updated model parameters. # This ensures they will be resolved locally - next_local_tensorkey_model_dict = {TensorKey( - tensor_name, origin, round_num + 1, False, ('model',)): nparray - for tensor_name, nparray in local_model_dict.items()} + next_local_tensorkey_model_dict = { + TensorKey(tensor_name, origin, round_num + 1, False, ('model', )): + nparray + for tensor_name, nparray in local_model_dict.items() + } global_tensor_dict = global_tensorkey_model_dict - local_tensor_dict = {**local_tensorkey_model_dict, **next_local_tensorkey_model_dict} + local_tensor_dict = { + **local_tensorkey_model_dict, + **next_local_tensorkey_model_dict + } # Update the required tensors if they need to be # pulled from the aggregator @@ -73,40 +100,48 @@ def _prepare_tensorkeys_for_agggregation(self, metric_dict, validation_flag, else: suffix = 'validate' + validation_flag - tags = (suffix,) + tags = (suffix, ) tags = change_tags(tags, add_field='metric') metric_dict = { - TensorKey(metric, origin, round_num, True, tags): - np.array(value) for metric, value in metric_dict.items() + TensorKey(metric, origin, round_num, True, tags): np.array(value) + for metric, value in metric_dict.items() } global_tensor_dict = {**global_tensor_dict, **metric_dict} return global_tensor_dict, local_tensor_dict def adapt_tasks(self): - """ - Prepare tasks for the collaborator. + """Prepare tasks for the collaborator. - Using functions from a task provider (deserialized interface object) and - registered task contracts prepares callable tasks to be invoked by the collaborator. + Using functions from a task provider (deserialized interface object) + and registered task contracts prepares callable tasks to be invoked by + the collaborator. - Preparing includes conditional model rebuilding and filling output dicts - with tensors for aggregation and storing in local DB. + Preparing includes conditional model rebuilding and filling output + dicts with tensors for aggregation and storing in local DB. There is an assumption that any training task accepts optimizer as one of the arguments, thus the model should be aggregated after such tasks. + + Returns: + None """ def task_binder(task_name, callable_task): - def collaborator_adapted_task(col_name, round_num, input_tensor_dict, **kwargs): + + def collaborator_adapted_task(col_name, round_num, + input_tensor_dict, **kwargs): task_contract = self.task_provider.task_contract[task_name] # Validation flag can be [False, '_local', '_agg'] - validation_flag = True if task_contract['optimizer'] is None else False + validation_flag = True if task_contract[ + 'optimizer'] is None else False task_settings = self.task_provider.task_settings[task_name] device = kwargs.get('device', 'cpu') - self.rebuild_model(input_tensor_dict, validation=validation_flag, device=device) + self.rebuild_model(input_tensor_dict, + validation=validation_flag, + device=device) task_kwargs = {} if validation_flag: loader = self.data_loader.get_valid_loader() @@ -137,16 +172,20 @@ def collaborator_adapted_task(col_name, round_num, input_tensor_dict, **kwargs): return collaborator_adapted_task - for task_name, callable_task in self.task_provider.task_registry.items(): - self.TASK_REGISTRY[task_name] = task_binder(task_name, callable_task) + for task_name, callable_task in self.task_provider.task_registry.items( + ): + self.TASK_REGISTRY[task_name] = task_binder( + task_name, callable_task) def __init__(self, **kwargs): - """ - Initialize Task Runner. + """Initializes the Task Runner object. This class is a part of the Interactive python API release. It is no longer a user interface entity that should be subclassed but a part of OpenFL kernel. + + Args: + **kwargs: Additional parameters to pass to the function. """ self.set_logger() @@ -164,17 +203,21 @@ def __init__(self, **kwargs): # overwrite attribute to account for one optimizer param (in every # child model that does not overwrite get and set tensordict) that is # not a numpy array - self.tensor_dict_split_fn_kwargs.update({ - 'holdout_tensor_names': ['__opt_state_needed'] - }) + self.tensor_dict_split_fn_kwargs.update( + {'holdout_tensor_names': ['__opt_state_needed']}) def set_task_provider(self, task_provider): - """ - Set task registry. + """Set task registry. This method recieves Task Interface object as an argument and uses provided callables and information to prepare tasks that may be called by the collaborator component. + + Args: + task_provider: Task provider object. + + Returns: + None """ if task_provider is None: return @@ -182,84 +225,127 @@ def set_task_provider(self, task_provider): self.adapt_tasks() def set_data_loader(self, data_loader): - """Register a data loader initialized with local data path.""" + """Register a data loader initialized with local data path. + + Args: + data_loader: Data loader object. + + Returns: + None + """ self.data_loader = data_loader def set_model_provider(self, model_provider): - """Retrieve a model and an optimizer from the interface object.""" + """Retrieve a model and an optimizer from the interface object. + + Args: + model_provider: Model provider object. + + Returns: + None + """ self.model_provider = model_provider self.model = self.model_provider.provide_model() self.optimizer = self.model_provider.provide_optimizer() def set_framework_adapter(self, framework_adapter): - """ - Set framework adapter. + """Set framework adapter. Setting a framework adapter allows first extraction of the weigths - of the model with the purpose to make a list of parameters to be aggregated. + of the model with the purpose to make a list of parameters to be + aggregated. + + Args: + framework_adapter: Framework adapter object. + + Returns: + None """ self.framework_adapter = framework_adapter if self.opt_treatment == 'CONTINUE_GLOBAL': aggregate_optimizer_parameters = True else: aggregate_optimizer_parameters = False - self.initialize_tensorkeys_for_functions(with_opt_vars=aggregate_optimizer_parameters) + self.initialize_tensorkeys_for_functions( + with_opt_vars=aggregate_optimizer_parameters) def set_logger(self): - """Set up the log object.""" + """Set up the log object. + + Returns: + None + """ self.logger = getLogger(__name__) def set_optimizer_treatment(self, opt_treatment): # SHould be removed! We have this info at the initialization time # and do not change this one during training. - """Change the treatment of current instance optimizer.""" + """Change the treatment of current instance optimizer. + + Args: + opt_treatment (str): The optimizer treatment. + + Returns: + None + """ self.opt_treatment = opt_treatment def rebuild_model(self, input_tensor_dict, validation=False, device='cpu'): - """ - Parse tensor names and update weights of model. Handles the optimizer treatment. + """Parse tensor names and update weights of model. Handles the + optimizer treatment. + + Args: + input_tensor_dict (dict): The input tensor dictionary. + validation (bool): If True, perform validation. Default is False. + device (str): The device to use. Default is 'cpu'. Returns: None """ if self.opt_treatment == 'RESET': self.reset_opt_vars() - self.set_tensor_dict(input_tensor_dict, with_opt_vars=False, device=device) + self.set_tensor_dict(input_tensor_dict, + with_opt_vars=False, + device=device) elif (self.training_round_completed and self.opt_treatment == 'CONTINUE_GLOBAL' and not validation): - self.set_tensor_dict(input_tensor_dict, with_opt_vars=True, device=device) + self.set_tensor_dict(input_tensor_dict, + with_opt_vars=True, + device=device) else: - self.set_tensor_dict(input_tensor_dict, with_opt_vars=False, device=device) + self.set_tensor_dict(input_tensor_dict, + with_opt_vars=False, + device=device) def get_required_tensorkeys_for_function(self, func_name, **kwargs): - """ - Get the required tensors for specified function that could be called as part of a task. + """Get the required tensors for specified function that could be called + as part of a task. By default, this is just all of the layers and optimizer of the model. - Parameters - ---------- - None + Args: + func_name (str): The function name. + **kwargs: Additional parameters to pass to the function. - Returns - ------- - List - [TensorKey] + Returns: + list: List of required TensorKey. """ # We rely on validation type tasks parameter `apply` # In the interface layer we add those parameters automatically if 'apply' not in kwargs: return [ - TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) - for tensor_name in self.required_tensorkeys_for_function['global_model_dict'] + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model', )) + for tensor_name in + self.required_tensorkeys_for_function['global_model_dict'] ] + [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) - for tensor_name in self.required_tensorkeys_for_function['local_model_dict'] + TensorKey(tensor_name, 'LOCAL', 0, False, ('model', )) + for tensor_name in + self.required_tensorkeys_for_function['local_model_dict'] ] if kwargs['apply'] == 'local': return [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('trained',)) + TensorKey(tensor_name, 'LOCAL', 0, False, ('trained', )) for tensor_name in { **self.required_tensorkeys_for_function['local_model_dict_val'], **self.required_tensorkeys_for_function['global_model_dict_val'] @@ -268,11 +354,13 @@ def get_required_tensorkeys_for_function(self, func_name, **kwargs): elif kwargs['apply'] == 'global': return [ - TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) - for tensor_name in self.required_tensorkeys_for_function['global_model_dict_val'] + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model', )) + for tensor_name in + self.required_tensorkeys_for_function['global_model_dict_val'] ] + [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) - for tensor_name in self.required_tensorkeys_for_function['local_model_dict_val'] + TensorKey(tensor_name, 'LOCAL', 0, False, ('model', )) + for tensor_name in + self.required_tensorkeys_for_function['local_model_dict_val'] ] def initialize_tensorkeys_for_functions(self, with_opt_vars=False): @@ -282,47 +370,46 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): Custom tensors should be added to this function. Args: - None + with_opt_vars (bool): Specify if we also want to set the variables + of the optimizer. Default is False. Returns: None """ - # TODO: Framework adapters should have separate methods for dealing with optimizer - # Set model dict for validation tasks + # TODO: Framework adapters should have separate methods for dealing + # with optimizer. Set model dict for validation tasks output_model_dict = self.get_tensor_dict(with_opt_vars=False) global_model_dict_val, local_model_dict_val = split_tensor_dict_for_holdouts( - self.logger, - output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) + self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs) # Now set model dict for training tasks if with_opt_vars: output_model_dict = self.get_tensor_dict(with_opt_vars=True) global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( self.logger, output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) + **self.tensor_dict_split_fn_kwargs) else: global_model_dict = global_model_dict_val local_model_dict = local_model_dict_val - self.required_tensorkeys_for_function['global_model_dict'] = global_model_dict - self.required_tensorkeys_for_function['local_model_dict'] = local_model_dict - self.required_tensorkeys_for_function['global_model_dict_val'] = global_model_dict_val - self.required_tensorkeys_for_function['local_model_dict_val'] = local_model_dict_val + self.required_tensorkeys_for_function[ + 'global_model_dict'] = global_model_dict + self.required_tensorkeys_for_function[ + 'local_model_dict'] = local_model_dict + self.required_tensorkeys_for_function[ + 'global_model_dict_val'] = global_model_dict_val + self.required_tensorkeys_for_function[ + 'local_model_dict_val'] = local_model_dict_val def reset_opt_vars(self): - """ - Reset optimizer variables. - - Resets the optimizer variables + """Reset optimizer variables. + Returns: + None """ self.optimizer = self.model_provider.provide_optimizer() def get_train_data_size(self): - """ - Get the number of training examples. + """Get the number of training examples. It will be used for weighted averaging in aggregation. @@ -332,8 +419,7 @@ def get_train_data_size(self): return self.data_loader.get_train_data_size() def get_valid_data_size(self): - """ - Get the number of examples. + """Get the number of examples. It will be used for weighted averaging in aggregation. @@ -347,11 +433,10 @@ def get_tensor_dict(self, with_opt_vars=False): Args: with_opt_vars (bool): Return the tensor dictionary including the - optimizer tensors (Default=False) + optimizer tensors (Default=False). Returns: - dict: Tensor dictionary {**dict, **optimizer_dict} - + dict: Tensor dictionary {**dict, **optimizer_dict}. """ args = [self.model] if with_opt_vars: @@ -365,8 +450,7 @@ def set_tensor_dict(self, tensor_dict, with_opt_vars=False, device='cpu'): Args: tensor_dict: The tensor dictionary with_opt_vars (bool): Return the tensor dictionary including the - optimizer tensors (Default=False) - + optimizer tensors (Default=False). """ # Sets tensors for model layers and optimizer state. # FIXME: self.parameters() instead? Unclear if load_state_dict() or @@ -377,6 +461,8 @@ def set_tensor_dict(self, tensor_dict, with_opt_vars=False, device='cpu'): if with_opt_vars: args.append(self.optimizer) - kwargs = {'device': device, } + kwargs = { + 'device': device, + } return self.framework_adapter.set_tensor_dict(*args, **kwargs) diff --git a/openfl/interface/aggregation_functions/__init__.py b/openfl/interface/aggregation_functions/__init__.py index e99dbc15f9..b9121ac1dd 100644 --- a/openfl/interface/aggregation_functions/__init__.py +++ b/openfl/interface/aggregation_functions/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Aggregation functions package.""" from .adagrad_adaptive_aggregation import AdagradAdaptiveAggregation @@ -12,11 +11,8 @@ from .weighted_average import WeightedAverage from .yogi_adaptive_aggregation import YogiAdaptiveAggregation -__all__ = ['Median', - 'WeightedAverage', - 'GeometricMedian', - 'AdagradAdaptiveAggregation', - 'AdamAdaptiveAggregation', - 'YogiAdaptiveAggregation', - 'AggregationFunction', - 'FedCurvWeightedAverage'] +__all__ = [ + 'Median', 'WeightedAverage', 'GeometricMedian', + 'AdagradAdaptiveAggregation', 'AdamAdaptiveAggregation', + 'YogiAdaptiveAggregation', 'AggregationFunction', 'FedCurvWeightedAverage' +] diff --git a/openfl/interface/aggregation_functions/adagrad_adaptive_aggregation.py b/openfl/interface/aggregation_functions/adagrad_adaptive_aggregation.py index 27aee4f867..3360410b36 100644 --- a/openfl/interface/aggregation_functions/adagrad_adaptive_aggregation.py +++ b/openfl/interface/aggregation_functions/adagrad_adaptive_aggregation.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Adagrad adaptive aggregation module.""" from typing import Dict @@ -13,7 +12,6 @@ from .core import AggregationFunction from .weighted_average import WeightedAverage - DEFAULT_AGG_FUNC = WeightedAverage() @@ -30,17 +28,20 @@ def __init__( initial_accumulator_value: float = 0.1, epsilon: float = 1e-10, ) -> None: - """Initialize. + """Initialize the AdagradAdaptiveAggregation object. Args: - agg_func: Aggregate function for aggregating - parameters that are not inside the optimizer (default: WeightedAverage()). - params: Parameters to be stored for optimization. + agg_func (AggregationFunction): Aggregate function for aggregating + parameters that are not inside the optimizer (default: + WeightedAverage()). + params (Optional[Dict[str, np.ndarray]]): Parameters to be stored + for optimization. model_interface: Model interface instance to provide parameters. - learning_rate: Tuning parameter that determines + learning_rate (float): Tuning parameter that determines the step size at each iteration. - initial_accumulator_value: Initial value for squared gradients. - epsilon: Value for computational stability. + initial_accumulator_value (float): Initial value for squared + gradients. + epsilon (float): Value for computational stability. """ opt = NumPyAdagrad(params=params, model_interface=model_interface, diff --git a/openfl/interface/aggregation_functions/adam_adaptive_aggregation.py b/openfl/interface/aggregation_functions/adam_adaptive_aggregation.py index 6c6ad125e6..dd7c067a81 100644 --- a/openfl/interface/aggregation_functions/adam_adaptive_aggregation.py +++ b/openfl/interface/aggregation_functions/adam_adaptive_aggregation.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Adam adaptive aggregation module.""" from typing import Dict @@ -14,7 +13,6 @@ from .core import AggregationFunction from .weighted_average import WeightedAverage - DEFAULT_AGG_FUNC = WeightedAverage() @@ -32,20 +30,22 @@ def __init__( initial_accumulator_value: float = 0.0, epsilon: float = 1e-8, ) -> None: - """Initialize. + """Initialize the AdamAdaptiveAggregation object. Args: - agg_func: Aggregate function for aggregating - parameters that are not inside the optimizer (default: WeightedAverage()). - params: Parameters to be stored for optimization. + agg_func (AggregationFunction): Aggregate function for aggregating + parameters that are not inside the optimizer (default: + WeightedAverage()). + params (Optional[Dict[str, np.ndarray]]): Parameters to be stored + for optimization. model_interface: Model interface instance to provide parameters. - learning_rate: Tuning parameter that determines + learning_rate (float): Tuning parameter that determines the step size at each iteration. - betas: Coefficients used for computing running - averages of gradient and its square. - initial_accumulator_value: Initial value for gradients + betas (Tuple[float, float]): Coefficients used for computing + running averages of gradient and its square. + initial_accumulator_value (float): Initial value for gradients and squared gradients. - epsilon: Value for computational stability. + epsilon (float): Value for computational stability. """ opt = NumPyAdam(params=params, model_interface=model_interface, diff --git a/openfl/interface/aggregation_functions/core/__init__.py b/openfl/interface/aggregation_functions/core/__init__.py index 7bd173d33f..c344fb8f8b 100644 --- a/openfl/interface/aggregation_functions/core/__init__.py +++ b/openfl/interface/aggregation_functions/core/__init__.py @@ -1,10 +1,8 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Aggregation functions core package.""" from .adaptive_aggregation import AdaptiveAggregation from .interface import AggregationFunction -__all__ = ['AggregationFunction', - 'AdaptiveAggregation'] +__all__ = ['AggregationFunction', 'AdaptiveAggregation'] diff --git a/openfl/interface/aggregation_functions/core/adaptive_aggregation.py b/openfl/interface/aggregation_functions/core/adaptive_aggregation.py index bd175116f4..e4165f64a6 100644 --- a/openfl/interface/aggregation_functions/core/adaptive_aggregation.py +++ b/openfl/interface/aggregation_functions/core/adaptive_aggregation.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Adaptive aggregation module.""" from typing import List @@ -23,11 +22,11 @@ def __init__( optimizer: Optimizer, agg_func: AggregationFunction, ) -> None: - """Initialize. + """Initialize the AdaptiveAggregation class. Args: - optimizer: One of numpy optimizer class instance. - agg_func: Aggregate function for aggregating + optimizer (Optimizer): One of numpy optimizer class instance. + agg_func (AggregationFunction): Aggregate function for aggregating parameters that are not inside the optimizer. """ super().__init__() @@ -35,38 +34,46 @@ def __init__( self.default_agg_func = agg_func @staticmethod - def _make_gradient( - base_model_nparray: np.ndarray, - local_tensors: List[LocalTensor] - ) -> np.ndarray: - """Make gradient.""" - return sum([local_tensor.weight * (base_model_nparray - local_tensor.tensor) - for local_tensor in local_tensors]) - - def call( - self, - local_tensors, - db_iterator, - tensor_name, - fl_round, - tags - ) -> np.ndarray: + def _make_gradient(base_model_nparray: np.ndarray, + local_tensors: List[LocalTensor]) -> np.ndarray: + """Make gradient. + + Args: + base_model_nparray (np.ndarray): The base model tensor. + local_tensors (List[LocalTensor]): List of local tensors. + + Returns: + np.ndarray: The gradient tensor. + """ + return sum([ + local_tensor.weight * (base_model_nparray - local_tensor.tensor) + for local_tensor in local_tensors + ]) + + def call(self, local_tensors, db_iterator, tensor_name, fl_round, + tags) -> np.ndarray: """Aggregate tensors. Args: - local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate. + local_tensors (list[openfl.utilities.LocalTensor]): List of local + tensors to aggregate. db_iterator: An iterator over history of all tensors. Columns: - 'tensor_name': name of the tensor. - Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'. - - 'fl_round': 0-based number of round corresponding to this tensor. + Examples for `torch.nn.Module`s: 'conv1.weight','fc2.bias'. + - 'fl_round': 0-based number of round corresponding to this + tensor. - 'tags': tuple of tensor tags. Tags that can appear: - 'model' indicates that the tensor is a model parameter. - - 'trained' indicates that tensor is a part of a training result. - These tensors are passed to the aggregator node after local learning. - - 'aggregated' indicates that tensor is a result of aggregation. - These tensors are sent to collaborators for the next round. - - 'delta' indicates that value is a difference between rounds - for a specific tensor. + - 'trained' indicates that tensor is a part of a training + result. + These tensors are passed to the aggregator node after + local learning. + - 'aggregated' indicates that tensor is a result of + aggregation. + These tensors are sent to collaborators for the next + round. + - 'delta' indicates that value is a difference between + rounds for a specific tensor. also one of the tags is a collaborator name if it corresponds to a result of a local task. @@ -78,26 +85,22 @@ def call( np.ndarray: aggregated tensor """ if tensor_name not in self.optimizer.params: - return self.default_agg_func(local_tensors, - db_iterator, - tensor_name, - fl_round, - tags) + return self.default_agg_func(local_tensors, db_iterator, + tensor_name, fl_round, tags) base_model_nparray = None search_tag = 'aggregated' if fl_round != 0 else 'model' for record in db_iterator: - if ( - record['round'] == fl_round - and record['tensor_name'] == tensor_name - and search_tag in record['tags'] - and 'delta' not in record['tags'] - ): + if (record['round'] == fl_round + and record['tensor_name'] == tensor_name + and search_tag in record['tags'] + and 'delta' not in record['tags']): base_model_nparray = record['nparray'] if base_model_nparray is None: raise KeyError( - f'There is no current global model in TensorDB for tensor name: {tensor_name}') + f'There is no current global model in TensorDB for tensor name: {tensor_name}' + ) gradient = self._make_gradient(base_model_nparray, local_tensors) gradients = {tensor_name: gradient} diff --git a/openfl/interface/aggregation_functions/core/interface.py b/openfl/interface/aggregation_functions/core/interface.py index 499837d1d7..0e77c68c75 100644 --- a/openfl/interface/aggregation_functions/core/interface.py +++ b/openfl/interface/aggregation_functions/core/interface.py @@ -17,35 +17,38 @@ class AggregationFunction(metaclass=SingletonABCMeta): """Interface for specifying aggregation function.""" def __init__(self): - """Initialize common AggregationFunction params + """Initialize common AggregationFunction params. - Default: Read only access to TensorDB + Default: Read only access to TensorDB """ self._privileged = False @abstractmethod - def call(self, - local_tensors: List[LocalTensor], - db_iterator: Iterator[pd.Series], - tensor_name: str, - fl_round: int, + def call(self, local_tensors: List[LocalTensor], + db_iterator: Iterator[pd.Series], tensor_name: str, fl_round: int, tags: Tuple[str]) -> np.ndarray: """Aggregate tensors. Args: - local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate. + local_tensors (list[openfl.utilities.LocalTensor]): List of local + tensors to aggregate. db_iterator: An iterator over history of all tensors. Columns: - 'tensor_name': name of the tensor. - Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'. - - 'round': 0-based number of round corresponding to this tensor. + Examples for `torch.nn.Module`s: 'conv1.weight','fc2.bias'. + - 'round': 0-based number of round corresponding to this + tensor. - 'tags': tuple of tensor tags. Tags that can appear: - 'model' indicates that the tensor is a model parameter. - - 'trained' indicates that tensor is a part of a training result. - These tensors are passed to the aggregator node after local learning. - - 'aggregated' indicates that tensor is a result of aggregation. - These tensors are sent to collaborators for the next round. - - 'delta' indicates that value is a difference between rounds - for a specific tensor. + - 'trained' indicates that tensor is a part of a training + result. + These tensors are passed to the aggregator node after + local learning. + - 'aggregated' indicates that tensor is a result of + aggregation. + These tensors are sent to collaborators for the next + round. + - 'delta' indicates that value is a difference between + rounds for a specific tensor. also one of the tags is a collaborator name if it corresponds to a result of a local task. @@ -58,10 +61,8 @@ def call(self, """ raise NotImplementedError - def __call__(self, local_tensors, - db_iterator, - tensor_name, - fl_round, + def __call__(self, local_tensors, db_iterator, tensor_name, fl_round, tags): """Use magic function for ease.""" - return self.call(local_tensors, db_iterator, tensor_name, fl_round, tags) + return self.call(local_tensors, db_iterator, tensor_name, fl_round, + tags) diff --git a/openfl/interface/aggregation_functions/experimental/__init__.py b/openfl/interface/aggregation_functions/experimental/__init__.py index 3cc4d3907e..23482e7a92 100644 --- a/openfl/interface/aggregation_functions/experimental/__init__.py +++ b/openfl/interface/aggregation_functions/experimental/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Aggregation functions experimental package.""" from .privileged_aggregation import PrivilegedAggregationFunction diff --git a/openfl/interface/aggregation_functions/experimental/privileged_aggregation.py b/openfl/interface/aggregation_functions/experimental/privileged_aggregation.py index e1c76d89ac..7da79cc157 100644 --- a/openfl/interface/aggregation_functions/experimental/privileged_aggregation.py +++ b/openfl/interface/aggregation_functions/experimental/privileged_aggregation.py @@ -13,40 +13,39 @@ class PrivilegedAggregationFunction(AggregationFunction): - """Privileged Aggregation Function interface provides write access to TensorDB Dataframe. + """Privileged Aggregation Function interface provides write access to + TensorDB Dataframe.""" - """ - - def __init__( - self - ) -> None: - """Initialize with TensorDB write access""" + def __init__(self) -> None: + """Initialize with TensorDB write access.""" super().__init__() self._privileged = True @abstractmethod - def call(self, - local_tensors: List[LocalTensor], - tensor_db: pd.DataFrame, - tensor_name: str, - fl_round: int, - tags: Tuple[str]) -> np.ndarray: + def call(self, local_tensors: List[LocalTensor], tensor_db: pd.DataFrame, + tensor_name: str, fl_round: int, tags: Tuple[str]) -> np.ndarray: """Aggregate tensors. Args: - local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate. + local_tensors (list[openfl.utilities.LocalTensor]): List of local + tensors to aggregate. tensor_db: Raw TensorDB dataframe (for write access). Columns: - 'tensor_name': name of the tensor. - Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'. - - 'round': 0-based number of round corresponding to this tensor. + Examples for `torch.nn.Module`s: 'conv1.weight','fc2.bias'. + - 'round': 0-based number of round corresponding to this + tensor. - 'tags': tuple of tensor tags. Tags that can appear: - 'model' indicates that the tensor is a model parameter. - - 'trained' indicates that tensor is a part of a training result. - These tensors are passed to the aggregator node after local learning. - - 'aggregated' indicates that tensor is a result of aggregation. - These tensors are sent to collaborators for the next round. - - 'delta' indicates that value is a difference between rounds - for a specific tensor. + - 'trained' indicates that tensor is a part of a training + result. + These tensors are passed to the aggregator node after + local learning. + - 'aggregated' indicates that tensor is a result of + aggregation. + These tensors are sent to collaborators for the next + round. + - 'delta' indicates that value is a difference between + rounds for a specific tensor. also one of the tags is a collaborator name if it corresponds to a result of a local task. diff --git a/openfl/interface/aggregation_functions/fedcurv_weighted_average.py b/openfl/interface/aggregation_functions/fedcurv_weighted_average.py index 75fa2417d3..acf0ce1f2c 100644 --- a/openfl/interface/aggregation_functions/fedcurv_weighted_average.py +++ b/openfl/interface/aggregation_functions/fedcurv_weighted_average.py @@ -18,11 +18,8 @@ class FedCurvWeightedAverage(WeightedAverage): def call(self, local_tensors, tensor_db, tensor_name, fl_round, tags): """Apply aggregation.""" - if ( - tensor_name.endswith('_u') - or tensor_name.endswith('_v') - or tensor_name.endswith('_w') - ): + if (tensor_name.endswith('_u') or tensor_name.endswith('_v') + or tensor_name.endswith('_w')): tensors = [local_tensor.tensor for local_tensor in local_tensors] agg_result = np.sum(tensors, axis=0) return agg_result diff --git a/openfl/interface/aggregation_functions/geometric_median.py b/openfl/interface/aggregation_functions/geometric_median.py index edba7fecb4..f3f7726b85 100644 --- a/openfl/interface/aggregation_functions/geometric_median.py +++ b/openfl/interface/aggregation_functions/geometric_median.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Geometric median module.""" import numpy as np @@ -10,12 +9,33 @@ def _geometric_median_objective(median, tensors, weights): - """Compute geometric median objective.""" + """Compute geometric median objective. + + Args: + median (np.ndarray): The median tensor. + tensors (list): List of tensors. + weights (list): List of weights corresponding to the tensors. + + Returns: + float: The geometric median objective. + """ return sum([w * _l2dist(median, x) for w, x in zip(weights, tensors)]) def geometric_median(tensors, weights, maxiter=4, eps=1e-5, ftol=1e-6): - """Compute geometric median of tensors with weights using Weiszfeld's Algorithm.""" + """Compute geometric median of tensors with weights using Weiszfeld's + Algorithm. + + Args: + tensors (list): List of tensors. + weights (list): List of weights corresponding to the tensors. + maxiter (int, optional): Maximum number of iterations. Defaults to 4. + eps (float, optional): Epsilon value for stability. Defaults to 1e-5. + ftol (float, optional): Tolerance for convergence. Defaults to 1e-6. + + Returns: + median (np.ndarray): The geometric median of the tensors. + """ weights = np.asarray(weights) / sum(weights) median = weighted_average(tensors, weights) num_oracle_calls = 1 @@ -24,7 +44,9 @@ def geometric_median(tensors, weights, maxiter=4, eps=1e-5, ftol=1e-6): for _ in range(maxiter): prev_obj_val = obj_val - weights = np.asarray([w / max(eps, _l2dist(median, x)) for w, x in zip(weights, tensors)]) + weights = np.asarray([ + w / max(eps, _l2dist(median, x)) for w, x in zip(weights, tensors) + ]) weights = weights / weights.sum() median = weighted_average(tensors, weights) num_oracle_calls += 1 @@ -35,7 +57,15 @@ def geometric_median(tensors, weights, maxiter=4, eps=1e-5, ftol=1e-6): def _l2dist(p1, p2): - """L2 distance between p1, p2, each of which is a list of nd-arrays.""" + """L2 distance between p1, p2, each of which is a list of nd-arrays. + + Args: + p1 (np.ndarray): First tensor. + p2 (np.ndarray): Second tensor. + + Returns: + float: The L2 distance between the two tensors. + """ if p1.ndim != p2.ndim: raise RuntimeError('Tensor shapes should be equal') if p1.ndim < 2: @@ -50,19 +80,23 @@ def call(self, local_tensors, *_) -> np.ndarray: """Aggregate tensors. Args: - local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate. + local_tensors (list[openfl.utilities.LocalTensor]): List of local + tensors to aggregate. db_iterator: iterator over history of all tensors. Columns: - 'tensor_name': name of the tensor. Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'. - - 'round': 0-based number of round corresponding to this tensor. + - 'round': 0-based number of round corresponding to this + tensor. - 'tags': tuple of tensor tags. Tags that can appear: - 'model' indicates that the tensor is a model parameter. - - 'trained' indicates that tensor is a part of a training result. - These tensors are passed to the aggregator node after local learning. - - 'aggregated' indicates that tensor is a result of aggregation. - These tensors are sent to collaborators for the next round. - - 'delta' indicates that value is a difference between rounds - for a specific tensor. + - 'trained' indicates that tensor is a part of a training + result. These tensors are passed to the aggregator + node after local learning. + - 'aggregated' indicates that tensor is a result of + aggregation. These tensors are sent to collaborators + for the next round. + - 'delta' indicates that value is a difference between + rounds for a specific tensor. also one of the tags is a collaborator name if it corresponds to a result of a local task. @@ -71,7 +105,7 @@ def call(self, local_tensors, *_) -> np.ndarray: fl_round: round number tags: tuple of tags for this tensor Returns: - np.ndarray: aggregated tensor + geometric_median (np.ndarray): aggregated tensor """ tensors, weights = zip(*[(x.tensor, x.weight) for x in local_tensors]) tensors, weights = np.array(tensors), np.array(weights) diff --git a/openfl/interface/aggregation_functions/median.py b/openfl/interface/aggregation_functions/median.py index aff44bceb4..7231c5c16e 100644 --- a/openfl/interface/aggregation_functions/median.py +++ b/openfl/interface/aggregation_functions/median.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Median module.""" import numpy as np @@ -15,19 +14,25 @@ def call(self, local_tensors, *_) -> np.ndarray: """Aggregate tensors. Args: - local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate. + local_tensors (list[openfl.utilities.LocalTensor]): List of local + tensors to aggregate. db_iterator: iterator over history of all tensors. Columns: - 'tensor_name': name of the tensor. - Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'. - - 'round': 0-based number of round corresponding to this tensor. + Examples for `torch.nn.Module`s: 'conv1.weight','fc2.bias'. + - 'round': 0-based number of round corresponding to this + tensor. - 'tags': tuple of tensor tags. Tags that can appear: - 'model' indicates that the tensor is a model parameter. - - 'trained' indicates that tensor is a part of a training result. - These tensors are passed to the aggregator node after local learning. - - 'aggregated' indicates that tensor is a result of aggregation. - These tensors are sent to collaborators for the next round. - - 'delta' indicates that value is a difference between rounds - for a specific tensor. + - 'trained' indicates that tensor is a part of a training + result. + These tensors are passed to the aggregator node after + local learning. + - 'aggregated' indicates that tensor is a result of + aggregation. + These tensors are sent to collaborators for the next + round. + - 'delta' indicates that value is a difference between + rounds for a specific tensor. also one of the tags is a collaborator name if it corresponds to a result of a local task. diff --git a/openfl/interface/aggregation_functions/weighted_average.py b/openfl/interface/aggregation_functions/weighted_average.py index b8793432bc..10b8f40ac4 100644 --- a/openfl/interface/aggregation_functions/weighted_average.py +++ b/openfl/interface/aggregation_functions/weighted_average.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Federated averaging module.""" import numpy as np @@ -20,19 +19,25 @@ def call(self, local_tensors, *_) -> np.ndarray: """Aggregate tensors. Args: - local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate. + local_tensors (list[openfl.utilities.LocalTensor]): List of local + tensors to aggregate. db_iterator: iterator over history of all tensors. Columns: - 'tensor_name': name of the tensor. - Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'. - - 'round': 0-based number of round corresponding to this tensor. + Examples for `torch.nn.Module`s: 'conv1.weight','fc2.bias'. + - 'round': 0-based number of round corresponding to this + tensor. - 'tags': tuple of tensor tags. Tags that can appear: - 'model' indicates that the tensor is a model parameter. - - 'trained' indicates that tensor is a part of a training result. - These tensors are passed to the aggregator node after local learning. - - 'aggregated' indicates that tensor is a result of aggregation. - These tensors are sent to collaborators for the next round. - - 'delta' indicates that value is a difference between rounds - for a specific tensor. + - 'trained' indicates that tensor is a part of a training + result. + These tensors are passed to the aggregator node after + local learning. + - 'aggregated' indicates that tensor is a result of + aggregation. + These tensors are sent to collaborators for the next + round. + - 'delta' indicates that value is a difference between + rounds for a specific tensor. also one of the tags is a collaborator name if it corresponds to a result of a local task. diff --git a/openfl/interface/aggregation_functions/yogi_adaptive_aggregation.py b/openfl/interface/aggregation_functions/yogi_adaptive_aggregation.py index 0245818114..abb5628659 100644 --- a/openfl/interface/aggregation_functions/yogi_adaptive_aggregation.py +++ b/openfl/interface/aggregation_functions/yogi_adaptive_aggregation.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Yogi adaptive aggregation module.""" from typing import Dict @@ -14,7 +13,6 @@ from .core import AggregationFunction from .weighted_average import WeightedAverage - DEFAULT_AGG_FUNC = WeightedAverage() @@ -32,20 +30,22 @@ def __init__( initial_accumulator_value: float = 0.0, epsilon: float = 1e-8, ) -> None: - """Initialize. + """Initialize the YogiAdaptiveAggregation object. Args: - agg_func: Aggregate function for aggregating - parameters that are not inside the optimizer (default: WeightedAverage()). - params: Parameters to be stored for optimization. + agg_func (AggregationFunction): Aggregate function for aggregating + parameters that are not inside the optimizer (default: + WeightedAverage()). + params (Optional[Dict[str, np.ndarray]]): Parameters to be stored + for optimization. model_interface: Model interface instance to provide parameters. - learning_rate: Tuning parameter that determines + learning_rate (float): Tuning parameter that determines the step size at each iteration. - betas: Coefficients used for computing running - averages of gradient and its square. - initial_accumulator_value: Initial value for gradients + betas (Tuple[float, float]): Coefficients used for computing + running averages of gradient and its square. + initial_accumulator_value (float): Initial value for gradients and squared gradients. - epsilon: Value for computational stability. + epsilon (float): Value for computational stability. """ opt = NumPyYogi(params=params, model_interface=model_interface, diff --git a/openfl/interface/aggregator.py b/openfl/interface/aggregator.py index 9a39c75cee..7a82d2a228 100644 --- a/openfl/interface/aggregator.py +++ b/openfl/interface/aggregator.py @@ -22,22 +22,41 @@ @group() @pass_context def aggregator(context): - """Manage Federated Learning Aggregator.""" + """Manage Federated Learning Aggregator. + + Args: + context (click.Context): The context passed from the CLI. + """ context.obj['group'] = 'aggregator' @aggregator.command(name='start') -@option('-p', '--plan', required=False, +@option('-p', + '--plan', + required=False, help='Federated learning plan [plan/plan.yaml]', default='plan/plan.yaml', type=ClickPath(exists=True)) -@option('-c', '--authorized_cols', required=False, +@option('-c', + '--authorized_cols', + required=False, help='Authorized collaborator list [plan/cols.yaml]', - default='plan/cols.yaml', type=ClickPath(exists=True)) -@option('-s', '--secure', required=False, - help='Enable Intel SGX Enclave', is_flag=True, default=False) + default='plan/cols.yaml', + type=ClickPath(exists=True)) +@option('-s', + '--secure', + required=False, + help='Enable Intel SGX Enclave', + is_flag=True, + default=False) def start_(plan, authorized_cols, secure): - """Start the aggregator service.""" + """Start the aggregator service. + + Args: + plan (str): Path to the federated learning plan. + authorized_cols (str): Path to the authorized collaborator list. + secure (bool): Flag to enable Intel SGX Enclave. + """ from pathlib import Path from openfl.federated import Plan @@ -58,16 +77,27 @@ def start_(plan, authorized_cols, secure): @aggregator.command(name='generate-cert-request') -@option('--fqdn', required=False, type=click_types.FQDN, +@option('--fqdn', + required=False, + type=click_types.FQDN, help=f'The fully qualified domain name of' - f' aggregator node [{getfqdn_env()}]', + f' aggregator node [{getfqdn_env()}]', default=getfqdn_env()) def _generate_cert_request(fqdn): + """Create aggregator certificate key pair. + + Args: + fqdn (str): The fully qualified domain name of aggregator node. + """ generate_cert_request(fqdn) def generate_cert_request(fqdn): - """Create aggregator certificate key pair.""" + """Create aggregator certificate key pair. + + Args: + fqdn (str): The fully qualified domain name of aggregator node. + """ from openfl.cryptography.participant import generate_csr from openfl.cryptography.io import write_crt from openfl.cryptography.io import write_key @@ -89,8 +119,8 @@ def generate_cert_request(fqdn): (CERT_DIR / 'server').mkdir(parents=True, exist_ok=True) - echo(' Writing AGGREGATOR certificate key pair to: ' + style( - f'{CERT_DIR}/server', fg='green')) + echo(' Writing AGGREGATOR certificate key pair to: ' + + style(f'{CERT_DIR}/server', fg='green')) # Print csr hash before writing csr to disk csr_hash = get_csr_hash(server_csr) @@ -103,7 +133,14 @@ def generate_cert_request(fqdn): # TODO: function not used def find_certificate_name(file_name): - """Search the CRT for the actual aggregator name.""" + """Search the CRT for the actual aggregator name. + + Args: + file_name (str): The name of the file to search. + + Returns: + str: The name of the aggregator found in the CRT. + """ # This loop looks for the collaborator name in the key with open(file_name, 'r', encoding='utf-8') as f: for line in f: @@ -119,11 +156,22 @@ def find_certificate_name(file_name): default=getfqdn_env()) @option('-s', '--silent', help='Do not prompt', is_flag=True) def _certify(fqdn, silent): + """Sign/certify the aggregator certificate key pair. + + Args: + fqdn (str): The fully qualified domain name of aggregator node. + silent (bool): Flag to enable silent mode. + """ certify(fqdn, silent) def certify(fqdn, silent): - """Sign/certify the aggregator certificate key pair.""" + """Sign/certify the aggregator certificate key pair. + + Args: + fqdn (str): The fully qualified domain name of aggregator node. + silent (bool): Flag to enable silent mode. + """ from pathlib import Path from click import confirm @@ -147,34 +195,37 @@ def certify(fqdn, silent): # Load CSR csr_path_absolute_path = Path(CERT_DIR / f'{cert_name}.csr').absolute() if not csr_path_absolute_path.exists(): - echo(style('Aggregator certificate signing request not found.', fg='red') - + ' Please run `fx aggregator generate-cert-request`' - ' to generate the certificate request.') + echo( + style('Aggregator certificate signing request not found.', + fg='red') + + ' Please run `fx aggregator generate-cert-request`' + ' to generate the certificate request.') csr, csr_hash = read_csr(csr_path_absolute_path) # Load private signing key - private_sign_key_absolute_path = Path(CERT_DIR / signing_key_path).absolute() + private_sign_key_absolute_path = Path(CERT_DIR + / signing_key_path).absolute() if not private_sign_key_absolute_path.exists(): - echo(style('Signing key not found.', fg='red') - + ' Please run `fx workspace certify`' - ' to initialize the local certificate authority.') + echo( + style('Signing key not found.', fg='red') + + ' Please run `fx workspace certify`' + ' to initialize the local certificate authority.') signing_key = read_key(private_sign_key_absolute_path) # Load signing cert signing_crt_absolute_path = Path(CERT_DIR / signing_crt_path).absolute() if not signing_crt_absolute_path.exists(): - echo(style('Signing certificate not found.', fg='red') - + ' Please run `fx workspace certify`' - ' to initialize the local certificate authority.') + echo( + style('Signing certificate not found.', fg='red') + + ' Please run `fx workspace certify`' + ' to initialize the local certificate authority.') signing_crt = read_crt(signing_crt_absolute_path) - echo('The CSR Hash for file ' - + style(f'{cert_name}.csr', fg='green') - + ' = ' - + style(f'{csr_hash}', fg='red')) + echo('The CSR Hash for file ' + style(f'{cert_name}.csr', fg='green') + + ' = ' + style(f'{csr_hash}', fg='red')) crt_path_absolute_path = Path(CERT_DIR / f'{cert_name}.crt').absolute() @@ -182,7 +233,8 @@ def certify(fqdn, silent): echo(' Warning: manual check of certificate hashes is bypassed in silent mode.') echo(' Signing AGGREGATOR certificate') - signed_agg_cert = sign_certificate(csr, signing_key, signing_crt.subject) + signed_agg_cert = sign_certificate(csr, signing_key, + signing_crt.subject) write_crt(signed_agg_cert, crt_path_absolute_path) else: @@ -191,10 +243,12 @@ def certify(fqdn, silent): if confirm('Do you want to sign this certificate?'): echo(' Signing AGGREGATOR certificate') - signed_agg_cert = sign_certificate(csr, signing_key, signing_crt.subject) + signed_agg_cert = sign_certificate(csr, signing_key, + signing_crt.subject) write_crt(signed_agg_cert, crt_path_absolute_path) else: - echo(style('Not signing certificate.', fg='red') - + ' Please check with this AGGREGATOR to get the correct' - ' certificate for this federation.') + echo( + style('Not signing certificate.', fg='red') + + ' Please check with this AGGREGATOR to get the correct' + ' certificate for this federation.') diff --git a/openfl/interface/cli.py b/openfl/interface/cli.py index 25a0eed1eb..3f1dee05bb 100755 --- a/openfl/interface/cli.py +++ b/openfl/interface/cli.py @@ -21,7 +21,12 @@ def setup_logging(level='info', log_file=None): - """Initialize logging settings.""" + """Initialize logging settings. + + Args: + level (str, optional): Logging verbosity level. Defaults to 'info'. + log_file (str, optional): The log file. Defaults to None. + """ import logging from logging import basicConfig @@ -38,15 +43,16 @@ def setup_logging(level='info', log_file=None): if log_file: fh = logging.FileHandler(log_file) formatter = logging.Formatter( - '%(asctime)s %(levelname)s %(message)s %(filename)s:%(lineno)d' - ) + '%(asctime)s %(levelname)s %(message)s %(filename)s:%(lineno)d') fh.setFormatter(formatter) handlers.append(fh) console = Console(width=160) handlers.append(RichHandler(console=console)) - basicConfig(level=level, format='%(message)s', - datefmt='[%X]', handlers=handlers) + basicConfig(level=level, + format='%(message)s', + datefmt='[%X]', + handlers=handlers) def disable_warnings(): @@ -63,35 +69,54 @@ class CLI(Group): """CLI class.""" def __init__(self, name=None, commands=None, **kwargs): - """Initialize.""" + """Initialize CLI object. + + Args: + name (str, optional): Name of the CLI group. Defaults to None. + commands (dict, optional): Commands for the CLI group. Defaults + to None. + **kwargs: Arbitrary keyword arguments. + """ super(CLI, self).__init__(name, commands, **kwargs) self.commands = commands or {} def list_commands(self, ctx): - """Display all available commands.""" + """Display all available commands. + + Args: + ctx (click.core.Context): Click context. + + Returns: + dict: Available commands. + """ return self.commands def format_help(self, ctx, formatter): - """Dislpay user-friendly help.""" + """Display user-friendly help. + + Args: + ctx (click.core.Context): Click context. + formatter (click.formatting.HelpFormatter): Click help formatter. + """ show_header() uses = [ - f'{ctx.command_path}', - '[options]', + f'{ctx.command_path}', '[options]', style('[command]', fg='blue'), - style('[subcommand]', fg='cyan'), - '[args]' + style('[subcommand]', fg='cyan'), '[args]' ] - formatter.write(style('BASH COMPLETE ACTIVATION\n\n', bold=True, fg='bright_black')) + formatter.write( + style('BASH COMPLETE ACTIVATION\n\n', bold=True, + fg='bright_black')) formatter.write( 'Run in terminal:\n' ' _FX_COMPLETE=bash_source fx > ~/.fx-autocomplete.sh\n' ' source ~/.fx-autocomplete.sh\n' 'If ~/.fx-autocomplete.sh has already exist:\n' - ' source ~/.fx-autocomplete.sh\n\n' - ) + ' source ~/.fx-autocomplete.sh\n\n') - formatter.write(style('CORRECT USAGE\n\n', bold=True, fg='bright_black')) + formatter.write( + style('CORRECT USAGE\n\n', bold=True, fg='bright_black')) formatter.write(' '.join(uses) + '\n') opts = [] @@ -100,8 +125,8 @@ def format_help(self, ctx, formatter): if rv is not None: opts.append(rv) - formatter.write(style( - '\nGLOBAL OPTIONS\n\n', bold=True, fg='bright_black')) + formatter.write( + style('\nGLOBAL OPTIONS\n\n', bold=True, fg='bright_black')) formatter.write_dl(opts) cmds = [] @@ -113,8 +138,8 @@ def format_help(self, ctx, formatter): sub = cmd.get_command(ctx, sub) cmds.append((sub.name, sub, 1)) - formatter.write(style( - '\nAVAILABLE COMMANDS\n', bold=True, fg='bright_black')) + formatter.write( + style('\nAVAILABLE COMMANDS\n', bold=True, fg='bright_black')) for name, cmd, level in cmds: help_str = cmd.get_short_help_str() @@ -134,7 +159,13 @@ def format_help(self, ctx, formatter): @option('--no-warnings', is_flag=True, help='Disable third-party warnings.') @pass_context def cli(context, log_level, no_warnings): - """Command-line Interface.""" + """Command-line Interface. + + Args: + context (click.core.Context): Click context. + log_level (str): Logging verbosity level. + no_warnings (bool): Flag to disable third-party warnings. + """ import os from sys import argv @@ -156,7 +187,13 @@ def cli(context, log_level, no_warnings): @cli.result_callback() @pass_context def end(context, result, **kwargs): - """Print the result of the operation.""" + """Print the result of the operation. + + Args: + context (click.core.Context): Click context. + result: Result of the operation. + **kwargs: Arbitrary keyword arguments. + """ if context.obj['fail']: echo('\n ❌ :(') else: @@ -167,39 +204,64 @@ def end(context, result, **kwargs): @pass_context @argument('subcommand', required=False) def help_(context, subcommand): - """Display help.""" + """Display help. + + Args: + context (click.core.Context): Click context. + subcommand (str, optional): Subcommand to display help for. Defaults to None. + """ pass def error_handler(error): - """Handle the error.""" + """Handle the error. + + Args: + error (Exception): Error to handle. + """ if 'cannot import' in str(error): if 'TensorFlow' in str(error): - echo(style('EXCEPTION', fg='red', bold=True) + ' : ' + style( - 'Tensorflow must be installed prior to running this command', - fg='red')) + echo( + style('EXCEPTION', fg='red', bold=True) + ' : ' + style( + 'Tensorflow must be installed prior to running this command', + fg='red')) if 'PyTorch' in str(error): - echo(style('EXCEPTION', fg='red', bold=True) + ' : ' + style( - 'Torch must be installed prior to running this command', - fg='red')) - echo(style('EXCEPTION', fg='red', bold=True) - + ' : ' + style(f'{error}', fg='red')) + echo( + style('EXCEPTION', fg='red', bold=True) + ' : ' + + style('Torch must be installed prior to running this command', + fg='red')) + echo( + style('EXCEPTION', fg='red', bold=True) + ' : ' + + style(f'{error}', fg='red')) raise error def review_plan_callback(file_name, file_path): - """Review plan callback for Director and Envoy.""" - echo(style( - f'Please review the contents of {file_name} before proceeding...', - fg='green', - bold=True)) - # Wait for users to read the question before flashing the contents of the file. + """Review plan callback for Director and Envoy. + + Args: + file_name (str): Name of the file to review. + file_path (str): Path of the file to review. + + Returns: + bool: True if the file is accepted, False otherwise. + """ + echo( + style( + f'Please review the contents of {file_name} before proceeding...', + fg='green', + bold=True)) + # Wait for users to read the question before flashing the contents of the + # file. time.sleep(3) with open_file(file_path, 'r') as f: echo(f.read()) - if confirm(style(f'Do you want to accept the {file_name}?', fg='green', bold=True)): + if confirm( + style(f'Do you want to accept the {file_name}?', + fg='green', + bold=True)): echo(style(f'{file_name} accepted!', fg='green', bold=True)) return True else: @@ -235,7 +297,8 @@ def entry(): root = Path(__file__).parent.resolve() if experimental.exists(): - root = root.parent.joinpath("experimental", "interface", "cli").resolve() + root = root.parent.joinpath("experimental", "interface", + "cli").resolve() work = Path.cwd().resolve() path.append(str(root)) diff --git a/openfl/interface/cli_helper.py b/openfl/interface/cli_helper.py index 748bf990ce..a1bb500d5e 100644 --- a/openfl/interface/cli_helper.py +++ b/openfl/interface/cli_helper.py @@ -23,7 +23,11 @@ def pretty(o): - """Pretty-print the dictionary given.""" + """Pretty-print the dictionary given. + + Args: + o (dict): The dictionary to be printed. + """ m = max(map(len, o.keys())) for k, v in o.items(): @@ -31,7 +35,11 @@ def pretty(o): def tree(path): - """Print current directory file tree.""" + """Print current directory file tree. + + Args: + path (str): The path of the directory. + """ echo(f'+ {path}') for path in sorted(path.rglob('*')): @@ -45,10 +53,19 @@ def tree(path): echo(f'{space}d {path.name}') -def print_tree(dir_path: Path, level: int = -1, +def print_tree(dir_path: Path, + level: int = -1, limit_to_directories: bool = False, length_limit: int = 1000): - """Given a directory Path object print a visual tree structure.""" + """Given a directory Path object print a visual tree structure. + + Args: + dir_path (Path): The directory path. + level (int, optional): The level of the directory. Defaults to -1. + limit_to_directories (bool, optional): Limit to directories. Defaults + to False. + length_limit (int, optional): The length limit. Defaults to 1000. + """ space = ' ' branch = '│ ' tee = '├── ' @@ -74,7 +91,8 @@ def inner(dir_path: Path, prefix: str = '', level=-1): yield prefix + pointer + path.name directories += 1 extension = branch if pointer == tee else space - yield from inner(path, prefix=prefix + extension, + yield from inner(path, + prefix=prefix + extension, level=level - 1) elif not limit_to_directories: yield prefix + pointer + path.name @@ -86,12 +104,31 @@ def inner(dir_path: Path, prefix: str = '', level=-1): echo(line) if next(iterator, None): echo(f'... length_limit, {length_limit}, reached, counted:') - echo(f'\n{directories} directories' + (f', {files} files' if files else '')) - - -def copytree(src, dst, symlinks=False, ignore=None, - ignore_dangling_symlinks=False, dirs_exist_ok=False): - """From Python 3.8 'shutil' which include 'dirs_exist_ok' option.""" + echo(f'\n{directories} directories' + + (f', {files} files' if files else '')) + + +def copytree(src, + dst, + symlinks=False, + ignore=None, + ignore_dangling_symlinks=False, + dirs_exist_ok=False): + """From Python 3.8 'shutil' which include 'dirs_exist_ok' option. + + Args: + src (str): The source directory. + dst (str): The destination directory. + symlinks (bool, optional): Whether to copy symlinks. Defaults to False. + ignore (callable, optional): A function that takes a directory name + and filenames as input parameters and returns a list of names to + ignore. Defaults to None. + ignore_dangling_symlinks (bool, optional): Whether to ignore dangling + symlinks. Defaults to False. + dirs_exist_ok (bool, optional): Whether to raise an exception in case + dst or any missing parent directory already exists. Defaults to + False. + """ import os import shutil @@ -127,19 +164,26 @@ def _copytree(): linkto = os.readlink(srcname) if symlinks: os.symlink(linkto, dstname) - shutil.copystat(srcobj, dstname, + shutil.copystat(srcobj, + dstname, follow_symlinks=not symlinks) else: if (not os.path.exists(linkto) and ignore_dangling_symlinks): continue if srcentry.is_dir(): - copytree(srcobj, dstname, symlinks, ignore, + copytree(srcobj, + dstname, + symlinks, + ignore, dirs_exist_ok=dirs_exist_ok) else: copy_function(srcobj, dstname) elif srcentry.is_dir(): - copytree(srcobj, dstname, symlinks, ignore, + copytree(srcobj, + dstname, + symlinks, + ignore, dirs_exist_ok=dirs_exist_ok) else: copy_function(srcobj, dstname) @@ -160,7 +204,14 @@ def _copytree(): def get_workspace_parameter(name): - """Get a parameter from the workspace config file (.workspace).""" + """Get a parameter from the workspace config file (.workspace). + + Args: + name (str): The name of the parameter. + + Returns: + str: The value of the parameter. + """ # Update the .workspace file to show the current workspace plan workspace_file = '.workspace' @@ -177,7 +228,16 @@ def get_workspace_parameter(name): def check_varenv(env: str = '', args: dict = None): - """Update "args" (dictionary) with if env has a defined value in the host.""" + """Update "args" (dictionary) with if env has a defined + value in the host. + + Args: + env (str, optional): The environment variable. Defaults to ''. + args (dict, optional): The dictionary to be updated. Defaults to None. + + Returns: + args (dict): The updated dictionary. + """ if args is None: args = {} env_val = environ.get(env) @@ -188,7 +248,14 @@ def check_varenv(env: str = '', args: dict = None): def get_fx_path(curr_path=''): - """Return the absolute path to fx binary.""" + """Return the absolute path to fx binary. + + Args: + curr_path (str, optional): The current path. Defaults to ''. + + Returns: + str: The absolute path to fx binary. + """ import re import os @@ -202,7 +269,12 @@ def get_fx_path(curr_path=''): def remove_line_from_file(pkg, filename): - """Remove line that contains `pkg` from the `filename` file.""" + """Remove line that contains `pkg` from the `filename` file. + + Args: + pkg (str): The package to be removed. + filename (str): The name of the file. + """ with open(filename, 'r+', encoding='utf-8') as f: d = f.readlines() f.seek(0) @@ -213,7 +285,13 @@ def remove_line_from_file(pkg, filename): def replace_line_in_file(line, line_num_to_replace, filename): - """Replace line at `line_num_to_replace` with `line`.""" + """Replace line at `line_num_to_replace` with `line`. + + Args: + line (str): The new line. + line_num_to_replace (int): The line number to be replaced. + filename (str): The name of the file. + """ with open(filename, 'r+', encoding='utf-8') as f: d = f.readlines() f.seek(0) diff --git a/openfl/interface/collaborator.py b/openfl/interface/collaborator.py index c28d4f194a..e654d39a28 100644 --- a/openfl/interface/collaborator.py +++ b/openfl/interface/collaborator.py @@ -21,33 +21,59 @@ @group() @pass_context def collaborator(context): - """Manage Federated Learning Collaborators.""" + """Manage Federated Learning Collaborators. + + Args: + context (click.core.Context): Click context. + """ context.obj['group'] = 'service' @collaborator.command(name='start') -@option('-p', '--plan', required=False, +@option('-p', + '--plan', + required=False, help='Federated learning plan [plan/plan.yaml]', default='plan/plan.yaml', type=ClickPath(exists=True)) -@option('-d', '--data_config', required=False, +@option('-d', + '--data_config', + required=False, help='The data set/shard configuration file [plan/data.yaml]', - default='plan/data.yaml', type=ClickPath(exists=True)) -@option('-n', '--collaborator_name', required=True, + default='plan/data.yaml', + type=ClickPath(exists=True)) +@option('-n', + '--collaborator_name', + required=True, help='The certified common name of the collaborator') -@option('-s', '--secure', required=False, - help='Enable Intel SGX Enclave', is_flag=True, default=False) +@option('-s', + '--secure', + required=False, + help='Enable Intel SGX Enclave', + is_flag=True, + default=False) def start_(plan, collaborator_name, data_config, secure): - """Start a collaborator service.""" + """Start a collaborator service. + + Args: + plan (str): Federated learning plan. + collaborator_name (str): The certified common name of the collaborator. + data_config (str): The data set/shard configuration file. + secure (bool): Enable Intel SGX Enclave. + """ from pathlib import Path from openfl.federated import Plan if plan and is_directory_traversal(plan): - echo('Federated learning plan path is out of the openfl workspace scope.') + echo( + 'Federated learning plan path is out of the openfl workspace scope.' + ) sys.exit(1) if data_config and is_directory_traversal(data_config): - echo('The data set/shard configuration file path is out of the openfl workspace scope.') + echo( + 'The data set/shard configuration file path is out of the openfl workspace scope.' + ) sys.exit(1) plan = Plan.parse(plan_config_path=Path(plan).absolute(), @@ -62,18 +88,23 @@ def start_(plan, collaborator_name, data_config, secure): @collaborator.command(name='create') -@option('-n', '--collaborator_name', required=True, +@option('-n', + '--collaborator_name', + required=True, help='The certified common name of the collaborator') -@option('-d', '--data_path', +@option('-d', + '--data_path', help='The data path to be associated with the collaborator') @option('-s', '--silent', help='Do not prompt', is_flag=True) def create_(collaborator_name, data_path, silent): - """Creates a user for an experiment.""" - create(collaborator_name, data_path, silent) + """Creates a user for an experiment. + Args: + collaborator_name (str): The certified common name of the collaborator. + data_path (str): The data path to be associated with the collaborator. + silent (bool): Do not prompt. + """ -def create(collaborator_name, data_path, silent): - """Creates a user for an experiment.""" if data_path and is_directory_traversal(data_path): echo('Data path is out of the openfl workspace scope.') sys.exit(1) @@ -88,9 +119,11 @@ def register_data_path(collaborator_name, data_path=None, silent=False): """Register dataset path in the plan/data.yaml file. Args: - collaborator_name (str): The collaborator whose data path to be defined - data_path (str) : Data path (optional) - silent (bool) : Silent operation (don't prompt) + collaborator_name (str): The collaborator whose data path to be + defined. + data_path (str, optional): Data path. Defaults to None. + silent (bool, optional): Silent operation (don't prompt). Defaults to + False. """ from click import prompt from os.path import isfile @@ -104,8 +137,8 @@ def register_data_path(collaborator_name, data_path=None, silent=False): if not silent and data_path is None: dir_path = prompt('\nWhere is the data (or what is the rank)' ' for collaborator ' - + style(f'{collaborator_name}', fg='green') - + ' ? ', default=default_data_path) + + style(f'{collaborator_name}', fg='green') + ' ? ', + default=default_data_path) elif data_path is not None: dir_path = data_path else: @@ -133,23 +166,37 @@ def register_data_path(collaborator_name, data_path=None, silent=False): @collaborator.command(name='generate-cert-request') -@option('-n', '--collaborator_name', required=True, +@option('-n', + '--collaborator_name', + required=True, help='The certified common name of the collaborator') @option('-s', '--silent', help='Do not prompt', is_flag=True) -@option('-x', '--skip-package', +@option('-x', + '--skip-package', help='Do not package the certificate signing request for export', is_flag=True) -def generate_cert_request_(collaborator_name, - silent, skip_package): - """Generate certificate request for the collaborator.""" +def generate_cert_request_(collaborator_name, silent, skip_package): + """Generate certificate request for the collaborator. + + Args: + collaborator_name (str): The certified common name of the collaborator. + silent (bool): Do not prompt. + skip_package (bool): Do not package the certificate signing request + for export. + """ generate_cert_request(collaborator_name, silent, skip_package) def generate_cert_request(collaborator_name, silent, skip_package): - """ - Create collaborator certificate key pair. + """Create collaborator certificate key pair. Then create a package with the CSR to send for signing. + + Args: + collaborator_name (str): The certified common name of the collaborator. + silent (bool): Do not prompt. + skip_package (bool): Do not package the certificate signing request + for export. """ from openfl.cryptography.participant import generate_csr from openfl.cryptography.io import write_crt @@ -161,9 +208,10 @@ def generate_cert_request(collaborator_name, silent, skip_package): subject_alternative_name = f'DNS:{common_name}' file_name = f'col_{common_name}' - echo(f'Creating COLLABORATOR certificate key pair with following settings: ' - f'CN={style(common_name, fg="red")},' - f' SAN={style(subject_alternative_name, fg="red")}') + echo( + f'Creating COLLABORATOR certificate key pair with following settings: ' + f'CN={style(common_name, fg="red")},' + f' SAN={style(subject_alternative_name, fg="red")}') client_private_key, client_csr = generate_csr(common_name, server=False) @@ -218,7 +266,14 @@ def generate_cert_request(collaborator_name, silent, skip_package): def find_certificate_name(file_name): - """Parse the collaborator name.""" + """Parse the collaborator name. + + Args: + file_name (str): The name of the collaborator in this federation. + + Returns: + col_name (str): The collaborator name. + """ col_name = str(file_name).split(os.sep)[-1].split('.')[0][4:] return col_name @@ -227,8 +282,7 @@ def register_collaborator(file_name): """Register the collaborator name in the cols.yaml list. Args: - file_name (str): The name of the collaborator in this federation - + file_name (str): The name of the collaborator in this federation. """ from os.path import isfile from yaml import dump @@ -254,10 +308,8 @@ def register_collaborator(file_name): if col_name in doc['collaborators']: - echo('\nCollaborator ' - + style(f'{col_name}', fg='green') - + ' is already in the ' - + style(f'{cols_file}', fg='green')) + echo('\nCollaborator ' + style(f'{col_name}', fg='green') + + ' is already in the ' + style(f'{cols_file}', fg='green')) else: @@ -265,30 +317,52 @@ def register_collaborator(file_name): with open(cols_file, 'w', encoding='utf-8') as f: dump(doc, f) - echo('\nRegistering ' - + style(f'{col_name}', fg='green') - + ' in ' + echo('\nRegistering ' + style(f'{col_name}', fg='green') + ' in ' + style(f'{cols_file}', fg='green')) @collaborator.command(name='certify') -@option('-n', '--collaborator_name', +@option('-n', + '--collaborator_name', help='The certified common name of the collaborator. This is only' - ' needed for single node expiriments') + ' needed for single node expiriments') @option('-s', '--silent', help='Do not prompt', is_flag=True) -@option('-r', '--request-pkg', type=ClickPath(exists=True), +@option('-r', + '--request-pkg', + type=ClickPath(exists=True), help='The archive containing the certificate signing' - ' request (*.zip) for a collaborator') -@option('-i', '--import', 'import_', type=ClickPath(exists=True), + ' request (*.zip) for a collaborator') +@option('-i', + '--import', + 'import_', + type=ClickPath(exists=True), help='Import the archive containing the collaborator\'s' - ' certificate (signed by the CA)') + ' certificate (signed by the CA)') def certify_(collaborator_name, silent, request_pkg, import_): - """Certify the collaborator.""" + """Certify the collaborator. + + Args: + collaborator_name (str): The certified common name of the collaborator. + silent (bool): Do not prompt. + request_pkg (str): The archive containing the certificate signing + request (*.zip) for a collaborator. + import_ (str): Import the archive containing the collaborator's + certificate (signed by the CA). + """ certify(collaborator_name, silent, request_pkg, import_) def certify(collaborator_name, silent, request_pkg=None, import_=False): - """Sign/certify collaborator certificate key pair.""" + """Sign/certify collaborator certificate key pair. + + Args: + collaborator_name (str): The certified common name of the collaborator. + silent (bool): Do not prompt. + request_pkg (str, optional): The archive containing the certificate + signing request (*.zip) for a collaborator. Defaults to None. + import_ (bool, optional): Import the archive containing the + collaborator's certificate (signed by the CA). Defaults to False. + """ from click import confirm from pathlib import Path from shutil import copy @@ -373,14 +447,16 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False): if confirm('Do you want to sign this certificate?'): echo(' Signing COLLABORATOR certificate') - signed_col_cert = sign_certificate(csr, signing_key, signing_crt.subject) + signed_col_cert = sign_certificate(csr, signing_key, + signing_crt.subject) write_crt(signed_col_cert, f'{cert_name}.crt') register_collaborator(CERT_DIR / 'client' / f'{file_name}.crt') else: - echo(style('Not signing certificate.', fg='red') - + ' Please check with this collaborator to get the' - ' correct certificate for this federation.') + echo( + style('Not signing certificate.', fg='red') + + ' Please check with this collaborator to get the' + ' correct certificate for this federation.') return if len(common_name) == 0: diff --git a/openfl/interface/director.py b/openfl/interface/director.py index a3fd229e25..79e9753bd2 100644 --- a/openfl/interface/director.py +++ b/openfl/interface/director.py @@ -25,26 +25,56 @@ @group() @pass_context def director(context): - """Manage Federated Learning Director.""" + """Manage Federated Learning Director. + + Args: + context (click.core.Context): Click context. + """ context.obj['group'] = 'director' @director.command(name='start') -@option('-c', '--director-config-path', default='director.yaml', - help='The director config file path', type=ClickPath(exists=True)) -@option('--tls/--disable-tls', default=True, - is_flag=True, help='Use TLS or not (By default TLS is enabled)') -@option('-rc', '--root-cert-path', 'root_certificate', required=False, - type=ClickPath(exists=True), default=None, +@option('-c', + '--director-config-path', + default='director.yaml', + help='The director config file path', + type=ClickPath(exists=True)) +@option('--tls/--disable-tls', + default=True, + is_flag=True, + help='Use TLS or not (By default TLS is enabled)') +@option('-rc', + '--root-cert-path', + 'root_certificate', + required=False, + type=ClickPath(exists=True), + default=None, help='Path to a root CA cert') -@option('-pk', '--private-key-path', 'private_key', required=False, - type=ClickPath(exists=True), default=None, +@option('-pk', + '--private-key-path', + 'private_key', + required=False, + type=ClickPath(exists=True), + default=None, help='Path to a private key') -@option('-oc', '--public-cert-path', 'certificate', required=False, - type=ClickPath(exists=True), default=None, +@option('-oc', + '--public-cert-path', + 'certificate', + required=False, + type=ClickPath(exists=True), + default=None, help='Path to a signed certificate') -def start(director_config_path, tls, root_certificate, private_key, certificate): - """Start the director service.""" +def start(director_config_path, tls, root_certificate, private_key, + certificate): + """Start the director service. + + Args: + director_config_path (str): The director config file path. + tls (bool): Use TLS or not. + root_certificate (str): Path to a root CA cert. + private_key (str): Path to a private key. + certificate (str): Path to a signed certificate. + """ from openfl.component.director import Director from openfl.transport import DirectorGRPCServer @@ -52,7 +82,9 @@ def start(director_config_path, tls, root_certificate, private_key, certificate) director_config_path = Path(director_config_path).absolute() logger.info('🧿 Starting the Director Service.') if is_directory_traversal(director_config_path): - click.echo('The director config file path is out of the openfl workspace scope.') + click.echo( + 'The director config file path is out of the openfl workspace scope.' + ) sys.exit(1) config = merge_configs( settings_files=director_config_path, @@ -63,13 +95,18 @@ def start(director_config_path, tls, root_certificate, private_key, certificate) }, validators=[ Validator('settings.listen_host', default='localhost'), - Validator('settings.listen_port', default=50051, gte=1024, lte=65535), + Validator('settings.listen_port', + default=50051, + gte=1024, + lte=65535), Validator('settings.sample_shape', default=[]), Validator('settings.target_shape', default=[]), Validator('settings.install_requirements', default=False), - Validator('settings.envoy_health_check_period', - default=60, # in seconds - gte=1, lte=24 * 60 * 60), + Validator( + 'settings.envoy_health_check_period', + default=60, # in seconds + gte=1, + lte=24 * 60 * 60), Validator('settings.review_experiment', default=False), ], value_transform=[ @@ -78,10 +115,8 @@ def start(director_config_path, tls, root_certificate, private_key, certificate) ], ) - logger.info( - f'Sample shape: {config.settings.sample_shape}, ' - f'target shape: {config.settings.target_shape}' - ) + logger.info(f'Sample shape: {config.settings.sample_shape}, ' + f'target shape: {config.settings.target_shape}') if config.root_certificate: config.root_certificate = Path(config.root_certificate).absolute() @@ -110,24 +145,32 @@ def start(director_config_path, tls, root_certificate, private_key, certificate) listen_port=config.settings.listen_port, review_plan_callback=overwritten_review_plan_callback, envoy_health_check_period=config.settings.envoy_health_check_period, - install_requirements=config.settings.install_requirements - ) + install_requirements=config.settings.install_requirements) director_server.start() @director.command(name='create-workspace') -@option('-p', '--director-path', required=True, - help='The director path', type=ClickPath()) +@option('-p', + '--director-path', + required=True, + help='The director path', + type=ClickPath()) def create(director_path): - """Create a director workspace.""" + """Create a director workspace. + + Args: + director_path (str): The director path. + """ if is_directory_traversal(director_path): click.echo('The director path is out of the openfl workspace scope.') sys.exit(1) director_path = Path(director_path).absolute() if director_path.exists(): - if not click.confirm('Director workspace already exists. Recreate?', default=True): + if not click.confirm('Director workspace already exists. Recreate?', + default=True): sys.exit(1) shutil.rmtree(director_path) (director_path / 'cert').mkdir(parents=True, exist_ok=True) (director_path / 'logs').mkdir(parents=True, exist_ok=True) - shutil.copyfile(WORKSPACE / 'default/director.yaml', director_path / 'director.yaml') + shutil.copyfile(WORKSPACE / 'default/director.yaml', + director_path / 'director.yaml') diff --git a/openfl/interface/envoy.py b/openfl/interface/envoy.py index b974fc6b11..a5ee16c783 100644 --- a/openfl/interface/envoy.py +++ b/openfl/interface/envoy.py @@ -27,36 +27,74 @@ @group() @pass_context def envoy(context): - """Manage Federated Learning Envoy.""" + """Manage Federated Learning Envoy. + + Args: + context (click.core.Context): Click context. + """ context.obj['group'] = 'envoy' @envoy.command(name='start') -@option('-n', '--shard-name', required=True, - help='Current shard name') -@option('-dh', '--director-host', required=True, - help='The FQDN of the federation director', type=click_types.FQDN) -@option('-dp', '--director-port', required=True, - help='The federation director port', type=click.IntRange(1, 65535)) -@option('--tls/--disable-tls', default=True, - is_flag=True, help='Use TLS or not (By default TLS is enabled)') -@option('-ec', '--envoy-config-path', default='envoy_config.yaml', - help='The envoy config path', type=ClickPath(exists=True)) -@option('-rc', '--root-cert-path', 'root_certificate', default=None, - help='Path to a root CA cert', type=ClickPath(exists=True)) -@option('-pk', '--private-key-path', 'private_key', default=None, - help='Path to a private key', type=ClickPath(exists=True)) -@option('-oc', '--public-cert-path', 'certificate', default=None, - help='Path to a signed certificate', type=ClickPath(exists=True)) +@option('-n', '--shard-name', required=True, help='Current shard name') +@option('-dh', + '--director-host', + required=True, + help='The FQDN of the federation director', + type=click_types.FQDN) +@option('-dp', + '--director-port', + required=True, + help='The federation director port', + type=click.IntRange(1, 65535)) +@option('--tls/--disable-tls', + default=True, + is_flag=True, + help='Use TLS or not (By default TLS is enabled)') +@option('-ec', + '--envoy-config-path', + default='envoy_config.yaml', + help='The envoy config path', + type=ClickPath(exists=True)) +@option('-rc', + '--root-cert-path', + 'root_certificate', + default=None, + help='Path to a root CA cert', + type=ClickPath(exists=True)) +@option('-pk', + '--private-key-path', + 'private_key', + default=None, + help='Path to a private key', + type=ClickPath(exists=True)) +@option('-oc', + '--public-cert-path', + 'certificate', + default=None, + help='Path to a signed certificate', + type=ClickPath(exists=True)) def start_(shard_name, director_host, director_port, tls, envoy_config_path, root_certificate, private_key, certificate): - """Start the Envoy.""" + """Start the Envoy. + + Args: + shard_name (str): Current shard name. + director_host (str): The FQDN of the federation director. + director_port (int): The federation director port. + tls (bool): Use TLS or not. + envoy_config_path (str): The envoy config path. + root_certificate (str): Path to a root CA cert. + private_key (str): Path to a private key. + certificate (str): Path to a signed certificate. + """ from openfl.component.envoy.envoy import Envoy logger.info('🧿 Starting the Envoy.') if is_directory_traversal(envoy_config_path): - click.echo('The shard config path is out of the openfl workspace scope.') + click.echo( + 'The shard config path is out of the openfl workspace scope.') sys.exit(1) config = merge_configs( @@ -107,28 +145,34 @@ def start_(shard_name, director_host, director_port, tls, envoy_config_path, del envoy_params.review_experiment # Instantiate Shard Descriptor - shard_descriptor = shard_descriptor_from_config(config.get('shard_descriptor', {})) - envoy = Envoy( - shard_name=shard_name, - director_host=director_host, - director_port=director_port, - tls=tls, - shard_descriptor=shard_descriptor, - root_certificate=config.root_certificate, - private_key=config.private_key, - certificate=config.certificate, - review_plan_callback=overwritten_review_plan_callback, - **envoy_params - ) + shard_descriptor = shard_descriptor_from_config( + config.get('shard_descriptor', {})) + envoy = Envoy(shard_name=shard_name, + director_host=director_host, + director_port=director_port, + tls=tls, + shard_descriptor=shard_descriptor, + root_certificate=config.root_certificate, + private_key=config.private_key, + certificate=config.certificate, + review_plan_callback=overwritten_review_plan_callback, + **envoy_params) envoy.start() @envoy.command(name='create-workspace') -@option('-p', '--envoy-path', required=True, - help='The Envoy path', type=ClickPath()) +@option('-p', + '--envoy-path', + required=True, + help='The Envoy path', + type=ClickPath()) def create(envoy_path): - """Create an envoy workspace.""" + """Create an envoy workspace. + + Args: + envoy_path (str): The Envoy path. + """ if is_directory_traversal(envoy_path): click.echo('The Envoy path is out of the openfl workspace scope.') sys.exit(1) @@ -150,7 +194,14 @@ def create(envoy_path): def shard_descriptor_from_config(shard_config: dict): - """Build a shard descriptor from config.""" + """Build a shard descriptor from config. + + Args: + shard_config (dict): Shard configuration. + + Returns: + instance: Shard descriptor instance. + """ template = shard_config.get('template') if not template: raise Exception('You should define a shard ' diff --git a/openfl/interface/experimental.py b/openfl/interface/experimental.py index d7622ea25f..fbf8646b2e 100644 --- a/openfl/interface/experimental.py +++ b/openfl/interface/experimental.py @@ -20,8 +20,7 @@ def experimental(context): @experimental.command(name="activate") def activate(): """Activate experimental environment.""" - settings = Path("~").expanduser().joinpath( - ".openfl").resolve() + settings = Path("~").expanduser().joinpath(".openfl").resolve() settings.mkdir(parents=False, exist_ok=True) settings = settings.joinpath("experimental").resolve() @@ -30,13 +29,11 @@ def activate(): import openfl rf = Path(openfl.__file__).parent.parent.resolve().joinpath( - "openfl-tutorials", "experimental", "requirements_workflow_interface.txt").resolve() + "openfl-tutorials", "experimental", + "requirements_workflow_interface.txt").resolve() if rf.is_file(): - check_call( - [executable, '-m', 'pip', 'install', '-r', rf], - shell=False - ) + check_call([executable, '-m', 'pip', 'install', '-r', rf], shell=False) else: logger.warning(f"Requirements file {rf} not found.") diff --git a/openfl/interface/interactive_api/experiment.py b/openfl/interface/interactive_api/experiment.py index d68baf11cc..a3c6d47fa7 100644 --- a/openfl/interface/interactive_api/experiment.py +++ b/openfl/interface/interactive_api/experiment.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Python low-level API module.""" import os import time @@ -27,7 +26,17 @@ class ModelStatus: - """Model statuses.""" + """Model statuses. + + This class defines the various statuses a model can have during an + experiment. + + Attributes: + INITIAL (str): Initial status of the model. + BEST (str): Status indicating the model with the best score. + LAST (str): Status indicating the last model used in the experiment. + RESTORED (str): Status indicating a model that has been restored. + """ INITIAL = 'initial' BEST = 'best' @@ -36,23 +45,46 @@ class ModelStatus: class FLExperiment: - """Central class for FL experiment orchestration.""" + """Central class for FL experiment orchestration. + + This class is responsible for orchestrating the federated learning + experiment. It manages + the experiment's lifecycle and interacts with the federation. + + Attributes: + federation: The federation that this experiment is part of. + experiment_name (str): The name of the experiment. + summary_writer (SummaryWriter): The summary writer. + serializer_plugin (str): The serializer plugin to use. + experiment_submitted (bool): Whether the experiment has been submitted. + is_validate_task_exist (bool): Whether a validate task exists. + logger (Logger): The logger to use. + plan (Plan): The plan for the experiment. + """ def __init__( - self, - federation, - experiment_name: str = None, - serializer_plugin: str = 'openfl.plugins.interface_serializer.' - 'cloudpickle_serializer.CloudpickleSerializer' + self, + federation, + experiment_name: str = None, + serializer_plugin: str = 'openfl.plugins.interface_serializer.' + 'cloudpickle_serializer.CloudpickleSerializer' ) -> None: - """ - Initialize an experiment inside a federation. + """Initialize an experiment inside a federation. Experiment makes sense in a scope of some machine learning problem. - Information about the data on collaborators is contained on the federation level. + Information about the data on collaborators is contained on the + federation level. + + Args: + federation: The federation that this experiment is part of. + experiment_name (str, optional): The name of the experiment. + Defaults to None. + serializer_plugin (str, optional): The serializer plugin. Defaults + to 'openfl.plugins.interface_serializer.cloudpickle_serializer.CloudpickleSerializer'. """ self.federation = federation - self.experiment_name = experiment_name or 'test-' + time.strftime('%Y%m%d-%H%M%S') + self.experiment_name = experiment_name or 'test-' + time.strftime( + '%Y%m%d-%H%M%S') self.summary_writer = None self.serializer_plugin = serializer_plugin @@ -82,11 +114,9 @@ def _assert_experiment_submitted(self): """Assure experiment is sent to director and accepted.""" if not self.experiment_submitted: self.logger.error( - 'The experiment was not submitted to a Director service.' - ) - self.logger.error( - 'Report the experiment first: ' - 'use the Experiment.start() method.') + 'The experiment was not submitted to a Director service.') + self.logger.error('Report the experiment first: ' + 'use the Experiment.start() method.') return False return True @@ -95,8 +125,7 @@ def get_experiment_status(self): if not self._assert_experiment_submitted(): return exp_status = self.federation.dir_client.get_experiment_status( - experiment_name=self.experiment_name - ) + experiment_name=self.experiment_name) return exp_status.experiment_status def get_best_model(self): @@ -106,7 +135,8 @@ def get_best_model(self): tensor_dict = self.federation.dir_client.get_best_model( experiment_name=self.experiment_name) - return self._rebuild_model(tensor_dict, upcoming_model_status=ModelStatus.BEST) + return self._rebuild_model(tensor_dict, + upcoming_model_status=ModelStatus.BEST) def get_last_model(self): """Retrieve the aggregated model after the last round.""" @@ -115,10 +145,28 @@ def get_last_model(self): tensor_dict = self.federation.dir_client.get_last_model( experiment_name=self.experiment_name) - return self._rebuild_model(tensor_dict, upcoming_model_status=ModelStatus.LAST) + return self._rebuild_model(tensor_dict, + upcoming_model_status=ModelStatus.LAST) + + def _rebuild_model(self, + tensor_dict, + upcoming_model_status=ModelStatus.BEST): + """Use tensor dict to update model weights. - def _rebuild_model(self, tensor_dict, upcoming_model_status=ModelStatus.BEST): - """Use tensor dict to update model weights.""" + This method updates the model weights using the provided tensor + dictionary. If the tensor dictionary is empty, it logs a warning and + returns the current model. Otherwise, it rebuilds the model with the + new weights and updates the current model status. + + Args: + tensor_dict (dict): A dictionary containing tensor names as keys + and tensor values as values. + upcoming_model_status (ModelStatus, optional): The upcoming status + of the model. Defaults to ModelStatus.BEST. + + Returns: + The updated model. + """ if len(tensor_dict) == 0: warning_msg = ('No tensors received from director\n' 'Possible reasons:\n' @@ -133,7 +181,9 @@ def _rebuild_model(self, tensor_dict, upcoming_model_status=ModelStatus.BEST): self.logger.warning(warning_msg) else: - self.task_runner_stub.rebuild_model(tensor_dict, validation=True, device='cpu') + self.task_runner_stub.rebuild_model(tensor_dict, + validation=True, + device='cpu') self.current_model_status = upcoming_model_status return deepcopy(self.task_runner_stub.model) @@ -142,7 +192,8 @@ def stream_metrics(self, tensorboard_logs: bool = True) -> None: """Stream metrics.""" if not self._assert_experiment_submitted(): return - for metric_message_dict in self.federation.dir_client.stream_metrics(self.experiment_name): + for metric_message_dict in self.federation.dir_client.stream_metrics( + self.experiment_name): self.logger.metric( f'Round {metric_message_dict["round"]}, ' f'collaborator {metric_message_dict["metric_origin"]} ' @@ -156,7 +207,8 @@ def stream_metrics(self, tensorboard_logs: bool = True) -> None: def write_tensorboard_metric(self, metric: dict) -> None: """Write metric callback.""" if not self.summary_writer: - self.summary_writer = SummaryWriter(f'./logs/{self.experiment_name}', flush_secs=5) + self.summary_writer = SummaryWriter( + f'./logs/{self.experiment_name}', flush_secs=5) self.summary_writer.add_scalar( f'{metric["metric_origin"]}/{metric["task_name"]}/{metric["metric_name"]}', @@ -168,8 +220,7 @@ def remove_experiment_data(self): return log_message = 'Removing experiment data ' if self.federation.dir_client.remove_experiment_data( - name=self.experiment_name - ): + name=self.experiment_name): log_message += 'succeed.' self.experiment_submitted = False else: @@ -177,14 +228,37 @@ def remove_experiment_data(self): self.logger.info(log_message) - def prepare_workspace_distribution(self, model_provider, task_keeper, data_loader, + def prepare_workspace_distribution(self, + model_provider, + task_keeper, + data_loader, task_assigner, pip_install_options: Tuple[str] = ()): - """Prepare an archive from a user workspace.""" + """Prepare an archive from a user workspace. + + This method serializes interface objects and saves them to disk, + dumps the prepared plan, prepares a requirements file to restore + the Python environment, and compresses the workspace to restore it on + a collaborator. + + Args: + model_provider: The model provider object. + task_keeper: The task keeper object. + data_loader: The data loader object. + task_assigner: The task assigner object. + pip_install_options (tuple, optional): A tuple of options for pip + install. Defaults to an empty tuple. + + Returns: + None + """ # Save serialized python objects to disc - self._serialize_interface_objects(model_provider, task_keeper, data_loader, task_assigner) + self._serialize_interface_objects(model_provider, task_keeper, + data_loader, task_assigner) # Save the prepared plan - Plan.dump(Path(f'./plan/{self.plan.name}'), self.plan.config, freeze=False) + Plan.dump(Path(f'./plan/{self.plan.name}'), + self.plan.config, + freeze=False) # PACK the WORKSPACE! # Prepare requirements file to restore python env @@ -194,56 +268,71 @@ def prepare_workspace_distribution(self, model_provider, task_keeper, data_loade # Compress te workspace to restore it on collaborator self.arch_path = self._pack_the_workspace() - def start(self, *, model_provider, task_keeper, data_loader, - rounds_to_train: int, - task_assigner=None, - override_config: dict = None, - delta_updates: bool = False, - opt_treatment: str = 'RESET', - device_assignment_policy: str = 'CPU_ONLY', - pip_install_options: Tuple[str] = ()) -> None: - """ - Prepare workspace distribution and send to Director. - - A successful call of this function will result in sending the experiment workspace - to the Director service and experiment start. - - Parameters: - model_provider - Model Interface instance. - task_keeper - Task Interface instance. - data_loader - Data Interface instance. - rounds_to_train - required number of training rounds for the experiment. - delta_updates - [bool] Tells if collaborators should send delta updates - for the locally tuned models. If set to False, whole checkpoints will be sent. - opt_treatment - Optimizer state treatment policy. - Valid options: 'RESET' - reinitialize optimizer for every round, - 'CONTINUE_LOCAL' - keep local optimizer state, - 'CONTINUE_GLOBAL' - aggregate optimizer state. - device_assignment_policy - device assignment policy. - Valid options: 'CPU_ONLY' - device parameter passed to tasks - will always be 'cpu', - 'CUDA_PREFERRED' - enable passing CUDA device identifiers to tasks - by collaborators, works with cuda-device-monitor plugin equipped Envoys. - pip_install_options - tuple of options for the remote `pip install` calls, - example: ('-f some.website', '--no-index') + def start( + self, + *, + model_provider, + task_keeper, + data_loader, + rounds_to_train: int, + task_assigner=None, + override_config: dict = None, + delta_updates: bool = False, + opt_treatment: str = 'RESET', + device_assignment_policy: str = 'CPU_ONLY', + pip_install_options: Tuple[str] = () + ) -> None: + """Prepare workspace distribution and send to Director. + + A successful call of this function will result in sending the + experiment workspace to the Director service and experiment start. + + Args: + model_provider: Model Interface instance. + task_keeper: Task Interface instance. + data_loader: Data Interface instance. + rounds_to_train (int): Required number of training rounds for the + experiment. + task_assigner (optional): Task assigner instance. Defaults to None. + override_config (dict, optional): Configuration to override the + default settings. Defaults to None. + delta_updates (bool, optional): Flag to indicate if delta updates + should be sent. Defaults to False. + opt_treatment (str, optional): Optimizer state treatment policy. + Defaults to 'RESET'. + Valid options: 'RESET' - reinitialize optimizer for every + round, + 'CONTINUE_LOCAL' - keep local optimizer state, + 'CONTINUE_GLOBAL' - aggregate optimizer state. + device_assignment_policy (str, optional): Device assignment policy. + Defaults to 'CPU_ONLY'. + Valid options: 'CPU_ONLY' - device parameter passed to tasks + will always be 'cpu', + 'CUDA_PREFERRED' - enable passing CUDA device identifiers to + tasks by collaborators, works with cuda-device-monitor plugin + equipped Envoys. + pip_install_options (Tuple[str], optional): Options for the remote + `pip install` calls. Defaults to (). + example: ('-f some.website', '--no-index') """ if not task_assigner: - task_assigner = self.define_task_assigner(task_keeper, rounds_to_train) + task_assigner = self.define_task_assigner(task_keeper, + rounds_to_train) - self._prepare_plan(model_provider, data_loader, + self._prepare_plan(model_provider, + data_loader, rounds_to_train, - delta_updates=delta_updates, opt_treatment=opt_treatment, + delta_updates=delta_updates, + opt_treatment=opt_treatment, device_assignment_policy=device_assignment_policy, override_config=override_config, model_interface_file='model_obj.pkl', tasks_interface_file='tasks_obj.pkl', dataloader_interface_file='loader_obj.pkl') - self.prepare_workspace_distribution( - model_provider, task_keeper, data_loader, - task_assigner, - pip_install_options - ) + self.prepare_workspace_distribution(model_provider, task_keeper, + data_loader, task_assigner, + pip_install_options) self.logger.info('Starting experiment!') self.plan.resolve() @@ -253,8 +342,7 @@ def start(self, *, model_provider, task_keeper, data_loader, name=self.experiment_name, col_names=self.plan.authorized_cols, arch_path=self.arch_path, - initial_tensor_dict=initial_tensor_dict - ) + initial_tensor_dict=initial_tensor_dict) finally: self.remove_workspace_archive() @@ -262,10 +350,34 @@ def start(self, *, model_provider, task_keeper, data_loader, self.logger.info('Experiment was submitted to the director!') self.experiment_submitted = True else: - self.logger.info('Experiment could not be submitted to the director.') + self.logger.info( + 'Experiment could not be submitted to the director.') def define_task_assigner(self, task_keeper, rounds_to_train): - """Define task assigner by registered tasks.""" + """Define task assigner by registered tasks. + + This method defines a task assigner based on the registered tasks. + It checks if there are 'train' and 'validate' tasks among the + registered tasks and defines the task assigner accordingly. If there + are both 'train' and 'validate' tasks, the task assigner assigns these + tasks to each collaborator. If there are only 'validate' tasks, the + task assigner assigns only these tasks to each collaborator. + If there are no 'train' or 'validate' tasks, an exception is raised. + + Args: + task_keeper: The task keeper object that holds the registered + tasks. + rounds_to_train (int): The number of rounds to train. + + Returns: + assigner: A function that assigns tasks to each collaborator. + + Raises: + Exception: If there are no 'train' tasks and rounds_to_train is + not 1. + Exception: If there are no 'validate' tasks. + Exception: If there are no 'train' or 'validate' tasks. + """ tasks = task_keeper.get_registered_tasks() is_train_task_exist = False self.is_validate_task_exist = False @@ -276,10 +388,12 @@ def define_task_assigner(self, task_keeper, rounds_to_train): self.is_validate_task_exist = True if not is_train_task_exist and rounds_to_train != 1: - # Since we have only validation tasks, we do not have to train it multiple times + # Since we have only validation tasks, we do not have to train it + # multiple times raise Exception('Variable rounds_to_train must be equal 1, ' 'because only validation tasks were given') if is_train_task_exist and self.is_validate_task_exist: + def assigner(collaborators, round_number, **kwargs): tasks_by_collaborator = {} for collaborator in collaborators: @@ -289,8 +403,10 @@ def assigner(collaborators, round_number, **kwargs): tasks['aggregated_model_validate'], ] return tasks_by_collaborator + return assigner elif not is_train_task_exist and self.is_validate_task_exist: + def assigner(collaborators, round_number, **kwargs): tasks_by_collaborator = {} for collaborator in collaborators: @@ -298,6 +414,7 @@ def assigner(collaborators, round_number, **kwargs): tasks['aggregated_model_validate'], ] return tasks_by_collaborator + return assigner elif is_train_task_exist and not self.is_validate_task_exist: raise Exception('You should define validate task!') @@ -305,8 +422,19 @@ def assigner(collaborators, round_number, **kwargs): raise Exception('You should define train and validate tasks!') def restore_experiment_state(self, model_provider): - """Restore accepted experiment object.""" - self.task_runner_stub = self.plan.get_core_task_runner(model_provider=model_provider) + """Restores the state of an accepted experiment object. + + This method restores the state of an accepted experiment object by + getting the core task runner from the plan and setting the current + model status to RESTORED. It also sets the experiment_submitted + attribute to True. + + Args: + model_provider: The provider of the model used in the experiment. + + """ + self.task_runner_stub = self.plan.get_core_task_runner( + model_provider=model_provider) self.current_model_status = ModelStatus.RESTORED self.experiment_submitted = True @@ -328,13 +456,14 @@ def _pack_the_workspace(): tmp_dir = 'temp_' + archive_name makedirs(tmp_dir, exist_ok=True) - ignore = ignore_patterns( - '__pycache__', 'data', 'cert', tmp_dir, '*.crt', '*.key', - '*.csr', '*.srl', '*.pem', '*.pbuf', '*zip') + ignore = ignore_patterns('__pycache__', 'data', 'cert', tmp_dir, + '*.crt', '*.key', '*.csr', '*.srl', '*.pem', + '*.pbuf', '*zip') copytree('./', tmp_dir + '/workspace', ignore=ignore) - arch_path = make_archive(archive_name, archive_type, tmp_dir + '/workspace') + arch_path = make_archive(archive_name, archive_type, + tmp_dir + '/workspace') rmtree(tmp_dir) @@ -346,50 +475,97 @@ def remove_workspace_archive(self): del self.arch_path def _get_initial_tensor_dict(self, model_provider): - """Extract initial weights from the model.""" - self.task_runner_stub = self.plan.get_core_task_runner(model_provider=model_provider) + """Extracts initial weights from the model. + + This method extracts the initial weights from the model by getting the + core task runner from the plan and setting the current model status to + INITIAL. It then splits the tensor dictionary for holdouts and returns + the tensor dictionary. + + Args: + model_provider: The provider of the model used in the experiment. + + Returns: + dict: The tensor dictionary. + """ + self.task_runner_stub = self.plan.get_core_task_runner( + model_provider=model_provider) self.current_model_status = ModelStatus.INITIAL tensor_dict, _ = split_tensor_dict_for_holdouts( - self.logger, - self.task_runner_stub.get_tensor_dict(False), - **self.task_runner_stub.tensor_dict_split_fn_kwargs - ) + self.logger, self.task_runner_stub.get_tensor_dict(False), + **self.task_runner_stub.tensor_dict_split_fn_kwargs) return tensor_dict - def _prepare_plan(self, model_provider, data_loader, - rounds_to_train, - delta_updates, opt_treatment, - device_assignment_policy, - override_config=None, - model_interface_file='model_obj.pkl', tasks_interface_file='tasks_obj.pkl', - dataloader_interface_file='loader_obj.pkl', - aggregation_function_interface_file='aggregation_function_obj.pkl', - task_assigner_file='task_assigner_obj.pkl'): - """Fill plan.yaml file using user provided setting.""" + def _prepare_plan( + self, + model_provider, + data_loader, + rounds_to_train, + delta_updates, + opt_treatment, + device_assignment_policy, + override_config=None, + model_interface_file='model_obj.pkl', + tasks_interface_file='tasks_obj.pkl', + dataloader_interface_file='loader_obj.pkl', + aggregation_function_interface_file='aggregation_function_obj.pkl', + task_assigner_file='task_assigner_obj.pkl'): + """Fills the plan.yaml file using user-provided settings. + + It sets up the network, aggregator, collaborator, data loader, task + runner, and API layer according to the user's specifications. + + Args: + model_provider: The provider of the model used in the experiment. + data_loader: The data loader to be used in the experiment. + rounds_to_train (int): The number of rounds to train. + delta_updates (bool): Whether to use delta updates. + opt_treatment (str): The optimization treatment to be used. + device_assignment_policy (str): The device assignment policy to be + used. + override_config (dict, optional): The configuration to override + the default settings. + model_interface_file (str, optional): The file for the model + interface. Defaults to 'model_obj.pkl'. + tasks_interface_file (str, optional): The file for the tasks + interface. Defaults to 'tasks_obj.pkl'. + dataloader_interface_file (str, optional): The file for the data + loader interface. Defaults to 'loader_obj.pkl'. + aggregation_function_interface_file (str, optional): The file for + the aggregation function interface. Defaults to + 'aggregation_function_obj.pkl'. + task_assigner_file (str, optional): The file for the task assigner. + Defaults to 'task_assigner_obj.pkl'. + """ # Seems like we still need to fill authorized_cols list # So aggregator know when to start sending tasks - # We also could change the aggregator logic so it will send tasks to aggregator - # as soon as it connects. This change should be a part of a bigger PR - # brining in fault tolerance changes + # We also could change the aggregator logic so it will send tasks to + # aggregator as soon as it connects. This change should be a part of a + # bigger PR brining in fault tolerance changes shard_registry = self.federation.get_shard_registry() self.plan.authorized_cols = [ name for name, info in shard_registry.items() if info['is_online'] ] # Network part of the plan - # We keep in mind that an aggregator FQND will be the same as the directors FQDN + # We keep in mind that an aggregator FQND will be the same as the + # directors FQDN # We just choose a port randomly from plan hash - director_fqdn = self.federation.director_node_fqdn.split(':')[0] # We drop the port + director_fqdn = self.federation.director_node_fqdn.split(':')[ + 0] # We drop the port self.plan.config['network']['settings']['agg_addr'] = director_fqdn self.plan.config['network']['settings']['tls'] = self.federation.tls # Aggregator part of the plan - self.plan.config['aggregator']['settings']['rounds_to_train'] = rounds_to_train + self.plan.config['aggregator']['settings'][ + 'rounds_to_train'] = rounds_to_train # Collaborator part - self.plan.config['collaborator']['settings']['delta_updates'] = delta_updates - self.plan.config['collaborator']['settings']['opt_treatment'] = opt_treatment + self.plan.config['collaborator']['settings'][ + 'delta_updates'] = delta_updates + self.plan.config['collaborator']['settings'][ + 'opt_treatment'] = opt_treatment self.plan.config['collaborator']['settings'][ 'device_assignment_policy'] = device_assignment_policy @@ -398,8 +574,8 @@ def _prepare_plan(self, model_provider, data_loader, self.plan.config['data_loader']['settings'][setting] = value # TaskRunner framework plugin - # ['required_plugin_components'] should be already in the default plan with all the fields - # filled with the default values + # ['required_plugin_components'] should be already in the default plan + # with all the fields filled with the default values self.plan.config['task_runner']['required_plugin_components'] = { 'framework_adapters': model_provider.framework_plugin } @@ -413,24 +589,34 @@ def _prepare_plan(self, model_provider, data_loader, 'model_interface_file': model_interface_file, 'tasks_interface_file': tasks_interface_file, 'dataloader_interface_file': dataloader_interface_file, - 'aggregation_function_interface_file': aggregation_function_interface_file, + 'aggregation_function_interface_file': + aggregation_function_interface_file, 'task_assigner_file': task_assigner_file } } if override_config: - self.plan = update_plan(override_config, plan=self.plan, resolve=False) + self.plan = update_plan(override_config, + plan=self.plan, + resolve=False) - def _serialize_interface_objects( - self, - model_provider, - task_keeper, - data_loader, - task_assigner - ): - """Save python objects to be restored on collaborators.""" + def _serialize_interface_objects(self, model_provider, task_keeper, + data_loader, task_assigner): + """Save python objects to be restored on collaborators. + + This method serializes the provided python objects and saves them for + later use. The objects are serialized using the serializer plugin + specified in the plan configuration. + + Args: + model_provider: The ModelInterface instance to be serialized. + task_keeper: The TaskKeeper instance to be serialized. + data_loader: The DataInterface instance to be serialized. + task_assigner: The task assigner to be serialized. + """ serializer = self.plan.build( - self.plan.config['api_layer']['required_plugin_components']['serializer_plugin'], {}) + self.plan.config['api_layer']['required_plugin_components'] + ['serializer_plugin'], {}) framework_adapter = Plan.build(model_provider.framework_plugin, {}) # Model provider serialization may need preprocessing steps framework_adapter.serialization_setup() @@ -439,25 +625,43 @@ def _serialize_interface_objects( 'model_interface_file': model_provider, 'tasks_interface_file': task_keeper, 'dataloader_interface_file': data_loader, - 'aggregation_function_interface_file': task_keeper.aggregation_functions, + 'aggregation_function_interface_file': + task_keeper.aggregation_functions, 'task_assigner_file': task_assigner } for filename, object_ in obj_dict.items(): - serializer.serialize(object_, self.plan.config['api_layer']['settings'][filename]) + serializer.serialize( + object_, self.plan.config['api_layer']['settings'][filename]) class TaskKeeper: - """ - Task keeper class. + """Task keeper class. + + This class is responsible for managing tasks in a federated learning + experiment. It keeps track of registered tasks, their settings, and + aggregation functions. Task should accept the following entities that exist on collaborator nodes: - 1. model - will be rebuilt with relevant weights for every task by `TaskRunner` - 2. data_loader - data loader equipped with `repository adapter` that provides local data - 3. device - a device to be used on collaborator machines - 4. optimizer (optional) + 1. model - will be rebuilt with relevant weights for every task by + `TaskRunner`. + 2. data_loader - data loader equipped with `repository adapter` that + provides local data. + 3. device - a device to be used on collaborator machines. + 4. optimizer (optional). Task returns a dictionary {metric name: metric value for this task} + + Attributes: + task_registry (dict): A dictionary mapping task names to callable + functions. + task_contract (dict): A dictionary mapping task names to their + contract. + task_settings (dict): A dictionary mapping task names to their + settings. + aggregation_functions (dict): A dictionary mapping task names to their + aggregation functions. + _tasks (dict): A dictionary mapping task aliases to Task objects. """ def __init__(self) -> None: @@ -473,15 +677,20 @@ def __init__(self) -> None: # Mapping 'task_alias' -> Task self._tasks: Dict[str, Task] = {} - def register_fl_task(self, model, data_loader, device, optimizer=None, round_num=None): - """ - Register FL tasks. + def register_fl_task(self, + model, + data_loader, + device, + optimizer=None, + round_num=None): + """Register FL tasks. The task contract should be set up by providing variable names: [model, data_loader, device] - necessarily and optimizer - optionally - All tasks should accept contract entities to be run on collaborator node. + All tasks should accept contract entities to be run on collaborator + node. Moreover we ask users return dict{'metric':value} in every task ` TI = TaskInterface() @@ -493,13 +702,36 @@ def register_fl_task(self, model, data_loader, device, optimizer=None, round_num @TI.add_kwargs(**task_settings) @TI.register_fl_task(model='my_model', data_loader='train_loader', device='device', optimizer='my_Adam_opt') - def foo_task(my_model, train_loader, my_Adam_opt, device, batch_size, some_arg=356) + def foo_task(my_model, train_loader, my_Adam_opt, device, batch_size, + some_arg=356) ... return {'metric_name': metric, 'metric_name_2': metric_2,} ` + + Args: + model: The model to be used in the task. + data_loader: The data loader to be used in the task. + device: The device to be used in the task. + optimizer (optional): The optimizer to be used in the task. + Defaults to None. + round_num (optional): The round number for the task. Defaults + to None. """ - # The highest level wrapper for allowing arguments for the decorator + def decorator_with_args(training_method): + """A high-level wrapper that allows arguments for the decorator. + + This function is a decorator that wraps a training method. It + saves the task and the contract for later serialization. It also + defines tasks based on whether an optimizer is provided. + + Args: + training_method: The training method to be wrapped. + + Returns: + function: The wrapped training method. + """ + # We could pass hooks to the decorator # @functools.wraps(training_method) @@ -510,8 +742,13 @@ def wrapper_decorator(**task_keywords): # Saving the task and the contract for later serialization function_name = training_method.__name__ self.task_registry[function_name] = wrapper_decorator - contract = {'model': model, 'data_loader': data_loader, - 'device': device, 'optimizer': optimizer, 'round_num': round_num} + contract = { + 'model': model, + 'data_loader': data_loader, + 'device': device, + 'optimizer': optimizer, + 'round_num': round_num + } self.task_contract[function_name] = contract # define tasks if optimizer: @@ -536,14 +773,17 @@ def wrapper_decorator(**task_keywords): return decorator_with_args def add_kwargs(self, **task_kwargs): - """ - Register tasks settings. + """Register tasks settings. Warning! We do not actually need to register additional kwargs, we ust serialize them. This one is a decorator because we need task name and to be consistent with the main registering method + + Args: + **task_kwargs: Keyword arguments for the task settings. """ + # The highest level wrapper for allowing arguments for the decorator def decorator_with_args(training_method): # Saving the task's settings to be written in plan @@ -553,7 +793,8 @@ def decorator_with_args(training_method): return decorator_with_args - def set_aggregation_function(self, aggregation_function: AggregationFunction): + def set_aggregation_function(self, + aggregation_function: AggregationFunction): """Set aggregation function for the task. To be serialized and sent to aggregator node. @@ -562,27 +803,37 @@ def set_aggregation_function(self, aggregation_function: AggregationFunction): containing logic from workspace-related libraries that are not present on director yet. - Args: - aggregation_function: Aggregation function. - - You might need to override default FedAvg aggregation with built-in aggregation types: + You might need to override default FedAvg aggregation with built-in + aggregation types: - openfl.interface.aggregation_functions.GeometricMedian - openfl.interface.aggregation_functions.Median or define your own AggregationFunction subclass. - See more details on `Overriding the aggregation function`_ documentation page. + See more details on `Overriding the aggregation function`_ + documentation page. .. _Overriding the aggregation function: https://openfl.readthedocs.io/en/latest/overriding_agg_fn.html + + Args: + aggregation_function: The aggregation function to be used for + the task. """ + def decorator_with_args(training_method): if not isinstance(aggregation_function, AggregationFunction): raise Exception('aggregation_function must implement ' 'AggregationFunction interface.') - self.aggregation_functions[training_method.__name__] = aggregation_function + self.aggregation_functions[ + training_method.__name__] = aggregation_function return training_method + return decorator_with_args def get_registered_tasks(self) -> Dict[str, Task]: - """Return registered tasks.""" + """Return registered tasks. + + Returns: + A dictionary mapping task names to Task objects. + """ return self._tasks @@ -591,49 +842,71 @@ def get_registered_tasks(self) -> Dict[str, Task]: class ModelInterface: - """ - Registers model graph and optimizer. + """Registers model graph and optimizer. - To be serialized and sent to collaborator nodes + This class is responsible for managing the model and optimizer in a + federated learning experiment. + To be serialized and sent to collaborator nodes. This is the place to determine correct framework adapter - as they are needed to fill the model graph with trained tensors. + as they are needed to fill the model graph with trained tensors. There is no support for several models / optimizers yet. + + Attributes: + model: The model to be used in the experiment. + optimizer: The optimizer to be used in the experiment. + framework_plugin: The framework plugin to be used in the experiment. """ def __init__(self, model, optimizer, framework_plugin) -> None: - """ - Initialize model keeper. + """Initialize model keeper. Tensors in provided graphs will be used for initialization of the global model. - Arguments: - model: Union[tuple, graph] - optimizer: Union[tuple, optimizer] + Args: + model (Union[Path, str]) : The model to be used in the experiment. + optimizer (Union[tuple, optimizer]) : The optimizer to be used in + the experiment. + framework_plugin: The framework plugin to be used in the + experiment. """ self.model = model self.optimizer = optimizer self.framework_plugin = framework_plugin def provide_model(self): - """Retrieve model.""" + """Retrieve model. + + Returns: + The model used in the experiment. + """ return self.model def provide_optimizer(self): - """Retrieve optimizer.""" + """Retrieve optimizer. + + Returns: + The optimizer used in the experiment. + """ return self.optimizer class DataInterface: - """ - The class to define dataloaders. + """The class to define dataloaders. + + This class is responsible for managing the data loaders in a federated + learning experiment. In the future users will have to adapt `unified data interface hook` - in their dataloaders. + in their dataloaders. For now, we can provide `data_path` variable on every collaborator node - at initialization time for dataloader customization + at initialization time for dataloader customization. + + Attributes: + kwargs (dict): The keyword arguments for the data loaders. + shard_descriptor: The shard descriptor for the dataloader. """ def __init__(self, **kwargs): @@ -642,26 +915,34 @@ def __init__(self, **kwargs): @property def shard_descriptor(self): - """Return shard descriptor.""" + """Return shard descriptor. + + Returns: + The shard descriptor for the data loaders. + """ return self._shard_descriptor @shard_descriptor.setter def shard_descriptor(self, shard_descriptor): - """ - Describe per-collaborator procedures or sharding. + """Describe per-collaborator procedures or sharding. This method will be called during a collaborator initialization. Local shard_descriptor will be set by Envoy. + + Args: + shard_descriptor: The shard descriptor for the data loaders. """ self._shard_descriptor = shard_descriptor raise NotImplementedError def get_train_loader(self, **kwargs): - """Output of this method will be provided to tasks with optimizer in contract.""" + """Output of this method will be provided to tasks with optimizer in + contract.""" raise NotImplementedError def get_valid_loader(self, **kwargs): - """Output of this method will be provided to tasks without optimizer in contract.""" + """Output of this method will be provided to tasks without optimizer + in contract.""" raise NotImplementedError def get_train_data_size(self): diff --git a/openfl/interface/interactive_api/federation.py b/openfl/interface/interactive_api/federation.py index f8fce7a8d9..fef08ff5a4 100644 --- a/openfl/interface/interactive_api/federation.py +++ b/openfl/interface/interactive_api/federation.py @@ -8,15 +8,32 @@ class Federation: - """ - Federation class. + """Federation class. + + Manages information about collaborator settings, local data, and network settings. + + The Federation class is used to maintain information about collaborator-related settings, + their local data, and network settings to enable communication within the federation. - Federation entity exists to keep information about collaborator related settings, - their local data and network setting to enable communication in federation. + Attributes: + director_node_fqdn (str): The fully qualified domain name (FQDN) of the director node. + tls (bool): A boolean indicating whether mTLS (mutual Transport Layer Security) is enabled. + cert_chain (str): The path to the certificate chain to the Certificate Authority (CA). + api_cert (str): The path to the API certificate. + api_private_key (str): The path to the API private key. + dir_client (DirectorClient): An instance of the DirectorClient class. + sample_shape (tuple): The shape of the samples in the dataset. + target_shape (tuple): The shape of the targets in the dataset. """ - def __init__(self, client_id=None, director_node_fqdn=None, director_port=None, tls=True, - cert_chain=None, api_cert=None, api_private_key=None) -> None: + def __init__(self, + client_id=None, + director_node_fqdn=None, + director_port=None, + tls=True, + cert_chain=None, + api_cert=None, + api_private_key=None) -> None: """ Initialize federation. @@ -26,10 +43,15 @@ def __init__(self, client_id=None, director_node_fqdn=None, director_port=None, pricate key to enable mTLS. Args: - - client_id: name of created Frontend API instance. - The same name user certify. - - director_node_fqdn: Address and port a director's service is running on. - User passes here an address with a port. + client_id (str): Name of created Frontend API instance. + The same name user certify. + director_node_fqdn (str): Address and port a director's service is running on. + User passes here an address with a port. + director_port (int): Port a director's service is running on. + tls (bool): Enable mTLS. + cert_chain (str): Path to a certificate chain to CA. + api_cert (str): Path to API certificate. + api_private_key (str): Path to API private key. """ if director_node_fqdn is None: self.director_node_fqdn = getfqdn_env() @@ -43,24 +65,34 @@ def __init__(self, client_id=None, director_node_fqdn=None, director_port=None, self.api_private_key = api_private_key # Create Director client - self.dir_client = DirectorClient( - client_id=client_id, - director_host=director_node_fqdn, - director_port=director_port, - tls=tls, - root_certificate=cert_chain, - private_key=api_private_key, - certificate=api_cert - ) + self.dir_client = DirectorClient(client_id=client_id, + director_host=director_node_fqdn, + director_port=director_port, + tls=tls, + root_certificate=cert_chain, + private_key=api_private_key, + certificate=api_cert) # Request sample and target shapes from Director. # This is an internal method for finding out dataset properties in a Federation. - self.sample_shape, self.target_shape = self.dir_client.get_dataset_info() + self.sample_shape, self.target_shape = self.dir_client.get_dataset_info( + ) def get_dummy_shard_descriptor(self, size): - """Return a dummy shard descriptor.""" + """Return a dummy shard descriptor. + + Args: + size (int): Size of the shard descriptor. + + Returns: + DummyShardDescriptor: A dummy shard descriptor. + """ return DummyShardDescriptor(self.sample_shape, self.target_shape, size) def get_shard_registry(self): - """Return a shard registry.""" + """Return a shard registry. + + Returns: + list: A list of envoys. + """ return self.dir_client.get_envoys() diff --git a/openfl/interface/interactive_api/shard_descriptor.py b/openfl/interface/interactive_api/shard_descriptor.py index 806beefb75..93e2e28b18 100644 --- a/openfl/interface/interactive_api/shard_descriptor.py +++ b/openfl/interface/interactive_api/shard_descriptor.py @@ -24,81 +24,134 @@ class ShardDescriptor: """Shard descriptor class.""" def get_dataset(self, dataset_type: str) -> ShardDataset: - """Return a shard dataset by type.""" + """Return a shard dataset by type. + + Args: + dataset_type (str): The type of the dataset. + + Returns: + ShardDataset: The shard dataset. + """ raise NotImplementedError @property def sample_shape(self) -> List[int]: - """Return the sample shape info.""" + """Return the sample shape info. + + Returns: + List[int]: The sample shape. + """ raise NotImplementedError @property def target_shape(self) -> List[int]: - """Return the target shape info.""" + """Return the target shape info. + + Returns: + List[int]: The target shape. + """ raise NotImplementedError @property def dataset_description(self) -> str: - """Return the dataset description.""" + """Return the dataset description. + + Returns: + str: The dataset description. + """ return '' class DummyShardDataset(ShardDataset): """Dummy shard dataset class.""" - def __init__( - self, *, - size: int, - sample_shape: List[int], - target_shape: List[int] - ): - """Initialize DummyShardDataset.""" + def __init__(self, *, size: int, sample_shape: List[int], + target_shape: List[int]): + """Initialize DummyShardDataset. + + Args: + size (int): The size of the dataset. + sample_shape (List[int]): The shape of the samples. + target_shape (List[int]): The shape of the targets. + """ self.size = size - self.samples = np.random.randint(0, 255, (self.size, *sample_shape), np.uint8) - self.targets = np.random.randint(0, 255, (self.size, *target_shape), np.uint8) + self.samples = np.random.randint(0, 255, (self.size, *sample_shape), + np.uint8) + self.targets = np.random.randint(0, 255, (self.size, *target_shape), + np.uint8) def __len__(self) -> int: - """Return the len of the dataset.""" + """Return the len of the dataset. + + Returns: + int: The length of the dataset. + """ return self.size def __getitem__(self, index: int): - """Return a item by the index.""" + """Return a item by the index. + + Args: + index (int): The index of the item. + + Returns: + tuple: The sample and target at the given index. + """ return self.samples[index], self.targets[index] class DummyShardDescriptor(ShardDescriptor): """Dummy shard descriptor class.""" - def __init__( - self, - sample_shape: Iterable[str], - target_shape: Iterable[str], - size: int - ) -> None: - """Initialize DummyShardDescriptor.""" + def __init__(self, sample_shape: Iterable[str], + target_shape: Iterable[str], size: int) -> None: + """Initialize DummyShardDescriptor. + + Args: + sample_shape (Iterable[str]): The shape of the samples. + target_shape (Iterable[str]): The shape of the targets. + size (int): The size of the dataset. + """ self._sample_shape = [int(dim) for dim in sample_shape] self._target_shape = [int(dim) for dim in target_shape] self.size = size def get_dataset(self, dataset_type: str) -> ShardDataset: - """Return a shard dataset by type.""" - return DummyShardDataset( - size=self.size, - sample_shape=self._sample_shape, - target_shape=self._target_shape - ) + """Return a shard dataset by type. + + Args: + dataset_type (str): The type of the dataset. + + Returns: + ShardDataset: The shard dataset. + """ + return DummyShardDataset(size=self.size, + sample_shape=self._sample_shape, + target_shape=self._target_shape) @property def sample_shape(self) -> List[int]: - """Return the sample shape info.""" + """Return the sample shape info. + + Returns: + List[int]: The sample shape. + """ return self._sample_shape @property def target_shape(self) -> List[int]: - """Return the target shape info.""" + """Return the target shape info. + + Returns: + List[int]: The target shape. + """ return self._target_shape @property def dataset_description(self) -> str: - """Return the dataset description.""" + """Return the dataset description. + + Returns: + str: The dataset description. + """ return 'Dummy shard descriptor' diff --git a/openfl/interface/model.py b/openfl/interface/model.py index b14d50ecc0..1cc4f88a6d 100644 --- a/openfl/interface/model.py +++ b/openfl/interface/model.py @@ -17,58 +17,94 @@ @group() @pass_context def model(context): - """Manage Federated Learning Models.""" + """Manage Federated Learning Models. + + Args: + context (click.core.Context): Click context. + """ context.obj['group'] = 'model' @model.command(name='save') @pass_context -@option('-i', '--input', 'model_protobuf_path', required=True, +@option('-i', + '--input', + 'model_protobuf_path', + required=True, help='The model protobuf to convert', type=ClickPath(exists=True)) -@option('-o', '--output', 'output_filepath', required=False, +@option('-o', + '--output', + 'output_filepath', + required=False, help='Filename the model will be saved to in native format', - default='output_model', type=ClickPath(writable=True)) -@option('-p', '--plan-config', required=False, + default='output_model', + type=ClickPath(writable=True)) +@option('-p', + '--plan-config', + required=False, help='Federated learning plan [plan/plan.yaml]', - default='plan/plan.yaml', type=ClickPath(exists=True)) -@option('-c', '--cols-config', required=False, + default='plan/plan.yaml', + type=ClickPath(exists=True)) +@option('-c', + '--cols-config', + required=False, help='Authorized collaborator list [plan/cols.yaml]', - default='plan/cols.yaml', type=ClickPath(exists=True)) -@option('-d', '--data-config', required=False, + default='plan/cols.yaml', + type=ClickPath(exists=True)) +@option('-d', + '--data-config', + required=False, help='The data set/shard configuration file [plan/data.yaml]', - default='plan/data.yaml', type=ClickPath(exists=True)) -def save_(context, plan_config, cols_config, data_config, model_protobuf_path, output_filepath): - """ - Save the model in native format (PyTorch / Keras). + default='plan/data.yaml', + type=ClickPath(exists=True)) +def save_(context, plan_config, cols_config, data_config, model_protobuf_path, + output_filepath): + """Save the model in native format (PyTorch / Keras). + + Args: + context (click.core.Context): Click context. + plan_config (str): Federated learning plan. + cols_config (str): Authorized collaborator list. + data_config (str): The data set/shard configuration file. + model_protobuf_path (str): The model protobuf to convert. + output_filepath (str): Filename the model will be saved to in native + format. """ output_filepath = Path(output_filepath).absolute() if output_filepath.exists(): - if not confirm(style( - f'Do you want to overwrite the {output_filepath}?', fg='red', bold=True - )): + if not confirm( + style(f'Do you want to overwrite the {output_filepath}?', + fg='red', + bold=True)): logger.info('Exiting') context.obj['fail'] = True return - task_runner = get_model(plan_config, cols_config, data_config, model_protobuf_path) + task_runner = get_model(plan_config, cols_config, data_config, + model_protobuf_path) task_runner.save_native(output_filepath) logger.info(f'Saved model in native format: 🠆 {output_filepath}') -def get_model( - plan_config: str, - cols_config: str, - data_config: str, - model_protobuf_path: str -): - """ - Initialize TaskRunner and load it with provided model.pbuf. +def get_model(plan_config: str, cols_config: str, data_config: str, + model_protobuf_path: str): + """Initialize TaskRunner and load it with provided model.pbuf. Contrary to its name, this function returns a TaskRunner instance. - The reason for this behavior is the flexibility of the TaskRunner interface and - the diversity of the ways we store models in our template workspaces. + The reason for this behavior is the flexibility of the TaskRunner + interface and the diversity of the ways we store models in our template + workspaces. + + Args: + plan_config (str): Federated learning plan. + cols_config (str): Authorized collaborator list. + data_config (str): The data set/shard configuration file. + model_protobuf_path (str): The model protobuf to convert. + + Returns: + task_runner (instance): TaskRunner instance. """ from openfl.federated import Plan @@ -84,11 +120,9 @@ def get_model( data_config = Path(data_config).resolve().relative_to(workspace_path) with set_directory(workspace_path): - plan = Plan.parse( - plan_config_path=plan_config, - cols_config_path=cols_config, - data_config_path=data_config - ) + plan = Plan.parse(plan_config_path=plan_config, + cols_config_path=cols_config, + data_config_path=data_config) collaborator_name = list(plan.cols_data_paths)[0] data_loader = plan.get_data_loader(collaborator_name) task_runner = plan.get_task_runner(data_loader=data_loader) @@ -98,7 +132,8 @@ def get_model( model_protobuf = utils.load_proto(model_protobuf_path) - tensor_dict, _ = utils.deconstruct_model_proto(model_protobuf, NoCompressionPipeline()) + tensor_dict, _ = utils.deconstruct_model_proto(model_protobuf, + NoCompressionPipeline()) # This may break for multiple models. # task_runner.set_tensor_dict will need to handle multiple models diff --git a/openfl/interface/pki.py b/openfl/interface/pki.py index 272f4edc88..bbad6400a1 100644 --- a/openfl/interface/pki.py +++ b/openfl/interface/pki.py @@ -32,19 +32,31 @@ @group() @pass_context def pki(context): - """Manage Step-ca PKI.""" + """Manage Step-ca PKI. + + Args: + context (click.core.Context): Click context. + """ context.obj['group'] = 'pki' @pki.command(name='run') -@option('-p', '--ca-path', required=True, - help='The ca path', type=ClickPath()) +@option('-p', '--ca-path', required=True, help='The ca path', type=ClickPath()) def run_(ca_path): + """Run CA server. + + Args: + ca_path (str): The ca path. + """ run(ca_path) def run(ca_path): - """Run CA server.""" + """Run CA server. + + Args: + ca_path (str): The ca path. + """ ca_path = Path(ca_path).absolute() step_config_dir = ca_path / CA_STEP_CONFIG_DIR pki_dir = ca_path / CA_PKI_DIR @@ -54,27 +66,37 @@ def run(ca_path): if (not os.path.exists(step_config_dir) or not os.path.exists(pki_dir) or not os.path.exists(password_file) or not os.path.exists(ca_json) or not os.path.exists(step_ca_path)): - logger.error('CA is not installed or corrupted, please install it first') + logger.error( + 'CA is not installed or corrupted, please install it first') sys.exit(1) run_ca(step_ca_path, password_file, ca_json) @pki.command(name='install') -@option('-p', '--ca-path', required=True, - help='The ca path', type=ClickPath()) -@password_option(prompt='The password will encrypt some ca files \nEnter the password') +@option('-p', '--ca-path', required=True, help='The ca path', type=ClickPath()) +@password_option( + prompt='The password will encrypt some ca files \nEnter the password') @option('--ca-url', required=False, default=CA_URL) def install_(ca_path, password, ca_url): - """Create a ca workspace.""" + """Create a ca workspace. + + Args: + ca_path (str): The ca path. + password (str): The password will encrypt some ca files. + ca_url (str): CA URL. + """ ca_path = Path(ca_path).absolute() install(ca_path, ca_url, password) @pki.command(name='uninstall') -@option('-p', '--ca-path', required=True, - help='The CA path', type=ClickPath()) +@option('-p', '--ca-path', required=True, help='The CA path', type=ClickPath()) def uninstall(ca_path): - """Remove step-CA.""" + """Remove step-CA. + + Args: + ca_path (str): The CA path. + """ ca_path = Path(ca_path).absolute() remove_ca(ca_path) @@ -82,17 +104,19 @@ def uninstall(ca_path): @pki.command(name='get-token') @option('-n', '--name', required=True) @option('--ca-url', required=False, default=CA_URL) -@option('-p', '--ca-path', default='.', - help='The CA path', type=ClickPath(exists=True)) +@option('-p', + '--ca-path', + default='.', + help='The CA path', + type=ClickPath(exists=True)) def get_token_(name, ca_url, ca_path): - """ - Create authentication token. + """Create authentication token. Args: - name: common name for following certificate - (aggregator fqdn or collaborator name) - ca_url: full url of CA server - ca_path: the path to CA binaries + name (str): Common name for following certificate (aggregator fqdn or + collaborator name). + ca_url (str): Full URL of CA server. + ca_path (str): The path to CA binaries. """ ca_path = Path(ca_path).absolute() token = get_token(name, ca_url, ca_path) @@ -103,12 +127,28 @@ def get_token_(name, ca_url, ca_path): @pki.command(name='certify') @option('-n', '--name', required=True) @option('-t', '--token', 'token_with_cert', required=True) -@option('-c', '--certs-path', required=False, default=Path('.') / 'cert', - help='The path where certificates will be stored', type=ClickPath()) -@option('-p', '--ca-path', default='.', help='The path to CA client', - type=ClickPath(exists=True), required=False) +@option('-c', + '--certs-path', + required=False, + default=Path('.') / 'cert', + help='The path where certificates will be stored', + type=ClickPath()) +@option('-p', + '--ca-path', + default='.', + help='The path to CA client', + type=ClickPath(exists=True), + required=False) def certify_(name, token_with_cert, certs_path, ca_path): - """Create an envoy workspace.""" + """Create an envoy workspace. + + Args: + name (str): Common name for following certificate (aggregator fqdn or + collaborator name). + token_with_cert (str): Authentication token. + certs_path (str): The path where certificates will be stored. + ca_path (str): The path to CA client. + """ certs_path = Path(certs_path).absolute() ca_path = Path(ca_path).absolute() certs_path.mkdir(parents=True, exist_ok=True) diff --git a/openfl/interface/plan.py b/openfl/interface/plan.py index 9e1618f742..5ec6338ee6 100644 --- a/openfl/interface/plan.py +++ b/openfl/interface/plan.py @@ -21,37 +21,67 @@ @group() @pass_context def plan(context): - """Manage Federated Learning Plans.""" + """Manage Federated Learning Plans. + + Args: + context (click.core.Context): Click context. + """ context.obj['group'] = 'plan' @plan.command() @pass_context -@option('-p', '--plan_config', required=False, +@option('-p', + '--plan_config', + required=False, help='Federated learning plan [plan/plan.yaml]', - default='plan/plan.yaml', type=ClickPath(exists=True)) -@option('-c', '--cols_config', required=False, + default='plan/plan.yaml', + type=ClickPath(exists=True)) +@option('-c', + '--cols_config', + required=False, help='Authorized collaborator list [plan/cols.yaml]', - default='plan/cols.yaml', type=ClickPath(exists=True)) -@option('-d', '--data_config', required=False, + default='plan/cols.yaml', + type=ClickPath(exists=True)) +@option('-d', + '--data_config', + required=False, help='The data set/shard configuration file [plan/data.yaml]', - default='plan/data.yaml', type=ClickPath(exists=True)) -@option('-a', '--aggregator_address', required=False, + default='plan/data.yaml', + type=ClickPath(exists=True)) +@option('-a', + '--aggregator_address', + required=False, help='The FQDN of the federation agregator') -@option('-f', '--input_shape', cls=InputSpec, required=False, - help="The input shape to the model. May be provided as a list:\n\n" - "--input_shape [1,28,28]\n\n" - "or as a dictionary for multihead models (must be passed in quotes):\n\n" - "--input_shape \"{'input_0': [1, 240, 240, 4],'output_1': [1, 240, 240, 1]}\"\n\n ") -@option('-g', '--gandlf_config', required=False, +@option( + '-f', + '--input_shape', + cls=InputSpec, + required=False, + help="The input shape to the model. May be provided as a list:\n\n" + "--input_shape [1,28,28]\n\n" + "or as a dictionary for multihead models (must be passed in quotes):\n\n" + "--input_shape \"{'input_0': [1, 240, 240, 4],'output_1': [1, 240, 240, 1]}\"\n\n " +) +@option('-g', + '--gandlf_config', + required=False, help='GaNDLF Configuration File Path') def initialize(context, plan_config, cols_config, data_config, - aggregator_address, input_shape, gandlf_config): - """ - Initialize Data Science plan. - - Create a protocol buffer file of the initial model weights for - the federation. + aggregator_address, feature_shape, gandlf_config): + """Initialize Data Science plan. + + Create a protocol buffer file of the initial model weights for the + federation. + + Args: + context (click.core.Context): Click context. + plan_config (str): Federated learning plan. + cols_config (str): Authorized collaborator list. + data_config (str): The data set/shard configuration file. + aggregator_address (str): The FQDN of the federation aggregator. + feature_shape (str): The input shape to the model. + gandlf_config (str): GaNDLF Configuration File Path. """ from pathlib import Path @@ -91,10 +121,8 @@ def initialize(context, plan_config, cols_config, data_config, tensor_pipe = plan.get_tensor_pipe() tensor_dict, holdout_params = split_tensor_dict_for_holdouts( - logger, - task_runner.get_tensor_dict(False), - **task_runner.tensor_dict_split_fn_kwargs - ) + logger, task_runner.get_tensor_dict(False), + **task_runner.tensor_dict_split_fn_kwargs) logger.warn(f'Following parameters omitted from global initial model, ' f'local initialization will determine' @@ -114,10 +142,12 @@ def initialize(context, plan_config, cols_config, data_config, if (plan_origin.config['network']['settings']['agg_addr'] == 'auto' or aggregator_address): - plan_origin.config['network']['settings']['agg_addr'] = aggregator_address or getfqdn_env() + plan_origin.config['network']['settings'][ + 'agg_addr'] = aggregator_address or getfqdn_env() - logger.warn(f'Patching Aggregator Addr in Plan' - f" 🠆 {plan_origin.config['network']['settings']['agg_addr']}") + logger.warn( + f'Patching Aggregator Addr in Plan' + f" 🠆 {plan_origin.config['network']['settings']['agg_addr']}") Plan.dump(plan_config, plan_origin.config) @@ -133,7 +163,11 @@ def initialize(context, plan_config, cols_config, data_config, # TODO: looks like Plan.method def freeze_plan(plan_config): - """Dump the plan to YAML file.""" + """Dump the plan to YAML file. + + Args: + plan_config (str): Federated learning plan. + """ from pathlib import Path from openfl.federated import Plan @@ -152,15 +186,20 @@ def freeze_plan(plan_config): @plan.command(name='freeze') -@option('-p', '--plan_config', required=False, +@option('-p', + '--plan_config', + required=False, help='Federated learning plan [plan/plan.yaml]', - default='plan/plan.yaml', type=ClickPath(exists=True)) + default='plan/plan.yaml', + type=ClickPath(exists=True)) def freeze(plan_config): - """ - Finalize the Data Science plan. + """Finalize the Data Science plan. Create a new plan file that embeds its hash in the file name - (plan.yaml -> plan_{hash}.yaml) and changes the permissions to read only + (plan.yaml -> plan_{hash}.yaml) and changes the permissions to read only. + + Args: + plan_config (str): Federated learning plan. """ if is_directory_traversal(plan_config): echo('Plan config path is out of the openfl workspace scope.') @@ -169,7 +208,11 @@ def freeze(plan_config): def switch_plan(name): - """Switch the FL plan to this one.""" + """Switch the FL plan to this one. + + Args: + name (str): Name of the Federated learning plan. + """ from shutil import copyfile from os.path import isfile @@ -205,20 +248,34 @@ def switch_plan(name): @plan.command(name='switch') -@option('-n', '--name', required=False, +@option('-n', + '--name', + required=False, help='Name of the Federated learning plan', - default='default', type=str) + default='default', + type=str) def switch_(name): - """Switch the current plan to this plan.""" + """Switch the current plan to this plan. + + Args: + name (str): Name of the Federated learning plan. + """ switch_plan(name) @plan.command(name='save') -@option('-n', '--name', required=False, +@option('-n', + '--name', + required=False, help='Name of the Federated learning plan', - default='default', type=str) + default='default', + type=str) def save_(name): - """Save the current plan to this plan and switch.""" + """Save the current plan to this plan and switch. + + Args: + name (str): Name of the Federated learning plan. + """ from os import makedirs from shutil import copyfile @@ -233,11 +290,18 @@ def save_(name): @plan.command(name='remove') -@option('-n', '--name', required=False, +@option('-n', + '--name', + required=False, help='Name of the Federated learning plan', - default='default', type=str) + default='default', + type=str) def remove_(name): - """Remove this plan.""" + """Remove this plan. + + Args: + name (str): Name of the Federated learning plan. + """ from shutil import rmtree if name != 'default': diff --git a/openfl/interface/tutorial.py b/openfl/interface/tutorial.py index 85a0cb3f66..2537e2b221 100644 --- a/openfl/interface/tutorial.py +++ b/openfl/interface/tutorial.py @@ -17,17 +17,32 @@ @group() @pass_context def tutorial(context): - """Manage Jupyter notebooks.""" + """Manage Jupyter notebooks. + + Args: + context (click.core.Context): Click context. + """ context.obj['group'] = 'tutorial' @tutorial.command() -@option('-ip', '--ip', required=False, type=click_types.IP_ADDRESS, +@option('-ip', + '--ip', + required=False, + type=click_types.IP_ADDRESS, help='IP address the Jupyter Lab that should start') -@option('-port', '--port', required=False, type=IntRange(1, 65535), +@option('-port', + '--port', + required=False, + type=IntRange(1, 65535), help='The port the Jupyter Lab server will listen on') def start(ip, port): - """Start the Jupyter Lab from the tutorials directory.""" + """Start the Jupyter Lab from the tutorials directory. + + Args: + ip (str): IP address the Jupyter Lab that should start. + port (int): The port the Jupyter Lab server will listen on. + """ from os import environ from os import sep from subprocess import check_call # nosec @@ -38,9 +53,10 @@ def start(ip, port): if 'VIRTUAL_ENV' in environ: venv = environ['VIRTUAL_ENV'].split(sep)[-1] check_call([ - executable, '-m', 'ipykernel', 'install', - '--user', '--name', f'{venv}' - ], shell=False) + executable, '-m', 'ipykernel', 'install', '--user', '--name', + f'{venv}' + ], + shell=False) jupyter_command = ['jupyter', 'lab', '--notebook-dir', f'{TUTORIALS}'] diff --git a/openfl/interface/workspace.py b/openfl/interface/workspace.py index 31b5ff647d..83e6f6c0e1 100644 --- a/openfl/interface/workspace.py +++ b/openfl/interface/workspace.py @@ -20,12 +20,23 @@ @group() @pass_context def workspace(context): - """Manage Federated Learning Workspaces.""" + """Manage Federated Learning Workspaces. + + Args: + context: The context in which the command is being invoked. + """ context.obj['group'] = 'workspace' def is_directory_traversal(directory: Union[str, Path]) -> bool: - """Check for directory traversal.""" + """Check for directory traversal. + + Args: + directory (Union[str, Path]): The directory to check. + + Returns: + bool: True if directory traversal is detected, False otherwise. + """ cwd = os.path.abspath(os.getcwd()) requested_path = os.path.relpath(directory, start=cwd) requested_path = os.path.abspath(requested_path) @@ -34,7 +45,11 @@ def is_directory_traversal(directory: Union[str, Path]) -> bool: def create_dirs(prefix): - """Create workspace directories.""" + """Create workspace directories. + + Args: + prefix: The prefix for the directories to be created. + """ from shutil import copyfile from openfl.interface.cli_helper import WORKSPACE @@ -44,14 +59,20 @@ def create_dirs(prefix): (prefix / 'cert').mkdir(parents=True, exist_ok=True) # certifications (prefix / 'data').mkdir(parents=True, exist_ok=True) # training data (prefix / 'logs').mkdir(parents=True, exist_ok=True) # training logs - (prefix / 'save').mkdir(parents=True, exist_ok=True) # model weight saves / initialization + (prefix / 'save').mkdir( + parents=True, exist_ok=True) # model weight saves / initialization (prefix / 'src').mkdir(parents=True, exist_ok=True) # model code copyfile(WORKSPACE / 'workspace' / '.workspace', prefix / '.workspace') def create_temp(prefix, template): - """Create workspace templates.""" + """Create workspace templates. + + Args: + prefix: The prefix for the directories to be created. + template: The template to use for creating the workspace. + """ from shutil import ignore_patterns from openfl.interface.cli_helper import copytree @@ -59,24 +80,39 @@ def create_temp(prefix, template): echo('Creating Workspace Templates') - copytree(src=WORKSPACE / template, dst=prefix, dirs_exist_ok=True, + copytree(src=WORKSPACE / template, + dst=prefix, + dirs_exist_ok=True, ignore=ignore_patterns('__pycache__')) # from template workspace def get_templates(): - """Grab the default templates from the distribution.""" + """Grab the default templates from the distribution. + + Returns: + list: A list of default templates. + """ from openfl.interface.cli_helper import WORKSPACE - return [d.name for d in WORKSPACE.glob('*') if d.is_dir() - and d.name not in ['__pycache__', 'workspace', 'experimental']] + return [ + d.name for d in WORKSPACE.glob('*') if d.is_dir() + and d.name not in ['__pycache__', 'workspace', 'experimental'] + ] @workspace.command(name='create') -@option('--prefix', required=True, - help='Workspace name or path', type=ClickPath()) +@option('--prefix', + required=True, + help='Workspace name or path', + type=ClickPath()) @option('--template', required=True, type=Choice(get_templates())) def create_(prefix, template): - """Create the workspace.""" + """Create the workspace. + + Args: + prefix: The prefix for the directories to be created. + template: The template to use for creating the workspace. + """ if is_directory_traversal(prefix): echo('Workspace name or path is out of the openfl workspace scope.') sys.exit(1) @@ -84,7 +120,12 @@ def create_(prefix, template): def create(prefix, template): - """Create federated learning workspace.""" + """Create federated learning workspace. + + Args: + prefix: The prefix for the directories to be created. + template: The template to use for creating the workspace. + """ from os.path import isfile from subprocess import check_call # nosec from sys import executable @@ -105,12 +146,17 @@ def create(prefix, template): if isfile(f'{str(prefix)}/{requirements_filename}'): check_call([ executable, '-m', 'pip', 'install', '-r', - f'{prefix}/requirements.txt'], shell=False) - echo(f'Successfully installed packages from {prefix}/requirements.txt.') + f'{prefix}/requirements.txt' + ], + shell=False) + echo( + f'Successfully installed packages from {prefix}/requirements.txt.') else: echo('No additional requirements for workspace defined. Skipping...') prefix_hash = _get_dir_hash(str(prefix.absolute())) - with open(OPENFL_USERDIR / f'requirements.{prefix_hash}.txt', 'w', encoding='utf-8') as f: + with open(OPENFL_USERDIR / f'requirements.{prefix_hash}.txt', + 'w', + encoding='utf-8') as f: check_call([executable, '-m', 'pip', 'freeze'], shell=False, stdout=f) apply_template_plan(prefix, template) @@ -119,13 +165,22 @@ def create(prefix, template): @workspace.command(name='export') -@option('-o', '--pip-install-options', required=False, - type=str, multiple=True, default=tuple, - help='Options for remote pip install. ' - 'You may pass several options in quotation marks alongside with arguments, ' - 'e.g. -o "--find-links source.site"') +@option( + '-o', + '--pip-install-options', + required=False, + type=str, + multiple=True, + default=tuple, + help='Options for remote pip install. ' + 'You may pass several options in quotation marks alongside with arguments, ' + 'e.g. -o "--find-links source.site"') def export_(pip_install_options: Tuple[str]): - """Export federated learning workspace.""" + """Export federated learning workspace. + + Args: + pip_install_options (Tuple[str]): Options for remote pip install. + """ from os import getcwd from os import makedirs from os.path import isfile @@ -154,8 +209,8 @@ def export_(pip_install_options: Tuple[str]): # Aggregator workspace tmp_dir = join(mkdtemp(), 'openfl', archive_name) - ignore = ignore_patterns( - '__pycache__', '*.crt', '*.key', '*.csr', '*.srl', '*.pem', '*.pbuf') + ignore = ignore_patterns('__pycache__', '*.crt', '*.key', '*.csr', '*.srl', + '*.pem', '*.pbuf') # We only export the minimum required files to set up a collaborator makedirs(f'{tmp_dir}/save', exist_ok=True) @@ -164,7 +219,8 @@ def export_(pip_install_options: Tuple[str]): copytree('./src', f'{tmp_dir}/src', ignore=ignore) # code copytree('./plan', f'{tmp_dir}/plan', ignore=ignore) # plan if isfile('./requirements.txt'): - copy2('./requirements.txt', f'{tmp_dir}/requirements.txt') # requirements + copy2('./requirements.txt', + f'{tmp_dir}/requirements.txt') # requirements else: echo('No requirements.txt file found.') @@ -187,11 +243,16 @@ def export_(pip_install_options: Tuple[str]): @workspace.command(name='import') -@option('--archive', required=True, +@option('--archive', + required=True, help='Zip file containing workspace to import', type=ClickPath(exists=True)) def import_(archive): - """Import federated learning workspace.""" + """Import federated learning workspace. + + Args: + archive: The archive file containing the workspace to import. + """ from os import chdir from os.path import basename from os.path import isfile @@ -208,11 +269,10 @@ def import_(archive): requirements_filename = 'requirements.txt' if isfile(requirements_filename): - check_call([ - executable, '-m', 'pip', 'install', '--upgrade', 'pip'], - shell=False) - check_call([ - executable, '-m', 'pip', 'install', '-r', 'requirements.txt'], + check_call([executable, '-m', 'pip', 'install', '--upgrade', 'pip'], + shell=False) + check_call( + [executable, '-m', 'pip', 'install', '-r', 'requirements.txt'], shell=False) else: echo('No ' + requirements_filename + ' file found.') @@ -241,20 +301,28 @@ def certify(): echo('1. Create Root CA') echo('1.1 Create Directories') - (CERT_DIR / 'ca/root-ca/private').mkdir( - parents=True, exist_ok=True, mode=0o700) + (CERT_DIR / 'ca/root-ca/private').mkdir(parents=True, + exist_ok=True, + mode=0o700) (CERT_DIR / 'ca/root-ca/db').mkdir(parents=True, exist_ok=True) echo('1.2 Create Database') - with open(CERT_DIR / 'ca/root-ca/db/root-ca.db', 'w', encoding='utf-8') as f: + with open(CERT_DIR / 'ca/root-ca/db/root-ca.db', 'w', + encoding='utf-8') as f: pass # write empty file - with open(CERT_DIR / 'ca/root-ca/db/root-ca.db.attr', 'w', encoding='utf-8') as f: + with open(CERT_DIR / 'ca/root-ca/db/root-ca.db.attr', + 'w', + encoding='utf-8') as f: pass # write empty file - with open(CERT_DIR / 'ca/root-ca/db/root-ca.crt.srl', 'w', encoding='utf-8') as f: + with open(CERT_DIR / 'ca/root-ca/db/root-ca.crt.srl', + 'w', + encoding='utf-8') as f: f.write('01') # write file with '01' - with open(CERT_DIR / 'ca/root-ca/db/root-ca.crl.srl', 'w', encoding='utf-8') as f: + with open(CERT_DIR / 'ca/root-ca/db/root-ca.crl.srl', + 'w', + encoding='utf-8') as f: f.write('01') # write file with '01' echo('1.3 Create CA Request and Certificate') @@ -266,34 +334,41 @@ def certify(): # Write root CA certificate to disk with open(CERT_DIR / root_crt_path, 'wb') as f: - f.write(root_cert.public_bytes( - encoding=serialization.Encoding.PEM, - )) + f.write(root_cert.public_bytes(encoding=serialization.Encoding.PEM, )) with open(CERT_DIR / root_key_path, 'wb') as f: - f.write(root_private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() - )) + f.write( + root_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption())) echo('2. Create Signing Certificate') echo('2.1 Create Directories') - (CERT_DIR / 'ca/signing-ca/private').mkdir( - parents=True, exist_ok=True, mode=0o700) + (CERT_DIR / 'ca/signing-ca/private').mkdir(parents=True, + exist_ok=True, + mode=0o700) (CERT_DIR / 'ca/signing-ca/db').mkdir(parents=True, exist_ok=True) echo('2.2 Create Database') - with open(CERT_DIR / 'ca/signing-ca/db/signing-ca.db', 'w', encoding='utf-8') as f: + with open(CERT_DIR / 'ca/signing-ca/db/signing-ca.db', + 'w', + encoding='utf-8') as f: pass # write empty file - with open(CERT_DIR / 'ca/signing-ca/db/signing-ca.db.attr', 'w', encoding='utf-8') as f: + with open(CERT_DIR / 'ca/signing-ca/db/signing-ca.db.attr', + 'w', + encoding='utf-8') as f: pass # write empty file - with open(CERT_DIR / 'ca/signing-ca/db/signing-ca.crt.srl', 'w', encoding='utf-8') as f: + with open(CERT_DIR / 'ca/signing-ca/db/signing-ca.crt.srl', + 'w', + encoding='utf-8') as f: f.write('01') # write file with '01' - with open(CERT_DIR / 'ca/signing-ca/db/signing-ca.crl.srl', 'w', encoding='utf-8') as f: + with open(CERT_DIR / 'ca/signing-ca/db/signing-ca.crl.srl', + 'w', + encoding='utf-8') as f: f.write('01') # write file with '01' echo('2.3 Create Signing Certificate CSR') @@ -307,24 +382,25 @@ def certify(): # Write Signing CA CSR to disk with open(CERT_DIR / signing_csr_path, 'wb') as f: f.write(signing_csr.public_bytes( - encoding=serialization.Encoding.PEM, - )) + encoding=serialization.Encoding.PEM, )) with open(CERT_DIR / signing_key_path, 'wb') as f: - f.write(signing_private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() - )) + f.write( + signing_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption())) echo('2.4 Sign Signing Certificate CSR') - signing_cert = sign_certificate(signing_csr, root_private_key, root_cert.subject, ca=True) + signing_cert = sign_certificate(signing_csr, + root_private_key, + root_cert.subject, + ca=True) with open(CERT_DIR / signing_crt_path, 'wb') as f: - f.write(signing_cert.public_bytes( - encoding=serialization.Encoding.PEM, - )) + f.write( + signing_cert.public_bytes(encoding=serialization.Encoding.PEM, )) echo('3 Create Certificate Chain') @@ -339,6 +415,14 @@ def certify(): def _get_requirements_dict(txtfile): + """Get requirements from a text file. + + Args: + txtfile (str): The text file containing the requirements. + + Returns: + snapshot_dict (dict): A dictionary containing the requirements. + """ with open(txtfile, 'r', encoding='utf-8') as snapshot: snapshot_dict = {} for line in snapshot: @@ -352,6 +436,14 @@ def _get_requirements_dict(txtfile): def _get_dir_hash(path): + """Get the hash of a directory. + + Args: + path (str): The path of the directory. + + Returns: + str: The hash of the directory. + """ from hashlib import sha256 hash_ = sha256() hash_.update(path.encode('utf-8')) @@ -360,7 +452,9 @@ def _get_dir_hash(path): @workspace.command(name='dockerize') -@option('-b', '--base_image', required=False, +@option('-b', + '--base_image', + required=False, help='The tag for openfl base image', default='openfl') @option('--save/--no-save', @@ -369,14 +463,19 @@ def _get_dir_hash(path): default=True) @pass_context def dockerize_(context, base_image, save): - """ - Pack the workspace as a Docker image. + """Pack the workspace as a Docker image. This command is the alternative to `workspace export`. It should be called after plan initialization from the workspace dir. User is expected to be in docker group. - If your machine is behind a proxy, make sure you set it up in ~/.docker/config.json. + If your machine is behind a proxy, make sure you set it up in + ~/.docker/config.json. + + Args: + context: The context in which the command is being invoked. + base_image (str): The tag for openfl base image. + save (bool): Whether to save the Docker image into the workspace. """ import docker import sys @@ -389,7 +488,8 @@ def dockerize_(context, base_image, save): dockerfile_workspace = 'Dockerfile.workspace' # Apparently, docker's python package does not support # scenarios when the dockerfile is placed outside the build context - copyfile(os.path.join(openfl_docker_dir, dockerfile_workspace), dockerfile_workspace) + copyfile(os.path.join(openfl_docker_dir, dockerfile_workspace), + dockerfile_workspace) workspace_path = os.getcwd() workspace_name = os.path.basename(workspace_path) @@ -398,22 +498,17 @@ def dockerize_(context, base_image, save): context.invoke(export_) workspace_archive = workspace_name + '.zip' - build_args = { - 'WORKSPACE_NAME': workspace_name, - 'BASE_IMAGE': base_image - } + build_args = {'WORKSPACE_NAME': workspace_name, 'BASE_IMAGE': base_image} cli = docker.APIClient() echo('Building the Docker image') try: - for line in cli.build( - path=str(workspace_path), - tag=workspace_name, - buildargs=build_args, - dockerfile=dockerfile_workspace, - timeout=3600, - decode=True - ): + for line in cli.build(path=str(workspace_path), + tag=workspace_name, + buildargs=build_args, + dockerfile=dockerfile_workspace, + timeout=3600, + decode=True): if 'stream' in line: print(f'> {line["stream"]}', end='') elif 'error' in line: @@ -435,62 +530,95 @@ def dockerize_(context, base_image, save): with open(workspace_image_tar, 'wb') as f: for chunk in resp: f.write(chunk) - echo(f'{workspace_name} image saved to {workspace_path}/{workspace_image_tar}') + echo( + f'{workspace_name} image saved to {workspace_path}/{workspace_image_tar}' + ) @workspace.command(name='graminize') -@option('-s', '--signing-key', required=False, - type=lambda p: Path(p).absolute(), default='/', - help='A 3072-bit RSA private key (PEM format) is required for signing the manifest.\n' - 'If a key is passed the gramine-sgx manifest fill be prepared.\n' - 'In option is ignored this command will build an image that can only run ' - 'with gramine-direct (not in enclave).', - ) -@option('-e', '--enclave_size', required=False, - type=str, default='16G', +@option( + '-s', + '--signing-key', + required=False, + type=lambda p: Path(p).absolute(), + default='/', + help= + 'A 3072-bit RSA private key (PEM format) is required for signing the manifest.\n' + 'If a key is passed the gramine-sgx manifest fill be prepared.\n' + 'In option is ignored this command will build an image that can only run ' + 'with gramine-direct (not in enclave).', +) +@option('-e', + '--enclave_size', + required=False, + type=str, + default='16G', help='Memory size of the enclave, defined as number with size suffix. ' - 'Must be a power-of-2.\n' - 'Default is 16G.' - ) -@option('-t', '--tag', required=False, - type=str, multiple=False, default='', + 'Must be a power-of-2.\n' + 'Default is 16G.') +@option('-t', + '--tag', + required=False, + type=str, + multiple=False, + default='', help='Tag of the built image.\n' - 'By default, the workspace name is used.' - ) -@option('-o', '--pip-install-options', required=False, - type=str, multiple=True, default=tuple, - help='Options for remote pip install. ' - 'You may pass several options in quotation marks alongside with arguments, ' - 'e.g. -o "--find-links source.site"') -@option('--save/--no-save', required=False, - default=True, type=bool, + 'By default, the workspace name is used.') +@option( + '-o', + '--pip-install-options', + required=False, + type=str, + multiple=True, + default=tuple, + help='Options for remote pip install. ' + 'You may pass several options in quotation marks alongside with arguments, ' + 'e.g. -o "--find-links source.site"') +@option('--save/--no-save', + required=False, + default=True, + type=bool, help='Dump the Docker image to an archive') @option('--rebuild', help='Build images with `--no-cache`', is_flag=True) @pass_context def graminize_(context, signing_key: Path, enclave_size: str, tag: str, - pip_install_options: Tuple[str], save: bool, rebuild: bool) -> None: - """ - Build gramine app inside a docker image. + pip_install_options: Tuple[str], save: bool, + rebuild: bool) -> None: + """Build gramine app inside a docker image. This command is the alternative to `workspace export`. It should be called after `fx plan initialize` inside the workspace dir. User is expected to be in docker group. - If your machine is behind a proxy, make sure you set it up in ~/.docker/config.json. + If your machine is behind a proxy, make sure you set it up in + ~/.docker/config.json. TODO: 1. gramine-direct, check if a key is provided 2. make a standalone function with `export` parametr + + Args: + context: The context in which the command is being invoked. + signing_key (Path): A 3072-bit RSA private key (PEM format) is + required for signing the manifest. + enclave_size (str): Memory size of the enclave, defined as number with + size suffix. + tag (str): Tag of the built image. + pip_install_options (Tuple[str]): Options for remote pip install. + save (bool): Whether to dump the Docker image to an archive. + rebuild (bool): Whether to build images with `--no-cache`. """ + def open_pipe(command: str): echo(f'\n 📦 Executing command:\n{command}\n') - process = subprocess.Popen( - command, - shell=True, stderr=subprocess.STDOUT, - stdout=subprocess.PIPE) + process = subprocess.Popen(command, + shell=True, + stderr=subprocess.STDOUT, + stdout=subprocess.PIPE) for line in process.stdout: echo(line) - _ = process.communicate() # pipe is already empty, used to get `returncode` + _ = process.communicate( + ) # pipe is already empty, used to get `returncode` if process.returncode != 0: raise Exception('\n ❌ Execution failed\n') @@ -549,6 +677,10 @@ def apply_template_plan(prefix, template): This function unfolds default values from template plan configuration and writes the configuration to the current workspace. + + Args: + prefix: The prefix for the directories to be created. + template: The template to use for creating the workspace. """ from openfl.federated.plan import Plan from openfl.interface.cli_helper import WORKSPACE diff --git a/openfl/native/fastestimator.py b/openfl/native/fastestimator.py index e2d659563a..cea24c7feb 100644 --- a/openfl/native/fastestimator.py +++ b/openfl/native/fastestimator.py @@ -15,10 +15,24 @@ class FederatedFastEstimator: - """A wrapper for fastestimator.estimator that allows running in federated mode.""" + """A wrapper for fastestimator.estimator that allows running in federated + mode. + + Attributes: + estimator: The FastEstimator to be used. + logger: A logger to record events. + rounds: The number of rounds to train. + """ def __init__(self, estimator, override_config: dict = None, **kwargs): - """Initialize.""" + """Initializes a new instance of the FederatedFastEstimator class. + + Args: + estimator: The FastEstimator to be used. + override_config (dict, optional): A dictionary to override the + default configuration. Defaults to None. + **kwargs: Additional keyword arguments. + """ self.estimator = estimator self.logger = getLogger(__name__) fx.init(**kwargs) @@ -26,7 +40,7 @@ def __init__(self, estimator, override_config: dict = None, **kwargs): fx.update_plan(override_config) def fit(self): - """Run the estimator.""" + """Runs the estimator in federated mode.""" import fastestimator as fe from fastestimator.trace.io.best_model_saver import BestModelSaver from sys import path @@ -52,8 +66,8 @@ def fit(self): self.rounds = plan.config['aggregator']['settings']['rounds_to_train'] data_loader = FastEstimatorDataLoader(self.estimator.pipeline) - runner = FastEstimatorTaskRunner( - self.estimator, data_loader=data_loader) + runner = FastEstimatorTaskRunner(self.estimator, + data_loader=data_loader) # Overwrite plan values tensor_pipe = plan.get_tensor_pipe() # Initialize model weights @@ -76,7 +90,8 @@ def fit(self): aggregator = plan.get_aggregator() model_states = { - collaborator: None for collaborator in plan.authorized_cols + collaborator: None + for collaborator in plan.authorized_cols } runners = {} save_dir = {} @@ -84,12 +99,14 @@ def fit(self): for col in plan.authorized_cols: data = self.estimator.pipeline.data train_data, eval_data, test_data = split_data( - data['train'], data['eval'], data['test'], - data_path, len(plan.authorized_cols)) + data['train'], data['eval'], data['test'], data_path, + len(plan.authorized_cols)) pipeline_kwargs = {} for k, v in self.estimator.pipeline.__dict__.items(): - if k in ['batch_size', 'ops', 'num_process', - 'drop_last', 'pad_value', 'collate_fn']: + if k in [ + 'batch_size', 'ops', 'num_process', 'drop_last', + 'pad_value', 'collate_fn' + ]: pipeline_kwargs[k] = v pipeline_kwargs.update({ 'train_data': train_data, @@ -101,8 +118,8 @@ def fit(self): data_loader = FastEstimatorDataLoader(pipeline) self.estimator.system.pipeline = pipeline - runners[col] = FastEstimatorTaskRunner( - estimator=self.estimator, data_loader=data_loader) + runners[col] = FastEstimatorTaskRunner(estimator=self.estimator, + data_loader=data_loader) runners[col].set_optimizer_treatment('CONTINUE_LOCAL') for trace in runners[col].estimator.system.traces: @@ -114,9 +131,12 @@ def fit(self): data_path += 1 # Create the collaborators - collaborators = {collaborator: fx.create_collaborator( - plan, collaborator, runners[collaborator], aggregator) - for collaborator in plan.authorized_cols} + collaborators = { + collaborator: + fx.create_collaborator(plan, collaborator, runners[collaborator], + aggregator) + for collaborator in plan.authorized_cols + } model = None for round_num in range(self.rounds): @@ -154,7 +174,19 @@ def fit(self): def split_data(train, eva, test, rank, collaborator_count): - """Split data into N parts, where N is the collaborator count.""" + """Split data into N parts, where N is the collaborator count. + + Args: + train : The training data. + eva : The evaluation data. + test : The testing data. + rank (int): The rank of the current collaborator. + collaborator_count (int): The total number of collaborators. + + Returns: + tuple: The training, evaluation, and testing data for the current + collaborator. + """ if collaborator_count == 1: return train, eva, test diff --git a/openfl/native/native.py b/openfl/native/native.py index 1842ec6412..dcc1ce04d9 100644 --- a/openfl/native/native.py +++ b/openfl/native/native.py @@ -1,8 +1,9 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""openfl Native functions module. +"""Openfl Native functions module. -This file defines openfl entrypoints to be used directly through python (not CLI) +This file defines openfl entrypoints to be used directly through python (not +CLI) """ import logging @@ -27,15 +28,15 @@ def setup_plan(log_level='CRITICAL'): - """ - Dump the plan with all defaults + overrides set. + """Dump the plan with all defaults and overrides set. Args: - save : bool (default=True) - Whether to save the plan to disk + log_level (str, optional): The log level Whether to save the plan to + disk. + Defaults to 'CRITICAL'. Returns: - plan : Plan object + plan: Plan object. """ plan_config = 'plan/plan.yaml' cols_config = 'plan/cols.yaml' @@ -53,12 +54,22 @@ def setup_plan(log_level='CRITICAL'): def flatten(config, return_complete=False): - """Flatten nested config.""" + """Flatten nested config. + + Args: + config (dict): The configuration dictionary to flatten. + return_complete (bool, optional): Whether to return the complete + flattened config. Defaults to False. + + Returns: + flattened_config (dict): The flattened configuration dictionary. + """ flattened_config = flatten_json.flatten(config, '.') if not return_complete: keys_to_remove = [ k for k, v in flattened_config.items() - if ('defaults' in k or v is None)] + if ('defaults' in k or v is None) + ] else: keys_to_remove = [k for k, v in flattened_config.items() if v is None] for k in keys_to_remove: @@ -68,16 +79,19 @@ def flatten(config, return_complete=False): def update_plan(override_config, plan=None, resolve=True): - """ - Update the plan with the provided override and save it to disk. + """Updates the plan with the provided override and saves it to disk. For a list of available override options, call `fx.get_plan()` Args: - override_config : dict {"COMPONENT.settings.variable" : value or list of values} + override_config (dict): A dictionary of values to override in the plan. + plan (Plan, optional): The plan to update. If None, a new plan is set + up. Defaults to None. + resolve (bool, optional): Whether to resolve the plan. Defaults to + True. Returns: - None + plan (object): The updated plan. """ if plan is None: plan = setup_plan() @@ -102,7 +116,9 @@ def update_plan(override_config, plan=None, resolve=True): logger.info(f'Updating {key} to {val}... ') else: # TODO: We probably need to validate the new key somehow - logger.info(f'Did not find {key} in config. Make sure it should exist. Creating...') + logger.info( + f'Did not find {key} in config. Make sure it should exist. Creating...' + ) if type(val) is list: for idx, v in enumerate(val): flat_plan_config[f'{key}.{idx}'] = v @@ -116,13 +132,28 @@ def update_plan(override_config, plan=None, resolve=True): def unflatten(config, separator='.'): - """Unfold `config` settings that have `separator` in their names.""" + """Unfolds `config` settings that have `separator` in their names. + + Args: + config (dict): The flattened configuration dictionary to unfold. + separator (str, optional): The separator used in the flattened config. + Defaults to '.'. + + Returns: + config (dict): The unfolded configuration dictionary. + """ config = flatten_json.unflatten_list(config, separator) return config def setup_logging(level='INFO', log_file=None): - """Initialize logging settings.""" + """Initializes logging settings. + + Args: + level (str, optional): The log level. Defaults to 'INFO'. + log_file (str, optional): The name of the file to log to. + If None, logs are not saved to a file. Defaults to None. + """ # Setup logging from logging import basicConfig from rich.console import Console @@ -141,52 +172,53 @@ def setup_logging(level='INFO', log_file=None): if log_file: fh = logging.FileHandler(log_file) formatter = logging.Formatter( - '%(asctime)s %(levelname)s %(message)s %(filename)s:%(lineno)d' - ) + '%(asctime)s %(levelname)s %(message)s %(filename)s:%(lineno)d') fh.setFormatter(formatter) handlers.append(fh) console = Console(width=160) handlers.append(RichHandler(console=console)) - basicConfig(level=level, format='%(message)s', - datefmt='[%X]', handlers=handlers) + basicConfig(level=level, + format='%(message)s', + datefmt='[%X]', + handlers=handlers) -def init(workspace_template: str = 'default', log_level: str = 'INFO', - log_file: str = None, agg_fqdn: str = None, col_names=None): - """ - Initialize the openfl package. +def init(workspace_template: str = 'default', + log_level: str = 'INFO', + log_file: str = None, + agg_fqdn: str = None, + col_names=None): + """Initializes the openfl package. It performs the following tasks: - 1. Creates a workspace in ~/.local/workspace (Equivalent to `fx - workspace create --prefix ~/.local/workspace --template - $workspace_template) - 2. Setup certificate authority (equivalent to `fx workspace certify`) - 3. Setup aggregator PKI (equivalent to `fx aggregator - generate-cert-request` followed by `fx aggregator certify`) - 4. Setup list of collaborators (col_names) and their PKI. (Equivalent - to running `fx collaborator generate-cert-request` followed by `fx - collaborator certify` for each of the collaborators in col_names) - 5. Setup logging + 1. Creates a workspace in ~/.local/workspace (Equivalent to `fx + workspace create --prefix ~/.local/workspace --template + $workspace_template) + 2. Setup certificate authority (equivalent to `fx workspace certify`) + 3. Setup aggregator PKI (equivalent to `fx aggregator + generate-cert-request` followed by `fx aggregator certify`) + 4. Setup list of collaborators (col_names) and their PKI. (Equivalent + to running `fx collaborator generate-cert-request` followed by `fx + collaborator certify` for each of the collaborators in col_names) + 5. Setup logging Args: - workspace_template : str (default='default') - The template that should be used as the basis for the experiment. - Other options include are any of the template names [ - keras_cnn_mnist, tf_2dunet, tf_cnn_histology, mtorch_cnn_histology, - torch_cnn_mnist] - log_level : str - Log level for logging. METRIC level is available - log_file : str - Name of the file in which the log will be duplicated - agg_fqdn : str - The local node's fully qualified domain name (if it can't be - resolved automatically) - col_names: list[str] - The names of the collaborators that will be created. These - collaborators will be set up to participate in the experiment, but - are not required to + workspace_template (str): The template that should be used as the + basis for the experiment. Defaults to 'default'. + Other options include are any of the template names + [keras_cnn_mnist, tf_2dunet, tf_cnn_histology, + mtorch_cnn_histology, torch_cnn_mnist]. + log_level (str): Log level for logging. METRIC level is available. + Defaults to 'INFO'. + log_file (str): Name of the file in which the log will be duplicated. + If None, logs are not saved to a file. Defaults to None. + agg_fqdn (str): The local node's fully qualified domain name (if it + can't be resolved automatically). Defaults to None. + col_names (list[str]): The names of the collaborators that will be + created. These collaborators will be set up to participate in the + experiment, but are not required to. Defaults to None. Returns: None @@ -200,10 +232,10 @@ def init(workspace_template: str = 'default', log_level: str = 'INFO', aggregator.certify(agg_fqdn, silent=True) data_path = 1 for col_name in col_names: - collaborator.create( - col_name, str(data_path), silent=True) - collaborator.generate_cert_request( - col_name, silent=True, skip_package=True) + collaborator.create(col_name, str(data_path), silent=True) + collaborator.generate_cert_request(col_name, + silent=True, + skip_package=True) collaborator.certify(col_name, silent=True) data_path += 1 @@ -211,12 +243,20 @@ def init(workspace_template: str = 'default', log_level: str = 'INFO', def get_collaborator(plan, name, model, aggregator): - """ - Create the collaborator. + """Create the collaborator. Using the same plan object to create multiple collaborators leads to identical collaborator objects. This function can be removed once collaborator generation is fixed in openfl/federated/plan/plan.py + + Args: + plan (Plan): The plan to use to create the collaborator. + name (str): The name of the collaborator. + model (Model): The model to use for the collaborator. + aggregator (Aggregator): The aggregator to use for the collaborator. + + Returns: + Collaborator: The created collaborator. """ plan = copy(plan) @@ -224,22 +264,25 @@ def get_collaborator(plan, name, model, aggregator): def run_experiment(collaborator_dict: dict, override_config: dict = None): - """ - Core function that executes the FL Plan. + """Core function that executes the FL Plan. Args: - collaborator_dict : dict {collaborator_name(str): FederatedModel} - This dictionary defines which collaborators will participate in the - experiment, as well as a reference to that collaborator's + collaborator_dict (dict): A dictionary mapping collaborator names to + their federated models. + Example: {collaborator_name(str): FederatedModel} + This dictionary defines which collaborators will participate in + the experiment, as well as a reference to that collaborator's federated model. - override_config : dict {flplan.key : flplan.value} + override_config (dict, optional): A dictionary of values to override + in the plan. Defaults to None. + Example: dict {flplan.key : flplan.value} Override any of the plan parameters at runtime using this dictionary. To get a list of the available options, execute `fx.get_plan()` Returns: - final_federated_model : FederatedModel - The final model resulting from the federated learning experiment + model: Final Federated model. The model resulting from the federated + learning experiment """ from sys import path @@ -268,9 +311,7 @@ def run_experiment(collaborator_dict: dict, override_config: dict = None): init_state_path = plan.config['aggregator']['settings']['init_state_path'] rounds_to_train = plan.config['aggregator']['settings']['rounds_to_train'] tensor_dict, holdout_params = split_tensor_dict_for_holdouts( - logger, - model.get_tensor_dict(False) - ) + logger, model.get_tensor_dict(False)) model_snap = utils.construct_model_proto(tensor_dict=tensor_dict, round_number=0, @@ -286,9 +327,10 @@ def run_experiment(collaborator_dict: dict, override_config: dict = None): # get the collaborators collaborators = { - collaborator: get_collaborator( - plan, collaborator, collaborator_dict[collaborator], aggregator - ) for collaborator in plan.authorized_cols + collaborator: + get_collaborator(plan, collaborator, collaborator_dict[collaborator], + aggregator) + for collaborator in plan.authorized_cols } for _ in range(rounds_to_train): @@ -297,13 +339,26 @@ def run_experiment(collaborator_dict: dict, override_config: dict = None): collaborator.run_simulation() # Set the weights for the final model - model.rebuild_model( - rounds_to_train - 1, aggregator.last_tensor_dict, validation=True) + model.rebuild_model(rounds_to_train - 1, + aggregator.last_tensor_dict, + validation=True) return model def get_plan(fl_plan=None, indent=4, sort_keys=True): - """Get string representation of current Plan.""" + """Returns a string representation of the current Plan. + + Args: + fl_plan (Plan): The plan to get a string representation of. If None, a + new plan is set up. Defaults to None. + indent (int): The number of spaces to use for indentation in the + string representation. Defaults to 4. + sort_keys (bool): Whether to sort the keys in the string + representation. Defaults to True. + + Returns: + str: A string representation of the plan. + """ import json if fl_plan is None: plan = setup_plan() diff --git a/openfl/pipelines/eden_pipeline.py b/openfl/pipelines/eden_pipeline.py index e522e9c233..7317600b34 100644 --- a/openfl/pipelines/eden_pipeline.py +++ b/openfl/pipelines/eden_pipeline.py @@ -3,10 +3,11 @@ # Copyright 2022 VMware, Inc. # SPDX-License-Identifier: Apache-2.0 - """ -@author: Shay Vargaftik (VMware Research), shayv@vmware.com; vargaftik@gmail.com -@author: Yaniv Ben-Itzhak (VMware Research), ybenitzhak@vmware.com; yaniv.benizhak@gmail.com +@author: Shay Vargaftik (VMware Research), +shayv@vmware.com; vargaftik@gmail.com +@author: Yaniv Ben-Itzhak (VMware Research), +ybenitzhak@vmware.com; yaniv.benizhak@gmail.com EdenPipeline module. @@ -14,11 +15,15 @@ EDEN is an unbiased lossy compression method that uses a random rotation followed by deterministic quantization and scaling. -EDEN provides strong theoretical guarantees, as described in the following ICML 2022 paper: +EDEN provides strong theoretical guarantees, as described in the following + ICML 2022 paper: -"EDEN: Communication-Efficient and Robust Distributed Mean Estimation for Federated Learning" -Shay Vargaftik, Ran Ben Basat, Amit Portnoy, Gal Mendelson, Yaniv Ben Itzhak, Michael Mitzenmacher, -Proceedings of the 39th International Conference on Machine Learning, PMLR 162:21984-22014, 2022. +"EDEN: Communication-Efficient and Robust Distributed Mean Estimation for + Federated Learning" +Shay Vargaftik, Ran Ben Basat, Amit Portnoy, Gal Mendelson, Yaniv Ben Itzhak, +Michael Mitzenmacher, +Proceedings of the 39th International Conference on Machine Learning, +PMLR 162:21984-22014, 2022. https://proceedings.mlr.press/v162/vargaftik22a.html @@ -32,8 +37,8 @@ settings : n_bits : device: - dim_threshold: 1000 #EDEN compresses layers that their dimension is above the dim_threshold, - use 1000 as default + dim_threshold: 1000 #EDEN compresses layers that their dimension is above + the dim_threshold, use 1000 as default """ import torch @@ -46,10 +51,51 @@ class Eden: + """Eden class for quantization. + + This class is responsible for quantizing tensors using the Eden method. + + Attributes: + device (str): The device to be used for quantization ('cpu' or 'cuda'). + centroids (dict): A dictionary mapping the number of bits to the + corresponding centroids. + boundaries (dict): A dictionary mapping the number of bits to the + corresponding boundaries. + nbits (int): The number of bits per coordinate for quantization. + num_hadamard (int): The number of Hadamard transforms to employ. + max_padding_overhead (float): The maximum overhead that is allowed for + padding the vector. + """ def __init__(self, nbits=8, device='cpu'): + """Initialize Eden. + + Args: + nbits (int, optional): The number of bits per coordinate for + quantization. Defaults to 8. + device (str, optional): The device to be used for quantization + ('cpu' or 'cuda'). Defaults to 'cpu'. + """ def gen_normal_centroids_and_boundaries(device): + """Generates the centroids and boundaries for the quantization + process. + + This function generates the centroids and boundaries used in the + quantization process based on the specified device. The centroids + are generated for different numbers of bits, and the boundaries + are calculated based on these centroids. + + Args: + device (str): The device to be used for quantization + ('cpu' or 'cuda'). + + Returns: + tuple: A tuple containing two dictionaries. The first + dictionary maps the number of bits to the corresponding + centroids, and the second dictionary maps the number of + bits to the corresponding boundaries. + """ # half-normal lloyd-max centroids centroids = {} @@ -172,7 +218,15 @@ def gen_boundaries(centroids): # but 1 can be used for non-adversarial inputs and thus run faster def rand_diag(self, size, seed): + """Generate a random diagonal matrix. + + Args: + size (int): The size of the matrix. + seed (int): The seed for the random number generator. + Returns: + A random diagonal matrix. + """ bools_in_float32 = 8 shift = 32 // bools_in_float32 @@ -212,7 +266,14 @@ def rand_diag(self, size, seed): return res[:size] def hadamard(self, vec): + """Apply the Hadamard transform to a vector. + + Args: + vec: The vector to be transformed. + Returns: + The transformed vector. + """ d = vec.numel() if d & (d - 1) != 0: raise Exception("input numel must be a power of 2") @@ -228,24 +289,46 @@ def hadamard(self, vec): return vec.view(-1) - # randomized Hadamard transform def rht(self, vec, seed): + """Apply the randomized Hadamard transform to a vector. + + Args: + vec: The vector to be transformed. + seed (int): The seed for the random number generator. + Returns: + The transformed vector. + """ vec = vec * self.rand_diag(size=vec.numel(), seed=seed) vec = self.hadamard(vec) return vec - # inverse randomized Hadamard transform def irht(self, vec, seed): + """Apply the inverse randomized Hadamard transform to a vector. + Args: + vec: The vector to be transformed. + seed (int): The seed for the random number generator. + + Returns: + The transformed vector. + """ vec = self.hadamard(vec) vec = vec * self.rand_diag(size=vec.numel(), seed=seed) return vec def quantize(self, vec): + """Quantize a vector. + + Args: + vec: The vector to be quantized. + Returns: + bins: The quantized values of the vector. + scale: The scale factor for the quantization. + """ vec_norm = torch.norm(vec, 2) if vec_norm > 0: @@ -260,7 +343,17 @@ def quantize(self, vec): return torch.zeros(vec.numel(), device=self.device), torch.tensor([0]) def compress_slice(self, vec, seed): + """Compress a slice of a vector. + + Args: + vec: The slice of the vector to be compressed. + seed (int): The seed for the random number generator. + Returns: + bins: The compressed values of the slice. + scale: The scale factor for the compression. + dim: The dimension of the slice. + """ dim = vec.numel() if not dim & (dim - 1) == 0 or dim < 8: @@ -279,16 +372,28 @@ def compress_slice(self, vec, seed): return bins, float(scale.cpu().numpy()), vec.numel() def compress(self, vec, seed): + """Compress a vector. + + Args: + vec: The vector to be compressed. + seed (int): The seed for the random number generator. + + Returns: + int_array: The compressed values of the vector. + scale_list: The list of scale factors for the compression. + dim_list: The list of dimensions for the compression. + total_dim: The total dimension of the vector. + """ def low_po2(n): if not n: return 0 - return 2 ** int(np.log2(n)) + return 2**int(np.log2(n)) def high_po2(n): if not n: return 0 - return 2 ** (int(np.ceil(np.log2(n)))) + return 2**(int(np.ceil(np.log2(n)))) vec = torch.Tensor(vec.flatten()).to(self.device) @@ -324,7 +429,17 @@ def high_po2(n): return all_bins.cpu().numpy(), res_scale, res_dim, vec.numel() def decompress_slice(self, bins, scale, dim, seed): + """Decompress a slice of a vector. + Args: + bins: The compressed values of the slice. + scale: The scale factor for the decompression. + dim: The dimension of the slice. + seed (int): The seed for the random number generator. + + Returns: + The decompressed slice of the vector. + """ vec = torch.take(self.centroids[self.nbits], bins) for i in range(self.num_hadamard): @@ -333,7 +448,15 @@ def decompress_slice(self, bins, scale, dim, seed): return (scale * vec)[:dim] def decompress(self, bins, metadata): + """Decompress a vector. + + Args: + bins: The compressed values of the vector. + metadata: The metadata for the decompression. + Returns: + The decompressed vector. + """ bins = self.from_bits(torch.Tensor(bins).to(self.device)).long().flatten() seed = int(metadata[0]) @@ -353,8 +476,17 @@ def decompress(self, bins, metadata): return vec.cpu().numpy() - # packing the quantization values to bytes def to_bits(self, int_bool_vec): + """Convert a vector of integers to bits. + + Packing the quantization values to bytes. + + Args: + int_bool_vec: The vector of integers to be converted. + + Returns: + The bit vector. + """ def to_bits_h(ibv): @@ -376,8 +508,17 @@ def to_bits_h(ibv): return bit_vec - # unpacking bytes to quantization values def from_bits(self, bit_vec): + """Convert a bit vector to integers. + + Unpacking bytes to quantization values. + + Args: + bit_vec: The bit vector to be converted. + + Returns: + The vector of integers. + """ def from_bits_h(bv): @@ -400,10 +541,31 @@ def from_bits_h(bv): class EdenTransformer(Transformer): - """Eden transformer class to quantize input data.""" + """Eden transformer class for quantizing input data. + + This class is a transformer that uses the Eden method for quantization. + + Attributes: + n_bits (int): The number of bits per coordinate for quantization. + dim_threshold (int): The threshold for the dimension of the data. Data + with dimensions less than this threshold are not compressed. + device (str): The device to be used for quantization ('cpu' or 'cuda'). + eden (Eden): The Eden object for quantization. + no_comp (Float32NumpyArrayToBytes): The transformer for data that are + not compressed. + """ def __init__(self, n_bits=8, dim_threshold=100, device='cpu'): - """Class initializer. + """Initialize EdenTransformer. + + Args: + n_bits (int, optional): The number of bits per coordinate for + quantization. Defaults to 8. + dim_threshold (int, optional): The threshold for the dimension of + the data. Data with dimensions less than this threshold are + not compressed. Defaults to 100. + device (str, optional): The device to be used for quantization + ('cpu' or 'cuda'). Defaults to 'cpu'. """ self.lossy = True self.eden = Eden(nbits=n_bits, device=device) @@ -415,16 +577,23 @@ def __init__(self, n_bits=8, dim_threshold=100, device='cpu'): self.no_comp = Float32NumpyArrayToBytes() def forward(self, data, **kwargs): - """ - Quantize data. + """Quantize data. + + Args: + data: The data to be quantized. + + Returns: + The quantized data and the metadata for the quantization. """ # TODO: can be simplified if have access to a unique feature of the participant (e.g., ID) - seed = (hash(sum(data.flatten()) * 13 + 7) + np.random.randint(1, 2**16)) % (2**16) + seed = (hash(sum(data.flatten()) * 13 + 7) + + np.random.randint(1, 2**16)) % (2**16) seed = int(float(seed)) metadata = {'int_list': list(data.shape)} if data.size > self.dim_threshold: - int_array, scale_list, dim_list, total_dim = self.eden.compress(data, seed) + int_array, scale_list, dim_list, total_dim = self.eden.compress( + data, seed) # TODO: workaround: using the int to float dictionary to pass eden's metadata metadata['int_to_float'] = {0: float(seed), 1: float(total_dim)} @@ -445,23 +614,23 @@ def forward(self, data, **kwargs): return return_values def backward(self, data, metadata, **kwargs): - """Recover data array back to the original numerical type and the shape. + """Recover data array back to the original numerical type and the + shape. Args: - data: an flattened numpy array - metadata: dictionary to contain information for recovering to original data array + data: an flattened numpy array. + metadata: dictionary to contain information for recovering to + original data array. Returns: - data: Numpy array with original numerical type and shape + data: Numpy array with original numerical type and shape. """ - if np.prod(metadata['int_list']) >= self.dim_threshold: # compressed data + if np.prod( + metadata['int_list']) >= self.dim_threshold: # compressed data data = np.frombuffer(data, dtype=np.uint8) data = co.deepcopy(data) - data = self.eden.decompress( - data, - metadata['int_to_float'] - ) + data = self.eden.decompress(data, metadata['int_to_float']) data_shape = list(metadata['int_list']) data = data.reshape(data_shape) else: @@ -472,19 +641,31 @@ def backward(self, data, metadata, **kwargs): class EdenPipeline(TransformationPipeline): - """A pipeline class to compress data lossy using EDEN.""" + """A pipeline class for compressing data using the Eden method. + + This class is a pipeline of transformers that use the Eden method for + quantization. + + Attributes: + n_bits (int): The number of bits per coordinate for quantization. + dim_threshold (int): The threshold for the dimension of the data. Data + with dimensions less than this threshold are not compressed. + device (str): The device to be used for quantization ('cpu' or 'cuda'). + """ def __init__(self, n_bits=8, dim_threshold=100, device='cpu', **kwargs): """Initialize a pipeline of transformers. Args: - n_bits (int): Number of bits per coordinate (1-8 bits are supported) - dim_threshold (int): Layers with less than dim_threshold params are not compressed - device: Device for executing the compression and decompression - (e.g., 'cpu', 'cuda:0', 'cuda:1') + n_bits (int): Number of bits per coordinate (1-8 bits are + supported). + dim_threshold (int): Layers with less than dim_threshold params + are not compressed. + device: Device for executing the compression and decompressionc + (e.g., 'cpu', 'cuda:0', 'cuda:1'). Return: - Transformer class object + Transformer class object. """ # instantiate each transformer diff --git a/openfl/pipelines/kc_pipeline.py b/openfl/pipelines/kc_pipeline.py index d29b1345d0..cf1b12627e 100644 --- a/openfl/pipelines/kc_pipeline.py +++ b/openfl/pipelines/kc_pipeline.py @@ -1,9 +1,7 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """KCPipeline module.""" - import copy as co import gzip as gz @@ -15,33 +13,44 @@ class KmeansTransformer(Transformer): - """A K-means transformer class to quantize input data.""" + """K-means transformer class for quantizing input data. + + This class is a transformer that uses the K-means method for quantization. + + Attributes: + n_cluster (int): The number of clusters for the K-means. + lossy (bool): Indicates if the transformer is lossy. + """ def __init__(self, n_cluster=6): - """Class initializer. + """Initialize KmeansTransformer. Args: - n_cluster (int): Number of clusters for the K-means + n_cluster (int, optional): The number of clusters for the K-means. + Defaults to 6. """ self.lossy = True self.n_cluster = n_cluster def forward(self, data, **kwargs): - """ - Quantize data into n_cluster levels of values. + """Quantize data into n_cluster levels of values. Args: - data: an numpy array from the model tensor_dict - data: an numpy array being quantized - **kwargs: Variable arguments to pass + data: The data to be quantized. + **kwargs: Variable arguments to pass. + + Returns: + int_array: The quantized data. + metadata: The metadata for the quantization. """ metadata = {'int_list': list(data.shape)} # clustering - k_means = cluster.KMeans(n_clusters=self.n_cluster, n_init=self.n_cluster) + k_means = cluster.KMeans(n_clusters=self.n_cluster, + n_init=self.n_cluster) data = data.reshape((-1, 1)) if data.shape[0] >= self.n_cluster: - k_means = cluster.KMeans( - n_clusters=self.n_cluster, n_init=self.n_cluster) + k_means = cluster.KMeans(n_clusters=self.n_cluster, + n_init=self.n_cluster) k_means.fit(data) quantized_values = k_means.cluster_centers_.squeeze() indices = k_means.labels_ @@ -55,14 +64,16 @@ def forward(self, data, **kwargs): return int_array, metadata def backward(self, data, metadata, **kwargs): - """Recover data array back to the original numerical type and the shape. + """Recover data array back to the original numerical type and the + shape. Args: - data: an flattened numpy array - metadata: dictionary to contain information for recovering ack to original data array + data: The flattened numpy array. + metadata: The dictionary containing information for recovering to + original data array. Returns: - data: Numpy array with original numerical type and shape + data: The numpy array with original numerical type and shape. """ # convert back to float # TODO @@ -77,14 +88,16 @@ def backward(self, data, metadata, **kwargs): @staticmethod def _float_to_int(np_array): - """Create look-up table for conversion between floating and integer types. + """Create look-up table for conversion between floating and integer + types. Args: - np_array: A Numpy array + np_array: A Numpy array. Returns: - int_array: The input Numpy float array converted to an integer array - int_to_float_map + int_array: The input Numpy float array converted to an integer + array. + int_to_float_map: The dictionary mapping integers to floats. """ flatten_array = np_array.reshape(-1) unique_value_array = np.unique(flatten_array) @@ -103,20 +116,25 @@ def _float_to_int(np_array): class GZIPTransformer(Transformer): - """A GZIP transformer class to losslessly compress data.""" + """GZIP transformer class for losslessly compressing data. + + Attributes: + lossy (bool): Indicates if the transformer is lossy. + """ def __init__(self): - """Initialize.""" + """Initialize GZIPTransformer.""" self.lossy = False def forward(self, data, **kwargs): """Compress data into bytes. Args: - data: A Numpy array + data: A Numpy array. Returns: - GZIP compressed data object + compressed_bytes_: The GZIP compressed data object. + metadata: An empty dictionary. """ bytes_ = data.astype(np.float32).tobytes() compressed_bytes_ = gz.compress(bytes_) @@ -128,11 +146,11 @@ def backward(self, data, metadata, **kwargs): Args: data: Compressed GZIP data - metadata: + metadata: An empty dictionary. **kwargs: Additional parameters to pass to the function Returns: - data: Numpy array + data: The decompressed data as a numpy array. """ decompressed_bytes_ = gz.decompress(data) data = np.frombuffer(decompressed_bytes_, dtype=np.float32) @@ -140,17 +158,22 @@ def backward(self, data, metadata, **kwargs): class KCPipeline(TransformationPipeline): - """A pipeline class to compress data lossly using k-means and GZIP methods.""" + """A pipeline class to compress data lossly using k-means and GZIP methods. + + Attributes: + p (float): The amount of sparsity for compression. + n_cluster (int): The number of K-mean clusters. + """ def __init__(self, p_sparsity=0.01, n_clusters=6, **kwargs): """Initialize a pipeline of transformers. Args: - p_sparsity (float): Amount of sparsity for compression (Default = 0.01) - n_clusters (int): Number of K-mean cluster - - Return: - Transformer class object + p_sparsity (float, optional): The amount of sparsity for + compression. Defaults to 0.01. + n_clusters (int, optional): The number of K-mean clusters. + Defaults to 6. + **kwargs: Additional keyword arguments. """ # instantiate each transformer self.p = p_sparsity diff --git a/openfl/pipelines/no_compression_pipeline.py b/openfl/pipelines/no_compression_pipeline.py index b0f7527ebe..5cc189f45d 100644 --- a/openfl/pipelines/no_compression_pipeline.py +++ b/openfl/pipelines/no_compression_pipeline.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """NoCompressionPipeline module.""" from .pipeline import Float32NumpyArrayToBytes @@ -12,5 +11,6 @@ class NoCompressionPipeline(TransformationPipeline): def __init__(self, **kwargs): """Initialize.""" - super(NoCompressionPipeline, self).__init__( - transformers=[Float32NumpyArrayToBytes()], **kwargs) + super(NoCompressionPipeline, + self).__init__(transformers=[Float32NumpyArrayToBytes()], + **kwargs) diff --git a/openfl/pipelines/pipeline.py b/openfl/pipelines/pipeline.py index a5a6479914..4d0d9d1b67 100644 --- a/openfl/pipelines/pipeline.py +++ b/openfl/pipelines/pipeline.py @@ -1,26 +1,30 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Pipeline module.""" import numpy as np class Transformer: - """Data transformation class.""" + """Base class for data transformation. + + This class defines the basic structure of a data transformer. It should be + subclassed when implementing new types of data transformations. + """ def forward(self, data, **kwargs): """Forward pass data transformation. Implement the data transformation. + This method should be overridden by all subclasses. Args: - data: + data: The data to be transformed. **kwargs: Additional parameters to pass to the function Returns: - transformed_data - metadata + transformed_data: The transformed data. + metadata: The metadata for the transformation. """ raise NotImplementedError @@ -29,35 +33,37 @@ def backward(self, data, metadata, **kwargs): Implement the data transformation needed when going the opposite direction to the forward method. + This method should be overridden by all subclasses. Args: - data: - metadata: - **kwargs: Additional parameters to pass to the function + data: The transformed data. + metadata: The metadata for the transformation. + **kwargs: Additional keyword arguments for the transformation. Returns: - transformed_data + transformed_data: The original data before the transformation. """ raise NotImplementedError class Float32NumpyArrayToBytes(Transformer): - """Converts float32 Numpy array to Bytes array.""" + """Transformer class for converting float32 Numpy arrays to bytes + arrays.""" def __init__(self): - """Initialize.""" + """Initialize Float32NumpyArrayToBytes.""" self.lossy = False def forward(self, data, **kwargs): - """Forward pass. + """Convert a float32 Numpy array to bytes. Args: - data: - **kwargs: Additional arguments to pass to the function + data: The float32 Numpy array to be converted. + **kwargs: Additional keyword arguments for the conversion. Returns: - data_bytes: - metadata: + data_bytes: The data converted to bytes. + metadata: The metadata for the conversion. """ # TODO: Warn when this casting is being performed. if data.dtype != np.float32: @@ -69,15 +75,14 @@ def forward(self, data, **kwargs): return data_bytes, metadata def backward(self, data, metadata, **kwargs): - """Backward pass. + """Convert bytes back to a float32 Numpy array. Args: - data: - metadata: + data: The data in bytes. + metadata: The metadata for the conversion. Returns: - Numpy Array - + The data converted back to a float32 Numpy array. """ array_shape = tuple(metadata['int_list']) flat_array = np.frombuffer(data, dtype=np.float32) @@ -89,17 +94,23 @@ def backward(self, data, metadata, **kwargs): class TransformationPipeline: """Data Transformer Pipeline Class. + This class is a pipeline of transformers that transform data in a + sequential manner. + A sequential pipeline to transform (e.x. compress) data (e.x. layer of model_weights) as well as return metadata (if needed) for the reconstruction process carried out by the backward method. + + Attributes: + transformers (list): The list of transformers in the pipeline. """ def __init__(self, transformers, **kwargs): - """Initialize. + """Initialize TransformationPipeline. Args: - transformers: - **kwargs: Additional parameters to pass to the function + transformers (list): The list of transformers in the pipeline. + **kwargs: Additional keyword arguments for the pipeline. """ self.transformers = transformers @@ -107,13 +118,12 @@ def forward(self, data, **kwargs): """Forward pass of pipeline data transformer. Args: - data: Data to transform - **kwargs: Additional parameters to pass to the function + data: The data to be transformed. + **kwargs: Additional keyword arguments for the transformation. Returns: - data: - transformer_metadata: - + data: The transformed data. + transformer_metadata: The metadata for the transformation. """ transformer_metadata = [] @@ -139,19 +149,24 @@ def backward(self, data, transformer_metadata, **kwargs): """Backward pass of pipeline data transformer. Args: - data: Data to transform - transformer_metadata: - **kwargs: Additional parameters to pass to the function + data: The transformed data. + transformer_metadata: The metadata for the transformation. + **kwargs: Additional keyword arguments for the transformation. Returns: - data: - + The original data before the transformation. """ for transformer in self.transformers[::-1]: - data = transformer.backward( - data=data, metadata=transformer_metadata.pop(), **kwargs) + data = transformer.backward(data=data, + metadata=transformer_metadata.pop(), + **kwargs) return data def is_lossy(self): - """If any of the transformers are lossy, then the pipeline is lossy.""" + """If any of the transformers are lossy, then the pipeline is lossy. + + Returns: + True if any of the transformers in the pipeline are lossy, False + otherwise. + """ return any(transformer.lossy for transformer in self.transformers) diff --git a/openfl/pipelines/random_shift_pipeline.py b/openfl/pipelines/random_shift_pipeline.py index f9b3785f35..07f49106d1 100644 --- a/openfl/pipelines/random_shift_pipeline.py +++ b/openfl/pipelines/random_shift_pipeline.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """RandomShiftPipeline module.""" import numpy as np @@ -11,10 +10,10 @@ class RandomShiftTransformer(Transformer): - """Random Shift Transformer.""" + """Random Shift Transformer class.""" def __init__(self): - """Initialize.""" + """Initialize RandomShiftTransformer.""" self.lossy = False return @@ -25,16 +24,15 @@ def forward(self, data, **kwargs): Implement the data transformation. Args: - data: an array value from a model tensor_dict + data: The data to be transformed. Returns: - transformed_data: - metadata: - + transformed_data: The data after the random shift. + metadata: The metadata for the transformation. """ shape = data.shape - random_shift = np.random.uniform( - low=-20, high=20, size=shape).astype(np.float32) + random_shift = np.random.uniform(low=-20, high=20, + size=shape).astype(np.float32) transformed_data = data + random_shift # construct metadata @@ -51,11 +49,11 @@ def backward(self, data, metadata, **kwargs): direction to the forward method. Args: - data: - metadata: + data: The transformed data. + metadata: The metadata for the transformation. Returns: - transformed_data: + The original data before the random shift. """ shape = tuple(metadata['int_list']) # this is an awkward use of the metadata into to float dict, usually diff --git a/openfl/pipelines/skc_pipeline.py b/openfl/pipelines/skc_pipeline.py index cbb138bcde..419c70beb1 100644 --- a/openfl/pipelines/skc_pipeline.py +++ b/openfl/pipelines/skc_pipeline.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """SKCPipeline module.""" import copy as co @@ -14,26 +13,32 @@ class SparsityTransformer(Transformer): - """A transformer class to sparsify input data.""" + """A transformer class to sparsify input data. + + Attributes: + p (float): The sparsity ratio. + lossy (bool): A flag indicating if the transformation is lossy. + """ def __init__(self, p=0.01): """Initialize. Args: - p (float): sparsity ratio (Default=0.01) + p (float, optional): The sparsity ratio. Defaults to 0.01. """ self.lossy = True self.p = p def forward(self, data, **kwargs): - """ - Sparsify data and pass over only non-sparsified elements by reducing the array size. + """Sparsify data and pass over only non-sparsified elements by reducing + the array size. Args: data: an numpy array from the model tensor_dict. Returns: - sparse_data: a flattened, sparse representation of the input tensor + sparse_data: a flattened, sparse representation of the input + tensor. metadata: dictionary to store a list of meta information. """ metadata = {'int_list': list(data.shape)} @@ -53,7 +58,7 @@ def backward(self, data, metadata, **kwargs): Args: data: an numpy array with non-zero values. metadata: dictionary to contain information for recovering back - to original data array. + to original data array. Returns: recovered_data: an numpy array with original shape. @@ -89,10 +94,20 @@ def _topk_func(x, k): class KmeansTransformer(Transformer): - """A transformer class to quantize input data.""" + """A transformer class to quantize input data. + + Attributes: + n_cluster (int): The number of clusters for the K-means. + lossy (bool): A flag indicating if the transformation is lossy. + """ def __init__(self, n_cluster=6): - """Initialize.""" + """Initialize KmeansTransformer. + + Args: + n_cluster (int, optional): The number of clusters for the K-means. + Defaults to 6. + """ self.n_cluster = n_cluster self.lossy = True @@ -109,8 +124,8 @@ def forward(self, data, **kwargs): # clustering data = data.reshape((-1, 1)) if data.shape[0] >= self.n_cluster: - k_means = cluster.KMeans( - n_clusters=self.n_cluster, n_init=self.n_cluster) + k_means = cluster.KMeans(n_clusters=self.n_cluster, + n_init=self.n_cluster) k_means.fit(data) quantized_values = k_means.cluster_centers_.squeeze() indices = k_means.labels_ @@ -126,12 +141,14 @@ def backward(self, data, metadata, **kwargs): """Recover data array back to the original numerical type. Args: - data: an numpy array with non-zero values + data: an numpy array with non-zero values. metadata: dictionary to contain information for recovering back - to original data array + to original data array. Returns: - data: an numpy array with original numerical type + metadata: dictionary to contain information for recovering back + to original data array. + data: an numpy array with original numerical type. """ # convert back to float data = co.deepcopy(data) @@ -143,14 +160,16 @@ def backward(self, data, metadata, **kwargs): @staticmethod def _float_to_int(np_array): - """ - Create look-up table for conversion between floating and integer types. + """Create look-up table for conversion between floating and integer + types. Args: - np_array + np_array: A numpy array. Returns: - int_array, int_to_float_map + int_array: The input numpy float array converted to an integer + array. + int_to_float_map: The dictionary mapping integers to floats. """ flatten_array = np_array.reshape(-1) unique_value_array = np.unique(flatten_array) @@ -169,7 +188,11 @@ def _float_to_int(np_array): class GZIPTransformer(Transformer): - """A transformer class to losslessly compress data.""" + """GZIP transformer class for losslessly compressing data. + + Attributes: + lossy (bool): A flag indicating if the transformation is lossy. + """ def __init__(self): """Initialize.""" @@ -179,7 +202,11 @@ def forward(self, data, **kwargs): """Compress data into bytes. Args: - data: an numpy array with non-zero values + data: an numpy array with non-zero values. + + Returns: + compressed_bytes_: The compressed data. + metadata: An empty dictionary. """ bytes_ = data.astype(np.float32).tobytes() compressed_bytes_ = gz.compress(bytes_) @@ -190,12 +217,13 @@ def backward(self, data, metadata, **kwargs): """Decompress data into numpy of float32. Args: - data: an numpy array with non-zero values + data: an numpy array with non-zero values. metadata: dictionary to contain information for recovering back - to original data array + to original data array. Returns: - data: + data: A numpy array with the original numerical type after + decompression. """ decompressed_bytes_ = gz.decompress(data) data = np.frombuffer(decompressed_bytes_, dtype=np.float32) @@ -203,17 +231,25 @@ def backward(self, data, metadata, **kwargs): class SKCPipeline(TransformationPipeline): - """A pipeline class to compress data lossly using sparsity and k-means methods.""" + """A pipeline class to compress data lossly using sparsity and k-means + methods. + + Attributes: + p (float): The sparsity factor. + n_cluster (int): The number of K-mean clusters. + """ def __init__(self, p_sparsity=0.1, n_clusters=6, **kwargs): """Initialize a pipeline of transformers. Args: - p_sparsity (float): Sparsity factor (Default=0.1) - n_cluster (int): Number of K-Means clusters (Default=6) + p_sparsity (float, optional): The sparsity factor. Defaults to 0.1. + n_clusters (int, optional): The number of K-mean clusters. + Defaults to 6. + **kwargs: Additional keyword arguments for the pipeline. Returns: - Data compression transformer pipeline object + Data compression transformer pipeline object. """ # instantiate each transformer self.p = p_sparsity diff --git a/openfl/pipelines/stc_pipeline.py b/openfl/pipelines/stc_pipeline.py index 7198502050..91ad5a78af 100644 --- a/openfl/pipelines/stc_pipeline.py +++ b/openfl/pipelines/stc_pipeline.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """STCPipelinemodule.""" import gzip as gz @@ -12,26 +11,32 @@ class SparsityTransformer(Transformer): - """A transformer class to sparsify input data.""" + """A transformer class to sparsify input data. + + Attributes: + p (float): The sparsity ratio. + lossy (bool): A flag indicating if the transformation is lossy. + """ def __init__(self, p=0.01): """Initialize. Args: - p (float): sparsity ratio (Default=0.01) + p (float, optional): The sparsity ratio. Defaults to 0.01. """ self.lossy = True self.p = p def forward(self, data, **kwargs): - """ - Sparsify data and pass over only non-sparsified elements by reducing the array size. + """Sparsify data and pass over only non-sparsified elements by reducing + the array size. Args: data: an numpy array from the model tensor_dict. Returns: - sparse_data: a flattened, sparse representation of the input tensor + sparse_data: a flattened, sparse representation of the input + tensor. metadata: dictionary to store a list of meta information. """ metadata = {'int_list': list(data.shape)} @@ -51,7 +56,7 @@ def backward(self, data, metadata, **kwargs): Args: data: an numpy array with non-zero values. metadata: dictionary to contain information for recovering back - to original data array. + to original data array. Returns: recovered_data: an numpy array with original shape. @@ -87,14 +92,19 @@ def _topk_func(x, k): class TernaryTransformer(Transformer): - """A transformer class to ternerize input data.""" + """A transformer class to ternarize input data. + + Attributes: + lossy (bool): A flag indicating if the transformation is lossy. + """ def __init__(self): """Initialize.""" self.lossy = True def forward(self, data, **kwargs): - """Ternerize data into positive mean value, negative mean value and zero value. + """Ternerize data into positive mean value, negative mean value and + zero value. Args: data: an flattened numpy array @@ -116,10 +126,13 @@ def backward(self, data, metadata, **kwargs): Args: data: an numpy array with non-zero values. + metadata: dictionary to contain information for recovering back + to original data array. Returns: - metadata: dictionary to contain information for recovering back to original data array. - data (return): an numpy array with original numerical type. + metadata: dictionary to contain information for recovering back + to original data array. + data: an numpy array with original numerical type. """ # TODO import copy @@ -132,15 +145,16 @@ def backward(self, data, metadata, **kwargs): @staticmethod def _float_to_int(np_array): - """Create look-up table for conversion between floating and integer types. + """Create look-up table for conversion between floating and integer + types. Args: - np_array: + np_array: A numpy array. Returns: - int_array: - int_to_float_map: - + int_array: The input numpy float array converted to an integer + array. + int_to_float_map: The dictionary mapping integers to floats. """ flatten_array = np_array.reshape(-1) unique_value_array = np.unique(flatten_array) @@ -159,7 +173,11 @@ def _float_to_int(np_array): class GZIPTransformer(Transformer): - """A transformer class to losslessly compress data.""" + """GZIP transformer class for losslessly compressing data. + + Attributes: + lossy (bool): A flag indicating if the transformation is lossy. + """ def __init__(self): """Initialize.""" @@ -169,12 +187,12 @@ def forward(self, data, **kwargs): """Compress data into numpy of float32. Args: - data: an numpy array with non-zero values + data: an numpy array with non-zero values. Returns: - compressed_bytes : - metadata: dictionary to contain information for recovering back to original data array - + compressed_bytes: The compressed data. + metadata: dictionary to contain information for recovering back + to original data array """ bytes_ = data.astype(np.float32).tobytes() compressed_bytes = gz.compress(bytes_) @@ -185,11 +203,13 @@ def backward(self, data, metadata, **kwargs): """Decompress data into numpy of float32. Args: - data: an numpy array with non-zero values - metadata: dictionary to contain information for recovering back to original data array + data: an numpy array with non-zero values. + metadata: dictionary to contain information for recovering back + to original data array. Returns: - data: + data: A numpy array with the original numerical type after + decompression. """ decompressed_bytes_ = gz.decompress(data) data = np.frombuffer(decompressed_bytes_, dtype=np.float32) @@ -197,19 +217,30 @@ def backward(self, data, metadata, **kwargs): class STCPipeline(TransformationPipeline): - """A pipeline class to compress data lossly using sparsity and ternerization methods.""" + """A pipeline class to compress data lossly using sparsity and + ternarization methods. + + Attributes: + p (float): The sparsity factor. + """ def __init__(self, p_sparsity=0.1, n_clusters=6, **kwargs): """Initialize a pipeline of transformers. Args: - p_sparsity (float): Sparsity factor (Default=0.01) - n_cluster (int): Number of K-Means clusters (Default=6) + p_sparsity (float, optional): The sparsity factor. Defaults to 0.1. + n_clusters (int, optional): The number of K-mean clusters. + Defaults to 6. + **kwargs: Additional keyword arguments for the pipeline. Returns: - Data compression transformer pipeline object + Data compression transformer pipeline object. """ # instantiate each transformer self.p = p_sparsity - transformers = [SparsityTransformer(self.p), TernaryTransformer(), GZIPTransformer()] + transformers = [ + SparsityTransformer(self.p), + TernaryTransformer(), + GZIPTransformer() + ] super(STCPipeline, self).__init__(transformers=transformers, **kwargs) diff --git a/openfl/pipelines/tensor_codec.py b/openfl/pipelines/tensor_codec.py index 907f8840e8..ba3c3596a7 100644 --- a/openfl/pipelines/tensor_codec.py +++ b/openfl/pipelines/tensor_codec.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """TensorCodec module.""" import numpy as np @@ -13,12 +12,22 @@ class TensorCodec: """TensorCodec is responsible for the following. - 1. Tracking the compression/decompression related dependencies of a given tensor - 2. Acting as a TensorKey aware wrapper for the compression_pipeline functionality + 1. Tracking the compression/decompression related dependencies of a given + tensor. + 2. Acting as a TensorKey aware wrapper for the compression_pipeline + functionality. + + Attributes: + compression_pipeline: The pipeline used for compression. + lossless_pipeline: The pipeline used for lossless compression. """ def __init__(self, compression_pipeline): - """Initialize.""" + """Initialize the TensorCodec. + + Args: + compression_pipeline: The pipeline used for compression. + """ self.compression_pipeline = compression_pipeline if self.compression_pipeline.is_lossy(): self.lossless_pipeline = NoCompressionPipeline() @@ -26,37 +35,37 @@ def __init__(self, compression_pipeline): self.lossless_pipeline = compression_pipeline def set_lossless_pipeline(self, lossless_pipeline): - """Set lossless pipeline.""" + """Set lossless pipeline. + + Args: + lossless_pipeline: The pipeline to be set as the lossless pipeline. + It should be a pipeline that is not lossy. + + Raises: + AssertionError: If the provided pipeline is not lossless. + """ assert lossless_pipeline.is_lossy() is False, ( 'The provided pipeline is not lossless') self.lossless_pipeline = lossless_pipeline def compress(self, tensor_key, data, require_lossless=False, **kwargs): - """ - Function-wrapper around the tensor_pipeline.forward function. + """Function-wrapper around the tensor_pipeline.forward function. - It also keeps track of the tensorkeys associated with the compressed nparray + It also keeps track of the tensorkeys associated with the compressed + nparray. Args: - tensor_key: TensorKey is provided to verify it should - be compressed, and new TensorKeys returned - will be derivatives of the existing - tensor_name - - data: (uncompressed) numpy array associated with - the tensor_key - - require_lossless: boolean. Does tensor require - compression + tensor_key: TensorKey is provided to verify it should be + compressed, and new TensorKeys returned will be derivatives of + the existing tensor_name. + data: (uncompressed) numpy array associated with the tensor_key. + require_lossless: boolean. Does tensor require compression. Returns: - compressed_tensor_key: Tensorkey corresponding to the decompressed - tensor - - compressed_nparray: The compressed tensor - - metadata: metadata associated with compressed tensor - + compressed_tensor_key: Tensorkey corresponding to the decompressed + tensor. + compressed_nparray: The compressed tensor. + metadata: metadata associated with compressed tensor. """ if require_lossless: compressed_nparray, metadata = self.lossless_pipeline.forward( @@ -71,48 +80,46 @@ def compress(self, tensor_key, data, require_lossless=False, **kwargs): new_tags = change_tags(tags, add_field='compressed') else: new_tags = change_tags(tags, add_field='lossy_compressed') - compressed_tensor_key = TensorKey( - tensor_name, origin, round_number, report, new_tags) + compressed_tensor_key = TensorKey(tensor_name, origin, round_number, + report, new_tags) return compressed_tensor_key, compressed_nparray, metadata - def decompress(self, tensor_key, data, transformer_metadata, - require_lossless=False, **kwargs): - """ - Function-wrapper around the tensor_pipeline.backward function. + def decompress(self, + tensor_key, + data, + transformer_metadata, + require_lossless=False, + **kwargs): + """Function-wrapper around the tensor_pipeline.backward function. - It also keeps track of the tensorkeys associated with the decompressed nparray + It also keeps track of the tensorkeys associated with the decompressed + nparray. Args: - tensor_key: TensorKey is provided to verify it should - be decompressed, and new TensorKeys - returned will be derivatives of the - existing tensor_name - - data: (compressed) numpy array associated with - the tensor_key - - transformer_metadata: metadata associated with the compressed - tensor - - require_lossless: boolean, does data require lossless - decompression + tensor_key: TensorKey is provided to verify it should be + decompressed, and new TensorKeys returned will be derivatives + of the existing tensor_name. + data: (compressed) numpy array associated with the tensor_key. + transformer_metadata: metadata associated with the compressed + tensor. + require_lossless: boolean, does data require lossless + decompression. Returns: - decompressed_tensor_key: Tensorkey corresponding to the - decompressed tensor - - decompressed_nparray: The decompressed tensor - + decompressed_tensor_key: Tensorkey corresponding to the + decompressed tensor. + decompressed_nparray: The decompressed tensor. """ tensor_name, origin, round_number, report, tags = tensor_key - assert (len(transformer_metadata) > 0), ( - 'metadata must be included for decompression') - assert (('compressed' in tags) or ('lossy_compressed' in tags)), ( - 'Cannot decompress an uncompressed tensor') + assert (len(transformer_metadata) + > 0), ('metadata must be included for decompression') + assert (('compressed' in tags) + or ('lossy_compressed' + in tags)), ('Cannot decompress an uncompressed tensor') if require_lossless: - assert ('compressed' in tags), ( - 'Cannot losslessly decompress lossy tensor') + assert ('compressed' + in tags), ('Cannot losslessly decompress lossy tensor') if require_lossless or 'compressed' in tags: decompressed_nparray = self.lossless_pipeline.backward( @@ -122,16 +129,17 @@ def decompress(self, tensor_key, data, transformer_metadata, data, transformer_metadata, **kwargs) # Define the decompressed tensorkey that should be returned if 'lossy_compressed' in tags: - new_tags = change_tags( - tags, add_field='lossy_decompressed', remove_field='lossy_compressed') - decompressed_tensor_key = TensorKey( - tensor_name, origin, round_number, report, new_tags) + new_tags = change_tags(tags, + add_field='lossy_decompressed', + remove_field='lossy_compressed') + decompressed_tensor_key = TensorKey(tensor_name, origin, + round_number, report, new_tags) elif 'compressed' in tags: # 'compressed' == lossless compression; no need for # compression related tag after decompression new_tags = change_tags(tags, remove_field='compressed') - decompressed_tensor_key = TensorKey( - tensor_name, origin, round_number, report, new_tags) + decompressed_tensor_key = TensorKey(tensor_name, origin, + round_number, report, new_tags) else: raise NotImplementedError( 'Decompression is only supported on compressed data') @@ -140,82 +148,84 @@ def decompress(self, tensor_key, data, transformer_metadata, @staticmethod def generate_delta(tensor_key, nparray, base_model_nparray): - """ - Create delta from the updated layer and base layer. + """Create delta from the updated layer and base layer. Args: - tensor_key: This is the tensor_key associated with the - nparray. - Should have a tag of 'trained' or 'aggregated' - - nparray: The nparray that corresponds to the tensorkey - + tensor_key: This is the tensor_key associated with the nparray. + Should have a tag of 'trained' or 'aggregated' + nparray: The nparray that corresponds to the tensorkey. base_model_nparray: The base model tensor that will be subtracted - from the new weights + from the new weights. Returns: - delta_tensor_key: Tensorkey that corresponds to the delta weight - array - - delta: Difference between the provided tensors - + delta_tensor_key: Tensorkey that corresponds to the delta weight + array. + delta: Difference between the provided tensors. """ tensor_name, origin, round_number, report, tags = tensor_key if not np.isscalar(nparray): assert nparray.shape == base_model_nparray.shape, ( f'Shape of updated layer ({nparray.shape}) is not equal to base ' - f'layer shape of ({base_model_nparray.shape})' - ) + f'layer shape of ({base_model_nparray.shape})') assert 'model' not in tags, ( 'The tensorkey should be provided ' 'from the layer with new weights, not the base model') new_tags = change_tags(tags, add_field='delta') - delta_tensor_key = TensorKey( - tensor_name, origin, round_number, report, new_tags) + delta_tensor_key = TensorKey(tensor_name, origin, round_number, report, + new_tags) return delta_tensor_key, nparray - base_model_nparray @staticmethod - def apply_delta(tensor_key, delta, base_model_nparray, creates_model=False): - """ - Add delta to the nparray. + def apply_delta(tensor_key, + delta, + base_model_nparray, + creates_model=False): + """Add delta to the nparray. Args: - tensor_key: This is the tensor_key associated with the - delta. Should have a tag of 'trained' or - 'aggregated' - delta: Weight delta between the new model and - old model - base_model_nparray: The nparray that corresponds to the prior - weights - creates_model: If flag is set, the tensorkey returned - will correspond to the aggregator model + tensor_key: This is the tensor_key associated with the delta. + Should have a tag of 'trained' or 'aggregated'. + delta: Weight delta between the new model and old model. + base_model_nparray: The nparray that corresponds to the prior + weights. + creates_model: If flag is set, the tensorkey returned will + correspond to the aggregator model. Returns: - new_model_tensor_key: Latest model layer tensorkey - new_model_nparray: Latest layer weights - + new_model_tensor_key: Latest model layer tensorkey. + new_model_nparray: Latest layer weights. """ tensor_name, origin, round_number, report, tags = tensor_key if not np.isscalar(base_model_nparray): assert (delta.shape == base_model_nparray.shape), ( f'Shape of delta ({delta.shape}) is not equal to shape of model' - f' layer ({base_model_nparray.shape})' - ) + f' layer ({base_model_nparray.shape})') # assert('model' in tensor_key[3]), 'The tensorkey should be provided # from the base model' # Aggregator UUID has the prefix 'aggregator' if 'aggregator' in origin and not creates_model: new_tags = change_tags(tags, remove_field='delta') - new_model_tensor_key = TensorKey( - tensor_name, origin, round_number, report, new_tags) + new_model_tensor_key = TensorKey(tensor_name, origin, round_number, + report, new_tags) else: - new_model_tensor_key = TensorKey( - tensor_name, origin, round_number, report, ('model',)) + new_model_tensor_key = TensorKey(tensor_name, origin, round_number, + report, ('model', )) return new_model_tensor_key, base_model_nparray + delta def find_dependencies(self, tensor_key, send_model_deltas): - """Resolve the tensors required to do the specified operation.""" + """Resolve the tensors required to do the specified operation. + + Args: + tensor_key: A tuple containing the tensor name, origin, round + number, report, and tags. + send_model_deltas: A boolean flag indicating whether to send model + deltas. + + Returns: + tensor_key_dependencies: A list of tensor keys that are + dependencies of the given tensor key. + """ tensor_key_dependencies = [] tensor_name, origin, round_number, report, tags = tensor_key @@ -224,18 +234,14 @@ def find_dependencies(self, tensor_key, send_model_deltas): if round_number >= 1: # The new model can be generated by previous model + delta tensor_key_dependencies.append( - TensorKey( - tensor_name, origin, round_number - 1, report, tags - ) - ) + TensorKey(tensor_name, origin, round_number - 1, report, + tags)) if self.compression_pipeline.is_lossy(): new_tags = ('aggregated', 'delta', 'lossy_compressed') else: new_tags = ('aggregated', 'delta', 'compressed') tensor_key_dependencies.append( - TensorKey( - tensor_name, origin, round_number, report, new_tags - ) - ) + TensorKey(tensor_name, origin, round_number, report, + new_tags)) return tensor_key_dependencies diff --git a/openfl/plugins/frameworks_adapters/flax_adapter.py b/openfl/plugins/frameworks_adapters/flax_adapter.py index 9ad077a9b1..0c758fc480 100644 --- a/openfl/plugins/frameworks_adapters/flax_adapter.py +++ b/openfl/plugins/frameworks_adapters/flax_adapter.py @@ -15,12 +15,15 @@ class FrameworkAdapterPlugin(FrameworkAdapterPluginInterface): @staticmethod def get_tensor_dict(model, optimizer=None): - """ - Extract tensor dict from a model.params and model.opt_state (optimizer). + """Extract tensor dict from a model.params and model.opt_state + (optimizer). - Returns: - dict {weight name: numpy ndarray} + Args: + model (object): The model object. + optimizer (object, optional): The optimizer object. Defaults to None. + Returns: + params_dict (dict): A dictionary with weight name as key and numpy ndarray as value. """ # Convert PyTree Structure DeviceArray to Numpy @@ -28,7 +31,8 @@ def get_tensor_dict(model, optimizer=None): params_dict = _get_weights_dict(model_params, 'param') # If optimizer is initialized - # Optax Optimizer agnostic state processing (TraceState, AdamScaleState, any...) + # Optax Optimizer agnostic state processing (TraceState, + # AdamScaleState, any...) if not isinstance(model.opt_state[0], optax.EmptyState): opt_state = jax.tree_util.tree_map(np.array, model.opt_state)[0] opt_vars = filter(_get_opt_vars, dir(opt_state)) @@ -42,11 +46,17 @@ def get_tensor_dict(model, optimizer=None): @staticmethod def set_tensor_dict(model, tensor_dict, optimizer=None, device='cpu'): - """ - Set the `model.params and model.opt_state` with a flattened tensor dictionary. - Choice of JAX platform (device) cpu/gpu/gpu is initialized at start. + """Set the `model.params and model.opt_state` with a flattened tensor + dictionary. Choice of JAX platform (device) cpu/gpu/gpu is initialized + at start. + Args: - tensor_dict: flattened {weight name: numpy ndarray} tensor dictionary + model (object): The model object. + tensor_dict (dict): Flattened dictionary with weight name as key + and numpy ndarray as value. + optimizer (object, optional): The optimizer object. Defaults to + None. + device (str, optional): The device to be used. Defaults to 'cpu'. Returns: None @@ -61,6 +71,14 @@ def set_tensor_dict(model, tensor_dict, optimizer=None, device='cpu'): def _get_opt_vars(x): + """Helper function to filter out unwanted variables. + + Args: + x (str): The variable name. + + Returns: + bool: True if the variable is wanted, False otherwise. + """ return False if x.startswith('_') or x in ['index', 'count'] else True @@ -70,9 +88,11 @@ def _set_weights_dict(obj, weights_dict, prefix=''): The obj can be a model or an optimizer. Args: - obj (Model or Optimizer): The target object that we want to set - the weights. + obj (Model or Optimizer): The target object that we want to set the + weights. weights_dict (dict): The weight dictionary. + prefix (str, optional): The prefix for the weight dictionary keys. + Defaults to ''. Returns: None @@ -90,8 +110,21 @@ def _set_weights_dict(obj, weights_dict, prefix=''): def _update_weights(state_dict, tensor_dict, prefix, suffix=None): - # Re-assignment of the state variable(s) is restricted. - # Instead update the nested layers weights iteratively. + """Helper function to update the weights of the state dictionary. + + Re-assignment of the state variable(s) is restricted. + Instead update the nested layers weights iteratively. + + Args: + state_dict (dict): The state dictionary. + tensor_dict (dict): The tensor dictionary. + prefix (str): The prefix for the weight dictionary keys. + suffix (str, optional): The suffix for the weight dictionary keys. + Defaults to None. + + Returns: + None + """ dict_prefix = f'{prefix}_{suffix}' if suffix is not None else f'{prefix}' for layer_name, param_obj in state_dict.items(): for param_name, value in param_obj.items(): @@ -101,18 +134,15 @@ def _update_weights(state_dict, tensor_dict, prefix, suffix=None): def _get_weights_dict(obj, prefix): - """ - Get the dictionary of weights. + """Get the dictionary of weights. - Parameters - ---------- - obj : Model or Optimizer - The target object that we want to get the weights. + Args: + obj (Model or Optimizer): The target object that we want to get the + weights. + prefix (str): The prefix for the weight dictionary keys. - Returns - ------- - dict - The weight dictionary. + Returns: + flat_params (dict): The weight dictionary. """ weights_dict = {prefix: obj} # Flatten the dictionary with a given separator for diff --git a/openfl/plugins/frameworks_adapters/framework_adapter_interface.py b/openfl/plugins/frameworks_adapters/framework_adapter_interface.py index 95de6d09fd..be4a449947 100644 --- a/openfl/plugins/frameworks_adapters/framework_adapter_interface.py +++ b/openfl/plugins/frameworks_adapters/framework_adapter_interface.py @@ -17,20 +17,43 @@ def serialization_setup(): @staticmethod def get_tensor_dict(model, optimizer=None) -> dict: - """ - Extract tensor dict from a model and an optimizer. + """Extract tensor dict from a model and an optimizer. + + Args: + model (object): The model object. + optimizer (object, optional): The optimizer object. Defaults to + None. Returns: - dict {weight name: numpy ndarray} + dict: A dictionary with weight name as key and numpy ndarray as + value. + + Raises: + NotImplementedError: This is a placeholder method that needs to be + implemented in subclasses. """ raise NotImplementedError @staticmethod def set_tensor_dict(model, tensor_dict, optimizer=None, device='cpu'): - """ - Set tensor dict from a model and an optimizer. + """Set tensor dict from a model and an optimizer. Given a dict {weight name: numpy ndarray} sets weights to the model and optimizer objects inplace. + + Args: + model (object): The model object. + tensor_dict (dict): Dictionary with weight name as key and numpy + ndarray as value. + optimizer (object, optional): The optimizer object. Defaults to + None. + device (str, optional): The device to be used. Defaults to 'cpu'. + + Returns: + None + + Raises: + NotImplementedError: This is a placeholder method that needs to be + implemented in subclasses. """ raise NotImplementedError diff --git a/openfl/plugins/frameworks_adapters/keras_adapter.py b/openfl/plugins/frameworks_adapters/keras_adapter.py index 9508fceaf1..71d0b41fd4 100644 --- a/openfl/plugins/frameworks_adapters/keras_adapter.py +++ b/openfl/plugins/frameworks_adapters/keras_adapter.py @@ -32,9 +32,7 @@ def unpack(model, training_config, weights): if training_config is not None: restored_model.compile( **saving_utils.compile_args_from_training_config( - training_config - ) - ) + training_config)) restored_model.set_weights(weights) return restored_model @@ -54,10 +52,13 @@ def __reduce__(self): # NOQA:N807 # Run the function if version.parse(tf.__version__) <= version.parse('2.7.1'): - logger.warn('Applying hotfix for model serialization.' - 'Please consider updating to tensorflow>=2.8 to silence this warning.') + logger.warn( + 'Applying hotfix for model serialization.' + 'Please consider updating to tensorflow>=2.8 to silence this warning.' + ) make_keras_picklable() if version.parse(tf.__version__) >= version.parse('2.13'): + def build(self, var_list): pass @@ -66,11 +67,18 @@ def build(self, var_list): @staticmethod def get_tensor_dict(model, optimizer=None, suffix=''): - """ - Extract tensor dict from a model and an optimizer. + """Extract tensor dict from a model and an optimizer. + + Args: + model (object): The model object. + optimizer (object, optional): The optimizer object. Defaults to + None. + suffix (str, optional): The suffix for the weight dictionary keys. + Defaults to ''. Returns: - dict {weight name: numpy ndarray} + model_weights (dict): A dictionary with weight name as key and + numpy ndarray as value. """ model_weights = _get_weights_dict(model, suffix) @@ -86,42 +94,45 @@ def get_tensor_dict(model, optimizer=None, suffix=''): @staticmethod def set_tensor_dict(model, tensor_dict, optimizer=None, device='cpu'): - """ - Set the model weights with a tensor dictionary. + """Set the model weights with a tensor dictionary. Args: - tensor_dict: the tensor dictionary - with_opt_vars (bool): True = include the optimizer's status. + model (object): The model object. + tensor_dict (dict): The tensor dictionary. + optimizer (object, optional): The optimizer object. Defaults to + None. + device (str, optional): The device to be used. Defaults to 'cpu'. + + Returns: + None """ model_weight_names = [weight.name for weight in model.weights] model_weights_dict = { - name: tensor_dict[name] for name in model_weight_names + name: tensor_dict[name] + for name in model_weight_names } _set_weights_dict(model, model_weights_dict) if optimizer is not None: - opt_weight_names = [ - weight.name for weight in optimizer.weights - ] + opt_weight_names = [weight.name for weight in optimizer.weights] opt_weights_dict = { - name: tensor_dict[name] for name in opt_weight_names + name: tensor_dict[name] + for name in opt_weight_names } _set_weights_dict(optimizer, opt_weights_dict) def _get_weights_dict(obj, suffix=''): - """ - Get the dictionary of weights. + """Get the dictionary of weights. - Parameters - ---------- - obj : Model or Optimizer - The target object that we want to get the weights. + Args: + obj (Model or Optimizer): The target object that we want to get the + weights. + suffix (str, optional): The suffix for the weight dictionary keys. + Defaults to ''. - Returns - ------- - dict - The weight dictionary. + Returns: + weights_dict (dict): The weight dictionary. """ weights_dict = {} weight_names = [weight.name for weight in obj.weights] @@ -134,11 +145,9 @@ def _get_weights_dict(obj, suffix=''): def _set_weights_dict(obj, weights_dict): """Set the object weights with a dictionary. - The obj can be a model or an optimizer. - Args: - obj (Model or Optimizer): The target object that we want to set - the weights. + obj (Model or Optimizer): The target object that we want to set the + weights. weights_dict (dict): The weight dictionary. Returns: diff --git a/openfl/plugins/frameworks_adapters/pytorch_adapter.py b/openfl/plugins/frameworks_adapters/pytorch_adapter.py index 2ecaecc710..068f71b3bc 100644 --- a/openfl/plugins/frameworks_adapters/pytorch_adapter.py +++ b/openfl/plugins/frameworks_adapters/pytorch_adapter.py @@ -18,11 +18,16 @@ def __init__(self) -> None: @staticmethod def get_tensor_dict(model, optimizer=None): - """ - Extract tensor dict from a model and an optimizer. + """Extract tensor dict from a model and an optimizer. + + Args: + model (object): The model object. + optimizer (object, optional): The optimizer object. Defaults to + None. Returns: - dict {weight name: numpy ndarray} + dict: A dictionary with weight name as key and numpy ndarray as + value. """ state = to_cpu_numpy(model.state_dict()) @@ -34,11 +39,20 @@ def get_tensor_dict(model, optimizer=None): @staticmethod def set_tensor_dict(model, tensor_dict, optimizer=None, device='cpu'): - """ - Set tensor dict from a model and an optimizer. + """Set tensor dict from a model and an optimizer. Given a dict {weight name: numpy ndarray} sets weights to the model and optimizer objects inplace. + + Args: + model (object): The model object. + tensor_dict (dict): The tensor dictionary. + optimizer (object, optional): The optimizer object. Defaults to + None. + device (str, optional): The device to be used. Defaults to 'cpu'. + + Returns: + None """ new_state = {} # Grabbing keys from model's state_dict helps to confirm we have @@ -62,13 +76,15 @@ def _set_optimizer_state(optimizer, device, derived_opt_state_dict): """Set the optimizer state. Args: - optimizer: - device: - derived_opt_state_dict: + optimizer (object): The optimizer object. + device (str): The device to be used. + derived_opt_state_dict (dict): The derived optimizer state dictionary. + Returns: + None """ - temp_state_dict = expand_derived_opt_state_dict( - derived_opt_state_dict, device) + temp_state_dict = expand_derived_opt_state_dict(derived_opt_state_dict, + device) # Setting other items from the param_groups # getting them from the local optimizer @@ -85,7 +101,10 @@ def _get_optimizer_state(optimizer): """Return the optimizer state. Args: - optimizer + optimizer (object): The optimizer object. + + Returns: + derived_opt_state_dict (dict): The optimizer state dictionary. """ opt_state_dict = deepcopy(optimizer.state_dict()) @@ -111,8 +130,10 @@ def _derive_opt_state_dict(opt_state_dict): expand_derived_opt_state_dict. Args: - opt_state_dict: The optimizer state dictionary + opt_state_dict (dict): The optimizer state dictionary. + Returns: + derived_opt_state_dict (dict): The derived optimizer state dictionary. """ derived_opt_state_dict = {} @@ -127,8 +148,7 @@ def _derive_opt_state_dict(opt_state_dict): # dictionary value. example_state_key = opt_state_dict['param_groups'][0]['params'][0] example_state_subkeys = set( - opt_state_dict['state'][example_state_key].keys() - ) + opt_state_dict['state'][example_state_key].keys()) # We assume that the state collected for all params in all param groups is # the same. @@ -137,12 +157,12 @@ def _derive_opt_state_dict(opt_state_dict): # Using assert statements to break the routine if these assumptions are # incorrect. for state_key in opt_state_dict['state'].keys(): - assert example_state_subkeys == set(opt_state_dict['state'][state_key].keys()) + assert example_state_subkeys == set( + opt_state_dict['state'][state_key].keys()) for state_subkey in example_state_subkeys: assert (isinstance( opt_state_dict['state'][example_state_key][state_subkey], - pt.Tensor) - == isinstance( + pt.Tensor) == isinstance( opt_state_dict['state'][state_key][state_subkey], pt.Tensor)) @@ -152,10 +172,8 @@ def _derive_opt_state_dict(opt_state_dict): # tensor or not. state_subkey_tags = [] for state_subkey in state_subkeys: - if isinstance( - opt_state_dict['state'][example_state_key][state_subkey], - pt.Tensor - ): + if isinstance(opt_state_dict['state'][example_state_key][state_subkey], + pt.Tensor): state_subkey_tags.append('istensor') else: state_subkey_tags.append('') @@ -169,19 +187,18 @@ def _derive_opt_state_dict(opt_state_dict): for idx, param_id in enumerate(group['params']): for subkey, tag in state_subkeys_and_tags: if tag == 'istensor': - new_v = opt_state_dict['state'][param_id][ - subkey].cpu().numpy() + new_v = opt_state_dict['state'][param_id][subkey].cpu( + ).numpy() else: new_v = np.array( - [opt_state_dict['state'][param_id][subkey]] - ) - derived_opt_state_dict[f'__opt_state_{group_idx}_{idx}_{tag}_{subkey}'] = new_v + [opt_state_dict['state'][param_id][subkey]]) + derived_opt_state_dict[ + f'__opt_state_{group_idx}_{idx}_{tag}_{subkey}'] = new_v nb_params_per_group.append(idx + 1) # group lengths are also helpful for reconstructing # original opt_state_dict structure derived_opt_state_dict['__opt_group_lengths'] = np.array( - nb_params_per_group - ) + nb_params_per_group) return derived_opt_state_dict @@ -196,10 +213,11 @@ def expand_derived_opt_state_dict(derived_opt_state_dict, device): prefix, "__opt_state_0_0_", certain to be present. Args: - derived_opt_state_dict: Optimizer state dictionary + derived_opt_state_dict (dict): The derived optimizer state dictionary. + device (str): The device to be used. Returns: - dict: Optimizer state dictionary + opt_state_dict (dict): The expanded optimizer state dictionary. """ state_subkeys_and_tags = [] for key in derived_opt_state_dict: @@ -215,14 +233,11 @@ def expand_derived_opt_state_dict(derived_opt_state_dict, device): opt_state_dict = {'param_groups': [], 'state': {}} nb_params_per_group = list( - derived_opt_state_dict.pop('__opt_group_lengths').astype(np.int32) - ) + derived_opt_state_dict.pop('__opt_group_lengths').astype(np.int32)) # Construct the expanded dict. for group_idx, nb_params in enumerate(nb_params_per_group): - these_group_ids = [ - f'{group_idx}_{idx}' for idx in range(nb_params) - ] + these_group_ids = [f'{group_idx}_{idx}' for idx in range(nb_params)] opt_state_dict['param_groups'].append({'params': these_group_ids}) for this_id in these_group_ids: opt_state_dict['state'][this_id] = {} @@ -248,8 +263,10 @@ def to_cpu_numpy(state): """Send data to CPU as Numpy array. Args: - state + state (dict): The state dictionary. + Returns: + state (dict): The state dictionary with all values as numpy arrays. """ # deep copy so as to decouple from active model state = deepcopy(state) diff --git a/openfl/plugins/interface_serializer/cloudpickle_serializer.py b/openfl/plugins/interface_serializer/cloudpickle_serializer.py index 98fb888c26..5c70ce9a7e 100644 --- a/openfl/plugins/interface_serializer/cloudpickle_serializer.py +++ b/openfl/plugins/interface_serializer/cloudpickle_serializer.py @@ -16,12 +16,29 @@ def __init__(self) -> None: @staticmethod def serialize(object_, filename): - """Serialize an object and save to disk.""" + """Serialize an object and save to disk. + + Args: + object_ (object): The object to be serialized. + filename (str): The name of the file where the serialized object + will be saved. + + Returns: + None + """ with open(filename, 'wb') as f: cloudpickle.dump(object_, f) @staticmethod def restore_object(filename): - """Load and deserialize an object.""" + """Load and deserialize an object. + + Args: + filename (str): The name of the file where the serialized object + is saved. + + Returns: + object: The deserialized object. + """ with open(filename, 'rb') as f: return cloudpickle.load(f) diff --git a/openfl/plugins/interface_serializer/dill_serializer.py b/openfl/plugins/interface_serializer/dill_serializer.py index f4bb9ffd58..50b043babc 100644 --- a/openfl/plugins/interface_serializer/dill_serializer.py +++ b/openfl/plugins/interface_serializer/dill_serializer.py @@ -16,12 +16,29 @@ def __init__(self) -> None: @staticmethod def serialize(object_, filename): - """Serialize an object and save to disk.""" + """Serialize an object and save to disk. + + Args: + object_ (object): The object to be serialized. + filename (str): The name of the file where the serialized object + will be saved. + + Returns: + None + """ with open(filename, 'wb') as f: dill.dump(object_, f, recurse=True) @staticmethod def restore_object(filename): - """Load and deserialize an object.""" + """Load and deserialize an object. + + Args: + filename (str): The name of the file where the serialized object + is saved. + + Returns: + object: The deserialized object. + """ with open(filename, 'rb') as f: return dill.load(f) # nosec diff --git a/openfl/plugins/interface_serializer/keras_serializer.py b/openfl/plugins/interface_serializer/keras_serializer.py index ec36f38d25..e5a0fb48bc 100644 --- a/openfl/plugins/interface_serializer/keras_serializer.py +++ b/openfl/plugins/interface_serializer/keras_serializer.py @@ -16,13 +16,30 @@ def __init__(self) -> None: @staticmethod def serialize(object_, filename): - """Serialize an object and save to disk.""" + """Serialize an object and save to disk. + + Args: + object_ (object): The object to be serialized. + filename (str): The name of the file where the serialized object + will be saved. + + Returns: + None + """ with open(filename, 'wb') as f: cloudpickle.dump(object_, f) @staticmethod def restore_object(filename): - """Load and deserialize an object.""" + """Load and deserialize an object. + + Args: + filename (str): The name of the file where the serialized object + is saved. + + Returns: + object: The deserialized object. + """ from tensorflow.keras.optimizers.legacy import Optimizer def build(self, var_list): diff --git a/openfl/plugins/interface_serializer/serializer_interface.py b/openfl/plugins/interface_serializer/serializer_interface.py index b72d970a1c..becdbd0774 100644 --- a/openfl/plugins/interface_serializer/serializer_interface.py +++ b/openfl/plugins/interface_serializer/serializer_interface.py @@ -12,10 +12,39 @@ def __init__(self) -> None: @staticmethod def serialize(object_, filename): - """Serialize an object and save to disk.""" + """Serialize an object and save to disk. + + This is a static method that is not implemented. + + Args: + object_ (object): The object to be serialized. + filename (str): The name of the file where the serialized object + will be saved. + + Returns: + None + + Raises: + NotImplementedError: This is a placeholder method that needs to be + implemented in subclasses. + """ raise NotImplementedError @staticmethod def restore_object(filename): - """Load and deserialize an object.""" + """Load and deserialize an object. + + This is a static method that is not implemented. + + Args: + filename (str): The name of the file where the serialized object + is saved. + + Returns: + object: The deserialized object. + + Raises: + NotImplementedError: This is a placeholder method that needs to be + implemented in subclasses. + """ raise NotImplementedError diff --git a/openfl/plugins/processing_units_monitor/cuda_device_monitor.py b/openfl/plugins/processing_units_monitor/cuda_device_monitor.py index 4cf9d8b8e5..c6a20bbe83 100644 --- a/openfl/plugins/processing_units_monitor/cuda_device_monitor.py +++ b/openfl/plugins/processing_units_monitor/cuda_device_monitor.py @@ -9,29 +9,99 @@ class CUDADeviceMonitor(DeviceMonitor): """CUDA Device monitor plugin.""" def get_driver_version(self) -> str: - """Get CUDA driver version.""" + """Get CUDA driver version. + + This method is not implemented. + + Returns: + str: The CUDA driver version. + + Raises: + NotImplementedError: This is a placeholder method that needs to be + implemented in subclasses. + """ raise NotImplementedError def get_device_memory_total(self, index: int) -> int: - """Get total memory available on the device.""" + """Get total memory available on the device. + + This method is not implemented. + + Args: + index (int): The index of the device. + + Returns: + int: The total memory available on the device. + + Raises: + NotImplementedError: This is a placeholder method that needs to be + implemented in subclasses. + """ raise NotImplementedError def get_device_memory_utilized(self, index: int) -> int: - """Get utilized memory on the device.""" + """Get utilized memory on the device. + + This method is not implemented. + + Args: + index (int): The index of the device. + + Returns: + int: The utilized memory on the device. + + Raises: + NotImplementedError: This is a placeholder method that needs to be + implemented in subclasses. + """ raise NotImplementedError def get_device_utilization(self, index: int) -> str: - """ - Get device utilization. + """Get device utilization. + + It is just a general method that returns a string that may be shown to + the frontend user. + This method is not implemented. + + Args: + index (int): The index of the device. + + Returns: + str: The device utilization. - It is just a general method that returns a string that may be shown to the frontend user. + Raises: + NotImplementedError: This is a placeholder method that needs to be + implemented in subclasses. """ raise NotImplementedError def get_device_name(self, index: int) -> str: - """Get device name.""" + """Get device name. + + This method is not implemented. + + Args: + index (int): The index of the device. + + Returns: + str: The device name. + + Raises: + NotImplementedError: This is a placeholder method that needs to be + implemented in subclasses. + """ raise NotImplementedError def get_cuda_version(self) -> str: - """Get CUDA driver version.""" + """Get CUDA driver version. + + This method is not implemented. + + Returns: + str: The CUDA driver version. + + Raises: + NotImplementedError: This is a placeholder method that needs to be + implemented in subclasses. + """ raise NotImplementedError diff --git a/openfl/plugins/processing_units_monitor/device_monitor.py b/openfl/plugins/processing_units_monitor/device_monitor.py index c1ffe991db..5154d91a63 100644 --- a/openfl/plugins/processing_units_monitor/device_monitor.py +++ b/openfl/plugins/processing_units_monitor/device_monitor.py @@ -7,13 +7,33 @@ class DeviceMonitor: """Device monitor plugin interface.""" def get_driver_version(self) -> str: - """Get device's driver version.""" + """Get device's driver version. + + This method is not implemented. + + Returns: + str: The device's driver version. + + Raises: + NotImplementedError: This is a placeholder method that needs to be + implemented in subclasses. + """ raise NotImplementedError def get_device_utilization(self, index: int) -> str: - """ - Get device utilization method. + """Get device utilization method. + + It is just a general method that returns a string that may be shown to + the frontend user. + + Args: + index (int): The index of the device. + + Returns: + str: The device utilization. - It is just a general method that returns a string that may be shown to the frontend user. + Raises: + NotImplementedError: This is a placeholder method that needs to be + implemented in subclasses. """ raise NotImplementedError diff --git a/openfl/plugins/processing_units_monitor/pynvml_monitor.py b/openfl/plugins/processing_units_monitor/pynvml_monitor.py index e7f34e0a12..930862b8cf 100644 --- a/openfl/plugins/processing_units_monitor/pynvml_monitor.py +++ b/openfl/plugins/processing_units_monitor/pynvml_monitor.py @@ -1,7 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -""" -pynvml CUDA Device monitor plugin module. +"""Pynvml CUDA Device monitor plugin module. Required package: pynvml """ @@ -21,44 +20,77 @@ def __init__(self) -> None: super().__init__() def get_driver_version(self) -> str: - """Get Nvidia driver version.""" + """Get Nvidia driver version. + + Returns: + str: The Nvidia driver version. + """ return pynvml.nvmlSystemGetDriverVersion().decode('utf-8') def get_device_memory_total(self, index: int) -> int: - """Get total memory available on the device.""" + """Get total memory available on the device. + + Args: + index (int): The index of the device. + + Returns: + int: The total memory available on the device. + """ handle = pynvml.nvmlDeviceGetHandleByIndex(index) info = pynvml.nvmlDeviceGetMemoryInfo(handle) return info.total def get_device_memory_utilized(self, index: int) -> int: - """Get utilized memory on the device.""" + """Get utilized memory on the device. + + Args: + index (int): The index of the device. + + Returns: + int: The utilized memory on the device. + """ handle = pynvml.nvmlDeviceGetHandleByIndex(index) info = pynvml.nvmlDeviceGetMemoryInfo(handle) return info.used def get_device_utilization(self, index: int) -> str: - """ - Get device utilization. + """Get device utilization. + + It is just a general method that returns a string that may be shown to + the frontend user. - It is just a general method that returns a string that may be shown to the frontend user. + Args: + index (int): The index of the device. + + Returns: + str: The device utilization. """ handle = pynvml.nvmlDeviceGetHandleByIndex(index) info_utilization = pynvml.nvmlDeviceGetUtilizationRates(handle) return f'{info_utilization.gpu}%' def get_device_name(self, index: int) -> str: - """Get device utilization method.""" + """Get device utilization method. + + Args: + index (int): The index of the device. + + Returns: + device_name (str): The device name. + """ handle = pynvml.nvmlDeviceGetHandleByIndex(index) device_name = pynvml.nvmlDeviceGetName(handle) return device_name def get_cuda_version(self) -> str: - """ - Get CUDA driver version. + """Get CUDA driver version. The CUDA version is specified as (1000 * major + 10 * minor), so CUDA 11.2 should be specified as 11020. https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DRIVER__ENTRY__POINT.html + + Returns: + str: The CUDA driver version. """ cuda_version = pynvml.nvmlSystemGetCudaDriverVersion() major_version = int(cuda_version / 1000) diff --git a/openfl/protocols/interceptors.py b/openfl/protocols/interceptors.py index a54ff76d82..c9773bf627 100644 --- a/openfl/protocols/interceptors.py +++ b/openfl/protocols/interceptors.py @@ -1,7 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -"""gRPC interceptors module.""" +"""GRPC interceptors module.""" import collections import grpc @@ -15,16 +14,17 @@ class _GenericClientInterceptor(grpc.UnaryUnaryClientInterceptor, def __init__(self, interceptor_function): self._fn = interceptor_function - def intercept_unary_unary(self, continuation, client_call_details, request): + def intercept_unary_unary(self, continuation, client_call_details, + request): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, False) + client_call_details, iter((request, )), False, False) response = continuation(new_details, next(new_request_iterator)) return postprocess(response) if postprocess else response def intercept_unary_stream(self, continuation, client_call_details, request): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, True) + client_call_details, iter((request, )), False, True) response_it = continuation(new_details, next(new_request_iterator)) return postprocess(response_it) if postprocess else response_it @@ -48,20 +48,18 @@ def _create_generic_interceptor(intercept_call): class _ClientCallDetails( - collections.namedtuple( - '_ClientCallDetails', - ('method', 'timeout', 'metadata', 'credentials') - ), - grpc.ClientCallDetails -): + collections.namedtuple( + '_ClientCallDetails', + ('method', 'timeout', 'metadata', 'credentials')), + grpc.ClientCallDetails): pass def headers_adder(headers): """Create interceptor with added headers.""" - def intercept_call(client_call_details, request_iterator, request_streaming, - response_streaming): + def intercept_call(client_call_details, request_iterator, + request_streaming, response_streaming): metadata = [] if client_call_details.metadata is not None: metadata = list(client_call_details.metadata) diff --git a/openfl/protocols/utils.py b/openfl/protocols/utils.py index fc6edc7bae..3c0cd1830e 100644 --- a/openfl/protocols/utils.py +++ b/openfl/protocols/utils.py @@ -10,11 +10,14 @@ def model_proto_to_bytes_and_metadata(model_proto): """Convert the model protobuf to bytes and metadata. Args: - model_proto: Protobuf of the model + model_proto: The protobuf of the model. Returns: - bytes_dict: Dictionary of the bytes contained in the model protobuf - metadata_dict: Dictionary of the meta data in the model protobuf + bytes_dict: A dictionary where the keys are tensor names and the + values are the corresponding tensor data in bytes. + metadata_dict: A dictionary where the keys are tensor names and the + values are the corresponding tensor metadata. + round_number: The round number for the model. """ bytes_dict = {} metadata_dict = {} @@ -25,22 +28,32 @@ def model_proto_to_bytes_and_metadata(model_proto): 'int_to_float': proto.int_to_float, 'int_list': proto.int_list, 'bool_list': proto.bool_list - } - for proto in tensor_proto.transformer_metadata - ] + } for proto in tensor_proto.transformer_metadata] if round_number is None: round_number = tensor_proto.round_number else: assert round_number == tensor_proto.round_number, ( f'Round numbers in model are inconsistent: {round_number} ' - f'and {tensor_proto.round_number}' - ) + f'and {tensor_proto.round_number}') return bytes_dict, metadata_dict, round_number def bytes_and_metadata_to_model_proto(bytes_dict, model_id, model_version, is_delta, metadata_dict): - """Convert bytes and metadata to model protobuf.""" + """Convert bytes and metadata to model protobuf. + + Args: + bytes_dict: A dictionary where the keys are tensor names and the + values are the corresponding tensor data in bytes. + model_id: The ID of the model. + model_version: The version of the model. + is_delta: A flag indicating whether the model is a delta model. + metadata_dict: A dictionary where the keys are tensor names and the + values are the corresponding tensor metadata. + + Returns: + model_proto: The protobuf of the model. + """ model_header = ModelHeader(id=model_id, version=model_version, is_delta=is_delta) # NOQA:F821 tensor_protos = [] @@ -62,19 +75,31 @@ def bytes_and_metadata_to_model_proto(bytes_dict, model_id, model_version, bool_list = metadata.get('bool_list') else: bool_list = [] - metadata_protos.append(base_pb2.MetadataProto( - int_to_float=int_to_float, - int_list=int_list, - bool_list=bool_list, - )) + metadata_protos.append( + base_pb2.MetadataProto( + int_to_float=int_to_float, + int_list=int_list, + bool_list=bool_list, + )) tensor_protos.append(TensorProto(name=key, # NOQA:F821 data_bytes=data_bytes, transformer_metadata=metadata_protos)) return base_pb2.ModelProto(header=model_header, tensors=tensor_protos) -def construct_named_tensor(tensor_key, nparray, transformer_metadata, lossless): - """Construct named tensor.""" +def construct_named_tensor(tensor_key, nparray, transformer_metadata, + lossless): + """Construct named tensor. + + Args: + tensor_key: The key of the tensor. + nparray: The numpy array representing the tensor data. + transformer_metadata: The transformer metadata for the tensor. + lossless: A flag indicating whether the tensor is lossless. + + Returns: + named_tensor: The named tensor. + """ metadata_protos = [] for metadata in transformer_metadata: if metadata.get('int_to_float') is not None: @@ -91,11 +116,12 @@ def construct_named_tensor(tensor_key, nparray, transformer_metadata, lossless): bool_list = metadata.get('bool_list') else: bool_list = [] - metadata_protos.append(base_pb2.MetadataProto( - int_to_float=int_to_float, - int_list=int_list, - bool_list=bool_list, - )) + metadata_protos.append( + base_pb2.MetadataProto( + int_to_float=int_to_float, + int_list=int_list, + bool_list=bool_list, + )) tensor_name, origin, round_number, report, tags = tensor_key @@ -110,66 +136,112 @@ def construct_named_tensor(tensor_key, nparray, transformer_metadata, lossless): ) -def construct_proto(tensor_dict, model_id, model_version, is_delta, compression_pipeline): - """Construct proto.""" +def construct_proto(tensor_dict, model_id, model_version, is_delta, + compression_pipeline): + """Construct proto. + + Args: + tensor_dict: A dictionary where the keys are tensor names and the + values are the corresponding tensors. + model_id: The ID of the model. + model_version: The version of the model. + is_delta: A flag indicating whether the model is a delta model. + compression_pipeline: The compression pipeline for the model. + + Returns: + model_proto: The protobuf of the model. + """ # compress the arrays in the tensor_dict, and form the model proto # TODO: Hold-out tensors from the compression pipeline. bytes_dict = {} metadata_dict = {} for key, array in tensor_dict.items(): - bytes_dict[key], metadata_dict[key] = compression_pipeline.forward(data=array) - - # convert the compressed_tensor_dict and metadata to protobuf, and make the new model proto - model_proto = bytes_and_metadata_to_model_proto(bytes_dict=bytes_dict, - model_id=model_id, - model_version=model_version, - is_delta=is_delta, - metadata_dict=metadata_dict) + bytes_dict[key], metadata_dict[key] = compression_pipeline.forward( + data=array) + + # convert the compressed_tensor_dict and metadata to protobuf, and make + # the new model proto + model_proto = bytes_and_metadata_to_model_proto( + bytes_dict=bytes_dict, + model_id=model_id, + model_version=model_version, + is_delta=is_delta, + metadata_dict=metadata_dict) return model_proto def construct_model_proto(tensor_dict, round_number, tensor_pipe): - """Construct model proto from tensor dict.""" + """Construct model proto from tensor dict. + + Args: + tensor_dict: A dictionary where the keys are tensor names and the + values are the corresponding tensors. + round_number: The round number for the model. + tensor_pipe: The tensor pipe for the model. + + Returns: + model_proto: The protobuf of the model. + """ # compress the arrays in the tensor_dict, and form the model proto # TODO: Hold-out tensors from the tensor compression pipeline. named_tensors = [] for key, nparray in tensor_dict.items(): bytes_data, transformer_metadata = tensor_pipe.forward(data=nparray) - tensor_key = TensorKey(key, 'agg', round_number, False, ('model',)) - named_tensors.append(construct_named_tensor( - tensor_key, - bytes_data, - transformer_metadata, - lossless=True, - )) + tensor_key = TensorKey(key, 'agg', round_number, False, ('model', )) + named_tensors.append( + construct_named_tensor( + tensor_key, + bytes_data, + transformer_metadata, + lossless=True, + )) return base_pb2.ModelProto(tensors=named_tensors) def deconstruct_model_proto(model_proto, compression_pipeline): - """Deconstruct model proto.""" + """Deconstruct model proto. + + This function takes a model protobuf and a compression pipeline, + and deconstructs the protobuf into a dictionary of tensors and a round + number. + + Args: + model_proto: The protobuf of the model. + compression_pipeline: The compression pipeline for the model. + + Returns: + tensor_dict: A dictionary where the keys are tensor names and the + values are the corresponding tensors. + round_number: The round number for the model. + """ # extract the tensor_dict and metadata - bytes_dict, metadata_dict, round_number = model_proto_to_bytes_and_metadata(model_proto) + bytes_dict, metadata_dict, round_number = model_proto_to_bytes_and_metadata( + model_proto) # decompress the tensors # TODO: Handle tensors meant to be held-out from the compression pipeline # (currently none are held out). tensor_dict = {} for key in bytes_dict: - tensor_dict[key] = compression_pipeline.backward(data=bytes_dict[key], - transformer_metadata=metadata_dict[key]) + tensor_dict[key] = compression_pipeline.backward( + data=bytes_dict[key], transformer_metadata=metadata_dict[key]) return tensor_dict, round_number def deconstruct_proto(model_proto, compression_pipeline): """Deconstruct the protobuf. + This function takes a model protobuf and a compression pipeline, and + deconstructs the protobuf into a dictionary of tensors. + Args: - model_proto: The protobuf of the model - compression_pipeline: The compression pipeline object + model_proto: The protobuf of the model. + compression_pipeline: The compression pipeline for the model. Returns: - protobuf: A protobuf of the model + tensor_dict: A dictionary where the keys are tensor names and the + values are the corresponding tensors. """ # extract the tensor_dict and metadata bytes_dict, metadata_dict = model_proto_to_bytes_and_metadata(model_proto) @@ -179,8 +251,8 @@ def deconstruct_proto(model_proto, compression_pipeline): # (currently none are held out). tensor_dict = {} for key in bytes_dict: - tensor_dict[key] = compression_pipeline.backward(data=bytes_dict[key], - transformer_metadata=metadata_dict[key]) + tensor_dict[key] = compression_pipeline.backward( + data=bytes_dict[key], transformer_metadata=metadata_dict[key]) return tensor_dict @@ -188,10 +260,10 @@ def load_proto(fpath): """Load the protobuf. Args: - fpath: The filepath for the protobuf + fpath: The file path of the protobuf. Returns: - protobuf: A protobuf of the model + model: The protobuf of the model. """ with open(fpath, 'rb') as f: loaded = f.read() @@ -203,9 +275,8 @@ def dump_proto(model_proto, fpath): """Dump the protobuf to a file. Args: - model_proto: The protobuf of the model - fpath: The filename to save the model protobuf - + model_proto: The protobuf of the model. + fpath: The file path to dump the protobuf. """ s = model_proto.SerializeToString() with open(fpath, 'wb') as f: @@ -216,12 +287,12 @@ def datastream_to_proto(proto, stream, logger=None): """Convert the datastream to the protobuf. Args: - model_proto: The protobuf of the model - stream: The data stream from the remote connection - logger: (Optional) The log object + proto: The protobuf to be filled with the data stream. + stream: The data stream. + logger (optional): The logger for logging information. Returns: - protobuf: A protobuf of the model + proto: The protobuf filled with the data stream. """ npbytes = b'' for chunk in stream: @@ -233,30 +304,43 @@ def datastream_to_proto(proto, stream, logger=None): logger.debug(f'datastream_to_proto parsed a {type(proto)}.') return proto else: - raise RuntimeError(f'Received empty stream message of type {type(proto)}') + raise RuntimeError( + f'Received empty stream message of type {type(proto)}') def proto_to_datastream(proto, logger, max_buffer_size=(2 * 1024 * 1024)): """Convert the protobuf to the datastream for the remote connection. Args: - model_proto: The protobuf of the model - logger: The log object - max_buffer_size: The buffer size (Default= 2*1024*1024) - Returns: - reply: The message for the remote connection. + proto: The protobuf to be converted into a data stream. + logger: The logger for logging information. + max_buffer_size (optional): The maximum buffer size for the data + stream. Defaults to 2*1024*1024. + + Yields: + reply: Chunks of the data stream for the remote connection. """ npbytes = proto.SerializeToString() data_size = len(npbytes) buffer_size = data_size if max_buffer_size > data_size else max_buffer_size - logger.debug(f'Setting stream chunks with size {buffer_size} for proto of type {type(proto)}') + logger.debug( + f'Setting stream chunks with size {buffer_size} for proto of type {type(proto)}' + ) for i in range(0, data_size, buffer_size): - chunk = npbytes[i: i + buffer_size] + chunk = npbytes[i:i + buffer_size] reply = base_pb2.DataStream(npbytes=chunk, size=len(chunk)) yield reply def get_headers(context) -> dict: - """Get headers from context.""" + """Get headers from context. + + Args: + context: The context containing the headers. + + Returns: + headers: A dictionary where the keys are header names and the + values are the corresponding header values. + """ return {header[0]: header[1] for header in context.invocation_metadata()} diff --git a/openfl/transport/__init__.py b/openfl/transport/__init__.py index 474178432f..d3b0cc477d 100644 --- a/openfl/transport/__init__.py +++ b/openfl/transport/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.transport package.""" from .grpc import AggregatorGRPCClient diff --git a/openfl/transport/grpc/__init__.py b/openfl/transport/grpc/__init__.py index 784c9acf66..26e0c7d712 100644 --- a/openfl/transport/grpc/__init__.py +++ b/openfl/transport/grpc/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.transport.grpc package.""" from .aggregator_client import AggregatorGRPCClient diff --git a/openfl/transport/grpc/aggregator_client.py b/openfl/transport/grpc/aggregator_client.py index b6de77eb1e..ec9c9f3e59 100644 --- a/openfl/transport/grpc/aggregator_client.py +++ b/openfl/transport/grpc/aggregator_client.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """AggregatorGRPCClient module.""" import time @@ -20,10 +19,26 @@ class ConstantBackoff: - """Constant Backoff policy.""" + """Constant Backoff policy. + + This class implements a backoff policy that waits for a constant amount of + time between retries. + + Attributes: + reconnect_interval (int): The interval between connection attempts. + logger (Logger): The logger to use for reporting connection attempts. + uri (str): The URI to connect to. + """ def __init__(self, reconnect_interval, logger, uri): - """Initialize Constant Backoff.""" + """Initialize Constant Backoff. + + Args: + reconnect_interval (int): The interval between connection attempts. + logger (Logger): The logger to use for reporting connection + attempts. + uri (str): The URI to connect to. + """ self.reconnect_interval = reconnect_interval self.logger = logger self.uri = uri @@ -34,51 +49,99 @@ def sleep(self): time.sleep(self.reconnect_interval) -class RetryOnRpcErrorClientInterceptor( - grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor -): - """Retry gRPC connection on failure.""" +class RetryOnRpcErrorClientInterceptor(grpc.UnaryUnaryClientInterceptor, + grpc.StreamUnaryClientInterceptor): + """Retry gRPC connection on failure. + + This class implements a gRPC client interceptor that retries failed RPC + calls. + + Attributes: + sleeping_policy (ConstantBackoff): The backoff policy to use between + retries. + status_for_retry (Tuple[grpc.StatusCode]): The gRPC status codes that + should trigger a retry. + """ def __init__( - self, - sleeping_policy, - status_for_retry: Optional[Tuple[grpc.StatusCode]] = None, + self, + sleeping_policy, + status_for_retry: Optional[Tuple[grpc.StatusCode]] = None, ): - """Initialize function for gRPC retry.""" + """Initialize function for gRPC retry. + + Args: + sleeping_policy (ConstantBackoff): The backoff policy to use + between retries. + status_for_retry (Tuple[grpc.StatusCode], optional): The gRPC + status codes that should trigger a retry. + """ self.sleeping_policy = sleeping_policy self.status_for_retry = status_for_retry - def _intercept_call(self, continuation, client_call_details, request_or_iterator): - """Intercept the call to the gRPC server.""" + def _intercept_call(self, continuation, client_call_details, + request_or_iterator): + """Intercept the call to the gRPC server. + + Args: + continuation (function): The original RPC call. + client_call_details (grpc.ClientCallDetails): The details of the + call. + request_or_iterator (object): The request message for the RPC call. + + Returns: + response (grpc.Call): The result of the RPC call. + """ while True: response = continuation(client_call_details, request_or_iterator) if isinstance(response, grpc.RpcError): # If status code is not in retryable status codes - self.sleeping_policy.logger.info(f'Response code: {response.code()}') - if ( - self.status_for_retry - and response.code() not in self.status_for_retry - ): + self.sleeping_policy.logger.info( + f'Response code: {response.code()}') + if (self.status_for_retry + and response.code() not in self.status_for_retry): return response self.sleeping_policy.sleep() else: return response - def intercept_unary_unary(self, continuation, client_call_details, request): - """Wrap intercept call for unary->unary RPC.""" + def intercept_unary_unary(self, continuation, client_call_details, + request): + """Wrap intercept call for unary->unary RPC. + + Args: + continuation (function): The original RPC call. + client_call_details (grpc.ClientCallDetails): The details of the + call. + request (object): The request message for the RPC call. + + Returns: + grpc.Call: The result of the RPC call. + """ return self._intercept_call(continuation, client_call_details, request) - def intercept_stream_unary( - self, continuation, client_call_details, request_iterator - ): - """Wrap intercept call for stream->unary RPC.""" - return self._intercept_call(continuation, client_call_details, request_iterator) + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + """Wrap intercept call for stream->unary RPC. + + Args: + continuation (function): The original RPC call. + client_call_details (grpc.ClientCallDetails): The details of the + call. + request_iterator (iterator): The request messages for the RPC call. + + Returns: + grpc.Call: The result of the RPC call. + """ + return self._intercept_call(continuation, client_call_details, + request_iterator) def _atomic_connection(func): + def wrapper(self, *args, **kwargs): self.reconnect() response = func(self, *args, **kwargs) @@ -89,6 +152,7 @@ def wrapper(self, *args, **kwargs): def _resend_data_on_reconnection(func): + def wrapper(self, *args, **kwargs): while True: try: @@ -108,7 +172,27 @@ def wrapper(self, *args, **kwargs): class AggregatorGRPCClient: - """Client to the aggregator over gRPC-TLS.""" + """Client to the aggregator over gRPC-TLS. + + This class implements a gRPC client for communicating with an aggregator + over a secure (TLS) connection. + + Attributes: + uri (str): The URI of the aggregator. + tls (bool): Whether to use TLS for the connection. + disable_client_auth (bool): Whether to disable client-side + authentication. + root_certificate (str): The path to the root certificate for the TLS + connection. + certificate (str): The path to the client's certificate for the TLS + connection. + private_key (str): The path to the client's private key for the TLS + connection. + aggregator_uuid (str): The UUID of the aggregator. + federation_uuid (str): The UUID of the federation. + single_col_cert_common_name (str): The common name on the + collaborator's certificate. + """ def __init__(self, agg_addr, @@ -122,7 +206,26 @@ def __init__(self, federation_uuid=None, single_col_cert_common_name=None, **kwargs): - """Initialize.""" + """Initialize. + + Args: + agg_addr (str): The address of the aggregator. + agg_port (int): The port of the aggregator. + tls (bool): Whether to use TLS for the connection. + disable_client_auth (bool): Whether to disable client-side + authentication. + root_certificate (str): The path to the root certificate for the + TLS connection. + certificate (str): The path to the client's certificate for the + TLS connection. + private_key (str): The path to the client's private key for the + TLS connection. + aggregator_uuid (str,optional): The UUID of the aggregator. + federation_uuid (str, optional): The UUID of the federation. + single_col_cert_common_name (str, optional): The common name on + the collaborator's certificate. + **kwargs: Additional keyword arguments. + """ self.uri = f'{agg_addr}:{agg_port}' self.tls = tls self.disable_client_auth = disable_client_auth @@ -137,13 +240,11 @@ def __init__(self, 'gRPC is running on insecure channel with TLS disabled.') self.channel = self.create_insecure_channel(self.uri) else: - self.channel = self.create_tls_channel( - self.uri, - self.root_certificate, - self.disable_client_auth, - self.certificate, - self.private_key - ) + self.channel = self.create_tls_channel(self.uri, + self.root_certificate, + self.disable_client_auth, + self.certificate, + self.private_key) self.header = None self.aggregator_uuid = aggregator_uuid @@ -151,49 +252,46 @@ def __init__(self, self.single_col_cert_common_name = single_col_cert_common_name # Adding an interceptor for RPC Errors - self.interceptors = ( - RetryOnRpcErrorClientInterceptor( - sleeping_policy=ConstantBackoff( - logger=self.logger, - reconnect_interval=int(kwargs.get('client_reconnect_interval', 1)), - uri=self.uri), - status_for_retry=(grpc.StatusCode.UNAVAILABLE,), - ), - ) + self.interceptors = (RetryOnRpcErrorClientInterceptor( + sleeping_policy=ConstantBackoff( + logger=self.logger, + reconnect_interval=int( + kwargs.get('client_reconnect_interval', 1)), + uri=self.uri), + status_for_retry=(grpc.StatusCode.UNAVAILABLE, ), + ), ) self.stub = aggregator_pb2_grpc.AggregatorStub( - grpc.intercept_channel(self.channel, *self.interceptors) - ) + grpc.intercept_channel(self.channel, *self.interceptors)) def create_insecure_channel(self, uri): - """ - Set an insecure gRPC channel (i.e. no TLS) if desired. + """Set an insecure gRPC channel (i.e. no TLS) if desired. Warns user that this is not recommended. Args: - uri: The uniform resource identifier fo the insecure channel + uri (str): The uniform resource identifier for the insecure channel Returns: - An insecure gRPC channel object - + grpc.Channel: An insecure gRPC channel object """ return grpc.insecure_channel(uri, options=channel_options) def create_tls_channel(self, uri, root_certificate, disable_client_auth, certificate, private_key): - """ - Set an secure gRPC channel (i.e. TLS). + """Set an secure gRPC channel (i.e. TLS). Args: - uri: The uniform resource identifier fo the insecure channel - root_certificate: The Certificate Authority filename - disable_client_auth (boolean): True disabled client-side - authentication (not recommended, throws warning to user) - certificate: The client certficate filename from the collaborator - (signed by the certificate authority) + uri (str): The uniform resource identifier for the secure channel. + root_certificate (str): The Certificate Authority filename. + disable_client_auth (bool): True disables client-side + authentication (not recommended, throws warning to user). + certificate (str): The client certificate filename from the + collaborator (signed by the certificate authority). + private_key (str): The private key filename for the client + certificate. Returns: - An insecure gRPC channel object + grpc.Channel: A secure gRPC channel object """ with open(root_certificate, 'rb') as f: root_certificate_b = f.read() @@ -214,36 +312,38 @@ def create_tls_channel(self, uri, root_certificate, disable_client_auth, certificate_chain=certificate_b, ) - return grpc.secure_channel( - uri, credentials, options=channel_options) + return grpc.secure_channel(uri, credentials, options=channel_options) def _set_header(self, collaborator_name): + """Set the header for gRPC messages. + + Args: + collaborator_name (str): The name of the collaborator. + """ self.header = aggregator_pb2.MessageHeader( sender=collaborator_name, receiver=self.aggregator_uuid, federation_uuid=self.federation_uuid, - single_col_cert_common_name=self.single_col_cert_common_name or '' - ) + single_col_cert_common_name=self.single_col_cert_common_name or '') def validate_response(self, reply, collaborator_name): - """Validate the aggregator response.""" + """Validate the aggregator response. + + Args: + reply (aggregator_pb2.MessageReply): The reply from the aggregator. + collaborator_name (str): The name of the collaborator. + """ # check that the message was intended to go to this collaborator check_equal(reply.header.receiver, collaborator_name, self.logger) check_equal(reply.header.sender, self.aggregator_uuid, self.logger) # check that federation id matches - check_equal( - reply.header.federation_uuid, - self.federation_uuid, - self.logger - ) + check_equal(reply.header.federation_uuid, self.federation_uuid, + self.logger) # check that there is aggrement on the single_col_cert_common_name - check_equal( - reply.header.single_col_cert_common_name, - self.single_col_cert_common_name or '', - self.logger - ) + check_equal(reply.header.single_col_cert_common_name, + self.single_col_cert_common_name or '', self.logger) def disconnect(self): """Close the gRPC channel.""" @@ -252,30 +352,37 @@ def disconnect(self): def reconnect(self): """Create a new channel with the gRPC server.""" - # channel.close() is idempotent. Call again here in case it wasn't issued previously + # channel.close() is idempotent. Call again here in case it wasn't + # issued previously self.disconnect() if not self.tls: self.channel = self.create_insecure_channel(self.uri) else: - self.channel = self.create_tls_channel( - self.uri, - self.root_certificate, - self.disable_client_auth, - self.certificate, - self.private_key - ) + self.channel = self.create_tls_channel(self.uri, + self.root_certificate, + self.disable_client_auth, + self.certificate, + self.private_key) self.logger.debug(f'Connecting to gRPC at {self.uri}') self.stub = aggregator_pb2_grpc.AggregatorStub( - grpc.intercept_channel(self.channel, *self.interceptors) - ) + grpc.intercept_channel(self.channel, *self.interceptors)) @_atomic_connection @_resend_data_on_reconnection def get_tasks(self, collaborator_name): - """Get tasks from the aggregator.""" + """Get tasks from the aggregator. + + Args: + collaborator_name (str): The name of the collaborator. + + Returns: + Tuple[List[str], int, int, bool]: A tuple containing a list of + tasks, the round number, the sleep time, and a boolean + indicating whether to quit. + """ self._set_header(collaborator_name) request = aggregator_pb2.GetTasksRequest(header=self.header) response = self.stub.GetTasks(request) @@ -285,9 +392,21 @@ def get_tasks(self, collaborator_name): @_atomic_connection @_resend_data_on_reconnection - def get_aggregated_tensor(self, collaborator_name, tensor_name, round_number, - report, tags, require_lossless): - """Get aggregated tensor from the aggregator.""" + def get_aggregated_tensor(self, collaborator_name, tensor_name, + round_number, report, tags, require_lossless): + """Get aggregated tensor from the aggregator. + + Args: + collaborator_name (str): The name of the collaborator. + tensor_name (str): The name of the tensor. + round_number (int): The round number. + report (str): The report. + tags (List[str]): The tags. + require_lossless (bool): Whether lossless compression is required. + + Returns: + aggregator_pb2.TensorProto: The aggregated tensor. + """ self._set_header(collaborator_name) request = aggregator_pb2.GetAggregatedTensorRequest( @@ -296,8 +415,7 @@ def get_aggregated_tensor(self, collaborator_name, tensor_name, round_number, round_number=round_number, report=report, tags=tags, - require_lossless=require_lossless - ) + require_lossless=require_lossless) response = self.stub.GetAggregatedTensor(request) # also do other validation, like on the round_number self.validate_response(response, collaborator_name) @@ -308,15 +426,22 @@ def get_aggregated_tensor(self, collaborator_name, tensor_name, round_number, @_resend_data_on_reconnection def send_local_task_results(self, collaborator_name, round_number, task_name, data_size, named_tensors): - """Send task results to the aggregator.""" + """Send task results to the aggregator. + + Args: + collaborator_name (str): The name of the collaborator. + round_number (int): The round number. + task_name (str): The name of the task. + data_size (int): The size of the data. + named_tensors (List[aggregator_pb2.NamedTensorProto]): The list of + named tensors. + """ self._set_header(collaborator_name) - request = aggregator_pb2.TaskResults( - header=self.header, - round_number=round_number, - task_name=task_name, - data_size=data_size, - tensors=named_tensors - ) + request = aggregator_pb2.TaskResults(header=self.header, + round_number=round_number, + task_name=task_name, + data_size=data_size, + tensors=named_tensors) # convert (potentially) long list of tensors into stream stream = [] @@ -327,7 +452,15 @@ def send_local_task_results(self, collaborator_name, round_number, self.validate_response(response, collaborator_name) def _get_trained_model(self, experiment_name, model_type): - """Get trained model RPC.""" + """Get trained model RPC. + + Args: + experiment_name (str): The name of the experiment. + model_type (str): The type of the model. + + Returns: + Dict[str, numpy.ndarray]: The trained model. + """ get_model_request = self.stub.GetTrainedModelRequest( experiment_name=experiment_name, model_type=model_type, diff --git a/openfl/transport/grpc/aggregator_server.py b/openfl/transport/grpc/aggregator_server.py index 39fde16445..93088cd8b0 100644 --- a/openfl/transport/grpc/aggregator_server.py +++ b/openfl/transport/grpc/aggregator_server.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """AggregatorGRPCServer module.""" import logging @@ -25,7 +24,26 @@ class AggregatorGRPCServer(aggregator_pb2_grpc.AggregatorServicer): - """gRPC server class for the Aggregator.""" + """GRPC server class for the Aggregator. + + This class implements a gRPC server for the Aggregator, allowing it to + communicate with collaborators. + + Attributes: + aggregator (Aggregator): The aggregator that this server is serving. + uri (str): The URI that the server is serving on. + tls (bool): Whether to use TLS for the connection. + disable_client_auth (bool): Whether to disable client-side + authentication. + root_certificate (str): The path to the root certificate for the TLS + connection. + certificate (str): The path to the server's certificate for the TLS + connection. + private_key (str): The path to the server's private key for the TLS + connection. + server (grpc.Server): The gRPC server. + server_credentials (grpc.ServerCredentials): The server's credentials. + """ def __init__(self, aggregator, @@ -36,20 +54,22 @@ def __init__(self, certificate=None, private_key=None, **kwargs): - """ - Class initializer. + """Initialize the AggregatorGRPCServer. Args: - aggregator: The aggregator - Args: - fltask (FLtask): The gRPC service task. - tls (bool): To disable the TLS. (Default: True) - disable_client_auth (bool): To disable the client side - authentication. (Default: False) - root_certificate (str): File path to the CA certificate. - certificate (str): File path to the server certificate. - private_key (str): File path to the private key. - kwargs (dict): Additional arguments to pass into function + aggregator (Aggregator): The aggregator that this server is + serving. + agg_port (int): The port that the server is serving on. + tls (bool): Whether to use TLS for the connection. + disable_client_auth (bool): Whether to disable client-side + authentication. + root_certificate (str): The path to the root certificate for the + TLS connection. + certificate (str): The path to the server's certificate for the + TLS connection. + private_key (str): The path to the server's private key for the + TLS connection. + **kwargs: Additional keyword arguments. """ self.aggregator = aggregator self.uri = f'[::]:{agg_port}' @@ -64,21 +84,23 @@ def __init__(self, self.logger = logging.getLogger(__name__) def validate_collaborator(self, request, context): - """ - Validate the collaborator. + """Validate the collaborator. + + This method checks that the collaborator who sent the request is + authorized to do so. Args: - request: The gRPC message request - context: The gRPC context + request (aggregator_pb2.MessageHeader): The request from the + collaborator. + context (grpc.ServicerContext): The context of the request. Raises: - ValueError: If the collaborator or collaborator certificate is not - valid then raises error. - + grpc.RpcError: If the collaborator or collaborator certificate is + not authorized. """ if self.tls: - common_name = context.auth_context()[ - 'x509_common_name'][0].decode('utf-8') + common_name = context.auth_context()['x509_common_name'][0].decode( + 'utf-8') collaborator_common_name = request.header.sender if not self.aggregator.valid_collaborator_cn_and_id( common_name, collaborator_common_name): @@ -90,53 +112,65 @@ def validate_collaborator(self, request, context): f'collaborator_common_name: |{collaborator_common_name}|') def get_header(self, collaborator_name): - """ - Compose and return MessageHeader. + """Compose and return MessageHeader. + + This method creates a MessageHeader for a message to the specified + collaborator. Args: - collaborator_name : str - The collaborator the message is intended for + collaborator_name (str): The name of the collaborator to send the + message to. + + Returns: + aggregator_pb2.MessageHeader: The header for the message. """ return aggregator_pb2.MessageHeader( sender=self.aggregator.uuid, receiver=collaborator_name, federation_uuid=self.aggregator.federation_uuid, - single_col_cert_common_name=self.aggregator.single_col_cert_common_name - ) + single_col_cert_common_name=self.aggregator. + single_col_cert_common_name) def check_request(self, request): - """ - Validate request header matches expected values. + """Validate request header matches expected values. + + This method checks that the request is valid and was sent by an + authorized collaborator. Args: - request : protobuf - Request sent from a collaborator that requires validation + request (aggregator_pb2.MessageHeader): Request sent from a + collaborator that requires validation. + + Raises: + ValueError: If the request is not valid. """ # TODO improve this check. the sender name could be spoofed - check_is_in(request.header.sender, self.aggregator.authorized_cols, self.logger) + check_is_in(request.header.sender, self.aggregator.authorized_cols, + self.logger) # check that the message is for me check_equal(request.header.receiver, self.aggregator.uuid, self.logger) # check that the message is for my federation - check_equal( - request.header.federation_uuid, self.aggregator.federation_uuid, self.logger) + check_equal(request.header.federation_uuid, + self.aggregator.federation_uuid, self.logger) # check that we agree on the single cert common name - check_equal( - request.header.single_col_cert_common_name, - self.aggregator.single_col_cert_common_name, - self.logger - ) + check_equal(request.header.single_col_cert_common_name, + self.aggregator.single_col_cert_common_name, self.logger) def GetTasks(self, request, context): # NOQA:N802 - """ - Request a job from aggregator. + """Request a job from aggregator. + + This method handles a request from a collaborator for a job. Args: - request: The gRPC message request - context: The gRPC context + request (aggregator_pb2.GetTasksRequest): The request from the + collaborator. + context (grpc.ServicerContext): The context of the request. + Returns: + aggregator_pb2.GetTasksResponse: The response to the request. """ self.validate_collaborator(request, context) self.check_request(request) @@ -147,18 +181,15 @@ def GetTasks(self, request, context): # NOQA:N802 if isinstance(tasks[0], str): # backward compatibility tasks_proto = [ - aggregator_pb2.Task( - name=task, - ) for task in tasks + aggregator_pb2.Task(name=task, ) for task in tasks ] else: tasks_proto = [ - aggregator_pb2.Task( - name=task.name, - function_name=task.function_name, - task_type=task.task_type, - apply_local=task.apply_local - ) for task in tasks + aggregator_pb2.Task(name=task.name, + function_name=task.function_name, + task_type=task.task_type, + apply_local=task.apply_local) + for task in tasks ] else: tasks_proto = [] @@ -168,17 +199,22 @@ def GetTasks(self, request, context): # NOQA:N802 round_number=round_number, tasks=tasks_proto, sleep_time=sleep_time, - quit=time_to_quit - ) + quit=time_to_quit) def GetAggregatedTensor(self, request, context): # NOQA:N802 - """ - Request a job from aggregator. + """Request a job from aggregator. + + This method handles a request from a collaborator for an aggregated + tensor. Args: - request: The gRPC message request - context: The gRPC context + request (aggregator_pb2.GetAggregatedTensorRequest): The request + from the collaborator. + context (grpc.ServicerContext): The context of the request. + Returns: + aggregator_pb2.GetAggregatedTensorResponse: The response to the + request. """ self.validate_collaborator(request, context) self.check_request(request) @@ -190,22 +226,28 @@ def GetAggregatedTensor(self, request, context): # NOQA:N802 tags = tuple(request.tags) named_tensor = self.aggregator.get_aggregated_tensor( - collaborator_name, tensor_name, round_number, report, tags, require_lossless) + collaborator_name, tensor_name, round_number, report, tags, + require_lossless) return aggregator_pb2.GetAggregatedTensorResponse( header=self.get_header(collaborator_name), round_number=round_number, - tensor=named_tensor - ) + tensor=named_tensor) def SendLocalTaskResults(self, request, context): # NOQA:N802 - """ - Request a model download from aggregator. + """Request a model download from aggregator. + + This method handles a request from a collaborator to send the results + of a local task. Args: - request: The gRPC message request - context: The gRPC context + request (aggregator_pb2.SendLocalTaskResultsRequest): The request + from the collaborator. + context (grpc.ServicerContext): The context of the request. + Returns: + aggregator_pb2.SendLocalTaskResultsResponse: The response to the + request. """ try: proto = aggregator_pb2.TaskResults() @@ -224,15 +266,22 @@ def SendLocalTaskResults(self, request, context): # NOQA:N802 round_number = proto.round_number data_size = proto.data_size named_tensors = proto.tensors - self.aggregator.send_local_task_results( - collaborator_name, round_number, task_name, data_size, named_tensors) + self.aggregator.send_local_task_results(collaborator_name, + round_number, task_name, + data_size, named_tensors) # turn data stream into local model update return aggregator_pb2.SendLocalTaskResultsResponse( - header=self.get_header(collaborator_name) - ) + header=self.get_header(collaborator_name)) def get_server(self): - """Return gRPC server.""" + """Return gRPC server. + + This method creates a gRPC server if it does not already exist and + returns it. + + Returns: + grpc.Server: The gRPC server. + """ self.server = server(ThreadPoolExecutor(max_workers=cpu_count()), options=channel_options) @@ -258,17 +307,20 @@ def get_server(self): self.logger.warn('Client-side authentication is disabled.') self.server_credentials = ssl_server_credentials( - ((private_key_b, certificate_b),), + ((private_key_b, certificate_b), ), root_certificates=root_certificate_b, - require_client_auth=not self.disable_client_auth - ) + require_client_auth=not self.disable_client_auth) self.server.add_secure_port(self.uri, self.server_credentials) return self.server def serve(self): - """Start an aggregator gRPC service.""" + """Start an aggregator gRPC service. + + This method starts the gRPC server and handles requests until all quit + jobs havebeen sent. + """ self.get_server() self.logger.info('Starting Aggregator gRPC Server') diff --git a/openfl/transport/grpc/director_client.py b/openfl/transport/grpc/director_client.py index 8f82af1341..284b92de27 100644 --- a/openfl/transport/grpc/director_client.py +++ b/openfl/transport/grpc/director_client.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Director clients module.""" import logging @@ -26,16 +25,46 @@ class ShardDirectorClient: - """The internal director client class.""" - - def __init__(self, *, director_host, director_port, shard_name, tls=True, - root_certificate=None, private_key=None, certificate=None) -> None: - """Initialize a shard director client object.""" + """The internal director client class. + + This class communicates with the director to manage the shard's + participation in the federation. + + Attributes: + shard_name (str): The name of the shard. + stub (director_pb2_grpc.DirectorStub): The gRPC stub for communication + with the director. + """ + + def __init__(self, + *, + director_host, + director_port, + shard_name, + tls=True, + root_certificate=None, + private_key=None, + certificate=None) -> None: + """Initialize a shard director client object. + + Args: + director_host (str): The host of the director. + director_port (int): The port of the director. + shard_name (str): The name of the shard. + tls (bool): Whether to use TLS for the connection. + root_certificate (str): The path to the root certificate for the + TLS connection. + private_key (str): The path to the private key for the TLS + connection. + certificate (str): The path to the certificate for the TLS + connection. + """ self.shard_name = shard_name director_addr = f'{director_host}:{director_port}' logger.info(f'Director address: {director_addr}') if not tls: - channel = grpc.insecure_channel(director_addr, options=channel_options) + channel = grpc.insecure_channel(director_addr, + options=channel_options) else: if not (root_certificate and private_key and certificate): raise Exception('No certificates provided') @@ -47,39 +76,53 @@ def __init__(self, *, director_host, director_port, shard_name, tls=True, with open(certificate, 'rb') as f: certificate_b = f.read() except FileNotFoundError as exc: - raise Exception(f'Provided certificate file is not exist: {exc.filename}') + raise Exception( + f'Provided certificate file is not exist: {exc.filename}') credentials = grpc.ssl_channel_credentials( root_certificates=root_certificate_b, private_key=private_key_b, - certificate_chain=certificate_b - ) - channel = grpc.secure_channel(director_addr, credentials, options=channel_options) + certificate_chain=certificate_b) + channel = grpc.secure_channel(director_addr, + credentials, + options=channel_options) self.stub = director_pb2_grpc.DirectorStub(channel) def report_shard_info(self, shard_descriptor: Type[ShardDescriptor], cuda_devices: tuple) -> bool: - """Report shard info to the director.""" + """Report shard info to the director. + + Args: + shard_descriptor (Type[ShardDescriptor]): The descriptor of the + shard. + cuda_devices (tuple): The CUDA devices available on the shard. + + Returns: + acknowledgement (bool): Whether the report was accepted by the + director. + """ logger.info(f'Sending {self.shard_name} shard info to director') # True considered as successful registration shard_info = director_pb2.ShardInfo( shard_description=shard_descriptor.dataset_description, sample_shape=shard_descriptor.sample_shape, - target_shape=shard_descriptor.target_shape - ) + target_shape=shard_descriptor.target_shape) shard_info.node_info.name = self.shard_name shard_info.node_info.cuda_devices.extend( director_pb2.CudaDeviceInfo(index=cuda_device) - for cuda_device in cuda_devices - ) + for cuda_device in cuda_devices) request = director_pb2.UpdateShardInfoRequest(shard_info=shard_info) acknowledgement = self.stub.UpdateShardInfo(request) return acknowledgement.accepted def wait_experiment(self): - """Wait an experiment data from the director.""" + """Wait an experiment data from the director. + + Returns: + experiment_name (str): The name of the experiment. + """ logger.info('Waiting for an experiment to run...') response = self.stub.WaitExperiment(self._get_experiment_data()) logger.info(f'New experiment received: {response}') @@ -90,43 +133,71 @@ def wait_experiment(self): return experiment_name def get_experiment_data(self, experiment_name): - """Get an experiment data from the director.""" + """Get an experiment data from the director. + + Args: + experiment_name (str): The name of the experiment. + + Returns: + data_stream (grpc._channel._MultiThreadedRendezvous): The data + stream of the experiment data. + """ logger.info(f'Getting experiment data for {experiment_name}...') request = director_pb2.GetExperimentDataRequest( - experiment_name=experiment_name, - collaborator_name=self.shard_name - ) + experiment_name=experiment_name, collaborator_name=self.shard_name) data_stream = self.stub.GetExperimentData(request) return data_stream - def set_experiment_failed( - self, - experiment_name: str, - error_code: int = 1, - error_description: str = '' - ): - """Set the experiment failed.""" + def set_experiment_failed(self, + experiment_name: str, + error_code: int = 1, + error_description: str = ''): + """Set the experiment failed. + + Args: + experiment_name (str): The name of the experiment. + error_code (int, optional): The error code. Defaults to 1. + error_description (str, optional): The description of the error. + Defaults to ''. + """ logger.info(f'Experiment {experiment_name} failed') request = director_pb2.SetExperimentFailedRequest( experiment_name=experiment_name, collaborator_name=self.shard_name, error_code=error_code, - error_description=error_description - ) + error_description=error_description) self.stub.SetExperimentFailed(request) def _get_experiment_data(self): - """Generate the experiment data request.""" - return director_pb2.WaitExperimentRequest(collaborator_name=self.shard_name) + """Generate the experiment data request. + + Returns: + director_pb2.WaitExperimentRequest: The request for experiment + data. + """ + return director_pb2.WaitExperimentRequest( + collaborator_name=self.shard_name) def send_health_check( - self, *, - envoy_name: str, - is_experiment_running: bool, - cuda_devices_info: List[dict] = None, + self, + *, + envoy_name: str, + is_experiment_running: bool, + cuda_devices_info: List[dict] = None, ) -> int: - """Send envoy health check.""" + """Send envoy health check. + + Args: + envoy_name (str): The name of the envoy. + is_experiment_running (bool): Whether an experiment is currently + running. + cuda_devices_info (List[dict], optional): Information about the + CUDA devices. Defaults to None. + + Returns: + health_check_period (int): The period for health checks. + """ status = director_pb2.UpdateEnvoyStatusRequest( name=envoy_name, is_experiment_running=is_experiment_running, @@ -159,24 +230,47 @@ def send_health_check( class DirectorClient: - """Director client class for users.""" + """Director client class for users. + + This class communicates with the director to manage the user's + participation in the federation. + + Attributes: + stub (director_pb2_grpc.DirectorStub): The gRPC stub for communication + with the director. + """ def __init__( - self, *, - client_id: str, - director_host: str, - director_port: int, - tls: bool, - root_certificate: str, - private_key: str, - certificate: str, + self, + *, + client_id: str, + director_host: str, + director_port: int, + tls: bool, + root_certificate: str, + private_key: str, + certificate: str, ) -> None: - """Initialize director client object.""" + """Initialize director client object. + + Args: + client_id (str): The ID of the client. + director_host (str): The host of the director. + director_port (int): The port of the director. + tls (bool): Whether to use TLS for the connection. + root_certificate (str): The path to the root certificate for the + TLS connection. + private_key (str): The path to the private key for the TLS + connection. + certificate (str): The path to the certificate for the TLS + connection. + """ director_addr = f'{director_host}:{director_port}' if not tls: if not client_id: client_id = CLIENT_ID_DEFAULT - channel = grpc.insecure_channel(director_addr, options=channel_options) + channel = grpc.insecure_channel(director_addr, + options=channel_options) headers = { 'client_id': client_id, } @@ -193,23 +287,41 @@ def __init__( with open(certificate, 'rb') as f: certificate_b = f.read() except FileNotFoundError as exc: - raise Exception(f'Provided certificate file is not exist: {exc.filename}') + raise Exception( + f'Provided certificate file is not exist: {exc.filename}') credentials = grpc.ssl_channel_credentials( root_certificates=root_certificate_b, private_key=private_key_b, - certificate_chain=certificate_b - ) + certificate_chain=certificate_b) - channel = grpc.secure_channel(director_addr, credentials, options=channel_options) + channel = grpc.secure_channel(director_addr, + credentials, + options=channel_options) self.stub = director_pb2_grpc.DirectorStub(channel) - def set_new_experiment(self, name, col_names, arch_path, + def set_new_experiment(self, + name, + col_names, + arch_path, initial_tensor_dict=None): - """Send the new experiment to director to launch.""" + """Send the new experiment to director to launch. + + Args: + name (str): The name of the experiment. + col_names (List[str]): The names of the collaborators. + arch_path (str): The path to the architecture. + initial_tensor_dict (dict, optional): The initial tensor + dictionary. Defaults to None. + + Returns: + resp (director_pb2.SetNewExperimentResponse): The response from + the director. + """ logger.info(f'Submitting new experiment {name} to director') if initial_tensor_dict: - model_proto = construct_model_proto(initial_tensor_dict, 0, NoCompressionPipeline()) + model_proto = construct_model_proto(initial_tensor_dict, 0, + NoCompressionPipeline()) experiment_info_gen = self._get_experiment_info( arch_path=arch_path, name=name, @@ -220,6 +332,20 @@ def set_new_experiment(self, name, col_names, arch_path, return resp def _get_experiment_info(self, arch_path, name, col_names, model_proto): + """Generate the experiment data request. + + This method generates a stream of experiment data to be sent to the + director. + + Args: + arch_path (str): The path to the architecture. + name (str): The name of the experiment. + col_names (List[str]): The names of the collaborators. + model_proto (ModelProto): The initial model. + + Yields: + director_pb2.ExperimentInfo: The experiment data. + """ with open(arch_path, 'rb') as arch: max_buffer_size = 2 * 1024 * 1024 chunk = arch.read(max_buffer_size) @@ -230,27 +356,49 @@ def _get_experiment_info(self, arch_path, name, col_names, model_proto): experiment_info = director_pb2.ExperimentInfo( name=name, collaborator_names=col_names, - model_proto=model_proto - ) + model_proto=model_proto) experiment_info.experiment_data.size = len(chunk) experiment_info.experiment_data.npbytes = chunk yield experiment_info chunk = arch.read(max_buffer_size) def get_experiment_status(self, experiment_name): - """Check if the experiment was accepted by the director""" + """Check if the experiment was accepted by the director. + + Args: + experiment_name (str): The name of the experiment. + + Returns: + resp (director_pb2.GetExperimentStatusResponse): The response from + the director. + """ logger.info('Getting experiment Status...') - request = director_pb2.GetExperimentStatusRequest(experiment_name=experiment_name) + request = director_pb2.GetExperimentStatusRequest( + experiment_name=experiment_name) resp = self.stub.GetExperimentStatus(request) return resp def get_dataset_info(self): - """Request the dataset info from the director.""" + """Request the dataset info from the director. + + Returns: + Tuple[List[int], List[int]]: The sample shape and target shape of + the dataset. + """ resp = self.stub.GetDatasetInfo(director_pb2.GetDatasetInfoRequest()) return resp.shard_info.sample_shape, resp.shard_info.target_shape def _get_trained_model(self, experiment_name, model_type): - """Get trained model RPC.""" + """Get trained model RPC. + + Args: + experiment_name (str): The name of the experiment. + model_type (director_pb2.GetTrainedModelRequest.ModelType): The + type of the model. + + Returns: + tensor_dict (Dict[str, numpy.ndarray]): The trained model. + """ get_model_request = director_pb2.GetTrainedModelRequest( experiment_name=experiment_name, model_type=model_type, @@ -263,18 +411,40 @@ def _get_trained_model(self, experiment_name, model_type): return tensor_dict def get_best_model(self, experiment_name): - """Get best model method.""" + """Get best model method. + + Args: + experiment_name (str): The name of the experiment. + + Returns: + Dict[str, numpy.ndarray]: The best model. + """ model_type = director_pb2.GetTrainedModelRequest.BEST_MODEL return self._get_trained_model(experiment_name, model_type) def get_last_model(self, experiment_name): - """Get last model method.""" + """Get last model method. + + Args: + experiment_name (str): The name of the experiment. + + Returns: + Dict[str, numpy.ndarray]: The last model. + """ model_type = director_pb2.GetTrainedModelRequest.LAST_MODEL return self._get_trained_model(experiment_name, model_type) def stream_metrics(self, experiment_name): - """Stream metrics RPC.""" - request = director_pb2.GetMetricStreamRequest(experiment_name=experiment_name) + """Stream metrics RPC. + + Args: + experiment_name (str): The name of the experiment. + + Yields: + Dict[str, Any]: The metrics. + """ + request = director_pb2.GetMetricStreamRequest( + experiment_name=experiment_name) for metric_message in self.stub.GetMetricStream(request): yield { 'metric_origin': metric_message.metric_origin, @@ -285,13 +455,29 @@ def stream_metrics(self, experiment_name): } def remove_experiment_data(self, name): - """Remove experiment data RPC.""" + """Remove experiment data RPC. + + Args: + name (str): The name of the experiment. + + Returns: + bool: Whether the removal was acknowledged. + """ request = director_pb2.RemoveExperimentRequest(experiment_name=name) response = self.stub.RemoveExperimentData(request) return response.acknowledgement def get_envoys(self, raw_result=False): - """Get envoys info.""" + """Get envoys info. + + Args: + raw_result (bool, optional): Whether to return the raw result. + Defaults to False. + + Returns: + result (Union[director_pb2.GetEnvoysResponse, + Dict[str, Dict[str, Any]]]): The envoys info. + """ envoys = self.stub.GetEnvoys(director_pb2.GetEnvoysRequest()) if raw_result: return envoys @@ -299,27 +485,44 @@ def get_envoys(self, raw_result=False): result = {} for envoy in envoys.envoy_infos: result[envoy.shard_info.node_info.name] = { - 'shard_info': envoy.shard_info, - 'is_online': envoy.is_online or False, - 'is_experiment_running': envoy.is_experiment_running or False, - 'last_updated': datetime.fromtimestamp( + 'shard_info': + envoy.shard_info, + 'is_online': + envoy.is_online or False, + 'is_experiment_running': + envoy.is_experiment_running or False, + 'last_updated': + datetime.fromtimestamp( envoy.last_updated.seconds).strftime('%Y-%m-%d %H:%M:%S'), - 'current_time': now, - 'valid_duration': envoy.valid_duration, - 'experiment_name': 'ExperimentName Mock', + 'current_time': + now, + 'valid_duration': + envoy.valid_duration, + 'experiment_name': + 'ExperimentName Mock', } return result def get_experiments_list(self): - """Get experiments list.""" + """Get experiments list. + + Returns: + List[str]: The list of experiments. + """ response = self.stub.GetExperimentsList( - director_pb2.GetExperimentsListRequest() - ) + director_pb2.GetExperimentsListRequest()) return response.experiments def get_experiment_description(self, name): - """Get experiment info.""" + """Get experiment info. + + Args: + name (str): The name of the experiment. + + Returns: + director_pb2.ExperimentDescription: The description of the + experiment. + """ response = self.stub.GetExperimentDescription( - director_pb2.GetExperimentDescriptionRequest(name=name) - ) + director_pb2.GetExperimentDescriptionRequest(name=name)) return response.experiment diff --git a/openfl/transport/grpc/director_server.py b/openfl/transport/grpc/director_server.py index 7dd1a3a6c9..679f8110cc 100644 --- a/openfl/transport/grpc/director_server.py +++ b/openfl/transport/grpc/director_server.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Director server.""" import asyncio @@ -34,22 +33,61 @@ class DirectorGRPCServer(director_pb2_grpc.DirectorServicer): - """Director transport class.""" - - def __init__( - self, *, - director_cls, - tls: bool = True, - root_certificate: Optional[Union[Path, str]] = None, - private_key: Optional[Union[Path, str]] = None, - certificate: Optional[Union[Path, str]] = None, - review_plan_callback: Union[None, Callable] = None, - listen_host: str = '[::]', - listen_port: int = 50051, - envoy_health_check_period: int = 0, - **kwargs - ) -> None: - """Initialize a director object.""" + """Director transport class. + + This class implements a gRPC server for the Director, allowing it to + communicate with collaborators. + + Attributes: + director (Director): The director that this server is serving. + listen_uri (str): The URI that the server is serving on. + tls (bool): Whether to use TLS for the connection. + root_certificate (Path): The path to the root certificate for the TLS + connection. + private_key (Path): The path to the server's private key for the TLS + connection. + certificate (Path): The path to the server's certificate for the TLS + connection. + server (grpc.Server): The gRPC server. + """ + + def __init__(self, + *, + director_cls, + tls: bool = True, + root_certificate: Optional[Union[Path, str]] = None, + private_key: Optional[Union[Path, str]] = None, + certificate: Optional[Union[Path, str]] = None, + review_plan_callback: Union[None, Callable] = None, + listen_host: str = '[::]', + listen_port: int = 50051, + envoy_health_check_period: int = 0, + **kwargs) -> None: + """Initialize a director object. + + Args: + director_cls (Type[Director]): The class of the director. + tls (bool, optional): Whether to use TLS for the connection. + Defaults to True. + root_certificate (Optional[Union[Path, str]], optional): The path + to the root certificate for the TLS connection. Defaults to + None. + private_key (Optional[Union[Path, str]], optional): The path to + the server's private key for the TLS connection. Defaults to + None. + certificate (Optional[Union[Path, str]], optional): The path to + the server's certificate for the TLS connection. Defaults to + None. + review_plan_callback (Union[None, Callable], optional): The + callback for reviewing the plan. Defaults to None. + listen_host (str, optional): The host to listen on. Defaults to + '[::]'. + listen_port (int, optional): The port to listen on. Defaults to + 50051. + envoy_health_check_period (int, optional): The period for health + checks. Defaults to 0. + **kwargs: Additional keyword arguments. + """ # TODO: add working directory super().__init__() @@ -68,11 +106,19 @@ def __init__( certificate=self.certificate, review_plan_callback=review_plan_callback, envoy_health_check_period=envoy_health_check_period, - **kwargs - ) + **kwargs) def _fill_certs(self, root_certificate, private_key, certificate): - """Fill certificates.""" + """Fill certificates. + + Args: + root_certificate (Union[Path, str]): The path to the root + certificate for the TLS connection. + private_key (Union[Path, str]): The path to the server's private + key for the TLS connection. + certificate (Union[Path, str]): The path to the server's + certificate for the TLS connection. + """ if self.tls: if not (root_certificate and private_key and certificate): raise Exception('No certificates provided') @@ -81,25 +127,33 @@ def _fill_certs(self, root_certificate, private_key, certificate): self.certificate = Path(certificate).absolute() def get_caller(self, context): - """ - Get caller name from context. + """Get caller name from context. + + if tls == True: get caller name from auth_context + if tls == False: get caller name from context header 'client_id' - if tls == True: get caller name from auth_context - if tls == False: get caller name from context header 'client_id' + Args: + context (grpc.ServicerContext): The context of the request. + + Returns: + str: The name of the caller. """ if self.tls: - return context.auth_context()['x509_common_name'][0].decode('utf-8') + return context.auth_context()['x509_common_name'][0].decode( + 'utf-8') headers = get_headers(context) client_id = headers.get('client_id', CLIENT_ID_DEFAULT) return client_id def start(self): """Launch the director GRPC server.""" - loop = asyncio.get_event_loop() # TODO: refactor after end of support for python3.6 + loop = asyncio.get_event_loop( + ) # TODO: refactor after end of support for python3.6 loop.create_task(self.director.start_experiment_execution_loop()) loop.run_until_complete(self._run_server()) async def _run_server(self): + """Run the gRPC server.""" self.server = aio.server(options=channel_options) director_pb2_grpc.add_DirectorServicer_to_server(self, self.server) @@ -113,41 +167,59 @@ async def _run_server(self): with open(self.root_certificate, 'rb') as f: root_certificate_b = f.read() server_credentials = ssl_server_credentials( - ((private_key_b, certificate_b),), + ((private_key_b, certificate_b), ), root_certificates=root_certificate_b, - require_client_auth=True - ) + require_client_auth=True) self.server.add_secure_port(self.listen_uri, server_credentials) logger.info(f'Starting director server on {self.listen_uri}') await self.server.start() await self.server.wait_for_termination() async def UpdateShardInfo(self, request, context): # NOQA:N802 - """Receive acknowledge shard info.""" + """Receive acknowledge shard info. + + Args: + request (director_pb2.UpdateShardInfoRequest): The request from + the shard. + context (grpc.ServicerContext): The context of the request. + + Returns: + reply (director_pb2.UpdateShardInfoResponse): The response to the + request. + """ logger.info(f'Updating shard info: {request.shard_info}') - dict_shard_info = MessageToDict( - request.shard_info, - preserving_proto_field_name=True - ) + dict_shard_info = MessageToDict(request.shard_info, + preserving_proto_field_name=True) is_accepted = self.director.acknowledge_shard(dict_shard_info) reply = director_pb2.UpdateShardInfoResponse(accepted=is_accepted) return reply async def SetNewExperiment(self, stream, context): # NOQA:N802 - """Request to set new experiment.""" + """Request to set new experiment. + + Args: + stream (grpc.aio._MultiThreadedRendezvous): The stream of + experiment data. + context (grpc.ServicerContext): The context of the request. + + Returns: + director_pb2.SetNewExperimentResponse: The response to the request. + """ # TODO: add streaming reader data_file_path = self.root_dir / str(uuid.uuid4()) with open(data_file_path, 'wb') as data_file: async for request in stream: - if request.experiment_data.size == len(request.experiment_data.npbytes): + if request.experiment_data.size == len( + request.experiment_data.npbytes): data_file.write(request.experiment_data.npbytes) else: raise Exception('Could not register new experiment') tensor_dict = None if request.model_proto: - tensor_dict, _ = deconstruct_model_proto(request.model_proto, NoCompressionPipeline()) + tensor_dict, _ = deconstruct_model_proto(request.model_proto, + NoCompressionPipeline()) caller = self.get_caller(context) @@ -156,25 +228,42 @@ async def SetNewExperiment(self, stream, context): # NOQA:N802 sender_name=caller, tensor_dict=tensor_dict, collaborator_names=request.collaborator_names, - experiment_archive_path=data_file_path - ) + experiment_archive_path=data_file_path) logger.info(f'Experiment {request.name} registered') return director_pb2.SetNewExperimentResponse(accepted=is_accepted) async def GetExperimentStatus(self, request, context): # NOQA: N802 - """Get experiment status and update if experiment was approved.""" + """Get experiment status and update if experiment was approved. + + Args: + request (director_pb2.GetExperimentStatusRequest): The request + from the collaborator. + context (grpc.ServicerContext): The context of the request. + + Returns: + director_pb2.GetExperimentStatusResponse: The response to the + request. + """ logger.debug('GetExperimentStatus request received') caller = self.get_caller(context) experiment_status = await self.director.get_experiment_status( - experiment_name=request.experiment_name, - caller=caller - ) + experiment_name=request.experiment_name, caller=caller) logger.debug('Sending GetExperimentStatus response') - return director_pb2.GetExperimentStatusResponse(experiment_status=experiment_status) + return director_pb2.GetExperimentStatusResponse( + experiment_status=experiment_status) async def GetTrainedModel(self, request, context): # NOQA:N802 - """RPC for retrieving trained models.""" + """RPC for retrieving trained models. + + Args: + request (director_pb2.GetTrainedModelRequest): The request from + the collaborator. + context (grpc.ServicerContext): The context of the request. + + Returns: + director_pb2.TrainedModelResponse: The response to the request. + """ logger.debug('Received request for trained model...') if request.model_type == director_pb2.GetTrainedModelRequest.BEST_MODEL: @@ -190,24 +279,34 @@ async def GetTrainedModel(self, request, context): # NOQA:N802 trained_model_dict = self.director.get_trained_model( experiment_name=request.experiment_name, caller=caller, - model_type=model_type - ) + model_type=model_type) if trained_model_dict is None: return director_pb2.TrainedModelResponse() - model_proto = construct_model_proto(trained_model_dict, 0, NoCompressionPipeline()) + model_proto = construct_model_proto(trained_model_dict, 0, + NoCompressionPipeline()) logger.debug('Sending trained model') return director_pb2.TrainedModelResponse(model_proto=model_proto) async def GetExperimentData(self, request, context): # NOQA:N802 - """Receive experiment data.""" + """Receive experiment data. + + Args: + request (director_pb2.GetExperimentDataRequest): The request from + the collaborator. + context (grpc.ServicerContext): The context of the request. + + Yields: + director_pb2.ExperimentData: The experiment data. + """ # TODO: add size filling # TODO: add experiment name field # TODO: rename npbytes to data - data_file_path = self.director.get_experiment_data(request.experiment_name) + data_file_path = self.director.get_experiment_data( + request.experiment_name) max_buffer_size = (2 * 1024 * 1024) with open(data_file_path, 'rb') as df: while True: @@ -217,42 +316,82 @@ async def GetExperimentData(self, request, context): # NOQA:N802 yield director_pb2.ExperimentData(size=len(data), npbytes=data) async def WaitExperiment(self, request, context): # NOQA:N802 - """Request for wait an experiment.""" - logger.debug(f'Request WaitExperiment received from envoy {request.collaborator_name}') - experiment_name = await self.director.wait_experiment(request.collaborator_name) - logger.debug(f'Experiment {experiment_name} is ready for {request.collaborator_name}') + """Request for wait an experiment. + + Args: + request (director_pb2.WaitExperimentRequest): The request from the + collaborator. + context (grpc.ServicerContext): The context of the request. + + Returns: + director_pb2.WaitExperimentResponse: The response to the request. + """ + logger.debug( + f'Request WaitExperiment received from envoy {request.collaborator_name}' + ) + experiment_name = await self.director.wait_experiment( + request.collaborator_name) + logger.debug( + f'Experiment {experiment_name} is ready for {request.collaborator_name}' + ) - return director_pb2.WaitExperimentResponse(experiment_name=experiment_name) + return director_pb2.WaitExperimentResponse( + experiment_name=experiment_name) async def GetDatasetInfo(self, request, context): # NOQA:N802 - """Request the info about target and sample shapes in the dataset.""" + """Request the info about target and sample shapes in the dataset. + + Args: + request (director_pb2.GetDatasetInfoRequest): The request from the + collaborator. + context (grpc.ServicerContext): The context of the request. + + Returns: + director_pb2.GetDatasetInfoResponse: The response to the request. + """ logger.debug('Received request for dataset info...') sample_shape, target_shape = self.director.get_dataset_info() - shard_info = director_pb2.ShardInfo( - sample_shape=sample_shape, - target_shape=target_shape - ) + shard_info = director_pb2.ShardInfo(sample_shape=sample_shape, + target_shape=target_shape) resp = director_pb2.GetDatasetInfoResponse(shard_info=shard_info) logger.debug('Sending dataset info') return resp async def GetMetricStream(self, request, context): # NOQA:N802 - """Request to stream metrics from the aggregator to frontend.""" + """Request to stream metrics from the aggregator to frontend. + + Args: + request (director_pb2.GetMetricStreamRequest): The request from + the collaborator. + context (grpc.ServicerContext): The context of the request. + + Yields: + director_pb2.GetMetricStreamResponse: The metrics. + """ logger.info(f'Getting metrics for {request.experiment_name}...') caller = self.get_caller(context) async for metric_dict in self.director.stream_metrics( - experiment_name=request.experiment_name, caller=caller - ): + experiment_name=request.experiment_name, caller=caller): if metric_dict is None: await asyncio.sleep(1) continue yield director_pb2.GetMetricStreamResponse(**metric_dict) async def RemoveExperimentData(self, request, context): # NOQA:N802 - """Remove experiment data RPC.""" + """Remove experiment data RPC. + + Args: + request (director_pb2.RemoveExperimentRequest): The request from + the collaborator. + context (grpc.ServicerContext): The context of the request. + + Returns: + response (director_pb2.RemoveExperimentResponse): The response to + the request. + """ response = director_pb2.RemoveExperimentResponse(acknowledgement=False) caller = self.get_caller(context) self.director.remove_experiment_data( @@ -264,22 +403,42 @@ async def RemoveExperimentData(self, request, context): # NOQA:N802 return response async def SetExperimentFailed(self, request, context): # NOQA:N802 - """Set the experiment failed.""" + """Set the experiment failed. + + Args: + request (director_pb2.SetExperimentFailedRequest): The request + from the collaborator. + context (grpc.ServicerContext): The context of the request. + + Returns: + response (director_pb2.SetExperimentFailedResponse): The response + to the request. + """ response = director_pb2.SetExperimentFailedResponse() if self.get_caller(context) != CLIENT_ID_DEFAULT: return response - logger.error(f'Collaborator {request.collaborator_name} failed with error code:' - f' {request.error_code}, error_description: {request.error_description}' - f'Stopping experiment.') + logger.error( + f'Collaborator {request.collaborator_name} failed with error code:' + f' {request.error_code}, error_description: {request.error_description}' + f'Stopping experiment.') self.director.set_experiment_failed( experiment_name=request.experiment_name, - collaborator_name=request.collaborator_name - ) + collaborator_name=request.collaborator_name) return response async def UpdateEnvoyStatus(self, request, context): # NOQA:N802 - """Accept health check from envoy.""" + """Accept health check from envoy. + + Args: + request (director_pb2.UpdateEnvoyStatusRequest): The request from + the envoy. + context (grpc.ServicerContext): The context of the request. + + Returns: + resp (director_pb2.UpdateEnvoyStatusResponse): The response to the + request. + """ logger.debug(f'Updating envoy status: {request}') cuda_devices_info = [ MessageToDict(message, preserving_proto_field_name=True) @@ -289,8 +448,7 @@ async def UpdateEnvoyStatus(self, request, context): # NOQA:N802 health_check_period = self.director.update_envoy_status( envoy_name=request.name, is_experiment_running=request.is_experiment_running, - cuda_devices_status=cuda_devices_info - ) + cuda_devices_status=cuda_devices_info) except ShardNotFoundError as exc: logger.error(exc) await context.abort(grpc.StatusCode.NOT_FOUND, str(exc)) @@ -301,51 +459,75 @@ async def UpdateEnvoyStatus(self, request, context): # NOQA:N802 return resp async def GetEnvoys(self, request, context): # NOQA:N802 - """Get a status information about envoys.""" + """Get a status information about envoys. + + Args: + request (director_pb2.GetEnvoysRequest): The request from the + collaborator. + context (grpc.ServicerContext): The context of the request. + + Returns: + director_pb2.GetEnvoysResponse: The response to the request. + """ envoy_infos = self.director.get_envoys() envoy_statuses = [] for envoy_info in envoy_infos: envoy_info_message = director_pb2.EnvoyInfo( - shard_info=ParseDict( - envoy_info['shard_info'], director_pb2.ShardInfo(), - ignore_unknown_fields=True), + shard_info=ParseDict(envoy_info['shard_info'], + director_pb2.ShardInfo(), + ignore_unknown_fields=True), is_online=envoy_info['is_online'], is_experiment_running=envoy_info['is_experiment_running']) - envoy_info_message.valid_duration.seconds = envoy_info['valid_duration'] - envoy_info_message.last_updated.seconds = int(envoy_info['last_updated']) + envoy_info_message.valid_duration.seconds = envoy_info[ + 'valid_duration'] + envoy_info_message.last_updated.seconds = int( + envoy_info['last_updated']) envoy_statuses.append(envoy_info_message) return director_pb2.GetEnvoysResponse(envoy_infos=envoy_statuses) async def GetExperimentsList(self, request, context): # NOQA:N802 - """Get list of experiments description.""" + """Get list of experiments description. + + Args: + request (director_pb2.GetExperimentsListRequest): The request from + the collaborator. + context (grpc.ServicerContext): The context of the request. + + Returns: + director_pb2.GetExperimentsListResponse: The response to the + request. + """ caller = self.get_caller(context) experiments = self.director.get_experiments_list(caller) experiment_list = [ - director_pb2.ExperimentListItem(**exp) - for exp in experiments + director_pb2.ExperimentListItem(**exp) for exp in experiments ] return director_pb2.GetExperimentsListResponse( - experiments=experiment_list - ) + experiments=experiment_list) async def GetExperimentDescription(self, request, context): # NOQA:N802 - """Get an experiment description.""" + """Get an experiment description. + + Args: + request (director_pb2.GetExperimentDescriptionRequest): The + request from the collaborator. + context (grpc.ServicerContext): The context of the request. + + Returns: + director_pb2.GetExperimentDescriptionResponse: The response to the + request. + """ caller = self.get_caller(context) - experiment = self.director.get_experiment_description(caller, request.name) + experiment = self.director.get_experiment_description( + caller, request.name) models_statuses = [ - base_pb2.DownloadStatus( - name=ms['name'], - status=ms['status'] - ) + base_pb2.DownloadStatus(name=ms['name'], status=ms['status']) for ms in experiment['download_statuses']['models'] ] logs_statuses = [ - base_pb2.DownloadStatus( - name=ls['name'], - status=ls['status'] - ) + base_pb2.DownloadStatus(name=ls['name'], status=ls['status']) for ls in experiment['download_statuses']['logs'] ] download_statuses = base_pb2.DownloadStatuses( @@ -353,21 +535,17 @@ async def GetExperimentDescription(self, request, context): # NOQA:N802 logs=logs_statuses, ) collaborators = [ - base_pb2.CollaboratorDescription( - name=col['name'], - status=col['status'], - progress=col['progress'], - round=col['round'], - current_task=col['current_task'], - next_task=col['next_task'] - ) + base_pb2.CollaboratorDescription(name=col['name'], + status=col['status'], + progress=col['progress'], + round=col['round'], + current_task=col['current_task'], + next_task=col['next_task']) for col in experiment['collaborators'] ] tasks = [ - base_pb2.TaskDescription( - name=task['name'], - description=task['description'] - ) + base_pb2.TaskDescription(name=task['name'], + description=task['description']) for task in experiment['tasks'] ] @@ -381,5 +559,4 @@ async def GetExperimentDescription(self, request, context): # NOQA:N802 download_statuses=download_statuses, collaborators=collaborators, tasks=tasks, - ), - ) + ), ) diff --git a/openfl/transport/grpc/exceptions.py b/openfl/transport/grpc/exceptions.py index 5bd19315c0..3af78b9a23 100644 --- a/openfl/transport/grpc/exceptions.py +++ b/openfl/transport/grpc/exceptions.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Exceptions that occur during service interaction.""" diff --git a/openfl/transport/grpc/grpc_channel_options.py b/openfl/transport/grpc/grpc_channel_options.py index 229dd45e51..5c4d6f01fa 100644 --- a/openfl/transport/grpc/grpc_channel_options.py +++ b/openfl/transport/grpc/grpc_channel_options.py @@ -1,11 +1,9 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -max_metadata_size = 32 * 2 ** 20 -max_message_length = 2 ** 30 +max_metadata_size = 32 * 2**20 +max_message_length = 2**30 -channel_options = [ - ('grpc.max_metadata_size', max_metadata_size), - ('grpc.max_send_message_length', max_message_length), - ('grpc.max_receive_message_length', max_message_length) -] +channel_options = [('grpc.max_metadata_size', max_metadata_size), + ('grpc.max_send_message_length', max_message_length), + ('grpc.max_receive_message_length', max_message_length)] diff --git a/openfl/utilities/__init__.py b/openfl/utilities/__init__.py index 9cfc001eaa..af38a74ee7 100644 --- a/openfl/utilities/__init__.py +++ b/openfl/utilities/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.utilities package.""" from .types import * # NOQA diff --git a/openfl/utilities/ca.py b/openfl/utilities/ca.py index a35210ffef..1277adbdda 100644 --- a/openfl/utilities/ca.py +++ b/openfl/utilities/ca.py @@ -5,7 +5,21 @@ def get_credentials(folder_path): - """Get credentials from folder by template.""" + """Get credentials from folder by template. + + This function retrieves the root certificate, key, and certificate from + the specified folder. + The files are identified by their extensions: '.key' for the key, '.crt' + for the certificate, and 'root_ca' for the root certificate. + + Args: + folder_path (str): The path to the folder containing the credentials. + + Returns: + Tuple[Optional[str], Optional[str], Optional[str]]: The paths to the + root certificate, key, and certificate. + If a file is not found, its corresponding value is None. + """ root_ca, key, cert = None, None, None if os.path.exists(folder_path): for f in os.listdir(folder_path): diff --git a/openfl/utilities/ca/__init__.py b/openfl/utilities/ca/__init__.py index 3277f66c42..a7dbfd24f2 100644 --- a/openfl/utilities/ca/__init__.py +++ b/openfl/utilities/ca/__init__.py @@ -1,4 +1,3 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """CA package.""" diff --git a/openfl/utilities/ca/ca.py b/openfl/utilities/ca/ca.py index 1ff88c5742..76ab58f28a 100644 --- a/openfl/utilities/ca/ca.py +++ b/openfl/utilities/ca/ca.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """CA module.""" import base64 @@ -30,21 +29,24 @@ def get_token(name, ca_url, ca_path='.'): - """ - Create authentication token. + """Generate an authentication token. Args: - name: common name for following certificate - (aggregator fqdn or collaborator name) - ca_url: full url of CA server - ca_path: path to ca folder + name (str): Common name for the following certificate + (aggregator fqdn or collaborator name). + ca_url (str): Full URL of the CA server. + ca_path (str, optional): Path to the CA folder. Defaults to '.'. + + Returns: + str: The generated authentication token. """ ca_path = Path(ca_path) step_config_dir = ca_path / CA_STEP_CONFIG_DIR pki_dir = ca_path / CA_PKI_DIR step_path, _ = get_ca_bin_paths(ca_path) if not step_path: - raise Exception('Step-CA is not installed!\nRun `fx pki install` first') + raise Exception( + 'Step-CA is not installed!\nRun `fx pki install` first') priv_json = step_config_dir / 'secrets' / 'priv.json' pass_file = pki_dir / CA_PASSWORD_FILE @@ -53,7 +55,9 @@ def get_token(name, ca_url, ca_path='.'): token = subprocess.check_output( f'{step_path} ca token {name} ' f'--key {priv_json} --root {root_crt} ' - f'--password-file {pass_file} 'f'--ca-url {ca_url}', shell=True) + f'--password-file {pass_file} ' + f'--ca-url {ca_url}', + shell=True) except subprocess.CalledProcessError as exc: logger.error(f'Error code {exc.returncode}: {exc.output}') sys.exit(1) @@ -72,7 +76,14 @@ def get_token(name, ca_url, ca_path='.'): def get_ca_bin_paths(ca_path): - """Get paths of step binaries.""" + """Get the paths of the step binaries. + + Args: + ca_path (str): Path to the CA directory. + + Returns: + tuple: Paths to the step and step-ca binaries. + """ ca_path = Path(ca_path) step = None step_ca = None @@ -93,7 +104,14 @@ def get_ca_bin_paths(ca_path): def certify(name, cert_path: Path, token_with_cert, ca_path: Path): - """Create an envoy workspace.""" + """Create a certificate for a given name. + + Args: + name (str): Name for the certificate. + cert_path (Path): Path to store the certificate. + token_with_cert (str): Authentication token with certificate. + ca_path (Path): Path to the CA directory. + """ os.makedirs(cert_path, exist_ok=True) token, root_certificate = token_with_cert.split(TOKEN_DELIMITER) @@ -105,29 +123,34 @@ def certify(name, cert_path: Path, token_with_cert, ca_path: Path): download_step_bin(prefix=ca_path) step_path, _ = get_ca_bin_paths(ca_path) if not step_path: - raise Exception('Step-CA is not installed!\nRun `fx pki install` first') + raise Exception( + 'Step-CA is not installed!\nRun `fx pki install` first') with open(f'{cert_path}/root_ca.crt', mode='wb') as file: file.write(root_certificate) - check_call(f'{step_path} ca certificate {name} {cert_path}/{name}.crt ' - f'{cert_path}/{name}.key --kty EC --curve P-384 -f --token {token}', shell=True) + check_call( + f'{step_path} ca certificate {name} {cert_path}/{name}.crt ' + f'{cert_path}/{name}.key --kty EC --curve P-384 -f --token {token}', + shell=True) def remove_ca(ca_path): - """Kill step-ca process and rm ca directory.""" + """Remove the CA directory and kill the step-ca process. + + Args: + ca_path (str): Path to the CA directory. + """ _check_kill_process('step-ca') shutil.rmtree(ca_path, ignore_errors=True) def install(ca_path, ca_url, password): - """ - Create certificate authority for federation. + """Create a certificate authority for the federation. Args: - ca_path: path to ca directory - ca_url: url for ca server like: 'host:port' - password: Simple password for encrypting root private keys - + ca_path (str): Path to the CA directory. + ca_url (str): URL for the CA server. Like: 'host:port' + password (str): Password for encrypting root private keys. """ logger.info('Creating CA') @@ -137,28 +160,46 @@ def install(ca_path, ca_url, password): os.environ['STEPPATH'] = str(step_config_dir) step_path, step_ca_path = get_ca_bin_paths(ca_path) - if not (step_path and step_ca_path and step_path.exists() and step_ca_path.exists()): + if not (step_path and step_ca_path and step_path.exists() + and step_ca_path.exists()): download_step_bin(prefix=ca_path, confirmation=True) download_step_ca_bin(prefix=ca_path, confirmation=False) step_config_dir = ca_path / CA_STEP_CONFIG_DIR - if (not step_config_dir.exists() - or confirm('CA exists, do you want to recreate it?', default=True)): + if (not step_config_dir.exists() or confirm( + 'CA exists, do you want to recreate it?', default=True)): _create_ca(ca_path, ca_url, password) _configure(step_config_dir) def run_ca(step_ca, pass_file, ca_json): - """Run CA server.""" + """Run the CA server. + + Args: + step_ca (str): Path to the step-ca binary. + pass_file (str): Path to the password file. + ca_json (str): Path to the CA configuration JSON file. + """ if _check_kill_process('step-ca', confirmation=True): logger.info('Up CA server') - check_call(f'{step_ca} --password-file {pass_file} {ca_json}', shell=True) + check_call(f'{step_ca} --password-file {pass_file} {ca_json}', + shell=True) def _check_kill_process(pstring, confirmation=False): - """Kill process by name.""" + """Kill a process by its name. + + Args: + pstring (str): Name of the process. + confirmation (bool, optional): If True, ask for confirmation before + killing the process. Defaults to False. + + Returns: + bool: True if the process was killed, False otherwise. + """ pids = [] proc = subprocess.Popen(f'ps ax | grep {pstring} | grep -v grep', - shell=True, stdout=subprocess.PIPE) + shell=True, + stdout=subprocess.PIPE) text = proc.communicate()[0].decode('utf-8') for line in text.splitlines(): @@ -166,7 +207,8 @@ def _check_kill_process(pstring, confirmation=False): pids.append(fields[0]) if len(pids): - if confirmation and not confirm('CA server is already running. Stop him?', default=True): + if confirmation and not confirm( + 'CA server is already running. Stop him?', default=True): return False for pid in pids: os.kill(int(pid), signal.SIGKILL) @@ -175,7 +217,13 @@ def _check_kill_process(pstring, confirmation=False): def _create_ca(ca_path: Path, ca_url: str, password: str): - """Create a ca workspace.""" + """Create a certificate authority workspace. + + Args: + ca_path (Path): Path to the CA directory. + ca_url (str): URL for the CA server. + password (str): Password for encrypting root private keys. + """ import os pki_dir = ca_path / CA_PKI_DIR step_config_dir = ca_path / CA_STEP_CONFIG_DIR @@ -187,7 +235,8 @@ def _create_ca(ca_path: Path, ca_url: str, password: str): f.write(password) os.chmod(f'{pki_dir}/pass_file', 0o600) step_path, step_ca_path = get_ca_bin_paths(ca_path) - if not (step_path and step_ca_path and step_path.exists() and step_ca_path.exists()): + if not (step_path and step_ca_path and step_path.exists() + and step_ca_path.exists()): logger.error('Could not find step-ca binaries in the path specified') sys.exit(1) @@ -199,22 +248,24 @@ def _create_ca(ca_path: Path, ca_url: str, password: str): f'{step_path} ca init --name name --dns {name} ' f'--address {ca_url} --provisioner prov ' f'--password-file {pki_dir}/pass_file', - shell=True - ) + shell=True) check_call(f'{step_path} ca provisioner remove prov --all', shell=True) check_call( f'{step_path} crypto jwk create {step_config_dir}/certs/pub.json ' f'{step_config_dir}/secrets/priv.json --password-file={pki_dir}/pass_file', - shell=True - ) + shell=True) check_call( f'{step_path} ca provisioner add provisioner {step_config_dir}/certs/pub.json', - shell=True - ) + shell=True) def _configure(step_config_dir): + """Configure the certificate authority. + + Args: + step_config_dir (str): Path to the step configuration directory. + """ conf_file = step_config_dir / CA_CONFIG_JSON with open(conf_file, 'r+', encoding='utf-8') as f: data = json.load(f) diff --git a/openfl/utilities/ca/downloader.py b/openfl/utilities/ca/downloader.py index 9331d1fd4a..bf6c1eda2f 100644 --- a/openfl/utilities/ca/downloader.py +++ b/openfl/utilities/ca/downloader.py @@ -16,14 +16,15 @@ 'aarch64': 'arm64' } -FILE_EXTENSIONS = { - 'windows': 'zip', - 'linux': 'tar.gz' -} +FILE_EXTENSIONS = {'windows': 'zip', 'linux': 'tar.gz'} def get_system_and_architecture(): - """Get system and architecture of machine.""" + """Get the system and architecture of the machine. + + Returns: + tuple: The system and architecture of the machine. + """ uname_res = platform.uname() system = uname_res.system.lower() @@ -34,12 +35,12 @@ def get_system_and_architecture(): def download_step_bin(prefix='.', confirmation=True): - """ - Download step binaries. + """Download step binaries. Args: - prefix: folder path to download - confirmation: request user confirmation or not + prefix (str, optional): Folder path to download. Defaults to '.'. + confirmation (bool, optional): Request user confirmation or not. + Defaults to True. """ system, arch = get_system_and_architecture() ext = FILE_EXTENSIONS[system] @@ -49,12 +50,12 @@ def download_step_bin(prefix='.', confirmation=True): def download_step_ca_bin(prefix='.', confirmation=True): - """ - Download step-ca binaries. + """Download step-ca binaries. Args: - prefix: folder path to download - confirmation: request user confirmation or not + prefix (str, optional): Folder path to download. Defaults to '.'. + confirmation (bool, optional): Request user confirmation or not. + Defaults to True. """ system, arch = get_system_and_architecture() ext = FILE_EXTENSIONS[system] @@ -64,6 +65,13 @@ def download_step_ca_bin(prefix='.', confirmation=True): def _download(url, prefix, confirmation): + """Download a file from a URL. + + Args: + url (str): URL of the file to download. + prefix (str): Folder path to download. + confirmation (bool): Request user confirmation or not. + """ if confirmation: confirm('CA binaries will be downloaded now', default=True, abort=True) name = url.split('/')[-1] diff --git a/openfl/utilities/checks.py b/openfl/utilities/checks.py index 6aacd1ea26..fe4681c072 100644 --- a/openfl/utilities/checks.py +++ b/openfl/utilities/checks.py @@ -4,15 +4,34 @@ def check_type(obj, expected_type, logger): - """Assert `obj` is of `expected_type` type.""" + """Assert `obj` is of `expected_type` type. + + Args: + obj (Any): The object to check. + expected_type (type): The expected type of the object. + logger (Logger): The logger to use for reporting the error. + + Raises: + TypeError: If the object is not of the expected type. + """ if not isinstance(obj, expected_type): - exception = TypeError(f'Expected type {type(obj)}, got type {str(expected_type)}') + exception = TypeError( + f'Expected type {type(obj)}, got type {str(expected_type)}') logger.exception(repr(exception)) raise exception def check_equal(x, y, logger): - """Assert `x` and `y` are equal.""" + """Assert `x` and `y` are equal. + + Args: + x (Any): The first value to compare. + y (Any): The second value to compare. + logger (Logger): The logger to use for reporting the error. + + Raises: + ValueError: If the values are not equal. + """ if x != y: exception = ValueError(f'{x} != {y}') logger.exception(repr(exception)) @@ -20,24 +39,56 @@ def check_equal(x, y, logger): def check_not_equal(x, y, logger, name='None provided'): - """Assert `x` and `y` are not equal.""" + """Assert `x` and `y` are not equal. + + Args: + x (Any): The first value to compare. + y (Any): The second value to compare. + logger (Logger): The logger to use for reporting the error. + name (str, optional): The name of the values. Defaults to + 'None provided'. + + Raises: + ValueError: If the values are equal. + """ if x == y: - exception = ValueError(f'Name {name}. Expected inequality, but {x} == {y}') + exception = ValueError( + f'Name {name}. Expected inequality, but {x} == {y}') logger.exception(repr(exception)) raise exception def check_is_in(element, _list, logger): - """Assert `element` is in `_list` collection.""" + """Assert `element` is in `_list` collection. + + Args: + element (Any): The element to check. + _list (Iterable): The collection to check in. + logger (Logger): The logger to use for reporting the error. + + Raises: + ValueError: If the element is not in the collection. + """ if element not in _list: - exception = ValueError(f'Expected sequence membership, but {element} is not in {_list}') + exception = ValueError( + f'Expected sequence membership, but {element} is not in {_list}') logger.exception(repr(exception)) raise exception def check_not_in(element, _list, logger): - """Assert `element` is not in `_list` collection.""" + """Assert `element` is not in `_list` collection. + + Args: + element (Any): The element to check. + _list (Iterable): The collection to check in. + logger (Logger): The logger to use for reporting the error. + + Raises: + ValueError: If the element is in the collection. + """ if element in _list: - exception = ValueError(f'Expected not in sequence, but {element} is in {_list}') + exception = ValueError( + f'Expected not in sequence, but {element} is in {_list}') logger.exception(repr(exception)) raise exception diff --git a/openfl/utilities/click_types.py b/openfl/utilities/click_types.py index 4847ff5ed1..9955fff461 100644 --- a/openfl/utilities/click_types.py +++ b/openfl/utilities/click_types.py @@ -1,6 +1,6 @@ # Copyright (C) 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""Custom input types definition for Click""" +"""Custom input types definition for Click.""" import click import ast @@ -9,31 +9,74 @@ class FqdnParamType(click.ParamType): - """Domain Type for click arguments.""" + """Domain Type for click arguments. + + This class is used to validate that a command line argument is a fully + qualified domain name. + + Attributes: + name (str): The name of the parameter type. + """ name = 'fqdn' def convert(self, value, param, ctx): - """Validate value, if value is valid, return it.""" + """Validate value, if value is valid, return it. + + Args: + value (str): The value to validate. + param (click.core.Option): The option that this value was supplied + to. + ctx (click.core.Context): The context for the parameter. + + Returns: + str: The value, if it is valid. + + Raises: + value (click.exceptions.BadParameter): If the value is not a valid + domain name. + """ if not utils.is_fqdn(value): self.fail(f'{value} is not a valid domain name', param, ctx) return value class IpAddressParamType(click.ParamType): - """IpAddress Type for click arguments.""" + """IpAddress Type for click arguments. + + This class is used to validate that a command line argument is an IP + address. + + Attributes: + name (str): The name of the parameter type. + """ name = 'IpAddress type' def convert(self, value, param, ctx): - """Validate value, if value is valid, return it.""" + """Validate value, if value is valid, return it. + + Args: + value (str): The value to validate. + param (click.core.Option): The option that this value was supplied + to. + ctx (click.core.Context): The context for the parameter. + + Returns: + str: The value, if it is valid. + + Raises: + click.exceptions.BadParameter: If the value is not a valid IP + address. + """ if not utils.is_api_adress(value): self.fail(f'{value} is not a valid ip adress name', param, ctx) return value class InputSpec(click.Option): - """List or dictionary that corresponds to the input shape for a model""" + """List or dictionary that corresponds to the input shape for a model.""" + def type_cast_value(self, ctx, value): try: if value is None: diff --git a/openfl/utilities/data_splitters/__init__.py b/openfl/utilities/data_splitters/__init__.py index 3aec457b4d..bf123f66f3 100644 --- a/openfl/utilities/data_splitters/__init__.py +++ b/openfl/utilities/data_splitters/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0- - """openfl.utilities.data package.""" from openfl.utilities.data_splitters.data_splitter import DataSplitter from openfl.utilities.data_splitters.numpy import DirichletNumPyDataSplitter diff --git a/openfl/utilities/data_splitters/data_splitter.py b/openfl/utilities/data_splitters/data_splitter.py index cd1a29927d..9f8459be66 100644 --- a/openfl/utilities/data_splitters/data_splitter.py +++ b/openfl/utilities/data_splitters/data_splitter.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.utilities.data_splitters.data_splitter module.""" from abc import ABC from abc import abstractmethod @@ -12,9 +11,27 @@ class DataSplitter(ABC): - """Base class for data splitting.""" + """Base class for data splitting. + + This class should be subclassed when creating specific data splitter + classes. + """ @abstractmethod - def split(self, data: Iterable[T], num_collaborators: int) -> List[Iterable[T]]: - """Split the data.""" + def split(self, data: Iterable[T], + num_collaborators: int) -> List[Iterable[T]]: + """Split the data into a specified number of parts. + + Args: + data (Iterable[T]): The data to be split. + num_collaborators (int): The number of parts to split the data + into. + + Returns: + List[Iterable[T]]: The split data. + + Raises: + NotImplementedError: This is an abstract method and must be + overridden in a subclass. + """ raise NotImplementedError diff --git a/openfl/utilities/data_splitters/numpy.py b/openfl/utilities/data_splitters/numpy.py index 6d8cf22fc9..585a894d09 100644 --- a/openfl/utilities/data_splitters/numpy.py +++ b/openfl/utilities/data_splitters/numpy.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """UnbalancedFederatedDataset module.""" from abc import abstractmethod @@ -13,35 +12,63 @@ def get_label_count(labels, label): - """Count samples with label `label` in `labels` array.""" + """Count the number of samples with a specific label in a labels array. + + Args: + labels (np.ndarray): Array of labels. + label (int or str): The label to count. + + Returns: + int: The count of the label in the labels array. + """ return len(np.nonzero(labels == label)[0]) def one_hot(labels, classes): - """Apply One-Hot encoding to labels.""" + """Apply One-Hot encoding to labels. + + Args: + labels (np.ndarray): Array of labels. + classes (int): The total number of classes. + + Returns: + np.ndarray: The one-hot encoded labels. + """ return np.eye(classes)[labels] class NumPyDataSplitter(DataSplitter): - """Base class for splitting numpy arrays of data.""" + """Base class for splitting numpy arrays of data. + + This class should be subclassed when creating specific data splitter + classes. + """ @abstractmethod - def split(self, data: np.ndarray, num_collaborators: int) -> List[List[int]]: + def split(self, data: np.ndarray, + num_collaborators: int) -> List[List[int]]: """Split the data.""" raise NotImplementedError class EqualNumPyDataSplitter(NumPyDataSplitter): - """Splits the data evenly.""" + """Class for splitting numpy arrays of data evenly. + + Args: + shuffle (bool, optional): Flag determining whether to shuffle the + dataset before splitting. Defaults to True. + seed (int, optional): Random numbers generator seed. Defaults to 0. + """ def __init__(self, shuffle=True, seed=0): """Initialize. Args: - shuffle(bool): Flag determining whether to shuffle the dataset before splitting. - seed(int): Random numbers generator seed. - For different splits on envoys, try setting different values for this parameter - on each shard descriptor. + shuffle (bool): Flag determining whether to shuffle the dataset + before splitting. Defaults to True. + seed (int): Random numbers generator seed. Defaults to 0. + For different splits on envoys, try setting different values + for this parameter on each shard descriptor. """ self.shuffle = shuffle self.seed = seed @@ -57,16 +84,23 @@ def split(self, data, num_collaborators): class RandomNumPyDataSplitter(NumPyDataSplitter): - """Splits the data randomly.""" + """Class for splitting numpy arrays of data randomly. + + Args: + shuffle (bool, optional): Flag determining whether to shuffle the + dataset before splitting. Defaults to True. + seed (int, optional): Random numbers generator seed. Defaults to 0. + """ def __init__(self, shuffle=True, seed=0): """Initialize. Args: - shuffle(bool): Flag determining whether to shuffle the dataset before splitting. - seed(int): Random numbers generator seed. - For different splits on envoys, try setting different values for this parameter - on each shard descriptor. + shuffle (bool): Flag determining whether to shuffle the dataset + before splitting. Defaults to True. + seed (int): Random numbers generator seed. Defaults to 0. + For different splits on envoys, try setting different values + for this parameter on each shard descriptor. """ self.shuffle = shuffle self.seed = seed @@ -77,29 +111,45 @@ def split(self, data, num_collaborators): idx = range(len(data)) if self.shuffle: idx = np.random.permutation(idx) - random_idx = np.sort(np.random.choice(len(data), num_collaborators - 1, replace=False)) + random_idx = np.sort( + np.random.choice(len(data), num_collaborators - 1, replace=False)) return np.split(idx, random_idx) class LogNormalNumPyDataSplitter(NumPyDataSplitter): - """Unbalanced (LogNormal) dataset split. + """Class for splitting numpy arrays of data according to a LogNormal + distribution. + Unbalanced (LogNormal) dataset split. This split assumes only several classes are assigned to each collaborator. - Firstly, it assigns classes_per_col * min_samples_per_class items of dataset - to each collaborator so all of collaborators will have some data after the split. + Firstly, it assigns classes_per_col * min_samples_per_class items of + dataset to each collaborator so all of collaborators will have some data + after the split. Then, it generates positive integer numbers by log-normal (power) law. - These numbers correspond to numbers of dataset items picked each time from dataset - and assigned to a collaborator. + These numbers correspond to numbers of dataset items picked each time from + dataset and assigned to a collaborator. Generation is repeated for each class assigned to a collaborator. - This is a parametrized version of non-i.i.d. data split in FedProx algorithm. + This is a parametrized version of non-i.i.d. data split in FedProx + algorithm. Origin source: https://github.com/litian96/FedProx/blob/master/data/mnist/generate_niid.py#L30 - NOTE: This split always drops out some part of the dataset! - Non-deterministic behavior selects only random subpart of class items. + Args: + mu (float): Distribution hyperparameter. + sigma (float): Distribution hyperparameter. + num_classes (int): Number of classes. + classes_per_col (int): Number of classes assigned to each collaborator. + min_samples_per_class (int): Minimum number of collaborator samples of + each class. + seed (int, optional): Random numbers generator seed. Defaults to 0. + + .. note:: + This split always drops out some part of the dataset! + Non-deterministic behavior selects only random subpart of class items. """ - def __init__(self, mu, + def __init__(self, + mu, sigma, num_classes, classes_per_col, @@ -108,13 +158,15 @@ def __init__(self, mu, """Initialize the generator. Args: - mu(float): Distribution hyperparameter. - sigma(float): Distribution hyperparameter. - classes_per_col(int): Number of classes assigned to each collaborator. - min_samples_per_class(int): Minimum number of collaborator samples of each class. - seed(int): Random numbers generator seed. - For different splits on envoys, try setting different values for this parameter - on each shard descriptor. + mu (float): Distribution hyperparameter. + sigma (float): Distribution hyperparameter. + classes_per_col (int): Number of classes assigned to each + collaborator. + min_samples_per_class (int): Minimum number of collaborator + samples of each class. + seed (int): Random numbers generator seed. Defaults to 0. + For different splits on envoys, try setting different values + for this parameter on each shard descriptor. """ self.mu = mu self.sigma = sigma @@ -127,8 +179,9 @@ def split(self, data, num_collaborators): """Split the data. Args: - data(np.ndarray): numpy-like label array. - num_collaborators(int): number of collaborators to split data across. + data (np.ndarray): numpy-like label array. + num_collaborators (int): number of collaborators to split data + across. Should be divisible by number of classes in ``data``. """ np.random.seed(self.seed) @@ -141,33 +194,38 @@ def split(self, data, num_collaborators): slice_start = col // self.num_classes * samples_per_col slice_start += self.min_samples_per_class * c slice_end = slice_start + self.min_samples_per_class - print(f'Assigning {slice_start}:{slice_end} of class {label} to {col} col...') + print( + f'Assigning {slice_start}:{slice_end} of class {label} to {col} col...' + ) idx[col] += list(label_idx[slice_start:slice_end]) if any(len(i) != samples_per_col for i in idx): - raise SystemError(f'''All collaborators should have {samples_per_col} elements + raise SystemError( + f'''All collaborators should have {samples_per_col} elements but distribution is {[len(i) for i in idx]}''') - props_shape = ( - self.num_classes, - num_collaborators // self.num_classes, - self.classes_per_col - ) + props_shape = (self.num_classes, num_collaborators // self.num_classes, + self.classes_per_col) props = np.random.lognormal(self.mu, self.sigma, props_shape) - num_samples_per_class = [[[get_label_count(data, label) - self.min_samples_per_class]] - for label in range(self.num_classes)] + num_samples_per_class = [[[ + get_label_count(data, label) - self.min_samples_per_class + ]] for label in range(self.num_classes)] num_samples_per_class = np.array(num_samples_per_class) - props = num_samples_per_class * props / np.sum(props, (1, 2), keepdims=True) + props = num_samples_per_class * props / np.sum(props, (1, 2), + keepdims=True) for col in trange(num_collaborators): for j in range(self.classes_per_col): label = (col + j) % self.num_classes num_samples = int(props[label, col // self.num_classes, j]) - print(f'Trying to append {num_samples} samples of {label} class to {col} col...') + print( + f'Trying to append {num_samples} samples of {label} class to {col} col...' + ) slice_start = np.count_nonzero(data[np.hstack(idx)] == label) slice_end = slice_start + num_samples label_count = get_label_count(data, label) if slice_end < label_count: - label_subset = np.nonzero(data == (col + j) % self.num_classes)[0] + label_subset = np.nonzero(data == (col + j) % + self.num_classes)[0] idx_to_append = label_subset[slice_start:slice_end] idx[col] = np.append(idx[col], idx_to_append) else: @@ -178,23 +236,33 @@ def split(self, data, num_collaborators): class DirichletNumPyDataSplitter(NumPyDataSplitter): - """Numpy splitter according to dirichlet distribution. + """Class for splitting numpy arrays of data according to a Dirichlet + distribution. Generates the random sample of integer numbers from dirichlet distribution until minimum subset length exceeds the specified threshold. - This behavior is a parametrized version of non-i.i.d. split in FedMA algorithm. + This behavior is a parametrized version of non-i.i.d. split in FedMA + algorithm. Origin source: https://github.com/IBM/FedMA/blob/master/utils.py#L96 + + Args: + alpha (float, optional): Dirichlet distribution parameter. Defaults + to 0.5. + min_samples_per_col (int, optional): Minimal amount of samples per + collaborator. Defaults to 10. + seed (int, optional): Random numbers generator seed. Defaults to 0. """ def __init__(self, alpha=0.5, min_samples_per_col=10, seed=0): """Initialize. Args: - alpha(float): Dirichlet distribution parameter. - min_samples_per_col(int): Minimal amount of samples per collaborator. - seed(int): Random numbers generator seed. - For different splits on envoys, try setting different values for this parameter - on each shard descriptor. + alpha (float): Dirichlet distribution parameter. Defaults to 0.5. + min_samples_per_col (int): Minimal amount of samples per + collaborator. Defaults to 10. + seed (int): Random numbers generator seed. Defaults to 0. + For different splits on envoys, try setting different values + for this parameter on each shard descriptor. """ self.alpha = alpha self.min_samples_per_col = min_samples_per_col @@ -212,13 +280,20 @@ def split(self, data, num_collaborators): for k in range(classes): idx_k = np.where(data == k)[0] np.random.shuffle(idx_k) - proportions = np.random.dirichlet(np.repeat(self.alpha, num_collaborators)) - proportions = [p * (len(idx_j) < n / num_collaborators) - for p, idx_j in zip(proportions, idx_batch)] + proportions = np.random.dirichlet( + np.repeat(self.alpha, num_collaborators)) + proportions = [ + p * (len(idx_j) < n / num_collaborators) + for p, idx_j in zip(proportions, idx_batch) + ] proportions = np.array(proportions) proportions = proportions / proportions.sum() - proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] + proportions = (np.cumsum(proportions) + * len(idx_k)).astype(int)[:-1] idx_splitted = np.split(idx_k, proportions) - idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, idx_splitted)] + idx_batch = [ + idx_j + idx.tolist() + for idx_j, idx in zip(idx_batch, idx_splitted) + ] min_size = min([len(idx_j) for idx_j in idx_batch]) return idx_batch diff --git a/openfl/utilities/fed_timer.py b/openfl/utilities/fed_timer.py index 4540e9bb10..e2f4e68fd3 100644 --- a/openfl/utilities/fed_timer.py +++ b/openfl/utilities/fed_timer.py @@ -1,6 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""Components Timeout Configuration Module""" +"""Components Timeout Configuration Module.""" import asyncio import logging @@ -15,45 +15,77 @@ class CustomThread(Thread): - ''' - The CustomThread object implements `threading.Thread` class. - Allows extensibility and stores the returned result from threaded execution. + """Custom Thread class. + + This class extends the `threading.Thread` class and allows for the storage + of the result returned by the target function. Attributes: - target (function): decorated function - name (str): Name of the decorated function - *args (tuple): Arguments passed as a parameter to decorated function. - **kwargs (dict): Keyword arguments passed as a parameter to decorated function. + target (function): The function to be executed in a separate thread. + name (str): The name of the thread. + args (tuple): The positional arguments to pass to the target function. + kwargs (dict): The keyword arguments to pass to the target function. + """ - ''' def __init__(self, group=None, target=None, name=None, args=(), kwargs={}): + """Initialize a CustomThread object. + + Args: + group (None, optional): Reserved for future extension when a + ThreadGroup class is implemented. + target (function, optional): The function to be executed in a + separate thread. + name (str, optional): The name of the thread. + args (tuple, optional): The positional arguments to pass to the + target function. + kwargs (dict, optional): The keyword arguments to pass to the + target function. + """ Thread.__init__(self, group, target, name, args, kwargs) self._result = None def run(self): - ''' - `run()` Invoked by `thread.start()` - ''' + """Method representing the thread's activity. + + This method is invoked by `thread.start()`. + """ if self._target is not None: self._result = self._target(*self._args, **self._kwargs) def result(self): + """Get the result of the thread's activity. + + Returns: + Any: The result of the target function. + """ return self._result class PrepareTask(): - ''' - `PrepareTask` class stores the decorated function metadata and instantiates - either the `asyncio` or `thread` tasks to handle asynchronous - and synchronous execution of the decorated function respectively. + """Prepare Task class. + + This class stores the decorated function metadata and instantiates. + Prepares a task for execution, either synchronously or asynchronously. + Attributes: - target (function): decorated function - timeout (int): Timeout duration in second(s). - *args (tuple): Arguments passed as a parameter to decorated function. - **kwargs (dict): Keyword arguments passed as a parameter to decorated function. - ''' + target_fn (function): The function to be executed. + max_timeout (int): The maximum time to allow for the function's + execution. + args (tuple): The positional arguments to pass to the function. + kwargs (dict): The keyword arguments to pass to the function. + """ + def __init__(self, target_fn, timeout, args, kwargs) -> None: + """Initialize a PrepareTask object. + + Args: + target_fn (function): The function to be executed. + timeout (int): The maximum time to allow for the function's + execution. + args (tuple): The positional arguments to pass to the function. + kwargs (dict): The keyword arguments to pass to the function. + """ self._target_fn = target_fn self._fn_name = target_fn.__name__ self._max_timeout = timeout @@ -61,53 +93,58 @@ def __init__(self, target_fn, timeout, args, kwargs) -> None: self._kwargs = kwargs async def async_execute(self): - '''Handles asynchronous execution of the - decorated function referenced by `self._target_fn`. + """Execute the task asynchronously of the decorated function referenced + by `self._target_fn`. Raises: - asyncio.TimeoutError: If the async execution exceeds permitted time limit. - Exception: Captures generic exceptions. + asyncio.TimeoutError: If the execution exceeds the maximum time. + Exception: If any other error occurs during execution. Returns: - Any: The returned value from `task.results()` depends on the decorated function. - ''' - task = asyncio.create_task( - name=self._fn_name, - coro=self._target_fn(*self._args, **self._kwargs) - ) + Any: The result of the function's execution. + The returned value from `task.results()` depends on the + decorated function. + """ + task = asyncio.create_task(name=self._fn_name, + coro=self._target_fn( + *self._args, **self._kwargs)) try: await asyncio.wait_for(task, timeout=self._max_timeout) except asyncio.TimeoutError: - raise asyncio.TimeoutError(f"Timeout after {self._max_timeout} second(s), " - f"Exception method: ({self._fn_name})") + raise asyncio.TimeoutError( + f"Timeout after {self._max_timeout} second(s), " + f"Exception method: ({self._fn_name})") except Exception: raise Exception(f"Generic Exception: {self._fn_name}") return task.result() def sync_execute(self): - '''Handles synchronous execution of the - decorated function referenced by `self._target_fn`. + """Execute the task synchronously of the decorated function referenced + by `self._target_fn`. Raises: - TimeoutError: If the synchronous execution exceeds permitted time limit. + TimeoutError: If the execution exceeds the maximum time. Returns: - Any: The returned value from `task.results()` depends on the decorated function. - ''' + Any: The result of the function's execution. + The returned value from `task.results()` depends on the + decorated function. + """ task = CustomThread(target=self._target_fn, name=self._fn_name, args=self._args, kwargs=self._kwargs) task.start() # Execution continues if the decorated function completes within the timelimit. - # If the execution exceeds time limit then - # the spawned thread is force joined to current/main thread. + # If the execution exceeds time limit then the spawned thread is force + # joined to current/main thread. task.join(self._max_timeout) # If the control is back to current/main thread - # and the spawned thread is still alive then timeout and raise exception. + # and the spawned thread is still alive then timeout and raise + # exception. if task.is_alive(): raise TimeoutError(f"Timeout after {self._max_timeout} second(s), " f"Exception method: ({self._fn_name})") @@ -116,76 +153,123 @@ def sync_execute(self): class SyncAsyncTaskDecoFactory: - ''' - `Sync` and `Async` Task decorator factory allows creation of - concrete implementation of `wrapper` interface and `contextmanager` to - setup a common functionality/resources shared by `async_wrapper` and `sync_wrapper`. + """Sync and Async Task decorator factory. + + This class is a factory for creating decorators for synchronous and + asynchronous tasks. + Task decorator factory allows creation of concrete implementation of + `wrapper` interface and `contextmanager` to setup a common + functionality/resources shared by `async_wrapper` and `sync_wrapper`. - ''' + Attributes: + is_coroutine (bool): Whether the decorated function is a coroutine. + """ @contextmanager def wrapper(self, func, *args, **kwargs): + """Create a context for the decorated function. + + Args: + func (function): The function to be decorated. + args (tuple): The positional arguments to pass to the function. + kwargs (dict): The keyword arguments to pass to the function. + + Yields: + None + """ yield def __call__(self, func): - ''' - Call to `@fedtiming()` executes `__call__()` method - delegated from the derived class `fedtiming` implementing `SyncAsyncTaskDecoFactory`. - ''' + """Decorate the function. Call to `@fedtiming()` executes `__call__()` + method delegated from the derived class `fedtiming` implementing + `SyncAsyncTaskDecoFactory`. + Args: + func (function): The function to be decorated. + + Returns: + function: The decorated function. + """ # Closures self.is_coroutine = asyncio.iscoroutinefunction(func) str_fmt = "{} Method ({}); Co-routine {}" @wraps(func) def sync_wrapper(*args, **kwargs): - ''' - Wrapper for synchronous execution of decorated function. - ''' - logger.debug(str_fmt.format("sync", func.__name__, self.is_coroutine)) + """Wrapper for synchronous execution of decorated function.""" + logger.debug( + str_fmt.format("sync", func.__name__, self.is_coroutine)) with self.wrapper(func, *args, **kwargs): return self.task.sync_execute() @wraps(func) async def async_wrapper(*args, **kwargs): - ''' - Wrapper for asynchronous execution of decorated function. - ''' - logger.debug(str_fmt.format("async", func.__name__, self.is_coroutine)) + """Wrapper for asynchronous execution of decorated function.""" + logger.debug( + str_fmt.format("async", func.__name__, self.is_coroutine)) with self.wrapper(func, *args, **kwargs): return await self.task.async_execute() - # Identify if the decorated function is `async` or `sync` and return appropriate wrapper. + # Identify if the decorated function is `async` or `sync` and return + # appropriate wrapper. if self.is_coroutine: return async_wrapper return sync_wrapper class fedtiming(SyncAsyncTaskDecoFactory): # noqa: N801 + """FedTiming decorator factory. + + This class is a factory for creating decorators for timing synchronous and + asynchronous tasks. + + Attributes: + timeout (int): The maximum time to allow for the function's execution. + """ + def __init__(self, timeout): + """Initialize a FedTiming object. + + Args: + timeout (int): The maximum time to allow for the function's + execution. + """ self.timeout = timeout @contextmanager def wrapper(self, func, *args, **kwargs): - ''' - Concrete implementation of setup and teardown logic, yields the control back to - `async_wrapper` or `sync_wrapper` function call. + """Create a context for the decorated function. + + This method sets up the task for execution and measures its execution + time. + Yields the control back to `async_wrapper` or `sync_wrapper` function + call. + + Args: + func (function): The function to be decorated. + args (tuple): The positional arguments to pass to the function. + kwargs (dict): The keyword arguments to pass to the function. + + Yields: + None Raises: - Exception: Captures the exception raised by `async_wrapper` - or `sync_wrapper` and terminates the execution. - ''' - self.task = PrepareTask( - target_fn=func, - timeout=self.timeout, - args=args, - kwargs=kwargs - ) + Exception: If an error occurs during the function's execution + raised by `async_wrapper` or `sync_wrapper` and terminates the + execution.. + """ + self.task = PrepareTask(target_fn=func, + timeout=self.timeout, + args=args, + kwargs=kwargs) try: start = time.perf_counter() yield - logger.info(f"({self.task._fn_name}) Elapsed Time: {time.perf_counter() - start}") + logger.info( + f"({self.task._fn_name}) Elapsed Time: {time.perf_counter() - start}" + ) except Exception as e: - logger.exception(f"An exception of type {type(e).__name__} occurred. " - f"Arguments:\n{e.args[0]!r}") + logger.exception( + f"An exception of type {type(e).__name__} occurred. " + f"Arguments:\n{e.args[0]!r}") os._exit(status=os.EX_TEMPFAIL) diff --git a/openfl/utilities/fedcurv/torch/fedcurv.py b/openfl/utilities/fedcurv/torch/fedcurv.py index 0e18de1a3a..16455a3aca 100644 --- a/openfl/utilities/fedcurv/torch/fedcurv.py +++ b/openfl/utilities/fedcurv/torch/fedcurv.py @@ -9,12 +9,13 @@ def register_buffer(module: torch.nn.Module, name: str, value: torch.Tensor): - """Add a buffer to module. + """Add a buffer to a module. Args: - module: Module - name: Buffer name. Supports complex module names like 'model.conv1.bias'. - value: Buffer value + module (torch.nn.Module): The module to add the buffer to. + name (str): The name of the buffer. Supports complex module names like + 'model.conv1.bias'. + value (torch.Tensor): The value of the buffer. """ module_path, _, name = name.rpartition('.') mod = module.get_submodule(module_path) @@ -22,18 +23,26 @@ def register_buffer(module: torch.nn.Module, name: str, value: torch.Tensor): def get_buffer(module, target): - """Get module buffer. + """Get a buffer from a module. Remove after pinning to a version where https://github.com/pytorch/pytorch/pull/61429 is included. Use module.get_buffer() instead. + + Args: + module (torch.nn.Module): The module to get the buffer from. + target (str): The name of the buffer to get. + + Returns: + torch.Tensor: The buffer. """ module_path, _, buffer_name = target.rpartition('.') mod: torch.nn.Module = module.get_submodule(module_path) if not hasattr(mod, buffer_name): - raise AttributeError(f'{mod._get_name()} has no attribute `{buffer_name}`') + raise AttributeError( + f'{mod._get_name()} has no attribute `{buffer_name}`') buffer: torch.Tensor = getattr(mod, buffer_name) @@ -46,21 +55,35 @@ def get_buffer(module, target): class FedCurv: """Federated Curvature class. + This class implements the FedCurv algorithm for federated learning. Requires torch>=1.9.0. + + Args: + model (torch.nn.Module): The base model. Parameters of it are used in + loss penalty calculation. + importance (float): The lambda coefficient of the FedCurv algorithm. """ def __init__(self, model: torch.nn.Module, importance: float): - """Initialize. + """Initialize the FedCurv object. Args: - model: Base model. Parameters of it are used in loss penalty calculation. - importance: Lambda coefficient of FedCurv algorithm. + model (torch.nn.Module): The base model. Parameters of it are used + in loss penalty calculation. + importance (float): The lambda coefficient of the FedCurv + algorithm. """ self.importance = importance self._params = {} self._register_fisher_parameters(model) def _register_fisher_parameters(self, model): + """Register the Fisher parameters of the model. + + Args: + model (torch.nn.Module): The model to register the Fisher + parameters for. + """ params = list(model.named_parameters()) for n, p in params: u = torch.zeros_like(p, requires_grad=False) @@ -78,9 +101,28 @@ def _register_fisher_parameters(self, model): setattr(self, f'{n}_w', w) def _update_params(self, model): - self._params = deepcopy({n: p for n, p in model.named_parameters() if p.requires_grad}) + """Update the parameters of the model. + + Args: + model (torch.nn.Module): The model to update the parameters for. + """ + self._params = deepcopy({ + n: p + for n, p in model.named_parameters() if p.requires_grad + }) def _diag_fisher(self, model, data_loader, device): + """Calculate the diagonal of the Fisher information matrix. + + Args: + model (torch.nn.Module): The model to calculate the Fisher + information matrix for. + data_loader (Iterable): The data loader for the training data. + device (str): The device to perform the calculations on. + + Returns: + dict: The diagonal of the Fisher information matrix. + """ precision_matrices = {} for n, p in self._params.items(): p.data.zero_() @@ -98,7 +140,8 @@ def _diag_fisher(self, model, data_loader, device): for n, p in model.named_parameters(): if p.requires_grad: - precision_matrices[n].data = p.grad.data ** 2 / len(data_loader) + precision_matrices[n].data = p.grad.data**2 / len( + data_loader) return precision_matrices @@ -106,28 +149,30 @@ def get_penalty(self, model): """Calculate the penalty term for the loss function. Args: - model(torch.nn.Module): Model that stores global u_t and v_t values as buffers. + model (torch.nn.Module): The model to calculate the penalty for. + Stores global u_t and v_t values as buffers. Returns: - float: Penalty term. + float: The penalty term. """ penalty = 0 if not self._params: return penalty for name, param in model.named_parameters(): if param.requires_grad: - u_global, v_global, w_global = ( - get_buffer(model, target).detach() - for target in (f'{name}_u', f'{name}_v', f'{name}_w') - ) - u_local, v_local, w_local = ( - getattr(self, name).detach() - for name in (f'{name}_u', f'{name}_v', f'{name}_w') - ) + u_global, v_global, w_global = (get_buffer(model, + target).detach() + for target in (f'{name}_u', + f'{name}_v', + f'{name}_w')) + u_local, v_local, w_local = (getattr(self, name).detach() + for name in (f'{name}_u', + f'{name}_v', + f'{name}_w')) u = u_global - u_local v = v_global - v_local w = w_global - w_local - _penalty = param ** 2 * u - 2 * param * v + w + _penalty = param**2 * u - 2 * param * v + w penalty += _penalty.sum() penalty = self.importance * penalty return penalty.float() @@ -136,25 +181,24 @@ def on_train_begin(self, model): """Pre-train steps. Args: - model(torch.nn.Module): model for training. + model (torch.nn.Module): The model for training. """ self._update_params(model) def on_train_end(self, model: torch.nn.Module, data_loader, device): - """Post-train steps. + """Perform post-training steps. Args: - model(torch.nn.Module): Trained model. - data_loader(Iterable): Train dataset iterator. - device(str): Model device. - loss_fn(Callable): Train loss function. + model (torch.nn.Module): The trained model. + data_loader (Iterable): The data loader for the training data. + device (str): The device that the model was trained on. """ precision_matrices = self._diag_fisher(model, data_loader, device) for n, m in precision_matrices.items(): u = m.data.to(device) v = m.data * model.get_parameter(n) v = v.to(device) - w = m.data * model.get_parameter(n) ** 2 + w = m.data * model.get_parameter(n)**2 w = w.to(device) register_buffer(model, f'{n}_u', u.clone().detach()) register_buffer(model, f'{n}_v', v.clone().detach()) diff --git a/openfl/utilities/logs.py b/openfl/utilities/logs.py index 1d804fe0be..756c2f42ad 100644 --- a/openfl/utilities/logs.py +++ b/openfl/utilities/logs.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Logs utilities.""" import logging @@ -13,26 +12,48 @@ def get_writer(): - """Create global writer object.""" + """Create global writer object. + + This function creates a global `SummaryWriter` object for logging to + TensorBoard. + """ global writer if not writer: writer = SummaryWriter('./logs/tensorboard', flush_secs=5) def write_metric(node_name, task_name, metric_name, metric, round_number): - """Write metric callback.""" + """Write metric callback. + + This function logs a metric to TensorBoard. + + Args: + node_name (str): The name of the node. + task_name (str): The name of the task. + metric_name (str): The name of the metric. + metric (float): The value of the metric. + round_number (int): The current round number. + """ get_writer() - writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number) + writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, + round_number) def setup_loggers(log_level=logging.INFO): - """Configure loggers.""" + """Configure loggers. + + This function sets up the root logger to log messages with a certain + minimum level and a specific format. + + Args: + log_level (int, optional): The minimum level of messages to log. + Defaults to logging.INFO. + """ root = logging.getLogger() root.setLevel(log_level) console = Console(width=160) handler = RichHandler(console=console) formatter = logging.Formatter( - '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' - ) + '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s') handler.setFormatter(formatter) root.addHandler(handler) diff --git a/openfl/utilities/mocks.py b/openfl/utilities/mocks.py index a6b6206b71..77f751ac4c 100644 --- a/openfl/utilities/mocks.py +++ b/openfl/utilities/mocks.py @@ -1,10 +1,11 @@ # Copyright (C) 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""Mock objects to eliminate extraneous dependencies""" +"""Mock objects to eliminate extraneous dependencies.""" class MockDataLoader: - """Placeholder dataloader for when data is not available""" + """Placeholder dataloader for when data is not available.""" + def __init__(self, feature_shape): self.feature_shape = feature_shape diff --git a/openfl/utilities/optimizers/__init__.py b/openfl/utilities/optimizers/__init__.py index 57170411de..41365cb222 100644 --- a/openfl/utilities/optimizers/__init__.py +++ b/openfl/utilities/optimizers/__init__.py @@ -1,4 +1,3 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Optimizers package.""" diff --git a/openfl/utilities/optimizers/keras/__init__.py b/openfl/utilities/optimizers/keras/__init__.py index 82a6941e6b..364e483f39 100644 --- a/openfl/utilities/optimizers/keras/__init__.py +++ b/openfl/utilities/optimizers/keras/__init__.py @@ -1,8 +1,7 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Keras optimizers package.""" import pkgutil if pkgutil.find_loader('tensorflow'): - from .fedprox import FedProxOptimizer # NOQA + from .fedprox import FedProxOptimizer # NOQA diff --git a/openfl/utilities/optimizers/keras/fedprox.py b/openfl/utilities/optimizers/keras/fedprox.py index 3e50ae620d..c29701fc59 100644 --- a/openfl/utilities/optimizers/keras/fedprox.py +++ b/openfl/utilities/optimizers/keras/fedprox.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """FedProx Keras optimizer module.""" import tensorflow as tf import tensorflow.keras as keras @@ -11,11 +10,34 @@ class FedProxOptimizer(keras.optimizers.Optimizer): """FedProx optimizer. + Implements the FedProx algorithm as a Keras optimizer. FedProx is a + federated learning optimization algorithm designed to handle non-IID data. + It introduces a proximal term to the federated averaging algorithm to + reduce the impact of devices with outlying updates. + Paper: https://arxiv.org/pdf/1812.06127.pdf + + Attributes: + learning_rate (float): The learning rate for the optimizer. + mu (float): The proximal term coefficient. """ - def __init__(self, learning_rate=0.01, mu=0.01, name='FedProxOptimizer', **kwargs): - """Initialize.""" + def __init__(self, + learning_rate=0.01, + mu=0.01, + name='FedProxOptimizer', + **kwargs): + """Initialize the FedProxOptimizer. + + Args: + learning_rate (float, optional): The learning rate for the + optimizer. Defaults to 0.01. + mu (float, optional): The proximal term coefficient. Defaults} + to 0.01. + name (str, optional): The name of the optimizer. Defaults to + 'FedProxOptimizer'. + **kwargs: Additional keyword arguments. + """ super().__init__(name=name, **kwargs) self._set_hyper('learning_rate', learning_rate) @@ -25,35 +47,80 @@ def __init__(self, learning_rate=0.01, mu=0.01, name='FedProxOptimizer', **kwarg self._mu_t = None def _prepare(self, var_list): - self._lr_t = tf.convert_to_tensor(self._get_hyper('learning_rate'), name='lr') + """Prepare the optimizer's state. + + Args: + var_list (list): List of variables to be optimized. + """ + self._lr_t = tf.convert_to_tensor(self._get_hyper('learning_rate'), + name='lr') self._mu_t = tf.convert_to_tensor(self._get_hyper('mu'), name='mu') def _create_slots(self, var_list): + """Create slots for the optimizer's state. + + Args: + var_list (list): List of variables to be optimized. + """ for v in var_list: self.add_slot(v, 'vstar') def _resource_apply_dense(self, grad, var): + """Apply gradients to variables. + + Args: + grad: Gradients. + var: Variables. + + Returns: + A tf.Operation that applies the specified gradients. + """ lr_t = tf.cast(self._lr_t, var.dtype.base_dtype) mu_t = tf.cast(self._mu_t, var.dtype.base_dtype) vstar = self.get_slot(var, 'vstar') var_update = var.assign_sub(lr_t * (grad + mu_t * (var - vstar))) - return tf.group(*[var_update, ]) + return tf.group(*[ + var_update, + ]) def _apply_sparse_shared(self, grad, var, indices, scatter_add): + """Apply sparse gradients to variables. + + Args: + grad: Gradients. + var: Variables. + indices: A tensor of indices into the first dimension of `var`. + scatter_add: A scatter operation. + + Returns: + A tf.Operation that applies the specified gradients. + """ lr_t = tf.cast(self._lr_t, var.dtype.base_dtype) mu_t = tf.cast(self._mu_t, var.dtype.base_dtype) vstar = self.get_slot(var, 'vstar') - v_diff = vstar.assign(mu_t * (var - vstar), use_locking=self._use_locking) + v_diff = vstar.assign(mu_t * (var - vstar), + use_locking=self._use_locking) with tf.control_dependencies([v_diff]): scaled_grad = scatter_add(vstar, indices, grad) var_update = var.assign_sub(lr_t * scaled_grad) - return tf.group(*[var_update, ]) + return tf.group(*[ + var_update, + ]) def _resource_apply_sparse(self, grad, var): + """Apply sparse gradients to variables. + + Args: + grad: Gradients. + var: Variables. + + Returns: + A tf.Operation that applies the specified gradients. + """ return self._apply_sparse_shared( grad.values, var, grad.indices, lambda x, i, v: standard_ops.scatter_add(x, i, v)) @@ -67,11 +134,11 @@ def get_config(self): (without any saved state) from this configuration. Returns: - Python dictionary. + dict: The optimizer configuration. """ base_config = super(FedProxOptimizer, self).get_config() return { - **base_config, - 'lr': self._serialize_hyperparameter('learning_rate'), + **base_config, 'lr': + self._serialize_hyperparameter('learning_rate'), 'mu': self._serialize_hyperparameter('mu') } diff --git a/openfl/utilities/optimizers/numpy/__init__.py b/openfl/utilities/optimizers/numpy/__init__.py index b6498c36b8..1f0ff52965 100644 --- a/openfl/utilities/optimizers/numpy/__init__.py +++ b/openfl/utilities/optimizers/numpy/__init__.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Numpy optimizers package.""" from .adagrad_optimizer import NumPyAdagrad from .adam_optimizer import NumPyAdam diff --git a/openfl/utilities/optimizers/numpy/adagrad_optimizer.py b/openfl/utilities/optimizers/numpy/adagrad_optimizer.py index 92f0f08042..77561b89b4 100644 --- a/openfl/utilities/optimizers/numpy/adagrad_optimizer.py +++ b/openfl/utilities/optimizers/numpy/adagrad_optimizer.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Adagrad optimizer module.""" from typing import Dict @@ -14,7 +13,21 @@ class NumPyAdagrad(Optimizer): """Adagrad optimizer implementation. + Implements the Adagrad optimization algorithm using NumPy. Adagrad is an + algorithm for gradient-based optimization that adapts the learning rate to + the parameters, performing smaller updates for parameters associated with + frequently occurring features, and larger updates for parameters + associated with infrequent features. + Original paper: http://jmlr.org/papers/v12/duchi11a.html + + Attributes: + params (dict, optional): Parameters to be stored for optimization. + model_interface: Model interface instance to provide parameters. + learning_rate (float): Tuning parameter that determines the step size + at each iteration. + initial_accumulator_value (float): Initial value for squared gradients. + epsilon (float): Value for computational stability. """ def __init__( @@ -26,31 +39,44 @@ def __init__( initial_accumulator_value: float = 0.1, epsilon: float = 1e-10, ) -> None: - """Initialize. + """Initialize the Adagrad optimizer. Args: - params: Parameters to be stored for optimization. + params (dict, optional): Parameters to be stored for optimization. + Defaults to None. model_interface: Model interface instance to provide parameters. - learning_rate: Tuning parameter that determines - the step size at each iteration. - initial_accumulator_value: Initial value for squared gradients. - epsilon: Value for computational stability. + Defaults to None. + learning_rate (float, optional): Tuning parameter that determines + the step size at each iteration. Defaults to 0.01. + initial_accumulator_value (float, optional): Initial value for + squared gradients. Defaults to 0.1. + epsilon (float, optional): Value for computational stability. + Defaults to 1e-10. + + Raises: + ValueError: If both params and model_interface are None. + ValueError: If learning_rate is less than 0. + ValueError: If initial_accumulator_value is less than 0. + ValueError: If epsilon is less than or equal to 0. """ super().__init__() if model_interface is None and params is None: - raise ValueError('Should provide one of the params or model_interface') + raise ValueError( + 'Should provide one of the params or model_interface') if learning_rate < 0: raise ValueError( - f'Invalid learning rate: {learning_rate}. Learning rate must be >= 0.') + f'Invalid learning rate: {learning_rate}. Learning rate must be >= 0.' + ) if initial_accumulator_value < 0: raise ValueError( f'Invalid initial_accumulator_value value: {initial_accumulator_value}.' 'Initial accumulator value must be >= 0.') if epsilon <= 0: raise ValueError( - f'Invalid epsilon value: {epsilon}. Epsilon avalue must be > 0.') + f'Invalid epsilon value: {epsilon}. Epsilon avalue must be > 0.' + ) self.params = params @@ -63,27 +89,39 @@ def __init__( self.grads_squared = {} for param_name in self.params: - self.grads_squared[param_name] = np.full_like(self.params[param_name], - self.initial_accumulator_value) + self.grads_squared[param_name] = np.full_like( + self.params[param_name], self.initial_accumulator_value) def _update_param(self, grad_name: str, grad: np.ndarray) -> None: - """Update papams by given gradients.""" - self.params[grad_name] -= (self.learning_rate * grad - / (np.sqrt(self.grads_squared[grad_name]) + self.epsilon)) + """Update parameters by given gradients. - def step(self, gradients: Dict[str, np.ndarray]) -> None: + Args: + grad_name (str): The name of the gradient. + grad (np.ndarray): The gradient values. """ - Perform a single step for parameter update. + self.params[grad_name] -= ( + self.learning_rate * grad + / (np.sqrt(self.grads_squared[grad_name]) + self.epsilon)) + + def step(self, gradients: Dict[str, np.ndarray]) -> None: + """Perform a single step for parameter update. Implement Adagrad optimizer weights update rule. Args: - gradients: Partial derivatives with respect to optimized parameters. + gradients (dict): Partial derivatives with respect to optimized + parameters. + + Raises: + KeyError: If a key in gradients does not exist in optimized + parameters. """ for grad_name in gradients: if grad_name not in self.grads_squared: - raise KeyError(f"Key {grad_name} doesn't exist in optimized parameters") + raise KeyError( + f"Key {grad_name} doesn't exist in optimized parameters") grad = gradients[grad_name] - self.grads_squared[grad_name] = self.grads_squared[grad_name] + grad**2 + self.grads_squared[ + grad_name] = self.grads_squared[grad_name] + grad**2 self._update_param(grad_name, grad) diff --git a/openfl/utilities/optimizers/numpy/adam_optimizer.py b/openfl/utilities/optimizers/numpy/adam_optimizer.py index 8660a59855..14f0cd9550 100644 --- a/openfl/utilities/optimizers/numpy/adam_optimizer.py +++ b/openfl/utilities/optimizers/numpy/adam_optimizer.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Adam optimizer module.""" from typing import Dict @@ -15,7 +14,23 @@ class NumPyAdam(Optimizer): """Adam optimizer implementation. + Implements the Adam optimization algorithm using NumPy. + Adam is an algorithm for first-order gradient-based optimization of + stochastic objective functions, based on adaptive estimates of lower-order + moments. + Original paper: https://openreview.net/forum?id=ryQu7f-RZ + + Attributes: + params (dict, optional): Parameters to be stored for optimization. + model_interface: Model interface instance to provide parameters. + learning_rate (float): Tuning parameter that determines the step size + at each iteration. + betas (tuple): Coefficients used for computing running averages of + gradient and its square. + initial_accumulator_value (float): Initial value for gradients and + squared gradients. + epsilon (float): Value for computational stability. """ def __init__( @@ -28,40 +43,56 @@ def __init__( initial_accumulator_value: float = 0.0, epsilon: float = 1e-8, ) -> None: - """Initialize. + """Initialize the Adam optimizer. Args: - params: Parameters to be stored for optimization. + params (dict, optional): Parameters to be stored for optimization. + Defaults to None. model_interface: Model interface instance to provide parameters. - learning_rate: Tuning parameter that determines - the step size at each iteration. - betas: Coefficients used for computing running - averages of gradient and its square. - initial_accumulator_value: Initial value for gradients - and squared gradients. - epsilon: Value for computational stability. + Defaults to None. + learning_rate (float, optional): Tuning parameter that determines + the step size at each iteration. Defaults to 0.01. + betas (tuple, optional): Coefficients used for computing running + averages of gradient and its square. Defaults to (0.9, 0.999). + initial_accumulator_value (float, optional): Initial value for + gradients and squared gradients. Defaults to 0.0. + epsilon (float, optional): Value for computational stability. + Defaults to 1e-8. + + Raises: + ValueError: If both params and model_interface are None. + ValueError: If learning_rate is less than 0. + ValueError: If betas[0] is not in [0, 1). + ValueError: If betas[1] is not in [0, 1). + ValueError: If initial_accumulator_value is less than 0. + ValueError: If epsilon is less than or equal to 0. """ super().__init__() if model_interface is None and params is None: - raise ValueError('Should provide one of the params or model_interface') + raise ValueError( + 'Should provide one of the params or model_interface') if learning_rate < 0: raise ValueError( - f'Invalid learning rate: {learning_rate}. Learning rate must be >= 0.') + f'Invalid learning rate: {learning_rate}. Learning rate must be >= 0.' + ) if not 0.0 <= betas[0] < 1: raise ValueError( - f'Invalid betas[0] value: {betas[0]}. betas[0] must be in [0, 1).') + f'Invalid betas[0] value: {betas[0]}. betas[0] must be in [0, 1).' + ) if not 0.0 <= betas[1] < 1: raise ValueError( - f'Invalid betas[1] value: {betas[1]}. betas[1] must be in [0, 1).') + f'Invalid betas[1] value: {betas[1]}. betas[1] must be in [0, 1).' + ) if initial_accumulator_value < 0: raise ValueError( f'Invalid initial_accumulator_value value: {initial_accumulator_value}. \ Initial accumulator value must be >= 0.') if epsilon <= 0: raise ValueError( - f'Invalid epsilon value: {epsilon}. Epsilon avalue must be > 0.') + f'Invalid epsilon value: {epsilon}. Epsilon avalue must be > 0.' + ) self.params = params @@ -77,35 +108,50 @@ def __init__( self.grads_first_moment, self.grads_second_moment = {}, {} for param_name in self.params: - self.grads_first_moment[param_name] = np.full_like(self.params[param_name], - self.initial_accumulator_value) - self.grads_second_moment[param_name] = np.full_like(self.params[param_name], - self.initial_accumulator_value) + self.grads_first_moment[param_name] = np.full_like( + self.params[param_name], self.initial_accumulator_value) + self.grads_second_moment[param_name] = np.full_like( + self.params[param_name], self.initial_accumulator_value) def _update_first_moment(self, grad_name: str, grad: np.ndarray) -> None: - """Update gradients first moment.""" + """Update gradients first moment. + + Args: + grad_name (str): The name of the gradient. + grad (np.ndarray): The gradient values. + """ self.grads_first_moment[grad_name] = (self.beta_1 * self.grads_first_moment[grad_name] + ((1.0 - self.beta_1) * grad)) def _update_second_moment(self, grad_name: str, grad: np.ndarray) -> None: - """Update gradients second moment.""" + """Update gradients second moment. + + Args: + grad_name (str): The name of the gradient. + grad (np.ndarray): The gradient values. + """ self.grads_second_moment[grad_name] = (self.beta_2 * self.grads_second_moment[grad_name] + ((1.0 - self.beta_2) * grad**2)) def step(self, gradients: Dict[str, np.ndarray]) -> None: - """ - Perform a single step for parameter update. + """Perform a single step for parameter update. Implement Adam optimizer weights update rule. Args: - gradients: Partial derivatives with respect to optimized parameters. + gradients (dict): Partial derivatives with respect to optimized + parameters. + + Raises: + KeyError: If a key in gradients does not exist in optimized + parameters. """ for grad_name in gradients: if grad_name not in self.grads_first_moment: - raise KeyError(f"Key {grad_name} doesn't exist in optimized parameters") + raise KeyError( + f"Key {grad_name} doesn't exist in optimized parameters") grad = gradients[grad_name] @@ -116,11 +162,12 @@ def step(self, gradients: Dict[str, np.ndarray]) -> None: mean = self.grads_first_moment[grad_name] var = self.grads_second_moment[grad_name] - grads_first_moment_normalized = mean / (1. - self.beta_1 ** t) - grads_second_moment_normalized = var / (1. - self.beta_2 ** t) + grads_first_moment_normalized = mean / (1. - self.beta_1**t) + grads_second_moment_normalized = var / (1. - self.beta_2**t) # Make an update for a group of parameters - self.params[grad_name] -= (self.learning_rate * grads_first_moment_normalized - / (np.sqrt(grads_second_moment_normalized) + self.epsilon)) + self.params[grad_name] -= ( + self.learning_rate * grads_first_moment_normalized / + (np.sqrt(grads_second_moment_normalized) + self.epsilon)) self.current_step[grad_name] += 1 diff --git a/openfl/utilities/optimizers/numpy/base_optimizer.py b/openfl/utilities/optimizers/numpy/base_optimizer.py index 26e701c152..be095ed5d8 100644 --- a/openfl/utilities/optimizers/numpy/base_optimizer.py +++ b/openfl/utilities/optimizers/numpy/base_optimizer.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Base abstract optimizer class module.""" import abc from importlib import import_module @@ -10,28 +9,46 @@ from numpy import ndarray from openfl.plugins.frameworks_adapters.framework_adapter_interface import ( - FrameworkAdapterPluginInterface -) + FrameworkAdapterPluginInterface) class Optimizer(abc.ABC): - """Base abstract optimizer class.""" + """Base abstract optimizer class. + + This class serves as a base class for all optimizers. It defines the basic + structure that all derived optimizer classes should follow. + It includes an abstract method `step` that must be implemented by any + concrete optimizer class. + """ @abc.abstractmethod def step(self, gradients: Dict[str, ndarray]) -> None: """Perform a single step for parameter update. + This method should be overridden by all subclasses to implement the + specific optimization algorithm. + Args: - gradients: Partial derivatives with respect to optimized parameters. + gradients (dict): Partial derivatives with respect to optimized + parameters. """ pass def _set_params_from_model(self, model_interface): - """Eject and store model parameters.""" + """Eject and store model parameters. + + This method is used to extract the parameters from the provided model + interface and store them in the optimizer. + + Args: + model_interface: The model interface instance to provide + parameters. + """ class_name = splitext(model_interface.framework_plugin)[1].strip('.') module_path = splitext(model_interface.framework_plugin)[0] framework_adapter = import_module(module_path) framework_adapter_plugin: FrameworkAdapterPluginInterface = getattr( framework_adapter, class_name, None) - self.params: Dict[str, ndarray] = framework_adapter_plugin.get_tensor_dict( - model_interface.provide_model()) + self.params: Dict[str, + ndarray] = framework_adapter_plugin.get_tensor_dict( + model_interface.provide_model()) diff --git a/openfl/utilities/optimizers/numpy/yogi_optimizer.py b/openfl/utilities/optimizers/numpy/yogi_optimizer.py index a9984a8613..2a9683ccec 100644 --- a/openfl/utilities/optimizers/numpy/yogi_optimizer.py +++ b/openfl/utilities/optimizers/numpy/yogi_optimizer.py @@ -1,7 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -"""Adam optimizer module.""" +"""Yogi optimizer module.""" from typing import Dict from typing import Optional @@ -15,8 +14,25 @@ class NumPyYogi(NumPyAdam): """Yogi optimizer implementation. + Implements the Yogi optimization algorithm using NumPy. + Yogi is an algorithm for first-order gradient-based optimization of + stochastic objective functions, based on adaptive estimates of lower-order + moments. It is a variant of Adam and it is more robust to large learning + rates. + Original paper: https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization + + Attributes: + params (dict, optional): Parameters to be stored for optimization. + model_interface: Model interface instance to provide parameters. + learning_rate (float): Tuning parameter that determines the step size + at each iteration. + betas (tuple): Coefficients used for computing running averages of + gradient and its square. + initial_accumulator_value (float): Initial value for gradients and + squared gradients. + epsilon (float): Value for computational stability. """ def __init__( @@ -29,18 +45,21 @@ def __init__( initial_accumulator_value: float = 0.0, epsilon: float = 1e-8, ) -> None: - """Initialize. + """Initialize the Yogi optimizer. Args: - params: Parameters to be stored for optimization. + params (dict, optional): Parameters to be stored for optimization. + Defaults to None. model_interface: Model interface instance to provide parameters. - learning_rate: Tuning parameter that determines - the step size at each iteration. - betas: Coefficients used for computing running - averages of gradient and its square. - initial_accumulator_value: Initial value for gradients - and squared gradients. - epsilon: Value for computational stability. + Defaults to None. + learning_rate (float, optional): Tuning parameter that determines + the step size at each iteration. Defaults to 0.01. + betas (tuple, optional): Coefficients used for computing running + averages of gradient and its square. Defaults to (0.9, 0.999). + initial_accumulator_value (float, optional): Initial value for + gradients and squared gradients. Defaults to 0.0. + epsilon (float, optional): Value for computational stability. + Defaults to 1e-8. """ super().__init__(params=params, model_interface=model_interface, @@ -50,19 +69,24 @@ def __init__( epsilon=epsilon) def _update_second_moment(self, grad_name: str, grad: np.ndarray) -> None: - """Override second moment update rule for Yogi optimization updates.""" + """Override second moment update rule for Yogi optimization updates. + + Args: + grad_name (str): The name of the gradient. + grad (np.ndarray): The gradient values. + """ sign = np.sign(grad**2 - self.grads_second_moment[grad_name]) - self.grads_second_moment[grad_name] = (self.beta_2 - * self.grads_second_moment[grad_name] - + (1.0 - self.beta_2) * sign * grad**2) + self.grads_second_moment[grad_name] = ( + self.beta_2 * self.grads_second_moment[grad_name] + + (1.0 - self.beta_2) * sign * grad**2) def step(self, gradients: Dict[str, np.ndarray]) -> None: - """ - Perform a single step for parameter update. + """Perform a single step for parameter update. Implement Yogi optimizer weights update rule. Args: - gradients: Partial derivatives with respect to optimized parameters. + gradients (dict): Partial derivatives with respect to optimized + parameters. """ super().step(gradients) diff --git a/openfl/utilities/optimizers/torch/__init__.py b/openfl/utilities/optimizers/torch/__init__.py index 0facde5af4..73c41ada26 100644 --- a/openfl/utilities/optimizers/torch/__init__.py +++ b/openfl/utilities/optimizers/torch/__init__.py @@ -1,9 +1,8 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """PyTorch optimizers package.""" import pkgutil if pkgutil.find_loader('torch'): - from .fedprox import FedProxOptimizer # NOQA - from .fedprox import FedProxAdam # NOQA + from .fedprox import FedProxOptimizer # NOQA + from .fedprox import FedProxAdam # NOQA diff --git a/openfl/utilities/optimizers/torch/fedprox.py b/openfl/utilities/optimizers/torch/fedprox.py index caa6254b5d..1f0afb1e23 100644 --- a/openfl/utilities/optimizers/torch/fedprox.py +++ b/openfl/utilities/optimizers/torch/fedprox.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """PyTorch FedProx optimizer module.""" import math @@ -13,7 +12,22 @@ class FedProxOptimizer(Optimizer): """FedProx optimizer. + Implements the FedProx optimization algorithm using PyTorch. + FedProx is a federated learning optimization algorithm designed to handle + non-IID data. + It introduces a proximal term to the federated averaging algorithm to + reduce the impact of devices with outlying updates. + Paper: https://arxiv.org/pdf/1812.06127.pdf + + Attributes: + params: Parameters to be stored for optimization. + lr: Learning rate. + mu: Proximal term coefficient. + momentum: Momentum factor. + dampening: Dampening for momentum. + weight_decay: Weight decay (L2 penalty). + nesterov: Enables Nesterov momentum. """ def __init__(self, @@ -24,7 +38,23 @@ def __init__(self, dampening=0, weight_decay=0, nesterov=False): - """Initialize.""" + """Initialize the FedProx optimizer. + + Args: + params: Parameters to be stored for optimization. + lr: Learning rate. + mu: Proximal term coefficient. Defaults to 0.0. + momentum: Momentum factor. Defaults to 0. + dampening: Dampening for momentum. Defaults to 0. + weight_decay: Weight decay (L2 penalty). Defaults to 0. + nesterov: Enables Nesterov momentum. Defaults to False + + Raises: + ValueError: If momentum is less than 0. + ValueError: If learning rate is less than 0. + ValueError: If weight decay is less than 0. + ValueError: If mu is less than 0. + """ if momentum < 0.0: raise ValueError(f'Invalid momentum value: {momentum}') if lr is not required and lr < 0.0: @@ -43,12 +73,17 @@ def __init__(self, } if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError('Nesterov momentum requires a momentum and zero dampening') + raise ValueError( + 'Nesterov momentum requires a momentum and zero dampening') super(FedProxOptimizer, self).__init__(params, defaults) def __setstate__(self, state): - """Set optimizer state.""" + """Set optimizer state. + + Args: + state: State dictionary. + """ super(FedProxOptimizer, self).__setstate__(state) for group in self.param_groups: group.setdefault('nesterov', False) @@ -57,9 +92,12 @@ def __setstate__(self, state): def step(self, closure=None): """Perform a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. + + Returns: + Loss value if closure is provided. None otherwise. """ loss = None if closure is not None: @@ -81,7 +119,8 @@ def step(self, closure=None): if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: - buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() + buf = param_state['momentum_buffer'] = torch.clone( + d_p).detach() else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(d_p, alpha=1 - dampening) @@ -96,17 +135,59 @@ def step(self, closure=None): return loss def set_old_weights(self, old_weights): - """Set the global weights parameter to `old_weights` value.""" + """Set the global weights parameter to `old_weights` value. + + Args: + old_weights: The old weights to be set. + """ for param_group in self.param_groups: param_group['w_old'] = old_weights class FedProxAdam(Optimizer): - """FedProxAdam optimizer.""" + """FedProxAdam optimizer. + + Implements the FedProx optimization algorithm with Adam optimizer. + + Attributes: + params: Parameters to be stored for optimization. + mu: Proximal term coefficient. + lr: Learning rate. + betas: Coefficients used for computing running averages of gradient + and its square. + eps: Value for computational stability. + weight_decay: Weight decay (L2 penalty). + amsgrad: Whether to use the AMSGrad variant of this algorithm. + """ + + def __init__(self, + params, + mu=0, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False): + """Initialize the FedProxAdam optimizer. - def __init__(self, params, mu=0, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False): - """Initialize.""" + Args: + params: Parameters to be stored for optimization. + mu: Proximal term coefficient. Defaults to 0. + lr: Learning rate. Defaults to 1e-3. + betas: Coefficients used for computing running averages of + gradient and its square. Defaults to (0.9, 0.999). + eps: Value for computational stability. Defaults to 1e-8. + weight_decay: Weight decay (L2 penalty). Defaults to 0. + amsgrad: Whether to use the AMSGrad variant of this algorithm. + Defaults to False. + + Raises: + ValueError: If learning rate is less than 0. + ValueError: If betas[0] is not in [0, 1). + ValueError: If betas[1] is not in [0, 1). + ValueError: If weight decay is less than 0. + ValueError: If mu is less than 0. + """ if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -119,18 +200,32 @@ def __init__(self, params, mu=0, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, raise ValueError(f'Invalid weight_decay value: {weight_decay}') if mu < 0.0: raise ValueError(f'Invalid mu value: {mu}') - defaults = {'lr': lr, 'betas': betas, 'eps': eps, - 'weight_decay': weight_decay, 'amsgrad': amsgrad, 'mu': mu} + defaults = { + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, + 'amsgrad': amsgrad, + 'mu': mu + } super(FedProxAdam, self).__init__(params, defaults) def __setstate__(self, state): - """Set optimizer state.""" + """Set optimizer state. + + Args: + state: State dictionary. + """ super(FedProxAdam, self).__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False) def set_old_weights(self, old_weights): - """Set the global weights parameter to `old_weights` value.""" + """Set the global weights parameter to `old_weights` value. + + Args: + old_weights: The old weights to be set. + """ for param_group in self.param_groups: param_group['w_old'] = old_weights @@ -141,6 +236,9 @@ def step(self, closure=None): Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. + + Returns: + Loss value if closure is provided. None otherwise. """ loss = None if closure is not None: @@ -169,7 +267,8 @@ def step(self, closure=None): if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( p, memory_format=torch.preserve_format) @@ -190,38 +289,35 @@ def step(self, closure=None): state_steps.append(state['step']) beta1, beta2 = group['betas'] - self.adam(params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - group['amsgrad'], - beta1, - beta2, - group['lr'], - group['weight_decay'], - group['eps'], - group['mu'], - group['w_old'] - ) + self.adam(params_with_grad, grads, exp_avgs, exp_avg_sqs, + max_exp_avg_sqs, state_steps, group['amsgrad'], beta1, + beta2, group['lr'], group['weight_decay'], group['eps'], + group['mu'], group['w_old']) return loss - def adam(self, params, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - mu: float, - w_old): - """Updtae optimizer parameters.""" + def adam(self, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps, amsgrad, beta1: float, beta2: float, lr: float, + weight_decay: float, eps: float, mu: float, w_old): + """Update optimizer parameters. + + Args: + params: Parameters to be stored for optimization. + grads: Gradients. + exp_avgs: Exponential moving average of gradient values. + exp_avg_sqs: Exponential moving average of squared gradient values. + max_exp_avg_sqs: Maintains max of all exp. moving avg. of sq. grad. values. + state_steps: Steps for each param group update. + amsgrad: Whether to use the AMSGrad variant of this algorithm. + beta1 (float): Coefficient used for computing running averages of + gradient. + beta2 (float): Coefficient used for computing running averages of + squared gradient. + lr (float): Learning rate. + weight_decay (float): Weight decay (L2 penalty). + eps (float): Value for computational stability. + mu (float): Proximal term coefficient. + w_old: The old weights. + """ for i, param in enumerate(params): w_old_p = w_old[i] grad = grads[i] @@ -230,8 +326,8 @@ def adam(self, params, exp_avg_sq = exp_avg_sqs[i] step = state_steps[i] - bias_correction1 = 1 - beta1 ** step - bias_correction2 = 1 - beta2 ** step + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) @@ -241,7 +337,9 @@ def adam(self, params, exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now - torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) + torch.maximum(max_exp_avg_sqs[i], + exp_avg_sq, + out=max_exp_avg_sqs[i]) # Use the max. for normalizing running avg. of gradient denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) else: diff --git a/openfl/utilities/path_check.py b/openfl/utilities/path_check.py index bdd272b05a..175953c9d2 100644 --- a/openfl/utilities/path_check.py +++ b/openfl/utilities/path_check.py @@ -1,7 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -"""openfl path checks.""" +"""Openfl path checks.""" import os from pathlib import Path @@ -9,7 +8,20 @@ def is_directory_traversal(directory: Union[str, Path]) -> bool: - """Check for directory traversal.""" + """Check for directory traversal. + + This function checks if the provided directory is a subdirectory of the + current working directory. + It returns `True` if the directory is not a subdirectory (i.e., it is a + directory traversal), and `False` otherwise. + + Args: + directory (Union[str, Path]): The directory to check. + + Returns: + bool: `True` if the directory is a directory traversal, `False` + otherwise. + """ cwd = os.path.abspath(os.getcwd()) requested_path = os.path.relpath(directory, start=cwd) requested_path = os.path.abspath(requested_path) diff --git a/openfl/utilities/split.py b/openfl/utilities/split.py index 9692d8e33e..c50e2fbd55 100644 --- a/openfl/utilities/split.py +++ b/openfl/utilities/split.py @@ -1,23 +1,24 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""split tensors module.""" +"""Split tensors module.""" import numpy as np def split_tensor_dict_into_floats_and_non_floats(tensor_dict): - """ - Split the tensor dictionary into float and non-floating point values. + """Split the tensor dictionary into float and non-floating point values. - Splits a tensor dictionary into float and non-float values. + This function splits a tensor dictionary into two dictionaries: one + containing all the floating point tensors and the other containing all the + non-floating point tensors. Args: - tensor_dict: A dictionary of tensors + tensor_dict (dict): A dictionary of tensors. Returns: - Two dictionaries: the first contains all of the floating point tensors - and the second contains all of the non-floating point tensors - + Tuple[dict, dict]: The first dictionary contains all of the floating + point tensors and the second dictionary contains all of the + non-floating point tensors. """ float_dict = {} non_float_dict = {} @@ -30,16 +31,16 @@ def split_tensor_dict_into_floats_and_non_floats(tensor_dict): def split_tensor_dict_by_types(tensor_dict, keep_types): - """ - Split the tensor dictionary into supported and not supported types. + """Split the tensor dictionary into supported and not supported types. Args: - tensor_dict: A dictionary of tensors - keep_types: An iterable of supported types - Returns: - Two dictionaries: the first contains all of the supported tensors - and the second contains all of the not supported tensors + tensor_dict (dict): A dictionary of tensors. + keep_types (Iterable[type]): An iterable of supported types. + Returns: + Tuple[dict, dict]: The first dictionary contains all of the supported + tensors and the second dictionary contains all of the not + supported tensors. """ keep_dict = {} holdout_dict = {} @@ -51,24 +52,28 @@ def split_tensor_dict_by_types(tensor_dict, keep_types): return keep_dict, holdout_dict -def split_tensor_dict_for_holdouts(logger, tensor_dict, +def split_tensor_dict_for_holdouts(logger, + tensor_dict, keep_types=(np.floating, np.integer), holdout_tensor_names=()): - """ - Split a tensor according to tensor types. + """Split a tensor according to tensor types. + + This function splits a tensor dictionary into two dictionaries: one + containing the tensors to send and the other containing the holdout + tensors. Args: - logger: The log object - tensor_dict: A dictionary of tensors - keep_types: A list of types to keep in dictionary of tensors - holdout_tensor_names: A list of tensor names to extract from the - dictionary of tensors + logger (Logger): The logger to use for reporting warnings. + tensor_dict (dict): A dictionary of tensors. + keep_types (Tuple[type, ...], optional): A tuple of types to keep in + the dictionary of tensors. Defaults to (np.floating, np.integer). + holdout_tensor_names (Iterable[str], optional): An iterable of tensor + names to extract from the dictionary of tensors. Defaults to (). Returns: - Two dictionaries: the first is the original tensor dictionary minus - the holdout tenors and the second is a tensor dictionary with only the - holdout tensors - + Tuple[dict, dict]: The first dictionary is the original tensor + dictionary minus the holdout tensors and the second dictionary is + a tensor dictionary with only the holdout tensors. """ # initialization tensors_to_send = tensor_dict.copy() @@ -81,15 +86,14 @@ def split_tensor_dict_for_holdouts(logger, tensor_dict, try: holdout_tensors[tensor_name] = tensors_to_send.pop(tensor_name) except KeyError: - logger.warn(f'tried to remove tensor: {tensor_name} not present ' - f'in the tensor dict') + logger.warn( + f'tried to remove tensor: {tensor_name} not present ' + f'in the tensor dict') continue # filter holdout_types from tensors_to_send and add to holdout_tensors tensors_to_send, not_supported_tensors_dict = split_tensor_dict_by_types( - tensors_to_send, - keep_types - ) + tensors_to_send, keep_types) holdout_tensors = {**holdout_tensors, **not_supported_tensors_dict} return tensors_to_send, holdout_tensors diff --git a/openfl/utilities/types.py b/openfl/utilities/types.py index 369a5f985f..8ce3053444 100644 --- a/openfl/utilities/types.py +++ b/openfl/utilities/types.py @@ -1,25 +1,42 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - -"""openfl common object types.""" +"""Openfl common object types.""" from abc import ABCMeta from collections import namedtuple -TensorKey = namedtuple('TensorKey', ['tensor_name', 'origin', 'round_number', 'report', 'tags']) -TaskResultKey = namedtuple('TaskResultKey', ['task_name', 'owner', 'round_number']) +TensorKey = namedtuple( + 'TensorKey', ['tensor_name', 'origin', 'round_number', 'report', 'tags']) +TaskResultKey = namedtuple('TaskResultKey', + ['task_name', 'owner', 'round_number']) Metric = namedtuple('Metric', ['name', 'value']) LocalTensor = namedtuple('LocalTensor', ['col_name', 'tensor', 'weight']) class SingletonABCMeta(ABCMeta): - """Metaclass for singleton instances.""" + """Metaclass for singleton instances. + + This metaclass ensures that only one instance of any class using it can be + created. + + Attributes: + _instances (dict): A dictionary mapping classes to their instances. + """ _instances = {} def __call__(cls, *args, **kwargs): - """Use the singleton instance if it has already been created.""" + """Use the singleton instance if it has already been created. + + Args: + *args: Positional arguments to pass to the class constructor. + **kwargs: Keyword arguments to pass to the class constructor. + + Returns: + Any: The singleton instance of the class. + """ if cls not in cls._instances: - cls._instances[cls] = super(SingletonABCMeta, cls).__call__(*args, **kwargs) + cls._instances[cls] = super(SingletonABCMeta, + cls).__call__(*args, **kwargs) return cls._instances[cls] diff --git a/openfl/utilities/utils.py b/openfl/utilities/utils.py index 015e067c91..2e92865946 100644 --- a/openfl/utilities/utils.py +++ b/openfl/utilities/utils.py @@ -21,14 +21,19 @@ def getfqdn_env(name: str = '') -> str: - """ - Get the system FQDN, with priority given to environment variables. + """Get the system FQDN, with priority given to environment variables. + + This function retrieves the fully qualified domain name (FQDN) of the + system. + If the 'FQDN' environment variable is set, its value is returned. + Otherwise,the FQDN is determined based on the system's hostname. Args: - name: The name from which to extract the FQDN. + name (str, optional): The name from which to extract the FQDN. + Defaults to ''. Returns: - The FQDN of the system. + str: The FQDN of the system. """ fqdn = os.environ.get('FQDN', None) if fqdn is not None: @@ -37,7 +42,18 @@ def getfqdn_env(name: str = '') -> str: def is_fqdn(hostname: str) -> bool: - """https://en.m.wikipedia.org/wiki/Fully_qualified_domain_name.""" + """Check if a hostname is a fully qualified domain name. + + This function checks if a hostname is a fully qualified domain name (FQDN) + according to the rules specified on Wikipedia. + https://en.m.wikipedia.org/wiki/Fully_qualified_domain_name. + + Args: + hostname (str): The hostname to check. + + Returns: + bool: `True` if the hostname is a FQDN, `False` otherwise. + """ if not 1 < len(hostname) < 253: return False @@ -51,14 +67,24 @@ def is_fqdn(hostname: str) -> bool: # Can begin and end with a number or letter only # Can contain hyphens, a-z, A-Z, 0-9 # 1 - 63 chars allowed - fqdn = re.compile(r'^[a-z0-9]([a-z-0-9-]{0,61}[a-z0-9])?$', re.IGNORECASE) # noqa FS003 + fqdn = re.compile(r'^[a-z0-9]([a-z-0-9-]{0,61}[a-z0-9])?$', + re.IGNORECASE) # noqa FS003 # Check that all labels match that pattern. return all(fqdn.match(label) for label in labels) def is_api_adress(address: str) -> bool: - """Validate ip address value.""" + """Validate IP address value. + + This function checks if a string is a valid IP address. + + Args: + address (str): The string to check. + + Returns: + bool: `True` if the string is a valid IP address, `False` otherwise. + """ try: ipaddress.ip_address(address) return True @@ -67,14 +93,16 @@ def is_api_adress(address: str) -> bool: def add_log_level(level_name, level_num, method_name=None): - """ - Add a new logging level to the logging module. + """Add a new logging level to the logging module. - Args: - level_name: name of log level. - level_num: log level value. - method_name: log method wich will use new log level (default = level_name.lower()) + This function adds a new logging level to the logging module with a + specified name, value, and method name. + Args: + level_name (str): The name of the new logging level. + level_num (int): The value of the new logging level. + method_name (str, optional): The name of the method to use for + the new logging level. Defaults to None. """ if not method_name: method_name = level_name.lower() @@ -95,13 +123,19 @@ def log_to_root(message, *args, **kwargs): def validate_file_hash(file_path, expected_hash, chunk_size=8192): """Validate SHA384 hash for file specified. + This function validates the SHA384 hash of a file against an expected hash. + Args: - file_path(path-like): path-like object giving the pathname - (absolute or relative to the current working directory) - of the file to be opened or an integer file descriptor of the file to be wrapped. - expected_hash(str): hash string to compare with. - hasher(_Hash): hash algorithm. Default value: `hashlib.sha384()` - chunk_size(int): Buffer size for file reading. + file_path (str): The path to the file to validate. + (absolute or relative to the current working directory) of the file + to be opened or an integer file descriptor of the file to be + wrapped. + expected_hash (str): The expected SHA384 hash of the file. + chunk_size (int, optional): The size of the chunks to read from the + file. Defaults to 8192. + + Raises: + SystemError: If the hash of the file does not match the expected hash. """ h = hashlib.sha384() with open(file_path, 'rb') as file: @@ -117,7 +151,14 @@ def validate_file_hash(file_path, expected_hash, chunk_size=8192): def tqdm_report_hook(): - """Visualize downloading.""" + """Visualize downloading. + + This function creates a progress bar for visualizing the progress of a + download. + + Returns: + Callable: A function that updates the progress bar. + """ def report_hook(pbar, count, block_size, total_size): """Update progressbar.""" @@ -131,11 +172,29 @@ def report_hook(pbar, count, block_size, total_size): def merge_configs( - overwrite_dict: Optional[dict] = None, - value_transform: Optional[List[Tuple[str, Callable]]] = None, - **kwargs, + overwrite_dict: Optional[dict] = None, + value_transform: Optional[List[Tuple[str, Callable]]] = None, + **kwargs, ) -> Dynaconf: - """Create Dynaconf settings, merge its with `overwrite_dict` and validate result.""" + """Create Dynaconf settings, merge its with `overwrite_dict` and validate + result. + + This function creates a Dynaconf settings object, merges it with an + optional dictionary, applies an optional value transformation, and + validates the result. + + Args: + overwrite_dict (Optional[dict], optional): A dictionary to merge with + the settings. Defaults to None. + value_transform (Optional[List[Tuple[str, Callable]]], optional): A + list of tuples, each containing a key and a function to apply to + the value of that key. Defaults to None. + **kwargs: Additional keyword arguments to pass to the Dynaconf + constructor. + + Returns: + Dynaconf: The merged and validated settings. + """ settings = Dynaconf(**kwargs, YAML_LOADER='safe_load') if overwrite_dict: for key, value in overwrite_dict.items(): @@ -152,10 +211,20 @@ def merge_configs( def change_tags(tags, *, add_field=None, remove_field=None) -> Tuple[str, ...]: """Change tensor tags to add or remove fields. + This function adds or removes fields from tensor tags. + Args: - tags(tuple): tensor tags. - add_field(str): add a new tensor tag field. - remove_field(str): remove a tensor tag field. + tags (Tuple[str, ...]): The tensor tags. + add_field (str, optional): A new tensor tag field to add. Defaults to + None. + remove_field (str, optional): A tensor tag field to remove. Defaults + to None. + + Returns: + Tuple[str, ...]: The modified tensor tags. + + Raises: + Exception: If `remove_field` is not in `tags`. """ tags = list(set(tags)) @@ -172,9 +241,27 @@ def change_tags(tags, *, add_field=None, remove_field=None) -> Tuple[str, ...]: def rmtree(path, ignore_errors=False): + """Remove a directory tree. + + This function removes a directory tree. If a file in the directory tree is + read-only, its read-only attribute is cleared before it is removed. + + Args: + path (str): The path to the directory tree to remove. + ignore_errors (bool, optional): Whether to ignore errors. Defaults to + False. + + Returns: + str: The path to the removed directory tree. + """ + def remove_readonly(func, path, _): "Clear the readonly bit and reattempt the removal" if os.name == 'nt': - os.chmod(path, stat.S_IWRITE) # Windows can not remove read-only files. + os.chmod(path, + stat.S_IWRITE) # Windows can not remove read-only files. func(path) - return shutil.rmtree(path, ignore_errors=ignore_errors, onerror=remove_readonly) + + return shutil.rmtree(path, + ignore_errors=ignore_errors, + onerror=remove_readonly) diff --git a/openfl/utilities/workspace.py b/openfl/utilities/workspace.py index c1561b726e..0a15d850c1 100644 --- a/openfl/utilities/workspace.py +++ b/openfl/utilities/workspace.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Workspace utils module.""" import logging @@ -20,16 +19,37 @@ class ExperimentWorkspace: - """Experiment workspace context manager.""" - - def __init__( - self, - experiment_name: str, - data_file_path: Path, - install_requirements: bool = False, - remove_archive: bool = True - ) -> None: - """Initialize workspace context manager.""" + """Experiment workspace context manager. + + This class is a context manager for creating a workspace for an experiment. + + Attributes: + experiment_name (str): The name of the experiment. + data_file_path (Path): The path to the data file for the experiment. + install_requirements (bool): Whether to install the requirements for + the experiment. + cwd (Path): The current working directory. + experiment_work_dir (Path): The working directory for the experiment. + remove_archive (bool): Whether to remove the archive after the + experiment. + """ + + def __init__(self, + experiment_name: str, + data_file_path: Path, + install_requirements: bool = False, + remove_archive: bool = True) -> None: + """Initialize workspace context manager. + + Args: + experiment_name (str): The name of the experiment. + data_file_path (Path): The path to the data file for the + experiment. + install_requirements (bool, optional): Whether to install the + requirements for the experiment. Defaults to False. + remove_archive (bool, optional): Whether to remove the archive + after the experiment. Defaults to True. + """ self.experiment_name = experiment_name self.data_file_path = data_file_path self.install_requirements = install_requirements @@ -64,7 +84,9 @@ def __enter__(self): shutil.rmtree(self.experiment_work_dir, ignore_errors=True) os.makedirs(self.experiment_work_dir) - shutil.unpack_archive(self.data_file_path, self.experiment_work_dir, format='zip') + shutil.unpack_archive(self.data_file_path, + self.experiment_work_dir, + format='zip') if self.install_requirements: self._install_requirements() @@ -82,20 +104,32 @@ def __exit__(self, exc_type, exc_value, traceback): sys.path.remove(str(self.experiment_work_dir)) if self.remove_archive: + logger.debug('Exiting from the workspace context manager' + f' for {self.experiment_name} experiment') logger.debug( - 'Exiting from the workspace context manager' - f' for {self.experiment_name} experiment' - ) - logger.debug(f'Archive still exists: {self.data_file_path.exists()}') + f'Archive still exists: {self.data_file_path.exists()}') self.data_file_path.unlink(missing_ok=False) def dump_requirements_file( - path: Union[str, Path] = './requirements.txt', - keep_original_prefixes: bool = True, - prefixes: Optional[Union[Tuple[str], str]] = None, + path: Union[str, Path] = './requirements.txt', + keep_original_prefixes: bool = True, + prefixes: Optional[Union[Tuple[str], str]] = None, ) -> None: - """Prepare and save requirements.txt.""" + """Prepare and save requirements.txt. + + This function prepares a requirements.txt file and saves it to a specified + path. + + Args: + path (Union[str, Path], optional): The path to save the + requirements.txt file. + Defaults to './requirements.txt'. + keep_original_prefixes (bool, optional): Whether to keep the original + prefixes in the requirements.txt file. Defaults to True. + prefixes (Optional[Union[Tuple[str], str]], optional): The prefixes to + add to the requirements.txt file. Defaults to None. + """ from pip._internal.operations import freeze path = Path(path).absolute() @@ -103,7 +137,7 @@ def dump_requirements_file( if prefixes is None: prefixes = set() elif type(prefixes) is str: - prefixes = set(prefixes,) + prefixes = set(prefixes, ) else: prefixes = set(prefixes) @@ -131,20 +165,25 @@ def dump_requirements_file( def _is_package_versioned(package: str) -> bool: - """Check if the package has a version.""" - return ('==' in package - and package not in ['pkg-resources==0.0.0', 'pkg_resources==0.0.0'] - and '-e ' not in package - ) + """Check if a package has a version. + + Args: + package (str): The package to check. + + Returns: + bool: `True` if the package has a version, `False` otherwise. + """ + return ('==' in package and package + not in ['pkg-resources==0.0.0', 'pkg_resources==0.0.0'] + and '-e ' not in package) @contextmanager def set_directory(path: Path): - """ - Sets provided path as the cwd within the context. + """Set the current working directory within the context. Args: - path (Path): The path to the cwd + path (Path): The path to set as the current working directory. Yields: None