Skip to content

Commit

Permalink
✨ Add sql query count metric
Browse files Browse the repository at this point in the history
  • Loading branch information
pajowu committed Nov 18, 2023
1 parent 8615778 commit 08b567d
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 17 deletions.
2 changes: 1 addition & 1 deletion backend/openapi-schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ components:
info:
title: FastAPI
version: 0.1.0
openapi: 3.0.2
openapi: 3.1.0
paths:
/:
get:
Expand Down
2 changes: 1 addition & 1 deletion backend/scripts/create_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
parser.add_argument("--user", required=True)
parser.add_argument("--pass", required=True)
args = parser.parse_args()
with SessionContextManager() as session:
with SessionContextManager(path="mangement_comment:create_user") as session:
try:
user = create_user(
session=session, username=args.user, password=getattr(args, "pass")
Expand Down
2 changes: 1 addition & 1 deletion backend/scripts/create_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
if args.token is None:
args.token = utils.get_random_string()

with SessionContextManager() as session:
with SessionContextManager(path="mangement_comment:create_worker") as session:
statement = select(Worker).where(Worker.token == args.token)
results = session.exec(statement)
existing_worker = results.one_or_none()
Expand Down
2 changes: 1 addition & 1 deletion backend/scripts/reset_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"--uuid", required=True, type=uuid.UUID, help="Task UUID or Document UUID"
)
args = parser.parse_args()
with SessionContextManager() as session:
with SessionContextManager(path="mangement_comment:reset_task") as session:
task = session.execute(
update(Task)
.where(
Expand Down
2 changes: 1 addition & 1 deletion backend/scripts/set_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
parser.add_argument("--pass", required=True)
args = parser.parse_args()

with SessionContextManager() as session:
with SessionContextManager(path="mangement_comment:set_password") as session:
try:
user = change_user_password(
session=session, username=args.user, new_password=getattr(args, "pass")
Expand Down
46 changes: 43 additions & 3 deletions backend/transcribee_backend/db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Optional

from fastapi import Request
from prometheus_client import Histogram
from prometheus_fastapi_instrumentator import routing
from sqlalchemy import event
from sqlmodel import Session, create_engine
from starlette.websockets import WebSocket

DEFAULT_SOCKET_PATH = Path(__file__).parent.parent.parent / "db" / "sockets"

Expand All @@ -13,10 +19,44 @@

engine = create_engine(DATABASE_URL)

query_histogram = Histogram(
"sql_queries",
"Number of sql queries executed per db session",
["path"],
buckets=[1, 2, 4, 8, 16, 32, 128, 256, 512],
)


def get_session(request: Request):
handler = routing.get_route_name(request)
with Session(engine) as session, query_counter(session, path=handler):
yield session


def get_session():
with Session(engine) as session:
def get_session_ws(websocket: WebSocket):
# get_route_name is typed with a Request, but in reality a HttpConnection
# (which WebSocket is) is enough
handler = routing.get_route_name(websocket) # type: ignore
with Session(engine) as session, query_counter(session, path=handler):
yield session


SessionContextManager = contextmanager(get_session)
@contextmanager
def SessionContextManager(path: str):
with Session(engine) as session, query_counter(session, path=path):
yield session


@contextmanager
def query_counter(session: Session, path: Optional[str]):
engine = session.connection().engine
count = 0

def callback(*args, **kwargs):
nonlocal count
count += 1

event.listen(engine, "before_cursor_execute", callback)
yield
event.remove(engine, "before_cursor_execute", callback)
query_histogram.labels(path=path).observe(count)
4 changes: 2 additions & 2 deletions backend/transcribee_backend/helpers/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def timeouted_tasks(session: Session) -> Iterable[Task]:

def timeout_attempts():
now = now_tz_aware()
with SessionContextManager() as session:
with SessionContextManager(path="repeating_task:timeout_attempts") as session:
for task in timeouted_tasks(session):
finish_current_attempt(
session=session, task=task, now=now, successful=False
Expand All @@ -72,7 +72,7 @@ def expired_tokens(session: Session) -> Iterable[UserToken]:


def remove_expired_tokens():
with SessionContextManager() as session:
with SessionContextManager(path="repeating_task:remove_expired_tokens") as session:
for user_token in expired_tokens(session):
session.delete(user_token)

Expand Down
2 changes: 1 addition & 1 deletion backend/transcribee_backend/media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def store_file(file: BinaryIO) -> str:
return name


def force_bytes(v: bytes | str):
def force_bytes(v: bytes | str) -> bytes:
if isinstance(v, str):
return v.encode()
return v
Expand Down
2 changes: 1 addition & 1 deletion backend/transcribee_backend/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def refresh(self, session: Session):


def refresh_metrics():
with SessionContextManager() as session:
with SessionContextManager(path="repeating_task:refresh_metrics") as session:
for metric in METRICS:
metric.refresh(session)

Expand Down
14 changes: 9 additions & 5 deletions backend/transcribee_backend/routers/document.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import datetime
import enum
import pathlib
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated, Callable, List, Optional

import magic
Expand All @@ -13,6 +13,7 @@
Form,
Header,
HTTPException,
Path,
Query,
UploadFile,
WebSocket,
Expand All @@ -35,7 +36,10 @@
validate_worker_authorization,
)
from transcribee_backend.config import get_model_config, settings
from transcribee_backend.db import get_session
from transcribee_backend.db import (
get_session,
get_session_ws,
)
from transcribee_backend.helpers.sync import DocumentSyncConsumer
from transcribee_backend.helpers.time import now_tz_aware
from transcribee_backend.models.document import (
Expand Down Expand Up @@ -200,7 +204,7 @@ def func(
def auth_fn_to_ws(f: Callable):
def func(
document_id: uuid.UUID,
session: Session = Depends(get_session),
session: Session = Depends(get_session_ws),
authorization: Optional[str] = Query(default=None),
share_token: Optional[str] = Query(default=None, alias="share_token"),
):
Expand Down Expand Up @@ -430,7 +434,7 @@ def delete_document(
auth: AuthInfo = Depends(get_doc_full_auth),
session: Session = Depends(get_session),
) -> None:
paths_to_delete: List[Path] = []
paths_to_delete: List[pathlib.Path] = []
media_files = select(DocumentMediaFile).where(
DocumentMediaFile.document == auth.document
)
Expand Down Expand Up @@ -463,7 +467,7 @@ def get_document_tasks(
async def websocket_endpoint(
websocket: WebSocket,
auth: AuthInfo = Depends(ws_get_doc_min_readonly_or_worker_auth),
session: Session = Depends(get_session),
session: Session = Depends(get_session_ws),
):
connection = DocumentSyncConsumer(
document=auth.document,
Expand Down

0 comments on commit 08b567d

Please sign in to comment.