Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type checking #69

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions dissect/util/compression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING

from dissect.util.compression import lz4 as lz4_python
from dissect.util.compression import lzo as lzo_python

Expand All @@ -16,8 +18,8 @@
# Note that the pure Python implementation is not a full replacement of the
# native lz4 Python package: only the decompress() function is implemented.
try:
import lz4.block as lz4
import lz4.block as lz4_native
import lz4.block as lz4 # type: ignore
import lz4.block as lz4_native # type: ignore

Check warning on line 22 in dissect/util/compression/__init__.py

View check run for this annotation

Codecov / codecov/patch

dissect/util/compression/__init__.py#L22

Added line #L22 was not covered by tests
except ImportError:
lz4 = lz4_python
lz4_native = None
Expand All @@ -37,12 +39,19 @@
# Note that the pure Python implementation is not a full replacement of the
# native lzo Python package: only the decompress() function is implemented.
try:
import lzo
import lzo as lzo_native
import lzo # type: ignore
import lzo as lzo_native # type: ignore

Check warning on line 43 in dissect/util/compression/__init__.py

View check run for this annotation

Codecov / codecov/patch

dissect/util/compression/__init__.py#L43

Added line #L43 was not covered by tests
except ImportError:
lzo = lzo_python
lzo_native = None


from dissect.util.compression import lznt1, lzxpress, lzxpress_huffman, sevenbit

if TYPE_CHECKING:
lzo = lzo_python
lz4 = lz4_python

Check warning on line 54 in dissect/util/compression/__init__.py

View check run for this annotation

Codecov / codecov/patch

dissect/util/compression/__init__.py#L52-L54

Added lines #L52 - L54 were not covered by tests
__all__ = [
"lz4",
"lz4_native",
Expand Down
10 changes: 5 additions & 5 deletions dissect/util/compression/lz4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import io
import struct
from typing import BinaryIO
from typing import BinaryIO, cast

from dissect.util.exceptions import CorruptDataError

Expand All @@ -25,12 +25,12 @@ def _get_length(src: BinaryIO, length: int) -> int:


def decompress(
src: bytes | BinaryIO,
src: bytes | bytearray | memoryview | BinaryIO,
uncompressed_size: int = -1,
max_length: int = -1,
return_bytearray: bool = False,
return_bytes_read: bool = False,
) -> bytes | tuple[bytes, int]:
) -> bytes | bytearray | tuple[bytes | bytearray, int]:
"""LZ4 decompress from a file-like object up to a certain length. Assumes no header.

Args:
Expand All @@ -44,7 +44,7 @@ def decompress(
Returns:
The decompressed data or a tuple of the decompressed data and the amount of bytes read.
"""
if not hasattr(src, "read"):
if isinstance(src, (bytes, bytearray, memoryview)):
src = io.BytesIO(src)

