diff --git a/dissect/esedb/btree.py b/dissect/esedb/btree.py new file mode 100644 index 0000000..dd6fb64 --- /dev/null +++ b/dissect/esedb/btree.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from dissect.esedb.exceptions import KeyNotFoundError, NoNeighbourPageError +from dissect.esedb.page import Node, Page + +if TYPE_CHECKING: + from dissect.esedb.esedb import EseDB + + +class BTree: + """A simple implementation for searching the ESE B+Trees. + + This is a stateful interactive class that moves an internal cursor to a position within the BTree. + + Args: + esedb: An instance of :class:`~dissect.esedb.esedb.EseDB`. + page: The page to open a BTree on. + """ + + def __init__(self, esedb: EseDB, root: int | Page): + self.esedb = esedb + + if isinstance(root, int): + page_num = root + root = esedb.page(page_num) + else: + page_num = root.num + + self.root = root + + self._page = root + self._page_num = page_num + self._node_num = 0 + + def reset(self) -> None: + """Reset the internal state to the root of the BTree.""" + self._page = self.root + self._page_num = self._page.num + self._node_num = 0 + + def node(self) -> Node: + """Return the node the BTree is currently on.""" + return self._page.node(self._node_num) + + def next(self) -> Node: + """Move the BTree to the next node and return it. + + Can move the BTree to the next page as a side effect. + """ + if self._node_num + 1 > self._page.node_count - 1: + self.next_page() + else: + self._node_num += 1 + + return self.node() + + def next_page(self) -> None: + """Move the BTree to the next page in the tree. + + Raises: + NoNeighbourPageError: If the current page has no next page. + """ + if self._page.next_page: + self._page = self.esedb.page(self._page.next_page) + self._node_num = 0 + else: + raise NoNeighbourPageError(f"{self._page} has no next page") + + def prev(self) -> Node: + """Move the BTree to the previous node and return it. + + Can move the BTree to the previous page as a side effect. + """ + if self._node_num - 1 < 0: + self.prev_page() + else: + self._node_num -= 1 + + return self.node() + + def prev_page(self) -> None: + """Move the BTree to the previous page in the tree. + + Raises: + NoNeighbourPageError: If the current page has no previous page. + """ + if self._page.previous_page: + self._page = self.esedb.page(self._page.previous_page) + self._node_num = self._page.node_count - 1 + else: + raise NoNeighbourPageError(f"{self._page} has no previous page") + + def search(self, key: bytes, exact: bool = True) -> Node: + """Search the tree for the given key. + + Moves the BTree to the matching node, or on the last node that is less than the requested key. + + Args: + key: The key to search for. + exact: Whether to only return successfully on an exact match. + + Raises: + KeyNotFoundError: If an ``exact`` match was requested but not found. + """ + page = self._page + while True: + node = find_node(page, key) + + if page.is_branch: + page = self.esedb.page(node.child) + else: + self._page = page + self._page_num = page.num + self._node_num = node.num + break + + if exact and key != node.key: + raise KeyNotFoundError(f"Can't find key: {key}") + + return self.node() + + +def find_node(page: Page, key: bytes) -> Node: + """Search a page for a node matching ``key``. + + Args: + page: The page to search. + key: The key to search. + """ + first_node_idx = 0 + last_node_idx = page.node_count - 1 + + node = None + while first_node_idx < last_node_idx: + node_idx = (first_node_idx + last_node_idx) // 2 + node = page.node(node_idx) + + # It turns out that the way BTree keys are compared matches 1:1 with how Python compares bytes + # First compare data, then length + if key < node.key: + last_node_idx = node_idx + elif key == node.key: + if page.is_branch: + # If there's an exact match on a key on a branch page, the actual leaf nodes are in the next branch + # Page keys for branch pages appear to be non-inclusive upper bounds + node_idx = min(node_idx + 1, page.node_count - 1) + node = page.node(node_idx) + + return node + else: + first_node_idx = node_idx + 1 + + # We're at the last node + return page.node(first_node_idx) diff --git a/dissect/esedb/c_esedb.py b/dissect/esedb/c_esedb.py index 13b1ffc..524cac2 100644 --- a/dissect/esedb/c_esedb.py +++ b/dissect/esedb/c_esedb.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import struct import uuid @@ -425,6 +427,34 @@ DotNetGuid = 0x00040000, // index over GUID column according to .Net GUID sort order ImmutableStructure = 0x00080000, // Do not write to the input structures during a JetCreateIndexN call. }; + +flag IDBFLAG : uint16 { + Unique = 0x0001, // Duplicate keys not allowed + AllowAllNulls = 0x0002, // Make entries for NULL keys (all segments are null) + AllowFirstNull = 0x0004, // First index column NULL allowed in index + AllowSomeNulls = 0x0008, // Make entries for keys with some null segments + NoNullSeg = 0x0010, // Don't allow a NULL key segment + Primary = 0x0020, // Index is the primary index + LocaleSet = 0x0040, // Index locale information (locale name) is set (JET_bitIndexUnicode was specified). + Multivalued = 0x0080, // Has a multivalued segment + TemplateIndex = 0x0100, // Index of a template table + DerivedIndex = 0x0200, // Index derived from template table + // Note that this flag is persisted, but + // never used in an in-memory IDB, because + // we use the template index IDB instead. + LocalizedText = 0x0400, // Has a unicode text column? (code page is 1200) + SortNullsHigh = 0x0800, // NULL sorts after data + // Jan 2012: MSU is being removed. fidbUnicodeFixupOn should no longer be referenced. + UnicodeFixupOn_Deprecated = 0x1000, // Track entries with undefined Unicode codepoints + CrossProduct = 0x2000, // all combinations of multi-valued columns are indexed + DisallowTruncation = 0x4000, // fail update rather than allow key truncation + NestedTable = 0x8000, // combinations of multi-valued columns of same itagSequence are indexed +}; + +flag IDXFLAG : uint16 { + ExtendedColumns = 0x0001, // IDXSEGs are comprised of JET_COLUMNIDs, not FIDs + DotNetGuid = 0x0002, // GUIDs sort according to .Net rules +}; """ # noqa E501 c_esedb = cstruct().load(esedb_def) @@ -443,6 +473,8 @@ TAGFLD_HEADER = c_esedb.TAGFLD_HEADER CODEPAGE = c_esedb.CODEPAGE COMPRESSION_SCHEME = c_esedb.COMPRESSION_SCHEME +IDBFLAG = c_esedb.IDBFLAG +IDXFLAG = c_esedb.IDXFLAG CODEPAGE_MAP = { CODEPAGE.UNICODE: "utf-16-le", diff --git a/dissect/esedb/compression.py b/dissect/esedb/compression.py index 29a2536..a3b45a5 100644 --- a/dissect/esedb/compression.py +++ b/dissect/esedb/compression.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import struct -from typing import Optional from dissect.util.compression import lzxpress, sevenbit @@ -29,7 +30,7 @@ def decompress(buf: bytes) -> bytes: return buf -def decompress_size(buf: bytes) -> Optional[int]: +def decompress_size(buf: bytes) -> int | None: """Return the decompressed size of the given bytes according to the encoded compression scheme. Args: diff --git a/dissect/esedb/cursor.py b/dissect/esedb/cursor.py index f8eceaa..46b55e4 100644 --- a/dissect/esedb/cursor.py +++ b/dissect/esedb/cursor.py @@ -1,146 +1,172 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Iterator -from dissect.esedb.exceptions import KeyNotFoundError, NoNeighbourPageError -from dissect.esedb.page import Node, Page +from dissect.esedb.btree import BTree +from dissect.esedb.exceptions import NoNeighbourPageError +from dissect.esedb.record import Record if TYPE_CHECKING: - from dissect.esedb.esedb import EseDB + from dissect.esedb.index import Index + from dissect.esedb.page import Node class Cursor: - """A simple cursor implementation for searching the ESE B+Trees + """A simple cursor implementation for searching the ESE indexes. Args: - esedb: An instance of :class:`~dissect.esedb.esedb.EseDB`. - page: The page to open a cursor on. + index: The index to create the cursor for. """ - def __init__(self, esedb: EseDB, page: Union[int, Page]): - self.esedb = esedb + def __init__(self, index: Index): + self.index = index + self.table = index.table + self.esedb = index.esedb - if isinstance(page, int): - page_num = page - page = esedb.page(page_num) - else: - page_num = page.num + self._first = BTree(self.esedb, index.root) + self._secondary = None if index.is_primary else BTree(self.esedb, self.table.root) - self._page = page - self._page_num = page_num - self._node_num = 0 + def __iter__(self) -> Iterator[Record]: + while True: + yield self._record() - def node(self) -> Node: - """Return the node the cursor is currently on.""" - return self._page.node(self._node_num) + try: + self._first.next() + except NoNeighbourPageError: + break - def next(self) -> Node: - """Move the cursor to the next node and return it. + def _node(self) -> Node: + """Return the node the cursor is currently on. Resolves the secondary index if needed.""" + node = self._first.node() + if self._secondary: + self._secondary.reset() + node = self._secondary.search(node.data.tobytes(), exact=True) + return node - Can move the cursor to the next page as a side effect. - """ - if self._node_num + 1 > self._page.node_count - 1: - self.next_page() - else: - self._node_num += 1 + def _record(self) -> Record: + """Return the record the cursor is currently on.""" + return Record(self.table, self._node()) + + def reset(self) -> None: + """Reset the internal state.""" + self._first.reset() + if self._secondary: + self._secondary.reset() - return self.node() + def search(self, **kwargs) -> Record: + """Search the index for the requested values. - def next_page(self) -> None: - """Move the cursor to the next page in the tree. + Searching modifies the cursor state. Searching again will search from the current position. + Reset the cursor with :meth:`reset` to start from the beginning. - Raises: - NoNeighbourPageError: If the current page has no next page. + Args: + **kwargs: The columns and values to search for. """ - if self._page.next_page: - self._page = self.esedb.page(self._page.next_page) - self._node_num = 0 - else: - raise NoNeighbourPageError(f"{self._page} has no next page") + key = self.index.make_key(kwargs) + return self.search_key(key, exact=True) - def prev(self) -> Node: - """Move the cursor to the previous node and return it. + def search_key(self, key: bytes, exact: bool = True) -> Record: + """Search for a record with the given key. - Can move the cursor to the previous page as a side effect. + Args: + key: The key to search for. + exact: If ``True``, search for an exact match. If ``False``, sets the cursor on the + next record that is greater than or equal to the key. """ - if self._node_num - 1 < 0: - self.prev_page() - else: - self._node_num -= 1 + self._first.search(key, exact) + return self._record() - return self.node() + def seek(self, **kwargs) -> None: + """Seek to the record with the given values. - def prev_page(self) -> None: - """Move the cursor to the previous page in the tree. + Args: + **kwargs: The columns and values to seek to. + """ + key = self.index.make_key(kwargs) + self.search_key(key, exact=False) + + def seek_key(self, key: bytes) -> None: + """Seek to the record with the given key. - Raises: - NoNeighbourPageError: If the current page has no previous page. + Args: + key: The key to seek to. """ - if self._page.previous_page: - self._page = self.esedb.page(self._page.previous_page) - self._node_num = self._page.node_count - 1 - else: - raise NoNeighbourPageError(f"{self._page} has no previous page") + self._first.search(key, exact=False) - def search(self, key: bytes, exact: bool = True) -> Node: - """Search the tree for the given key. + def find(self, **kwargs) -> Record | None: + """Find a record in the index. - Moves the cursor to the matching node, or on the last node that is less than the requested key. + This differs from :meth:`search` in that it will allow additional filtering on non-indexed columns. Args: - key: The key to search for. - exact: Whether to only return successfully on an exact match. + **kwargs: The columns and values to search for. + """ + return next(self.find_all(**kwargs), None) + + def find_all(self, **kwargs) -> Iterator[Record]: + """Find all records in the index that match the given values. - Raises: - KeyNotFoundError: If an ``exact`` match was requested but not found. + This differs from :meth:`search` in that it will allows additional filtering on non-indexed columns. + If you only search on indexed columns, this will yield all records that match the indexed columns. + + Args: + **kwargs: The columns and values to search for. """ - page = self._page + indexed_columns = {c.name: kwargs.pop(c.name) for c in self.index.columns} + other_columns = kwargs + + # We need at least an exact match on the indexed columns + self.search(**indexed_columns) + + current_key = self._first.node().key + + # Check if we need to move the cursor back to find the first record while True: - node = find_node(page, key) - - if page.is_branch: - page = self.esedb.page(node.child) - else: - self._page = page - self._page_num = page.num - self._node_num = node.num + if current_key != self._first.node().key: + self._first.next() break - if exact and key != node.key: - raise KeyNotFoundError(f"Can't find key: {key}") + try: + self._first.prev() + except NoNeighbourPageError: + break - return self.node() + while True: + # Entries with the same indexed columns are guaranteed to be adjacent + if current_key != self._first.node().key: + break + record = self._record() + if all(record.get(k) == v for k, v in other_columns.items()): + yield record -def find_node(page: Page, key: bytes) -> Node: - """Search the tree, starting from the given ``page`` and search for ``key``. + try: + self._first.next() + except NoNeighbourPageError: + break - Args: - page: The page to start searching from. Should be a branch page. - key: The key to search. - """ - first_node_idx = 0 - last_node_idx = page.node_count - 1 - - node = None - while first_node_idx < last_node_idx: - node_idx = (first_node_idx + last_node_idx) // 2 - node = page.node(node_idx) - - # It turns out that the way BTree keys are compared matches 1:1 with how Python compares bytes - # First compare data, then length - if key < node.key: - last_node_idx = node_idx - elif key == node.key: - if page.is_branch: - # If there's an exact match on a key on a branch page, the actual leaf nodes are in the next branch - # Page keys for branch pages appear to be non-inclusive upper bounds - node_idx = min(node_idx + 1, page.node_count - 1) - node = page.node(node_idx) - - return node - else: - first_node_idx = node_idx + 1 - - # We're at the last node - return page.node(first_node_idx) + def record(self) -> Record: + """Return the record the cursor is currently on.""" + return self._record() + + def next(self) -> Record: + """Move the cursor to the next record and return it. + + Can move the cursor to the next page as a side effect. + """ + try: + self._first.next() + except NoNeighbourPageError: + raise IndexError("No next record") + return self._record() + + def prev(self) -> Record: + """Move the cursor to the previous node and return it. + + Can move the cursor to the previous page as a side effect. + """ + try: + self._first.prev() + except NoNeighbourPageError: + raise IndexError("No previous record") + return self._record() diff --git a/dissect/esedb/index.py b/dissect/esedb/index.py index 2ce22b8..28393aa 100644 --- a/dissect/esedb/index.py +++ b/dissect/esedb/index.py @@ -3,9 +3,9 @@ import struct import uuid from functools import cached_property -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING -from dissect.esedb.c_esedb import CODEPAGE, JET_bitIndex, JET_coltyp, RecordValue +from dissect.esedb.c_esedb import CODEPAGE, IDBFLAG, IDXFLAG, JET_coltyp, RecordValue from dissect.esedb.cursor import Cursor from dissect.esedb.lcmapstring import map_string from dissect.esedb.page import Node, Page @@ -29,16 +29,25 @@ class Index(object): record: The record in the catalog for this index. """ - def __init__(self, table: Table, record: Record = None): + def __init__(self, table: Table, record: Record | None = None): self.table = table self.record = record self.esedb = table.esedb self.name = record.get("Name") - self.flags = JET_bitIndex(record.get("Flags")) + flags = record.get("Flags") + self.idb_flags = IDBFLAG(flags & 0xFFFF) + self.idx_flags = IDXFLAG(flags >> 16) self._key_most = record.get("KeyMost") or JET_cbKeyMost_OLD self._var_seg_mac = record.get("VarSegMac") or self._key_most + def __repr__(self) -> str: + return f"" + + @property + def is_primary(self) -> bool: + return bool(self.idb_flags & IDBFLAG.Primary) + @cached_property def root(self) -> Page: """Return the root page of this index.""" @@ -60,14 +69,17 @@ def columns(self) -> list[Column]: """Return a list of all columns that are used in this index.""" return [self.table._column_id_map[cid] for cid in self.column_ids] + def cursor(self) -> Cursor: + """Create a new cursor for this index.""" + return Cursor(self) + def search(self, **kwargs) -> Record: """Search the index for the requested values. - Specify the column and value as a keyword argument. + Args: + **kwargs: The columns and values to search for. """ - key = self.make_key(kwargs) - node = self.search_key(key) - return Record(self.table, node) + return self.cursor().search(**kwargs) def search_key(self, key: bytes) -> Node: """Search the index for a specific key. @@ -75,8 +87,7 @@ def search_key(self, key: bytes) -> Node: Args: key: The key to search for. """ - cursor = Cursor(self.esedb, self.root) - return cursor.search(key) + return self.cursor().search_key(key) def key_from_record(self, record: Record) -> bytes: """Generate a key for this index from a record. @@ -112,9 +123,6 @@ def make_key(self, values: dict[str, RecordValue]) -> bytes: key = key[: self._key_most] return key - def __repr__(self) -> str: - return f"" - bPrefixNull = 0x00 bPrefixZeroLength = 0x40 @@ -266,7 +274,7 @@ def _encode_text(index: Index, column: Column, value: str, max_size: int) -> byt return bytes(key) -def _encode_guid(value: Union[str, uuid.UUID]) -> bytes: +def _encode_guid(value: str | uuid.UUID) -> bytes: if isinstance(value, str): value = uuid.UUID(value) guid_bytes = value.bytes_le diff --git a/dissect/esedb/lcmapstring.py b/dissect/esedb/lcmapstring.py index bba3195..0578e4e 100644 --- a/dissect/esedb/lcmapstring.py +++ b/dissect/esedb/lcmapstring.py @@ -212,7 +212,7 @@ def map_string(value: str, flags: MapFlags, locale: str) -> bytes: ) -def _filter_weights(weights): +def _filter_weights(weights: list[int]) -> list[int]: i = len(weights) while i > 0: if weights[i - 1] > 2: diff --git a/dissect/esedb/page.py b/dissect/esedb/page.py index f9f41db..e421048 100644 --- a/dissect/esedb/page.py +++ b/dissect/esedb/page.py @@ -2,7 +2,7 @@ import struct from functools import cached_property -from typing import TYPE_CHECKING, Iterator, Optional, Union +from typing import TYPE_CHECKING, Iterator from dissect.esedb.c_esedb import PAGE_FLAG, TAG_FLAG, c_esedb @@ -81,7 +81,7 @@ def is_branch(self) -> bool: return not self.is_leaf @cached_property - def key_prefix(self) -> Optional[bytes]: + def key_prefix(self) -> bytes | None: if not self.is_root: return bytes(self.tag(0).data) @@ -104,7 +104,7 @@ def tags(self) -> Iterator[Tag]: for i in range(1, self.tag_count): yield self.tag(i) - def node(self, num: int) -> Union[BranchNode, LeafNode]: + def node(self, num: int) -> BranchNode | LeafNode: """Retrieve a node by index. Nodes are just tags, but indexed from the first tag. @@ -123,7 +123,7 @@ def node(self, num: int) -> Union[BranchNode, LeafNode]: return self._node_cache[num] - def nodes(self) -> Iterator[Union[BranchNode, LeafNode]]: + def nodes(self) -> Iterator[BranchNode | LeafNode]: """Yield all nodes.""" for i in range(self.node_count): yield self.node(i) @@ -200,7 +200,7 @@ def __init__(self, page: Page, num: int): self.flags = TAG_FLAG(flags) def __repr__(self) -> str: - return f"" + return f"" class Node: diff --git a/dissect/esedb/record.py b/dissect/esedb/record.py index 838bc67..1621cd5 100644 --- a/dissect/esedb/record.py +++ b/dissect/esedb/record.py @@ -4,7 +4,7 @@ import struct from binascii import hexlify from functools import lru_cache -from typing import TYPE_CHECKING, Any, Iterator, Optional +from typing import TYPE_CHECKING, Any, Iterator from dissect.util.xmemoryview import xmemoryview @@ -16,7 +16,7 @@ from dissect.esedb.table import Column, Table -def noop(value: Any): +def noop(value: Any) -> Any: return value @@ -252,7 +252,7 @@ def _parse_value(self, column: Column, value: bytes, tag_field: TagField = None) return value - def _parse_multivalue(self, value: bytes, tag_field: TagField): + def _parse_multivalue(self, value: bytes, tag_field: TagField) -> list[bytes]: fSeparatedInstance = 0x8000 if tag_field.flags & TAGFLD_HEADER.TwoValues: @@ -277,6 +277,8 @@ def _parse_multivalue(self, value: bytes, tag_field: TagField): data = self.table.get_long_value(bytes(data)) values.append(data) value = values + else: + raise ValueError(f"Unknown flags for tag field: {tag_field}") if tag_field.flags & TAGFLD_HEADER.Compressed: # Only the first entry appears to be compressed @@ -284,7 +286,7 @@ def _parse_multivalue(self, value: bytes, tag_field: TagField): return value - def _get_fixed(self, column: Column) -> Optional[bytes]: + def _get_fixed(self, column: Column) -> bytes | None: """Parse a specific fixed column.""" if column.identifier <= self._last_fixed_id: # Check if it's not null @@ -303,7 +305,7 @@ def _get_fixed(self, column: Column) -> Optional[bytes]: return value - def _get_variable(self, column: Column) -> Optional[bytes]: + def _get_variable(self, column: Column) -> bytes | None: """Parse a specific variable column.""" if column.identifier <= self._last_variable_id: identifier_idx = column.identifier - 128 @@ -331,7 +333,7 @@ def _get_variable(self, column: Column) -> Optional[bytes]: return value - def _get_tagged(self, column: Column) -> Optional[bytes]: + def _get_tagged(self, column: Column) -> bytes | None: """Parse a specific tagged column.""" tag_field = None @@ -362,7 +364,7 @@ def _get_tag_field(self, idx: int) -> TagField: """Retrieve the :class:`TagField` at the given index in the ``TAGFLD`` array.""" return TagField(self, self._tagged_data_view[idx]) - def _find_tag_field_idx(self, identifier: int, is_derived: bool = False) -> Optional[TagField]: + def _find_tag_field_idx(self, identifier: int, is_derived: bool = False) -> TagField | None: """Find a tag field by identifier and optional derived flag. Performs a binary search in the tagged field array for the given identifier. The comparison algorithm used is diff --git a/dissect/esedb/table.py b/dissect/esedb/table.py index 4a07095..6d8f6ca 100644 --- a/dissect/esedb/table.py +++ b/dissect/esedb/table.py @@ -2,9 +2,10 @@ import struct from functools import cached_property -from typing import TYPE_CHECKING, Any, Iterator, Optional +from typing import TYPE_CHECKING, Any, Iterator from dissect.esedb import compression +from dissect.esedb.btree import BTree from dissect.esedb.c_esedb import ( CODEPAGE, COLUMN_TYPE_MAP, @@ -55,7 +56,7 @@ def __init__( self.name = name self.root_page = root_page self.columns: list[Column] = [] - self.indexes = [] + self.indexes: list[Index] = [] # Set by the catalog during parsing self._long_value_record: Record = None @@ -74,7 +75,7 @@ def __init__( self.record = record def __repr__(self) -> str: - return f"" + return f"
" @cached_property def root(self) -> Page: @@ -112,6 +113,19 @@ def column_names(self) -> list[str]: """Return a list of all the column names.""" return list(self._column_name_map.keys()) + @property + def primary_index(self) -> Index | None: + # It's generally the first index, but loop just in case + for index in self.indexes: + if index.is_primary: + return index + + def cursor(self) -> Cursor | None: + """Create a new cursor for this table.""" + primary_idx = self.primary_index + if primary_idx: + return primary_idx.cursor() + def index(self, name: str) -> Index: """Return the index with the given name. @@ -126,6 +140,39 @@ def index(self, name: str) -> Index: except KeyError: raise KeyError(f"No index with this name in table {self.name}: {name}") + def find_index(self, column_names: list[str]) -> Index | None: + """Find the most suitable index to search for the given columns. + + Args: + column_names: A list of column names to find the best index for. + """ + best_match = 0 + best_index = None + for index in self.indexes: + # We want to find the index that has the most matching columns in the order they are indexed + i = 0 + for column in index.columns: + if column.name not in column_names: + break + i += 1 + + if i > best_match: + best_index = index + best_match = i + + return best_index + + def search(self, **kwargs) -> Record | None: + """Search for a record in the table. + + Args: + **kwargs: The columns and values to search for. + + Returns: + The first record that matches the search criteria, or None if no record was found. + """ + return self.cursor().search(**kwargs) + def records(self) -> Iterator[Record]: """Return an iterator of all the records of the table.""" for node in self.root.iter_leaf_nodes(): @@ -138,8 +185,8 @@ def get_long_value(self, key: bytes) -> bytes: key: The lookup key for the long value. """ rkey = key[::-1] - cursor = Cursor(self.esedb, self.lv_page) - header = cursor.search(rkey) + btree = BTree(self.esedb, self.lv_page) + header = btree.search(rkey) _, size = struct.unpack("<2I", header.data) chunks = [] @@ -147,7 +194,7 @@ def get_long_value(self, key: bytes) -> bytes: while True: try: - node = cursor.next() + node = btree.next() if not node.key.startswith(rkey): break except NoNeighbourPageError: @@ -199,6 +246,9 @@ def __init__(self, identifier: int, name: str, type_: JET_coltyp, record: Record self.record = record + def __repr__(self) -> str: + return f"" + @property def offset(self) -> int: return self._offset @@ -231,13 +281,13 @@ def size(self) -> int: return self.ctype.size @cached_property - def default(self) -> Optional[Any]: + def default(self) -> Any | None: if self.record and self.record.get("DefaultValue"): return self.record.get("DefaultValue") return None @cached_property - def encoding(self) -> Optional[CODEPAGE]: + def encoding(self) -> CODEPAGE | None: if self.is_text: return CODEPAGE(self.record.get("PagesOrLocale")) if self.record else CODEPAGE.ASCII return None @@ -246,9 +296,6 @@ def encoding(self) -> Optional[CODEPAGE]: def ctype(self) -> ColumnType: return COLUMN_TYPE_MAP[self.type.value] - def __repr__(self) -> str: - return f"" - class Catalog: """Parse and interact with the catalog table. diff --git a/dissect/esedb/tools/sru.py b/dissect/esedb/tools/sru.py index 265e1e8..772c9a2 100644 --- a/dissect/esedb/tools/sru.py +++ b/dissect/esedb/tools/sru.py @@ -1,7 +1,7 @@ from __future__ import annotations import argparse -from typing import BinaryIO, Iterator, Optional +from typing import BinaryIO, Iterator from dissect.util.sid import read_sid from dissect.util.ts import oatimestamp, wintimestamp @@ -56,7 +56,7 @@ def __init__(self, fh: BinaryIO): id_map_table = self.esedb.table("SruDbIdMapTable") self.id_map = {r.get("IdIndex"): r for r in id_map_table.records()} - def get_table(self, table_name: str = None, table_guid: str = None) -> Optional[Table]: + def get_table(self, table_name: str = None, table_guid: str = None) -> Table | None: if all((table_name, table_guid)) or not any((table_name, table_guid)): raise ValueError("Either table_name or table_guid must be provided") @@ -85,7 +85,7 @@ def get_table_entries(self, table: Table = None, table_name: str = None, table_g for record in table.records(): yield Entry(self, table, record) - def resolve_id(self, value: int) -> Optional[str]: + def resolve_id(self, value: int) -> str | None: try: record = self.id_map[value] except KeyError: @@ -140,7 +140,7 @@ def __getattr__(self, attr: str) -> RecordValue: def __repr__(self) -> str: column_values = serialise_record_column_values(self.record) - return f"" + return f"" def main(): diff --git a/tests/conftest.py b/tests/conftest.py index 6a58ffb..38bc708 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,63 +1,64 @@ import gzip import os +from typing import IO, BinaryIO, Iterator import pytest -def absolute_path(filename): +def absolute_path(filename: str) -> str: return os.path.join(os.path.dirname(__file__), filename) -def open_file(name, mode="rb"): +def open_file(name: str, mode: str = "rb") -> Iterator[IO]: with open(absolute_path(name), mode) as f: yield f -def open_file_gz(name, mode="rb"): +def open_file_gz(name: str, mode: str = "rb") -> Iterator[IO]: with gzip.GzipFile(absolute_path(name), mode) as f: yield f @pytest.fixture -def basic_db(): +def basic_db() -> Iterator[BinaryIO]: yield from open_file_gz("data/basic.edb.gz") @pytest.fixture -def binary_db(): +def binary_db() -> Iterator[BinaryIO]: yield from open_file_gz("data/binary.edb.gz") @pytest.fixture -def text_db(): +def text_db() -> Iterator[BinaryIO]: yield from open_file_gz("data/text.edb.gz") @pytest.fixture -def multi_db(): +def multi_db() -> Iterator[BinaryIO]: yield from open_file_gz("data/multi.edb.gz") @pytest.fixture -def default_db(): +def default_db() -> Iterator[BinaryIO]: yield from open_file_gz("data/default.edb.gz") @pytest.fixture -def index_db(): +def index_db() -> Iterator[BinaryIO]: yield from open_file_gz("data/index.edb.gz") @pytest.fixture -def large_db(): +def large_db() -> Iterator[BinaryIO]: yield from open_file_gz("data/large.edb.gz") @pytest.fixture -def sru_db(): +def sru_db() -> Iterator[BinaryIO]: yield from open_file_gz("data/SRUDB.dat.gz") @pytest.fixture -def ual_db(): +def ual_db() -> Iterator[BinaryIO]: yield from open_file_gz("data/Current.mdb.gz") diff --git a/tests/test_cursor.py b/tests/test_cursor.py new file mode 100644 index 0000000..cf8e096 --- /dev/null +++ b/tests/test_cursor.py @@ -0,0 +1,59 @@ +from typing import BinaryIO + +from dissect.esedb.esedb import EseDB + + +def test_cursor(basic_db: BinaryIO) -> None: + db = EseDB(basic_db) + table = db.table("basic") + idx = table.index("IxId") + + cursor = idx.cursor() + record = cursor.search(Id=1) + assert record.Id == 1 + record = cursor.next() + assert record.Id == 2 + record = cursor.prev() + assert record.Id == 1 + assert record.Id == cursor.record().Id + + +def test_cursor_iterator(basic_db: BinaryIO) -> None: + db = EseDB(basic_db) + table = db.table("basic") + idx = table.index("IxId") + + cursor = idx.cursor() + records = list(cursor) + assert len(records) == 2 + assert records[0].Id == 1 + assert records[1].Id == 2 + + +def test_cursor_search(ual_db: BinaryIO) -> None: + db = EseDB(ual_db) + table = db.table("CLIENTS") + idx = table.index("Username_RoleGuid_TenantId_index") + + cursor = idx.cursor() + records = list( + cursor.find_all( + AuthenticatedUserName="blackclover\\administrator", + RoleGuid="ad495fc3-0eaa-413d-ba7d-8b13fa7ec598", + TenantId="2417e4c3-5467-40c5-809b-12b59a86c102", + ) + ) + + assert len(records) == 5 + + cursor.reset() + records = list( + cursor.find_all( + AuthenticatedUserName="blackclover\\administrator", + RoleGuid="ad495fc3-0eaa-413d-ba7d-8b13fa7ec598", + TenantId="2417e4c3-5467-40c5-809b-12b59a86c102", + Day204=4, + ) + ) + + assert len(records) == 1 diff --git a/tests/test_esedb.py b/tests/test_esedb.py index 91e1114..42f6221 100644 --- a/tests/test_esedb.py +++ b/tests/test_esedb.py @@ -1,4 +1,5 @@ import datetime +from typing import BinaryIO from dissect.util.ts import oatimestamp @@ -6,7 +7,7 @@ from dissect.esedb.esedb import EseDB -def test_basic_types(basic_db): +def test_basic_types(basic_db: BinaryIO) -> None: db = EseDB(basic_db) table = db.table("basic") @@ -54,7 +55,7 @@ def test_basic_types(basic_db): assert oatimestamp(records[1].DateTime) == datetime.datetime(1337, 6, 9, 0, 0, tzinfo=datetime.timezone.utc) -def test_binary_types(binary_db): +def test_binary_types(binary_db: BinaryIO) -> None: db = EseDB(binary_db) table = db.table("binary") @@ -90,7 +91,7 @@ def test_binary_types(binary_db): assert records[0].MaxLongCompressedBinary == b"test max long compressed binary data " + (b"a" * 900) -def test_text_types(text_db): +def test_text_types(text_db: BinaryIO) -> None: db = EseDB(text_db) table = db.table("text") @@ -158,7 +159,7 @@ def test_text_types(text_db): ) -def test_multivalue_types(multi_db): +def test_multivalue_types(multi_db: BinaryIO) -> None: db = EseDB(multi_db) table = db.table("multi") @@ -305,7 +306,7 @@ def test_multivalue_types(multi_db): assert records[1].UnsignedShort is None -def test_default_db(default_db): +def test_default_db(default_db: BinaryIO) -> None: db = EseDB(default_db) table = db.table("default") @@ -334,7 +335,7 @@ def test_default_db(default_db): assert records[0].LongUnicode == "Long default Unicode 🦊 " + ("a" * 64) -def test_large_db(large_db): +def test_large_db(large_db: BinaryIO) -> None: db = EseDB(large_db) table = db.table("large") diff --git a/tests/test_index.py b/tests/test_index.py index c517e11..67294d2 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -1,7 +1,9 @@ +from typing import BinaryIO + from dissect.esedb.esedb import EseDB -def test_index(index_db): +def test_index(index_db: BinaryIO) -> None: db = EseDB(index_db) table = db.table("index") @@ -9,76 +11,95 @@ def test_index(index_db): assert table.indexes[0].name == "IxId" assert table.indexes[0].column_ids == [1] - assert table.indexes[0].search(Id=1) + record = table.indexes[0].search(Id=1) + assert record.Id == 1 assert table.indexes[1].name == "IxBit" assert table.indexes[1].column_ids == [2] - assert table.indexes[1].search(Bit=False) + record = table.indexes[1].search(Bit=False) + assert record.Bit is False assert table.indexes[2].name == "IxUnsignedByte" assert table.indexes[2].column_ids == [3] - assert table.indexes[2].search(UnsignedByte=213) + record = table.indexes[2].search(UnsignedByte=213) + assert record.UnsignedByte == 213 assert table.indexes[3].name == "IxShort" assert table.indexes[3].column_ids == [4] - assert table.indexes[3].search(Short=-1337) + record = table.indexes[3].search(Short=-1337) + assert record.Short == -1337 assert table.indexes[4].name == "IxLong" assert table.indexes[4].column_ids == [5] - assert table.indexes[4].search(Long=-13371337) + record = table.indexes[4].search(Long=-13371337) + assert record.Long == -13371337 assert table.indexes[5].name == "IxCurrenc" assert table.indexes[5].column_ids == [6] - assert table.indexes[5].search(Currency=1337133713371337) + record = table.indexes[5].search(Currency=1337133713371337) + assert record.Currency == 1337133713371337 assert table.indexes[6].name == "IxIEEESingle" assert table.indexes[6].column_ids == [7] - assert table.indexes[6].search(IEEESingle=1.0) + record = table.indexes[6].search(IEEESingle=1.0) + assert record.IEEESingle == 1.0 assert table.indexes[7].name == "IxIEEEDouble" assert table.indexes[7].column_ids == [8] - assert table.indexes[7].search(IEEEDouble=13371337.13371337) + record = table.indexes[7].search(IEEEDouble=13371337.13371337) + assert record.IEEEDouble == 13371337.13371337 assert table.indexes[8].name == "IxDateTime" assert table.indexes[8].column_ids == [9] - assert table.indexes[8].search(DateTime=4675210852477960192) + record = table.indexes[8].search(DateTime=4675210852477960192) + assert record.DateTime == 4675210852477960192 assert table.indexes[9].name == "IxUnsignedLong" assert table.indexes[9].column_ids == [10] - assert table.indexes[9].search(UnsignedLong=13371337) + record = table.indexes[9].search(UnsignedLong=13371337) + assert record.UnsignedLong == 13371337 assert table.indexes[10].name == "IxLongLong" assert table.indexes[10].column_ids == [11] - assert table.indexes[10].search(LongLong=-13371337) + record = table.indexes[10].search(LongLong=-13371337) + assert record.LongLong == -13371337 assert table.indexes[11].name == "IxGUID" assert table.indexes[11].column_ids == [12] - assert table.indexes[11].search(GUID="3f360af1-6766-46dc-9af2-0dacf295c2a1") + record = table.indexes[11].search(GUID="3f360af1-6766-46dc-9af2-0dacf295c2a1") + assert record.GUID == "3f360af1-6766-46dc-9af2-0dacf295c2a1" assert table.indexes[12].name == "IxUnsignedShort" assert table.indexes[12].column_ids == [13] - assert table.indexes[12].search(UnsignedShort=1337) + record = table.indexes[12].search(UnsignedShort=1337) + assert record.UnsignedShort == 1337 assert table.indexes[13].name == "IxBinary" assert table.indexes[13].column_ids == [128] - assert table.indexes[13].search(Binary=b"test binary data") + record = table.indexes[13].search(Binary=b"test binary data") + assert record.Binary == b"test binary data" assert table.indexes[14].name == "IxLongBinary" assert table.indexes[14].column_ids == [256] - assert table.indexes[14].search(LongBinary=b"test long binary data " + (b"a" * 1024)) + record = table.indexes[14].search(LongBinary=b"test long binary data " + (b"a" * 1000)) + assert record.LongBinary == b"test long binary data " + (b"a" * 1000) assert table.indexes[15].name == "IxASCII" assert table.indexes[15].column_ids == [129] - assert table.indexes[15].search(ASCII="Simple ASCII text") + record = table.indexes[15].search(ASCII="Simple ASCII text") + assert record.ASCII == "Simple ASCII text" assert table.indexes[16].name == "IxUnicode" assert table.indexes[16].column_ids == [130] - assert table.indexes[16].search(Unicode="Simple Unicode text 🦊") + record = table.indexes[16].search(Unicode="Simple Unicode text 🦊") + assert record.Unicode == "Simple Unicode text 🦊" assert table.indexes[17].name == "IxLongASCII" assert table.indexes[17].column_ids == [257] - assert table.indexes[17].search(LongASCII="Long ASCII text " + ("a" * 1024)) + record = table.indexes[17].search(LongASCII="Long ASCII text " + ("a" * 1024)) + assert record.LongASCII == "Long ASCII text " + ("a" * 1024) assert table.indexes[18].name == "IxLongUnicode" assert table.indexes[18].column_ids == [258] - assert table.indexes[18].search(LongUnicode="Long Unicode text 🦊 " + ("a" * 1024)) + record = table.indexes[18].search(LongUnicode="Long Unicode text 🦊 " + ("a" * 1024)) + assert record.LongUnicode == "Long Unicode text 🦊 " + ("a" * 1024) diff --git a/tests/test_record.py b/tests/test_record.py index afea747..86ba939 100644 --- a/tests/test_record.py +++ b/tests/test_record.py @@ -3,7 +3,7 @@ from dissect.esedb.esedb import EseDB -def test_as_dict(basic_db: BinaryIO): +def test_as_dict(basic_db: BinaryIO) -> None: db = EseDB(basic_db) table = db.table("basic") diff --git a/tests/test_sru.py b/tests/test_sru.py index da29950..915dfb5 100644 --- a/tests/test_sru.py +++ b/tests/test_sru.py @@ -1,7 +1,9 @@ +from typing import BinaryIO + from dissect.esedb.tools.sru import SRU -def test_sru(sru_db): +def test_sru(sru_db: BinaryIO) -> None: db = SRU(sru_db) records = list(db.entries()) diff --git a/tests/test_table.py b/tests/test_table.py new file mode 100644 index 0000000..5260240 --- /dev/null +++ b/tests/test_table.py @@ -0,0 +1,32 @@ +from unittest.mock import MagicMock + +from dissect.esedb.table import Table + + +def test_find_index() -> None: + mock_column_id = MagicMock() + mock_column_id.name = "Id" + mock_column_bit = MagicMock() + mock_column_bit.name = "Bit" + mock_column_unsigned_byte = MagicMock() + mock_column_unsigned_byte.name = "UnsignedByte" + + mock_idx_id = MagicMock(name="IxId") + mock_idx_id.is_primary = True + mock_idx_id.columns = [mock_column_id] + mock_idx_bit = MagicMock(name="IxBit") + mock_idx_bit.is_primary = False + mock_idx_bit.columns = [mock_column_bit] + mock_idx_multiple = MagicMock(name="IxMultiple") + mock_idx_multiple.is_primary = False + mock_idx_multiple.columns = [mock_column_bit, mock_column_unsigned_byte] + + table = Table(MagicMock(), 69, "index", indexes=[mock_idx_id, mock_idx_bit, mock_idx_multiple]) + + assert table.find_index(["Id"]) == mock_idx_id + assert table.find_index(["Bit"]) == mock_idx_bit + assert table.find_index(["Bit", "UnsignedByte"]) == mock_idx_multiple + assert table.find_index(["UnsignedByte", "Bit"]) == mock_idx_multiple + assert table.find_index(["UnsignedByte"]) is None + assert table.find_index(["Id", "Bit"]) == mock_idx_id + assert table.find_index(["Bit", "SomethingElse"]) == mock_idx_bit diff --git a/tests/test_ual.py b/tests/test_ual.py index c0e4562..6062f50 100644 --- a/tests/test_ual.py +++ b/tests/test_ual.py @@ -1,7 +1,9 @@ +from typing import BinaryIO + from dissect.esedb.tools.ual import UAL -def test_ual(ual_db): +def test_ual(ual_db: BinaryIO) -> None: db = UAL(ual_db) assert len(list(db.get_table_records("CLIENTS"))) == 19