Skip to content

Commit

Permalink
[omm][index] index uniqueness, hash insert api, idx checkpointing (#1457
Browse files Browse the repository at this point in the history
)
  • Loading branch information
Dcallies authored Nov 16, 2023
1 parent 7922acb commit 20bfb7c
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 37 deletions.
48 changes: 41 additions & 7 deletions open-media-match/src/OpenMediaMatch/blueprints/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from flask import Blueprint
from flask import request, jsonify, abort

from sqlalchemy.exc import IntegrityError
from threatexchange.signal_type.signal_base import SignalType

from OpenMediaMatch import database, persistence, utils
from OpenMediaMatch.storage.interface import BankConfig
Expand Down Expand Up @@ -125,18 +125,52 @@ def bank_add_file(bank_name: str):
hashes = hashing.hash_media_post_impl()
else:
abort(400, "Neither `url` query param nor multipart file upload was received")
return _bank_add_signals(bank, hashes)


signal_type_configs = storage.get_signal_type_configs()
def _bank_add_signals(
bank: BankConfig, signal_type_to_signal_str: dict[str, str]
) -> dict[str, t.Any]:
if not signal_type_to_signal_str:
abort(400, "No signals given")

storage = persistence.get_storage()

signals: dict[type[SignalType], str] = {}
signal_type_cfgs = storage.get_signal_type_configs()
for name, val in signal_type_to_signal_str.items():
st = signal_type_cfgs.get(name)
if st is None:
abort(400, f"No such signal type {name}")
try:
signals[st.signal_type] = st.signal_type.validate_signal_str(val)
except Exception as e:
abort(400, f"Invalid {name} signal: {str(e)}")
content_id = storage.bank_add_content(
bank.name,
{
signal_type_configs[name].signal_type: signal_value
for name, signal_value in hashes.items()
},
signals,
)

return {"id": content_id, "signals": hashes}
return {
"id": content_id,
"signals": {st.get_name(): val for st, val in signals.items()},
}


@bp.route("/bank/<bank_name>/signal", methods=["POST"])
@utils.abort_to_json
def bank_add_as_signals(bank_name: str):
"""
Add a signal/hash directly to the bank.
Most of the time you want to add by file, since you'll be able
able to process the file in all of the techniques you have available.
"""
storage = persistence.get_storage()
bank = storage.get_bank(bank_name)
if not bank:
abort(404, f"bank '{bank_name}' not found")
return _bank_add_signals(bank, t.cast(dict[str, str], request.json))


def _get_collab(name: str):
Expand Down
4 changes: 2 additions & 2 deletions open-media-match/src/OpenMediaMatch/blueprints/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

@bp.route("/hash", methods=["GET"])
@abort_to_json
def hash_media():
def hash_media() -> dict[str, str]:
"""
Fetch content and return its hash.
Expand All @@ -49,7 +49,7 @@ def hash_media():
content_type = _parse_request_content_type(url_content_type)
signal_types = _parse_request_signal_type(content_type)

ret = {}
ret: dict[str, str] = {}

# For images, we may need to copy the file suffix (.png, jpeg, etc) for it to work
with tempfile.NamedTemporaryFile("wb") as tmp:
Expand Down
55 changes: 44 additions & 11 deletions open-media-match/src/OpenMediaMatch/blueprints/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,50 @@ def lookup(signal, signal_type_name):


@bp.route("/index/status")
@abort_to_json
def index_status():
"""
Input:
* Signal type (hash type)
Output:
* Time of last index build
Get the status of matching indices.
You can limit to just a single type with the signal_type parameter.
Example Output:
{
"pdq": {
"built_to": -1,
"present": false,
"size": 0
},
"video_md5": {
"built_to": 1700146048,
"present": true,
"size": 591
}
}
"""
abort(501) # Unimplemented


@bp.route("/index/<signal_type_name>/status")
@abort_to_json
def index_status_by_type(signal_type_name: str):
abort(501) # Not yet implemented
storage = get_storage()
signal_types = storage.get_signal_type_configs()

limit_to_type = request.args.get("signal_type")
if limit_to_type is not None:
if limit_to_type not in signal_types:
abort(400, f"No such signal type '{limit_to_type}'")
signal_types = {limit_to_type: signal_types[limit_to_type]}

status_by_name = {}
for name, st in signal_types.items():
checkpoint = storage.get_last_index_build_checkpoint(st.signal_type)

status = {
"present": False,
"built_to": -1,
"size": 0,
}
if checkpoint is not None:
status = {
"present": True,
"built_to": checkpoint.last_item_timestamp,
"size": checkpoint.total_hash_count,
}
status_by_name[name] = status
return status_by_name
3 changes: 2 additions & 1 deletion open-media-match/src/OpenMediaMatch/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ class SignalIndex(db.Model): # type: ignore[name-defined]
"""

