diff --git a/appdaemon/__main__.py b/appdaemon/__main__.py index d2981981a..6d3006857 100644 --- a/appdaemon/__main__.py +++ b/appdaemon/__main__.py @@ -145,7 +145,12 @@ def run(self, ad_config_model: AppDaemonConfig, hadashboard, admin, aui, api, ht # Initialize Dashboard/API/admin - if http is not None and (hadashboard is not None or admin is not None or aui is not None or api is not False): + if http is not None and ( + hadashboard is not None or + admin is not None or + aui is not None or + api is not False + ): self.logger.info("Initializing HTTP") self.http_object = HTTP( self.AD, @@ -253,7 +258,11 @@ def main(self): # noqa: C901 pidfile = args.pidfile default_config_files = ["appdaemon.yaml", "appdaemon.toml"] - default_config_paths = [Path("~/.homeassistant").expanduser(), Path("/etc/appdaemon"), Path("/conf")] + default_config_paths = [ + Path("~/.homeassistant").expanduser(), + Path("/etc/appdaemon"), + Path("/conf") + ] if args.configfile is not None: config_file = Path(args.configfile).resolve() diff --git a/appdaemon/adapi.py b/appdaemon/adapi.py index 37ca3391e..8acc4bcb5 100644 --- a/appdaemon/adapi.py +++ b/appdaemon/adapi.py @@ -165,7 +165,15 @@ def _log( extra: Mapping[str, object] | None = None, ) -> None: ... - def _log(self, logger: Logger, msg: str, level: str | int = "INFO", *args, ascii_encode: bool = True, **kwargs) -> None: + def _log( + self, + logger: Logger, + msg: str, + level: str | int = "INFO", + *args, + ascii_encode: bool = True, + **kwargs + ) -> None: if ascii_encode: msg = str(msg).encode("utf-8", "replace").decode("ascii", "replace") @@ -191,7 +199,14 @@ def log( extra: Mapping[str, object] | None = None, ) -> None: ... - def log(self, msg: str, *args, level: str | int = "INFO", log: str | None = None, **kwargs) -> None: + def log( + self, + msg: str, + *args, + level: str | int = "INFO", + log: str | None = None, + **kwargs + ) -> None: """Logs a message to AppDaemon's main logfile. Args: @@ -246,10 +261,24 @@ def log(self, msg: str, *args, level: str | int = "INFO", log: str | None = None @overload def error( - self, msg: str, *args, level: str | int = "INFO", ascii_encode: bool = True, exc_info: bool = False, stack_info: bool = False, stacklevel: int = 1, extra: Mapping[str, object] | None = None + self, + msg: str, + *args, + level: str | int = "INFO", + ascii_encode: bool = True, + exc_info: bool = False, + stack_info: bool = False, + stacklevel: int = 1, + extra: Mapping[str, object] | None = None ) -> None: ... - def error(self, msg: str, *args, level: str | int = "INFO", **kwargs) -> None: + def error( + self, + msg: str, + *args, + level: str | int = "INFO", + **kwargs + ) -> None: """Logs a message to AppDaemon's error logfile. Args: @@ -278,10 +307,25 @@ def error(self, msg: str, *args, level: str | int = "INFO", **kwargs) -> None: self._log(self.err, msg, level, *args, **kwargs) @overload - async def listen_log(self, callback: Callable, level: str | int, namespace: str, log: str, pin: bool, pin_thread: int, **kwargs) -> str: ... + async def listen_log( + self, + callback: Callable, + level: str | int, + namespace: str, + log: str, + pin: bool, + pin_thread: int, + **kwargs + ) -> str: ... @utils.sync_decorator - async def listen_log(self, callback: Callable, level: str | int = "INFO", namespace: str = "admin", **kwargs) -> str: + async def listen_log( + self, + callback: Callable, + level: str | int = "INFO", + namespace: str = "admin", + **kwargs + ) -> str: """Registers the App to receive a callback every time an App logs a message. Args: @@ -546,7 +590,12 @@ def namespace_exists(self, namespace: str) -> bool: return self.AD.state.namespace_exists(namespace) @utils.sync_decorator - async def add_namespace(self, namespace: str, writeback: str = "safe", persist: bool = True) -> str | None: + async def add_namespace( + self, + namespace: str, + writeback: str = "safe", + persist: bool = True + ) -> str | None: """Used to add a user-defined namespaces from apps, which has a database file associated with it. This way, when AD restarts these entities will be reloaded into AD with its @@ -662,7 +711,11 @@ async def get_app(self, name: str) -> "ADAPI": def _check_entity(self, namespace: str, entity_id: str): """Ensures that the entity exists in the given namespace""" - if entity_id is not None and "." in entity_id and not self.AD.state.entity_exists(namespace, entity_id): + if ( + entity_id is not None and + "." in entity_id and + not self.AD.state.entity_exists(namespace, entity_id) + ): self.logger.warning("%s: Entity %s not found in namespace %s", self.name, entity_id, namespace) @staticmethod @@ -680,7 +733,13 @@ def get_ad_version() -> str: # @utils.sync_decorator - async def add_entity(self, entity_id: str, state: Any | None = None, attributes: dict | None = None, namespace: str | None = None) -> None: + async def add_entity( + self, + entity_id: str, + state: Any | None = None, + attributes: dict | None = None, + namespace: str | None = None + ) -> None: """Adds a non-existent entity, by creating it within a namespaces. If an entity doesn't exists and needs to be created, this function can be used to create it locally. @@ -867,7 +926,13 @@ def friendly_name(self, entity_id: str, namespace: str | None = None) -> str: namespace = namespace or self.namespace self._check_entity(namespace, entity_id) - return self.get_state(entity_id=entity_id, attribute="friendly_name", default=entity_id, namespace=namespace, copy=False) + return self.get_state( + entity_id=entity_id, + attribute="friendly_name", + default=entity_id, + namespace=namespace, + copy=False + ) # if entity_id in state: # if "friendly_name" in state[entity_id]["attributes"]: # return state[entity_id]["attributes"]["friendly_name"] @@ -1295,7 +1360,13 @@ async def listen_state( ) -> str | list[str]: ... @utils.sync_decorator - async def listen_state(self, callback: Callable, entity_id: str | Iterable[str] | None = None, namespace: str | None = None, **kwargs) -> str | list[str]: + async def listen_state( + self, + callback: Callable, + entity_id: str | Iterable[str] | None = None, + namespace: str | None = None, + **kwargs + ) -> str | list[str]: """Registers a callback to react to state changes. This function allows the user to register a callback for a wide variety of state changes. @@ -1439,7 +1510,10 @@ async def listen_state(self, callback: Callable, entity_id: str | Iterable[str] case Iterable(): for e in entity_id: self._check_entity(namespace, e) - return [await self.get_entity_api(namespace, e).listen_state(callback, **kwargs) for e in entity_id] + return [ + await self.get_entity_api(namespace, e).listen_state(callback, **kwargs) + for e in entity_id + ] @utils.sync_decorator async def cancel_listen_state(self, handle: str, silent: bool = False) -> bool: @@ -1559,10 +1633,25 @@ def get_state( return api.get_state(attribute=attribute, default=default, copy=copy) @overload - async def set_state(self, entity_id: str, state: Any | None, namespace: str | None, attributes: dict, replace: bool, **kwargs) -> dict: ... + async def set_state( + self, + entity_id: str, + state: Any | None, + namespace: str | None, + attributes: dict, + replace: bool, + **kwargs + ) -> dict: ... @utils.sync_decorator - async def set_state(self, entity_id: str, state: Any | None = None, namespace: str | None = None, check_existence: bool = True, **kwargs) -> dict: + async def set_state( + self, + entity_id: str, + state: Any | None = None, + namespace: str | None = None, + check_existence: bool = True, + **kwargs + ) -> dict: """Updates the state of the specified entity. Args: @@ -1615,7 +1704,13 @@ def _check_service(service: str) -> None: if service.find("/") == -1: raise ValueError(f"Invalid Service Name: {service}") - def register_service(self, service: str, cb: Callable, namespace: str | None = None, **kwargs) -> None: + def register_service( + self, + service: str, + cb: Callable, + namespace: str | None = None, + **kwargs + ) -> None: """Registers a service that can be called from other apps, the REST API and the Event Stream Using this function, an App can register a function to be available in the service registry. @@ -1648,7 +1743,14 @@ def register_service(self, service: str, cb: Callable, namespace: str | None = N self.logger.debug("register_service: %s, %s", service, kwargs) namespace = namespace or self.namespace - self.AD.services.register_service(namespace, *service.split("/"), cb, __async="auto", __name=self.name, **kwargs) + self.AD.services.register_service( + namespace, + *service.split("/"), + cb, + __async="auto", + __name=self.name, + **kwargs + ) def deregister_service(self, service: str, namespace: str | None = None) -> bool: """Deregisters a service that had been previously registered @@ -1877,10 +1979,26 @@ async def cancel_sequence(self, sequence: str | list[str] | Future) -> None: # @overload - async def listen_event(self, callback: Callable, event: str | list[str], namespace: str | None, timeout: int, oneshot: bool, pin: bool, pin_thread: int, **kwargs) -> str | list[str]: ... + async def listen_event( + self, + callback: Callable, + event: str | list[str], + namespace: str | None, + timeout: int, + oneshot: bool, + pin: bool, + pin_thread: int, + **kwargs + ) -> str | list[str]: ... @utils.sync_decorator - async def listen_event(self, callback: Callable, event: str | list[str] = None, namespace: str | None = None, **kwargs) -> str | list[str]: + async def listen_event( + self, + callback: Callable, + event: str | list[str] = None, + namespace: str | None = None, + **kwargs + ) -> str | list[str]: """Registers a callback for a specific event, or any event. Args: @@ -1953,7 +2071,10 @@ async def listen_event(self, callback: Callable, event: str | list[str] = None, case str() | None: return await self.AD.events.add_event_callback(self.name, namespace, callback, event, **kwargs) case Iterable(): - return [await self.AD.events.add_event_callback(self.name, namespace, callback, e, **kwargs) for e in event] + return [ + await self.AD.events.add_event_callback(self.name, namespace, callback, e, **kwargs) + for e in event + ] case _: self.logger.warning(f"Invalid event: {event}") @@ -2083,7 +2204,14 @@ async def sun_down(self) -> bool: return await self.AD.sched.sun_down() @overload - async def parse_time(self, time_str: str, name: str | None = None, aware: bool = False, today: bool = False, days_offset: int = 0) -> dt.time: ... + async def parse_time( + self, + time_str: str, + name: str | None = None, + aware: bool = False, + today: bool = False, + days_offset: int = 0 + ) -> dt.time: ... @utils.sync_decorator async def parse_time(self, time_str: str, name: str | None = None, *args, **kwargs) -> dt.time: @@ -2137,7 +2265,14 @@ async def parse_time(self, time_str: str, name: str | None = None, *args, **kwar return await self.AD.sched.parse_time(time_str, name, *args, **kwargs) @overload - async def parse_datetime(self, time_str: str, name: str | None = None, aware: bool = False, today: bool = False, days_offset: int = 0) -> dt.time: ... + async def parse_datetime( + self, + time_str: str, + name: str | None = None, + aware: bool = False, + today: bool = False, + days_offset: int = 0 + ) -> dt.time: ... @utils.sync_decorator async def parse_datetime(self, time_str: str, name: str | None = None, *args, **kwargs) -> dt.datetime: @@ -2436,7 +2571,17 @@ async def info_timer(self, handle: str) -> tuple[dt.datetime, int, dict] | None: return await self.AD.sched.info_timer(handle, self.name) @utils.sync_decorator - async def run_in(self, callback: Callable, delay: float, *args, random_start: int = None, random_end: int = None, pin: bool | None = None, pin_thread: int | None = None, **kwargs) -> str: + async def run_in( + self, + callback: Callable, + delay: float, + *args, + random_start: int = None, + random_end: int = None, + pin: bool | None = None, + pin_thread: int | None = None, + **kwargs + ) -> str: """Runs the callback in a defined number of seconds. This is used to add a delay, for instance, a 60 second delay before @@ -2494,7 +2639,15 @@ async def run_in(self, callback: Callable, delay: float, *args, random_start: in @utils.sync_decorator async def run_once( - self, callback: Callable, start: dt.time | str | None = None, *args, random_start: int = None, random_end: int = None, pin: bool | None = None, pin_thread: int | None = None, **kwargs + self, + callback: Callable, + start: dt.time | str | None = None, + *args, + random_start: int = None, + random_end: int = None, + pin: bool | None = None, + pin_thread: int | None = None, + **kwargs ) -> str: """Runs the callback once, at the specified time of day. @@ -2543,11 +2696,26 @@ async def run_once( >>> handle = self.run_once(self.run_once_c, "sunrise + 01:00:00") """ - return await self.run_at(callback, start, *args, random_start=random_start, random_end=random_end, pin=pin, pin_thread=pin_thread, **kwargs) + return await self.run_at( + callback, start, *args, + random_start=random_start, + random_end=random_end, + pin=pin, + pin_thread=pin_thread, + **kwargs + ) @utils.sync_decorator async def run_at( - self, callback: Callable, start: dt.time | str | None = None, *args, random_start: int = None, random_end: int = None, pin: bool | None = None, pin_thread: int | None = None, **kwargs + self, + callback: Callable, + start: dt.time | str | None = None, + *args, + random_start: int = None, + random_end: int = None, + pin: bool | None = None, + pin_thread: int | None = None, + **kwargs ) -> str: """Runs the callback once, at the specified time of day. @@ -2627,7 +2795,15 @@ async def run_at( @utils.sync_decorator async def run_daily( - self, callback: Callable, start: dt.time | str | None = None, *args, random_start: int = None, random_end: int = None, pin: bool | None = None, pin_thread: int | None = None, **cb_kwargs + self, + callback: Callable, + start: dt.time | str | None = None, + *args, + random_start: int = None, + random_end: int = None, + pin: bool | None = None, + pin_thread: int | None = None, + **cb_kwargs ) -> str: """Runs the callback at the same time every day. @@ -2699,11 +2875,32 @@ async def run_daily( match sun: case None: - return await self.run_every(callback, start, timedelta(days=1), *args, **ad_kwargs, **cb_kwargs) + return await self.run_every( + callback, + start, + timedelta(days=1), + *args, + **ad_kwargs, + **cb_kwargs + ) case "sunrise": - return await self.run_at_sunrise(callback, *args, repeat=True, offset=offset, **ad_kwargs, **cb_kwargs) + return await self.run_at_sunrise( + callback, + *args, + repeat=True, + offset=offset, + **ad_kwargs, + **cb_kwargs + ) case "sunset": - return await self.run_at_sunset(callback, *args, repeat=True, offset=offset, **ad_kwargs, **cb_kwargs) + return await self.run_at_sunset( + callback, + *args, + repeat=True, + offset=offset, + **ad_kwargs, + **cb_kwargs + ) @utils.sync_decorator async def run_hourly( @@ -2753,7 +2950,17 @@ async def run_hourly( >>> self.run_hourly(self.run_hourly_c, runtime) """ - return await self.run_every(callback, start, timedelta(hours=1), *args, random_start=random_start, random_end=random_end, pin=pin, pin_thread=pin_thread, **kwargs) + return await self.run_every( + callback, + start, + timedelta(hours=1), + *args, + random_start=random_start, + random_end=random_end, + pin=pin, + pin_thread=pin_thread, + **kwargs + ) @utils.sync_decorator async def run_minutely( @@ -2803,7 +3010,17 @@ async def run_minutely( >>> self.run_minutely(self.run_minutely_c, time) """ - return await self.run_every(callback, start, timedelta(minutes=1), *args, random_start=random_start, random_end=random_end, pin=pin, pin_thread=pin_thread, **kwargs) + return await self.run_every( + callback, + start, + timedelta(minutes=1), + *args, + random_start=random_start, + random_end=random_end, + pin=pin, + pin_thread=pin_thread, + **kwargs + ) @utils.sync_decorator async def run_every( @@ -3041,7 +3258,15 @@ async def run_at_sunrise( # Dashboard # - def dash_navigate(self, target: str, timeout: int = -1, ret: str | None = None, sticky: int = 0, deviceid: str | None = None, dashid: str | None = None) -> None: + def dash_navigate( + self, + target: str, + timeout: int = -1, + ret: str | None = None, + sticky: int = 0, + deviceid: str | None = None, + dashid: str | None = None + ) -> None: """Forces all connected Dashboards to navigate to a new URL. Args: diff --git a/appdaemon/app_management.py b/appdaemon/app_management.py index 4bd3ba712..9bb2bf853 100644 --- a/appdaemon/app_management.py +++ b/appdaemon/app_management.py @@ -121,8 +121,7 @@ def __init__(self, ad: "AppDaemon"): # Apply the profiler_decorator if the config option is enabled if self.AD.check_app_updates_profile: - self.check_app_updates = self.profiler_decorator( - self.check_app_updates) + self.check_app_updates = self.profiler_decorator(self.check_app_updates) @property def config_filecheck(self) -> FileCheck: @@ -656,13 +655,25 @@ async def check_app_config_files(self, update_actions: UpdateActions): files = await self.get_app_config_files() self.dependency_manager.app_deps.update(files) + # If there were config file changes if self.config_filecheck.there_were_changes: self.logger.debug(" Config file changes ".center(75, "=")) self.config_filecheck.log_changes(self.logger, self.AD.app_dir) + # Read any new/modified files into a fresh config model files_to_read = self.config_filecheck.new | self.config_filecheck.modified freshly_read_cfg = await self.read_all(files_to_read) + # TODO: Move this behavior to the model validation step eventually + # It has to be here for now because the files get read in multiple places + for gm in freshly_read_cfg.global_modules(): + rel_path = gm.config_path.relative_to(self.AD.app_dir) + self.logger.warning(f"Global modules are deprecated: '{gm.name}' defined in {rel_path}") + + if gm := freshly_read_cfg.root.get("global_modules"): + gm = ", ".join(f"'{g}'" for g in gm) + self.logger.warning(f"Global modules are deprecated: {gm}") + current_apps = self.valid_apps for name, cfg in freshly_read_cfg.app_definitions(): if isinstance(cfg, SequenceConfig): @@ -817,13 +828,13 @@ async def wrapper(*args, **kwargs): return wrapper # @utils.timeit - async def check_app_updates(self, plugin: str = None, mode: UpdateMode = UpdateMode.NORMAL): + async def check_app_updates(self, plugin_ns: str = None, mode: UpdateMode = UpdateMode.NORMAL): """Checks the states of the Python files that define the apps, reloading when necessary. Called as part of :meth:`.utility_loop.Utility.loop` Args: - plugin (str, optional): Plugin to restart, if necessary. Defaults to None. + plugin_ns (str, optional): Namespace of a plugin to restart, if necessary. Defaults to None. mode (UpdateMode, optional): Defaults to UpdateMode.NORMAL. """ if not self.AD.apps: @@ -859,7 +870,8 @@ async def check_app_updates(self, plugin: str = None, mode: UpdateMode = UpdateM # self._add_reload_apps(update_actions) # self._check_for_deleted_modules(update_actions) - await self._restart_plugin(plugin, update_actions) + if mode == UpdateMode.PLUGIN_RESTART: + await self._restart_plugin_apps(plugin_ns, update_actions) await self._import_modules(update_actions) @@ -994,29 +1006,23 @@ async def check_app_python_files(self, update_actions: UpdateActions): self.logger.info("Deletion affects apps %s", affected) update_actions.apps.term |= affected - async def _restart_plugin(self, plugin, update_actions: UpdateActions): - if plugin is not None: - self.logger.info("Processing restart for %s", plugin) - # This is a restart of one of the plugins so check which apps need to be restarted - for app in self.app_config: - reload = False - if app in self.non_apps: - continue - if "plugin" in self.app_config[app]: - for this_plugin in utils.single_or_list(self.app_config[app]["plugin"]): - if this_plugin == plugin: - # We got a match so do the reload - reload = True - break - elif plugin == "__ALL__": - reload = True - break - else: - # No plugin dependency specified, reload to error on the side of caution - reload = True + async def _restart_plugin_apps(self, plugin_ns: str | None, update_actions: UpdateActions): + """If a plugin ever re-connects after the initial startup, the apps that use it's plugin + all need to be restarted. The apps that belong to the plugin are determined by namespace. + """ + if plugin_ns is not None: + self.logger.info(f"Processing restart for plugin namespace '{plugin_ns}'") + + app_names = set( + app + for app, cfg in self.app_config.root.items() # For each config key + if isinstance(cfg, AppConfig) and # The config key is for an app + (mo := self.objects.get(app)) and # There's a valid ManagedObject + mo.object.namespace == plugin_ns # Its namespace matches the plugins + ) - if reload is True: - update_actions.apps.reload.add(app) + deps = self.dependency_manager.app_deps.get_dependents(app_names) + update_actions.apps.reload |= deps async def _stop_apps(self, update_actions: UpdateActions): """Terminate apps. Returns the set of app names that failed to properly terminate. diff --git a/appdaemon/http.py b/appdaemon/http.py index 916f86a99..28c85399c 100644 --- a/appdaemon/http.py +++ b/appdaemon/http.py @@ -95,7 +95,9 @@ async def wrapper(*args): return await myfunc(*args) elif "adcreds" in request.cookies: - match = await utils.run_in_executor(self, bcrypt.checkpw, str.encode(self.password), str.encode(request.cookies["adcreds"])) + match = await utils.run_in_executor( + self, bcrypt.checkpw, str.encode(self.password), str.encode(request.cookies["adcreds"]) + ) if match: return await myfunc(*args) @@ -369,7 +371,9 @@ def _process_http(self, http): self._process_arg("url", http) if not self.url: - self.logger.warning("'{arg}' is '{value}'. Please configure appdaemon.yaml".format(arg="url", value=self.url)) + self.logger.warning( + "'{arg}' is '{value}'. Please configure appdaemon.yaml".format(arg="url", value=self.url) + ) exit(0) self._process_arg("transport", http) diff --git a/appdaemon/models/config/app.py b/appdaemon/models/config/app.py index 45e1061a8..9347f744e 100644 --- a/appdaemon/models/config/app.py +++ b/appdaemon/models/config/app.py @@ -157,6 +157,12 @@ def app_definitions(self): if isinstance(cfg, (AppConfig, SequenceConfig)) ) + def global_modules(self) -> list[GlobalModule]: + return [ + cfg for cfg in self.root.values() + if isinstance(cfg, GlobalModule) + ] + def app_names(self) -> set[str]: """Returns all the app names for regular user apps and global module apps""" return set(app_name for app_name, cfg in self.root.items() if isinstance(cfg, BaseApp)) @@ -165,7 +171,12 @@ def apps_from_file(self, paths: Iterable[Path]): if not isinstance(paths, set): paths = set(paths) - return set(app_name for app_name, cfg in self.root.items() if isinstance(cfg, (AppConfig, GlobalModule)) and cfg.config_path in paths) + return set( + app_name + for app_name, cfg in self.root.items() + if isinstance(cfg, BaseApp) and + cfg.config_path in paths + ) @property def active_app_count(self) -> int: diff --git a/appdaemon/models/config/plugin.py b/appdaemon/models/config/plugin.py index c2954e9a2..658bfb61b 100644 --- a/appdaemon/models/config/plugin.py +++ b/appdaemon/models/config/plugin.py @@ -10,7 +10,7 @@ from typing_extensions import deprecated -class PluginConfig(BaseModel, extra="allow"): +class PluginConfig(BaseModel, extra="forbid"): type: str name: str """Gets set by a field_validator in the AppDaemonConfig""" @@ -59,12 +59,34 @@ def disabled(self) -> bool: return self.disable +class StartupState(BaseModel): + state: Any + attributes: dict[str, Any] | None = None + + +class StateStartupCondition(BaseModel): + entity: str + value: StartupState | None = None + + +class EventStartupCondition(BaseModel): + event_type: str + data: dict | None = None + + +class StartupConditions(BaseModel): + delay: int | float | None = None + state: StateStartupCondition | None = None + event: EventStartupCondition | None = None + + + class HASSConfig(PluginConfig): ha_url: str = "http://supervisor/core" token: SecretStr ha_key: Annotated[SecretStr, deprecated("'ha_key' is deprecated. Please use long lived tokens instead")] | None = None - appdaemon_startup_conditions: dict = Field(default_factory=dict) - plugin_startup_conditions: dict = Field(default_factory=dict) + appdaemon_startup_conditions: StartupConditions | None = None + plugin_startup_conditions: StartupConditions | None = None cert_path: Path | None = None cert_verify: bool | None = None commtype: str = "WS" diff --git a/appdaemon/models/internal/app_management.py b/appdaemon/models/internal/app_management.py index 11bcfc7e5..adc8620e4 100644 --- a/appdaemon/models/internal/app_management.py +++ b/appdaemon/models/internal/app_management.py @@ -1,7 +1,7 @@ import uuid from copy import copy from dataclasses import dataclass, field -from enum import Enum +from enum import Enum, auto from pathlib import Path from typing import Any, Literal, Optional @@ -20,9 +20,10 @@ class UpdateMode(Enum): Terminate all apps """ - INIT = 0 - NORMAL = 1 - TERMINATE = 2 + INIT = auto() + NORMAL = auto() + PLUGIN_RESTART = auto() + TERMINATE = auto() diff --git a/appdaemon/plugin_management.py b/appdaemon/plugin_management.py index a2e7de59a..b539a0c6c 100644 --- a/appdaemon/plugin_management.py +++ b/appdaemon/plugin_management.py @@ -37,6 +37,7 @@ class PluginBase(abc.ABC): updates_recv: int last_check_ts: float + connect_event: asyncio.Event ready_event: asyncio.Event constraints: list @@ -46,7 +47,8 @@ class PluginBase(abc.ABC): The first connection a plugin makes is handled a little differently because it'll be at startup and it'll be before any apps have been - loaded.""" + loaded. + """ stopping: bool = False """Flag that indicates whether AppDaemon is currently shutting down.""" @@ -57,6 +59,7 @@ def __init__(self, ad: "AppDaemon", name: str, config: PluginConfig): self.config = config self.logger = self.AD.logging.get_child(name) self.error = self.logger + self.connect_event = asyncio.Event() self.ready_event = asyncio.Event() self.constraints = [] self.stopping = False @@ -92,9 +95,7 @@ def namespaces(self, new: Iterable[str]): self.config.namespaces = new @property - def all_namespaces( - self, - ) -> list[str]: + def all_namespaces(self) -> list[str]: """A list of namespaces that includes the main namespace as well as any extra ones.""" return [self.namespace] + self.namespaces @@ -155,7 +156,7 @@ async def notify_plugin_started(self, meta: dict, state: dict): - sets the namespace state in self.AD.state - adds the plugin entity in self.AD.state - - sets the pluginobject to active + - sets the plugin object to active - fires a ``plugin_started`` event Arguments: @@ -173,7 +174,11 @@ async def notify_plugin_started(self, meta: dict, state: dict): event_coro = self.AD.events.process_event(ns, event) self.AD.loop.create_task(event_coro) self.AD.plugins.plugin_meta[ns] = meta - await self.AD.state.set_namespace_state(namespace=ns, state=state, persist=self.config.persist_entities) + await self.AD.state.set_namespace_state( + namespace=ns, + state=state, + persist=self.config.persist_entities + ) # This accounts for the case where there's not a plugin associated with the object if po := self.AD.plugins.plugin_objs.get(ns): @@ -196,7 +201,11 @@ async def notify_plugin_started(self, meta: dict, state: dict): ) if not self.first_time: - self.AD.loop.create_task(self.AD.app_management.check_app_updates(plugin=self.name, mode=UpdateMode.INIT)) + self.AD.loop.create_task( + self.AD.app_management.check_app_updates( + plugin_ns=self.namespace, + mode=UpdateMode.PLUGIN_RESTART + )) class PluginManagement: @@ -222,9 +231,9 @@ class PluginManagement: plugin_objs: Dict[str, PluginBase] """Dictionary storing the instantiated plugin objects. {: { - "object": , - "active": , - "name": + "object": , + "active": , + "name": }} """ required_meta = ["latitude", "longitude", "elevation", "time_zone"] @@ -407,17 +416,24 @@ def get_plugin_from_namespace(self, namespace: str) -> str: async def notify_plugin_stopped(self, name, namespace): self.plugin_objs[namespace]["active"] = False - await self.AD.events.process_event(namespace, {"event_type": "plugin_stopped", "data": {"name": name}}) + data = {"event_type": "plugin_stopped", "data": {"name": name}} + await self.AD.events.process_event(namespace, data) def get_plugin_meta(self, namespace: str) -> dict: return self.plugin_meta.get(namespace, {}) - async def wait_for_plugins(self): - self.logger.info("Waiting for plugins to be ready") - events: Iterable[asyncio.Event] = (plugin["object"].ready_event for plugin in self.plugin_objs.values()) - tasks = (self.AD.loop.create_task(e.wait()) for e in events) - await asyncio.wait(tasks) - self.logger.info("All plugins ready") + async def wait_for_plugins(self, timeout: float | None = None): + """Waits for the user-configured plugin startup conditions. + + Specifically, this waits for each of their ready events + """ + self.logger.info('Waiting for plugins to be ready') + events: Generator[asyncio.Event, None, None] = ( + plugin['object'].ready_event for plugin in self.plugin_objs.values() + ) + tasks = [self.AD.loop.create_task(e.wait()) for e in events] + await asyncio.wait(tasks, timeout=timeout) + self.logger.info('All plugins ready') def get_config_for_namespace(self, namespace: str) -> PluginConfig: plugin_name = self.get_plugin_from_namespace(namespace) @@ -442,7 +458,10 @@ async def update_plugin_state(self): if await self.time_since_plugin_update(plugin.name) > cfg.refresh_delay: self.logger.debug(f"Refreshing {plugin.name}[{cfg.type}] state") try: - state = await asyncio.wait_for(plugin.get_complete_state(), timeout=cfg.refresh_timeout) + state = await asyncio.wait_for( + plugin.get_complete_state(), + timeout=cfg.refresh_timeout + ) except asyncio.TimeoutError: self.logger.warning( "Timeout refreshing %s state - retrying in %s seconds", diff --git a/appdaemon/plugins/hass/hassplugin.py b/appdaemon/plugins/hass/hassplugin.py index f1a13d9ae..55caa5fbc 100644 --- a/appdaemon/plugins/hass/hassplugin.py +++ b/appdaemon/plugins/hass/hassplugin.py @@ -19,7 +19,7 @@ import appdaemon.utils as utils from appdaemon.appdaemon import AppDaemon -from appdaemon.models.config.plugin import HASSConfig +from appdaemon.models.config.plugin import HASSConfig, StartupConditions from appdaemon.plugin_management import PluginBase from .exceptions import HAEventsSubError @@ -74,6 +74,8 @@ class HassPlugin(PluginBase): _silent_results: dict[int, bool] startup_conditions: list[StartupWaitCondition] + start: float + first_time: bool = True stopping: bool = False @@ -91,7 +93,6 @@ def __init__(self, ad: "AppDaemon", name: str, config: HASSConfig): self.stopping = False self.logger.info("HASS Plugin initialization complete") - self.start = perf_counter() def stop(self): self.logger.debug("stop() called for %s", self.name) @@ -119,12 +120,14 @@ async def websocket_msg_factory(self): Handles creating the connection based on the HASSConfig and updates the performance counters """ + self.start = perf_counter() async with self.create_session() as self.session: async with self.session.ws_connect(self.config.websocket_url) as self.ws: self.id = 0 async for msg in self.ws: self.update_perf(bytes_recv=len(msg.data), updates_recv=1) yield msg + self.connect_event.clear() async def match_ws_msg(self, msg: aiohttp.WSMessage) -> dict: """Wraps a match/case statement for the ``msg.type``""" @@ -169,6 +172,7 @@ async def process_websocket_json(self, resp: dict): async def __post_conn__(self): """Initialization to do after getting connected to the Home Assistant websocket""" + self.connect_event.set() return await self.websocket_send_json(**self.config.auth_json) async def __post_auth__(self): @@ -191,8 +195,13 @@ async def __post_auth__(self): service_coro = looped_coro(self.get_hass_services, self.config.services_sleep_time) self.AD.loop.create_task(service_coro(self)) - await self.wait_for_start_conditions() - self.logger.info("All startup conditions met") + + if self.first_time: + conditions = self.config.appdaemon_startup_conditions + else: + conditions = self.config.plugin_startup_conditions + await self.wait_for_conditions(conditions) + self.logger.info("All plugin startup conditions met") self.ready_event.set() await self.notify_plugin_started( @@ -243,7 +252,10 @@ async def receive_event(self, event: dict): # check startup conditions if not self.is_ready: for condition in self.startup_conditions: - condition.check_received_event(event) + if not condition.conditions_met: + condition.check_received_event(event) + if condition.conditions_met: + self.logger.info(f'HASS startup condition met {condition}') match typ := event["event_type"]: # https://data.home-assistant.io/docs/events/#service_registered @@ -266,6 +278,8 @@ async def receive_event(self, event: dict): ... case "android.zone_entered": ... + case "component_loaded": + self.logger.debug(f'Loaded component: {event["data"]["component"]}') case _: if typ.startswith('recorder'): return @@ -388,49 +402,59 @@ async def http_method( raise NotImplementedError('Unhandled error: HTTP %s', resp.status) return resp - async def wait_for_start_conditions(self): - condition_tasks = [] - if delay := self.config.plugin_startup_conditions.get('delay'): - self.logger.info(f'Adding a {delay:.0f}s delay to the {self.name} startup') - condition_tasks.append( - self.AD.loop.create_task( - asyncio.sleep(delay) - ) - ) + async def wait_for_conditions(self, conditions: StartupConditions | None): + if conditions is None: + return + + self.startup_conditions = [] - if event := self.config.plugin_startup_conditions.get('event'): + if event := conditions.event: self.logger.info(f'Adding startup event condition: {event}') - condition = StartupWaitCondition(event) - self.startup_conditions.append(condition) - condition_tasks.append( + event_cond_data = event.model_dump(exclude_unset=True) + self.startup_conditions.append(StartupWaitCondition(event_cond_data)) + + if cond := conditions.state: + current_state = await self.check_for_entity(cond.entity) + if cond.value is None: + if current_state is False: + # Wait for entity to exist + self.startup_conditions.append( + StartupWaitCondition({ + 'event_type': 'state_changed', + 'data': {'entity_id': cond.entity} + })) + else: + self.logger.info(f'Startup state condition already met: {cond.entity} exists') + else: + data = cond.model_dump(exclude_unset=True) + if utils.deep_compare(data['value'], current_state): + self.logger.info(f'Startup state condition already met: {data}') + else: + self.logger.info(f'Adding startup state condition: {data}') + self.startup_conditions.append(StartupWaitCondition({ + 'event_type': 'state_changed', + 'data': { + 'entity_id': cond.entity, + 'new_state': data['value'] + } + })) + + tasks = [ + self.AD.loop.create_task(cond.event.wait()) + for cond in self.startup_conditions + ] + + if delay := conditions.delay: + self.logger.info(f'Adding a {delay:.0f}s delay to the {self.name} startup') + tasks.append( self.AD.loop.create_task( - condition.event.wait() + asyncio.sleep(delay) ) ) - if cond := self.config.plugin_startup_conditions.get('state'): - state = await self.get_plugin_state(cond['entity']) - if utils.deep_compare(cond['value'], state): - self.logger.info(f'Startup state condition already met: {cond}') - else: - self.logger.info(f'Adding startup state condition: {cond}') - condition = StartupWaitCondition({ - 'event_type': 'state_changed', - 'data': { - 'entity_id': cond['entity'], - 'new_state': cond['value'] - } - }) - self.startup_conditions.append(condition) - condition_tasks.append( - self.AD.loop.create_task( - condition.event.wait() - ) - ) - - self.logger.info(f'Waiting for {len(condition_tasks)} startup condition tasks after {self.time_str()}') - if condition_tasks: - await asyncio.wait(condition_tasks) + self.logger.info(f'Waiting for {len(tasks)} startup condition tasks after {self.time_str()}') + if tasks: + await asyncio.wait(tasks) async def get_updates(self): while not self.stopping: @@ -712,11 +736,13 @@ async def safe_set_state(self: 'HassPlugin'): async def get_plugin_state(self, entity_id: str, timeout: float | None = None): return await self.http_method('get', f'/api/states/{entity_id}', timeout) - async def check_for_entity(self, entity_id: str, timeout: float | None = None) -> bool: - """Tries to get the state of an entity ID to see if it exists""" + async def check_for_entity(self, entity_id: str, timeout: float | None = None) -> dict | Literal[False]: + """Tries to get the state of an entity ID to see if it exists. + + Returns a dict of the state if the entity exists. Otherwise returns False""" resp = await self.get_plugin_state(entity_id, timeout) if isinstance(resp, dict): - return True + return resp elif isinstance(resp, ClientResponse) and resp.status == 404: return False diff --git a/appdaemon/plugins/hass/utils.py b/appdaemon/plugins/hass/utils.py index f62d6e293..51b7a74b1 100644 --- a/appdaemon/plugins/hass/utils.py +++ b/appdaemon/plugins/hass/utils.py @@ -23,14 +23,14 @@ async def loop(self: "HassPlugin", *args, **kwargs): def hass_check(func): - """Essentially swallows the function call if the Home Assistant plugin isn't ready, in which case the function will return None. + """Essentially swallows the function call if the Home Assistant plugin isn't connected, in which case the function will return None. """ async def no_func(): pass @functools.wraps(func) def func_wrapper(self: "HassPlugin", *args, **kwargs): - if not self.is_ready: + if not self.connect_event.is_set(): self.logger.warning("Attempt to call Home Assistant while disconnected: %s", func.__name__) return no_func() else: diff --git a/appdaemon/state.py b/appdaemon/state.py index 638c498a0..dcb87cc23 100644 --- a/appdaemon/state.py +++ b/appdaemon/state.py @@ -156,7 +156,14 @@ def terminate(self): self.logger.info("Saving all namespaces") self.save_all_namespaces() - async def add_state_callback(self, name: str, namespace: str, entity: str, cb: Callable, kwargs: dict[str, Any]): # noqa: C901 + async def add_state_callback( + self, + name: str, + namespace: str, + entity: str, + cb: Callable, + kwargs: dict[str, Any] + ): # noqa: C901 # Filter none values, which might be present as defaults kwargs = {k: v for k, v in kwargs.items() if v is not None} @@ -226,7 +233,10 @@ async def add_state_callback(self, name: str, namespace: str, entity: str, cb: C if "new" in kwargs: if __attribute is None and self.state[namespace][entity].get("state") == kwargs["new"]: __new_state = kwargs["new"] - elif __attribute is not None and self.state[namespace][entity]["attributes"].get(__attribute) == kwargs["new"]: + elif ( + __attribute is not None + and self.state[namespace][entity]["attributes"].get(__attribute) == kwargs["new"] + ): __new_state = kwargs["new"] else: run = False @@ -295,7 +305,9 @@ async def cancel_state_callback(self, handle, name, silent=False): del self.AD.callbacks.callbacks[name] if not executed and not silent: - self.logger.warning("Invalid callback handle '{}' in cancel_state_callback() from app {}".format(handle, name)) + self.logger.warning( + f"Invalid callback handle '{handle}' in cancel_state_callback() from app {name}" + ) return executed @@ -325,7 +337,11 @@ async def process_state_callbacks(self, namespace, state): for name in self.AD.callbacks.callbacks.keys(): for uuid_ in self.AD.callbacks.callbacks[name]: callback = self.AD.callbacks.callbacks[name][uuid_] - if callback["type"] == "state" and (callback["namespace"] == namespace or callback["namespace"] == "global" or namespace == "global"): + if callback["type"] == "state" and ( + callback["namespace"] == namespace or + callback["namespace"] == "global" or + namespace == "global" + ): cdevice = None centity = None if callback["entity"] is not None: @@ -528,9 +544,21 @@ def maybe_copy(data): return maybe_copy(self.state[namespace]) domain = entity_id.split(".", 1)[0] - return {entity_id: maybe_copy(state) for entity_id, state in self.state[namespace].items() if entity_id.split(".", 1)[0] == domain} + return { + entity_id: maybe_copy(state) + for entity_id, state in self.state[namespace].items() + if entity_id.split(".", 1)[0] == domain + } - def parse_state(self, namespace: str, entity: str, state: Any | None = None, attributes: dict | None = None, replace: bool = False, **kwargs): + def parse_state( + self, + namespace: str, + entity: str, + state: Any | None = None, + attributes: dict | None = None, + replace: bool = False, + **kwargs + ): self.logger.debug(f"parse_state: {entity}, {kwargs}") if entity in self.state[namespace]: @@ -611,7 +639,17 @@ async def state_services(self, namespace, domain, service, kwargs): self.logger.warning("Unknown service in state service call: %s", kwargs) @overload - async def set_state(self, name: str, namespace: str, entity: str, _silent: bool, state: Any, attributes: dict, replace: bool, **kwargs) -> None: ... + async def set_state( + self, + name: str, + namespace: str, + entity: str, + _silent: bool, + state: Any, + attributes: dict, + replace: bool, + **kwargs + ) -> None: ... async def set_state(self, name: str, namespace: str, entity: str, _silent: bool = False, **kwargs): """Sets the internal state of an entity. Uses relevant plugin objects based on namespace. @@ -648,7 +686,12 @@ async def set_state(self, name: str, namespace: str, entity: str, _silent: bool # We assume that the state change will come back to us via the plugin self.logger.debug("sending event to plugin") - result = await set_plugin_state(namespace, entity, state=new_state["state"], attributes=new_state["attributes"]) + result = await set_plugin_state( + namespace, + entity, + state=new_state["state"], + attributes=new_state["attributes"] + ) if result is not None: if "entity_id" in result: result.pop("entity_id")