diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index 041860957db7..e5b16009e563 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -15,67 +15,18 @@ """Flower client interceptor.""" -import base64 -import collections -from collections.abc import Sequence -from logging import WARNING -from typing import Any, Callable, Optional, Union +from typing import Any, Callable import grpc from cryptography.hazmat.primitives.asymmetric import ec +from google.protobuf.message import Message as GrpcMessage -from flwr.common.logger import log +from flwr.common import now +from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - bytes_to_public_key, - compute_hmac, - generate_shared_key, public_key_to_bytes, + sign_message, ) -from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611 -from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 - CreateNodeRequest, - DeleteNodeRequest, - PingRequest, - PullTaskInsRequest, - PushTaskResRequest, -) -from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611 - -_PUBLIC_KEY_HEADER = "public-key" -_AUTH_TOKEN_HEADER = "auth-token" - -Request = Union[ - CreateNodeRequest, - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - PingRequest, - GetFabRequest, -] - - -def _get_value_from_tuples( - key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] -) -> bytes: - value = next((value for key, value in tuples if key == key_string), "") - if isinstance(value, str): - return value.encode() - - return value - - -class _ClientCallDetails( - collections.namedtuple( - "_ClientCallDetails", ("method", "timeout", "metadata", "credentials") - ), - grpc.ClientCallDetails, # type: ignore -): - """Details for each client call. - - The class will be passed on as the first argument in continuation function. - In our case, `AuthenticateClientInterceptor` adds new metadata to the construct. - """ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore @@ -87,84 +38,33 @@ def __init__( public_key: ec.EllipticCurvePublicKey, ): self.private_key = private_key - self.public_key = public_key - self.shared_secret: Optional[bytes] = None - self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None - self.encoded_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self.public_key) - ) + self.public_key_bytes = public_key_to_bytes(public_key) def intercept_unary_unary( self, continuation: Callable[[Any, Any], Any], client_call_details: grpc.ClientCallDetails, - request: Request, + request: GrpcMessage, ) -> grpc.Call: """Flower client interceptor. Intercept unary call from client and add necessary authentication header in the RPC metadata. """ - metadata = [] - postprocess = False - if client_call_details.metadata is not None: - metadata = list(client_call_details.metadata) - - # Always add the public key header - metadata.append( - ( - _PUBLIC_KEY_HEADER, - self.encoded_public_key, - ) - ) - - if isinstance(request, CreateNodeRequest): - postprocess = True - elif isinstance( - request, - ( - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - PingRequest, - GetFabRequest, - ), - ): - if self.shared_secret is None: - raise RuntimeError("Failure to compute hmac") - - message_bytes = request.SerializeToString(deterministic=True) - metadata.append( - ( - _AUTH_TOKEN_HEADER, - base64.urlsafe_b64encode( - compute_hmac(self.shared_secret, message_bytes) - ), - ) - ) + metadata = list(client_call_details.metadata or []) - client_call_details = _ClientCallDetails( - client_call_details.method, - client_call_details.timeout, - metadata, - client_call_details.credentials, - ) + # Add the public key + metadata.append((PUBLIC_KEY_HEADER, self.public_key_bytes)) - response = continuation(client_call_details, request) - if postprocess: - server_public_key_bytes = base64.urlsafe_b64decode( - _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata()) - ) + # Add timestamp + timestamp = now().isoformat() + metadata.append((TIMESTAMP_HEADER, timestamp)) - if server_public_key_bytes != b"": - self.server_public_key = bytes_to_public_key(server_public_key_bytes) - else: - log(WARNING, "Can't get server public key, SuperLink may be offline") + # Sign and add the signature + signature = sign_message(self.private_key, timestamp.encode("ascii")) + metadata.append((SIGNATURE_HEADER, signature)) - if self.server_public_key is not None: - self.shared_secret = generate_shared_key( - self.private_key, self.server_public_key - ) + # Overwrite the metadata + details = client_call_details._replace(metadata=metadata) - return response + return continuation(details, request) diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 9ea23e78c009..1017cf5dc154 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -115,6 +115,12 @@ CREDENTIALS_DIR = ".credentials" AUTH_TYPE = "auth_type" +# Constants for node authentication +PUBLIC_KEY_HEADER = "public-key-bin" # Must end with "-bin" for binary data +SIGNATURE_HEADER = "signature-bin" # Must end with "-bin" for binary data +TIMESTAMP_HEADER = "timestamp" +TIMESTAMP_TOLERANCE = 10 # Tolerance for timestamp verification + class MessageType: """Message type.""" diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index c07ee0788493..38ef0f829dc0 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -15,91 +15,54 @@ """Flower server interceptor.""" -import base64 -from collections.abc import Sequence -from logging import INFO, WARNING -from typing import Any, Callable, Optional, Union +import datetime +from typing import Any, Callable, Optional, cast import grpc -from cryptography.hazmat.primitives.asymmetric import ec - -from flwr.common.logger import log +from google.protobuf.message import Message as GrpcMessage + +from flwr.common import now +from flwr.common.constant import ( + PUBLIC_KEY_HEADER, + SIGNATURE_HEADER, + TIMESTAMP_HEADER, + TIMESTAMP_TOLERANCE, +) from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - bytes_to_private_key, bytes_to_public_key, - generate_shared_key, - verify_hmac, + verify_signature, ) -from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, - DeleteNodeRequest, - DeleteNodeResponse, - PingRequest, - PingResponse, - PullTaskInsRequest, - PullTaskInsResponse, - PushTaskResRequest, - PushTaskResResponse, ) -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.linkstate import LinkStateFactory -_PUBLIC_KEY_HEADER = "public-key" -_AUTH_TOKEN_HEADER = "auth-token" - -Request = Union[ - CreateNodeRequest, - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - PingRequest, - GetFabRequest, -] - -Response = Union[ - CreateNodeResponse, - DeleteNodeResponse, - PullTaskInsResponse, - PushTaskResResponse, - GetRunResponse, - PingResponse, - GetFabResponse, -] - -def _get_value_from_tuples( - key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] -) -> bytes: - value = next((value for key, value in tuples if key == key_string), "") - if isinstance(value, str): - return value.encode() +def _unary_unary_rpc_terminator(message: str) -> grpc.RpcMethodHandler: + def terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage: + context.abort(grpc.StatusCode.UNAUTHENTICATED, message) + raise RuntimeError("Should not reach this point") # Make mypy happy - return value + return grpc.unary_unary_rpc_method_handler(terminate) class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore - """Server interceptor for node authentication.""" - - def __init__(self, state_factory: LinkStateFactory): + """Server interceptor for node authentication. + + Parameters + ---------- + state_factory : LinkStateFactory + A factory for creating new instances of LinkState. + auto_auth : bool + If True, automatically authenticates nodes without verifying their public keys. + If False, only nodes with pre-stored public keys in the LinkState can be + authenticated. + """ + + def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False): self.state_factory = state_factory - state = self.state_factory.state() - - self.node_public_keys = state.get_node_public_keys() - if len(self.node_public_keys) == 0: - log(WARNING, "Authentication enabled, but no known public keys configured") - - private_key = state.get_server_private_key() - public_key = state.get_server_public_key() - - if private_key is None or public_key is None: - raise ValueError("Error loading authentication keys") - - self.server_private_key = bytes_to_private_key(private_key) - self.encoded_server_public_key = base64.urlsafe_b64encode(public_key) + self.auto_auth = auto_auth def intercept_service( self, @@ -112,116 +75,78 @@ def intercept_service( metadata sent by the node. Continue RPC call if node is authenticated, else, terminate RPC call by setting context to abort. """ + state = self.state_factory.state() + metadata_dict = dict(handler_call_details.invocation_metadata) + + # Retrieve info from the metadata + try: + node_pk_bytes = cast(bytes, metadata_dict[PUBLIC_KEY_HEADER]) + timestamp_iso = cast(str, metadata_dict[TIMESTAMP_HEADER]) + signature = cast(bytes, metadata_dict[SIGNATURE_HEADER]) + except KeyError: + return _unary_unary_rpc_terminator("Missing authentication metadata") + + if not self.auto_auth: + # Abort the RPC call if the node public key is not found + if node_pk_bytes not in state.get_node_public_keys(): + return _unary_unary_rpc_terminator("Public key not recognized") + + # Verify the signature + node_pk = bytes_to_public_key(node_pk_bytes) + if not verify_signature(node_pk, timestamp_iso.encode("ascii"), signature): + return _unary_unary_rpc_terminator("Invalid signature") + + # Verify the timestamp + current = now() + time_diff = current - datetime.datetime.fromisoformat(timestamp_iso) + # Abort the RPC call if the timestamp is too old or in the future + if not 0 < time_diff.total_seconds() < TIMESTAMP_TOLERANCE: + return _unary_unary_rpc_terminator("Invalid timestamp") + + # Continue the RPC call + expected_node_id = state.get_node_id(node_pk_bytes) + if not handler_call_details.method.endswith("CreateNode"): + if expected_node_id is None: + return _unary_unary_rpc_terminator("Invalid node ID") # One of the method handlers in # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer` method_handler: grpc.RpcMethodHandler = continuation(handler_call_details) - return self._generic_auth_unary_method_handler(method_handler) + return self._wrap_method_handler( + method_handler, expected_node_id, node_pk_bytes + ) - def _generic_auth_unary_method_handler( - self, method_handler: grpc.RpcMethodHandler + def _wrap_method_handler( + self, + method_handler: grpc.RpcMethodHandler, + expected_node_id: Optional[int], + node_public_key: bytes, ) -> grpc.RpcMethodHandler: def _generic_method_handler( - request: Request, + request: GrpcMessage, context: grpc.ServicerContext, - ) -> Response: - node_public_key_bytes = base64.urlsafe_b64decode( - _get_value_from_tuples( - _PUBLIC_KEY_HEADER, context.invocation_metadata() - ) - ) - if node_public_key_bytes not in self.node_public_keys: - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") - - if isinstance(request, CreateNodeRequest): - response = self._create_authenticated_node( - node_public_key_bytes, request, context - ) - log( - INFO, - "AuthenticateServerInterceptor: Created node_id=%s", - response.node.node_id, - ) - return response - - # Verify hmac value - hmac_value = base64.urlsafe_b64decode( - _get_value_from_tuples( - _AUTH_TOKEN_HEADER, context.invocation_metadata() - ) - ) - public_key = bytes_to_public_key(node_public_key_bytes) - - if not self._verify_hmac(public_key, request, hmac_value): - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") - - # Verify node_id - node_id = self.state_factory.state().get_node_id(node_public_key_bytes) - - if not self._verify_node_id(node_id, request): - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") - - return method_handler.unary_unary(request, context) # type: ignore + ) -> GrpcMessage: + # Verify the node ID + if not isinstance(request, CreateNodeRequest): + try: + if request.node.node_id != expected_node_id: # type: ignore + raise ValueError + except (AttributeError, ValueError): + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID") + + response: GrpcMessage = method_handler.unary_unary(request, context) + + # Set the public key after a successful CreateNode request + if isinstance(response, CreateNodeResponse): + state = self.state_factory.state() + try: + state.set_node_public_key(response.node.node_id, node_public_key) + except ValueError as e: + context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e)) + + return response return grpc.unary_unary_rpc_method_handler( _generic_method_handler, request_deserializer=method_handler.request_deserializer, response_serializer=method_handler.response_serializer, ) - - def _verify_node_id( - self, - node_id: Optional[int], - request: Union[ - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - PingRequest, - GetFabRequest, - ], - ) -> bool: - if node_id is None: - return False - if isinstance(request, PushTaskResRequest): - if len(request.task_res_list) == 0: - return False - return request.task_res_list[0].task.producer.node_id == node_id - if isinstance(request, GetRunRequest): - return node_id in self.state_factory.state().get_nodes(request.run_id) - return request.node.node_id == node_id - - def _verify_hmac( - self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes - ) -> bool: - shared_secret = generate_shared_key(self.server_private_key, public_key) - message_bytes = request.SerializeToString(deterministic=True) - return verify_hmac(shared_secret, message_bytes, hmac_value) - - def _create_authenticated_node( - self, - public_key_bytes: bytes, - request: CreateNodeRequest, - context: grpc.ServicerContext, - ) -> CreateNodeResponse: - context.send_initial_metadata( - ( - ( - _PUBLIC_KEY_HEADER, - self.encoded_server_public_key, - ), - ) - ) - state = self.state_factory.state() - node_id = state.get_node_id(public_key_bytes) - - # Handle `CreateNode` here instead of calling the default method handler - # Return previously assigned `node_id` for the provided `public_key` - if node_id is not None: - state.acknowledge_ping(node_id, request.ping_interval) - return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False)) - - # No `node_id` exists for the provided `public_key` - # Handle `CreateNode` here instead of calling the default method handler - # Note: the innermost `CreateNode` method will never be called - node_id = state.create_node(request.ping_interval, public_key_bytes) - return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False)) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index a0ff7a77304a..9984b93f3e84 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -161,9 +161,7 @@ def test_unsuccessful_create_node_with_metadata(self) -> None: def test_successful_delete_node_with_metadata(self) -> None: """Test server interceptor for deleting node.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = DeleteNodeRequest(node=Node(node_id=node_id)) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key @@ -191,9 +189,7 @@ def test_successful_delete_node_with_metadata(self) -> None: def test_unsuccessful_delete_node_with_metadata(self) -> None: """Test server interceptor for deleting node unsuccessfully.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = DeleteNodeRequest(node=Node(node_id=node_id)) node_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(node_private_key, self._server_public_key) @@ -217,9 +213,7 @@ def test_unsuccessful_delete_node_with_metadata(self) -> None: def test_successful_pull_task_ins_with_metadata(self) -> None: """Test server interceptor for pull task ins.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = PullTaskInsRequest(node=Node(node_id=node_id)) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key @@ -247,9 +241,7 @@ def test_successful_pull_task_ins_with_metadata(self) -> None: def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: """Test server interceptor for pull task ins unsuccessfully.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = PullTaskInsRequest(node=Node(node_id=node_id)) node_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(node_private_key, self._server_public_key) @@ -273,9 +265,7 @@ def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: def test_successful_push_task_res_with_metadata(self) -> None: """Test server interceptor for push task res.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. PushTaskRes is only allowed in running status. _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) @@ -311,9 +301,7 @@ def test_successful_push_task_res_with_metadata(self) -> None: def test_unsuccessful_push_task_res_with_metadata(self) -> None: """Test server interceptor for push task res unsuccessfully.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. PushTaskRes is only allowed in running status. _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) @@ -344,9 +332,7 @@ def test_unsuccessful_push_task_res_with_metadata(self) -> None: def test_successful_get_run_with_metadata(self) -> None: """Test server interceptor for get run.""" # Prepare - self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. GetRun is only allowed in running status. _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) @@ -378,9 +364,7 @@ def test_successful_get_run_with_metadata(self) -> None: def test_unsuccessful_get_run_with_metadata(self) -> None: """Test server interceptor for get run unsuccessfully.""" # Prepare - self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) request = GetRunRequest(run_id=run_id) node_private_key, _ = generate_key_pairs() @@ -405,9 +389,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None: def test_successful_ping_with_metadata(self) -> None: """Test server interceptor for ping.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = PingRequest(node=Node(node_id=node_id)) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key @@ -435,9 +417,7 @@ def test_successful_ping_with_metadata(self) -> None: def test_unsuccessful_ping_with_metadata(self) -> None: """Test server interceptor for ping unsuccessfully.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = PingRequest(node=Node(node_id=node_id)) node_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(node_private_key, self._server_public_key) @@ -458,65 +438,8 @@ def test_unsuccessful_ping_with_metadata(self) -> None: ), ) - def test_successful_restore_node(self) -> None: - """Test server interceptor for restoring node.""" - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - response, call = self._create_node.with_call( - request=CreateNodeRequest(), - metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), - ) - - expected_metadata = ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._server_public_key) - ).decode(), - ) - - node = response.node - node_node_id = node.node_id - - assert call.initial_metadata()[0] == expected_metadata - assert isinstance(response, CreateNodeResponse) - - request = DeleteNodeRequest(node=node) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - response, call = self._delete_node.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - assert isinstance(response, DeleteNodeResponse) - assert grpc.StatusCode.OK == call.code() - - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - response, call = self._create_node.with_call( - request=CreateNodeRequest(), - metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), - ) - - expected_metadata = ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._server_public_key) - ).decode(), - ) - - assert call.initial_metadata()[0] == expected_metadata - assert isinstance(response, CreateNodeResponse) - assert response.node.node_id == node_node_id + def _create_node_and_set_public_key(self) -> int: + node_id = self.state.create_node(ping_interval=30) + pk_bytes = public_key_to_bytes(self._node_public_key) + self.state.set_node_public_key(node_id, pk_bytes) + return node_id diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index 19024c0f1948..ae26f0992e6b 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -62,6 +62,7 @@ def __init__(self) -> None: # Map node_id to (online_until, ping_interval) self.node_ids: dict[int, tuple[float, float]] = {} self.public_key_to_node_id: dict[bytes, int] = {} + self.node_id_to_public_key: dict[int, bytes] = {} # Map run_id to RunRecord self.run_ids: dict[int, RunRecord] = {} @@ -306,9 +307,7 @@ def num_task_res(self) -> int: """ return len(self.task_res_store) - def create_node( - self, ping_interval: float, public_key: Optional[bytes] = None - ) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in the link state, and return `node_id`.""" # Sample a random int64 as node_id node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) @@ -318,33 +317,18 @@ def create_node( log(ERROR, "Unexpected node registration failure.") return 0 - if public_key is not None: - if ( - public_key in self.public_key_to_node_id - or node_id in self.public_key_to_node_id.values() - ): - log(ERROR, "Unexpected node registration failure.") - return 0 - - self.public_key_to_node_id[public_key] = node_id - self.node_ids[node_id] = (time.time() + ping_interval, ping_interval) return node_id - def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: + def delete_node(self, node_id: int) -> None: """Delete a node.""" with self.lock: if node_id not in self.node_ids: raise ValueError(f"Node {node_id} not found") - if public_key is not None: - if ( - public_key not in self.public_key_to_node_id - or node_id not in self.public_key_to_node_id.values() - ): - raise ValueError("Public key or node_id not found") - - del self.public_key_to_node_id[public_key] + # Remove node ID <> public key mappings + if pk := self.node_id_to_public_key.pop(node_id, None): + del self.public_key_to_node_id[pk] del self.node_ids[node_id] @@ -366,6 +350,26 @@ def get_nodes(self, run_id: int) -> set[int]: if online_until > current_time } + def set_node_public_key(self, node_id: int, public_key: bytes) -> None: + """Set `public_key` for the specified `node_id`.""" + with self.lock: + if node_id not in self.node_ids: + raise ValueError(f"Node {node_id} not found") + + if public_key in self.public_key_to_node_id: + raise ValueError("Public key already in use") + + self.public_key_to_node_id[public_key] = node_id + self.node_id_to_public_key[node_id] = public_key + + def get_node_public_key(self, node_id: int) -> Optional[bytes]: + """Get `public_key` for the specified `node_id`.""" + with self.lock: + if node_id not in self.node_ids: + raise ValueError(f"Node {node_id} not found") + + return self.node_id_to_public_key.get(node_id) + def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" return self.public_key_to_node_id.get(node_public_key) diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index 4f3c16a5460a..e1eccf2b8b2f 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -154,13 +154,11 @@ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: """Get all TaskIns IDs for the given run_id.""" @abc.abstractmethod - def create_node( - self, ping_interval: float, public_key: Optional[bytes] = None - ) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in the link state, and return `node_id`.""" @abc.abstractmethod - def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: + def delete_node(self, node_id: int) -> None: """Remove `node_id` from the link state.""" @abc.abstractmethod @@ -173,6 +171,14 @@ def get_nodes(self, run_id: int) -> set[int]: an empty `Set` MUST be returned. """ + @abc.abstractmethod + def set_node_public_key(self, node_id: int, public_key: bytes) -> None: + """Set `public_key` for the specified `node_id`.""" + + @abc.abstractmethod + def get_node_public_key(self, node_id: int) -> Optional[bytes]: + """Get `public_key` for the specified `node_id`.""" + @abc.abstractmethod def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 3edaf72ec20c..fd1051e1cbfc 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -588,7 +588,8 @@ def test_create_node_public_key(self) -> None: run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute - node_id = state.create_node(ping_interval=10, public_key=public_key) + node_id = state.create_node(ping_interval=10) + state.set_node_public_key(node_id, public_key) retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(public_key) @@ -602,15 +603,21 @@ def test_create_node_public_key_twice(self) -> None: state: LinkState = self.state_factory() public_key = b"mock" run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) - node_id = state.create_node(ping_interval=10, public_key=public_key) + node_id = state.create_node(ping_interval=10) + state.set_node_public_key(node_id, public_key) # Execute - new_node_id = state.create_node(ping_interval=10, public_key=public_key) + new_node_id = state.create_node(ping_interval=10) + try: + state.set_node_public_key(new_node_id, public_key) + except ValueError: + state.delete_node(new_node_id) + else: + raise AssertionError("Should have raised ValueError") retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(public_key) # Assert - assert new_node_id == 0 assert len(retrieved_node_ids) == 1 assert retrieved_node_id == node_id @@ -639,10 +646,11 @@ def test_delete_node_public_key(self) -> None: state: LinkState = self.state_factory() public_key = b"mock" run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) - node_id = state.create_node(ping_interval=10, public_key=public_key) + node_id = state.create_node(ping_interval=10) + state.set_node_public_key(node_id, public_key) # Execute - state.delete_node(node_id, public_key=public_key) + state.delete_node(node_id) retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(public_key) @@ -650,43 +658,6 @@ def test_delete_node_public_key(self) -> None: assert len(retrieved_node_ids) == 0 assert retrieved_node_id is None - def test_delete_node_public_key_none(self) -> None: - """Test deleting a client node with public key.""" - # Prepare - state: LinkState = self.state_factory() - public_key = b"mock" - run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) - node_id = 0 - - # Execute & Assert - with self.assertRaises(ValueError): - state.delete_node(node_id, public_key=public_key) - - retrieved_node_ids = state.get_nodes(run_id) - retrieved_node_id = state.get_node_id(public_key) - - assert len(retrieved_node_ids) == 0 - assert retrieved_node_id is None - - def test_delete_node_wrong_public_key(self) -> None: - """Test deleting a client node with wrong public key.""" - # Prepare - state: LinkState = self.state_factory() - public_key = b"mock" - wrong_public_key = b"mock_mock" - run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) - node_id = state.create_node(ping_interval=10, public_key=public_key) - - # Execute & Assert - with self.assertRaises(ValueError): - state.delete_node(node_id, public_key=wrong_public_key) - - retrieved_node_ids = state.get_nodes(run_id) - retrieved_node_id = state.get_node_id(public_key) - - assert len(retrieved_node_ids) == 1 - assert retrieved_node_id == node_id - def test_get_node_id_wrong_public_key(self) -> None: """Test retrieving a client node with wrong public key.""" # Prepare @@ -696,7 +667,8 @@ def test_get_node_id_wrong_public_key(self) -> None: run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute - state.create_node(ping_interval=10, public_key=public_key) + node_id = state.create_node(ping_interval=10) + state.set_node_public_key(node_id, public_key) retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(wrong_public_key) diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index e8311dfaac5e..cc773f7b93de 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -72,14 +72,14 @@ SQL_CREATE_TABLE_CREDENTIAL = """ CREATE TABLE IF NOT EXISTS credential( - private_key BLOB PRIMARY KEY, - public_key BLOB + private_key BLOB PRIMARY KEY, + public_key BLOB ); """ SQL_CREATE_TABLE_PUBLIC_KEY = """ CREATE TABLE IF NOT EXISTS public_key( - public_key BLOB UNIQUE + public_key BLOB PRIMARY KEY ); """ @@ -635,9 +635,7 @@ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: return {UUID(row["task_id"]) for row in rows} - def create_node( - self, ping_interval: float, public_key: Optional[bytes] = None - ) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in the link state, and return `node_id`.""" # Sample a random uint64 as node_id uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) @@ -645,13 +643,6 @@ def create_node( # Convert the uint64 value to sint64 for SQLite sint64_node_id = convert_uint64_to_sint64(uint64_node_id) - query = "SELECT node_id FROM node WHERE public_key = :public_key;" - row = self.query(query, {"public_key": public_key}) - - if len(row) > 0: - log(ERROR, "Unexpected node registration failure.") - return 0 - query = ( "INSERT INTO node " "(node_id, online_until, ping_interval, public_key) " @@ -665,7 +656,7 @@ def create_node( sint64_node_id, time.time() + ping_interval, ping_interval, - public_key, + b"", # Initialize with an empty public key ), ) except sqlite3.IntegrityError: @@ -675,7 +666,7 @@ def create_node( # Note: we need to return the uint64 value of the node_id return uint64_node_id - def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: + def delete_node(self, node_id: int) -> None: """Delete a node.""" # Convert the uint64 value to sint64 for SQLite sint64_node_id = convert_uint64_to_sint64(node_id) @@ -683,10 +674,6 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: query = "DELETE FROM node WHERE node_id = ?" params = (sint64_node_id,) - if public_key is not None: - query += " AND public_key = ?" - params += (public_key,) # type: ignore - if self.conn is None: raise AttributeError("LinkState is not initialized.") @@ -694,7 +681,7 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: with self.conn: rows = self.conn.execute(query, params) if rows.rowcount < 1: - raise ValueError("Public key or node_id not found") + raise ValueError(f"Node {node_id} not found") except KeyError as exc: log(ERROR, {"query": query, "data": params, "exception": exc}) @@ -722,6 +709,41 @@ def get_nodes(self, run_id: int) -> set[int]: result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows} return result + def set_node_public_key(self, node_id: int, public_key: bytes) -> None: + """Set `public_key` for the specified `node_id`.""" + # Convert the uint64 value to sint64 for SQLite + sint64_node_id = convert_uint64_to_sint64(node_id) + + # Check if the node exists in the `node` table + query = "SELECT 1 FROM node WHERE node_id = ?" + if not self.query(query, (sint64_node_id,)): + raise ValueError(f"Node {node_id} not found") + + # Check if the public key is already in use in the `node` table + query = "SELECT 1 FROM node WHERE public_key = ?" + if self.query(query, (public_key,)): + raise ValueError("Public key already in use") + + # Update the `node` table to set the public key for the given node ID + query = "UPDATE node SET public_key = ? WHERE node_id = ?" + self.query(query, (public_key, sint64_node_id)) + + def get_node_public_key(self, node_id: int) -> Optional[bytes]: + """Get `public_key` for the specified `node_id`.""" + # Convert the uint64 value to sint64 for SQLite + sint64_node_id = convert_uint64_to_sint64(node_id) + + # Query the public key for the given node_id + query = "SELECT public_key FROM node WHERE node_id = ?" + rows = self.query(query, (sint64_node_id,)) + + # If no result is found, return None + if not rows: + raise ValueError(f"Node {node_id} not found") + + # Return the public key if it is not empty, otherwise return None + return rows[0]["public_key"] or None + def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" query = "SELECT node_id FROM node WHERE public_key = :public_key;"