From 5a593427faba31bbc29a1dbfa8e752b3e8431817 Mon Sep 17 00:00:00 2001 From: Karan Shah Date: Wed, 27 Nov 2024 17:47:29 +0530 Subject: [PATCH] Fix auth failure on one-way TLS (#1167) * Fix auth failure on one-way TLS Signed-off-by: Shah, Karan * Migrate logger.warn to logger.warning Signed-off-by: Shah, Karan * Rename `tls` to `use_tls` globally Signed-off-by: Shah, Karan * Rename `disable_client_auth` to `require_client_auth` with flipped default Signed-off-by: Shah, Karan * Rename `disable_client_auth` to `require_client_auth` with flipped default Signed-off-by: Shah, Karan * Address review comments Signed-off-by: Shah, Karan --------- Signed-off-by: Shah, Karan --- .../workspace/plan/defaults/network.yaml | 4 +- openfl/federated/plan/plan.py | 2 +- openfl/federated/task/runner_xgb.py | 2 +- .../interface/interactive_api/experiment.py | 2 +- openfl/interface/plan.py | 4 +- .../frameworks_adapters/keras_adapter.py | 2 +- openfl/transport/grpc/aggregator_client.py | 49 +++++++++---------- openfl/transport/grpc/aggregator_server.py | 48 +++++++++--------- openfl/utilities/split.py | 2 +- 9 files changed, 57 insertions(+), 58 deletions(-) diff --git a/openfl-workspace/workspace/plan/defaults/network.yaml b/openfl-workspace/workspace/plan/defaults/network.yaml index 07d2e3aeec..11e03c1890 100644 --- a/openfl-workspace/workspace/plan/defaults/network.yaml +++ b/openfl-workspace/workspace/plan/defaults/network.yaml @@ -3,7 +3,7 @@ settings: agg_addr : auto agg_port : auto hash_salt : auto - tls : True + use_tls : True client_reconnect_interval : 5 - disable_client_auth : False + require_client_auth : True cert_folder : cert diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 5f0575837d..455b0b1414 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -729,7 +729,7 @@ def interactive_api_get_server( server_args["root_certificate"] = root_certificate server_args["certificate"] = certificate server_args["private_key"] = private_key - server_args["tls"] = tls + server_args["use_tls"] = tls server_args["aggregator"] = self.get_aggregator(tensor_dict) diff --git a/openfl/federated/task/runner_xgb.py b/openfl/federated/task/runner_xgb.py index 222b8c613f..ae44210ce2 100644 --- a/openfl/federated/task/runner_xgb.py +++ b/openfl/federated/task/runner_xgb.py @@ -35,7 +35,7 @@ def check_precision_loss(logger, converted_data, original_data): # Compare the original and reconstructed data if original_data != reconstructed_data: - logger.warn("Precision loss detected during conversion.") + logger.warning("Precision loss detected during conversion.") class XGBoostTaskRunner(TaskRunner): diff --git a/openfl/interface/interactive_api/experiment.py b/openfl/interface/interactive_api/experiment.py index 6e09c5bbeb..ce970abaaf 100644 --- a/openfl/interface/interactive_api/experiment.py +++ b/openfl/interface/interactive_api/experiment.py @@ -560,7 +560,7 @@ def _prepare_plan( # We just choose a port randomly from plan hash director_fqdn = self.federation.director_node_fqdn.split(":")[0] # We drop the port self.plan.config["network"]["settings"]["agg_addr"] = director_fqdn - self.plan.config["network"]["settings"]["tls"] = self.federation.tls + self.plan.config["network"]["settings"]["use_tls"] = self.federation.tls # Aggregator part of the plan self.plan.config["aggregator"]["settings"]["rounds_to_train"] = rounds_to_train diff --git a/openfl/interface/plan.py b/openfl/interface/plan.py index f4c91faed0..04f5dd9da8 100644 --- a/openfl/interface/plan.py +++ b/openfl/interface/plan.py @@ -182,7 +182,7 @@ def initialize( **task_runner.tensor_dict_split_fn_kwargs, ) - logger.warn( + logger.warning( f"Following parameters omitted from global initial model, " f"local initialization will determine" f" values: {list(holdout_params.keys())}" @@ -205,7 +205,7 @@ def initialize( if plan_origin.config["network"]["settings"]["agg_addr"] == "auto" or aggregator_address: plan_origin.config["network"]["settings"]["agg_addr"] = aggregator_address or getfqdn_env() - logger.warn( + logger.warning( f"Patching Aggregator Addr in Plan" f" 🠆 {plan_origin.config['network']['settings']['agg_addr']}" ) diff --git a/openfl/plugins/frameworks_adapters/keras_adapter.py b/openfl/plugins/frameworks_adapters/keras_adapter.py index 46373c629b..971be7840e 100644 --- a/openfl/plugins/frameworks_adapters/keras_adapter.py +++ b/openfl/plugins/frameworks_adapters/keras_adapter.py @@ -55,7 +55,7 @@ def __reduce__(self): # NOQA:N807 # Run the function if version.parse(tf.__version__) <= version.parse("2.7.1"): - logger.warn( + logger.warning( "Applying hotfix for model serialization." "Please consider updating to tensorflow>=2.8 to silence this warning." ) diff --git a/openfl/transport/grpc/aggregator_client.py b/openfl/transport/grpc/aggregator_client.py index 14c324e8a3..6713107c2b 100644 --- a/openfl/transport/grpc/aggregator_client.py +++ b/openfl/transport/grpc/aggregator_client.py @@ -170,15 +170,12 @@ class AggregatorGRPCClient: Attributes: uri (str): The URI of the aggregator. - tls (bool): Whether to use TLS for the connection. - disable_client_auth (bool): Whether to disable client-side - authentication. - root_certificate (str): The path to the root certificate for the TLS - connection. - certificate (str): The path to the client's certificate for the TLS - connection. - private_key (str): The path to the client's private key for the TLS - connection. + use_tls (bool): Whether to use TLS for the connection. + require_client_auth (bool): Whether to enable client-side authentication, i.e. mTLS. + Ignored if `use_tls=False`. + root_certificate (str): The path to the root certificate for the TLS connection, ignored if `use_tls=False`. + certificate (str): The path to the client's certificate for the TLS connection, ignored if `use_tls=False`. + private_key (str): The path to the client's private key for the TLS connection, ignored if `use_tls=False`. aggregator_uuid (str): The UUID of the aggregator. federation_uuid (str): The UUID of the federation. single_col_cert_common_name (str): The common name on the @@ -189,11 +186,11 @@ def __init__( self, agg_addr, agg_port, - disable_client_auth, root_certificate, certificate, private_key, - tls=True, + use_tls=True, + require_client_auth=True, aggregator_uuid=None, federation_uuid=None, single_col_cert_common_name=None, @@ -205,9 +202,9 @@ def __init__( Args: agg_addr (str): The address of the aggregator. agg_port (int): The port of the aggregator. - tls (bool): Whether to use TLS for the connection. - disable_client_auth (bool): Whether to disable client-side - authentication. + use_tls (bool): Whether to use TLS for the connection. + require_client_auth (bool): Whether to enable client-side + authentication, i.e. mTLS. Ignored if `use_tls=False`. root_certificate (str): The path to the root certificate for the TLS connection. certificate (str): The path to the client's certificate for the @@ -221,22 +218,22 @@ def __init__( **kwargs: Additional keyword arguments. """ self.uri = f"{agg_addr}:{agg_port}" - self.tls = tls - self.disable_client_auth = disable_client_auth + self.use_tls = use_tls + self.require_client_auth = require_client_auth self.root_certificate = root_certificate self.certificate = certificate self.private_key = private_key self.logger = getLogger(__name__) - if not self.tls: - self.logger.warn("gRPC is running on insecure channel with TLS disabled.") + if not self.use_tls: + self.logger.warning("gRPC is running on insecure channel with TLS disabled.") self.channel = self.create_insecure_channel(self.uri) else: self.channel = self.create_tls_channel( self.uri, self.root_certificate, - self.disable_client_auth, + self.require_client_auth, self.certificate, self.private_key, ) @@ -278,7 +275,7 @@ def create_tls_channel( self, uri, root_certificate, - disable_client_auth, + require_client_auth, certificate, private_key, ): @@ -288,8 +285,8 @@ def create_tls_channel( Args: uri (str): The uniform resource identifier for the secure channel. root_certificate (str): The Certificate Authority filename. - disable_client_auth (bool): True disables client-side - authentication (not recommended, throws warning to user). + require_client_auth (bool): True enables client-side + authentication. certificate (str): The client certificate filename from the collaborator (signed by the certificate authority). private_key (str): The private key filename for the client @@ -301,8 +298,8 @@ def create_tls_channel( with open(root_certificate, "rb") as f: root_certificate_b = f.read() - if disable_client_auth: - self.logger.warn("Client-side authentication is disabled.") + if not require_client_auth: + self.logger.warning("Client-side authentication is disabled.") private_key_b = None certificate_b = None else: @@ -364,13 +361,13 @@ def reconnect(self): # issued previously self.disconnect() - if not self.tls: + if not self.use_tls: self.channel = self.create_insecure_channel(self.uri) else: self.channel = self.create_tls_channel( self.uri, self.root_certificate, - self.disable_client_auth, + self.require_client_auth, self.certificate, self.private_key, ) diff --git a/openfl/transport/grpc/aggregator_server.py b/openfl/transport/grpc/aggregator_server.py index b7c54813af..19d156338f 100644 --- a/openfl/transport/grpc/aggregator_server.py +++ b/openfl/transport/grpc/aggregator_server.py @@ -28,15 +28,12 @@ class AggregatorGRPCServer(aggregator_pb2_grpc.AggregatorServicer): Attributes: aggregator (Aggregator): The aggregator that this server is serving. uri (str): The URI that the server is serving on. - tls (bool): Whether to use TLS for the connection. - disable_client_auth (bool): Whether to disable client-side - authentication. - root_certificate (str): The path to the root certificate for the TLS - connection. - certificate (str): The path to the server's certificate for the TLS - connection. - private_key (str): The path to the server's private key for the TLS - connection. + use_tls (bool): Whether to use TLS for the connection. + require_client_auth (bool): Whether to enable client-side authentication, i.e. mTLS. + Ignored if `use_tls=False`. + root_certificate (str): The path to the root certificate for the TLS connection, ignored if `use_tls=False`. + certificate (str): The path to the client's certificate for the TLS connection, ignored if `use_tls=False`. + private_key (str): The path to the client's private key for the TLS connection, ignored if `use_tls=False`. server (grpc.Server): The gRPC server. server_credentials (grpc.ServerCredentials): The server's credentials. """ @@ -45,8 +42,8 @@ def __init__( self, aggregator, agg_port, - tls=True, - disable_client_auth=False, + use_tls=True, + require_client_auth=True, root_certificate=None, certificate=None, private_key=None, @@ -59,9 +56,9 @@ def __init__( aggregator (Aggregator): The aggregator that this server is serving. agg_port (int): The port that the server is serving on. - tls (bool): Whether to use TLS for the connection. - disable_client_auth (bool): Whether to disable client-side - authentication. + use_tls (bool): Whether to use TLS for the connection. + require_client_auth (bool): Whether to enable client-side + authentication, i.e. mTLS. Ignored if `use_tls=False`. root_certificate (str): The path to the root certificate for the TLS connection. certificate (str): The path to the server's certificate for the @@ -70,10 +67,11 @@ def __init__( TLS connection. **kwargs: Additional keyword arguments. """ + print(f"{use_tls=}") self.aggregator = aggregator self.uri = f"[::]:{agg_port}" - self.tls = tls - self.disable_client_auth = disable_client_auth + self.use_tls = use_tls + self.require_client_auth = require_client_auth self.root_certificate = root_certificate self.certificate = certificate self.private_key = private_key @@ -97,9 +95,13 @@ def validate_collaborator(self, request, context): grpc.RpcError: If the collaborator or collaborator certificate is not authorized. """ - if self.tls: - common_name = context.auth_context()["x509_common_name"][0].decode("utf-8") + if self.use_tls: collaborator_common_name = request.header.sender + if self.require_client_auth: + common_name = context.auth_context()["x509_common_name"][0].decode("utf-8") + else: + common_name = collaborator_common_name + if not self.aggregator.valid_collaborator_cn_and_id( common_name, collaborator_common_name ): @@ -306,8 +308,8 @@ def get_server(self): aggregator_pb2_grpc.add_AggregatorServicer_to_server(self, self.server) - if not self.tls: - self.logger.warn("gRPC is running on insecure channel with TLS disabled.") + if not self.use_tls: + self.logger.warning("gRPC is running on insecure channel with TLS disabled.") port = self.server.add_insecure_port(self.uri) self.logger.info("Insecure port: %s", port) @@ -319,13 +321,13 @@ def get_server(self): with open(self.root_certificate, "rb") as f: root_certificate_b = f.read() - if self.disable_client_auth: - self.logger.warn("Client-side authentication is disabled.") + if not self.require_client_auth: + self.logger.warning("Client-side authentication is disabled.") self.server_credentials = ssl_server_credentials( ((private_key_b, certificate_b),), root_certificates=root_certificate_b, - require_client_auth=not self.disable_client_auth, + require_client_auth=self.require_client_auth, ) self.server.add_secure_port(self.uri, self.server_credentials) diff --git a/openfl/utilities/split.py b/openfl/utilities/split.py index 8eb0117f6d..ee2e4654ac 100644 --- a/openfl/utilities/split.py +++ b/openfl/utilities/split.py @@ -91,7 +91,7 @@ def split_tensor_dict_for_holdouts( try: holdout_tensors[tensor_name] = tensors_to_send.pop(tensor_name) except KeyError: - logger.warn( + logger.warning( f"tried to remove tensor: {tensor_name} not present " f"in the tensor dict" ) continue