From 2ad33700fec6694ea0f995b0c6c90d3747678940 Mon Sep 17 00:00:00 2001 From: miro Date: Tue, 7 Jan 2025 12:21:41 +0000 Subject: [PATCH] improve thread safety --- hivemind_http_protocol/__init__.py | 99 +++++++++++++++++------------- 1 file changed, 55 insertions(+), 44 deletions(-) diff --git a/hivemind_http_protocol/__init__.py b/hivemind_http_protocol/__init__.py index a151c63..b01175a 100644 --- a/hivemind_http_protocol/__init__.py +++ b/hivemind_http_protocol/__init__.py @@ -3,12 +3,13 @@ import os import os.path import random -import threading +from collections import defaultdict +from queue import Queue from os import makedirs from os.path import exists, join from socket import gethostname -from typing import Dict, Any, Optional, Tuple, List -from collections import defaultdict +from typing import Dict, Any, Optional, Tuple, Union + import pybase64 from OpenSSL import crypto from ovos_bus_client.session import Session @@ -28,11 +29,6 @@ from hivemind_plugin_manager.protocols import NetworkProtocol from poorman_handshake import PasswordHandShake -_LOCK = threading.RLock() -CLIENTS: Dict[str, HiveMindClientConnection] = {} -UNDELIVERED: Dict[str, List[str]] = defaultdict(list) # key: [messages] -UNDELIVERED_BIN: Dict[str, List[str]] = defaultdict(list) # key: [b64_messages] - @dataclasses.dataclass class HiveMindHttpProtocol(NetworkProtocol): @@ -46,7 +42,6 @@ class HiveMindHttpProtocol(NetworkProtocol): hm_protocol: Optional[HiveMindListenerProtocol] = None callbacks: ClientCallbacks = dataclasses.field(default_factory=ClientCallbacks) - def run(self): LOG.debug(f"HTTP server config: {self.config}") asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) @@ -70,10 +65,10 @@ def run(self): cert_file = f"{cert_dir}/{cert_name}.crt" key_file = f"{cert_dir}/{cert_name}.key" if not os.path.isfile(key_file): - LOG.info(f"generating self-signed SSL certificate") + LOG.info(f"Generating self-signed SSL certificate") cert_file, key_file = self.create_self_signed_cert(cert_dir, cert_name) - LOG.debug("using ssl key at " + key_file) - LOG.debug("using ssl certificate at " + cert_file) + LOG.debug("Using SSL key at " + key_file) + LOG.debug("Using SSL certificate at " + cert_file) ssl_options = {"certfile": cert_file, "keyfile": key_file} LOG.info(f"HTTPS listener started at port: {port}") application.listen(port, host, ssl_options=ssl_options) @@ -105,7 +100,7 @@ def create_self_signed_cert( makedirs(cert_dir, exist_ok=True) if not exists(join(cert_dir, cert_file)) or not exists(join(cert_dir, key_file)): - # create a key pair + # Create a key pair k = crypto.PKey() k.generate_key(crypto.TYPE_RSA, 2048) @@ -135,6 +130,11 @@ class HiveMindHttpHandler(web.RequestHandler): """Base handler for HTTP requests.""" hm_protocol = None + # Class-level properties for managing client state and message queues + clients: Dict[str, HiveMindClientConnection] = {} + undelivered: Dict[str, Queue] = defaultdict(Queue) # Non-binary messages + undelivered_bin: Dict[str, Queue] = defaultdict(Queue) # Binary messages + def decode_auth(self): auth = self.get_argument("authorization", "") if not auth: @@ -144,26 +144,22 @@ def decode_auth(self): userpass_decoded = pybase64.b64decode(userpass_encoded).decode("utf-8") return userpass_decoded.split(":") - def get_client(self, useragent, key, cache = True) -> Optional[HiveMindClientConnection]: - global CLIENTS, UNDELIVERED - - if cache and key in CLIENTS: - return CLIENTS[key] + def get_client(self, useragent, key, cache=True) -> Optional[HiveMindClientConnection]: + if cache and key in self.clients: + return self.clients[key] - def do_send(payload: str, is_bin: bool): - with _LOCK: - if is_bin: - payload = pybase64.b64encode(payload).decode("utf-8") - UNDELIVERED_BIN[key].append(payload) - else: - UNDELIVERED[key].append(payload) + def do_send(payload: Union[bytes, str], is_bin: bool): + if is_bin: + payload = pybase64.b64encode(payload).decode("utf-8") + self.undelivered_bin[key].put(payload) + else: + self.undelivered[key].put(payload) def do_disconnect(): - with _LOCK: - if key in UNDELIVERED: - UNDELIVERED.pop(key) - if key in CLIENTS: - CLIENTS.pop(key) + if key in self.undelivered: + self.undelivered.pop(key) + if key in self.clients: + self.clients.pop(key) client = HiveMindClientConnection( key=key, @@ -176,7 +172,7 @@ def do_disconnect(): self.hm_protocol.db.sync() user = self.hm_protocol.db.get_client_by_api_key(key) if not user: - LOG.error("Client provided an invalid api key") + LOG.error("Client provided an invalid Access key") self.hm_protocol.handle_invalid_key_connected(client) return None @@ -195,7 +191,7 @@ def do_disconnect(): client.node_type = HiveMindNodeType.NODE # TODO . placeholder if cache: - CLIENTS[key] = client + self.clients[key] = client return client @@ -232,18 +228,16 @@ async def post(self): class DisconnectHandler(HiveMindHttpHandler): async def post(self): - global CLIENTS try: useragent, key = self.decode_auth() if not key: self.write({"error": "Missing authorization"}) return - if key in CLIENTS: + if key in HiveMindHttpHandler.clients: client = self.get_client(useragent, key) LOG.info(f"disconnecting client: {client.peer}") self.hm_protocol.handle_client_disconnected(client) - CLIENTS.pop(key) self.write({"status": "Disconnected"}) else: self.write({"error": "Already Disconnected"}) @@ -261,7 +255,7 @@ async def post(self): self.write({"error": "Missing authorization"}) return # refuse if connect wasnt called first - if key not in CLIENTS: + if key not in HiveMindHttpHandler.clients: self.write({"error": "Client is not connected"}) return @@ -300,13 +294,21 @@ async def get(self): return # refuse if connect wasnt called first - if key not in CLIENTS: + if key not in HiveMindHttpHandler.clients: self.write({"error": "Client is not connected"}) return - # send non-binary payloads to the client - messages = UNDELIVERED[key] - UNDELIVERED[key] = [] + messages = [] + queue = HiveMindHttpHandler.undelivered[key] + + # Retrieve all messages from the queue + while not queue.empty(): + try: + message = queue.get_nowait() + messages.append(message) + except Exception as e: + # Handle unexpected errors (unlikely with get_nowait) + break self.write({"status": "messages retrieved", "messages": messages}) except Exception as e: LOG.error(f"Retrieving messages failed: {e}") @@ -324,13 +326,22 @@ async def get(self): return # refuse if connect wasnt called first - if key not in CLIENTS: + if key not in HiveMindHttpHandler.clients: self.write({"error": "Client is not connected"}) return - # send non-binary payloads to the client - messages = UNDELIVERED_BIN[key] - UNDELIVERED_BIN[key] = [] + messages = [] + queue = HiveMindHttpHandler.undelivered_bin[key] + + # Retrieve all messages from the queue + while not queue.empty(): + try: + message = queue.get_nowait() + messages.append(message) + except Exception as e: + # Handle unexpected errors (unlikely with get_nowait) + break + self.write({"status": "messages retrieved", "b64_messages": messages}) except Exception as e: LOG.error(f"Retrieving messages failed: {e}")