Skip to content

Commit

Permalink
Fix tls issue
Browse files Browse the repository at this point in the history
Signed-off-by: Ishant Thakare <[email protected]>
  • Loading branch information
ishant162 committed Jan 31, 2025
1 parent 1f43b93 commit 1546842
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 11 deletions.
15 changes: 7 additions & 8 deletions openfl/experimental/workflow/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def parse(
f"[blue]{plan_config_path}[/].",
extra={"markup": True},
)
Plan.dump(plan_config_path, plan.config)
Plan.logger.info(dump(plan.config))

return plan
Expand Down Expand Up @@ -404,7 +405,7 @@ def get_client(
root_certificate=None,
private_key=None,
certificate=None,
tls=False,
tls=True,
) -> AggregatorGRPCClient:
"""Get gRPC client for the specified collaborator.
Expand All @@ -423,8 +424,8 @@ def get_client(
Returns:
AggregatorGRPCClient: gRPC client for the specified collaborator.
"""
common_name = collaborator_name
if not root_certificate or not private_key or not certificate:
if tls and not (root_certificate and private_key and certificate):
common_name = collaborator_name
root_certificate = "cert/cert_chain.crt"
certificate = f"cert/client/col_{common_name}.crt"
private_key = f"cert/client/col_{common_name}.key"
Expand All @@ -436,7 +437,6 @@ def get_client(
client_args["root_certificate"] = root_certificate
client_args["certificate"] = certificate
client_args["private_key"] = private_key
client_args["tls"] = tls

client_args["aggregator_uuid"] = aggregator_uuid
client_args["federation_uuid"] = federation_uuid
Expand All @@ -451,7 +451,7 @@ def get_server(
root_certificate=None,
private_key=None,
certificate=None,
tls=False,
tls=True,
director_config=None,
**kwargs,
) -> AggregatorGRPCServer:
Expand All @@ -472,9 +472,8 @@ def get_server(
Returns:
AggregatorGRPCServer: gRPC server of the aggregator instance.
"""
common_name = self.config["network"][SETTINGS]["agg_addr"].lower()

if not root_certificate or not private_key or not certificate:
if tls and not (root_certificate and private_key and certificate):
common_name = self.config["network"][SETTINGS]["agg_addr"].lower()
root_certificate = "cert/cert_chain.crt"
certificate = f"cert/server/agg_{common_name}.crt"
private_key = f"cert/server/agg_{common_name}.key"
Expand Down
2 changes: 2 additions & 0 deletions openfl/experimental/workflow/runtime/federated_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def prepare_workspace_archive(self) -> Tuple[Path, str]:
archive_path, exp_name = WorkspaceExport.export_federated(
notebook_path=self.notebook_path,
output_workspace="./generated_workspace",
director_fqdn=self.director["director_node_fqdn"],
tls=self.tls,
)
return archive_path, exp_name

Expand Down
22 changes: 19 additions & 3 deletions openfl/experimental/workflow/workspace_export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import yaml
from nbdev.export import nb_export

from openfl.experimental.workflow.federated.plan import Plan
from openfl.experimental.workflow.interface.cli.cli_helper import print_tree

logger = getLogger(__name__)
Expand Down Expand Up @@ -275,20 +276,24 @@ def __write_yaml(self, path, data) -> None:
yaml.safe_dump(data, y)

@classmethod
def export_federated(cls, notebook_path: str, output_workspace: str) -> Tuple[str, str]:
def export_federated(
cls, notebook_path: str, output_workspace: str, director_fqdn: str, tls: bool = False
) -> Tuple[str, str]:
"""Exports workspace for FederatedRuntime.
Args:
notebook_path (str): Path to the Jupyter notebook.
output_workspace (str): Path for the generated workspace directory.
director_fqdn (str): Fully qualified domain name of the director node.
tls (bool, optional): Whether to use TLS for the connection.
Returns:
Tuple[str, str]: A tuple containing:
(archive_path, flow_class_name).
"""
instance = cls(notebook_path, output_workspace)
instance.generate_requirements()
instance.generate_plan_yaml()
instance.generate_plan_yaml(director_fqdn, tls)
instance._clean_generated_workspace()
print_tree(output_workspace, level=2)
return instance.generate_experiment_archive()
Expand Down Expand Up @@ -373,9 +378,13 @@ def _clean_generated_workspace(self) -> None:
if data_file.exists():
data_file.unlink()

def generate_plan_yaml(self) -> None:
def generate_plan_yaml(self, director_fqdn: str = None, tls: bool = False) -> None:
"""
Generates plan.yaml
Args:
director_fqdn (str): Fully qualified domain name of the director node.
tls (bool, optional): Whether to use TLS for the connection.
"""
flspec = importlib.import_module("openfl.experimental.workflow.interface").FLSpec
# Get flow classname
Expand Down Expand Up @@ -414,6 +423,13 @@ def update_dictionary(args: dict, data: dict, dtype: str = "args"):
kw_args = self.arguments_passed_to_initialize["kwargs"]
update_dictionary(kw_args, data, dtype="kwargs")

# Updating the aggregator address with director's hostname and tls settings in plan.yaml
if director_fqdn:
network_settings = Plan.parse(plan).config["network"]
data["network"] = network_settings
data["network"]["settings"]["agg_addr"] = director_fqdn
data["network"]["settings"]["tls"] = tls

self.__write_yaml(plan, data)

def generate_data_yaml(self) -> None: # noqa: C901
Expand Down

0 comments on commit 1546842

Please sign in to comment.