From 5e93efe9b217351d3aabd5fda864eae6f6a1d7f9 Mon Sep 17 00:00:00 2001 From: miro Date: Sun, 30 Jun 2024 22:18:36 +0100 Subject: [PATCH] feat!:configurable_database_backend allow several database backends Update test/unittests/test_db.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- README.md | 2 + hivemind_core/database.py | 670 +++++++++++++++++++++++++++----------- hivemind_core/protocol.py | 11 +- hivemind_core/scripts.py | 381 ++++++++++++++++------ hivemind_core/service.py | 80 +++-- setup.py | 2 +- test/unittests/test_db.py | 219 ++++++++++--- 7 files changed, 997 insertions(+), 368 deletions(-) diff --git a/README.md b/README.md index 0b18dfc..cf493da 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ Work in progress documentation can be found in the [docs](https://jarbashivemind You can also join the [Hivemind Matrix chat](https://matrix.to/#/#jarbashivemind:matrix.org) for general news, support and chit chat + + # Usage ``` diff --git a/hivemind_core/database.py b/hivemind_core/database.py index 02e1ce0..2b283dd 100644 --- a/hivemind_core/database.py +++ b/hivemind_core/database.py @@ -1,254 +1,546 @@ +import abc import json -from functools import wraps +from dataclasses import dataclass, field from typing import List, Dict, Union, Any, Optional, Iterable -from json_database import JsonDatabaseXDG +from json_database import JsonStorageXDG from ovos_utils.log import LOG +try: + import redis +except ImportError: + redis = None -def cast_to_client_obj(): - valid_kwargs: Iterable[str] = ( - "client_id", - "api_key", - "name", - "description", - "is_admin", - "last_seen", - "blacklist", - "allowed_types", - "crypto_key", - "password", - "can_broadcast", - "can_escalate", - "can_propagate", - ) - - def _handler(func): - def _cast(ret): - if ret is None or isinstance(ret, Client): - return ret - if isinstance(ret, list): - return [_cast(r) for r in ret] - if isinstance(ret, dict): - if not all((k in valid_kwargs for k in ret.keys())): - raise RuntimeError(f"{func} returned a dict with unknown keys") - return Client(**ret) - - raise TypeError( - "cast_to_client_obj decorator can only be used in functions that return None, dict, Client or a list of those types" - ) - - @wraps(func) - def call_function(*args, **kwargs): - ret = func(*args, **kwargs) - return _cast(ret) - - return call_function - - return _handler +ClientDict = Dict[str, Union[str, int, float, List[str]]] +ClientTypes = Union[None, 'Client', + str, # json + ClientDict, # dict + List[Union[str, ClientDict, 'Client']] # list of dicts/json/Client + ] + +def cast2client(ret: ClientTypes) -> Optional[Union['Client', List['Client']]]: + """ + Convert different input types (str, dict, list) to Client instances. + + Args: + ret: The object to be cast, can be a string, dictionary, or list. + + Returns: + A single Client instance or a list of Clients if ret is a list. + """ + if ret is None or isinstance(ret, Client): + return ret + if isinstance(ret, str) or isinstance(ret, dict): + return Client.deserialize(ret) + if isinstance(ret, list): + return [cast2client(r) for r in ret] + raise TypeError("not a client object") + + +@dataclass class Client: - def __init__( - self, - client_id: int, - api_key: str, - name: str = "", - description: str = "", - is_admin: bool = False, - last_seen: float = -1, - blacklist: Optional[Dict[str, List[str]]] = None, - allowed_types: Optional[List[str]] = None, - crypto_key: Optional[str] = None, - password: Optional[str] = None, - can_broadcast: bool = True, - can_escalate: bool = True, - can_propagate: bool = True, - ): - self.client_id = client_id - self.description = description - self.api_key = api_key - self.name = name - self.last_seen = last_seen - self.is_admin = is_admin - self.crypto_key = crypto_key - self.password = password - self.blacklist = blacklist or {"messages": [], "skills": [], "intents": []} - self.allowed_types = allowed_types or ["recognizer_loop:utterance", - "recognizer_loop:record_begin", - "recognizer_loop:record_end", - "recognizer_loop:audio_output_start", - "recognizer_loop:audio_output_end", - 'recognizer_loop:b64_transcribe', - 'speak:b64_audio', - "ovos.common_play.SEI.get.response"] + client_id: int + api_key: str + name: str = "" + description: str = "" + is_admin: bool = False + last_seen: float = -1 + intent_blacklist: List[str] = field(default_factory=list) + skill_blacklist: List[str] = field(default_factory=list) + message_blacklist: List[str] = field(default_factory=list) + allowed_types: List[str] = field(default_factory=list) + crypto_key: Optional[str] = None + password: Optional[str] = None + can_broadcast: bool = True + can_escalate: bool = True + can_propagate: bool = True + + def __post_init__(self): + """ + Initializes the allowed types for the Client instance if not provided. + """ + self.allowed_types = self.allowed_types or ["recognizer_loop:utterance", + "recognizer_loop:record_begin", + "recognizer_loop:record_end", + "recognizer_loop:audio_output_start", + "recognizer_loop:audio_output_end", + 'recognizer_loop:b64_transcribe', + 'speak:b64_audio', + "ovos.common_play.SEI.get.response"] if "recognizer_loop:utterance" not in self.allowed_types: self.allowed_types.append("recognizer_loop:utterance") - self.can_broadcast = can_broadcast - self.can_escalate = can_escalate - self.can_propagate = can_propagate + + def serialize(self) -> str: + """ + Serializes the Client instance into a JSON string. + + Returns: + A JSON string representing the client data. + """ + return json.dumps(self.__dict__, sort_keys=True, ensure_ascii=False) + + @staticmethod + def deserialize(client_data: Union[str, Dict]) -> 'Client': + """ + Deserialize a client from JSON string or dictionary into a Client instance. + + Args: + client_data: The data to be deserialized, either a string or dictionary. + + Returns: + A Client instance. + """ + if isinstance(client_data, str): + client_data = json.loads(client_data) + # TODO filter kwargs with inspect + return Client(**client_data) def __getitem__(self, item: str) -> Any: - return self.__dict__.get(item) + """ + Access attributes of the client via item access. + + Args: + item: The name of the attribute. + + Returns: + The value of the attribute. + + Raises: + KeyError: If the attribute does not exist. + """ + if hasattr(self, item): + return getattr(self, item) + raise KeyError(f"Unknown key: {item}") def __setitem__(self, key: str, value: Any): + """ + Set attributes of the client via item access. + + Args: + key: The name of the attribute. + value: The value to set. + + Raises: + ValueError: If the attribute does not exist. + """ if hasattr(self, key): setattr(self, key, value) else: - raise ValueError("unknown property") + raise ValueError(f"Unknown property: {key}") - def __eq__(self, other: Union[object, dict]) -> bool: - if not isinstance(other, dict): - other = other.__dict__ - if self.__dict__ == other: - return True + def __eq__(self, other: Any) -> bool: + """ + Compares two Client instances for equality based on their serialized data. + + Args: + other: The other Client or Client-compatible object to compare with. + + Returns: + True if the clients are equal, False otherwise. + """ + try: + other = cast2client(other) + except: + pass + if isinstance(other, Client): + return self.serialize() == other.serialize() return False def __repr__(self) -> str: - return str(self.__dict__) + """ + Returns a string representation of the Client instance. + Returns: + A string representing the client. + """ + return self.serialize() -class ClientDatabase(JsonDatabaseXDG): - def __init__(self): - super().__init__("clients", subfolder="hivemind") - def update_timestamp(self, key: str, timestamp: float) -> bool: - user = self.get_client_by_api_key(key) - if user is None: - return False - item_id = self.get_item_id(user) - user["last_seen"] = timestamp - self.update_item(item_id, user) +class AbstractDB(abc.ABC): + """ + Abstract base class for all database implementations. + + All database implementations should derive from this class and implement + the abstract methods. + """ + + @abc.abstractmethod + def add_item(self, client: Client) -> bool: + """ + Add a client to the database. + + Args: + client: The client to be added. + + Returns: + True if the addition was successful, False otherwise. + """ + pass + + @abc.abstractmethod + def delete_item(self, client: Client) -> bool: + """ + Delete a client from the database. + + Args: + client: The client to be deleted. + + Returns: + True if the deletion was successful, False otherwise. + """ + pass + + def update_item(self, client: Client) -> bool: + """ + Update an existing client in the database. + + Args: + client: The client to be updated. + + Returns: + True if the update was successful, False otherwise. + """ + return self.add_item(client) + + def replace_item(self, old_client: Client, new_client: Client) -> bool: + """ + Replace an old client with a new client. + + Args: + old_client: The old client to be replaced. + new_client: The new client to add. + + Returns: + True if the replacement was successful, False otherwise. + """ + self.delete_item(old_client) + return self.add_item(new_client) + + @abc.abstractmethod + def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[Client]: + """ + Search for clients by a specific key-value pair. + + Args: + key: The key to search by. + val: The value to search for. + + Returns: + A list of clients that match the search criteria. + """ + pass + + @abc.abstractmethod + def __len__(self) -> int: + """ + Get the number of items in the database. + + Returns: + The number of items in the database. + """ + return 0 + + @abc.abstractmethod + def __iter__(self) -> Iterable['Client']: + """ + Iterate over all clients in the database. + + Returns: + An iterator over the clients in the database. + """ + pass + + def commit(self) -> bool: + """ + Commit changes to the database. + + Returns: + True if the commit was successful, False otherwise. + """ return True - def delete_client(self, key: str) -> bool: - user = self.get_client_by_api_key(key) - if user: - item_id = self.get_item_id(user) - self.update_item(item_id, Client(-1, api_key="revoked")) + +class JsonDB(AbstractDB): + """Database implementation using JSON files.""" + + def __init__(self, name="clients", subfolder="hivemind-core"): + self._db: Dict[int, ClientDict] = JsonStorageXDG(name, subfolder=subfolder) + + def add_item(self, client: Client) -> bool: + """ + Add a client to the JSON database. + + Args: + client: The client to be added. + + Returns: + True if the addition was successful, False otherwise. + """ + self._db[client.client_id] = client.__dict__ + return True + + def delete_item(self, client: Client) -> bool: + """ + Delete a client from the JSON database. + + Args: + client: The client to be deleted. + + Returns: + True if the deletion was successful, False otherwise. + """ + if client.client_id in self._db: + self._db.pop(client.client_id) return True return False - def change_key(self, old_key: str, new_key: str) -> bool: - user = self.get_client_by_api_key(old_key) - if user is None: + def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[Client]: + """ + Search for clients by a specific key-value pair in the JSON database. + + Args: + key: The key to search by. + val: The value to search for. + + Returns: + A list of clients that match the search criteria. + """ + res = [] + if key == "client_id": + v = self._db.get(val) + if v: + res.append(cast2client(v)) + else: + for client in self._db.values(): + v = client.get(key) + if v == val: + res.append(cast2client(client)) + return res + + def __len__(self) -> int: + """ + Get the number of clients in the database. + + Returns: + The number of clients in the database. + """ + return len(self._db) + + def __iter__(self) -> Iterable['Client']: + """ + Iterate over all clients in the JSON database. + + Returns: + An iterator over the clients in the database. + """ + for item in self._db.values(): + yield Client.deserialize(item) + + def commit(self) -> bool: + """ + Commit changes to the JSON database. + + Returns: + True if the commit was successful, False otherwise. + """ + try: + self._db.store() + return True + except Exception as e: + LOG.error(f"Failed to save {self._db.path}") return False - item_id = self.get_item_id(user) - user["api_key"] = new_key - self.update_item(item_id, user) - return True - def change_crypto_key(self, api_key: str, new_key: str) -> bool: - user = self.get_client_by_api_key(api_key) - if user is None: + +class RedisDB(AbstractDB): + """Database implementation using Redis with RediSearch support.""" + + def __init__(self, host: str = "127.0.0.1", port: int = 6379, redis_db: int = 0): + """ + Initialize the RedisDB connection. + + Args: + host: Redis server host. + port: Redis server port. + redis_db: Redis database index. + """ + if redis is None: + raise ImportError("pip install redis") + self.redis = redis.StrictRedis(host=host, port=port, db=redis_db, decode_responses=True) + + def add_item(self, client: Client) -> bool: + """ + Add a client to Redis and RediSearch. + + Args: + client: The client to be added. + + Returns: + True if the addition was successful, False otherwise. + """ + item_key = f"client:{client.client_id}" + serialized_data: str = client.serialize() + + try: + # Store data in Redis + self.redis.set(item_key, serialized_data) + return True + except Exception as e: + LOG.error(f"Failed to add client to Redis/RediSearch: {e}") return False - item_id = self.get_item_id(user) - user["crypto_key"] = new_key - self.update_item(item_id, user) - return True - def get_crypto_key(self, api_key: str) -> Optional[str]: - user = self.get_client_by_api_key(api_key) - if user is None: - return None - return user["crypto_key"] + def delete_item(self, client: Client) -> bool: + """ + Delete a client from Redis and RediSearch. - def get_password(self, api_key: str) -> Optional[str]: - user = self.get_client_by_api_key(api_key) - if user is None: - return None - return user["password"] + Args: + client: The client to be deleted. - def change_name(self, new_name: str, key: str) -> bool: - user = self.get_client_by_api_key(key) - if user is None: + Returns: + True if the deletion was successful, False otherwise. + """ + item_key = f"client:{client.client_id}" + try: + self.redis.delete(item_key) + return True + except Exception as e: + LOG.error(f"Failed to delete client from Redis: {e}") return False - item_id = self.get_item_id(user) - user["name"] = new_name - self.update_item(item_id, user) - return True - def change_blacklist(self, blacklist: Union[str, Dict[str, Any]], key: str) -> bool: - if isinstance(blacklist, dict): - blacklist = json.dumps(blacklist) + def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[Client]: + """ + Search for clients by a specific key-value pair in Redis. + + Args: + key: The key to search by. + val: The value to search for. + + Returns: + A list of clients that match the search criteria. + """ + res = [] + for client_id in self.redis.scan_iter(f"client:*"): + client_data = self.redis.get(client_id) + client = cast2client(client_data) + if hasattr(client, key) and getattr(client, key) == val: + res.append(client) + return res + + def __len__(self) -> int: + """ + Get the number of items in the Redis database. + + Returns: + The number of clients in the database. + """ + return len(self.redis.keys("client:*")) + + def __iter__(self) -> Iterable['Client']: + """ + Iterate over all clients in Redis. + + Returns: + An iterator over the clients in the database. + """ + for client_id in self.redis.scan_iter(f"client:*"): + yield cast2client(self.redis.get(client_id)) + + +class ClientDatabase: + valid_backends = ["json", "redis"] + + def __init__(self, backend="json", **backend_kwargs): + """ + Initialize the client database with the specified backend. + """ + backend_kwargs = backend_kwargs or {} + if backend not in self.valid_backends: + raise NotImplementedError(f"{backend} not supported, choose one of {self.valid_backends}") + + if backend == "json": + self.db = JsonDB() + elif backend == "redis": + self.db = RedisDB(**backend_kwargs) + else: + raise NotImplementedError(f"{backend} not supported, valid databases: {self.valid_backends}") + + def delete_client(self, key: str) -> bool: user = self.get_client_by_api_key(key) - if user is None: - return False - item_id = self.get_item_id(user) - user["blacklist"] = blacklist - self.update_item(item_id, user) - return True + if user: + return self.db.delete_item(user) + return False - def get_blacklist_by_api_key(self, api_key: str): - search = self.search_by_value("api_key", api_key) - if len(search): - return search[0]["blacklist"] - return None + def get_clients_by_name(self, name: str) -> List[Client]: + return self.db.search_by_value("name", name) - @cast_to_client_obj() def get_client_by_api_key(self, api_key: str) -> Optional[Client]: - search = self.search_by_value("api_key", api_key) + search: List[Client] = self.db.search_by_value("api_key", api_key) if len(search): return search[0] return None - @cast_to_client_obj() - def get_clients_by_name(self, name: str) -> List[Client]: - return self.search_by_value("name", name) - - @cast_to_client_obj() - def add_client( - self, - name: str, - key: str = "", - admin: bool = False, - blacklist: Optional[Dict[str, Any]] = None, - allowed_types: Optional[List[str]] = None, - crypto_key: Optional[str] = None, - password: Optional[str] = None, - ) -> Client: - user = self.get_client_by_api_key(key) - item_id = self.get_item_id(user) + def add_client(self, + name: str, + key: str = "", + admin: bool = False, + intent_blacklist: Optional[List[str]] = None, + skill_blacklist: Optional[List[str]] = None, + message_blacklist: Optional[List[str]] = None, + allowed_types: Optional[List[str]] = None, + crypto_key: Optional[str] = None, + password: Optional[str] = None) -> bool: if crypto_key is not None: crypto_key = crypto_key[:16] - if item_id >= 0: + + user = self.get_client_by_api_key(key) + if user: + # Update the existing client object directly if name: - user["name"] = name - if blacklist: - user["blacklist"] = blacklist + user.name = name + if intent_blacklist: + user.intent_blacklist = intent_blacklist + if skill_blacklist: + user.skill_blacklist = skill_blacklist + if message_blacklist: + user.message_blacklist = message_blacklist if allowed_types: - user["allowed_types"] = allowed_types + user.allowed_types = allowed_types if admin is not None: - user["is_admin"] = admin + user.is_admin = admin if crypto_key: - user["crypto_key"] = crypto_key + user.crypto_key = crypto_key if password: - user["password"] = password - self.update_item(item_id, user) - else: - user = Client( - api_key=key, - name=name, - blacklist=blacklist, - crypto_key=crypto_key, - client_id=self.total_clients() + 1, - is_admin=admin, - password=password, - allowed_types=allowed_types, - ) - self.add_item(user) - return user + user.password = password + return self.db.update_item(user) + + user = Client( + api_key=key, + name=name, + intent_blacklist=intent_blacklist, + skill_blacklist=skill_blacklist, + message_blacklist=message_blacklist, + crypto_key=crypto_key, + client_id=self.total_clients() + 1, + is_admin=admin, + password=password, + allowed_types=allowed_types, + ) + return self.db.add_item(user) def total_clients(self) -> int: - return len(self) + return len(self.db) def __enter__(self): """Context handler""" return self + def __iter__(self) -> Iterable[Client]: + yield from self.db + def __exit__(self, _type, value, traceback): """Commits changes and Closes the session""" try: - self.commit() + self.db.commit() except Exception as e: LOG.error(e) diff --git a/hivemind_core/protocol.py b/hivemind_core/protocol.py index 67c39ad..f1bc09a 100644 --- a/hivemind_core/protocol.py +++ b/hivemind_core/protocol.py @@ -253,6 +253,7 @@ class HiveMindListenerProtocol: require_crypto: bool = True # throw error if crypto key not available handshake_enabled: bool = True # generate a key per session if not pre-shared identity: Optional[NodeIdentity] = None + db: Optional[ClientDatabase] = None # below are optional callbacks to handle payloads # receives the payload + HiveMindClient that sent it escalate_callback = None # slave asked to escalate payload @@ -262,8 +263,9 @@ class HiveMindListenerProtocol: mycroft_bus_callback = None # slave asked to inject payload into mycroft bus shared_bus_callback = None # passive sharing of slave device bus (info) - def bind(self, websocket, bus, identity): + def bind(self, websocket, bus, identity, db: ClientDatabase): self.identity = identity + self.db = db websocket.protocol = self self.internal_protocol = HiveMindListenerInternalProtocol(bus) self.internal_protocol.register_bus_handlers() @@ -755,10 +757,9 @@ def _update_blacklist(self, message: Message, client: HiveMindClientConnection): message.context["session"] = client.sess.serialize() # update blacklist from db, to account for changes without requiring a restart - with ClientDatabase() as users: - user = users.get_client_by_api_key(client.key) - client.skill_blacklist = user.blacklist.get("skills", []) - client.intent_blacklist = user.blacklist.get("intents", []) + user = self.db.get_client_by_api_key(client.key) + client.skill_blacklist = user.skill_blacklist + client.intent_blacklist = user.intent_blacklist # inject client specific blacklist into session if "blacklisted_skills" not in message.context["session"]: diff --git a/hivemind_core/scripts.py b/hivemind_core/scripts.py index cf40345..9d307b9 100644 --- a/hivemind_core/scripts.py +++ b/hivemind_core/scripts.py @@ -19,7 +19,24 @@ def hmcore_cmds(): @click.option("--access-key", required=False, type=str) @click.option("--password", required=False, type=str) @click.option("--crypto-key", required=False, type=str) -def add_client(name, access_key, password, crypto_key): +@click.option( + "--db-backend", + type=click.Choice(['redis', 'json'], case_sensitive=False), + default='json', + help="Select the database backend to use. Options: redis, json." +) +@click.option( + "--redis-host", + default="localhost", + help="Host for Redis (if selected). Default is localhost." +) +@click.option( + "--redis-port", + default=6379, + help="Port for Redis (if selected). Default is 6379." +) +def add_client(name, access_key, password, crypto_key, + db_backend, redis_host, redis_port): key = crypto_key if key: print( @@ -36,18 +53,26 @@ def add_client(name, access_key, password, crypto_key): key = os.urandom(8).hex() password = password or os.urandom(16).hex() - access_key = access_key or os.urandom(16).hex() - with ClientDatabase() as db: + + kwargs = {"backend": db_backend} + if db_backend == "redis": + kwargs["host"] = redis_host + kwargs["port"] = redis_port + with ClientDatabase(**kwargs) as db: name = name or f"HiveMind-Node-{db.total_clients()}" - db.add_client(name, access_key, crypto_key=key, password=password) + print(f"Database backend: {db.db.__class__.__name__}") + success = db.add_client(name, access_key, crypto_key=key, password=password) + if not success: + raise ValueError(f"Error adding User to database: {name}") # verify user = db.get_client_by_api_key(access_key) - node_id = db.get_item_id(user) + if user is None: + raise ValueError(f"User not found: {name}") print("Credentials added to database!\n") - print("Node ID:", node_id) + print("Node ID:", user.client_id) print("Friendly Name:", name) print("Access Key:", access_key) print("Password:", password) @@ -61,7 +86,29 @@ def add_client(name, access_key, password, crypto_key): @hmcore_cmds.command(help="allow message types sent from a client", name="allow-msg") @click.argument("msg_type", required=True, type=str) @click.argument("node_id", required=False, type=int) -def allow_msg(msg_type, node_id): +@click.option( + "--db-backend", + type=click.Choice(['redis', 'json'], case_sensitive=False), + default='json', + help="Select the database backend to use. Options: redis, json." +) +@click.option( + "--redis-host", + default="localhost", + help="Host for Redis (if selected). Default is localhost." +) +@click.option( + "--redis-port", + default=6379, + help="Port for Redis (if selected). Default is 6379." +) +def allow_msg(msg_type, node_id, + db_backend, redis_host, redis_port): + kwargs = {"backend": db_backend} + if db_backend == "redis": + kwargs["host"] = redis_host + kwargs["port"] = redis_port + if not node_id: # list clients and prompt for id using rich table = Table(title="HiveMind Clients") @@ -69,14 +116,14 @@ def allow_msg(msg_type, node_id): table.add_column("Name", style="magenta") table.add_column("Allowed Msg Types", style="yellow") _choices = [] - for client in ClientDatabase(): - if client["client_id"] != -1: + for client in ClientDatabase(**kwargs): + if client.client_id != -1: table.add_row( - str(client["client_id"]), - client["name"], - str(client.get("allowed_types", [])), + str(client.client_id), + client.name, + str(client.allowed_types), ) - _choices.append(str(client["client_id"])) + _choices.append(str(client.client_id)) if not _choices: print("No clients found!") @@ -95,19 +142,15 @@ def allow_msg(msg_type, node_id): else: node_id = _choices[0] - with ClientDatabase() as db: + with ClientDatabase(**kwargs) as db: for client in db: - if client["client_id"] == int(node_id): - allowed_types = client.get("allowed_types", []) - if msg_type in allowed_types: - print(f"Client {client['name']} already allowed '{msg_type}'") + if client.client_id == int(node_id): + if msg_type in client.allowed_types: + print(f"Client {client.name} already allowed '{msg_type}'") exit() - - allowed_types.append(msg_type) - client["allowed_types"] = allowed_types - item_id = db.get_item_id(client) - db.update_item(item_id, client) - print(f"Allowed '{msg_type}' for {client['name']}") + client.allowed_types.append(msg_type) + db.update_item(client) + print(f"Allowed '{msg_type}' for {client.name}") break @@ -115,8 +158,29 @@ def allow_msg(msg_type, node_id): help="remove credentials for a client (numeric unique ID)", name="delete-client" ) @click.argument("node_id", required=True, type=int) -def delete_client(node_id): - with ClientDatabase() as db: +@click.option( + "--db-backend", + type=click.Choice(['redis', 'json'], case_sensitive=False), + default='json', + help="Select the database backend to use. Options: redis, json." +) +@click.option( + "--redis-host", + default="localhost", + help="Host for Redis (if selected). Default is localhost." +) +@click.option( + "--redis-port", + default=6379, + help="Port for Redis (if selected). Default is 6379." +) +def delete_client(node_id, + db_backend, redis_host, redis_port): + kwargs = {"backend": db_backend} + if db_backend == "redis": + kwargs["host"] = redis_host + kwargs["port"] = redis_port + with ClientDatabase(**kwargs) as db: for x in db: if x["client_id"] == int(node_id): item_id = db.get_item_id(x) @@ -133,7 +197,23 @@ def delete_client(node_id): @hmcore_cmds.command(help="list clients and credentials", name="list-clients") -def list_clients(): +@click.option( + "--db-backend", + type=click.Choice(['redis', 'json'], case_sensitive=False), + default='json', + help="Select the database backend to use. Options: redis, json." +) +@click.option( + "--redis-host", + default="localhost", + help="Host for Redis (if selected). Default is localhost." +) +@click.option( + "--redis-port", + default=6379, + help="Port for Redis (if selected). Default is 6379." +) +def list_clients(db_backend, redis_host, redis_port): console = Console() table = Table(title="HiveMind Credentials:") table.add_column("ID", justify="center") @@ -142,7 +222,11 @@ def list_clients(): table.add_column("Password", justify="center") table.add_column("Crypto Key", justify="center") - with ClientDatabase() as db: + kwargs = {"backend": db_backend} + if db_backend == "redis": + kwargs["host"] = redis_host + kwargs["port"] = redis_port + with ClientDatabase(**kwargs) as db: for x in db: if x["client_id"] != -1: table.add_row( @@ -186,6 +270,22 @@ def list_clients(): type=str, default="hivemind", ) +@click.option( + "--db-backend", + type=click.Choice(['redis', 'json'], case_sensitive=False), + default='json', + help="Select the database backend to use. Options: redis, json." +) +@click.option( + "--redis-host", + default="localhost", + help="Host for Redis (if selected). Default is localhost." +) +@click.option( + "--redis-port", + default=6379, + help="Port for Redis (if selected). Default is 6379." +) def listen( ovos_bus_address: str, ovos_bus_port: int, @@ -194,6 +294,7 @@ def listen( ssl: bool, cert_dir: str, cert_name: str, + db_backend, redis_host, redis_port ): from hivemind_core.service import HiveMindService @@ -210,8 +311,15 @@ def listen( "cert_name": cert_name, } + kwargs = {"backend": db_backend} + if db_backend == "redis": + kwargs["host"] = redis_host + kwargs["port"] = redis_port + service = HiveMindService( - ovos_bus_config=ovos_bus_config, websocket_config=websocket_config + ovos_bus_config=ovos_bus_config, + websocket_config=websocket_config, + db=ClientDatabase(**kwargs) ) service.run() @@ -219,7 +327,29 @@ def listen( @hmcore_cmds.command(help="blacklist skills from being triggered by a client", name="blacklist-skill") @click.argument("skill_id", required=True, type=str) @click.argument("node_id", required=False, type=int) -def blacklist_skill(skill_id, node_id): +@click.option( + "--db-backend", + type=click.Choice(['redis', 'json'], case_sensitive=False), + default='json', + help="Select the database backend to use. Options: redis, json." +) +@click.option( + "--redis-host", + default="localhost", + help="Host for Redis (if selected). Default is localhost." +) +@click.option( + "--redis-port", + default=6379, + help="Port for Redis (if selected). Default is 6379." +) +def blacklist_skill(skill_id, node_id, + db_backend, redis_host, redis_port): + kwargs = {"backend": db_backend} + if db_backend == "redis": + kwargs["host"] = redis_host + kwargs["port"] = redis_port + if not node_id: # list clients and prompt for id using rich table = Table(title="HiveMind Clients") @@ -227,14 +357,14 @@ def blacklist_skill(skill_id, node_id): table.add_column("Name", style="magenta") table.add_column("Allowed Msg Types", style="yellow") _choices = [] - for client in ClientDatabase(): - if client["client_id"] != -1: + for client in ClientDatabase(**kwargs): + if client.client_id != -1: table.add_row( - str(client["client_id"]), - client["name"], - str(client.get("allowed_types", [])), + str(client.client_id), + client.name, + str(client.allowed_types), ) - _choices.append(str(client["client_id"])) + _choices.append(str(client.client_id)) if not _choices: print("No clients found!") @@ -253,26 +383,45 @@ def blacklist_skill(skill_id, node_id): else: node_id = _choices[0] - with ClientDatabase() as db: + with ClientDatabase(**kwargs) as db: for client in db: - if client["client_id"] == int(node_id): - blacklist = client.get("blacklist", {"messages": [], "skills": [], "intents": []}) - if skill_id in blacklist["skills"]: - print(f"Client {client['name']} already blacklisted '{skill_id}'") + if client.client_id == int(node_id): + if skill_id in client.skill_blacklist: + print(f"Client {client.name} already blacklisted '{skill_id}'") exit() - blacklist["skills"].append(skill_id) - client["blacklist"] = blacklist - item_id = db.get_item_id(client) - db.update_item(item_id, client) - print(f"Blacklisted '{skill_id}' for {client['name']}") + client.skill_blacklist.append(skill_id) + db.update_item(client) + print(f"Blacklisted '{skill_id}' for {client.name}") break @hmcore_cmds.command(help="remove skills from a client blacklist", name="unblacklist-skill") @click.argument("skill_id", required=True, type=str) @click.argument("node_id", required=False, type=int) -def unblacklist_skill(skill_id, node_id): +@click.option( + "--db-backend", + type=click.Choice(['redis', 'json'], case_sensitive=False), + default='json', + help="Select the database backend to use. Options: redis, json." +) +@click.option( + "--redis-host", + default="localhost", + help="Host for Redis (if selected). Default is localhost." +) +@click.option( + "--redis-port", + default=6379, + help="Port for Redis (if selected). Default is 6379." +) +def unblacklist_skill(skill_id, node_id, + db_backend, redis_host, redis_port): + kwargs = {"backend": db_backend} + if db_backend == "redis": + kwargs["host"] = redis_host + kwargs["port"] = redis_port + if not node_id: # list clients and prompt for id using rich table = Table(title="HiveMind Clients") @@ -280,14 +429,14 @@ def unblacklist_skill(skill_id, node_id): table.add_column("Name", style="magenta") table.add_column("Allowed Msg Types", style="yellow") _choices = [] - for client in ClientDatabase(): - if client["client_id"] != -1: + for client in ClientDatabase(**kwargs): + if client.client_id != -1: table.add_row( - str(client["client_id"]), - client["name"], - str(client.get("allowed_types", [])), + str(client.client_id), + client.name, + str(client.allowed_types), ) - _choices.append(str(client["client_id"])) + _choices.append(str(client.client_id)) if not _choices: print("No clients found!") @@ -306,26 +455,44 @@ def unblacklist_skill(skill_id, node_id): else: node_id = _choices[0] - with ClientDatabase() as db: + with ClientDatabase(**kwargs) as db: for client in db: - if client["client_id"] == int(node_id): - blacklist = client.get("blacklist", {"messages": [], "skills": [], "intents": []}) - if skill_id not in blacklist["skills"]: - print(f"'{skill_id}' is not blacklisted for client {client['name']}") + if client.client_id == int(node_id): + if skill_id not in client.skill_blacklist: + print(f"'{skill_id}' is not blacklisted for client {client.name}") exit() - - blacklist["skills"].pop(skill_id) - client["blacklist"] = blacklist - item_id = db.get_item_id(client) - db.update_item(item_id, client) - print(f"Blacklisted '{skill_id}' for {client['name']}") + client.skill_blacklist.remove(skill_id) + db.update_item(client) + print(f"Blacklisted '{skill_id}' for {client.name}") break @hmcore_cmds.command(help="blacklist intents from being triggered by a client", name="blacklist-intent") @click.argument("intent_id", required=True, type=str) @click.argument("node_id", required=False, type=int) -def blacklist_intent(intent_id, node_id): +@click.option( + "--db-backend", + type=click.Choice(['redis', 'json'], case_sensitive=False), + default='json', + help="Select the database backend to use. Options: redis, json." +) +@click.option( + "--redis-host", + default="localhost", + help="Host for Redis (if selected). Default is localhost." +) +@click.option( + "--redis-port", + default=6379, + help="Port for Redis (if selected). Default is 6379." +) +def blacklist_intent(intent_id, node_id, + db_backend, redis_host, redis_port): + kwargs = {"backend": db_backend} + if db_backend == "redis": + kwargs["host"] = redis_host + kwargs["port"] = redis_port + if not node_id: # list clients and prompt for id using rich table = Table(title="HiveMind Clients") @@ -333,14 +500,14 @@ def blacklist_intent(intent_id, node_id): table.add_column("Name", style="magenta") table.add_column("Allowed Msg Types", style="yellow") _choices = [] - for client in ClientDatabase(): - if client["client_id"] != -1: + for client in ClientDatabase(**kwargs): + if client.client_id != -1: table.add_row( - str(client["client_id"]), - client["name"], - str(client.get("allowed_types", [])), + str(client.client_id), + client.name, + str(client.allowed_types), ) - _choices.append(str(client["client_id"])) + _choices.append(str(client.client_id)) if not _choices: print("No clients found!") @@ -359,26 +526,44 @@ def blacklist_intent(intent_id, node_id): else: node_id = _choices[0] - with ClientDatabase() as db: + with ClientDatabase(**kwargs) as db: for client in db: - if client["client_id"] == int(node_id): - blacklist = client.get("blacklist", {"messages": [], "skills": [], "intents": []}) - if intent_id in blacklist["intents"]: - print(f"Client {client['name']} already blacklisted '{intent_id}'") + if client.client_id == int(node_id): + if intent_id in client.intent_blacklist: + print(f"Client {client.name} already blacklisted '{intent_id}'") exit() - - blacklist["intents"].append(intent_id) - client["blacklist"] = blacklist - item_id = db.get_item_id(client) - db.update_item(item_id, client) - print(f"Blacklisted '{intent_id}' for {client['name']}") + client.intent_blacklist.append(intent_id) + db.update_item(client) + print(f"Blacklisted '{intent_id}' for {client.name}") break @hmcore_cmds.command(help="remove intents from a client blacklist", name="unblacklist-intent") @click.argument("intent_id", required=True, type=str) @click.argument("node_id", required=False, type=int) -def unblacklist_intent(intent_id, node_id): +@click.option( + "--db-backend", + type=click.Choice(['redis', 'json'], case_sensitive=False), + default='json', + help="Select the database backend to use. Options: redis, json." +) +@click.option( + "--redis-host", + default="localhost", + help="Host for Redis (if selected). Default is localhost." +) +@click.option( + "--redis-port", + default=6379, + help="Port for Redis (if selected). Default is 6379." +) +def unblacklist_intent(intent_id, node_id, + db_backend, redis_host, redis_port): + kwargs = {"backend": db_backend} + if db_backend == "redis": + kwargs["host"] = redis_host + kwargs["port"] = redis_port + if not node_id: # list clients and prompt for id using rich table = Table(title="HiveMind Clients") @@ -386,14 +571,14 @@ def unblacklist_intent(intent_id, node_id): table.add_column("Name", style="magenta") table.add_column("Allowed Msg Types", style="yellow") _choices = [] - for client in ClientDatabase(): - if client["client_id"] != -1: + for client in ClientDatabase(**kwargs): + if client.client_id != -1: table.add_row( - str(client["client_id"]), - client["name"], - str(client.get("allowed_types", [])), + str(client.client_id), + client.name, + str(client.allowed_types), ) - _choices.append(str(client["client_id"])) + _choices.append(str(client.client_id)) if not _choices: print("No clients found!") @@ -412,19 +597,15 @@ def unblacklist_intent(intent_id, node_id): else: node_id = _choices[0] - with ClientDatabase() as db: + with ClientDatabase(**kwargs) as db: for client in db: - if client["client_id"] == int(node_id): - blacklist = client.get("blacklist", {"messages": [], "skills": [], "intents": []}) - if intent_id not in blacklist["intents"]: - print(f" '{intent_id}' not blacklisted for Client {client['name']} ") + if client.client_id == int(node_id): + if intent_id not in client.intent_blacklist: + print(f" '{intent_id}' not blacklisted for Client {client.name} ") exit() - - blacklist["intents"].pop(intent_id) - client["blacklist"] = blacklist - item_id = db.get_item_id(client) - db.update_item(item_id, client) - print(f"Blacklisted '{intent_id}' for {client['name']}") + client.intent_blacklist.remove(intent_id) + db.update_item(client) + print(f"Unblacklisted '{intent_id}' for {client.name}") break diff --git a/hivemind_core/service.py b/hivemind_core/service.py index b070952..09cdc74 100644 --- a/hivemind_core/service.py +++ b/hivemind_core/service.py @@ -99,6 +99,7 @@ def on_stopping(): class MessageBusEventHandler(WebSocketHandler): protocol: Optional[HiveMindListenerProtocol] = None + db: Optional[ClientDatabase] = None @staticmethod def decode_auth(auth) -> Tuple[str, str]: @@ -134,42 +135,46 @@ def open(self): handshake=handshake, loop=self.protocol.loop, ) + if self.db is None: + LOG.error("HiveMind database not initialized, can't validate connection") + self.protocol.handle_invalid_key_connected(self.client) + self.close() + return + + user = self.db.get_client_by_api_key(key) + if not user: + LOG.error("Client provided an invalid api key") + self.protocol.handle_invalid_key_connected(self.client) + self.close() + return + + self.client.crypto_key = user.crypto_key + self.client.msg_blacklist = user.message_blacklist + self.client.skill_blacklist = user.skill_blacklist + self.client.intent_blacklist = user.intent_blacklist + self.client.allowed_types = user.allowed_types + self.client.can_broadcast = user.can_broadcast + self.client.can_propagate = user.can_propagate + self.client.can_escalate = user.can_escalate + if user.password: + # pre-shared password to derive aes_key + self.client.pswd_handshake = PasswordHandShake(user.password) + + self.client.node_type = HiveMindNodeType.NODE # TODO . placeholder - with ClientDatabase() as users: - user = users.get_client_by_api_key(key) - if not user: - LOG.error("Client provided an invalid api key") - self.protocol.handle_invalid_key_connected(self.client) - self.close() - return - - self.client.crypto_key = user.crypto_key - self.client.msg_blacklist = user.blacklist.get("messages", []) - self.client.skill_blacklist = user.blacklist.get("skills", []) - self.client.intent_blacklist = user.blacklist.get("intents", []) - self.client.allowed_types = user.allowed_types - self.client.can_broadcast = user.can_broadcast - self.client.can_propagate = user.can_propagate - self.client.can_escalate = user.can_escalate - if user.password: - # pre-shared password to derive aes_key - self.client.pswd_handshake = PasswordHandShake(user.password) - - self.client.node_type = HiveMindNodeType.NODE # TODO . placeholder - - if ( - not self.client.crypto_key - and not self.protocol.handshake_enabled - and self.protocol.require_crypto - ): - LOG.error( - "No pre-shared crypto key for client and handshake disabled, " - "but configured to require crypto!" - ) - # clients requiring handshake support might fail here - self.protocol.handle_invalid_protocol_version(self.client) - self.close() - return + if ( + not self.client.crypto_key + and not self.protocol.handshake_enabled + and self.protocol.require_crypto + ): + LOG.error( + "No pre-shared crypto key for client and handshake disabled, " + "but configured to require crypto!" + ) + # clients requiring handshake support might fail here + self.protocol.handle_invalid_protocol_version(self.client) + self.close() + return self.protocol.handle_new_client(self.client) # self.write_message(Message("connected").serialize()) @@ -197,6 +202,7 @@ def __init__( protocol=HiveMindListenerProtocol, bus=None, ws_handler=MessageBusEventHandler, + db: ClientDatabase = None ): websocket_config = websocket_config or Configuration().get( "hivemind_websocket", {} @@ -208,8 +214,10 @@ def __init__( on_error=error_hook, on_stopping=stopping_hook, ) + self.db = db self._proto = protocol self._ws_handler = ws_handler + self._ws_handler.db = db if bus: self.bus = bus else: @@ -256,7 +264,7 @@ def run(self): loop = ioloop.IOLoop.current() self.protocol = self._proto(loop=loop) - self.protocol.bind(self._ws_handler, self.bus, self.identity) + self.protocol.bind(self._ws_handler, self.bus, self.identity, self.db) self.status.bind(self.bus) self.status.set_started() diff --git a/setup.py b/setup.py index 155a11e..607e5a6 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def required(requirements_file): setup( - name="jarbas_hive_mind", + name="hivemind-core", version=get_version(), packages=["hivemind_core"], include_package_data=True, diff --git a/test/unittests/test_db.py b/test/unittests/test_db.py index fd8c1d5..88412f1 100644 --- a/test/unittests/test_db.py +++ b/test/unittests/test_db.py @@ -1,37 +1,182 @@ -import os -from unittest import TestCase - -from hivemind_core.database import ClientDatabase, Client - - -class TestDB(TestCase): - def test_add_entry(self): - key = os.urandom(8).hex() - access_key = os.urandom(16).hex() - password = None - - with ClientDatabase() as db: - n = db.total_clients() - name = f"HiveMind-Node-{n}" - user = db.add_client(name, access_key, crypto_key=key, password=password) - # verify data - self.assertTrue(isinstance(user, Client)) - self.assertEqual(user.name, name) - self.assertEqual(user.api_key, access_key) - - # test search entry in db - node_id = db.get_item_id(user) - self.assertEqual(node_id, n) - - user2 = db.get_client_by_api_key(access_key) - self.assertEqual(user, user2) - - for u in db.get_clients_by_name(name): - self.assertEqual(user.name, u.name) - - # test delete entry - db.delete_client(access_key) - node_id = db.get_item_id(user) - self.assertEqual(node_id, -1) - user = db.get_client_by_api_key(access_key) - self.assertIsNone(user) +import json +import unittest +from unittest.mock import patch, MagicMock + +from hivemind_core.database import Client, JsonDB, RedisDB, cast2client, ClientDatabase + + +class TestClient(unittest.TestCase): + + def test_client_creation(self): + client_data = { + "client_id": 1, + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": False + } + client = Client(**client_data) + self.assertEqual(client.client_id, 1) + self.assertEqual(client.api_key, "test_api_key") + self.assertEqual(client.name, "Test Client") + self.assertEqual(client.description, "A test client") + self.assertFalse(client.is_admin) + + def test_client_serialization(self): + client_data = { + "client_id": 1, + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": False + } + client = Client(**client_data) + serialized_data = client.serialize() + self.assertIsInstance(serialized_data, str) + self.assertIn('"client_id": 1', serialized_data) + + def test_client_deserialization(self): + client_data = { + "client_id": 1, + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": False + } + serialized_data = json.dumps(client_data) + client = Client.deserialize(serialized_data) + self.assertEqual(client.client_id, 1) + self.assertEqual(client.api_key, "test_api_key") + + def test_cast2client(self): + client_data = { + "client_id": 1, + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": False + } + client = Client(**client_data) + serialized_client = client.serialize() + deserialized_client = cast2client(serialized_client) + self.assertEqual(client, deserialized_client) + + client_list = [client, client] + deserialized_client_list = cast2client([serialized_client, serialized_client]) + self.assertEqual(client_list, deserialized_client_list) + + +class TestJsonDB(unittest.TestCase): + + def setUp(self): + self.db = JsonDB(name=".hivemind-test") + + def test_add_item(self): + client_data = { + "client_id": 1, + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": False + } + client = Client(**client_data) + self.db.add_item(client) + self.assertTrue(client.client_id in self.db._db) + + def test_delete_item(self): + client_data = { + "client_id": 1, + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": False + } + client = Client(**client_data) + self.db.add_item(client) + result = self.db.delete_item(client) + self.assertTrue(result) + + def test_search_by_value(self): + client_data = { + "client_id": 1, + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": False + } + client = Client(**client_data) + self.db.add_item(client) + clients = self.db.search_by_value("name", "Test Client") + self.assertEqual(len(clients), 1) + self.assertEqual(clients[0].name, "Test Client") + + +class TestRedisDB(unittest.TestCase): + + @patch('hivemind_core.database.redis.StrictRedis') + def setUp(self, MockRedis): + self.mock_redis = MagicMock() + MockRedis.return_value = self.mock_redis + self.db = RedisDB() + + def test_add_item(self): + client_data = { + "client_id": 1, + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": False + } + client = Client(**client_data) + self.db.add_item(client) + self.mock_redis.set.assert_called_once() + + def test_delete_item(self): + client_data = { + "client_id": 1, + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": False + } + client = Client(**client_data) + self.db.add_item(client) + result = self.db.delete_item(client) + self.assertTrue(result) + + +class TestClientDatabase(unittest.TestCase): + + def test_delete_client(self): + db = MagicMock() + db.delete_item.return_value = True + client_db = ClientDatabase(backend="json") + client_db.db = db + client_db.get_client_by_api_key = MagicMock() + client_db.get_client_by_api_key.return_value = Client(1, "A") + + result = client_db.delete_client("test_api_key") + self.assertTrue(result) + db.delete_item.assert_called_once() + + def test_get_clients_by_name(self): + db = MagicMock() + client_data = { + "client_id": 1, + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": False + } + client = Client(**client_data) + db.search_by_value.return_value = [client] + + client_db = ClientDatabase(backend="json") + client_db.db = db + clients = client_db.get_clients_by_name("Test Client") + self.assertEqual(len(clients), 1) + self.assertEqual(clients[0].name, "Test Client") + + +if __name__ == '__main__': + unittest.main()