Skip to content

Commit

Permalink
feat: add plugin factory classes
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Dec 29, 2024
1 parent 1bc3374 commit 5a025d5
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 8 deletions.
159 changes: 159 additions & 0 deletions hivemind_plugin_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import enum
from typing import Optional, Dict, Any, Union

from ovos_utils.log import LOG

from hivemind_plugin_manager.database import AbstractDB, AbstractRemoteDB
from hivemind_plugin_manager.protocols import AgentProtocol, BinaryDataHandlerProtocol, NetworkProtocol


class HiveMindPluginTypes(str, enum.Enum):
DATABASE = "hivemind.database"
NETWORK_PROTOCOL = "hivemind.network.protocol"
AGENT_PROTOCOL = "hivemind.agent.protocol"
BINARY_PROTOCOL = "hivemind.binary.protocol"


def find_plugins(plug_type: HiveMindPluginTypes = None) -> dict:
"""
Finds all plugins matching specific entrypoint type.
Arguments:
plug_type (str): plugin entrypoint string to retrieve
Returns:
dict mapping plugin names to plugin entrypoints
"""
entrypoints = {}
if not plug_type:
plugs = list(HiveMindPluginTypes)
elif isinstance(plug_type, str):
plugs = [plug_type]
else:
plugs = plug_type
for plug in plugs:
for entry_point in _iter_entrypoints(plug):
try:
entrypoints[entry_point.name] = entry_point.load()
if entry_point.name not in entrypoints:
LOG.debug(f"Loaded plugin entry point {entry_point.name}")
except Exception as e:
if entry_point not in find_plugins._errored:
find_plugins._errored.append(entry_point)
# NOTE: this runs in a loop inside skills manager, this would endlessly spam logs
LOG.error(f"Failed to load plugin entry point {entry_point}: "
f"{e}")
return entrypoints


find_plugins._errored = []


def _iter_entrypoints(plug_type: Optional[str]):
"""
Return an iterator containing all entrypoints of the requested type
@param plug_type: entrypoint name to load
@return: iterator of all entrypoints
"""
try:
from importlib_metadata import entry_points
for entry_point in entry_points(group=plug_type):
yield entry_point
except ImportError:
import pkg_resources
for entry_point in pkg_resources.iter_entry_points(plug_type):
yield entry_point


def load_plugin(plug_name: str, plug_type: Optional[HiveMindPluginTypes] = None):
"""Load a specific plugin from a specific plugin type.
Arguments:
plug_type: (str) plugin type name. Ex. "hivemind.agent.protocol".
plug_name: (str) specific plugin name (else consider all plugin types)
Returns:
Loaded plugin Object or None if no matching object was found.
"""
plugins = find_plugins(plug_type)
if plug_name in plugins:
return plugins[plug_name]
plug_type = plug_type or "all plugin types"
LOG.warning(f'Could not find the plugin {plug_type}.{plug_name}')
return None


class DatabaseFactory:

@classmethod
def create(cls, plugin_name: str,
name: str = "clients",
subfolder: str = "hivemind-core",
password: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None) -> Union[AbstractRemoteDB, AbstractDB]:
plugins = find_plugins(HiveMindPluginTypes.DATABASE)
if plugin_name not in plugins:
raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}")
if issubclass(plugins[plugin_name], AbstractRemoteDB):
return plugins[plugin_name](name=name, subfolder=subfolder,
password=password, host=host, port=port)
return plugins[plugin_name](name=name, subfolder=subfolder,
password=password)


class AgentProtocolFactory:

@classmethod
def create(cls, plugin_name: str,
config: Optional[Dict[str, Any]] = None,
bus: Optional[Union['FakeBus', 'MessageBusClient']] = None,
hm_protocol: Optional['HiveMindListenerProtocol'] = None) -> AgentProtocol:
config = config or {}
plugins = find_plugins(HiveMindPluginTypes.AGENT_PROTOCOL)
if plugin_name not in plugins:
raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}")
return plugins[plugin_name](config=config, bus=bus, hm_protocol=hm_protocol)


class NetworkProtocolFactory:

@classmethod
def create(cls, plugin_name: str,
config: Optional[Dict[str, Any]] = None,
hm_protocol: Optional['HiveMindListenerProtocol'] = None) -> NetworkProtocol:
config = config or {}
plugins = find_plugins(HiveMindPluginTypes.NETWORK_PROTOCOL)
if plugin_name not in plugins:
raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}")
return plugins[plugin_name](config=config, hm_protocol=hm_protocol)


class BinaryDataHandlerProtocolFactory:

@classmethod
def create(cls, plugin_name: str,
config: Optional[Dict[str, Any]] = None,
hm_protocol: Optional['HiveMindListenerProtocol'] = None,
agent_protocol: Optional['AgentProtocol'] = None) -> BinaryDataHandlerProtocol:
config = config or {}
plugins = find_plugins(HiveMindPluginTypes.BINARY_PROTOCOL)
if plugin_name not in plugins:
raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}")
return plugins[plugin_name](config=config,
hm_protocol=hm_protocol,
agent_protocol=agent_protocol)


