Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[omm] Test SignalTypeExchange with infinite fetching capabilities #1488

Merged
merged 2 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion open-media-match/.devcontainer/omm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# to make it easier to copy this

from OpenMediaMatch.storage.postgres.impl import DefaultOMMStore
from OpenMediaMatch.utils.fetch_benchmarking import InfiniteRandomExchange
from threatexchange.signal_type.pdq.signal import PdqSignal
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.content_type.photo import PhotoContent
Expand All @@ -29,5 +30,5 @@
STORAGE_IFACE_INSTANCE = DefaultOMMStore(
signal_types=[PdqSignal, VideoMD5Signal],
content_types=[PhotoContent, VideoContent],
exchange_types=[StaticSampleSignalExchangeAPI],
exchange_types=[StaticSampleSignalExchangeAPI, InfiniteRandomExchange],
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

logger = logging.getLogger(__name__)

COMMIT_TO_DB_MAX_SIZE = 10000
COMMIT_TO_DB_MAX_SIZE = 5000
COMMIT_TO_DB_MAX_SEC = 60

ONE_FETCH_MAX_SEC = 60 * 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
const reload = window.location
if (!data.get("api_json")) {
as_object["api_json"] = {}
} else {
as_object["api_json"] = JSON.parse(data.get("api_json"))
}

fetch(`/c/exchanges`, {
Expand Down
143 changes: 143 additions & 0 deletions open-media-match/src/OpenMediaMatch/utils/fetch_benchmarking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
Helpers for fetch benchmarking and testing.

In the long term, these are generic enough that should probably live in
python-threatexchange, but for short-term needs we're just working on them
here.
"""
from dataclasses import dataclass
from collections import defaultdict
import time
import typing as t

from threatexchange.exchanges.fetch_state import (
FetchCheckpointBase,
FetchedSignalMetadata,
FetchDelta,
SignalOpinion,
)
from threatexchange.exchanges.collab_config import CollaborationConfigBase
from threatexchange.exchanges.signal_exchange_api import SignalExchangeAPI
from threatexchange.signal_type.signal_base import SignalType, CanGenerateRandomSignal


@dataclass
class InfiniteRandomExchangeCollabConfig(CollaborationConfigBase):
# How many items to return per fetch_iter()
new_items_per_fetch_iter: int = 500
# How many items to return before claiming no more records
# A limit of -1 means unlimited
new_items_per_fetch: int = -1
# How many items to return before no more records are returned
# A limit of -1 means unlimited
total_item_limit: int = 1000000
# Simulate IO time
sleep_per_iter: float = 0.0


@dataclass
class InfiniteRandomExchangeCheckpoint(FetchCheckpointBase):
total_fetched: int


class InfiniteRandomExchange(
SignalExchangeAPI[
InfiniteRandomExchangeCollabConfig,
InfiniteRandomExchangeCheckpoint,
FetchedSignalMetadata,
int,
t.Tuple[str, str],
]
):
"""
This SignalExchangeAPI generates random hashes for e2e testing.

By changing the config settings, it can even fetch an unbounded number
of signals.
"""

def __init__(self, config: InfiniteRandomExchangeCollabConfig) -> None:
self.config = config

@classmethod
def for_collab(
cls, collab: InfiniteRandomExchangeCollabConfig
) -> "SignalExchangeAPI":
return cls(collab)

@staticmethod
def get_checkpoint_cls() -> t.Type[InfiniteRandomExchangeCheckpoint]:
return InfiniteRandomExchangeCheckpoint

@staticmethod
def get_record_cls() -> t.Type[FetchedSignalMetadata]:
return FetchedSignalMetadata

@staticmethod
def get_config_cls() -> t.Type[InfiniteRandomExchangeCollabConfig]:
return InfiniteRandomExchangeCollabConfig

@classmethod
def naive_convert_to_signal_type(
cls,
signal_types: t.Sequence[t.Type[SignalType]],
collab: InfiniteRandomExchangeCollabConfig,
fetched: t.Mapping[int, t.Tuple[str, str]],
) -> t.Dict[t.Type[SignalType], t.Dict[str, FetchedSignalMetadata]]:
st_mapping = {st.get_name(): st for st in signal_types}
by_name: dict[str, dict[str, FetchedSignalMetadata]] = defaultdict(dict)
for st_name, st_val in fetched.values():
by_name[st_name][st_val] = FetchedSignalMetadata()
ret = {}
for st_name, signals in by_name.items():
st = st_mapping.get(st_name)
if st is None:
continue
ret[st] = signals
return ret

def fetch_iter(
self,
supported_signal_types: t.Sequence[t.Type[SignalType]],
# None if fetching for the first time,
# otherwise the previous FetchDelta returned
checkpoint: t.Optional[InfiniteRandomExchangeCheckpoint],
) -> t.Iterator[
FetchDelta[int, t.Tuple[str, str], InfiniteRandomExchangeCheckpoint]
]:
supported_sts = [
st
for st in supported_signal_types
if issubclass(st, CanGenerateRandomSignal)
]
total_fetched = 0 if checkpoint is None else checkpoint.total_fetched
inf = float("inf")
limit = inf
if self.config.new_items_per_fetch != -1:
limit = total_fetched + self.config.new_items_per_fetch
if self.config.total_item_limit != -1:
limit = min(limit, self.config.total_item_limit)
while total_fetched < limit:
new_tot = total_fetched + self.config.new_items_per_fetch_iter
if limit < inf:
new_tot = min(new_tot, int(limit))
ret = {
i: _get_signal(i, supported_sts) for i in range(total_fetched, new_tot)
}
if self.config.sleep_per_iter > 0:
time.sleep(self.config.sleep_per_iter)
yield FetchDelta(
ret, InfiniteRandomExchangeCheckpoint(total_fetched=new_tot)
)
total_fetched = new_tot


def _get_signal(
idx: int, supported: t.Sequence[t.Type[SignalType]]
) -> t.Tuple[str, str] | None:
if not supported:
return "", ""
st = supported[idx % len(supported)]
return st.get_name(), t.cast(CanGenerateRandomSignal, st).get_random_signal()
Loading