diff --git a/openfl/experimental/workflow/federated/plan/plan.py b/openfl/experimental/workflow/federated/plan/plan.py index dc03c2d0a9..b96b8de8b4 100644 --- a/openfl/experimental/workflow/federated/plan/plan.py +++ b/openfl/experimental/workflow/federated/plan/plan.py @@ -163,6 +163,7 @@ def parse( f"[blue]{plan_config_path}[/].", extra={"markup": True}, ) + Plan.dump(plan_config_path, plan.config) Plan.logger.info(dump(plan.config)) return plan @@ -404,7 +405,7 @@ def get_client( root_certificate=None, private_key=None, certificate=None, - tls=False, + tls=True, ) -> AggregatorGRPCClient: """Get gRPC client for the specified collaborator. @@ -423,8 +424,8 @@ def get_client( Returns: AggregatorGRPCClient: gRPC client for the specified collaborator. """ - common_name = collaborator_name - if not root_certificate or not private_key or not certificate: + if tls and not (root_certificate and private_key and certificate): + common_name = collaborator_name root_certificate = "cert/cert_chain.crt" certificate = f"cert/client/col_{common_name}.crt" private_key = f"cert/client/col_{common_name}.key" @@ -436,7 +437,6 @@ def get_client( client_args["root_certificate"] = root_certificate client_args["certificate"] = certificate client_args["private_key"] = private_key - client_args["tls"] = tls client_args["aggregator_uuid"] = aggregator_uuid client_args["federation_uuid"] = federation_uuid @@ -451,7 +451,7 @@ def get_server( root_certificate=None, private_key=None, certificate=None, - tls=False, + tls=True, director_config=None, **kwargs, ) -> AggregatorGRPCServer: @@ -472,9 +472,8 @@ def get_server( Returns: AggregatorGRPCServer: gRPC server of the aggregator instance. """ - common_name = self.config["network"][SETTINGS]["agg_addr"].lower() - - if not root_certificate or not private_key or not certificate: + if tls and not (root_certificate and private_key and certificate): + common_name = self.config["network"][SETTINGS]["agg_addr"].lower() root_certificate = "cert/cert_chain.crt" certificate = f"cert/server/agg_{common_name}.crt" private_key = f"cert/server/agg_{common_name}.key" diff --git a/openfl/experimental/workflow/runtime/federated_runtime.py b/openfl/experimental/workflow/runtime/federated_runtime.py index efa90e2a24..861c27e059 100644 --- a/openfl/experimental/workflow/runtime/federated_runtime.py +++ b/openfl/experimental/workflow/runtime/federated_runtime.py @@ -149,6 +149,8 @@ def prepare_workspace_archive(self) -> Tuple[Path, str]: archive_path, exp_name = WorkspaceExport.export_federated( notebook_path=self.notebook_path, output_workspace="./generated_workspace", + director_fqdn=self.director["director_node_fqdn"], + tls=self.tls, ) return archive_path, exp_name diff --git a/openfl/experimental/workflow/workspace_export/export.py b/openfl/experimental/workflow/workspace_export/export.py index 809be19bd5..0309392426 100644 --- a/openfl/experimental/workflow/workspace_export/export.py +++ b/openfl/experimental/workflow/workspace_export/export.py @@ -19,6 +19,7 @@ import yaml from nbdev.export import nb_export +from openfl.experimental.workflow.federated.plan import Plan from openfl.experimental.workflow.interface.cli.cli_helper import print_tree logger = getLogger(__name__) @@ -275,12 +276,16 @@ def __write_yaml(self, path, data) -> None: yaml.safe_dump(data, y) @classmethod - def export_federated(cls, notebook_path: str, output_workspace: str) -> Tuple[str, str]: + def export_federated( + cls, notebook_path: str, output_workspace: str, director_fqdn: str, tls: bool = False + ) -> Tuple[str, str]: """Exports workspace for FederatedRuntime. Args: notebook_path (str): Path to the Jupyter notebook. output_workspace (str): Path for the generated workspace directory. + director_fqdn (str): Fully qualified domain name of the director node. + tls (bool, optional): Whether to use TLS for the connection. Returns: Tuple[str, str]: A tuple containing: @@ -288,7 +293,7 @@ def export_federated(cls, notebook_path: str, output_workspace: str) -> Tuple[st """ instance = cls(notebook_path, output_workspace) instance.generate_requirements() - instance.generate_plan_yaml() + instance.generate_plan_yaml(director_fqdn, tls) instance._clean_generated_workspace() print_tree(output_workspace, level=2) return instance.generate_experiment_archive() @@ -373,9 +378,13 @@ def _clean_generated_workspace(self) -> None: if data_file.exists(): data_file.unlink() - def generate_plan_yaml(self) -> None: + def generate_plan_yaml(self, director_fqdn: str = None, tls: bool = False) -> None: """ Generates plan.yaml + + Args: + director_fqdn (str): Fully qualified domain name of the director node. + tls (bool, optional): Whether to use TLS for the connection. """ flspec = importlib.import_module("openfl.experimental.workflow.interface").FLSpec # Get flow classname @@ -414,6 +423,13 @@ def update_dictionary(args: dict, data: dict, dtype: str = "args"): kw_args = self.arguments_passed_to_initialize["kwargs"] update_dictionary(kw_args, data, dtype="kwargs") + # Updating the aggregator address with director's hostname and tls settings in plan.yaml + if director_fqdn: + network_settings = Plan.parse(plan).config["network"] + data["network"] = network_settings + data["network"]["settings"]["agg_addr"] = director_fqdn + data["network"]["settings"]["tls"] = tls + self.__write_yaml(plan, data) def generate_data_yaml(self) -> None: # noqa: C901