Skip to content
This repository has been archived by the owner on Dec 22, 2024. It is now read-only.

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Dec 20, 2024
1 parent b5331f4 commit d5ffc90
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
5 changes: 3 additions & 2 deletions hivemind_core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,8 +758,9 @@ def _update_blacklist(self, message: Message, client: HiveMindClientConnection):

# update blacklist from db, to account for changes without requiring a restart
user = self.db.get_client_by_api_key(client.key)
client.skill_blacklist = user.skill_blacklist
client.intent_blacklist = user.intent_blacklist
client.skill_blacklist = user.skill_blacklist or []
client.intent_blacklist = user.intent_blacklist or []
client.msg_blacklist = user.message_blacklist or []

# inject client specific blacklist into session
if "blacklisted_skills" not in message.context["session"]:
Expand Down
15 changes: 7 additions & 8 deletions hivemind_core/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def on_stopping():

class MessageBusEventHandler(WebSocketHandler):
protocol: Optional[HiveMindListenerProtocol] = None
db: Optional[ClientDatabase] = None

@staticmethod
def decode_auth(auth) -> Tuple[str, str]:
Expand Down Expand Up @@ -135,24 +134,25 @@ def open(self):
handshake=handshake,
loop=self.protocol.loop,
)
if self.db is None:
if self.protocol.db is None:
# should never happen, but double check!
LOG.error("Database connection not initialized. Please ensure database configuration is correct.")
LOG.exception(f"Client {self.client.peer} connection attempt failed due to missing database connection")
self.close()
raise RuntimeError("Database was not initialized!") # let it propagate, this is developer error most likely

user = self.db.get_client_by_api_key(key)
user = self.protocol.db.get_client_by_api_key(key)

if not user:
LOG.error("Client provided an invalid api key")
self.protocol.handle_invalid_key_connected(self.client)
self.close()
return

self.client.crypto_key = user.crypto_key
self.client.msg_blacklist = user.message_blacklist
self.client.skill_blacklist = user.skill_blacklist
self.client.intent_blacklist = user.intent_blacklist
self.client.msg_blacklist = user.message_blacklist or []
self.client.skill_blacklist = user.skill_blacklist or []
self.client.intent_blacklist = user.intent_blacklist or []
self.client.allowed_types = user.allowed_types
self.client.can_broadcast = user.can_broadcast
self.client.can_propagate = user.can_propagate
Expand Down Expand Up @@ -215,10 +215,9 @@ def __init__(
on_error=error_hook,
on_stopping=stopping_hook,
)
self.db = db
self.db = db or ClientDatabase()
self._proto = protocol
self._ws_handler = ws_handler
self._ws_handler.db = db
if bus:
self.bus = bus
else:
Expand Down

0 comments on commit d5ffc90

Please sign in to comment.