diff --git a/dissect/target/loader.py b/dissect/target/loader.py index a61588e2a..9f9159f1e 100644 --- a/dissect/target/loader.py +++ b/dissect/target/loader.py @@ -77,7 +77,7 @@ def detect(path: Path) -> bool: raise NotImplementedError() @staticmethod - def find_all(path: Path) -> Iterator[Path]: + def find_all(path: Path, **kwargs) -> Iterator[Path]: """Finds all targets to load from ``path``. This can be used to open multiple targets from a target path that doesn't necessarily map to files on a disk. diff --git a/dissect/target/loaders/mqtt.py b/dissect/target/loaders/mqtt.py index ffec6483a..d9154871d 100644 --- a/dissect/target/loaders/mqtt.py +++ b/dissect/target/loaders/mqtt.py @@ -6,17 +6,15 @@ import urllib from dataclasses import dataclass from functools import lru_cache -from io import BytesIO from pathlib import Path from struct import pack, unpack_from -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Iterator, Optional, Union import paho.mqtt.client as mqtt from dissect.util.stream import AlignedStream from dissect.target.containers.raw import RawContainer from dissect.target.exceptions import LoaderError -from dissect.target.filesystem import VirtualFilesystem from dissect.target.loader import Loader from dissect.target.plugin import arg from dissect.target.target import Target @@ -74,7 +72,7 @@ def __init__(self, broker: Broker, host: str): self.info = lru_cache(128)(self.info) self.read = lru_cache(128)(self.read) - def topo(self, peers: int): + def topo(self, peers: int) -> list[str]: self.broker.topology(self.host) while len(self.broker.peers(self.host)) < peers: @@ -148,7 +146,7 @@ def read(self, host: str, disk_id: int, seek_address: int, read_length: int) -> def disk(self, host: str) -> DiskMessage: return self.diskinfo[host] - def peers(self, host: str) -> int: + def peers(self, host: str) -> list[str]: return self.topo[host] def _on_disk(self, hostname: str, payload: bytes) -> None: @@ -252,9 +250,6 @@ def connect(self) -> None: class MQTTLoader(Loader): """Load remote targets through a broker.""" - PATH = "/remote/data/hosts.txt" - FOLDER = "/remote/hosts" - connection = None broker = None peers = [] @@ -262,10 +257,15 @@ class MQTTLoader(Loader): def __init__(self, path: Union[Path, str], **kwargs): super().__init__(path) cls = MQTTLoader + self.broker = cls.broker + self.connection = MQTTConnection(self.broker, path) - if str(path).startswith("/remote/hosts/host"): - self.path = path.read_text() # update path to reflect the resolved host + @staticmethod + def detect(path: Path) -> bool: + return False + def find_all(path: Path, **kwargs) -> Iterator[str]: + cls = MQTTLoader num_peers = 1 if cls.broker is None: if (uri := kwargs.get("parsed_path")) is None: @@ -275,26 +275,10 @@ def __init__(self, path: Union[Path, str], **kwargs): cls.broker.connect() num_peers = int(options.get("peers", 1)) - self.broker = cls.broker - self.connection = MQTTConnection(self.broker, self.path) - self.peers = self.connection.topo(num_peers) + cls.connection = MQTTConnection(cls.broker, path) + cls.peers = cls.connection.topo(num_peers) + yield from cls.peers def map(self, target: Target) -> None: - if len(self.peers) == 1 and self.peers[0] == str(self.path): - target.path = Path(str(self.path)) - for disk in self.connection.info(): - target.disks.add(RawContainer(disk)) - else: - target.props["mqtt"] = True - - vfs = VirtualFilesystem() - vfs.map_file_fh(self.PATH, BytesIO("\n".join(self.peers).encode("utf-8"))) - for index, peer in enumerate(self.peers): - vfs.map_file_fh(f"{self.FOLDER}/host{index}-{peer}", BytesIO(peer.encode("utf-8"))) - - target.fs.mount("/data", vfs) - target.filesystems.add(vfs) - - @staticmethod - def detect(path: Path) -> bool: - return str(path).startswith("/remote/hosts/host") + for disk in self.connection.info(): + target.disks.add(RawContainer(disk)) diff --git a/dissect/target/plugins/child/mqtt.py b/dissect/target/plugins/child/mqtt.py deleted file mode 100644 index ffe57a83f..000000000 --- a/dissect/target/plugins/child/mqtt.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Iterator - -from flow.record.fieldtypes import posix_path - -from dissect.target.exceptions import UnsupportedPluginError -from dissect.target.helpers.record import ChildTargetRecord -from dissect.target.plugin import ChildTargetPlugin - -if TYPE_CHECKING: - from dissect.target.target import Target - - -class MQTT(ChildTargetPlugin): - """Child target plugin that yields from remote broker.""" - - __type__ = "mqtt" - - PATH = "/remote/data/hosts.txt" - FOLDER = "/remote/hosts" - - def __init__(self, target: Target): - super().__init__(target) - - def check_compatible(self) -> None: - if not self.target.props.get("mqtt") or not self.target.fs.path(self.PATH).exists(): - raise UnsupportedPluginError("No remote children.txt file found.") - - def list_children(self) -> Iterator[ChildTargetRecord]: - hosts = self.target.fs.path(self.PATH).read_text(encoding="utf-8").split("\n") - for index, host in enumerate(hosts): - yield ChildTargetRecord( - type=self.__type__, path=posix_path(f"{self.FOLDER}/host{index}-{host}"), _target=self.target - ) diff --git a/dissect/target/target.py b/dissect/target/target.py index 3864f8057..ac922310c 100644 --- a/dissect/target/target.py +++ b/dissect/target/target.py @@ -280,7 +280,7 @@ def _find(find_path: Path, parsed_path: Optional[urllib.parse.ParseResult]): continue getlogger(entry).debug("Attempting to use loader: %s", loader_cls) - for sub_entry in loader_cls.find_all(entry): + for sub_entry in loader_cls.find_all(entry, parsed_path=parsed_path): try: ldr = loader_cls(sub_entry, parsed_path=parsed_path) except Exception as e: diff --git a/tests/loaders/test_mqtt.py b/tests/loaders/test_mqtt.py index 9a6451616..9e9eefc79 100644 --- a/tests/loaders/test_mqtt.py +++ b/tests/loaders/test_mqtt.py @@ -26,7 +26,9 @@ def publish(self, topic: str, *args) -> None: if command == "TOPO": tokens[2] = "ID" response.topic = "/".join(tokens) - response.payload = self.hostname.encode("utf-8") + for host in self.hostnames: + response.payload = host.encode("utf-8") + self.on_message(self, None, response) elif tokens[2] == "INFO": tokens[2] = "DISKS" response.topic = "/".join(tokens) @@ -61,12 +63,13 @@ def mock_client(mock_paho: MagicMock) -> Iterator[MagicMock]: @pytest.mark.parametrize( - "alias, host, disks, disk, seek, read, expected", + "alias, hosts, disks, disk, seek, read, expected", [ - ("host1", "host1", [3], 0, 0, 3, b"\x00\x01\x02"), # basic - ("host2", "host2", [10], 0, 1, 3, b"\x01\x02\x03"), # + use offset - ("group1", "host3", [10], 0, 1, 3, b"\x01\x02\x03"), # + use alias - ("group2", "host4", [10, 10, 1], 1, 1, 3, b"\x01\x02\x03"), # + use disk 2 + ("host1", ["host1"], [3], 0, 0, 3, b"\x00\x01\x02"), # basic + ("host2", ["host2"], [10], 0, 1, 3, b"\x01\x02\x03"), # + use offset + ("group1", ["host3"], [10], 0, 1, 3, b"\x01\x02\x03"), # + use alias + ("group2", ["host4"], [10, 10, 1], 1, 1, 3, b"\x01\x02\x03"), # + use disk 2 + ("group3", ["host4", "host5"], [10, 10, 1], 1, 1, 3, b"\x01\x02\x03"), # + use disk 2 ], ) @patch.object(time, "sleep") # improve speed during test, no need to wait for peers @@ -74,7 +77,7 @@ def test_remote_loader_stream( time: MagicMock, mock_client: MagicMock, alias: str, - host: str, + hosts: list[str], disks: list[int], disk: int, seek: int, @@ -86,15 +89,15 @@ def test_remote_loader_stream( broker = Broker("0.0.0.0", "1884", "key", "crt", "ca", "case1") broker.connect() broker.mqtt_client.fill_disks(disks) - broker.mqtt_client.hostname = host + broker.mqtt_client.hostnames = hosts with patch("dissect.target.loaders.mqtt.MQTTLoader.broker", broker): targets = list( Target.open_all( [f"mqtt://{alias}?broker=0.0.0.0&port=1884&key=key&crt=crt&ca=ca&peers=1&case=case1"], - include_children=True, ) ) + assert len(targets) == len(hosts) target = targets[-1] target.disks[disk].seek(seek) data = target.disks[disk].read(read) diff --git a/tests/test_target.py b/tests/test_target.py index 1f74bf713..452be25ab 100644 --- a/tests/test_target.py +++ b/tests/test_target.py @@ -128,7 +128,7 @@ def detect(path: Path) -> bool: return path.is_dir() and path.joinpath("select.txt").exists() @staticmethod - def find_all(path: Path) -> Iterator[Path]: + def find_all(path: Path, **kwargs) -> Iterator[Path]: return [Path("/dir/raw1.img").as_posix(), Path("/dir/raw3.img").as_posix()] def map(self, target: Target):