diff --git a/envr-default b/envr-default index 685c1dc..6ca80ab 100644 --- a/envr-default +++ b/envr-default @@ -8,4 +8,4 @@ PYTHON_VENV=.venv [ALIASES] lint=black --check . && isort --check-only . && flake8 . && mypy . -test=(coverage erase && pytest --cov --maxfail=3 -n auto) \ No newline at end of file +test=coverage erase && pytest --cov --maxfail=3 -n auto \ No newline at end of file diff --git a/smp/image_management.py b/smp/image_management.py index 4822346..2cd7b5e 100644 --- a/smp/image_management.py +++ b/smp/image_management.py @@ -75,20 +75,23 @@ class ImageUploadWriteRequest(message.WriteRequest): upgrade: bool | None = None # allowed when off == 0 -class ImageUploadProgressWriteResponse(message.WriteResponse): +class ImageUploadWriteResponse(message.WriteResponse): _GROUP_ID = header.GroupId.IMAGE_MANAGEMENT _COMMAND_ID = header.CommandId.ImageManagement.UPLOAD - rc: int | None = None off: int | None = None + """The portion of the upload that has been completed, in 8-bit bytes. - -class ImageUploadFinalWriteResponse(message.WriteResponse): - _GROUP_ID = header.GroupId.IMAGE_MANAGEMENT - _COMMAND_ID = header.CommandId.ImageManagement.UPLOAD - - off: int | None = None + This is the offset of the next byte to be written. If the offset is equal to + the length of the image, the upload is complete. + """ match: bool | None = None + """Indicates if the uploaded data successfully matches the provided SHA256. + + Only sent in the final packet if CONFIG_IMG_ENABLE_IMAGE_CHECK is enabled. + """ + rc: int | None = None + """Unspecified field used by MCUBoot's SMP Server implementation.""" class ImageEraseRequest(message.WriteRequest): diff --git a/tests/test_image_management.py b/tests/test_image_management.py index a05c233..f0c10ca 100644 --- a/tests/test_image_management.py +++ b/tests/test_image_management.py @@ -2,6 +2,7 @@ from __future__ import annotations +import sys from typing import cast import cbor2 @@ -179,3 +180,109 @@ def test_ImageEraseResponse() -> None: r = smpimg.ImageEraseResponse.load(cast(smpheader.Header, r.header), {}) assert_header(r) assert smpheader.Header.SIZE + 1 == len(r.BYTES) + + +def test_ImageUploadWriteRequest() -> None: + assert_header = make_assert_header( + smpheader.GroupId.IMAGE_MANAGEMENT, + smpheader.OP.WRITE, + smpheader.CommandId.ImageManagement.UPLOAD, + None, + ) + r = smpimg.ImageUploadWriteRequest( + off=0, + data=b"hello", + image=1, + len=5, + sha=b"world", + upgrade=True, + ) + + assert_header(r) + + r = smpimg.ImageUploadWriteRequest.loads(r.BYTES) + assert_header(r) + + assert_header = make_assert_header( + smpheader.GroupId.IMAGE_MANAGEMENT, + smpheader.OP.WRITE, + smpheader.CommandId.ImageManagement.UPLOAD, + None, + ) + r = smpimg.ImageUploadWriteRequest( + off=0, + data=b"hello", + image=1, + len=5, + sha=b"world", + upgrade=True, + ) + + assert_header(r) + assert r.off == 0 + assert r.data == b"hello" + assert r.image == 1 + assert r.len == 5 + assert r.sha == b"world" + assert r.upgrade is True + + r = smpimg.ImageUploadWriteRequest.loads(r.BYTES) + assert_header(r) + assert r.off == 0 + assert r.data == b"hello" + assert r.image == 1 + assert r.len == 5 + assert r.sha == b"world" + assert r.upgrade is True + + # when off != 0 do not send image, len, sha, or upgrade + r = smpimg.ImageUploadWriteRequest(off=10, data=b"hello") + assert_header(r) + assert r.off == 10 + assert r.data == b"hello" + + +@pytest.mark.parametrize("off", [None, 0, 1, 0xFFFF, 0xFFFFFFFF]) +@pytest.mark.parametrize("match", [None, True, False]) +@pytest.mark.parametrize("rc", [None, 0, 1, 10]) +def test_ImageUploadWriteResponse(off: int | None, match: bool | None, rc: int | None) -> None: + assert_header = make_assert_header( + smpheader.GroupId.IMAGE_MANAGEMENT, + smpheader.OP.WRITE_RSP, + smpheader.CommandId.ImageManagement.UPLOAD, + None, + ) + r = smpimg.ImageUploadWriteResponse(off=off, match=match, rc=rc) + + assert_header(r) + assert r.off == off + assert r.match == match + assert r.rc == rc + + r = smpimg.ImageUploadWriteResponse.loads(r.BYTES) + assert_header(r) + assert r.off == off + assert r.match == match + assert r.rc == rc + + if sys.version_info >= (3, 9): + cbor_dict = ( + {} + | ({"off": off} if off is not None else {}) + | ({"match": match} if match is not None else {}) + | ({"rc": rc} if rc is not None else {}) + ) + else: + cbor_dict = {} + if off is not None: + cbor_dict["off"] = off + if match is not None: + cbor_dict["match"] = match + if rc is not None: + cbor_dict["rc"] = rc + + r = smpimg.ImageUploadWriteResponse.load(cast(smpheader.Header, r.header), cbor_dict) + assert_header(r) + assert r.off == off + assert r.match == match + assert r.rc == rc diff --git a/tests/test_injected_header.py b/tests/test_injected_header.py index 5d0404e..e70a3ce 100644 --- a/tests/test_injected_header.py +++ b/tests/test_injected_header.py @@ -87,7 +87,7 @@ def test_ImageUploadWriteResponse_injected_header() -> None: command_id=smphdr.CommandId.ImageManagement.UPLOAD, ) - r = smpimg.ImageUploadProgressWriteResponse( + r = smpimg.ImageUploadWriteResponse( header=smphdr.Header( op=h.op, version=h.version,