From 20bfb7c0932333e98996a37c87e846b6871fb25e Mon Sep 17 00:00:00 2001 From: David Callies Date: Thu, 16 Nov 2023 10:46:59 -0500 Subject: [PATCH] [omm][index] index uniqueness, hash insert api, idx checkpointing (#1457) --- .../src/OpenMediaMatch/blueprints/curation.py | 48 +++++++++++++--- .../src/OpenMediaMatch/blueprints/hashing.py | 4 +- .../src/OpenMediaMatch/blueprints/matching.py | 55 +++++++++++++++---- .../src/OpenMediaMatch/database.py | 3 +- .../src/OpenMediaMatch/storage/default.py | 30 ++++++---- .../src/OpenMediaMatch/storage/interface.py | 1 + .../src/OpenMediaMatch/tests/test_api.py | 8 +-- .../src/OpenMediaMatch/tests/test_database.py | 10 +++- 8 files changed, 122 insertions(+), 37 deletions(-) diff --git a/open-media-match/src/OpenMediaMatch/blueprints/curation.py b/open-media-match/src/OpenMediaMatch/blueprints/curation.py index 89bc49389..a37ac4b36 100644 --- a/open-media-match/src/OpenMediaMatch/blueprints/curation.py +++ b/open-media-match/src/OpenMediaMatch/blueprints/curation.py @@ -3,8 +3,8 @@ from flask import Blueprint from flask import request, jsonify, abort - from sqlalchemy.exc import IntegrityError +from threatexchange.signal_type.signal_base import SignalType from OpenMediaMatch import database, persistence, utils from OpenMediaMatch.storage.interface import BankConfig @@ -125,18 +125,52 @@ def bank_add_file(bank_name: str): hashes = hashing.hash_media_post_impl() else: abort(400, "Neither `url` query param nor multipart file upload was received") + return _bank_add_signals(bank, hashes) + - signal_type_configs = storage.get_signal_type_configs() +def _bank_add_signals( + bank: BankConfig, signal_type_to_signal_str: dict[str, str] +) -> dict[str, t.Any]: + if not signal_type_to_signal_str: + abort(400, "No signals given") + storage = persistence.get_storage() + + signals: dict[type[SignalType], str] = {} + signal_type_cfgs = storage.get_signal_type_configs() + for name, val in signal_type_to_signal_str.items(): + st = signal_type_cfgs.get(name) + if st is None: + abort(400, f"No such signal type {name}") + try: + signals[st.signal_type] = st.signal_type.validate_signal_str(val) + except Exception as e: + abort(400, f"Invalid {name} signal: {str(e)}") content_id = storage.bank_add_content( bank.name, - { - signal_type_configs[name].signal_type: signal_value - for name, signal_value in hashes.items() - }, + signals, ) - return {"id": content_id, "signals": hashes} + return { + "id": content_id, + "signals": {st.get_name(): val for st, val in signals.items()}, + } + + +@bp.route("/bank//signal", methods=["POST"]) +@utils.abort_to_json +def bank_add_as_signals(bank_name: str): + """ + Add a signal/hash directly to the bank. + + Most of the time you want to add by file, since you'll be able + able to process the file in all of the techniques you have available. + """ + storage = persistence.get_storage() + bank = storage.get_bank(bank_name) + if not bank: + abort(404, f"bank '{bank_name}' not found") + return _bank_add_signals(bank, t.cast(dict[str, str], request.json)) def _get_collab(name: str): diff --git a/open-media-match/src/OpenMediaMatch/blueprints/hashing.py b/open-media-match/src/OpenMediaMatch/blueprints/hashing.py index d9b6917ef..f5c3c10fe 100644 --- a/open-media-match/src/OpenMediaMatch/blueprints/hashing.py +++ b/open-media-match/src/OpenMediaMatch/blueprints/hashing.py @@ -25,7 +25,7 @@ @bp.route("/hash", methods=["GET"]) @abort_to_json -def hash_media(): +def hash_media() -> dict[str, str]: """ Fetch content and return its hash. @@ -49,7 +49,7 @@ def hash_media(): content_type = _parse_request_content_type(url_content_type) signal_types = _parse_request_signal_type(content_type) - ret = {} + ret: dict[str, str] = {} # For images, we may need to copy the file suffix (.png, jpeg, etc) for it to work with tempfile.NamedTemporaryFile("wb") as tmp: diff --git a/open-media-match/src/OpenMediaMatch/blueprints/matching.py b/open-media-match/src/OpenMediaMatch/blueprints/matching.py index bca21aa0a..b558a03c6 100644 --- a/open-media-match/src/OpenMediaMatch/blueprints/matching.py +++ b/open-media-match/src/OpenMediaMatch/blueprints/matching.py @@ -152,17 +152,50 @@ def lookup(signal, signal_type_name): @bp.route("/index/status") +@abort_to_json def index_status(): """ - Input: - * Signal type (hash type) - Output: - * Time of last index build + Get the status of matching indices. + + You can limit to just a single type with the signal_type parameter. + + Example Output: + { + "pdq": { + "built_to": -1, + "present": false, + "size": 0 + }, + "video_md5": { + "built_to": 1700146048, + "present": true, + "size": 591 + } + } """ - abort(501) # Unimplemented - - -@bp.route("/index//status") -@abort_to_json -def index_status_by_type(signal_type_name: str): - abort(501) # Not yet implemented + storage = get_storage() + signal_types = storage.get_signal_type_configs() + + limit_to_type = request.args.get("signal_type") + if limit_to_type is not None: + if limit_to_type not in signal_types: + abort(400, f"No such signal type '{limit_to_type}'") + signal_types = {limit_to_type: signal_types[limit_to_type]} + + status_by_name = {} + for name, st in signal_types.items(): + checkpoint = storage.get_last_index_build_checkpoint(st.signal_type) + + status = { + "present": False, + "built_to": -1, + "size": 0, + } + if checkpoint is not None: + status = { + "present": True, + "built_to": checkpoint.last_item_timestamp, + "size": checkpoint.total_hash_count, + } + status_by_name[name] = status + return status_by_name diff --git a/open-media-match/src/OpenMediaMatch/database.py b/open-media-match/src/OpenMediaMatch/database.py index d38c07ed7..a692d8211 100644 --- a/open-media-match/src/OpenMediaMatch/database.py +++ b/open-media-match/src/OpenMediaMatch/database.py @@ -180,9 +180,10 @@ class SignalIndex(db.Model): # type: ignore[name-defined] """ id: Mapped[int] = mapped_column(primary_key=True) - signal_type: Mapped[str] + signal_type: Mapped[str] = mapped_column(String(255), unique=True, index=True) serialized_index: Mapped[bytes] = mapped_column(LargeBinary) signal_count: Mapped[int] + updated_to_id: Mapped[int] updated_to_ts: Mapped[int] updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=db.func.now() diff --git a/open-media-match/src/OpenMediaMatch/storage/default.py b/open-media-match/src/OpenMediaMatch/storage/default.py index b2923a256..78afcb31c 100644 --- a/open-media-match/src/OpenMediaMatch/storage/default.py +++ b/open-media-match/src/OpenMediaMatch/storage/default.py @@ -136,12 +136,15 @@ def store_signal_type_index( ).scalar_one_or_none() if db_record is not None: db_record.serialize_index(index) + # TODO - pass in real checkpoint + db_record.signal_count = signal_count else: database.db.session.add( database.SignalIndex( signal_type=signal_type.get_name(), # TODO - use real time checkpoint - updated_to_ts=int(time.time()), + updated_to_ts=-1, + updated_to_id=-1, signal_count=signal_count, ).serialize_index(index) ) @@ -151,15 +154,22 @@ def store_signal_type_index( def get_last_index_build_checkpoint( self, signal_type: t.Type[SignalType] ) -> t.Optional[interface.SignalTypeIndexBuildCheckpoint]: - updated_to = database.db.session.execute( - select(database.SignalIndex.updated_to_ts).where( - database.SignalIndex.signal_type == signal_type.get_name() - ) - ).scalar_one_or_none() - if updated_to is None: - return interface.SignalTypeIndexBuildCheckpoint.get_empty() - # TODO - return None + row = database.db.session.execute( + select( + database.SignalIndex.updated_to_ts, + database.SignalIndex.updated_to_id, + database.SignalIndex.signal_count, + ).where(database.SignalIndex.signal_type == signal_type.get_name()) + ).one_or_none() + + if row is None: + return None + updated_to_ts, updated_to_id, total_count = row._tuple() + return interface.SignalTypeIndexBuildCheckpoint( + last_item_timestamp=updated_to_ts, + last_item_id=updated_to_id, + total_hash_count=total_count, + ) # Collabs def get_collaborations(self) -> t.Dict[str, CollaborationConfigBase]: diff --git a/open-media-match/src/OpenMediaMatch/storage/interface.py b/open-media-match/src/OpenMediaMatch/storage/interface.py index 8288e8b96..b1347e42d 100644 --- a/open-media-match/src/OpenMediaMatch/storage/interface.py +++ b/open-media-match/src/OpenMediaMatch/storage/interface.py @@ -150,6 +150,7 @@ class SignalTypeIndexBuildCheckpoint: @classmethod def get_empty(cls): + """Represents a checkpoint for an empty index / no hashes.""" return cls(last_item_timestamp=-1, last_item_id=-1, total_hash_count=0) diff --git a/open-media-match/src/OpenMediaMatch/tests/test_api.py b/open-media-match/src/OpenMediaMatch/tests/test_api.py index 0ef0744a6..59322a6b8 100644 --- a/open-media-match/src/OpenMediaMatch/tests/test_api.py +++ b/open-media-match/src/OpenMediaMatch/tests/test_api.py @@ -132,10 +132,10 @@ def test_banks_add_hash(client: FlaskClient): image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true" post_response = client.post( - "/c/bank/{}/content?url={}&content_type=photo".format(bank_name, image_url) + f"/c/bank/{bank_name}/content?url={image_url}&content_type=photo" ) - assert post_response.status_code == 200 + assert post_response.status_code == 200, str(post_response.get_json()) assert post_response.json == { "id": 1, "signals": { @@ -169,14 +169,14 @@ def test_banks_add_hash_index(app: Flask, client: FlaskClient): # Test against first image post_response = client.get( - "/m/raw_lookup?signal_type=pdq&signal={}".format(IMAGE_URL_TO_PDQ[image_url]) + f"/m/raw_lookup?signal_type=pdq&signal={IMAGE_URL_TO_PDQ[image_url]}" ) assert post_response.status_code == 200 assert post_response.json == {"matches": [1]} # Test against second image post_response = client.get( - "/m/raw_lookup?signal_type=pdq&signal={}".format(IMAGE_URL_TO_PDQ[image_url_2]) + f"/m/raw_lookup?signal_type=pdq&signal={IMAGE_URL_TO_PDQ[image_url_2]}" ) assert post_response.status_code == 200 assert post_response.json == {"matches": [2]} diff --git a/open-media-match/src/OpenMediaMatch/tests/test_database.py b/open-media-match/src/OpenMediaMatch/tests/test_database.py index f4d567203..1b3a6247d 100644 --- a/open-media-match/src/OpenMediaMatch/tests/test_database.py +++ b/open-media-match/src/OpenMediaMatch/tests/test_database.py @@ -72,7 +72,10 @@ def test_store_index(app: Flask) -> None: database.db.session.add( database.SignalIndex( - signal_type="test", updated_to_ts=12345, signal_count=len(content) + signal_type="test", + updated_to_ts=12345, + updated_to_id=5678, + signal_count=len(content), ).serialize_index(index) ) database.db.session.commit() @@ -96,7 +99,10 @@ def test_store_index_updated_at(app: Flask) -> None: database.db.session.add( database.SignalIndex( - signal_type="test", updated_to_ts=1234, signal_count=len(content) + signal_type="test", + updated_to_ts=1234, + updated_to_id=5678, + signal_count=len(content), ).serialize_index(index) ) database.db.session.commit()