Skip to content

Commit

Permalink
Improve tests and fix some small bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Schamper committed Jan 10, 2024
1 parent 87f6471 commit eb7ea58
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
8 changes: 5 additions & 3 deletions dissect/util/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ class ZlibStream(AlignedStream):
size: The size the stream should be.
"""

def __init__(self, fh: BinaryIO, size: Optional[int] = None, **kwargs):
def __init__(self, fh: BinaryIO, size: Optional[int] = None, align: int = STREAM_BUFFER_SIZE, **kwargs):
self._fh = fh

self._zlib = None
Expand All @@ -575,7 +575,7 @@ def __init__(self, fh: BinaryIO, size: Optional[int] = None, **kwargs):
self._zlib_prepend_offset = None
self._rewind()

super().__init__(size)
super().__init__(size, align)

def _rewind(self) -> None:
self._fh.seek(0)
Expand All @@ -600,7 +600,7 @@ def _read_fh(self, length: int) -> bytes:
if self._zlib_prepend_offset + length <= len(self._zlib_prepend):
offset = self._zlib_prepend_offset
self._zlib_prepend_offset += length
return self._zlib_prepend_offset[offset : self._zlib_prepend_offset]
return self._zlib_prepend[offset : self._zlib_prepend_offset]
else:
offset = self._zlib_prepend_offset
self._zlib_prepend_offset = None
Expand Down Expand Up @@ -634,6 +634,8 @@ def _read(self, offset: int, length: int) -> bytes:
return self._read_zlib(length)

def readall(self) -> bytes:
self._seek_zlib(self.tell())

chunks = []
# sys.maxsize means the max length of output buffer is unlimited,
# so that the whole input buffer can be decompressed within one
Expand Down
30 changes: 21 additions & 9 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import zlib
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -174,20 +175,31 @@ def test_overlay_stream():


def test_zlib_stream():
buf = io.BytesIO(zlib.compress(b"\x01" * 512 + b"\x02" * 512 + b"\x03" * 512 + b"\x04" * 512))
fh = stream.ZlibStream(buf, size=512 * 4)
data = b"\x01" * 8192 + b"\x02" * 8192 + b"\x03" * 8192 + b"\x04" * 8192
fh = stream.ZlibStream(io.BytesIO(zlib.compress(data)), size=8192 * 4, align=512)

assert fh.read(512) == b"\x01" * 512
assert fh.read(512) == b"\x02" * 512
assert fh.read(512) == b"\x03" * 512
assert fh.read(512) == b"\x04" * 512
assert fh.read(8192) == b"\x01" * 8192
assert fh.read(8192) == b"\x02" * 8192
assert fh.read(8192) == b"\x03" * 8192
assert fh.read(8192) == b"\x04" * 8192
assert fh.read(1) == b""

fh.seek(0)
assert fh.read(512) == b"\x01" * 512
assert fh.read(8192) == b"\x01" * 8192

fh.seek(1024)
assert fh.read(512) == b"\x03" * 512
assert fh.read(8192) == b"\x01" * 7168 + b"\x02" * 1024

fh.seek(512)
assert fh.read(1024) == b"\x01" * 1024

fh.seek(0)
assert fh.readall() == data

fh.seek(512)
assert fh.read(1024) == b"\x02" * 512 + b"\x03" * 512
assert fh.read(1024) == b"\x01" * 1024
with patch("io.DEFAULT_BUFFER_SIZE", 8):
assert fh.read(1024) == b"\x01" * 1024

fh.seek(0)
assert fh.read() == data

0 comments on commit eb7ea58

Please sign in to comment.