From 813f563f9898134be2128559b0219a778a3a16bd Mon Sep 17 00:00:00 2001 From: Shailesh Pant Date: Wed, 18 Dec 2024 21:24:20 +0530 Subject: [PATCH] - extend plan initialize to have additional optional argument to take init model path (pbuf format) - added new function to utils.py - rebased 21.Jan.1 - reduce cyclo-complexity of initialize function - address review comments Signed-off-by: Shailesh Pant --- openfl/federated/plan/plan.py | 33 +++++++++++ openfl/interface/plan.py | 104 ++++++++++++++++++++++------------ openfl/protocols/utils.py | 4 ++ 3 files changed, 105 insertions(+), 36 deletions(-) diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 13d446e145..1683f976c7 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -777,3 +777,36 @@ def restore_object(self, filename): return None obj = serializer_plugin.restore_object(filename) return obj + + def save_model_to_state_file(self, tensor_dict, round_number, output_path): + """Save model weights to a protobuf state file. + + This method serializes the model weights into a protobuf format and saves + them to a file. The serialization is done using the tensor pipe to ensure + proper compression and formatting. + + Args: + tensor_dict (dict): Dictionary containing model weights and their + corresponding tensors. + round_number (int): The current federation round number. + output_path (str): Path where the serialized model state will be + saved. + + Raises: + Exception: If there is an error during model proto creation or saving + to file. + """ + from openfl.protocols import utils # Import here to avoid circular imports + + # Get tensor pipe to properly serialize the weights + tensor_pipe = self.get_tensor_pipe() + + # Create and save the protobuf message + try: + model_proto = utils.construct_model_proto( + tensor_dict=tensor_dict, round_number=round_number, tensor_pipe=tensor_pipe + ) + utils.dump_proto(model_proto=model_proto, fpath=output_path) + except Exception as e: + self.logger.error(f"Failed to create or save model proto: {e}") + raise diff --git a/openfl/interface/plan.py b/openfl/interface/plan.py index 503693e581..8fadf7cdf3 100644 --- a/openfl/interface/plan.py +++ b/openfl/interface/plan.py @@ -95,6 +95,13 @@ def plan(context): help="Install packages listed under 'requirements.txt'. True/False [Default: True]", default=True, ) +@option( + "-i", + "--init_model_path", + required=False, + help="Path to initial model protobuf file", + type=ClickPath(exists=True), +) def initialize( context, plan_config, @@ -104,6 +111,7 @@ def initialize( input_shape, gandlf_config, install_reqs, + init_model_path, ): """Initialize Data Science plan. @@ -119,6 +127,7 @@ def initialize( feature_shape (str): The input shape to the model. gandlf_config (str): GaNDLF Configuration File Path. install_reqs (bool): Whether to install packages listed under 'requirements.txt'. + init_model_path (str): Optional path to initialization model protobuf file. """ for p in [plan_config, cols_config, data_config]: @@ -133,29 +142,8 @@ def initialize( gandlf_config = Path(gandlf_config).absolute() if install_reqs: - requirements_filename = "requirements.txt" - requirements_path = Path(requirements_filename).absolute() - - if isfile(f"{str(requirements_path)}"): - check_call( - [ - sys.executable, - "-m", - "pip", - "install", - "-r", - f"{str(requirements_path)}", - ], - shell=False, - ) - echo(f"Successfully installed packages from {requirements_path}.") - - # Required to restart the process for newly installed packages to be recognized - args_restart = [arg for arg in sys.argv if not arg.startswith("--install_reqs")] - args_restart.append("--install_reqs=False") - os.execv(args_restart[0], args_restart) - else: - echo("No additional requirements for workspace defined. Skipping...") + requirements_path = Path("requirements.txt").absolute() + _handle_requirements_install(requirements_path) plan = Plan.parse( plan_config_path=plan_config, @@ -165,21 +153,20 @@ def initialize( ) init_state_path = plan.config["aggregator"]["settings"]["init_state_path"] - # This is needed to bypass data being locally available if input_shape is not None: logger.info( f"Attempting to generate initial model weights with custom input shape {input_shape}" ) - data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape) - - task_runner = plan.get_task_runner(data_loader) - tensor_pipe = plan.get_tensor_pipe() + # Initialize tensor dictionary + init_tensor_dict, task_runner, round_number = _initialize_tensor_dict( + plan, input_shape, init_model_path + ) tensor_dict, holdout_params = split_tensor_dict_for_holdouts( logger, - task_runner.get_tensor_dict(False), + init_tensor_dict, **task_runner.tensor_dict_split_fn_kwargs, ) @@ -189,13 +176,15 @@ def initialize( f" values: {list(holdout_params.keys())}" ) - model_snap = utils.construct_model_proto( - tensor_dict=tensor_dict, round_number=0, tensor_pipe=tensor_pipe - ) - - logger.info("Creating Initial Weights File 🠆 %s", init_state_path) - - utils.dump_proto(model_proto=model_snap, fpath=init_state_path) + # Save the model state + try: + logger.info(f"Saving model state to {init_state_path}") + plan.save_model_to_state_file( + tensor_dict=tensor_dict, round_number=round_number, output_path=init_state_path + ) + except Exception as e: + logger.error(f"Failed to save model state: {e}") + raise plan_origin = Plan.parse( plan_config_path=plan_config, @@ -223,6 +212,49 @@ def initialize( logger.info(f"{context.obj['plans']}") +def _handle_requirements_install(requirements_path): + """Handle the installation of requirements and process restart if needed.""" + if isfile(str(requirements_path)): + check_call( + [sys.executable, "-m", "pip", "install", "-r", str(requirements_path)], + shell=False, + ) + echo(f"Successfully installed packages from {requirements_path}.") + + # Required to restart the process for newly installed packages to be recognized + args_restart = [arg for arg in sys.argv if not arg.startswith("--install_reqs")] + args_restart.append("--install_reqs=False") + os.execv(args_restart[0], args_restart) + else: + echo("No additional requirements for workspace defined. Skipping...") + + +def _initialize_tensor_dict(plan, input_shape, init_model_path): + """Initialize and return the tensor dictionary. + + Args: + plan: The federation plan object + input_shape: The input shape to the model + init_model_path: Path to initial model protobuf file + + Returns: + Tuple of (tensor_dict, task_runner, round_number) + """ + data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape) + task_runner = plan.get_task_runner(data_loader) + tensor_pipe = plan.get_tensor_pipe() + round_number = 0 + + if init_model_path and isfile(init_model_path): + logger.info(f"Loading initial model from {init_model_path}") + model_proto = utils.load_proto(init_model_path) + init_tensor_dict, round_number = utils.deconstruct_model_proto(model_proto, tensor_pipe) + else: + init_tensor_dict = task_runner.get_tensor_dict(False) + + return init_tensor_dict, task_runner, round_number + + # TODO: looks like Plan.method def freeze_plan(plan_config): """Dump the plan to YAML file. diff --git a/openfl/protocols/utils.py b/openfl/protocols/utils.py index e1d3da888a..1ccaf5b534 100644 --- a/openfl/protocols/utils.py +++ b/openfl/protocols/utils.py @@ -4,9 +4,13 @@ """Proto utils.""" +import logging + from openfl.protocols import base_pb2 from openfl.utilities import TensorKey +logger = logging.getLogger(__name__) + def model_proto_to_bytes_and_metadata(model_proto): """Convert the model protobuf to bytes and metadata.