Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch database logic to use SQLModel, fix type issues #30

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"python.analysis.typeCheckingMode": "basic",
"python.analysis.typeCheckingMode": "off",
"python.analysis.autoImportCompletions": true
}
80 changes: 39 additions & 41 deletions app/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

from typing import Annotated
from typing import Annotated, Generator
import json
import logging
import os
Expand All @@ -14,18 +14,18 @@
from fastapi.exception_handlers import request_validation_exception_handler
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from sqlmodel import Session, select, func
import sentry_sdk

load_dotenv()

from app.search.helpers import get_default_index_name
from app.utils.exceptions import before_send
from app.models.database import SessionLocal, engine
from app.models.database import engine
from app.models.metadata import Metadata
from app.utils import storage
from app import worker
from app.models import call as call_model, base as base_model
from app.models import models

sentry_dsn = os.getenv("SENTRY_DSN")
if sentry_dsn:
Expand All @@ -44,9 +44,6 @@
before_send=before_send,
)

if os.getenv("POSTGRES_DB") is not None:
base_model.Base.metadata.create_all(bind=engine)

app = FastAPI()

logger = logging.getLogger()
Expand All @@ -59,17 +56,13 @@
logger.addHandler(stream_handler)


# Dependency
def get_db(): # type: ignore
db = SessionLocal()
try:
yield db
finally:
db.close()
def get_db() -> Generator[Session, None, None]:
with Session(engine) as session:
yield session


@app.middleware("http")
async def authenticate(request: Request, call_next) -> Response: # type: ignore
async def authenticate(request: Request, call_next) -> Response:
api_key = os.getenv("API_KEY", "")

