Skip to content

Commit

Permalink
Cleanly shut down the serial port on disconnect (#633)
Browse files Browse the repository at this point in the history
* Cleanly handle connection loss

* Guard disconnect

* Clean up exception handling and reduce unnecessary resets

* Rename `application` to `api` in EZSP UART

* Ensure `enter_failed_state` passes through an exception object

* Fix unit tests

* Bump minimum zigpy version

* Fix CLI

* Drop accidental import

* 100% coverage
  • Loading branch information
puddly authored Oct 28, 2024
1 parent 7e1008e commit ecce1ba
Show file tree
Hide file tree
Showing 16 changed files with 295 additions and 371 deletions.
8 changes: 7 additions & 1 deletion bellows/ash.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ def __init__(self, code: t.NcpResetCode) -> None:
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(code={self.code})>"

def __eq__(self, other: object) -> bool | NotImplemented:
if not isinstance(other, NcpFailure):
return NotImplemented

return self.code == other.code


class AshFrame(abc.ABC, BaseDataclassMixin):
MASK: t.uint8_t
Expand Down Expand Up @@ -368,7 +374,7 @@ def connection_made(self, transport):
self._transport = transport
self._ezsp_protocol.connection_made(self)

def connection_lost(self, exc):
def connection_lost(self, exc: Exception | None) -> None:
self._transport = None
self._cancel_pending_data_frames()
self._ezsp_protocol.connection_lost(exc)
Expand Down
2 changes: 1 addition & 1 deletion bellows/cli/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def dump(ctx, channel, outfile):
finally:
if "ezsp" in ctx.obj:
loop.run_until_complete(ctx.obj["ezsp"].mfglibEnd())
ctx.obj["ezsp"].close()
loop.run_until_complete(ctx.obj["ezsp"].disconnect())


def ieee_15_4_fcs(data: bytes) -> bytes:
Expand Down
10 changes: 5 additions & 5 deletions bellows/cli/ncp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def config(ctx, config, all_):
if v[0] == t.EzspStatus.ERROR_INVALID_ID:
continue
click.echo(f"{config.name}={v[1]}")
s.close()
await s.disconnect()
return

if "=" in config:
Expand All @@ -54,7 +54,7 @@ async def config(ctx, config, all_):

v = await s.setConfigurationValue(config, value)
click.echo(v)
s.close()
await s.disconnect()
return

v = await s.getConfigurationValue(config)
Expand Down Expand Up @@ -86,7 +86,7 @@ async def info(ctx):
click.echo(f"Board name: {brd_name}")
click.echo(f"EmberZNet version: {version}")

s.close()
await s.disconnect()


@main.command()
Expand All @@ -105,7 +105,7 @@ async def bootloader(ctx):
version, plat, micro, phy = await ezsp.getStandaloneBootloaderVersionPlatMicroPhy()
if version == 0xFFFF:
click.echo("No boot loader installed")
ezsp.close()
await ezsp.disconnect()
return

click.echo(
Expand All @@ -118,4 +118,4 @@ async def bootloader(ctx):
click.echo(f"Couldn't launch bootloader: {res[0]}")
else:
click.echo("bootloader launched successfully")
ezsp.close()
await ezsp.disconnect()
6 changes: 3 additions & 3 deletions bellows/cli/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def cb(fut, frame_name, response):

s.remove_callback(cbid)

s.close()
await s.disconnect()


@main.command()
Expand All @@ -126,7 +126,7 @@ async def leave(ctx):
expected=t.EmberStatus.NETWORK_DOWN,
)

s.close()
await s.disconnect()


@main.command()
Expand Down Expand Up @@ -157,4 +157,4 @@ async def scan(ctx, channels, duration_ms, energy_scan):
for network in v:
click.echo(network)

s.close()
await s.disconnect()
2 changes: 1 addition & 1 deletion bellows/cli/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def stream(ctx, channel, power):
s = ctx.obj["ezsp"]
loop.run_until_complete(s.mfglibStopStream())
loop.run_until_complete(s.mfglibEnd())
s.close()
loop.run_until_complete(s.disconnect())


async def _stream(ctx, channel, power):
Expand Down
2 changes: 1 addition & 1 deletion bellows/cli/tone.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def tone(ctx, channel, power):
s = ctx.obj["ezsp"]
loop.run_until_complete(s.mfglibStopTone())
loop.run_until_complete(s.mfglibEnd())
s.close()
loop.run_until_complete(s.disconnect())


async def _tone(ctx, channel, power):
Expand Down
25 changes: 7 additions & 18 deletions bellows/cli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,17 @@ async def async_inner(ctx, *args, **kwargs):
if extra_config:
app_config.update(extra_config)
application = await setup_application(app_config, startup=app_startup)
ctx.obj["app"] = application
await f(ctx, *args, **kwargs)
await asyncio.sleep(0.5)
await application.shutdown()

def shutdown():
with contextlib.suppress(Exception):
application._ezsp.close()
try:
ctx.obj["app"] = application
await f(ctx, *args, **kwargs)
finally:
with contextlib.suppress(Exception):
await application.shutdown()

@functools.wraps(f)
def inner(*args, **kwargs):
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(async_inner(*args, **kwargs))
except: # noqa: E722
# It seems that often errors like a message send will try to send
# two messages, and not reading all of them will leave the NCP in
# a bad state. This seems to mitigate this somewhat. Better way?
loop.run_until_complete(asyncio.sleep(0.5))
raise
finally:
shutdown()
loop.run_until_complete(async_inner(*args, **kwargs))

return inner

Expand Down
52 changes: 18 additions & 34 deletions bellows/ezsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from typing import Any, Callable, Generator
import urllib.parse

from bellows.ash import NcpFailure

if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout # pragma: no cover
else:
Expand Down Expand Up @@ -55,13 +57,14 @@ class EZSP:
v14.EZSPv14.VERSION: v14.EZSPv14,
}

