Skip to content

Commit

Permalink
simplify MQTT-loader (fox-it#578)
Browse files Browse the repository at this point in the history
  • Loading branch information
cecinestpasunepipe authored Mar 26, 2024
1 parent 12a4e46 commit 42add54
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 78 deletions.
2 changes: 1 addition & 1 deletion dissect/target/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def detect(path: Path) -> bool:
raise NotImplementedError()

@staticmethod
def find_all(path: Path) -> Iterator[Path]:
def find_all(path: Path, **kwargs) -> Iterator[Path]:
"""Finds all targets to load from ``path``.
This can be used to open multiple targets from a target path that doesn't necessarily map to files on a disk.
Expand Down
46 changes: 15 additions & 31 deletions dissect/target/loaders/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@
import urllib
from dataclasses import dataclass
from functools import lru_cache
from io import BytesIO
from pathlib import Path
from struct import pack, unpack_from
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Iterator, Optional, Union

import paho.mqtt.client as mqtt
from dissect.util.stream import AlignedStream

from dissect.target.containers.raw import RawContainer
from dissect.target.exceptions import LoaderError
from dissect.target.filesystem import VirtualFilesystem
from dissect.target.loader import Loader
from dissect.target.plugin import arg
from dissect.target.target import Target
Expand Down Expand Up @@ -74,7 +72,7 @@ def __init__(self, broker: Broker, host: str):
self.info = lru_cache(128)(self.info)
self.read = lru_cache(128)(self.read)

def topo(self, peers: int):
def topo(self, peers: int) -> list[str]:
self.broker.topology(self.host)

while len(self.broker.peers(self.host)) < peers:
Expand Down Expand Up @@ -148,7 +146,7 @@ def read(self, host: str, disk_id: int, seek_address: int, read_length: int) ->
def disk(self, host: str) -> DiskMessage:
return self.diskinfo[host]

def peers(self, host: str) -> int:
def peers(self, host: str) -> list[str]:
return self.topo[host]

def _on_disk(self, hostname: str, payload: bytes) -> None:
Expand Down Expand Up @@ -252,20 +250,22 @@ def connect(self) -> None:
class MQTTLoader(Loader):
"""Load remote targets through a broker."""

PATH = "/remote/data/hosts.txt"
FOLDER = "/remote/hosts"

connection = None
broker = None
peers = []

def __init__(self, path: Union[Path, str], **kwargs):
super().__init__(path)
cls = MQTTLoader
self.broker = cls.broker
self.connection = MQTTConnection(self.broker, path)

if str(path).startswith("/remote/hosts/host"):
self.path = path.read_text() # update path to reflect the resolved host
@staticmethod
def detect(path: Path) -> bool:
return False

def find_all(path: Path, **kwargs) -> Iterator[str]:
cls = MQTTLoader
num_peers = 1
if cls.broker is None:
if (uri := kwargs.get("parsed_path")) is None:
Expand All @@ -275,26 +275,10 @@ def __init__(self, path: Union[Path, str], **kwargs):
cls.broker.connect()
num_peers = int(options.get("peers", 1))

self.broker = cls.broker
self.connection = MQTTConnection(self.broker, self.path)
self.peers = self.connection.topo(num_peers)
cls.connection = MQTTConnection(cls.broker, path)
cls.peers = cls.connection.topo(num_peers)
yield from cls.peers

def map(self, target: Target) -> None:
if len(self.peers) == 1 and self.peers[0] == str(self.path):
target.path = Path(str(self.path))
for disk in self.connection.info():
target.disks.add(RawContainer(disk))
else:
target.props["mqtt"] = True

vfs = VirtualFilesystem()
vfs.map_file_fh(self.PATH, BytesIO("\n".join(self.peers).encode("utf-8")))
for index, peer in enumerate(self.peers):
vfs.map_file_fh(f"{self.FOLDER}/host{index}-{peer}", BytesIO(peer.encode("utf-8")))

target.fs.mount("/data", vfs)
target.filesystems.add(vfs)

@staticmethod
def detect(path: Path) -> bool:
return str(path).startswith("/remote/hosts/host")
for disk in self.connection.info():
target.disks.add(RawContainer(disk))
35 changes: 0 additions & 35 deletions dissect/target/plugins/child/mqtt.py

This file was deleted.

2 changes: 1 addition & 1 deletion dissect/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def _find(find_path: Path, parsed_path: Optional[urllib.parse.ParseResult]):
continue

getlogger(entry).debug("Attempting to use loader: %s", loader_cls)
for sub_entry in loader_cls.find_all(entry):
for sub_entry in loader_cls.find_all(entry, parsed_path=parsed_path):
try:
ldr = loader_cls(sub_entry, parsed_path=parsed_path)
except Exception as e:
Expand Down
21 changes: 12 additions & 9 deletions tests/loaders/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def publish(self, topic: str, *args) -> None:
if command == "TOPO":
tokens[2] = "ID"
response.topic = "/".join(tokens)
response.payload = self.hostname.encode("utf-8")
for host in self.hostnames:
response.payload = host.encode("utf-8")
self.on_message(self, None, response)
elif tokens[2] == "INFO":
tokens[2] = "DISKS"
response.topic = "/".join(tokens)
Expand Down Expand Up @@ -61,20 +63,21 @@ def mock_client(mock_paho: MagicMock) -> Iterator[MagicMock]:


@pytest.mark.parametrize(
"alias, host, disks, disk, seek, read, expected",
"alias, hosts, disks, disk, seek, read, expected",
[
("host1", "host1", [3], 0, 0, 3, b"\x00\x01\x02"), # basic
("host2", "host2", [10], 0, 1, 3, b"\x01\x02\x03"), # + use offset
("group1", "host3", [10], 0, 1, 3, b"\x01\x02\x03"), # + use alias
("group2", "host4", [10, 10, 1], 1, 1, 3, b"\x01\x02\x03"), # + use disk 2
("host1", ["host1"], [3], 0, 0, 3, b"\x00\x01\x02"), # basic
("host2", ["host2"], [10], 0, 1, 3, b"\x01\x02\x03"), # + use offset
("group1", ["host3"], [10], 0, 1, 3, b"\x01\x02\x03"), # + use alias
("group2", ["host4"], [10, 10, 1], 1, 1, 3, b"\x01\x02\x03"), # + use disk 2
("group3", ["host4", "host5"], [10, 10, 1], 1, 1, 3, b"\x01\x02\x03"), # + use disk 2
],
)
@patch.object(time, "sleep") # improve speed during test, no need to wait for peers
def test_remote_loader_stream(
time: MagicMock,
mock_client: MagicMock,
alias: str,
host: str,
hosts: list[str],
disks: list[int],
disk: int,
seek: int,
Expand All @@ -86,15 +89,15 @@ def test_remote_loader_stream(
broker = Broker("0.0.0.0", "1884", "key", "crt", "ca", "case1")
broker.connect()
broker.mqtt_client.fill_disks(disks)
broker.mqtt_client.hostname = host
broker.mqtt_client.hostnames = hosts

with patch("dissect.target.loaders.mqtt.MQTTLoader.broker", broker):
targets = list(
Target.open_all(
[f"mqtt://{alias}?broker=0.0.0.0&port=1884&key=key&crt=crt&ca=ca&peers=1&case=case1"],
include_children=True,
)
)
assert len(targets) == len(hosts)
target = targets[-1]
target.disks[disk].seek(seek)
data = target.disks[disk].read(read)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def detect(path: Path) -> bool:
return path.is_dir() and path.joinpath("select.txt").exists()

@staticmethod
def find_all(path: Path) -> Iterator[Path]:
def find_all(path: Path, **kwargs) -> Iterator[Path]:
return [Path("/dir/raw1.img").as_posix(), Path("/dir/raw3.img").as_posix()]

def map(self, target: Target):
Expand Down

0 comments on commit 42add54

Please sign in to comment.