diff --git a/hivemind_plugin_manager/__init__.py b/hivemind_plugin_manager/__init__.py index cb0e053..e6421f7 100644 --- a/hivemind_plugin_manager/__init__.py +++ b/hivemind_plugin_manager/__init__.py @@ -1,5 +1,5 @@ import enum -from typing import Optional, Dict, Any, Union +from typing import Optional, Dict, Any, Union, Type from ovos_utils.log import LOG @@ -15,6 +15,12 @@ class HiveMindPluginTypes(str, enum.Enum): class DatabaseFactory: + @classmethod + def get_class(cls, plugin_name: str) -> Type[AbstractDB]: + plugins = find_plugins(HiveMindPluginTypes.DATABASE) + if plugin_name not in plugins: + raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") + return plugins[plugin_name] @classmethod def create(cls, plugin_name: str, @@ -23,17 +29,19 @@ def create(cls, plugin_name: str, password: Optional[str] = None, host: Optional[str] = None, port: Optional[int] = None) -> Union[AbstractRemoteDB, AbstractDB]: - plugins = find_plugins(HiveMindPluginTypes.DATABASE) - if plugin_name not in plugins: - raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") - if issubclass(plugins[plugin_name], AbstractRemoteDB): - return plugins[plugin_name](name=name, subfolder=subfolder, - password=password, host=host, port=port) - return plugins[plugin_name](name=name, subfolder=subfolder, - password=password) + plugin = cls.get_class(plugin_name) + if issubclass(plugin, AbstractRemoteDB): + return plugin(name=name, subfolder=subfolder, password=password, host=host, port=port) + return plugin(name=name, subfolder=subfolder, password=password) class AgentProtocolFactory: + @classmethod + def get_class(cls, plugin_name: str) -> Type[AgentProtocol]: + plugins = find_plugins(HiveMindPluginTypes.AGENT_PROTOCOL) + if plugin_name not in plugins: + raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") + return plugins[plugin_name] @classmethod def create(cls, plugin_name: str, @@ -41,39 +49,44 @@ def create(cls, plugin_name: str, bus: Optional[Union['FakeBus', 'MessageBusClient']] = None, hm_protocol: Optional['HiveMindListenerProtocol'] = None) -> AgentProtocol: config = config or {} - plugins = find_plugins(HiveMindPluginTypes.AGENT_PROTOCOL) - if plugin_name not in plugins: - raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") - return plugins[plugin_name](config=config, bus=bus, hm_protocol=hm_protocol) + plugin = cls.get_class(plugin_name) + return plugin(config=config, bus=bus, hm_protocol=hm_protocol) class NetworkProtocolFactory: + @classmethod + def get_class(cls, plugin_name: str) -> Type[NetworkProtocol]: + plugins = find_plugins(HiveMindPluginTypes.NETWORK_PROTOCOL) + if plugin_name not in plugins: + raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") + return plugins[plugin_name] @classmethod def create(cls, plugin_name: str, config: Optional[Dict[str, Any]] = None, hm_protocol: Optional['HiveMindListenerProtocol'] = None) -> NetworkProtocol: config = config or {} - plugins = find_plugins(HiveMindPluginTypes.NETWORK_PROTOCOL) - if plugin_name not in plugins: - raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") - return plugins[plugin_name](config=config, hm_protocol=hm_protocol) + plugin = cls.get_class(plugin_name) + return plugin(config=config, hm_protocol=hm_protocol) class BinaryDataHandlerProtocolFactory: + @classmethod + def get_class(cls, plugin_name: str) -> Type[BinaryDataHandlerProtocol]: + plugins = find_plugins(HiveMindPluginTypes.BINARY_PROTOCOL) + if plugin_name not in plugins: + raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") + return plugins[plugin_name] + @classmethod def create(cls, plugin_name: str, config: Optional[Dict[str, Any]] = None, hm_protocol: Optional['HiveMindListenerProtocol'] = None, agent_protocol: Optional['AgentProtocol'] = None) -> BinaryDataHandlerProtocol: config = config or {} - plugins = find_plugins(HiveMindPluginTypes.BINARY_PROTOCOL) - if plugin_name not in plugins: - raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") - return plugins[plugin_name](config=config, - hm_protocol=hm_protocol, - agent_protocol=agent_protocol) + plugin = cls.get_class(plugin_name) + return plugin(config=config, hm_protocol=hm_protocol, agent_protocol=agent_protocol) def _iter_entrypoints(plug_type: Optional[str]):