diff --git a/aiowebostv/webos_client.py b/aiowebostv/webos_client.py index 33f368d..acd7454 100644 --- a/aiowebostv/webos_client.py +++ b/aiowebostv/webos_client.py @@ -56,7 +56,6 @@ def __init__( self.connect_result: Future[bool] | None = None self.connection: ClientWebSocketResponse | None = None self.input_connection: ClientWebSocketResponse | None = None - self.callbacks: dict[int, Callable] = {} self.futures: dict[int, Future[dict[str, Any]]] = {} self._power_state: dict[str, Any] = {} self._current_app_id: str | None = None @@ -77,6 +76,9 @@ def __init__( self._volume_step_delay: timedelta | None = None self._loop = asyncio.get_running_loop() self._media_state: list[dict[str, Any]] = [] + self.callback_queues: dict[int, Queue[dict[str, Any]]] = {} + self.callback_tasks: dict[int, Task] = {} + self._rx_tasks: set[Task] = set() async def connect(self) -> bool: """Connect to webOS TV device.""" @@ -133,183 +135,218 @@ async def close_client_session(self) -> None: self.created_client_session = False self.client_session = None - async def connect_handler(self, res: Future) -> None: - """Handle connection for webOS TV.""" - handler_tasks: set[Task] = set() - main_ws: ClientWebSocketResponse | None = None - input_ws: ClientWebSocketResponse | None = None + async def _create_main_ws(self) -> ClientWebSocketResponse: + """Create main websocket connection. + Try using ws:// and fallback to wss:// if the TV rejects the connection. + """ try: - # Create a new client session if not provided - if self.client_session is None: - self.client_session = ClientSession() - self.created_client_session = True - - try: - uri = f"ws://{self.host}:{WS_PORT}" - main_ws = await self._ws_connect(uri) - # ClientConnectionError is raised when firmware reject WS_PORT - # WSServerHandshakeError is raised when firmware enforce using ssl - except (aiohttp.ClientConnectionError, aiohttp.WSServerHandshakeError): - uri = f"wss://{self.host}:{WSS_PORT}" - main_ws = await self._ws_connect(uri) - - # send hello - _LOGGER.debug("send(%s): hello", self.host) - await main_ws.send_json({"id": "hello", "type": "hello"}) - response = await main_ws.receive_json() - _LOGGER.debug("recv(%s): %s", self.host, response) - - if response["type"] == "hello": - self._hello_info = response["payload"] - else: - error = f"Invalid response type {response}" - raise WebOsTvCommandError(error) - - # send registration - _LOGGER.debug("send(%s): registration", self.host) - await main_ws.send_json(self.registration_msg()) - response = await main_ws.receive_json() - _LOGGER.debug("recv(%s): registration", self.host) - - if ( - response["type"] == "response" - and response["payload"]["pairingType"] == "PROMPT" - ): - response = await main_ws.receive_json() - _LOGGER.debug("recv(%s): pairing", self.host) - _LOGGER.debug( - "pairing(%s): type: %s, error: %s", - self.host, - response["type"], - response.get("error"), - ) - if response["type"] == "error": - raise WebOsTvPairError(response["error"]) - if response["type"] == "registered": - self.client_key = response["payload"]["client-key"] - - if not self.client_key: - error = "Client key not set, pairing failed." - raise WebOsTvPairError(error) - - self.callbacks = {} - self.futures = {} - - handler_tasks.add( - asyncio.create_task( - self.consumer_handler(main_ws, self.callbacks, self.futures) - ) - ) - self.connection = main_ws + uri = f"ws://{self.host}:{WS_PORT}" + return await self._ws_connect(uri) + # ClientConnectionError is raised when firmware reject WS_PORT + # WSServerHandshakeError is raised when firmware enforce using ssl + except (aiohttp.ClientConnectionError, aiohttp.WSServerHandshakeError): + uri = f"wss://{self.host}:{WSS_PORT}" + return await self._ws_connect(uri) + + def _ensure_client_session(self) -> None: + """Create a new client session if no client session provided.""" + if self.client_session is None: + self.client_session = ClientSession() + self.created_client_session = True + + async def _get_hello_info(self, ws: ClientWebSocketResponse) -> None: + """Get hello info.""" + _LOGGER.debug("send(%s): hello", self.host) + await ws.send_json({"id": "hello", "type": "hello"}) + response = await ws.receive_json() + _LOGGER.debug("recv(%s): %s", self.host, response) + + if response["type"] == "hello": + self._hello_info = response["payload"] + else: + error = f"Invalid response type {response}" + raise WebOsTvCommandError(error) - # open additional connection needed to send button commands - # the url is dynamically generated and returned from the ep.INPUT_SOCKET - # endpoint on the main connection - # create an empty consumer handler to keep ping/pong alive - sockres = await self.request(ep.INPUT_SOCKET) - inputsockpath = sockres["socketPath"] - input_ws = await self._ws_connect(inputsockpath) - handler_tasks.add( - asyncio.create_task(self.consumer_handler(input_ws, None, None)) + async def _check_registration(self, ws: ClientWebSocketResponse) -> None: + """Check if the client is registered with the tv.""" + _LOGGER.debug("send(%s): registration", self.host) + await ws.send_json(self.registration_msg()) + response = await ws.receive_json() + _LOGGER.debug("recv(%s): registration", self.host) + + if ( + response["type"] == "response" + and response["payload"]["pairingType"] == "PROMPT" + ): + response = await ws.receive_json() + _LOGGER.debug("recv(%s): pairing", self.host) + _LOGGER.debug( + "pairing(%s): type: %s, error: %s", + self.host, + response["type"], + response.get("error"), ) - self.input_connection = input_ws + if response["type"] == "error": + raise WebOsTvPairError(response["error"]) + if response["type"] == "registered": + self.client_key = response["payload"]["client-key"] - # set static state and subscribe to state updates - # avoid partial updates during initial subscription + if not self.client_key: + error = "Client key not set, pairing failed." + raise WebOsTvPairError(error) - self.do_state_update = False - self._system_info, self._software_info = await asyncio.gather( - self.get_system_info(), self.get_software_info() - ) - subscribe_state_updates = { - self.subscribe_power_state(self.set_power_state), - self.subscribe_current_app(self.set_current_app_state), - self.subscribe_muted(self.set_muted_state), - self.subscribe_volume(self.set_volume_state), - self.subscribe_apps(self.set_apps_state), - self.subscribe_inputs(self.set_inputs_state), - self.subscribe_sound_output(self.set_sound_output_state), - self.subscribe_media_foreground_app(self.set_media_state), - } - subscribe_tasks = set() - for state_update in subscribe_state_updates: - subscribe_tasks.add(asyncio.create_task(state_update)) - await asyncio.wait(subscribe_tasks) - for task in subscribe_tasks: - with suppress(WebOsTvServiceNotFoundError): - task.result() - # set placeholder power state if not available - if not self._power_state: - self._power_state = {"state": "Unknown"} - self.do_state_update = True - if self.state_update_callbacks: - await self.do_state_update_callbacks() + async def _create_input_ws(self) -> ClientWebSocketResponse: + """Create input websocket connection. - res.set_result(True) + Open additional connection needed to send button commands + The url is dynamically generated and returned from the ep.INPUT_SOCKET + endpoint on the main connection. + """ + sockres = await self.request(ep.INPUT_SOCKET) + inputsockpath = sockres["socketPath"] + return await self._ws_connect(inputsockpath) - await asyncio.wait(handler_tasks, return_when=asyncio.FIRST_COMPLETED) + async def _get_states_and_subscribe_state_updates(self) -> None: + """Get initial states and subscribe to state updates. - except Exception as ex: # pylint: disable=broad-except - if isinstance(ex, TimeoutError): - _LOGGER.debug("timeout(%s): connection", self.host) - else: - _LOGGER.debug("exception(%s): %r", self.host, ex, exc_info=True) - if not res.done(): - res.set_exception(ex) - finally: - for task in handler_tasks: - if not task.done(): - task.cancel() - - for future in self.futures.values(): - future.cancel() - - closeout = set() - closeout.update(handler_tasks) - - if main_ws is not None: - closeout.add(asyncio.create_task(main_ws.close())) - if input_ws is not None: - closeout.add(asyncio.create_task(input_ws.close())) - if self.created_client_session: - closeout.add(asyncio.create_task(self.close_client_session())) - - self.connection = None - self.input_connection = None - - self.do_state_update = False - - self._power_state = {} - self._current_app_id = None - self._muted = None - self._volume = None - self._current_channel = None - self._channel_info = None - self._channels = None - self._apps = {} - self._extinputs = {} - self._system_info = {} - self._software_info = {} - self._hello_info = {} - self._sound_output = None - self._media_state = [] + Avoid partial updates during initial subscription. + """ + self.do_state_update = False + self._system_info, self._software_info = await asyncio.gather( + self.get_system_info(), self.get_software_info() + ) + subscribe_state_updates = { + self.subscribe_power_state(self.set_power_state), + self.subscribe_current_app(self.set_current_app_state), + self.subscribe_muted(self.set_muted_state), + self.subscribe_volume(self.set_volume_state), + self.subscribe_apps(self.set_apps_state), + self.subscribe_inputs(self.set_inputs_state), + self.subscribe_sound_output(self.set_sound_output_state), + self.subscribe_media_foreground_app(self.set_media_state), + } + subscribe_tasks = set() + for state_update in subscribe_state_updates: + subscribe_tasks.add(asyncio.create_task(state_update)) + await asyncio.wait(subscribe_tasks) + for task in subscribe_tasks: + with suppress(WebOsTvServiceNotFoundError): + task.result() + # set placeholder power state if not available + if not self._power_state: + self._power_state = {"state": "Unknown"} + self.do_state_update = True + if self.state_update_callbacks: + await self.do_state_update_callbacks() + + def _clear_tv_states(self) -> None: + """Clear all TV states.""" + self._power_state = {} + self._current_app_id = None + self._muted = None + self._volume = None + self._current_channel = None + self._channel_info = None + self._channels = None + self._apps = {} + self._extinputs = {} + self._system_info = {} + self._software_info = {} + self._hello_info = {} + self._sound_output = None + self._media_state = [] + + def _cancel_tasks(self) -> None: + """Cancel all tasks.""" + for callback_task in self.callback_tasks.values(): + if not callback_task.done(): + callback_task.cancel() + + for task in self._rx_tasks: + if not task.done(): + task.cancel() + + for future in self.futures.values(): + future.cancel() + + async def _closeout_tasks( + self, + main_ws: ClientWebSocketResponse | None, + input_ws: ClientWebSocketResponse | None, + ) -> None: + """Cancel all tasks and close connections.""" + closeout = set() + + self._cancel_tasks() + + if callback_tasks := set(self.callback_tasks.values()): + closeout.update(callback_tasks) + + closeout.update(self._rx_tasks) + + if main_ws is not None: + closeout.add(asyncio.create_task(main_ws.close())) + if input_ws is not None: + closeout.add(asyncio.create_task(input_ws.close())) + if self.created_client_session: + closeout.add(asyncio.create_task(self.close_client_session())) + + self.connection = None + self.input_connection = None + self.do_state_update = False + self._clear_tv_states() + + for callback in self.state_update_callbacks: + closeout.add(asyncio.create_task(callback(self))) - for callback in self.state_update_callbacks: - closeout.add(asyncio.create_task(callback(self))) + if not closeout: + return - if closeout: - closeout_task = asyncio.create_task(asyncio.wait(closeout)) + closeout_task = asyncio.create_task(asyncio.wait(closeout)) + while not closeout_task.done(): + with suppress(asyncio.CancelledError): + await asyncio.shield(closeout_task) - while not closeout_task.done(): - with suppress(asyncio.CancelledError): - await asyncio.shield(closeout_task) + async def connect_handler(self, res: Future) -> None: + """Handle connection for webOS TV.""" + self._rx_tasks = set() + self.callback_queues = {} + self.callback_tasks = {} + self.futures = {} + main_ws: ClientWebSocketResponse | None = None + input_ws: ClientWebSocketResponse | None = None + self._ensure_client_session() + try: + main_ws = await self._create_main_ws() + await self._get_hello_info(main_ws) + await self._check_registration(main_ws) + self._rx_tasks.add(asyncio.create_task(self._rx_msgs_main_ws(main_ws))) + self.connection = main_ws + input_ws = await self._create_input_ws() + self._rx_tasks.add(asyncio.create_task(self._rx_msgs_input_ws(input_ws))) + self.input_connection = input_ws + await self._get_states_and_subscribe_state_updates() + res.set_result(True) + await asyncio.wait(self._rx_tasks, return_when=asyncio.FIRST_COMPLETED) + except TimeoutError: + _LOGGER.debug("timeout(%s): connection", self.host) + if not res.done(): + res.set_exception(asyncio.TimeoutError) + except Exception as ex: + _LOGGER.debug("exception(%s): %r", self.host, ex, exc_info=True) + if not res.done(): + res.set_exception(ex) + else: + raise + finally: + await self._closeout_tasks(main_ws, input_ws) @staticmethod async def callback_handler( queue: Queue[dict[str, Any]], callback: Callable, - future: Future[dict[str, Any]] | None, + future: Future[dict[str, Any]], ) -> None: """Handle callbacks.""" with suppress(asyncio.CancelledError): @@ -317,54 +354,39 @@ async def callback_handler( msg = await queue.get() payload = msg.get("payload") await callback(payload) - if future is not None and not future.done(): + if not future.done(): future.set_result(msg) - async def consumer_handler( - self, - web_socket: ClientWebSocketResponse, - callbacks: dict[int, Callable] | None, - futures: dict[int, Future] | None, - ) -> None: - """Callbacks consumer handler.""" - callback_queues: dict[int, Queue[dict[str, Any]]] = {} - callback_tasks: dict[int, Task] = {} - - try: - async for raw_msg in web_socket: - _LOGGER.debug("recv(%s): %s", self.host, raw_msg) - if raw_msg.type is not WSMsgType.TEXT: - break - - if callbacks or futures: - msg = json.loads(raw_msg.data) - uid = msg.get("id") - callback = self.callbacks.get(uid) - future = self.futures.get(uid) - if callback is not None: - if uid not in callback_tasks: - queue: Queue[dict[str, Any]] = asyncio.Queue() - callback_queues[uid] = queue - callback_tasks[uid] = asyncio.create_task( - self.callback_handler(queue, callback, future) - ) - callback_queues[uid].put_nowait(msg) - elif future is not None and not future.done(): - self.futures[uid].set_result(msg) - - finally: - for task in callback_tasks.values(): - if not task.done(): - task.cancel() - - tasks = set(callback_tasks.values()) - - if tasks: - closeout_task = asyncio.create_task(asyncio.wait(tasks)) - - while not closeout_task.done(): - with suppress(asyncio.CancelledError): - await asyncio.shield(closeout_task) + def _process_text_message(self, data: str) -> None: + """Process text message.""" + msg = json.loads(data) + uid = msg.get("id") + # if we have a callback for this message, put it in the queue + # let the callback handle the message and mark the future as done + if queue := self.callback_queues.get(uid): + queue.put_nowait(msg) + elif future := self.futures.get(uid): + future.set_result(msg) + + async def _rx_msgs_main_ws(self, web_socket: ClientWebSocketResponse) -> None: + """Receive messages from main websocket connection.""" + async for raw_msg in web_socket: + _LOGGER.debug("recv(%s): %s", self.host, raw_msg) + if raw_msg.type is not WSMsgType.TEXT: + break + + self._process_text_message(raw_msg.data) + + async def _rx_msgs_input_ws(self, web_socket: ClientWebSocketResponse) -> None: + """Receive messages from input websocket connection. + + We are not expecting any messages from the input connection. + This is just to keep the connection alive. + """ + async for raw_msg in web_socket: + _LOGGER.debug("input recv(%s): %s", self.host, raw_msg) + if raw_msg.type is not WSMsgType.TEXT: + break # manage state @property @@ -480,11 +502,7 @@ def clear_state_update_callbacks(self) -> None: async def do_state_update_callbacks(self) -> None: """Call user state update callback.""" - callbacks = set() - for callback in self.state_update_callbacks: - callbacks.add(callback(self)) - - if callbacks: + if callbacks := {callback(self) for callback in self.state_update_callbacks}: await asyncio.gather(*callbacks) async def set_power_state(self, payload: dict[str, bool | str]) -> None: @@ -643,8 +661,10 @@ async def request( if uid is None: uid = self.command_count self.command_count += 1 - res = self._loop.create_future() - self.futures[uid] = res + res = self._loop.create_future() + self.futures[uid] = res + else: + res = self.futures[uid] try: await self.command(cmd_type, uri, payload, uid) except (asyncio.CancelledError, WebOsTvCommandError): @@ -683,19 +703,46 @@ async def request( return payload + async def create_subscription_handler(self, uid: int, callback: Callable) -> None: + """Create a subscription handler for a given uid. + + Create a queue to store the messages, a task to handle the messages + and a future to signal first subscription update processed. + """ + self.futures[uid] = future = self._loop.create_future() + queue: Queue[dict[str, Any]] = asyncio.Queue() + self.callback_queues[uid] = queue + self.callback_tasks[uid] = asyncio.create_task( + self.callback_handler(queue, callback, future) + ) + + async def delete_subscription_handler(self, uid: int) -> None: + """Delete a subscription handler for a given uid.""" + task = self.callback_tasks.pop(uid) + if not task.done(): + task.cancel() + while not task.done(): + with suppress(asyncio.CancelledError): + await task + del self.callback_queues[uid] + async def subscribe( self, callback: Callable, uri: str, payload: dict[str, Any] | None = None ) -> dict[str, Any]: - """Subscribe to updates.""" + """Subscribe to updates. + + Subsciption use a fixed uid, pre-create a future and a handler. + """ uid = self.command_count self.command_count += 1 - self.callbacks[uid] = callback + await self.create_subscription_handler(uid, callback) + try: return await self.request( uri, payload=payload, cmd_type="subscribe", uid=uid ) except Exception: - del self.callbacks[uid] + await self.delete_subscription_handler(uid) raise async def input_command(self, message: str) -> None: diff --git a/ruff.toml b/ruff.toml index e50ba94..d047628 100644 --- a/ruff.toml +++ b/ruff.toml @@ -3,14 +3,9 @@ target-version = "py311" lint.select = ["ALL"] lint.ignore = [ - "BLE001", # Do not catch blind exception: `Exception` - "C901", # is too complex "COM812", # Trailing comma missing (conflicts with formatter) "D203", # 1 blank line required before class docstring (conflicts with `no-blank-line-before-class` (D211)) "D213", # Multi-line docstring summary should start at the second line (conflicts with multi-line-summary-first-line` (D212)) - "PLR0912", # Too many branches - "PLR0915", # Too many statements - "TRY301", # Abstract `raise` to an inner function ]