Skip to content

Commit

Permalink
improve thread safety
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Jan 7, 2025
1 parent ce902ba commit 2ad3370
Showing 1 changed file with 55 additions and 44 deletions.
99 changes: 55 additions & 44 deletions hivemind_http_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import os
import os.path
import random
import threading
from collections import defaultdict
from queue import Queue
from os import makedirs
from os.path import exists, join
from socket import gethostname
from typing import Dict, Any, Optional, Tuple, List
from collections import defaultdict
from typing import Dict, Any, Optional, Tuple, Union

import pybase64
from OpenSSL import crypto
from ovos_bus_client.session import Session
Expand All @@ -28,11 +29,6 @@
from hivemind_plugin_manager.protocols import NetworkProtocol
from poorman_handshake import PasswordHandShake

_LOCK = threading.RLock()
CLIENTS: Dict[str, HiveMindClientConnection] = {}
UNDELIVERED: Dict[str, List[str]] = defaultdict(list) # key: [messages]
UNDELIVERED_BIN: Dict[str, List[str]] = defaultdict(list) # key: [b64_messages]


@dataclasses.dataclass
class HiveMindHttpProtocol(NetworkProtocol):
Expand All @@ -46,7 +42,6 @@ class HiveMindHttpProtocol(NetworkProtocol):
hm_protocol: Optional[HiveMindListenerProtocol] = None
callbacks: ClientCallbacks = dataclasses.field(default_factory=ClientCallbacks)