def __init__(self, device_config: dict):
def __init__(self, device_config: dict, application: Any | None = None):
self._config = device_config
self._callbacks = {}
self._ezsp_event = asyncio.Event()
self._ezsp_version = v4.EZSPv4.VERSION
self._gw = None
self._protocol = None
self._application = application

self._stack_status_listeners: collections.defaultdict[
t.sl_Status, list[asyncio.Future]
Expand Down Expand Up @@ -122,25 +125,17 @@ async def startup_reset(self) -> None:

await self.version()

@classmethod
async def initialize(cls, zigpy_config: dict) -> EZSP:
"""Return initialized EZSP instance."""
ezsp = cls(zigpy_config[conf.CONF_DEVICE])
await ezsp.connect(use_thread=zigpy_config[conf.CONF_USE_THREAD])
async def connect(self, *, use_thread: bool = True) -> None:
assert self._gw is None
self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread)

try:
await ezsp.startup_reset()
self._protocol = v4.EZSPv4(self.handle_callback, self._gw)
await self.startup_reset()
except Exception:
ezsp.close()
await self.disconnect()
raise

return ezsp

async def connect(self, *, use_thread: bool = True) -> None:
assert self._gw is None
self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread)
self._protocol = v4.EZSPv4(self.handle_callback, self._gw)

async def reset(self):
LOGGER.debug("Resetting EZSP")
self.stop_ezsp()
Expand Down Expand Up @@ -179,10 +174,10 @@ async def version(self):
ver,
)

def close(self):
async def disconnect(self):
self.stop_ezsp()
if self._gw:
self._gw.close()
await self._gw.disconnect()
self._gw = None

async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -264,23 +259,12 @@ async def leaveNetwork(self, timeout: float | int = NETWORK_OPS_TIMEOUT) -> None

def connection_lost(self, exc):
"""Lost serial connection."""
LOGGER.debug(
"%s connection lost unexpectedly: %s",
self._config[conf.CONF_DEVICE_PATH],
exc,
)
self.enter_failed_state(f"Serial connection loss: {exc!r}")

def enter_failed_state(self, error):
"""UART received error frame."""
if len(self._callbacks) > 1:
LOGGER.error("NCP entered failed state. Requesting APP controller restart")
self.close()
self.handle_callback("_reset_controller_application", (error,))
else:
LOGGER.info(
"NCP entered failed state. No application handler registered, ignoring..."
)
if self._application is not None:
self._application.connection_lost(exc)

def enter_failed_state(self, code: t.NcpResetCode) -> None:
"""UART received reset code."""
self.connection_lost(NcpFailure(code=code))