if __name__ == "__main__":
print(find_plugins(HiveMindPluginTypes.DATABASE))
# {'hivemind-json-db-plugin': <class 'json_database.hpm.JsonDB'>,
# 'hivemind-sqlite-db-plugin': <class 'hivemind_sqlite_database.SQLiteDB'>,
# 'hivemind-redis-db-plugin': <class 'hivemind_redis_database.RedisDB'>}
print(find_plugins(HiveMindPluginTypes.NETWORK_PROTOCOL))
# {'hivemind-websocket-plugin': <class 'hivemind_websocket_protocol.HiveMindWebsocketProtocol'>}
print(find_plugins(HiveMindPluginTypes.AGENT_PROTOCOL))
# {'hivemind-ovos-agent-plugin': <class 'ovos_bus_client.hpm.OVOSProtocol'>,
# 'hivemind-persona-agent-plugin': <class 'ovos_persona.hpm.PersonaProtocol'>}}
print(find_plugins(HiveMindPluginTypes.BINARY_PROTOCOL))
# {'hivemind-audio-binary-protocol-plugin': <class 'hivemind_listener.protocol.AudioBinaryProtocol'>}
62 changes: 58 additions & 4 deletions hivemind_plugin_manager/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,17 @@ def __repr__(self) -> str:
return self.serialize()


@dataclass
class AbstractDB(abc.ABC):
"""
Abstract base class for all database implementations.
All database implementations should derive from this class and implement
the abstract methods.
"""
name: str = "clients"
subfolder: str = "hivemind-core"
password: Optional[str] = None

@abc.abstractmethod
def add_item(self, client: Client) -> bool:
Expand All @@ -173,7 +177,6 @@ def add_item(self, client: Client) -> bool:
Returns:
True if the addition was successful, False otherwise.
"""
pass

def delete_item(self, client: Client) -> bool:
"""
Expand Down Expand Up @@ -227,7 +230,6 @@ def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[C
Returns:
A list of clients that match the search criteria.
"""
pass

@abc.abstractmethod
def __len__(self) -> int:
Expand All @@ -237,7 +239,6 @@ def __len__(self) -> int:
Returns:
The number of items in the database.
"""
return 0

@abc.abstractmethod
def __iter__(self) -> Iterable['Client']:
Expand All @@ -247,7 +248,6 @@ def __iter__(self) -> Iterable['Client']:
Returns:
An iterator over the clients in the database.
"""
pass

def sync(self):
"""update db from disk if needed"""
Expand All @@ -262,3 +262,57 @@ def commit(self) -> bool:
"""
return True


@dataclass
class AbstractRemoteDB(AbstractDB):
"""
Abstract base class for remote database implementations.
"""
host: str = "127.0.0.1"
port: Optional[int] = None
name: str = "clients"
subfolder: str = "hivemind-core"
password: Optional[str] = None

@abc.abstractmethod
def add_item(self, client: Client) -> bool:
"""
Add a client to the database.
Args:
client: The client to be added.
Returns:
True if the addition was successful, False otherwise.
"""

@abc.abstractmethod
def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[Client]:
"""
Search for clients by a specific key-value pair.
Args:
key: The key to search by.
val: The value to search for.
Returns:
A list of clients that match the search criteria.
"""

@abc.abstractmethod
def __len__(self) -> int:
"""
Get the number of items in the database.
Returns:
The number of items in the database.
"""

@abc.abstractmethod
def __iter__(self) -> Iterable['Client']:
"""
Iterate over all clients in the database.
Returns:
An iterator over the clients in the database.
"""
27 changes: 23 additions & 4 deletions hivemind_plugin_manager/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@ class _SubProtocol:

@property
def identity(self) -> NodeIdentity:
if not self.hm_protocol:
return NodeIdentity()
return self.hm_protocol.identity

@property
def database(self) -> 'ClientDatabase':
def database(self) -> Optional['ClientDatabase']:
if not self.hm_protocol:
return None
return self.hm_protocol.db

@property
def clients(self) -> Dict[str, 'HiveMindClientConnection']:
if not self.hm_protocol:
return {}
return self.hm_protocol.clients


Expand All @@ -34,7 +40,8 @@ class AgentProtocol(_SubProtocol):
"""protocol to handle Message objects, the payload of HiveMessage objects"""
bus: Union[FakeBus, MessageBusClient] = dataclasses.field(default_factory=FakeBus)
config: Dict[str, Any] = dataclasses.field(default_factory=dict)
hm_protocol: Optional['HiveMindListenerProtocol'] = None
hm_protocol: Optional['HiveMindListenerProtocol'] = None # usually AgentProtocol is passed as kwarg to hm_protocol
# and only then assigned in hm_protocol.__post_init__


@dataclass
Expand All @@ -43,6 +50,12 @@ class NetworkProtocol(_SubProtocol):
config: Dict[str, Any] = dataclasses.field(default_factory=dict)
hm_protocol: Optional['HiveMindListenerProtocol'] = None

@property
def agent_protocol(self) -> Optional['AgentProtocol']:
if not self.hm_protocol:
return None
return self.hm_protocol.agent_protocol

@abc.abstractmethod
def run(self):
pass
Expand All @@ -52,8 +65,14 @@ def run(self):
class BinaryDataHandlerProtocol(_SubProtocol):
"""protocol to handle Binary data HiveMessage objects"""
config: Dict[str, Any] = dataclasses.field(default_factory=dict)
hm_protocol: Optional['HiveMindListenerProtocol'] = None
agent_protocol: Optional[AgentProtocol] = None
hm_protocol: Optional['HiveMindListenerProtocol'] = None # usually BinaryDataHandlerProtocol is passed as kwarg to hm_protocol
# and only then assigned in hm_protocol.__post_init__
agent_protocol: Optional['AgentProtocol'] = None

def __post_init__(self):
# NOTE: the most common scenario is having self.agent_protocol but not having self.hm_protocol yet
if not self.agent_protocol and self.hm_protocol:
self.agent_protocol = self.hm_protocol.agent_protocol

def handle_microphone_input(self, bin_data: bytes,
sample_rate: int,
Expand Down

0 comments on commit 5a025d5

Please sign in to comment.