def run(self):
LOG.debug(f"HTTP server config: {self.config}")
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
Expand All @@ -70,10 +65,10 @@ def run(self):
cert_file = f"{cert_dir}/{cert_name}.crt"
key_file = f"{cert_dir}/{cert_name}.key"
if not os.path.isfile(key_file):
LOG.info(f"generating self-signed SSL certificate")
LOG.info(f"Generating self-signed SSL certificate")
cert_file, key_file = self.create_self_signed_cert(cert_dir, cert_name)
LOG.debug("using ssl key at " + key_file)
LOG.debug("using ssl certificate at " + cert_file)
LOG.debug("Using SSL key at " + key_file)
LOG.debug("Using SSL certificate at " + cert_file)
ssl_options = {"certfile": cert_file, "keyfile": key_file}
LOG.info(f"HTTPS listener started at port: {port}")
application.listen(port, host, ssl_options=ssl_options)
Expand Down Expand Up @@ -105,7 +100,7 @@ def create_self_signed_cert(
makedirs(cert_dir, exist_ok=True)

if not exists(join(cert_dir, cert_file)) or not exists(join(cert_dir, key_file)):
# create a key pair
# Create a key pair
k = crypto.PKey()
k.generate_key(crypto.TYPE_RSA, 2048)

Expand Down Expand Up @@ -135,6 +130,11 @@ class HiveMindHttpHandler(web.RequestHandler):
"""Base handler for HTTP requests."""
hm_protocol = None

# Class-level properties for managing client state and message queues
clients: Dict[str, HiveMindClientConnection] = {}
undelivered: Dict[str, Queue] = defaultdict(Queue) # Non-binary messages
undelivered_bin: Dict[str, Queue] = defaultdict(Queue) # Binary messages

def decode_auth(self):
auth = self.get_argument("authorization", "")
if not auth:
Expand All @@ -144,26 +144,22 @@ def decode_auth(self):
userpass_decoded = pybase64.b64decode(userpass_encoded).decode("utf-8")
return userpass_decoded.split(":")

def get_client(self, useragent, key, cache = True) -> Optional[HiveMindClientConnection]:
global CLIENTS, UNDELIVERED

if cache and key in CLIENTS:
return CLIENTS[key]
def get_client(self, useragent, key, cache=True) -> Optional[HiveMindClientConnection]:
if cache and key in self.clients:
return self.clients[key]

def do_send(payload: str, is_bin: bool):
with _LOCK:
if is_bin:
payload = pybase64.b64encode(payload).decode("utf-8")
UNDELIVERED_BIN[key].append(payload)
else:
UNDELIVERED[key].append(payload)
def do_send(payload: Union[bytes, str], is_bin: bool):
if is_bin:
payload = pybase64.b64encode(payload).decode("utf-8")
self.undelivered_bin[key].put(payload)
else:
self.undelivered[key].put(payload)

def do_disconnect():
with _LOCK:
if key in UNDELIVERED:
UNDELIVERED.pop(key)
if key in CLIENTS:
CLIENTS.pop(key)
if key in self.undelivered:
self.undelivered.pop(key)
if key in self.clients:
self.clients.pop(key)

client = HiveMindClientConnection(
key=key,
Expand All @@ -176,7 +172,7 @@ def do_disconnect():
self.hm_protocol.db.sync()
user = self.hm_protocol.db.get_client_by_api_key(key)
if not user:
LOG.error("Client provided an invalid api key")
LOG.error("Client provided an invalid Access key")
self.hm_protocol.handle_invalid_key_connected(client)
return None

Expand All @@ -195,7 +191,7 @@ def do_disconnect():

client.node_type = HiveMindNodeType.NODE # TODO . placeholder
if cache:
CLIENTS[key] = client
self.clients[key] = client
return client


Expand Down Expand Up @@ -232,18 +228,16 @@ async def post(self):

class DisconnectHandler(HiveMindHttpHandler):
async def post(self):
global CLIENTS

try:
useragent, key = self.decode_auth()
if not key:
self.write({"error": "Missing authorization"})
return
if key in CLIENTS:
if key in HiveMindHttpHandler.clients:
client = self.get_client(useragent, key)
LOG.info(f"disconnecting client: {client.peer}")
self.hm_protocol.handle_client_disconnected(client)
CLIENTS.pop(key)
self.write({"status": "Disconnected"})
else:
self.write({"error": "Already Disconnected"})
Expand All @@ -261,7 +255,7 @@ async def post(self):
self.write({"error": "Missing authorization"})
return
# refuse if connect wasnt called first
if key not in CLIENTS:
if key not in HiveMindHttpHandler.clients:
self.write({"error": "Client is not connected"})
return

Expand Down Expand Up @@ -300,13 +294,21 @@ async def get(self):
return

# refuse if connect wasnt called first
if key not in CLIENTS:
if key not in HiveMindHttpHandler.clients:
self.write({"error": "Client is not connected"})
return

# send non-binary payloads to the client
messages = UNDELIVERED[key]
UNDELIVERED[key] = []
messages = []
queue = HiveMindHttpHandler.undelivered[key]

# Retrieve all messages from the queue
while not queue.empty():
try:
message = queue.get_nowait()
messages.append(message)
except Exception as e:
# Handle unexpected errors (unlikely with get_nowait)
break
self.write({"status": "messages retrieved", "messages": messages})
except Exception as e:
LOG.error(f"Retrieving messages failed: {e}")
Expand All @@ -324,13 +326,22 @@ async def get(self):
return

# refuse if connect wasnt called first
if key not in CLIENTS:
if key not in HiveMindHttpHandler.clients:
self.write({"error": "Client is not connected"})
return

# send non-binary payloads to the client
messages = UNDELIVERED_BIN[key]
UNDELIVERED_BIN[key] = []
messages = []
queue = HiveMindHttpHandler.undelivered_bin[key]

# Retrieve all messages from the queue
while not queue.empty():
try:
message = queue.get_nowait()
messages.append(message)
except Exception as e:
# Handle unexpected errors (unlikely with get_nowait)
break

self.write({"status": "messages retrieved", "b64_messages": messages})
except Exception as e:
LOG.error(f"Retrieving messages failed: {e}")
Expand Down

0 comments on commit 2ad3370

Please sign in to comment.