diff --git a/bellows/ash.py b/bellows/ash.py index 1349e36d..ea6d6a0e 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -130,6 +130,12 @@ def __init__(self, code: t.NcpResetCode) -> None: def __repr__(self) -> str: return f"<{self.__class__.__name__}(code={self.code})>" + def __eq__(self, other: object) -> bool | NotImplemented: + if not isinstance(other, NcpFailure): + return NotImplemented + + return self.code == other.code + class AshFrame(abc.ABC, BaseDataclassMixin): MASK: t.uint8_t @@ -368,7 +374,7 @@ def connection_made(self, transport): self._transport = transport self._ezsp_protocol.connection_made(self) - def connection_lost(self, exc): + def connection_lost(self, exc: Exception | None) -> None: self._transport = None self._cancel_pending_data_frames() self._ezsp_protocol.connection_lost(exc) diff --git a/bellows/cli/dump.py b/bellows/cli/dump.py index e0d4acbb..17c598fc 100644 --- a/bellows/cli/dump.py +++ b/bellows/cli/dump.py @@ -37,7 +37,7 @@ def dump(ctx, channel, outfile): finally: if "ezsp" in ctx.obj: loop.run_until_complete(ctx.obj["ezsp"].mfglibEnd()) - ctx.obj["ezsp"].close() + loop.run_until_complete(ctx.obj["ezsp"].disconnect()) def ieee_15_4_fcs(data: bytes) -> bytes: diff --git a/bellows/cli/ncp.py b/bellows/cli/ncp.py index 6141d219..bbe0471a 100644 --- a/bellows/cli/ncp.py +++ b/bellows/cli/ncp.py @@ -30,7 +30,7 @@ async def config(ctx, config, all_): if v[0] == t.EzspStatus.ERROR_INVALID_ID: continue click.echo(f"{config.name}={v[1]}") - s.close() + await s.disconnect() return if "=" in config: @@ -54,7 +54,7 @@ async def config(ctx, config, all_): v = await s.setConfigurationValue(config, value) click.echo(v) - s.close() + await s.disconnect() return v = await s.getConfigurationValue(config) @@ -86,7 +86,7 @@ async def info(ctx): click.echo(f"Board name: {brd_name}") click.echo(f"EmberZNet version: {version}") - s.close() + await s.disconnect() @main.command() @@ -105,7 +105,7 @@ async def bootloader(ctx): version, plat, micro, phy = await ezsp.getStandaloneBootloaderVersionPlatMicroPhy() if version == 0xFFFF: click.echo("No boot loader installed") - ezsp.close() + await ezsp.disconnect() return click.echo( @@ -118,4 +118,4 @@ async def bootloader(ctx): click.echo(f"Couldn't launch bootloader: {res[0]}") else: click.echo("bootloader launched successfully") - ezsp.close() + await ezsp.disconnect() diff --git a/bellows/cli/network.py b/bellows/cli/network.py index 51266c74..84795de8 100644 --- a/bellows/cli/network.py +++ b/bellows/cli/network.py @@ -106,7 +106,7 @@ def cb(fut, frame_name, response): s.remove_callback(cbid) - s.close() + await s.disconnect() @main.command() @@ -126,7 +126,7 @@ async def leave(ctx): expected=t.EmberStatus.NETWORK_DOWN, ) - s.close() + await s.disconnect() @main.command() @@ -157,4 +157,4 @@ async def scan(ctx, channels, duration_ms, energy_scan): for network in v: click.echo(network) - s.close() + await s.disconnect() diff --git a/bellows/cli/stream.py b/bellows/cli/stream.py index fea70959..eb77dd00 100644 --- a/bellows/cli/stream.py +++ b/bellows/cli/stream.py @@ -35,7 +35,7 @@ def stream(ctx, channel, power): s = ctx.obj["ezsp"] loop.run_until_complete(s.mfglibStopStream()) loop.run_until_complete(s.mfglibEnd()) - s.close() + loop.run_until_complete(s.disconnect()) async def _stream(ctx, channel, power): diff --git a/bellows/cli/tone.py b/bellows/cli/tone.py index e31e3a0c..50a02e1e 100644 --- a/bellows/cli/tone.py +++ b/bellows/cli/tone.py @@ -35,7 +35,7 @@ def tone(ctx, channel, power): s = ctx.obj["ezsp"] loop.run_until_complete(s.mfglibStopTone()) loop.run_until_complete(s.mfglibEnd()) - s.close() + loop.run_until_complete(s.disconnect()) async def _tone(ctx, channel, power): diff --git a/bellows/cli/util.py b/bellows/cli/util.py index c8198147..76f83511 100644 --- a/bellows/cli/util.py +++ b/bellows/cli/util.py @@ -59,28 +59,17 @@ async def async_inner(ctx, *args, **kwargs): if extra_config: app_config.update(extra_config) application = await setup_application(app_config, startup=app_startup) - ctx.obj["app"] = application - await f(ctx, *args, **kwargs) - await asyncio.sleep(0.5) - await application.shutdown() - - def shutdown(): - with contextlib.suppress(Exception): - application._ezsp.close() + try: + ctx.obj["app"] = application + await f(ctx, *args, **kwargs) + finally: + with contextlib.suppress(Exception): + await application.shutdown() @functools.wraps(f) def inner(*args, **kwargs): loop = asyncio.get_event_loop() - try: - loop.run_until_complete(async_inner(*args, **kwargs)) - except: # noqa: E722 - # It seems that often errors like a message send will try to send - # two messages, and not reading all of them will leave the NCP in - # a bad state. This seems to mitigate this somewhat. Better way? - loop.run_until_complete(asyncio.sleep(0.5)) - raise - finally: - shutdown() + loop.run_until_complete(async_inner(*args, **kwargs)) return inner diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index 1ff511d3..669fd0f2 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -12,6 +12,8 @@ from typing import Any, Callable, Generator import urllib.parse +from bellows.ash import NcpFailure + if sys.version_info[:2] < (3, 11): from async_timeout import timeout as asyncio_timeout # pragma: no cover else: @@ -55,13 +57,14 @@ class EZSP: v14.EZSPv14.VERSION: v14.EZSPv14, } - def __init__(self, device_config: dict): + def __init__(self, device_config: dict, application: Any | None = None): self._config = device_config self._callbacks = {} self._ezsp_event = asyncio.Event() self._ezsp_version = v4.EZSPv4.VERSION self._gw = None self._protocol = None + self._application = application self._stack_status_listeners: collections.defaultdict[ t.sl_Status, list[asyncio.Future] @@ -122,25 +125,17 @@ async def startup_reset(self) -> None: await self.version() - @classmethod - async def initialize(cls, zigpy_config: dict) -> EZSP: - """Return initialized EZSP instance.""" - ezsp = cls(zigpy_config[conf.CONF_DEVICE]) - await ezsp.connect(use_thread=zigpy_config[conf.CONF_USE_THREAD]) + async def connect(self, *, use_thread: bool = True) -> None: + assert self._gw is None + self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread) try: - await ezsp.startup_reset() + self._protocol = v4.EZSPv4(self.handle_callback, self._gw) + await self.startup_reset() except Exception: - ezsp.close() + await self.disconnect() raise - return ezsp - - async def connect(self, *, use_thread: bool = True) -> None: - assert self._gw is None - self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread) - self._protocol = v4.EZSPv4(self.handle_callback, self._gw) - async def reset(self): LOGGER.debug("Resetting EZSP") self.stop_ezsp() @@ -179,10 +174,10 @@ async def version(self): ver, ) - def close(self): + async def disconnect(self): self.stop_ezsp() if self._gw: - self._gw.close() + await self._gw.disconnect() self._gw = None async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any: @@ -264,23 +259,12 @@ async def leaveNetwork(self, timeout: float | int = NETWORK_OPS_TIMEOUT) -> None def connection_lost(self, exc): """Lost serial connection.""" - LOGGER.debug( - "%s connection lost unexpectedly: %s", - self._config[conf.CONF_DEVICE_PATH], - exc, - ) - self.enter_failed_state(f"Serial connection loss: {exc!r}") - - def enter_failed_state(self, error): - """UART received error frame.""" - if len(self._callbacks) > 1: - LOGGER.error("NCP entered failed state. Requesting APP controller restart") - self.close() - self.handle_callback("_reset_controller_application", (error,)) - else: - LOGGER.info( - "NCP entered failed state. No application handler registered, ignoring..." - ) + if self._application is not None: + self._application.connection_lost(exc) + + def enter_failed_state(self, code: t.NcpResetCode) -> None: + """UART received reset code.""" + self.connection_lost(NcpFailure(code=code)) def __getattr__(self, name: str) -> Callable: if name not in self._protocol.COMMANDS: diff --git a/bellows/uart.py b/bellows/uart.py index d48838d5..e2dd3095 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -18,38 +18,29 @@ RESET_TIMEOUT = 5 -class Gateway(asyncio.Protocol): - def __init__(self, application, connected_future=None, connection_done_future=None): - self._application = application +class Gateway(zigpy.serial.SerialProtocol): + def __init__(self, api, connection_done_future=None): + super().__init__() + self._api = api self._reset_future = None self._startup_reset_future = None - self._connected_future = connected_future self._connection_done_future = connection_done_future - self._transport = None - - def close(self): - self._transport.close() - - def connection_made(self, transport): - """Callback when the uart is connected""" - self._transport = transport - if self._connected_future is not None: - self._connected_future.set_result(True) - async def send_data(self, data: bytes) -> None: await self._transport.send_data(data) def data_received(self, data): """Callback when there is data received from the uart""" - self._application.frame_received(data) + + # We intentionally do not call `SerialProtocol.data_received` + self._api.frame_received(data) def reset_received(self, code: t.NcpResetCode) -> None: """Reset acknowledgement frame receive handler""" - # not a reset we've requested. Signal application reset + # not a reset we've requested. Signal api reset if code is not t.NcpResetCode.RESET_SOFTWARE: - self._application.enter_failed_state(code) + self._api.enter_failed_state(code) return if self._reset_future and not self._reset_future.done(): @@ -61,7 +52,7 @@ def reset_received(self, code: t.NcpResetCode) -> None: def error_received(self, code: t.NcpResetCode) -> None: """Error frame receive handler.""" - self._application.enter_failed_state(code) + self._api.enter_failed_state(code) async def wait_for_startup_reset(self) -> None: """Wait for the first reset frame on startup.""" @@ -77,12 +68,9 @@ def _reset_cleanup(self, future): """Delete reset future.""" self._reset_future = None - def eof_received(self): - """Server gracefully closed its side of the connection.""" - self.connection_lost(ConnectionResetError("Remote server closed connection")) - def connection_lost(self, exc): """Port was closed unexpectedly.""" + super().connection_lost(exc) LOGGER.debug("Connection lost: %r", exc) reason = exc or ConnectionResetError("Remote server closed connection") @@ -102,12 +90,7 @@ def connection_lost(self, exc): self._reset_future.set_exception(reason) self._reset_future = None - if exc is None: - LOGGER.debug("Closed serial connection") - return - - LOGGER.error("Lost serial connection: %r", exc) - self._application.connection_lost(exc) + self._api.connection_lost(exc) async def reset(self): """Send a reset frame and init internal state.""" @@ -126,13 +109,12 @@ async def reset(self): return await self._reset_future -async def _connect(config, application): +async def _connect(config, api): loop = asyncio.get_event_loop() - connection_future = loop.create_future() connection_done_future = loop.create_future() - gateway = Gateway(application, connection_future, connection_done_future) + gateway = Gateway(api, connection_done_future) protocol = AshProtocol(gateway) if config[zigpy.config.CONF_DEVICE_FLOW_CONTROL] is None: @@ -149,25 +131,25 @@ async def _connect(config, application): rtscts=rtscts, ) - await connection_future + await gateway.wait_until_connected() thread_safe_protocol = ThreadsafeProxy(gateway, loop) return thread_safe_protocol, connection_done_future -async def connect(config, application, use_thread=True): +async def connect(config, api, use_thread=True): if use_thread: - application = ThreadsafeProxy(application, asyncio.get_event_loop()) + api = ThreadsafeProxy(api, asyncio.get_event_loop()) thread = EventLoopThread() await thread.start() try: protocol, connection_done = await thread.run_coroutine_threadsafe( - _connect(config, application) + _connect(config, api) ) except Exception: thread.force_stop() raise connection_done.add_done_callback(lambda _: thread.force_stop()) else: - protocol, _ = await _connect(config, application) + protocol, _ = await _connect(config, api) return protocol diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 0680184e..229e1499 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -141,23 +141,22 @@ async def _get_board_info(self) -> tuple[str, str, str] | tuple[None, None, None return None, None, None async def connect(self) -> None: - ezsp = bellows.ezsp.EZSP(self.config[zigpy.config.CONF_DEVICE]) - await ezsp.connect(use_thread=self.config[CONF_USE_THREAD]) + self._ezsp = bellows.ezsp.EZSP(self.config[zigpy.config.CONF_DEVICE], self) try: - await ezsp.startup_reset() + await self._ezsp.connect(use_thread=self.config[CONF_USE_THREAD]) # Writing config is required here because network info can't be loaded - await ezsp.write_config(self.config[CONF_EZSP_CONFIG]) + await self._ezsp.write_config(self.config[CONF_EZSP_CONFIG]) + + self._created_device_endpoints.clear() + await self.register_endpoints() except Exception: - ezsp.close() + if self._ezsp is not None: + await self._ezsp.disconnect() + self._ezsp = None raise - self._ezsp = ezsp - - self._created_device_endpoints.clear() - await self.register_endpoints() - async def _ensure_network_running(self) -> bool: """Ensures the network is currently running and returns whether or not the network was started. @@ -436,7 +435,7 @@ async def disconnect(self): # TODO: how do you shut down the stack? self.controller_event.clear() if self._ezsp is not None: - self._ezsp.close() + await self._ezsp.disconnect() self._ezsp = None async def force_remove(self, dev): @@ -519,8 +518,6 @@ def ezsp_callback_handler(self, frame_name, args): status, nwk = args status = t.sl_Status.from_ember_status(status) self.handle_route_error(status, nwk) - elif frame_name == "_reset_controller_application": - self.connection_lost(args[0]) elif frame_name == "idConflictHandler": self._handle_id_conflict(*args) diff --git a/pyproject.toml b/pyproject.toml index b76bf157..3a9c9b1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "click-log>=0.2.1", "pure_pcapy3==1.0.1", "voluptuous", - "zigpy>=0.68.0", + "zigpy>=0.70.0", 'async-timeout; python_version<"3.11"', ] diff --git a/tests/test_application.py b/tests/test_application.py index e35adfc5..8581e52f 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -12,6 +12,7 @@ import zigpy.types as zigpy_t import zigpy.zdo.types as zdo_t +from bellows.ash import NcpFailure import bellows.config as config from bellows.exception import ControllerError, EzspError import bellows.ezsp as ezsp @@ -112,7 +113,7 @@ def _create_app_for_startup( ezsp_mock.start_ezsp() ezsp_mock.connect = AsyncMock() - ezsp_mock.close = AsyncMock(wraps=ezsp_mock.close) + ezsp_mock.disconnect = AsyncMock() ezsp_mock.startup_reset = AsyncMock() ezsp_mock.can_burn_userdata_custom_eui64 = AsyncMock(return_value=True) ezsp_mock.can_rewrite_custom_eui64 = AsyncMock(return_value=True) @@ -1122,12 +1123,6 @@ def test_is_controller_running(app): assert ezsp_running.call_count == 1 -def test_reset_frame(app): - app.connection_lost = MagicMock(spec_set=app.connection_lost) - app.ezsp_callback_handler("_reset_controller_application", (sentinel.error,)) - assert app.connection_lost.mock_calls == [call(sentinel.error)] - - @pytest.mark.parametrize("ezsp_version", (4, 7)) async def test_watchdog(make_app, monkeypatch, ezsp_version): from bellows.zigbee import application @@ -1287,9 +1282,9 @@ async def counters_mock(): async def test_shutdown(app): ezsp = app._ezsp - await app.shutdown() + await app.disconnect() assert app.controller_event.is_set() is False - assert ezsp.close.call_count == 1 + assert len(ezsp.disconnect.mock_calls) == 1 @pytest.fixture @@ -1731,7 +1726,7 @@ async def test_connect_failure(app: ControllerApplication) -> None: assert app._ezsp is None - assert len(ezsp.close.mock_calls) == 1 + assert len(ezsp.disconnect.mock_calls) == 1 async def test_repair_tclk_partner_ieee( diff --git a/tests/test_ash.py b/tests/test_ash.py index cb7c356a..3b71bef1 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -600,3 +600,12 @@ async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None: with patch.object(ncp, "nak_state", True): with pytest.raises(ash.NotAcked): await host.send_data(b"ncp NAKing until failure") + + +def test_ncp_failure_comparison() -> None: + exc1 = ash.NcpFailure(code=t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT) + exc2 = ash.NcpFailure(code=t.NcpResetCode.RESET_POWER_ON) + + assert exc1 == exc1 + assert exc1 != exc2 + assert exc2 != t.NcpResetCode.RESET_POWER_ON diff --git a/tests/test_ezsp.py b/tests/test_ezsp.py index ee20af84..28fbf2c2 100644 --- a/tests/test_ezsp.py +++ b/tests/test_ezsp.py @@ -8,8 +8,10 @@ import pytest import zigpy.config -from bellows import config, ezsp, uart +from bellows import config, uart +from bellows.ash import NcpFailure from bellows.exception import EzspError, InvalidCommandError +from bellows.ezsp import EZSP, EZSP_LATEST import bellows.types as t if sys.version_info[:2] < (3, 11): @@ -17,7 +19,7 @@ else: from asyncio import timeout as asyncio_timeout # pragma: no cover -from unittest.mock import ANY, AsyncMock, MagicMock, call, patch, sentinel +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch from bellows.ezsp.v9.commands import GetTokenDataRsp @@ -27,49 +29,46 @@ } -@pytest.fixture -async def ezsp_f(): - api = ezsp.EZSP(DEVICE_CONFIG) - gw = MagicMock(spec_set=uart.Gateway) - with patch("bellows.uart.connect", new=AsyncMock(return_value=gw)): - await api.connect() - yield api - +def make_ezsp(config: dict = DEVICE_CONFIG, version: int = 4): + api = EZSP(config) -async def make_ezsp(version=4) -> ezsp.EZSP: - api = ezsp.EZSP(DEVICE_CONFIG) - gw = MagicMock(spec_set=uart.Gateway) + async def mock_command(command, *args, **kwargs): + if command in api._mock_commands: + return await api._mock_commands[command](*args, **kwargs) - with patch("bellows.uart.connect", new=AsyncMock(return_value=gw)): - await api.connect() + raise RuntimeError(f"Command has not been mocked: {command}({args}, {kwargs})") - assert api._ezsp_version == 4 + api._mock_commands = {} + api._mock_commands["version"] = AsyncMock(return_value=[version, 0, 0]) + api._command = AsyncMock(side_effect=mock_command) - with patch.object(api, "_command", new=AsyncMock(return_value=[version, 0, 0])): - await api.version() + return api - assert api._ezsp_version == version - return api +async def make_connected_ezsp(config: dict = DEVICE_CONFIG, version: int = 4): + with patch("bellows.uart.connect"): + ezsp = make_ezsp(config=config, version=version) + await ezsp.connect() + return ezsp -async def test_connect(ezsp_f, monkeypatch): - connected = False - async def mockconnect(*args, **kwargs): - nonlocal connected - connected = True +@pytest.fixture +async def ezsp_f() -> EZSP: + with patch("bellows.uart.connect"): + ezsp = make_ezsp(version=12) - monkeypatch.setattr(uart, "connect", mockconnect) - ezsp_f._gw = None + assert ezsp._ezsp_version == 4 + await ezsp.connect() + assert ezsp._ezsp_version == 12 - await ezsp_f.connect() - assert connected + yield ezsp async def test_reset(ezsp_f): ezsp_f.stop_ezsp = MagicMock() ezsp_f.start_ezsp = MagicMock() + reset_mock = AsyncMock() ezsp_f._gw.reset = MagicMock(side_effect=reset_mock) @@ -80,17 +79,10 @@ async def test_reset(ezsp_f): assert len(ezsp_f._callbacks) == 1 -def test_close(ezsp_f): - closed = False - - def close_mock(*args): - nonlocal closed - closed = True - - ezsp_f._gw.close = close_mock - ezsp_f.close() - assert closed is True - assert ezsp_f._gw is None +async def test_disconnect(ezsp_f): + gw_disconnect = ezsp_f._gw.disconnect + await ezsp_f.disconnect() + assert len(gw_disconnect.mock_calls) == 1 def test_attr(ezsp_f): @@ -105,91 +97,123 @@ async def test_non_existent_attr(ezsp_f): async def test_command(ezsp_f): - ezsp_f.start_ezsp() + # Un-mock it + ezsp_f._command = EZSP._command.__get__(ezsp_f, EZSP) + with patch.object(ezsp_f._protocol, "command") as cmd_mock: await ezsp_f.nop() assert cmd_mock.call_count == 1 async def test_command_ezsp_stopped(ezsp_f): + # Un-mock it + ezsp_f._command = EZSP._command.__get__(ezsp_f, EZSP) ezsp_f.stop_ezsp() with pytest.raises(EzspError): await ezsp_f._command("version") -async def _test_list_command(ezsp_f, mockcommand): - ezsp_f._command = mockcommand - return await ezsp_f._list_command( - "startScan", ["networkFoundHandler"], "scanCompleteHandler", 1 - ) - +async def test_list_command(): + ezsp = await make_connected_ezsp(version=4) -async def test_list_command(ezsp_f): async def mockcommand(name, *args, **kwargs): assert name == "startScan" - ezsp_f.frame_received(b"\x01\x00\x1b" + b"\x00" * 20) - ezsp_f.frame_received(b"\x02\x00\x1b" + b"\x00" * 20) - ezsp_f.frame_received(b"\x03\x00\x1c" + b"\x00" * 20) + ezsp.frame_received(b"\x01\x00\x1b" + b"\x00" * 20) + ezsp.frame_received(b"\x02\x00\x1b" + b"\x00" * 20) + ezsp.frame_received(b"\x03\x00\x1c" + b"\x00" * 20) return [t.EmberStatus.SUCCESS] - result = await _test_list_command(ezsp_f, mockcommand) + ezsp._command = mockcommand + + result = await ezsp._list_command( + "startScan", + ["networkFoundHandler"], + "scanCompleteHandler", + 1, + ) assert len(result) == 2 -async def test_list_command_initial_failure(ezsp_f): +async def test_list_command_initial_failure(): + ezsp = await make_connected_ezsp(version=4) + async def mockcommand(name, *args, **kwargs): assert name == "startScan" return [t.EmberStatus.FAILURE] + ezsp._command = mockcommand + with pytest.raises(Exception): - await _test_list_command(ezsp_f, mockcommand) + await ezsp._list_command( + "startScan", + ["networkFoundHandler"], + "scanCompleteHandler", + 1, + ) -async def test_list_command_later_failure(ezsp_f): +async def test_list_command_later_failure(): + ezsp = await make_connected_ezsp(version=4) + async def mockcommand(name, *args, **kwargs): assert name == "startScan" - ezsp_f.frame_received(b"\x01\x00\x1b" + b"\x00" * 20) - ezsp_f.frame_received(b"\x02\x00\x1b" + b"\x00" * 20) - ezsp_f.frame_received(b"\x03\x00\x1c\x01\x01") + ezsp.frame_received(b"\x01\x00\x1b" + b"\x00" * 20) + ezsp.frame_received(b"\x02\x00\x1b" + b"\x00" * 20) + ezsp.frame_received(b"\x03\x00\x1c\x01\x01") return [t.EmberStatus.SUCCESS] + ezsp._command = mockcommand + with pytest.raises(Exception): - await _test_list_command(ezsp_f, mockcommand) + await ezsp._list_command( + "startScan", + ["networkFoundHandler"], + "scanCompleteHandler", + 1, + ) -async def _test_form_network(ezsp_f, initial_result, final_result): +async def _test_form_network(ezsp, initial_result, final_result): async def mockcommand(name, *args, **kwargs): assert name == "formNetwork" - ezsp_f.frame_received(b"\x01\x00\x19" + final_result) + ezsp.frame_received(b"\x01\x00\x19" + final_result) return initial_result - ezsp_f._command = mockcommand + ezsp._command = mockcommand - await ezsp_f.formNetwork(MagicMock()) + await ezsp.formNetwork(MagicMock()) -async def test_form_network(ezsp_f): - await _test_form_network(ezsp_f, [t.EmberStatus.SUCCESS], b"\x90") +async def test_form_network(): + ezsp = await make_connected_ezsp(version=4) + await _test_form_network(ezsp, [t.EmberStatus.SUCCESS], b"\x90") + + +async def test_form_network_fail(): + ezsp = await make_connected_ezsp(version=4) -async def test_form_network_fail(ezsp_f): with pytest.raises(Exception): - await _test_form_network(ezsp_f, [t.EmberStatus.FAILURE], b"\x90") + await _test_form_network(ezsp, [t.EmberStatus.FAILURE], b"\x90") @patch("bellows.ezsp.NETWORK_OPS_TIMEOUT", 0.1) -async def test_form_network_fail_stack_status(ezsp_f): +async def test_form_network_fail_stack_status(): + ezsp = await make_connected_ezsp(version=4) + with pytest.raises(Exception): - await _test_form_network(ezsp_f, [t.EmberStatus.SUCCESS], b"\x00") + await _test_form_network(ezsp, [t.EmberStatus.SUCCESS], b"\x00") -def test_receive_new(ezsp_f): +async def test_receive_new(): + ezsp = await make_connected_ezsp(version=4) + callback = MagicMock() - ezsp_f.add_callback(callback) - ezsp_f.frame_received(b"\x00\xff\x00\x04\x05\x06\x00") + ezsp.add_callback(callback) + ezsp.frame_received(b"\x00\xff\x00\x04\x05\x06\x00") assert callback.call_count == 1 @@ -232,18 +256,18 @@ def test_callback_exc(ezsp_f): @pytest.mark.parametrize("version, call_count", ((4, 1), (5, 2), (6, 2), (99, 2))) -async def test_change_version(ezsp_f, version, call_count): - def mockcommand(name, *args, **kwargs): +async def test_change_version(version, call_count): + ezsp = await make_connected_ezsp(version=4) + + async def mockcommand(name, *args, **kwargs): assert name == "version" - ezsp_f.frame_received(b"\x01\x00\x00\x21\x22\x23\x24") - fut = asyncio.Future() - fut.set_result([version, 2, 2046]) - return fut + ezsp.frame_received(b"\x01\x00\x00\x21\x22\x23\x24") + return [version, 2, 2046] - ezsp_f._command = MagicMock(side_effect=mockcommand) - await ezsp_f.version() - assert ezsp_f.ezsp_version == version - assert ezsp_f._command.call_count == call_count + ezsp._command = AsyncMock(side_effect=mockcommand) + await ezsp.version() + assert ezsp.ezsp_version == version + assert ezsp._command.call_count == call_count def test_stop_ezsp(ezsp_f): @@ -258,59 +282,30 @@ def test_start_ezsp(ezsp_f): assert ezsp_f._ezsp_event.is_set() is True -def test_connection_lost(ezsp_f): - ezsp_f.enter_failed_state = MagicMock(spec_set=ezsp_f.enter_failed_state) - ezsp_f.connection_lost(sentinel.exc) - assert ezsp_f.enter_failed_state.call_count == 1 +def test_enter_failed_state(ezsp_f): + ezsp_f._application = MagicMock() + ezsp_f.enter_failed_state(t.NcpResetCode.RESET_SOFTWARE) - -async def test_enter_failed_state(ezsp_f): - ezsp_f.stop_ezsp = MagicMock(spec_set=ezsp_f.stop_ezsp) - cb = MagicMock(spec_set=ezsp_f.handle_callback) - ezsp_f.add_callback(cb) - ezsp_f.enter_failed_state(sentinel.error) - await asyncio.sleep(0) - assert ezsp_f.stop_ezsp.call_count == 1 - assert cb.call_count == 1 - assert cb.call_args[0][1][0] == sentinel.error - - -async def test_no_close_without_callback(ezsp_f): - ezsp_f.stop_ezsp = MagicMock(spec_set=ezsp_f.stop_ezsp) - ezsp_f.close = MagicMock(spec_set=ezsp_f.close) - ezsp_f.enter_failed_state(sentinel.error) - await asyncio.sleep(0) - assert ezsp_f.stop_ezsp.call_count == 0 - assert ezsp_f.close.call_count == 0 - - -@patch.object(ezsp.EZSP, "version", new_callable=AsyncMock) -@patch.object(ezsp.EZSP, "reset", new_callable=AsyncMock) -@patch("bellows.uart.connect", return_value=MagicMock(spec_set=uart.Gateway)) -async def test_ezsp_init(conn_mock, reset_mock, version_mock): - """Test initialize method.""" - zigpy_config = config.CONFIG_SCHEMA({"device": DEVICE_CONFIG}) - await ezsp.EZSP.initialize(zigpy_config) - assert conn_mock.await_count == 1 - assert reset_mock.await_count == 1 - assert version_mock.await_count == 1 + assert ezsp_f._application.connection_lost.mock_calls == [ + call(NcpFailure(code=t.NcpResetCode.RESET_SOFTWARE)) + ] -@patch.object(ezsp.EZSP, "version", side_effect=RuntimeError("Uh oh")) -@patch.object(ezsp.EZSP, "reset", new_callable=AsyncMock) -@patch.object(ezsp.EZSP, "close", new_callable=MagicMock) -@patch("bellows.uart.connect", return_value=MagicMock(spec_set=uart.Gateway)) -async def test_ezsp_init_failure(conn_mock, close_mock, reset_mock, version_mock): +@patch.object(EZSP, "version", side_effect=RuntimeError("Uh oh")) +@patch.object(EZSP, "reset", new_callable=AsyncMock) +@patch.object(EZSP, "disconnect", new_callable=AsyncMock) +async def test_ezsp_connect_failure(disconnect_mock, reset_mock, version_mock): """Test initialize method failing.""" - zigpy_config = config.CONFIG_SCHEMA({"device": DEVICE_CONFIG}) + with patch("bellows.uart.connect") as conn_mock: + ezsp = make_ezsp(version=4) - with pytest.raises(RuntimeError): - await ezsp.EZSP.initialize(zigpy_config) + with pytest.raises(RuntimeError): + await ezsp.connect() assert conn_mock.await_count == 1 assert reset_mock.await_count == 1 assert version_mock.await_count == 1 - assert close_mock.call_count == 1 + assert disconnect_mock.call_count == 1 async def test_ezsp_newer_version(ezsp_f): @@ -392,36 +387,9 @@ async def replacement(command_name, tokenId=None, valueId=None): assert (mfg, brd, ver) == expected -async def test_pre_permit(ezsp_f): - with patch("bellows.ezsp.v4.EZSPv4.pre_permit") as pre_mock: - await ezsp_f.pre_permit(sentinel.time) - assert pre_mock.call_count == 1 - assert pre_mock.await_count == 1 - - -async def test_update_policies(ezsp_f): - with patch("bellows.ezsp.v4.EZSPv4.update_policies") as pol_mock: - await ezsp_f.update_policies(sentinel.time) - assert pol_mock.call_count == 1 - assert pol_mock.await_count == 1 - - -async def test_set_source_routing_set_concentrator(ezsp_f): +async def test_set_source_routing(ezsp_f): """Test enabling source routing.""" - with patch.object(ezsp_f, "setConcentrator", new=AsyncMock()) as cnc_mock: - cnc_mock.return_value = (t.EmberStatus.SUCCESS,) - await ezsp_f.set_source_routing() - assert cnc_mock.await_count == 1 - - cnc_mock.return_value = (t.EmberStatus.ERR_FATAL,) - await ezsp_f.set_source_routing() - assert cnc_mock.await_count == 2 - - -async def test_set_source_routing_ezsp_v8(ezsp_f): - """Test enabling source routing on EZSPv8.""" - ezsp_f._ezsp_version = 8 ezsp_f.setConcentrator = AsyncMock(return_value=(t.EmberStatus.SUCCESS,)) ezsp_f.setSourceRouteDiscoveryMode = AsyncMock() @@ -515,6 +483,8 @@ def get_token_data(token, index): async def test_can_rewrite_custom_eui64_old_ezsp(ezsp_f): """Test detecting if a custom EUI64 can be rewritten in NV3, but with old EZSP.""" + ezsp_f._ezsp_version = 4 + ezsp_f.getTokenData = AsyncMock(side_effect=InvalidCommandError) assert await ezsp_f._get_nv3_restored_eui64_key() is None assert not await ezsp_f.can_rewrite_custom_eui64() @@ -616,23 +586,20 @@ async def test_write_custom_eui64_rcp(ezsp_f): ] -@patch.object(ezsp.EZSP, "version", new_callable=AsyncMock) -@patch.object(ezsp.EZSP, "reset", new_callable=AsyncMock) -@patch("bellows.uart.connect", return_value=MagicMock(spec_set=uart.Gateway)) -async def test_ezsp_init_zigbeed(conn_mock, reset_mock, version_mock): +@patch.object(EZSP, "version", new_callable=AsyncMock) +@patch.object(EZSP, "reset", new_callable=AsyncMock) +async def test_ezsp_init_zigbeed(reset_mock, version_mock): """Test initialize method with a received startup reset frame.""" - zigpy_config = config.CONFIG_SCHEMA( - { - "device": { - **DEVICE_CONFIG, - zigpy.config.CONF_DEVICE_PATH: "socket://localhost:1234", - } + ezsp = make_ezsp( + config={ + **DEVICE_CONFIG, + zigpy.config.CONF_DEVICE_PATH: "socket://localhost:1234", } ) - gw_wait_reset_mock = conn_mock.return_value.wait_for_startup_reset = AsyncMock() - - await ezsp.EZSP.initialize(zigpy_config) + with patch("bellows.uart.connect") as conn_mock: + gw_wait_reset_mock = conn_mock.return_value.wait_for_startup_reset = AsyncMock() + await ezsp.connect() assert conn_mock.await_count == 1 assert reset_mock.await_count == 0 # Reset is not called @@ -640,29 +607,26 @@ async def test_ezsp_init_zigbeed(conn_mock, reset_mock, version_mock): assert version_mock.await_count == 1 -@patch.object(ezsp.EZSP, "version", new_callable=AsyncMock) -@patch.object(ezsp.EZSP, "reset", new_callable=AsyncMock) -@patch("bellows.uart.connect", return_value=MagicMock(spec_set=uart.Gateway)) +@patch.object(EZSP, "version", new_callable=AsyncMock) +@patch.object(EZSP, "reset", new_callable=AsyncMock) @patch("bellows.ezsp.NETWORK_COORDINATOR_STARTUP_RESET_WAIT", 0.01) -async def test_ezsp_init_zigbeed_timeout(conn_mock, reset_mock, version_mock): +async def test_ezsp_init_zigbeed_timeout(reset_mock, version_mock): """Test initialize method with a received startup reset frame.""" - zigpy_config = config.CONFIG_SCHEMA( - { - "device": { - **DEVICE_CONFIG, - zigpy.config.CONF_DEVICE_PATH: "socket://localhost:1234", - } + ezsp = make_ezsp( + config={ + **DEVICE_CONFIG, + zigpy.config.CONF_DEVICE_PATH: "socket://localhost:1234", } ) async def wait_forever(*args, **kwargs): return await asyncio.get_running_loop().create_future() - gw_wait_reset_mock = conn_mock.return_value.wait_for_startup_reset = AsyncMock( - side_effect=wait_forever - ) - - await ezsp.EZSP.initialize(zigpy_config) + with patch("bellows.uart.connect") as conn_mock: + gw_wait_reset_mock = conn_mock.return_value.wait_for_startup_reset = AsyncMock( + side_effect=wait_forever + ) + await ezsp.connect() assert conn_mock.await_count == 1 assert reset_mock.await_count == 1 # Reset will be called @@ -694,20 +658,20 @@ async def test_wait_for_stack_status(ezsp_f): def test_ezsp_versions(ezsp_f): - for version in range(4, ezsp.EZSP_LATEST + 1): + for version in range(4, EZSP_LATEST + 1): assert version in ezsp_f._BY_VERSION assert ezsp_f._BY_VERSION[version].__name__ == f"EZSPv{version}" assert ezsp_f._BY_VERSION[version].VERSION == version -async def test_config_initialize_husbzb1(ezsp_f): +async def test_config_initialize_husbzb1(): """Test timeouts are properly set for HUSBZB-1.""" - ezsp_f._ezsp_version = 4 + ezsp = await make_connected_ezsp(version=4) - ezsp_f.getConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, 0)) - ezsp_f.setConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,)) - ezsp_f.networkState = AsyncMock(return_value=(t.EmberNetworkStatus.JOINED_NETWORK,)) + ezsp.getConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, 0)) + ezsp.setConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,)) + ezsp.networkState = AsyncMock(return_value=(t.EmberNetworkStatus.JOINED_NETWORK,)) expected_calls = [ call(configId=t.EzspConfigId.CONFIG_SOURCE_ROUTE_TABLE_SIZE, value=16), @@ -733,96 +697,98 @@ async def test_config_initialize_husbzb1(ezsp_f): call(configId=t.EzspConfigId.CONFIG_PACKET_BUFFER_COUNT, value=255), ] - await ezsp_f.write_config({}) - assert ezsp_f.setConfigurationValue.mock_calls == expected_calls + await ezsp.write_config({}) + assert ezsp.setConfigurationValue.mock_calls == expected_calls -@pytest.mark.parametrize("version", ezsp.EZSP._BY_VERSION) -async def test_config_initialize(version: int, ezsp_f, caplog): +@pytest.mark.parametrize("version", EZSP._BY_VERSION) +async def test_config_initialize(version: int, caplog): """Test config initialization for all protocol versions.""" - assert ezsp_f.ezsp_version == 4 + ezsp = await make_connected_ezsp(version=version) - with patch.object(ezsp_f, "_command", AsyncMock(return_value=[version, 2, 2046])): - await ezsp_f.version() + with patch.object(ezsp, "_command", AsyncMock(return_value=[version, 2, 2046])): + await ezsp.version() - assert ezsp_f.ezsp_version == version + assert ezsp.ezsp_version == version - ezsp_f.getConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, 0)) - ezsp_f.setConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,)) - ezsp_f.networkState = AsyncMock(return_value=(t.EmberNetworkStatus.JOINED_NETWORK,)) + ezsp.getConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, 0)) + ezsp.setConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,)) + ezsp.networkState = AsyncMock(return_value=(t.EmberNetworkStatus.JOINED_NETWORK,)) - ezsp_f.setValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,)) - ezsp_f.getValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, b"\xFF")) + ezsp.setValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,)) + ezsp.getValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, b"\xFF")) - await ezsp_f.write_config({}) + await ezsp.write_config({}) with caplog.at_level(logging.DEBUG): - ezsp_f.setConfigurationValue.return_value = (t.EzspStatus.ERROR_OUT_OF_MEMORY,) - await ezsp_f.write_config({}) + ezsp.setConfigurationValue.return_value = (t.EzspStatus.ERROR_OUT_OF_MEMORY,) + await ezsp.write_config({}) assert "Could not set config" in caplog.text - ezsp_f.setConfigurationValue.return_value = (t.EzspStatus.SUCCESS,) + ezsp.setConfigurationValue.return_value = (t.EzspStatus.SUCCESS,) caplog.clear() # EZSPv6 does not set any values on startup if version < 7: return - ezsp_f.setValue.reset_mock() - ezsp_f.getValue.return_value = (t.EzspStatus.ERROR_INVALID_ID, b"") - await ezsp_f.write_config({}) - assert len(ezsp_f.setValue.mock_calls) == 1 + ezsp.setValue.reset_mock() + ezsp.getValue.return_value = (t.EzspStatus.ERROR_INVALID_ID, b"") + await ezsp.write_config({}) + assert len(ezsp.setValue.mock_calls) == 1 - ezsp_f.getValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, b"\xFF")) + ezsp.getValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, b"\xFF")) caplog.clear() with caplog.at_level(logging.DEBUG): - ezsp_f.setValue.return_value = (t.EzspStatus.ERROR_INVALID_ID,) - await ezsp_f.write_config({}) + ezsp.setValue.return_value = (t.EzspStatus.ERROR_INVALID_ID,) + await ezsp.write_config({}) assert "Could not set value" in caplog.text - ezsp_f.setValue.return_value = (t.EzspStatus.SUCCESS,) + ezsp.setValue.return_value = (t.EzspStatus.SUCCESS,) caplog.clear() -async def test_cfg_initialize_skip(ezsp_f): +async def test_cfg_initialize_skip(): """Test initialization.""" - ezsp_f.networkState = AsyncMock(return_value=(t.EmberNetworkStatus.JOINED_NETWORK,)) + ezsp = await make_connected_ezsp(version=4) + + ezsp.networkState = AsyncMock(return_value=(t.EmberNetworkStatus.JOINED_NETWORK,)) p1 = patch.object( - ezsp_f, + ezsp, "setConfigurationValue", new=AsyncMock(return_value=(t.EzspStatus.SUCCESS,)), ) p2 = patch.object( - ezsp_f, + ezsp, "getConfigurationValue", new=AsyncMock(return_value=(t.EzspStatus.SUCCESS, 22)), ) with p1, p2: - await ezsp_f.write_config({"CONFIG_END_DEVICE_POLL_TIMEOUT": None}) + await ezsp.write_config({"CONFIG_END_DEVICE_POLL_TIMEOUT": None}) # Config not set when it is explicitly disabled with pytest.raises(AssertionError): - ezsp_f.setConfigurationValue.assert_called_with( + ezsp.setConfigurationValue.assert_called_with( configId=t.EzspConfigId.CONFIG_END_DEVICE_POLL_TIMEOUT, value=ANY ) with p1, p2: - await ezsp_f.write_config({"CONFIG_MULTICAST_TABLE_SIZE": 123}) + await ezsp.write_config({"CONFIG_MULTICAST_TABLE_SIZE": 123}) # Config is overridden - ezsp_f.setConfigurationValue.assert_any_call( + ezsp.setConfigurationValue.assert_any_call( configId=t.EzspConfigId.CONFIG_MULTICAST_TABLE_SIZE, value=123 ) with p1, p2: - await ezsp_f.write_config({}) + await ezsp.write_config({}) # Config is set by default - ezsp_f.setConfigurationValue.assert_any_call( + ezsp.setConfigurationValue.assert_any_call( configId=t.EzspConfigId.CONFIG_END_DEVICE_POLL_TIMEOUT, value=ANY ) diff --git a/tests/test_uart.py b/tests/test_uart.py index fdd404fb..1fdc05f3 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -3,8 +3,8 @@ from unittest.mock import AsyncMock, MagicMock, call, patch, sentinel import pytest -import serial_asyncio import zigpy.config as conf +import zigpy.serial from bellows import uart import bellows.types as t @@ -20,7 +20,7 @@ async def mockconnect(loop, protocol_factory, **kwargs): loop.call_soon(protocol.connection_made, transport) return None, protocol - monkeypatch.setattr(serial_asyncio, "create_serial_connection", mockconnect) + monkeypatch.setattr(zigpy.serial, "create_serial_connection", mockconnect) gw = await uart.connect( conf.SCHEMA_DEVICE( { @@ -47,7 +47,7 @@ async def mockconnect(loop, protocol_factory, **kwargs): loop.call_soon(protocol.connection_made, transport) return None, protocol - monkeypatch.setattr(serial_asyncio, "create_serial_connection", mockconnect) + monkeypatch.setattr(zigpy.serial, "create_serial_connection", mockconnect) def on_transport_close(): gw.connection_lost(None) @@ -76,7 +76,7 @@ async def test_connect_threaded_failure(monkeypatch): mockconnect = AsyncMock() mockconnect.side_effect = OSError - monkeypatch.setattr(serial_asyncio, "create_serial_connection", mockconnect) + monkeypatch.setattr(zigpy.serial, "create_serial_connection", mockconnect) def on_transport_close(): gw.connection_lost(None) @@ -168,21 +168,9 @@ async def test_reset_old(gw): def test_connection_lost_exc(gw): gw.connection_lost(sentinel.exception) - conn_lost = gw._application.connection_lost + conn_lost = gw._api.connection_lost assert conn_lost.call_count == 1 - assert conn_lost.call_args[0][0] is sentinel.exception - - -def test_connection_closed(gw): - gw.connection_lost(None) - - assert gw._application.connection_lost.call_count == 0 - - -def test_eof_received(gw): - gw.eof_received() - - assert gw._application.connection_lost.call_count == 1 + assert conn_lost.mock_calls[0].args[0] is sentinel.exception async def test_connection_lost_reset_error_propagation(monkeypatch): @@ -195,7 +183,7 @@ async def mockconnect(loop, protocol_factory, **kwargs): loop.call_soon(protocol.connection_made, transport) return None, protocol - monkeypatch.setattr(serial_asyncio, "create_serial_connection", mockconnect) + monkeypatch.setattr(zigpy.serial, "create_serial_connection", mockconnect) def on_transport_close(): gw.connection_lost(None) @@ -243,9 +231,16 @@ async def test_wait_for_startup_reset_failure(gw): async def test_callbacks(gw): gw.data_received(b"some ezsp packet") - assert gw._application.frame_received.mock_calls == [call(b"some ezsp packet")] + assert gw._api.frame_received.mock_calls == [call(b"some ezsp packet")] gw.error_received(t.NcpResetCode.RESET_SOFTWARE) - assert gw._application.enter_failed_state.mock_calls == [ + assert gw._api.enter_failed_state.mock_calls == [ call(t.NcpResetCode.RESET_SOFTWARE) ] + + +def test_reset_propagation(gw): + gw.reset_received(t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT) + assert gw._api.enter_failed_state.mock_calls == [ + call(t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT) + ] diff --git a/tests/test_zigbee_repairs.py b/tests/test_zigbee_repairs.py index c6fd3f3d..90a4939f 100644 --- a/tests/test_zigbee_repairs.py +++ b/tests/test_zigbee_repairs.py @@ -11,7 +11,7 @@ import bellows.types as t from bellows.zigbee import repairs -from tests.test_ezsp import ezsp_f, make_ezsp +from tests.test_ezsp import ezsp_f, make_connected_ezsp @pytest.fixture @@ -116,7 +116,8 @@ async def test_fix_invalid_tclk_all_versions( ) -> None: """Test that the TCLK is fixed (or not) on all versions.""" - ezsp = await make_ezsp(version) + ezsp = await make_connected_ezsp(version=version) + fw_has_token_interface = hasattr(ezsp, "setTokenData") if fw_has_token_interface: