Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move message processing to its own function #417

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 78 additions & 61 deletions aiowebostv/webos_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def __init__(
self.connect_result: Future[bool] | None = None
self.connection: ClientWebSocketResponse | None = None
self.input_connection: ClientWebSocketResponse | None = None
self.callbacks: dict[int, Callable] = {}
self.futures: dict[int, Future[dict[str, Any]]] = {}
self._power_state: dict[str, Any] = {}
self._current_app_id: str | None = None
Expand All @@ -77,6 +76,8 @@ def __init__(
self._volume_step_delay: timedelta | None = None
self._loop = asyncio.get_running_loop()
self._media_state: list[dict[str, Any]] = []
self.callback_queues: dict[int, Queue[dict[str, Any]]] = {}
self.callback_tasks: dict[int, Task] = {}

async def connect(self) -> bool:
"""Connect to webOS TV device."""
Expand Down Expand Up @@ -193,14 +194,11 @@ async def connect_handler(self, res: Future) -> None:
error = "Client key not set, pairing failed."
raise WebOsTvPairError(error)

self.callbacks = {}
self.callback_queues = {}
self.callback_tasks = {}
self.futures = {}

handler_tasks.add(
asyncio.create_task(
self.consumer_handler(main_ws, self.callbacks, self.futures)
)
)
handler_tasks.add(asyncio.create_task(self.consumer_handler(main_ws)))
self.connection = main_ws

# open additional connection needed to send button commands
Expand All @@ -211,7 +209,7 @@ async def connect_handler(self, res: Future) -> None:
inputsockpath = sockres["socketPath"]
input_ws = await self._ws_connect(inputsockpath)
handler_tasks.add(
asyncio.create_task(self.consumer_handler(input_ws, None, None))
asyncio.create_task(self.input_consumer_handler(input_ws))
)
self.input_connection = input_ws

Expand Down Expand Up @@ -258,6 +256,10 @@ async def connect_handler(self, res: Future) -> None:
if not res.done():
res.set_exception(ex)
finally:
for callback_task in self.callback_tasks.values():
if not callback_task.done():
callback_task.cancel()

for task in handler_tasks:
if not task.done():
task.cancel()
Expand All @@ -266,6 +268,11 @@ async def connect_handler(self, res: Future) -> None:
future.cancel()

closeout = set()

callback_tasks = set(self.callback_tasks.values())
if callback_tasks:
closeout.update(callback_tasks)

closeout.update(handler_tasks)

if main_ws is not None:
Expand Down Expand Up @@ -309,62 +316,47 @@ async def connect_handler(self, res: Future) -> None:
async def callback_handler(
queue: Queue[dict[str, Any]],
callback: Callable,
future: Future[dict[str, Any]] | None,
future: Future[dict[str, Any]],
) -> None:
"""Handle callbacks."""
with suppress(asyncio.CancelledError):
while True:
msg = await queue.get()
payload = msg.get("payload")
await callback(payload)
if future is not None and not future.done():
if not future.done():
future.set_result(msg)

async def consumer_handler(
self,
web_socket: ClientWebSocketResponse,
callbacks: dict[int, Callable] | None,
futures: dict[int, Future] | None,
) -> None:
def _process_text_message(self, data: str) -> None:
"""Process text message."""
msg = json.loads(data)
uid = msg.get("id")
# if we have a callback for this message, put it in the queue
# let the callback handle the message and mark the future as done
if queue := self.callback_queues.get(uid):
queue.put_nowait(msg)
elif future := self.futures.get(uid):
future.set_result(msg)

async def consumer_handler(self, web_socket: ClientWebSocketResponse) -> None:
"""Callbacks consumer handler."""
callback_queues: dict[int, Queue[dict[str, Any]]] = {}
callback_tasks: dict[int, Task] = {}
async for raw_msg in web_socket:
_LOGGER.debug("recv(%s): %s", self.host, raw_msg)
if raw_msg.type is not WSMsgType.TEXT:
break

try:
async for raw_msg in web_socket:
_LOGGER.debug("recv(%s): %s", self.host, raw_msg)
if raw_msg.type is not WSMsgType.TEXT:
break

if callbacks or futures:
msg = json.loads(raw_msg.data)
uid = msg.get("id")
callback = self.callbacks.get(uid)
future = self.futures.get(uid)
if callback is not None:
if uid not in callback_tasks:
queue: Queue[dict[str, Any]] = asyncio.Queue()
callback_queues[uid] = queue
callback_tasks[uid] = asyncio.create_task(
self.callback_handler(queue, callback, future)
)
callback_queues[uid].put_nowait(msg)
elif future is not None and not future.done():
self.futures[uid].set_result(msg)
self._process_text_message(raw_msg.data)

finally:
for task in callback_tasks.values():
if not task.done():
task.cancel()
async def input_consumer_handler(self, web_socket: ClientWebSocketResponse) -> None:
"""Input consumer handler.

tasks = set(callback_tasks.values())

if tasks:
closeout_task = asyncio.create_task(asyncio.wait(tasks))

while not closeout_task.done():
with suppress(asyncio.CancelledError):
await asyncio.shield(closeout_task)
We are not expecting any messages from the input connection.
This is just to keep the connection alive.
"""
async for raw_msg in web_socket:
_LOGGER.debug("input recv(%s): %s", self.host, raw_msg)
if raw_msg.type is not WSMsgType.TEXT:
break

# manage state
@property
Expand Down Expand Up @@ -480,11 +472,7 @@ def clear_state_update_callbacks(self) -> None:

async def do_state_update_callbacks(self) -> None:
"""Call user state update callback."""
callbacks = set()
for callback in self.state_update_callbacks:
callbacks.add(callback(self))

if callbacks:
if callbacks := {callback(self) for callback in self.state_update_callbacks}:
await asyncio.gather(*callbacks)

async def set_power_state(self, payload: dict[str, bool | str]) -> None:
Expand Down Expand Up @@ -643,8 +631,10 @@ async def request(
if uid is None:
uid = self.command_count
self.command_count += 1
res = self._loop.create_future()
self.futures[uid] = res
res = self._loop.create_future()
self.futures[uid] = res
else:
res = self.futures[uid]
try:
await self.command(cmd_type, uri, payload, uid)
except (asyncio.CancelledError, WebOsTvCommandError):
Expand Down Expand Up @@ -683,19 +673,46 @@ async def request(

return payload

async def create_subscription_handler(self, uid: int, callback: Callable) -> None:
"""Create a subscription handler for a given uid.

Create a queue to store the messages, a task to handle the messages
and a future to signal first subscription update processed.
"""
self.futures[uid] = future = self._loop.create_future()
queue: Queue[dict[str, Any]] = asyncio.Queue()
self.callback_queues[uid] = queue
self.callback_tasks[uid] = asyncio.create_task(
self.callback_handler(queue, callback, future)
)

async def delete_subscription_handler(self, uid: int) -> None:
"""Delete a subscription handler for a given uid."""
task = self.callback_tasks.pop(uid)
if not task.done():
task.cancel()
while not task.done():
with suppress(asyncio.CancelledError):
await asyncio.shield(task)
del self.callback_queues[uid]

async def subscribe(
self, callback: Callable, uri: str, payload: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Subscribe to updates."""
"""Subscribe to updates.

Subsciption use a fixed uid, pre-create a future and a handler.
"""
uid = self.command_count
self.command_count += 1
self.callbacks[uid] = callback
await self.create_subscription_handler(uid, callback)

try:
return await self.request(
uri, payload=payload, cmd_type="subscribe", uid=uid
)
except Exception:
del self.callbacks[uid]
await self.delete_subscription_handler(uid)
raise

async def input_command(self, message: str) -> None:
Expand Down
Loading