Skip to content

Commit

Permalink
- extend plan initialize to have additional optional argument to take…
Browse files Browse the repository at this point in the history
… init model path (pbuf format)

- added new function to utils.py
- rebased 21.Jan.1
- reduce cyclo-complexity of initialize function
- REMOVE - Enable draft checks
Signed-off-by: Shailesh Pant <[email protected]>
  • Loading branch information
ishaileshpant committed Jan 21, 2025
1 parent cc3f12d commit 1470b3b
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ permissions:

jobs:
build:
if: github.event.pull_request.draft == false
if: github.event.pull_request.draft == true

runs-on: ubuntu-latest

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest_coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ env:

jobs:
build:
if: github.event.pull_request.draft == false
if: github.event.pull_request.draft == true
runs-on: ubuntu-latest
timeout-minutes: 15

Expand Down
17 changes: 17 additions & 0 deletions openfl/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,20 @@ def restore_object(self, filename):
return None
obj = serializer_plugin.restore_object(filename)
return obj

def save_model_to_state_file(self, tensor_dict, round_number, output_path):
"""Save model weights to a protobuf state file."""
from openfl.protocols import utils # Import here to avoid circular imports

# Get tensor pipe to properly serialize the weights
tensor_pipe = self.get_tensor_pipe()

# Create and save the protobuf message
try:
model_proto = utils.construct_model_proto(
tensor_dict=tensor_dict, round_number=round_number, tensor_pipe=tensor_pipe
)
utils.dump_proto(model_proto=model_proto, fpath=output_path)
except Exception as e:
self.logger.error(f"Failed to create or save model proto: {e}")
raise
104 changes: 68 additions & 36 deletions openfl/interface/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def plan(context):
help="Install packages listed under 'requirements.txt'. True/False [Default: True]",
default=True,
)
@option(
"-i",
"--init_model_path",
required=False,
help="Path to initial model protobuf file",
type=ClickPath(exists=True),
)
def initialize(
context,
plan_config,
Expand All @@ -104,6 +111,7 @@ def initialize(
input_shape,
gandlf_config,
install_reqs,
init_model_path,
):
"""Initialize Data Science plan.
Expand All @@ -119,6 +127,7 @@ def initialize(
feature_shape (str): The input shape to the model.
gandlf_config (str): GaNDLF Configuration File Path.
install_reqs (bool): Whether to install packages listed under 'requirements.txt'.
init_model_path (str): Optional path to initialization model protobuf file.
"""

for p in [plan_config, cols_config, data_config]:
Expand All @@ -133,29 +142,8 @@ def initialize(
gandlf_config = Path(gandlf_config).absolute()

if install_reqs:
requirements_filename = "requirements.txt"
requirements_path = Path(requirements_filename).absolute()

if isfile(f"{str(requirements_path)}"):
check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-r",
f"{str(requirements_path)}",
],
shell=False,
)
echo(f"Successfully installed packages from {requirements_path}.")

# Required to restart the process for newly installed packages to be recognized
args_restart = [arg for arg in sys.argv if not arg.startswith("--install_reqs")]
args_restart.append("--install_reqs=False")
os.execv(args_restart[0], args_restart)
else:
echo("No additional requirements for workspace defined. Skipping...")
requirements_path = Path("requirements.txt").absolute()
_handle_requirements_install(requirements_path)

plan = Plan.parse(
plan_config_path=plan_config,
Expand All @@ -165,21 +153,20 @@ def initialize(
)

init_state_path = plan.config["aggregator"]["settings"]["init_state_path"]

# This is needed to bypass data being locally available
if input_shape is not None:
logger.info(
f"Attempting to generate initial model weights with custom input shape {input_shape}"
)

data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape)

task_runner = plan.get_task_runner(data_loader)
tensor_pipe = plan.get_tensor_pipe()
# Initialize tensor dictionary using the extracted function
init_tensor_dict, task_runner, round_number = _initialize_tensor_dict(
plan, input_shape, init_model_path
)

tensor_dict, holdout_params = split_tensor_dict_for_holdouts(
logger,
task_runner.get_tensor_dict(False),
init_tensor_dict,
**task_runner.tensor_dict_split_fn_kwargs,
)

Expand All @@ -189,13 +176,15 @@ def initialize(
f" values: {list(holdout_params.keys())}"
)

