Skip to content

Commit

Permalink
Replace checks with inline assertions
Browse files Browse the repository at this point in the history
Signed-off-by: Shah, Karan <[email protected]>
  • Loading branch information
MasterSkepticista committed Mar 7, 2025
1 parent 512e018 commit c359b43
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 140 deletions.
29 changes: 17 additions & 12 deletions openfl/experimental/workflow/transport/grpc/aggregator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from openfl.experimental.workflow.protocols import aggregator_pb2, aggregator_pb2_grpc
from openfl.experimental.workflow.transport.grpc.grpc_channel_options import channel_options
from openfl.utilities import check_equal


class ConstantBackoff:
Expand Down Expand Up @@ -221,17 +220,23 @@ def _set_header(self, collaborator_name):

def validate_response(self, reply, collaborator_name):
"""Validate the aggregator response."""
# check that the message was intended to go to this collaborator
check_equal(reply.header.receiver, collaborator_name)
check_equal(reply.header.sender, self.aggregator_uuid)

# check that federation id matches
check_equal(reply.header.federation_uuid, self.federation_uuid)

# check that there is aggrement on the single_col_cert_common_name
check_equal(
reply.header.single_col_cert_common_name,
self.single_col_cert_common_name or "",
assert reply.header.receiver == collaborator_name, (
f"Receiver in response header does not match collaborator name. "
f"Expected: {collaborator_name}, Actual: {reply.header.receiver}"
)
assert reply.header.sender == self.aggregator_uuid, (
f"Sender in response header does not match aggregator UUID. "
f"Expected: {self.aggregator_uuid}, Actual: {reply.header.sender}"
)
assert reply.header.federation_uuid == self.federation_uuid, (
f"Federation UUID in response header does not match. "
f"Expected: {self.federation_uuid}, Actual: {reply.header.federation_uuid}"
)
assert reply.header.single_col_cert_common_name == (
self.single_col_cert_common_name or ""
), (
f"Single collaborator certificate common name in response header does not match. "
f"Expected: {self.single_col_cert_common_name or ''}, Actual: {reply.header.single_col_cert_common_name}" # noqa: E501
)

def disconnect(self):
Expand Down
30 changes: 17 additions & 13 deletions openfl/experimental/workflow/transport/grpc/aggregator_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from openfl.experimental.workflow.protocols import aggregator_pb2, aggregator_pb2_grpc
from openfl.experimental.workflow.transport.grpc.grpc_channel_options import channel_options
from openfl.utilities import check_equal, check_is_in

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,22 +105,27 @@ def check_request(self, request):
request : protobuf
Request sent from a collaborator that requires validation
"""
# TODO improve this check. the sender name could be spoofed
check_is_in(request.header.sender, self.aggregator.authorized_cols)
assert request.header.sender in self.aggregator.authorized_cols, (
f"Sender in request header is not authorized. "
f"Expected: one of {self.aggregator.authorized_cols}, Actual: {request.header.sender}"
)

# check that the message is for me
check_equal(request.header.receiver, self.aggregator.uuid)
assert request.header.receiver == self.aggregator.uuid, (
f"Receiver in request header does not match aggregator UUID. "
f"Expected: {self.aggregator.uuid}, Actual: {request.header.receiver}"
)

# check that the message is for my federation
check_equal(
request.header.federation_uuid,
self.aggregator.federation_uuid,
assert request.header.federation_uuid == self.aggregator.federation_uuid, (
f"Federation UUID in request header does not match. "
f"Expected: {self.aggregator.federation_uuid}, Actual: {request.header.federation_uuid}"
)

# check that we agree on the single cert common name
check_equal(
request.header.single_col_cert_common_name,
self.aggregator.single_col_cert_common_name,
assert (
request.header.single_col_cert_common_name
== self.aggregator.single_col_cert_common_name
), (
f"Single collaborator certificate common name in request header does not match. "
f"Expected: {self.aggregator.single_col_cert_common_name}, Actual: {request.header.single_col_cert_common_name}" # noqa: E501
)

def SendTaskResults(self, request, context): # NOQA:N802
Expand Down
24 changes: 19 additions & 5 deletions openfl/transport/grpc/aggregator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, utils
from openfl.transport.grpc.common import create_header, create_insecure_channel, create_tls_channel
from openfl.utilities import check_equal

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -244,10 +243,22 @@ def __init__(

def validate_response(self, response, collaborator_name):
"""Validate the aggregator response."""
check_equal(response.header.receiver, collaborator_name)
check_equal(response.header.sender, self.aggregator_uuid)
check_equal(response.header.federation_uuid, self.federation_uuid)
check_equal(response.header.single_col_cert_common_name, self.single_col_cert_common_name)
assert response.header.receiver == collaborator_name, (
f"Receiver in response header does not match collaborator name. "
f"Expected: {collaborator_name}, Actual: {response.header.receiver}"
)
assert response.header.sender == self.aggregator_uuid, (
f"Sender in response header does not match aggregator UUID. "
f"Expected: {self.aggregator_uuid}, Actual: {response.header.sender}"
)
assert response.header.federation_uuid == self.federation_uuid, (
f"Federation UUID in response header does not match. "
f"Expected: {self.federation_uuid}, Actual: {response.header.federation_uuid}"
)
assert response.header.single_col_cert_common_name == self.single_col_cert_common_name, (
f"Single collaborator certificate common name in response header does not match. "
f"Expected: {self.single_col_cert_common_name}, Actual: {response.header.single_col_cert_common_name}" # noqa: E501
)

def disconnect(self):
"""Close the gRPC channel."""
Expand Down Expand Up @@ -347,6 +358,9 @@ def get_aggregated_tensor(
)
response = self.stub.GetAggregatedTensor(request)
self.validate_response(response, collaborator_name)

# Deserialize Tensor.

return response.tensor

@_resend_data_on_reconnection
Expand Down
30 changes: 17 additions & 13 deletions openfl/transport/grpc/aggregator_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, utils
from openfl.transport.grpc.common import create_grpc_server, create_header
from openfl.utilities import check_equal, check_is_in

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -103,22 +102,27 @@ def check_request(self, request):
Raises:
ValueError: If the request is not valid.
"""
# TODO improve this check. the sender name could be spoofed
check_is_in(request.header.sender, self.aggregator.authorized_cols)
assert request.header.sender in self.aggregator.authorized_cols, (
f"Sender in request header is not authorized. "
f"Expected: one of {self.aggregator.authorized_cols}, Actual: {request.header.sender}"
)

# check that the message is for me
check_equal(request.header.receiver, self.aggregator.uuid)
assert request.header.receiver == self.aggregator.uuid, (
f"Receiver in request header does not match aggregator UUID. "
f"Expected: {self.aggregator.uuid}, Actual: {request.header.receiver}"
)

# check that the message is for my federation
check_equal(
request.header.federation_uuid,
self.aggregator.federation_uuid,
assert request.header.federation_uuid == self.aggregator.federation_uuid, (
f"Federation UUID in request header does not match. "
f"Expected: {self.aggregator.federation_uuid}, Actual: {request.header.federation_uuid}"
)

# check that we agree on the single cert common name
check_equal(
request.header.single_col_cert_common_name,
self.aggregator.single_col_cert_common_name,
assert (
request.header.single_col_cert_common_name
== self.aggregator.single_col_cert_common_name
), (
f"Single collaborator certificate common name in request header does not match. "
f"Expected: {self.aggregator.single_col_cert_common_name}, Actual: {request.header.single_col_cert_common_name}" # noqa: E501
)

def GetTasks(self, request, context): # NOQA:N802
Expand Down
97 changes: 0 additions & 97 deletions openfl/utilities/checks.py

This file was deleted.

0 comments on commit c359b43

Please sign in to comment.