From 92c5ea88f00e0053fb36bf7ae62f29cff8418860 Mon Sep 17 00:00:00 2001 From: JarbasAI <33701864+JarbasAl@users.noreply.github.com> Date: Sun, 29 Dec 2024 03:57:30 +0000 Subject: [PATCH] feat: get_class factory methods (#7) --- hivemind_plugin_manager/__init__.py | 59 ++++++++++++++++++----------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/hivemind_plugin_manager/__init__.py b/hivemind_plugin_manager/__init__.py index cb0e053..e6421f7 100644 --- a/hivemind_plugin_manager/__init__.py +++ b/hivemind_plugin_manager/__init__.py @@ -1,5 +1,5 @@ import enum -from typing import Optional, Dict, Any, Union +from typing import Optional, Dict, Any, Union, Type from ovos_utils.log import LOG @@ -15,6 +15,12 @@ class HiveMindPluginTypes(str, enum.Enum): class DatabaseFactory: + @classmethod + def get_class(cls, plugin_name: str) -> Type[AbstractDB]: + plugins = find_plugins(HiveMindPluginTypes.DATABASE) + if plugin_name not in plugins: + raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") + return plugins[plugin_name] @classmethod def create(cls, plugin_name: str, @@ -23,17 +29,19 @@ def create(cls, plugin_name: str, password: Optional[str] = None, host: Optional[str] = None, port: Optional[int] = None) -> Union[AbstractRemoteDB, AbstractDB]: - plugins = find_plugins(HiveMindPluginTypes.DATABASE) - if plugin_name not in plugins: - raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") - if issubclass(plugins[plugin_name], AbstractRemoteDB): - return plugins[plugin_name](name=name, subfolder=subfolder, - password=password, host=host, port=port) - return plugins[plugin_name](name=name, subfolder=subfolder, - password=password) + plugin = cls.get_class(plugin_name) + if issubclass(plugin, AbstractRemoteDB): + return plugin(name=name, subfolder=subfolder, password=password, host=host, port=port) + return plugin(name=name, subfolder=subfolder, password=password) class AgentProtocolFactory: + @classmethod + def get_class(cls, plugin_name: str) -> Type[AgentProtocol]: + plugins = find_plugins(HiveMindPluginTypes.AGENT_PROTOCOL) + if plugin_name not in plugins: + raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") + return plugins[plugin_name] @classmethod def create(cls, plugin_name: str, @@ -41,39 +49,44 @@ def create(cls, plugin_name: str, bus: Optional[Union['FakeBus', 'MessageBusClient']] = None, hm_protocol: Optional['HiveMindListenerProtocol'] = None) -> AgentProtocol: config = config or {} - plugins = find_plugins(HiveMindPluginTypes.AGENT_PROTOCOL) - if plugin_name not in plugins: - raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") - return plugins[plugin_name](config=config, bus=bus, hm_protocol=hm_protocol) + plugin = cls.get_class(plugin_name) + return plugin(config=config, bus=bus, hm_protocol=hm_protocol) class NetworkProtocolFactory: + @classmethod + def get_class(cls, plugin_name: str) -> Type[NetworkProtocol]: + plugins = find_plugins(HiveMindPluginTypes.NETWORK_PROTOCOL) + if plugin_name not in plugins: + raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") + return plugins[plugin_name] @classmethod def create(cls, plugin_name: str, config: Optional[Dict[str, Any]] = None, hm_protocol: Optional['HiveMindListenerProtocol'] = None) -> NetworkProtocol: config = config or {} - plugins = find_plugins(HiveMindPluginTypes.NETWORK_PROTOCOL) - if plugin_name not in plugins: - raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") - return plugins[plugin_name](config=config, hm_protocol=hm_protocol) + plugin = cls.get_class(plugin_name) + return plugin(config=config, hm_protocol=hm_protocol) class BinaryDataHandlerProtocolFactory: + @classmethod + def get_class(cls, plugin_name: str) -> Type[BinaryDataHandlerProtocol]: + plugins = find_plugins(HiveMindPluginTypes.BINARY_PROTOCOL) + if plugin_name not in plugins: + raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") + return plugins[plugin_name] + @classmethod def create(cls, plugin_name: str, config: Optional[Dict[str, Any]] = None, hm_protocol: Optional['HiveMindListenerProtocol'] = None, agent_protocol: Optional['AgentProtocol'] = None) -> BinaryDataHandlerProtocol: config = config or {} - plugins = find_plugins(HiveMindPluginTypes.BINARY_PROTOCOL) - if plugin_name not in plugins: - raise KeyError(f"'{plugin_name}' not found. Available plugins: {list(plugins.keys())}") - return plugins[plugin_name](config=config, - hm_protocol=hm_protocol, - agent_protocol=agent_protocol) + plugin = cls.get_class(plugin_name) + return plugin(config=config, hm_protocol=hm_protocol, agent_protocol=agent_protocol) def _iter_entrypoints(plug_type: Optional[str]):