Skip to content

Commit

Permalink
feat: get_class factory methods (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl authored Dec 29, 2024
1 parent cfc42d4 commit 92c5ea8
Showing 1 changed file with 36 additions and 23 deletions.
59 changes: 36 additions & 23 deletions hivemind_plugin_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import enum
from typing import Optional, Dict, Any, Union
from typing import Optional, Dict, Any, Union, Type

from ovos_utils.log import LOG

Expand All @@ -15,6 +15,12 @@ class HiveMindPluginTypes(str, enum.Enum):


class DatabaseFactory:
@classmethod
def get_class(cls, plugin_name: str) -> Type[AbstractDB]:
plugins = find_plugins(HiveMindPluginTypes.DATABASE)
if plugin_name not in plugins:
raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}")
return plugins[plugin_name]

@classmethod
def create(cls, plugin_name: str,
Expand All @@ -23,57 +29,64 @@ def create(cls, plugin_name: str,
password: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None) -> Union[AbstractRemoteDB, AbstractDB]:
plugins = find_plugins(HiveMindPluginTypes.DATABASE)
if plugin_name not in plugins:
raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}")
if issubclass(plugins[plugin_name], AbstractRemoteDB):
return plugins[plugin_name](name=name, subfolder=subfolder,
password=password, host=host, port=port)
return plugins[plugin_name](name=name, subfolder=subfolder,
password=password)
plugin = cls.get_class(plugin_name)
if issubclass(plugin, AbstractRemoteDB):
return plugin(name=name, subfolder=subfolder, password=password, host=host, port=port)
return plugin(name=name, subfolder=subfolder, password=password)


class AgentProtocolFactory:
@classmethod
def get_class(cls, plugin_name: str) -> Type[AgentProtocol]:
plugins = find_plugins(HiveMindPluginTypes.AGENT_PROTOCOL)
if plugin_name not in plugins:
raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}")
return plugins[plugin_name]

@classmethod
def create(cls, plugin_name: str,
config: Optional[Dict[str, Any]] = None,
bus: Optional[Union['FakeBus', 'MessageBusClient']] = None,
hm_protocol: Optional['HiveMindListenerProtocol'] = None) -> AgentProtocol:
config = config or {}
plugins = find_plugins(HiveMindPluginTypes.AGENT_PROTOCOL)
if plugin_name not in plugins:
raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}")
return plugins[plugin_name](config=config, bus=bus, hm_protocol=hm_protocol)
plugin = cls.get_class(plugin_name)
return plugin(config=config, bus=bus, hm_protocol=hm_protocol)


class NetworkProtocolFactory:
@classmethod
def get_class(cls, plugin_name: str) -> Type[NetworkProtocol]:
plugins = find_plugins(HiveMindPluginTypes.NETWORK_PROTOCOL)
if plugin_name not in plugins:
raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}")
return plugins[plugin_name]

@classmethod
def create(cls, plugin_name: str,
config: Optional[Dict[str, Any]] = None,
hm_protocol: Optional['HiveMindListenerProtocol'] = None) -> NetworkProtocol:
config = config or {}
plugins = find_plugins(HiveMindPluginTypes.NETWORK_PROTOCOL)
if plugin_name not in plugins:
raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}")
return plugins[plugin_name](config=config, hm_protocol=hm_protocol)
plugin = cls.get_class(plugin_name)
return plugin(config=config, hm_protocol=hm_protocol)


class BinaryDataHandlerProtocolFactory:

@classmethod
def get_class(cls, plugin_name: str) -> Type[BinaryDataHandlerProtocol]:
plugins = find_plugins(HiveMindPluginTypes.BINARY_PROTOCOL)
if plugin_name not in plugins:
raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}")
return plugins[plugin_name]

@classmethod
def create(cls, plugin_name: str,
config: Optional[Dict[str, Any]] = None,
hm_protocol: Optional['HiveMindListenerProtocol'] = None,
agent_protocol: Optional['AgentProtocol'] = None) -> BinaryDataHandlerProtocol:
config = config or {}
plugins = find_plugins(HiveMindPluginTypes.BINARY_PROTOCOL)
if plugin_name not in plugins:
raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}")
return plugins[plugin_name](config=config,
hm_protocol=hm_protocol,
agent_protocol=agent_protocol)
plugin = cls.get_class(plugin_name)
return plugin(config=config, hm_protocol=hm_protocol, agent_protocol=agent_protocol)


def _iter_entrypoints(plug_type: Optional[str]):
Expand Down

0 comments on commit 92c5ea8

Please sign in to comment.