def __getattr__(self, name: str) -> Callable:
if name not in self._protocol.COMMANDS:
Expand Down
56 changes: 19 additions & 37 deletions bellows/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,29 @@
RESET_TIMEOUT = 5


class Gateway(asyncio.Protocol):
def __init__(self, application, connected_future=None, connection_done_future=None):
self._application = application
class Gateway(zigpy.serial.SerialProtocol):
def __init__(self, api, connection_done_future=None):
super().__init__()
self._api = api

self._reset_future = None
self._startup_reset_future = None
self._connected_future = connected_future
self._connection_done_future = connection_done_future

self._transport = None

def close(self):
self._transport.close()

def connection_made(self, transport):
"""Callback when the uart is connected"""
self._transport = transport
if self._connected_future is not None:
self._connected_future.set_result(True)

async def send_data(self, data: bytes) -> None:
await self._transport.send_data(data)

def data_received(self, data):
"""Callback when there is data received from the uart"""
self._application.frame_received(data)

# We intentionally do not call `SerialProtocol.data_received`
self._api.frame_received(data)

def reset_received(self, code: t.NcpResetCode) -> None:
"""Reset acknowledgement frame receive handler"""
# not a reset we've requested. Signal application reset
# not a reset we've requested. Signal api reset
if code is not t.NcpResetCode.RESET_SOFTWARE:
self._application.enter_failed_state(code)
self._api.enter_failed_state(code)
return

if self._reset_future and not self._reset_future.done():
Expand All @@ -61,7 +52,7 @@ def reset_received(self, code: t.NcpResetCode) -> None:

def error_received(self, code: t.NcpResetCode) -> None:
"""Error frame receive handler."""
self._application.enter_failed_state(code)
self._api.enter_failed_state(code)

async def wait_for_startup_reset(self) -> None:
"""Wait for the first reset frame on startup."""
Expand All @@ -77,12 +68,9 @@ def _reset_cleanup(self, future):
"""Delete reset future."""
self._reset_future = None

def eof_received(self):
"""Server gracefully closed its side of the connection."""
self.connection_lost(ConnectionResetError("Remote server closed connection"))

def connection_lost(self, exc):
"""Port was closed unexpectedly."""
super().connection_lost(exc)

LOGGER.debug("Connection lost: %r", exc)
reason = exc or ConnectionResetError("Remote server closed connection")
Expand All @@ -102,12 +90,7 @@ def connection_lost(self, exc):
self._reset_future.set_exception(reason)
self._reset_future = None

if exc is None:
LOGGER.debug("Closed serial connection")
return

LOGGER.error("Lost serial connection: %r", exc)
self._application.connection_lost(exc)
self._api.connection_lost(exc)

async def reset(self):
"""Send a reset frame and init internal state."""
Expand All @@ -126,13 +109,12 @@ async def reset(self):
return await self._reset_future


async def _connect(config, application):
async def _connect(config, api):
loop = asyncio.get_event_loop()

connection_future = loop.create_future()
connection_done_future = loop.create_future()

gateway = Gateway(application, connection_future, connection_done_future)
gateway = Gateway(api, connection_done_future)
protocol = AshProtocol(gateway)

if config[zigpy.config.CONF_DEVICE_FLOW_CONTROL] is None:
Expand All @@ -149,25 +131,25 @@ async def _connect(config, application):
rtscts=rtscts,
)

await connection_future
await gateway.wait_until_connected()

thread_safe_protocol = ThreadsafeProxy(gateway, loop)
return thread_safe_protocol, connection_done_future


async def connect(config, application, use_thread=True):
async def connect(config, api, use_thread=True):
if use_thread:
application = ThreadsafeProxy(application, asyncio.get_event_loop())
api = ThreadsafeProxy(api, asyncio.get_event_loop())
thread = EventLoopThread()
await thread.start()
try:
protocol, connection_done = await thread.run_coroutine_threadsafe(
_connect(config, application)
_connect(config, api)
)
except Exception:
thread.force_stop()
raise
connection_done.add_done_callback(lambda _: thread.force_stop())
else:
protocol, _ = await _connect(config, application)
protocol, _ = await _connect(config, api)
return protocol
Loading

0 comments on commit ecce1ba

Please sign in to comment.