diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index afa5ea6a2f..72c0d85c02 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -12,7 +12,7 @@ permissions: jobs: build: - if: github.event.pull_request.draft == false + if: github.event.pull_request.draft == true runs-on: ubuntu-latest diff --git a/.github/workflows/pytest_coverage.yml b/.github/workflows/pytest_coverage.yml index a5ea0c14d9..35ef54f998 100644 --- a/.github/workflows/pytest_coverage.yml +++ b/.github/workflows/pytest_coverage.yml @@ -17,7 +17,7 @@ env: jobs: build: - if: github.event.pull_request.draft == false + if: github.event.pull_request.draft == true runs-on: ubuntu-latest timeout-minutes: 15 diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 13d446e145..a0fe99d1a9 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -777,3 +777,20 @@ def restore_object(self, filename): return None obj = serializer_plugin.restore_object(filename) return obj + + def save_model_to_state_file(self, tensor_dict, round_number, output_path): + """Save model weights to a protobuf state file.""" + from openfl.protocols import utils # Import here to avoid circular imports + + # Get tensor pipe to properly serialize the weights + tensor_pipe = self.get_tensor_pipe() + + # Create and save the protobuf message + try: + model_proto = utils.construct_model_proto( + tensor_dict=tensor_dict, round_number=round_number, tensor_pipe=tensor_pipe + ) + utils.dump_proto(model_proto=model_proto, fpath=output_path) + except Exception as e: + self.logger.error(f"Failed to create or save model proto: {e}") + raise diff --git a/openfl/interface/plan.py b/openfl/interface/plan.py index 503693e581..56bbc8dc70 100644 --- a/openfl/interface/plan.py +++ b/openfl/interface/plan.py @@ -95,6 +95,13 @@ def plan(context): help="Install packages listed under 'requirements.txt'. True/False [Default: True]", default=True, ) +@option( + "-i", + "--init_model_path", + required=False, + help="Path to initial model protobuf file", + type=ClickPath(exists=True), +) def initialize( context, plan_config, @@ -104,6 +111,7 @@ def initialize( input_shape, gandlf_config, install_reqs, + init_model_path, ): """Initialize Data Science plan. @@ -119,6 +127,7 @@ def initialize( feature_shape (str): The input shape to the model. gandlf_config (str): GaNDLF Configuration File Path. install_reqs (bool): Whether to install packages listed under 'requirements.txt'. + init_model_path (str): Optional path to initialization model protobuf file. """ for p in [plan_config, cols_config, data_config]: @@ -165,21 +174,20 @@ def initialize( ) init_state_path = plan.config["aggregator"]["settings"]["init_state_path"] - # This is needed to bypass data being locally available if input_shape is not None: logger.info( f"Attempting to generate initial model weights with custom input shape {input_shape}" ) - data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape) - - task_runner = plan.get_task_runner(data_loader) - tensor_pipe = plan.get_tensor_pipe() + # Initialize tensor dictionary using the extracted function + init_tensor_dict, task_runner, round_number = _initialize_tensor_dict( + plan, input_shape, init_model_path + ) tensor_dict, holdout_params = split_tensor_dict_for_holdouts( logger, - task_runner.get_tensor_dict(False), + init_tensor_dict, **task_runner.tensor_dict_split_fn_kwargs, ) @@ -189,13 +197,15 @@ def initialize( f" values: {list(holdout_params.keys())}" ) - model_snap = utils.construct_model_proto( - tensor_dict=tensor_dict, round_number=0, tensor_pipe=tensor_pipe - ) - - logger.info("Creating Initial Weights File 🠆 %s", init_state_path) - - utils.dump_proto(model_proto=model_snap, fpath=init_state_path) + # Save the model state + try: + logger.info(f"Saving model state to {init_state_path}") + plan.save_model_to_state_file( + tensor_dict=tensor_dict, round_number=round_number, output_path=init_state_path + ) + except Exception as e: + logger.error(f"Failed to save model state: {e}") + raise plan_origin = Plan.parse( plan_config_path=plan_config, @@ -223,6 +233,32 @@ def initialize( logger.info(f"{context.obj['plans']}") +def _initialize_tensor_dict(plan, input_shape, init_model_path): + """Initialize and return the tensor dictionary. + + Args: + plan: The federation plan object + input_shape: The input shape to the model + init_model_path: Path to initial model protobuf file + + Returns: + Tuple of (tensor_dict, task_runner, round_number) + """ + data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape) + task_runner = plan.get_task_runner(data_loader) + tensor_pipe = plan.get_tensor_pipe() + round_number = 0 + + if init_model_path and isfile(init_model_path): + logger.info(f"Loading initial model from {init_model_path}") + model_proto = utils.load_proto(init_model_path) + init_tensor_dict, round_number = utils.deconstruct_model_proto(model_proto, tensor_pipe) + else: + init_tensor_dict = task_runner.get_tensor_dict(False) + + return init_tensor_dict, task_runner, round_number + + # TODO: looks like Plan.method def freeze_plan(plan_config): """Dump the plan to YAML file. diff --git a/openfl/protocols/utils.py b/openfl/protocols/utils.py index e1d3da888a..3ccdd0ab05 100644 --- a/openfl/protocols/utils.py +++ b/openfl/protocols/utils.py @@ -4,9 +4,13 @@ """Proto utils.""" +import logging + from openfl.protocols import base_pb2 from openfl.utilities import TensorKey +logger = logging.getLogger(__name__) + def model_proto_to_bytes_and_metadata(model_proto): """Convert the model protobuf to bytes and metadata. @@ -356,3 +360,43 @@ def get_headers(context) -> dict: values are the corresponding header values. """ return {header[0]: header[1] for header in context.invocation_metadata()} + + +def construct_tensor_dict_from_proto(model_proto, tensor_pipe): + """Convert a model protobuf message to a tensor dictionary.""" + tensor_dict = {} + logger.info("\n=== Processing Proto Message ===") + logger.info(f"Number of tensors in proto: {len(model_proto.tensors)}") + + for tensor in model_proto.tensors: + logger.info(f"\nProcessing proto tensor: {tensor.name}") + logger.info("-" * 50) + try: + # Extract metadata from the tensor proto + transformer_metadata = [ + { + "int_to_float": proto.int_to_float, + "int_list": proto.int_list, + "bool_list": proto.bool_list, + } + for proto in tensor.transformer_metadata + ] + + # Decompress the tensor value using the compression pipeline + logger.info("Decompressing tensor...") + decompressed_tensor = tensor_pipe.backward( + data=tensor.data_bytes, transformer_metadata=transformer_metadata + ) + + # Store in dictionary + tensor_dict[tensor.name] = decompressed_tensor + + except Exception as e: + logger.error(f"Failed to process tensor {tensor.name}") + logger.error(f"Error: {str(e)}") + raise + + logger.info("\n=== Finished Processing Proto Message ===") + logger.info(f"Successfully processed {len(tensor_dict)} tensors") + + return tensor_dict diff --git a/torch_cnn_mnist_init.pbuf b/torch_cnn_mnist_init.pbuf new file mode 100644 index 0000000000..36a3bdb969 Binary files /dev/null and b/torch_cnn_mnist_init.pbuf differ