id: Mapped[int] = mapped_column(primary_key=True)
signal_type: Mapped[str]
signal_type: Mapped[str] = mapped_column(String(255), unique=True, index=True)
serialized_index: Mapped[bytes] = mapped_column(LargeBinary)
signal_count: Mapped[int]
updated_to_id: Mapped[int]
updated_to_ts: Mapped[int]
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=db.func.now()
Expand Down
30 changes: 20 additions & 10 deletions open-media-match/src/OpenMediaMatch/storage/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,15 @@ def store_signal_type_index(
).scalar_one_or_none()
if db_record is not None:
db_record.serialize_index(index)
# TODO - pass in real checkpoint
db_record.signal_count = signal_count
else:
database.db.session.add(
database.SignalIndex(
signal_type=signal_type.get_name(),
# TODO - use real time checkpoint
updated_to_ts=int(time.time()),
updated_to_ts=-1,
updated_to_id=-1,
signal_count=signal_count,
).serialize_index(index)
)
Expand All @@ -151,15 +154,22 @@ def store_signal_type_index(
def get_last_index_build_checkpoint(
self, signal_type: t.Type[SignalType]
) -> t.Optional[interface.SignalTypeIndexBuildCheckpoint]:
updated_to = database.db.session.execute(
select(database.SignalIndex.updated_to_ts).where(
database.SignalIndex.signal_type == signal_type.get_name()
)
).scalar_one_or_none()
if updated_to is None:
return interface.SignalTypeIndexBuildCheckpoint.get_empty()
# TODO
return None
row = database.db.session.execute(
select(
database.SignalIndex.updated_to_ts,
database.SignalIndex.updated_to_id,
database.SignalIndex.signal_count,
).where(database.SignalIndex.signal_type == signal_type.get_name())
).one_or_none()

if row is None:
return None
updated_to_ts, updated_to_id, total_count = row._tuple()
return interface.SignalTypeIndexBuildCheckpoint(
last_item_timestamp=updated_to_ts,
last_item_id=updated_to_id,
total_hash_count=total_count,
)

# Collabs
def get_collaborations(self) -> t.Dict[str, CollaborationConfigBase]:
Expand Down
1 change: 1 addition & 0 deletions open-media-match/src/OpenMediaMatch/storage/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class SignalTypeIndexBuildCheckpoint:

@classmethod
def get_empty(cls):
"""Represents a checkpoint for an empty index / no hashes."""
return cls(last_item_timestamp=-1, last_item_id=-1, total_hash_count=0)


Expand Down
8 changes: 4 additions & 4 deletions open-media-match/src/OpenMediaMatch/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ def test_banks_add_hash(client: FlaskClient):
image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"

post_response = client.post(
"/c/bank/{}/content?url={}&content_type=photo".format(bank_name, image_url)
f"/c/bank/{bank_name}/content?url={image_url}&content_type=photo"
)

assert post_response.status_code == 200
assert post_response.status_code == 200, str(post_response.get_json())
assert post_response.json == {
"id": 1,
"signals": {
Expand Down Expand Up @@ -169,14 +169,14 @@ def test_banks_add_hash_index(app: Flask, client: FlaskClient):

# Test against first image
post_response = client.get(
"/m/raw_lookup?signal_type=pdq&signal={}".format(IMAGE_URL_TO_PDQ[image_url])
f"/m/raw_lookup?signal_type=pdq&signal={IMAGE_URL_TO_PDQ[image_url]}"
)
assert post_response.status_code == 200
assert post_response.json == {"matches": [1]}

# Test against second image
post_response = client.get(
"/m/raw_lookup?signal_type=pdq&signal={}".format(IMAGE_URL_TO_PDQ[image_url_2])
f"/m/raw_lookup?signal_type=pdq&signal={IMAGE_URL_TO_PDQ[image_url_2]}"
)
assert post_response.status_code == 200
assert post_response.json == {"matches": [2]}
10 changes: 8 additions & 2 deletions open-media-match/src/OpenMediaMatch/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def test_store_index(app: Flask) -> None:

database.db.session.add(
database.SignalIndex(
signal_type="test", updated_to_ts=12345, signal_count=len(content)
signal_type="test",
updated_to_ts=12345,
updated_to_id=5678,
signal_count=len(content),
).serialize_index(index)
)
database.db.session.commit()
Expand All @@ -96,7 +99,10 @@ def test_store_index_updated_at(app: Flask) -> None:

database.db.session.add(
database.SignalIndex(
signal_type="test", updated_to_ts=1234, signal_count=len(content)
signal_type="test",
updated_to_ts=1234,
updated_to_id=5678,
signal_count=len(content),
).serialize_index(index)
)
database.db.session.commit()
Expand Down

0 comments on commit 20bfb7c

Please sign in to comment.