dst = bytearray()
Expand Down Expand Up @@ -78,7 +78,7 @@ def decompress(
if len(read_buf) != 2:
raise CorruptDataError("Premature EOF")

(offset,) = struct.unpack("<H", read_buf)
(offset,) = cast(tuple[int], struct.unpack("<H", read_buf))

if offset == 0:
raise CorruptDataError("Offset can't be 0")
Expand Down
4 changes: 2 additions & 2 deletions dissect/util/compression/lznt1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _get_displacement(offset: int) -> int:
TAG_MASKS = [(1 << i) for i in range(8)]


def decompress(src: bytes | BinaryIO) -> bytes:
def decompress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes:
"""LZNT1 decompress from a file-like object or bytes.

Args:
Expand All @@ -34,7 +34,7 @@ def decompress(src: bytes | BinaryIO) -> bytes:
Returns:
The decompressed data.
"""
if not hasattr(src, "read"):
if isinstance(src, (bytes, bytearray, memoryview)):
src = io.BytesIO(src)

offset = src.tell()
Expand Down
4 changes: 2 additions & 2 deletions dissect/util/compression/lzo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _read_length(src: BinaryIO, val: int, mask: int) -> int:
return length + mask + val


def decompress(src: bytes | BinaryIO, header: bool = True, buflen: int = -1) -> bytes:
def decompress(src: bytes | bytearray | memoryview | BinaryIO, header: bool = True, buflen: int = -1) -> bytes:
"""LZO decompress from a file-like object or bytes. Assumes no header.

Arguments are largely compatible with python-lzo API.
Expand All @@ -36,7 +36,7 @@ def decompress(src: bytes | BinaryIO, header: bool = True, buflen: int = -1) ->
Returns:
The decompressed data.
"""
if not hasattr(src, "read"):
if isinstance(src, (bytes, bytearray, memoryview)):
src = io.BytesIO(src)

dst = bytearray()
Expand Down
6 changes: 3 additions & 3 deletions dissect/util/compression/lzxpress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import BinaryIO


def decompress(src: bytes | BinaryIO) -> bytes:
def decompress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes:
"""LZXPRESS decompress from a file-like object or bytes.

Args:
Expand All @@ -15,7 +15,7 @@ def decompress(src: bytes | BinaryIO) -> bytes:
Returns:
The decompressed data.
"""
if not hasattr(src, "read"):
if isinstance(src, (bytes, bytearray, memoryview)):
src = io.BytesIO(src)

offset = src.tell()
Expand All @@ -41,7 +41,7 @@ def decompress(src: bytes | BinaryIO) -> bytes:
if src.tell() - offset == size:
break

match = struct.unpack("<H", src.read(2))[0]
match: int = struct.unpack("<H", src.read(2))[0]
match_offset, match_length = divmod(match, 8)
match_offset += 1

Expand Down
20 changes: 12 additions & 8 deletions dissect/util/compression/lzxpress_huffman.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ def _read_16_bit(fh: BinaryIO) -> int:
class Node:
__slots__ = ("children", "is_leaf", "symbol")

def __init__(self, symbol: Symbol | None = None, is_leaf: bool = False):
def __init__(self, symbol: int = 0, is_leaf: bool = False):
self.symbol = symbol
self.is_leaf = is_leaf
self.children = [None, None]
self.children: list[Node | None] = [None, None]


def _add_leaf(nodes: list[Node], idx: int, mask: int, bits: int) -> int:
node = nodes[0]
i = idx + 1

while bits > 1:
while node and bits > 1:
bits -= 1
childidx = (mask >> bits) & 1
if node.children[childidx] is None:
Expand All @@ -38,6 +38,7 @@ def _add_leaf(nodes: list[Node], idx: int, mask: int, bits: int) -> int:
i += 1
node = node.children[childidx]

assert node
node.children[mask & 1] = nodes[idx]
return i

Expand Down Expand Up @@ -84,8 +85,9 @@ def _build_tree(buf: bytes) -> Node:


class BitString:
source: BinaryIO

def __init__(self):
self.source = None
self.mask = 0
self.bits = 0

Expand Down Expand Up @@ -114,16 +116,18 @@ def skip(self, n: int) -> None:
self.mask += _read_16_bit(self.source) << (16 - self.bits)
self.bits += 16

def decode(self, root: Node) -> Symbol:
def decode(self, root: Node) -> int:
node = root
while not node.is_leaf:
while node and not node.is_leaf:
bit = self.lookup(1)
self.skip(1)
node = node.children[bit]

assert node
return node.symbol


def decompress(src: bytes | BinaryIO) -> bytes:
def decompress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes:
"""LZXPRESS decompress from a file-like object or bytes.

Decompresses until EOF of the input data.
Expand All @@ -134,7 +138,7 @@ def decompress(src: bytes | BinaryIO) -> bytes:
Returns:
The decompressed data.
"""
if not hasattr(src, "read"):
if isinstance(src, (bytes, bytearray, memoryview)):
src = io.BytesIO(src)

dst = bytearray()
Expand Down
14 changes: 7 additions & 7 deletions dissect/util/compression/sevenbit.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from io import BytesIO
import io
from typing import BinaryIO


def compress(src: bytes | BinaryIO) -> bytes:
def compress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes:
"""Sevenbit compress from a file-like object or bytes.

Args:
Expand All @@ -13,8 +13,8 @@ def compress(src: bytes | BinaryIO) -> bytes:
Returns:
The compressed data.
"""
if not hasattr(src, "read"):
src = BytesIO(src)
if isinstance(src, (bytes, bytearray, memoryview)):
src = io.BytesIO(src)

dst = bytearray()

Expand All @@ -39,7 +39,7 @@ def compress(src: bytes | BinaryIO) -> bytes:
return bytes(dst)


def decompress(src: bytes | BinaryIO, wide: bool = False) -> bytes:
def decompress(src: bytes | bytearray | memoryview | BinaryIO, wide: bool = False) -> bytes:
"""Sevenbit decompress from a file-like object or bytes.

Args:
Expand All @@ -48,8 +48,8 @@ def decompress(src: bytes | BinaryIO, wide: bool = False) -> bytes:
Returns:
The decompressed data.
"""
if not hasattr(src, "read"):
src = BytesIO(src)
if isinstance(src, (bytes, bytearray, memoryview)):
src = io.BytesIO(src)

dst = bytearray()

Expand Down
4 changes: 2 additions & 2 deletions dissect/util/compression/xz.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
CRC_SIZE = 4


def repair_checksum(fh: BinaryIO) -> BinaryIO:
def repair_checksum(fh: BinaryIO) -> OverlayStream:
"""Repair CRC32 checksums for all headers in an XZ stream.

FortiOS XZ files have (on purpose) corrupt streams which they read using a modified ``xz`` binary.
Expand Down Expand Up @@ -55,7 +55,7 @@ def repair_checksum(fh: BinaryIO) -> BinaryIO:
# Parse the index
isize, num_records = _mbi(index[1:])
index = index[1 + isize : -4]
records = []
records: list[tuple[int, int]] = []
for _ in range(num_records):
if not index:
raise ValueError("Missing index size")
Expand Down
Loading
Loading