Skip to content

Commit

Permalink
Fix auth failure on one-way TLS (#1167)
Browse files Browse the repository at this point in the history
* Fix auth failure on one-way TLS

Signed-off-by: Shah, Karan <[email protected]>

* Migrate logger.warn to logger.warning

Signed-off-by: Shah, Karan <[email protected]>

* Rename `tls` to `use_tls` globally

Signed-off-by: Shah, Karan <[email protected]>

* Rename `disable_client_auth` to `require_client_auth` with flipped default

Signed-off-by: Shah, Karan <[email protected]>

* Rename `disable_client_auth` to `require_client_auth` with flipped default

Signed-off-by: Shah, Karan <[email protected]>

* Address review comments

Signed-off-by: Shah, Karan <[email protected]>

---------

Signed-off-by: Shah, Karan <[email protected]>
  • Loading branch information
MasterSkepticista authored Nov 27, 2024
1 parent 6c0e881 commit 5a59342
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 58 deletions.
4 changes: 2 additions & 2 deletions openfl-workspace/workspace/plan/defaults/network.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ settings:
agg_addr : auto
agg_port : auto
hash_salt : auto
tls : True
use_tls : True
client_reconnect_interval : 5
disable_client_auth : False
require_client_auth : True
cert_folder : cert
2 changes: 1 addition & 1 deletion openfl/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def interactive_api_get_server(
server_args["root_certificate"] = root_certificate
server_args["certificate"] = certificate
server_args["private_key"] = private_key
server_args["tls"] = tls
server_args["use_tls"] = tls

server_args["aggregator"] = self.get_aggregator(tensor_dict)

Expand Down
2 changes: 1 addition & 1 deletion openfl/federated/task/runner_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def check_precision_loss(logger, converted_data, original_data):

# Compare the original and reconstructed data
if original_data != reconstructed_data:
logger.warn("Precision loss detected during conversion.")
logger.warning("Precision loss detected during conversion.")


class XGBoostTaskRunner(TaskRunner):
Expand Down
2 changes: 1 addition & 1 deletion openfl/interface/interactive_api/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def _prepare_plan(
# We just choose a port randomly from plan hash
director_fqdn = self.federation.director_node_fqdn.split(":")[0] # We drop the port
self.plan.config["network"]["settings"]["agg_addr"] = director_fqdn
self.plan.config["network"]["settings"]["tls"] = self.federation.tls
self.plan.config["network"]["settings"]["use_tls"] = self.federation.tls

# Aggregator part of the plan
self.plan.config["aggregator"]["settings"]["rounds_to_train"] = rounds_to_train
Expand Down
4 changes: 2 additions & 2 deletions openfl/interface/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def initialize(
**task_runner.tensor_dict_split_fn_kwargs,
)

logger.warn(
logger.warning(
f"Following parameters omitted from global initial model, "
f"local initialization will determine"
f" values: {list(holdout_params.keys())}"
Expand All @@ -205,7 +205,7 @@ def initialize(
if plan_origin.config["network"]["settings"]["agg_addr"] == "auto" or aggregator_address:
plan_origin.config["network"]["settings"]["agg_addr"] = aggregator_address or getfqdn_env()

logger.warn(
logger.warning(
f"Patching Aggregator Addr in Plan"
f" 🠆 {plan_origin.config['network']['settings']['agg_addr']}"
)
Expand Down
2 changes: 1 addition & 1 deletion openfl/plugins/frameworks_adapters/keras_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __reduce__(self): # NOQA:N807

# Run the function
if version.parse(tf.__version__) <= version.parse("2.7.1"):
logger.warn(
logger.warning(
"Applying hotfix for model serialization."
"Please consider updating to tensorflow>=2.8 to silence this warning."
)
Expand Down
49 changes: 23 additions & 26 deletions openfl/transport/grpc/aggregator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,12 @@ class AggregatorGRPCClient:
Attributes:
uri (str): The URI of the aggregator.
tls (bool): Whether to use TLS for the connection.
disable_client_auth (bool): Whether to disable client-side
authentication.
root_certificate (str): The path to the root certificate for the TLS
connection.
certificate (str): The path to the client's certificate for the TLS
connection.
private_key (str): The path to the client's private key for the TLS
connection.
use_tls (bool): Whether to use TLS for the connection.
require_client_auth (bool): Whether to enable client-side authentication, i.e. mTLS.
Ignored if `use_tls=False`.
root_certificate (str): The path to the root certificate for the TLS connection, ignored if `use_tls=False`.
certificate (str): The path to the client's certificate for the TLS connection, ignored if `use_tls=False`.
private_key (str): The path to the client's private key for the TLS connection, ignored if `use_tls=False`.
aggregator_uuid (str): The UUID of the aggregator.
federation_uuid (str): The UUID of the federation.
single_col_cert_common_name (str): The common name on the
Expand All @@ -189,11 +186,11 @@ def __init__(
self,
agg_addr,
agg_port,
disable_client_auth,
root_certificate,
certificate,
private_key,
tls=True,
use_tls=True,
require_client_auth=True,
aggregator_uuid=None,
federation_uuid=None,
single_col_cert_common_name=None,
Expand All @@ -205,9 +202,9 @@ def __init__(
Args:
agg_addr (str): The address of the aggregator.
agg_port (int): The port of the aggregator.
tls (bool): Whether to use TLS for the connection.
disable_client_auth (bool): Whether to disable client-side
authentication.
use_tls (bool): Whether to use TLS for the connection.
require_client_auth (bool): Whether to enable client-side
authentication, i.e. mTLS. Ignored if `use_tls=False`.
root_certificate (str): The path to the root certificate for the
TLS connection.
certificate (str): The path to the client's certificate for the
Expand All @@ -221,22 +218,22 @@ def __init__(
**kwargs: Additional keyword arguments.
"""
self.uri = f"{agg_addr}:{agg_port}"
self.tls = tls
self.disable_client_auth = disable_client_auth
self.use_tls = use_tls
self.require_client_auth = require_client_auth
self.root_certificate = root_certificate
self.certificate = certificate
self.private_key = private_key

self.logger = getLogger(__name__)

if not self.tls:
self.logger.warn("gRPC is running on insecure channel with TLS disabled.")
if not self.use_tls:
self.logger.warning("gRPC is running on insecure channel with TLS disabled.")
self.channel = self.create_insecure_channel(self.uri)
else:
self.channel = self.create_tls_channel(
self.uri,
self.root_certificate,
self.disable_client_auth,
self.require_client_auth,
self.certificate,
self.private_key,
)
Expand Down Expand Up @@ -278,7 +275,7 @@ def create_tls_channel(
self,
uri,
root_certificate,
disable_client_auth,
require_client_auth,
certificate,
private_key,
):
Expand All @@ -288,8 +285,8 @@ def create_tls_channel(
Args:
uri (str): The uniform resource identifier for the secure channel.
root_certificate (str): The Certificate Authority filename.
disable_client_auth (bool): True disables client-side
authentication (not recommended, throws warning to user).
require_client_auth (bool): True enables client-side
authentication.
certificate (str): The client certificate filename from the
collaborator (signed by the certificate authority).
private_key (str): The private key filename for the client
Expand All @@ -301,8 +298,8 @@ def create_tls_channel(
with open(root_certificate, "rb") as f:
root_certificate_b = f.read()

if disable_client_auth:
self.logger.warn("Client-side authentication is disabled.")
if not require_client_auth:
self.logger.warning("Client-side authentication is disabled.")
private_key_b = None
certificate_b = None
else:
Expand Down Expand Up @@ -364,13 +361,13 @@ def reconnect(self):
# issued previously
self.disconnect()

if not self.tls:
if not self.use_tls:
self.channel = self.create_insecure_channel(self.uri)
else:
self.channel = self.create_tls_channel(
self.uri,
self.root_certificate,
self.disable_client_auth,
self.require_client_auth,
self.certificate,
self.private_key,
)
Expand Down
48 changes: 25 additions & 23 deletions openfl/transport/grpc/aggregator_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,12 @@ class AggregatorGRPCServer(aggregator_pb2_grpc.AggregatorServicer):
Attributes:
aggregator (Aggregator): The aggregator that this server is serving.
uri (str): The URI that the server is serving on.
tls (bool): Whether to use TLS for the connection.
disable_client_auth (bool): Whether to disable client-side
authentication.
root_certificate (str): The path to the root certificate for the TLS
connection.
certificate (str): The path to the server's certificate for the TLS
connection.
private_key (str): The path to the server's private key for the TLS
connection.
use_tls (bool): Whether to use TLS for the connection.
require_client_auth (bool): Whether to enable client-side authentication, i.e. mTLS.
Ignored if `use_tls=False`.
root_certificate (str): The path to the root certificate for the TLS connection, ignored if `use_tls=False`.
certificate (str): The path to the client's certificate for the TLS connection, ignored if `use_tls=False`.
private_key (str): The path to the client's private key for the TLS connection, ignored if `use_tls=False`.
server (grpc.Server): The gRPC server.
server_credentials (grpc.ServerCredentials): The server's credentials.
"""
Expand All @@ -45,8 +42,8 @@ def __init__(
self,
aggregator,
agg_port,
tls=True,
disable_client_auth=False,
use_tls=True,
require_client_auth=True,
root_certificate=None,
certificate=None,
private_key=None,
Expand All @@ -59,9 +56,9 @@ def __init__(
aggregator (Aggregator): The aggregator that this server is
serving.
agg_port (int): The port that the server is serving on.
tls (bool): Whether to use TLS for the connection.
disable_client_auth (bool): Whether to disable client-side
authentication.
use_tls (bool): Whether to use TLS for the connection.
require_client_auth (bool): Whether to enable client-side
authentication, i.e. mTLS. Ignored if `use_tls=False`.
root_certificate (str): The path to the root certificate for the
TLS connection.
certificate (str): The path to the server's certificate for the
Expand All @@ -70,10 +67,11 @@ def __init__(
TLS connection.
**kwargs: Additional keyword arguments.
"""
print(f"{use_tls=}")
self.aggregator = aggregator
self.uri = f"[::]:{agg_port}"
self.tls = tls
self.disable_client_auth = disable_client_auth
self.use_tls = use_tls
self.require_client_auth = require_client_auth
self.root_certificate = root_certificate
self.certificate = certificate
self.private_key = private_key
Expand All @@ -97,9 +95,13 @@ def validate_collaborator(self, request, context):
grpc.RpcError: If the collaborator or collaborator certificate is
not authorized.
"""
if self.tls:
common_name = context.auth_context()["x509_common_name"][0].decode("utf-8")
if self.use_tls:
collaborator_common_name = request.header.sender
if self.require_client_auth:
common_name = context.auth_context()["x509_common_name"][0].decode("utf-8")
else:
common_name = collaborator_common_name

if not self.aggregator.valid_collaborator_cn_and_id(
common_name, collaborator_common_name
):
Expand Down Expand Up @@ -306,8 +308,8 @@ def get_server(self):

aggregator_pb2_grpc.add_AggregatorServicer_to_server(self, self.server)

if not self.tls:
self.logger.warn("gRPC is running on insecure channel with TLS disabled.")
if not self.use_tls:
self.logger.warning("gRPC is running on insecure channel with TLS disabled.")
port = self.server.add_insecure_port(self.uri)
self.logger.info("Insecure port: %s", port)

Expand All @@ -319,13 +321,13 @@ def get_server(self):
with open(self.root_certificate, "rb") as f:
root_certificate_b = f.read()

if self.disable_client_auth:
self.logger.warn("Client-side authentication is disabled.")
if not self.require_client_auth:
self.logger.warning("Client-side authentication is disabled.")

self.server_credentials = ssl_server_credentials(
((private_key_b, certificate_b),),
root_certificates=root_certificate_b,
require_client_auth=not self.disable_client_auth,
require_client_auth=self.require_client_auth,
)

self.server.add_secure_port(self.uri, self.server_credentials)
Expand Down
2 changes: 1 addition & 1 deletion openfl/utilities/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def split_tensor_dict_for_holdouts(
try:
holdout_tensors[tensor_name] = tensors_to_send.pop(tensor_name)
except KeyError:
logger.warn(
logger.warning(
f"tried to remove tensor: {tensor_name} not present " f"in the tensor dict"
)
continue
Expand Down

0 comments on commit 5a59342

Please sign in to comment.