model_snap = utils.construct_model_proto(
tensor_dict=tensor_dict, round_number=0, tensor_pipe=tensor_pipe
)

logger.info("Creating Initial Weights File 🠆 %s", init_state_path)

utils.dump_proto(model_proto=model_snap, fpath=init_state_path)
# Save the model state
try:
logger.info(f"Saving model state to {init_state_path}")
plan.save_model_to_state_file(
tensor_dict=tensor_dict, round_number=round_number, output_path=init_state_path
)
except Exception as e:
logger.error(f"Failed to save model state: {e}")
raise

plan_origin = Plan.parse(
plan_config_path=plan_config,
Expand Down Expand Up @@ -223,6 +212,49 @@ def initialize(
logger.info(f"{context.obj['plans']}")


def _handle_requirements_install(requirements_path):
"""Handle the installation of requirements and process restart if needed."""
if isfile(str(requirements_path)):
check_call(
[sys.executable, "-m", "pip", "install", "-r", str(requirements_path)],
shell=False,
)
echo(f"Successfully installed packages from {requirements_path}.")

# Required to restart the process for newly installed packages to be recognized
args_restart = [arg for arg in sys.argv if not arg.startswith("--install_reqs")]
args_restart.append("--install_reqs=False")
os.execv(args_restart[0], args_restart)
else:
echo("No additional requirements for workspace defined. Skipping...")


def _initialize_tensor_dict(plan, input_shape, init_model_path):
"""Initialize and return the tensor dictionary.
Args:
plan: The federation plan object
input_shape: The input shape to the model
init_model_path: Path to initial model protobuf file
Returns:
Tuple of (tensor_dict, task_runner, round_number)
"""
data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape)
task_runner = plan.get_task_runner(data_loader)
tensor_pipe = plan.get_tensor_pipe()
round_number = 0

if init_model_path and isfile(init_model_path):
logger.info(f"Loading initial model from {init_model_path}")
model_proto = utils.load_proto(init_model_path)
init_tensor_dict, round_number = utils.deconstruct_model_proto(model_proto, tensor_pipe)
else:
init_tensor_dict = task_runner.get_tensor_dict(False)

return init_tensor_dict, task_runner, round_number


# TODO: looks like Plan.method
def freeze_plan(plan_config):
"""Dump the plan to YAML file.
Expand Down
44 changes: 44 additions & 0 deletions openfl/protocols/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

"""Proto utils."""

import logging

from openfl.protocols import base_pb2
from openfl.utilities import TensorKey

logger = logging.getLogger(__name__)


def model_proto_to_bytes_and_metadata(model_proto):
"""Convert the model protobuf to bytes and metadata.
Expand Down Expand Up @@ -356,3 +360,43 @@ def get_headers(context) -> dict:
values are the corresponding header values.
"""
return {header[0]: header[1] for header in context.invocation_metadata()}


def construct_tensor_dict_from_proto(model_proto, tensor_pipe):
"""Convert a model protobuf message to a tensor dictionary."""
tensor_dict = {}
logger.info("\n=== Processing Proto Message ===")
logger.info(f"Number of tensors in proto: {len(model_proto.tensors)}")

for tensor in model_proto.tensors:
logger.info(f"\nProcessing proto tensor: {tensor.name}")
logger.info("-" * 50)
try:
# Extract metadata from the tensor proto
transformer_metadata = [
{
"int_to_float": proto.int_to_float,
"int_list": proto.int_list,
"bool_list": proto.bool_list,
}
for proto in tensor.transformer_metadata
]

# Decompress the tensor value using the compression pipeline
logger.info("Decompressing tensor...")
decompressed_tensor = tensor_pipe.backward(
data=tensor.data_bytes, transformer_metadata=transformer_metadata
)

# Store in dictionary
tensor_dict[tensor.name] = decompressed_tensor

except Exception as e:
logger.error(f"Failed to process tensor {tensor.name}")
logger.error(f"Error: {str(e)}")
raise

logger.info("\n=== Finished Processing Proto Message ===")
logger.info(f"Successfully processed {len(tensor_dict)} tensors")

return tensor_dict
Binary file added torch_cnn_mnist_init.pbuf
Binary file not shown.

0 comments on commit 1470b3b

Please sign in to comment.