if (
Expand Down Expand Up @@ -190,23 +183,25 @@ def create_call_from_sdrtrunk(
finally:
os.unlink(raw_audio.name)

call = call_model.CallCreateSchema(
raw_metadata=dict(metadata), raw_audio_url=audio_url
)
call = models.CallCreate(raw_metadata=metadata, raw_audio_url=audio_url)

db_call = call_model.create_call(db=db, call=call)
db_call = models.create_call(db=db, call=call)

if "digital" in metadata["audio_type"]:
from app.radio.digital import build_transcribe_options
elif metadata["audio_type"] == "analog":
from app.radio.analog import build_transcribe_options
else:
raise HTTPException(
status_code=400, detail=f"Audio type {metadata['audio_type']} not supported"
)

worker.queue_task(
audio_url,
metadata,
build_transcribe_options(metadata),
whisper_implementation=None,
id=db_call.id, # type: ignore
id=db_call.id,
)

return Response("Call imported successfully.", status_code=200)
Expand Down Expand Up @@ -251,12 +246,12 @@ def queue_for_transcription(
audio_url, metadata, build_transcribe_options(metadata), whisper_implementation
)

return JSONResponse({"task_id": task.id}, status_code=201) # type: ignore
return JSONResponse({"task_id": task.id}, status_code=201)


@app.get("/tasks/{task_id}")
def get_status(task_id: str) -> JSONResponse:
task_result = AsyncResult(task_id, app=worker.celery) # type: ignore
task_result: AsyncResult = AsyncResult(task_id, app=worker.celery)
result = {
"task_id": task_id,
"task_status": task_result.status,
Expand All @@ -269,17 +264,22 @@ def get_status(task_id: str) -> JSONResponse:
return JSONResponse(result)


@app.get("/calls/", response_model=list[call_model.CallSchema])
@app.get("/calls/", response_model=models.CallsPublic)
def read_calls(
skip: int = 0, limit: int = 100, db: Session = Depends(get_db)
) -> list[call_model.Call]:
calls = call_model.get_calls(db, skip=skip, limit=limit)
return calls
) -> models.CallsPublic:
count_statement = select(func.count()).select_from(models.Call)
count = db.exec(count_statement).one()

statement = select(models.Call).offset(skip).limit(limit)
calls = db.exec(statement).all()

return models.CallsPublic(data=calls, count=count)

@app.get("/calls/{call_id}", response_model=call_model.CallSchema)
def read_call(call_id: int, db: Session = Depends(get_db)) -> call_model.Call:
db_call = call_model.get_call(db, call_id=call_id)

@app.get("/calls/{call_id}", response_model=models.CallPublic)
def read_call(call_id: int, db: Session = Depends(get_db)) -> models.Call:
db_call = db.get(models.Call, call_id)
if db_call is None:
raise HTTPException(status_code=404, detail="Call not found")
return db_call
Expand Down Expand Up @@ -328,40 +328,38 @@ def create_call(
else:
raise HTTPException(status_code=400, detail="No audio provided")

call = call_model.CallCreateSchema(raw_metadata=metadata, raw_audio_url=audio_url)
call = models.CallCreate(raw_metadata=metadata, raw_audio_url=audio_url)

db_call = call_model.create_call(db=db, call=call)
db_call = models.create_call(db=db, call=call)

task = worker.queue_task(
audio_url,
metadata,
build_transcribe_options(metadata),
whisper_implementation,
db_call.id, # type: ignore
db_call.id,
)

return JSONResponse(
{
"task_id": task.id # type: ignore
},
{"task_id": task.id},
status_code=201,
)


@app.patch("/calls/{call_id}", response_model=call_model.CallSchema)
@app.patch("/calls/{call_id}", response_model=models.CallPublic)
def update_call(
call_id: int, call: call_model.CallUpdateSchema, db: Session = Depends(get_db)
) -> call_model.Call:
db_call = call_model.get_call(db, call_id=call_id)
call_id: int, call: models.CallUpdate, db: Session = Depends(get_db)
) -> models.Call:
db_call = db.get(models.Call, call_id)
if db_call is None:
raise HTTPException(status_code=404, detail="Call not found")

return call_model.update_call(db=db, call=call, db_call=db_call)
return models.update_call(db=db, call=call, db_call=db_call)


@app.get("/talkgroups")
def talkgroups(db: Session = Depends(get_db)) -> JSONResponse:
tgs = call_model.get_talkgroups(db, get_default_index_name())
tgs = models.get_talkgroups(db, get_default_index_name())
return JSONResponse({"talkgroups": tgs})


Expand Down
2 changes: 1 addition & 1 deletion app/bin/autoscale-vast.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _update_pending_instances(self, instances: list[dict]):

def get_worker_status(self) -> list[dict]:
workers = []
result = worker.celery.control.inspect(timeout=10).stats() # type: ignore
result = worker.celery.control.inspect(timeout=10).stats()
if result:
for name, stats in result.items():
workers.append({"name": name, "stats": stats})
Expand Down
4 changes: 2 additions & 2 deletions app/bin/import-to-db.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def get_documents(
MeiliDocument(hit) for hit in results["hits"]
]
else:
results = index.get_documents(pagination) # type: ignore
return results.total, results.results # type: ignore
results = index.get_documents(pagination)
return results.total, results.results


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions app/bin/migrate-from-meilisearch-to-typesense.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# Load the .env file of our choice if specified before the regular .env can load
load_dotenv(os.getenv("ENV"))

from app.geocoding.geocoding import GeoResponse
from app.geocoding.types import GeoResponse
from app.models.metadata import Metadata
from app.models.transcript import Transcript
from app.search import helpers, adapters
Expand All @@ -26,7 +26,7 @@ def convert_document(document: MeiliDocument) -> helpers.Document:

if hasattr(document, "_geo") and hasattr(document, "geo_formatted_address"):
geo = GeoResponse(
geo=document._geo, # type: ignore
geo=document._geo,
geo_formatted_address=document.geo_formatted_address,
)
else:
Expand Down Expand Up @@ -109,7 +109,7 @@ def get_documents(index: Index, pagination: dict) -> Tuple[int, list[MeiliDocume

# Create collection in typesense
typesense_adapter.upsert_index(index)
collection_docs = typesense_adapter.client.collections[index].documents # type: ignore
collection_docs = typesense_adapter.client.collections[index].documents

total, _ = get_documents(meili_index, {"limit": 1})
logging.info(f"Found {total} total documents")
Expand Down
21 changes: 11 additions & 10 deletions app/bin/reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# Load the .env file of our choice if specified before the regular .env can load
load_dotenv(os.getenv("ENV"))

from app.geocoding.geocoding import GeoResponse, lookup_geo
from app.geocoding.geocoding import lookup_geo
from app.geocoding.types import GeoResponse
from app.models.metadata import Metadata
from app.models.transcript import Transcript
from app.search import helpers
Expand Down Expand Up @@ -68,10 +69,10 @@ def update_document(
if TALKGROUPS.get(metadata["short_name"]):
try:
talkgroup = TALKGROUPS[metadata["short_name"]][metadata["talkgroup"]]
metadata["talkgroup_tag"] = talkgroup["Alpha Tag"].strip() # type: ignore
metadata["talkgroup_description"] = talkgroup["Description"].strip() # type: ignore
metadata["talkgroup_group"] = talkgroup["Category"].strip() # type: ignore
metadata["talkgroup_group_tag"] = talkgroup["Tag"].strip() # type: ignore
metadata["talkgroup_tag"] = talkgroup["Alpha Tag"].strip()
metadata["talkgroup_description"] = talkgroup["Description"].strip()
metadata["talkgroup_group"] = talkgroup["Category"].strip()
metadata["talkgroup_group_tag"] = talkgroup["Tag"].strip()
except KeyError:
logging.warning(
f"Could not find talkgroup {metadata['talkgroup']} in {metadata['short_name']} CSV file"
Expand All @@ -84,7 +85,7 @@ def update_document(
and document["geo_formatted_address"]
):
geo = GeoResponse(
geo=document["_geo"], # type: ignore
geo=document["_geo"],
geo_formatted_address=document["geo_formatted_address"],
)
elif should_lookup_geo:
Expand Down Expand Up @@ -115,7 +116,7 @@ def update_document(
# build_transcribe_options(metadata),
# id=doc["id"],
# index_name=index.uid,
# ) # type: ignore
# )


def load_csvs(
Expand All @@ -126,7 +127,7 @@ def load_csvs(
for system, file in unit_tags:
tags = []
with open(file, newline="") as csvfile:
unit_reader = csv.reader(csvfile, escapechar="\\") # type: ignore
unit_reader = csv.reader(csvfile, escapechar="\\")
for row in unit_reader:
tags.append(row)
UNIT_TAGS[system] = tags
Expand All @@ -137,8 +138,8 @@ def load_csvs(
tgs = {}
with open(file, newline="") as csvfile:
tg_reader = csv.DictReader(csvfile)
for row in tg_reader: # type: ignore
tgs[int(row["Decimal"])] = row # type: ignore
for row in tg_reader:
tgs[int(row["Decimal"])] = row
TALKGROUPS[system] = tgs

return UNIT_TAGS, TALKGROUPS
Expand Down
22 changes: 6 additions & 16 deletions app/geocoding/geocoding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import re
import os
from typing import Any, TypedDict
from typing import Any

import sentry_sdk
from geopy.location import Location
Expand All @@ -13,16 +13,7 @@
from app.models.metadata import Metadata
from app.models.transcript import Transcript
from app.geocoding import llm


class Geo(TypedDict):
lat: float
lng: float


class GeoResponse(TypedDict):
geo: Geo
geo_formatted_address: str
from app.geocoding.types import AddressParts, GeoResponse


def build_address_regex(include_intersections: bool = True) -> str:
Expand Down Expand Up @@ -95,7 +86,7 @@ def contains_address(transcript: str) -> bool:


def geocode(
address_parts: dict[str, str | None], geocoder: str | None = None
address_parts: AddressParts, geocoder: str | None = None
) -> GeoResponse | None: # pragma: no cover
query: dict[str, Any] = {
"query": f"{address_parts['address']}, {address_parts['city']}, {address_parts['state']}, {address_parts['country']}"
Expand Down Expand Up @@ -243,16 +234,15 @@ def lookup_geo(
if geocoding_systems == "*" or metadata["short_name"] in filter(
lambda name: len(name), geocoding_systems.split(",")
):
default_address_parts = {
default_address_parts: AddressParts = {
"city": os.getenv("GEOCODING_CITY"),
"state": os.getenv("GEOCODING_STATE"),
"country": os.getenv("GEOCODING_COUNTRY", "US"),
}
bounds_raw = os.getenv("GEOCODING_BOUNDS")
if bounds_raw:
default_address_parts["bounds"] = [
Point(bound)
for bound in bounds_raw.split("|") # type: ignore
Point(bound) for bound in bounds_raw.split("|")
]
else:
default_address_parts["bounds"] = None
Expand All @@ -261,7 +251,7 @@ def lookup_geo(

transcript_txt = transcript.txt_nosrc

address_parts = default_address_parts.copy()
address_parts: AddressParts = default_address_parts.copy()
# TODO: how can we extract the city and state from the metadata?
address_parts["address"] = extract_address(transcript_txt)
if address_parts["address"]:
Expand Down
Loading
Loading