Skip to content

Commit

Permalink
feat(sequences): implement the frontend-specific routes (#412)
Browse files Browse the repository at this point in the history
* refactor(models): move label from detection to sequence

* refactor(sequences): moved label to sequence

* test(e2e): update test cases

* docs(datamodel): update datamodel documentation

* feat(client): update the client methods

* feat(sequences): unset det <--> seq link when deleting sequence

* test(backend): update test cases

* style(client): fix typing

* test(sequences): fix payload comparison

* fix(client): update URL prepend

* test(backend): fix test cases

* fix(client): update route mapping

* test(backend): update conftest

* test(client): fix client

* test(client): fix client

* fix(sequences): fix bucket resolution

* test(backend): fix test cases

* build(deps): fix lock file

* test(detections): remove legacy tests

* test(sequences): fix fixture picking

* fix(sequence): fix detection deletion

* test(sequences): update test cases

* test(detections): improve testing

* test(sequences): extend test cases

* test(detections): fix test case
  • Loading branch information
frgfm authored Jan 17, 2025
1 parent ccd95d2 commit 6c8b02f
Show file tree
Hide file tree
Showing 15 changed files with 521 additions and 296 deletions.
133 changes: 83 additions & 50 deletions client/pyroclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.


from enum import Enum
from typing import Dict, List, Tuple
from urllib.parse import urljoin

Expand All @@ -14,30 +14,25 @@

__all__ = ["Client"]

ROUTES: Dict[str, str] = {
#################

class ClientRoute(str, Enum):
# LOGIN
#################
"login-validate": "/login/validate",
#################
LOGIN_VALIDATE = "login/validate"
# CAMERAS
#################
"cameras-heartbeat": "/cameras/heartbeat",
"cameras-image": "/cameras/image",
"cameras-fetch": "/cameras/",
#################
CAMERAS_HEARTBEAT = "cameras/heartbeat"
CAMERAS_IMAGE = "cameras/image"
CAMERAS_FETCH = "cameras/"
# DETECTIONS
#################
"detections-create": "/detections/",
"detections-label": "/detections/{det_id}/label",
"detections-fetch": "/detections",
"detections-fetch-unl": "/detections/unlabeled/fromdate",
"detections-url": "/detections/{det_id}/url",
#################
DETECTIONS_CREATE = "detections/"
DETECTIONS_FETCH = "detections"
DETECTIONS_URL = "detections/{det_id}/url"
# SEQUENCES
SEQUENCES_LABEL = "sequences/{seq_id}/label"
SEQUENCES_FETCH_DETECTIONS = "sequences/{seq_id}/detections"
SEQUENCES_FETCH_LATEST = "sequences/unlabeled/latest"
SEQUENCES_FETCH_FROMDATE = "sequences/all/fromdate"
# ORGS
#################
"organizations-fetch": "/organizations",
}
ORGS_FETCH = "organizations"


def _to_str(coord: float) -> str:
Expand Down Expand Up @@ -76,8 +71,6 @@ class Client:
kwargs: optional parameters of `requests.post`
"""

routes: Dict[str, str]

def __init__(
self,
token: str,
Expand All @@ -89,10 +82,13 @@ def __init__(
if requests.get(urljoin(host, "status"), timeout=timeout, **kwargs).status_code != 200:
raise ValueError(f"unable to reach host {host}")
# Prepend API url to each route
self.routes = {k: urljoin(host, f"api/v1{v}") for k, v in ROUTES.items()}
self._route_prefix = urljoin(host, "api/v1/")
# Check token
response = requests.get(
self.routes["login-validate"], headers={"Authorization": f"Bearer {token}"}, timeout=timeout, **kwargs
urljoin(self._route_prefix, ClientRoute.LOGIN_VALIDATE),
headers={"Authorization": f"Bearer {token}"},
timeout=timeout,
**kwargs,
)
if response.status_code != 200:
raise HTTPRequestError(response.status_code, response.text)
Expand All @@ -115,7 +111,7 @@ def fetch_cameras(self) -> Response:
HTTP response
"""
return requests.get(
self.routes["cameras-fetch"],
urljoin(self._route_prefix, ClientRoute.CAMERAS_FETCH),
headers=self.headers,
timeout=self.timeout,
)
Expand All @@ -130,7 +126,9 @@ def heartbeat(self) -> Response:
Returns:
HTTP response containing the update device info
"""
return requests.patch(self.routes["cameras-heartbeat"], headers=self.headers, timeout=self.timeout)
return requests.patch(
urljoin(self._route_prefix, ClientRoute.CAMERAS_HEARTBEAT), headers=self.headers, timeout=self.timeout
)

def update_last_image(self, media: bytes) -> Response:
"""Update the last image of the camera
Expand All @@ -144,7 +142,7 @@ def update_last_image(self, media: bytes) -> Response:
HTTP response containing the update device info
"""
return requests.patch(
self.routes["cameras-image"],
urljoin(self._route_prefix, ClientRoute.CAMERAS_IMAGE),
headers=self.headers,
files={"file": ("logo.png", media, "image/png")},
timeout=self.timeout,
Expand Down Expand Up @@ -175,7 +173,7 @@ def create_detection(
if not isinstance(bboxes, (list, tuple)) or len(bboxes) == 0 or len(bboxes) > 5:
raise ValueError("bboxes must be a non-empty list of tuples with a maximum of 5 boxes")
return requests.post(
self.routes["detections-create"],
urljoin(self._route_prefix, ClientRoute.DETECTIONS_CREATE),
headers=self.headers,
data={
"azimuth": azimuth,
Expand All @@ -185,77 +183,112 @@ def create_detection(
files={"file": ("logo.png", media, "image/png")},
)

def label_detection(self, detection_id: int, is_wildfire: bool) -> Response:
"""Update the label of a detection made by a camera
def get_detection_url(self, detection_id: int) -> Response:
"""Retrieve the URL of the media linked to a detection
>>> from pyroclient import client
>>> api_client = Client("MY_USER_TOKEN")
>>> response = api_client.label_detection(1, is_wildfire=True)
>>> response = api_client.get_detection_url(1)
Args:
detection_id: ID of the associated detection entry
is_wildfire: whether this detection is confirmed as a wildfire
Returns:
HTTP response
"""
return requests.get(
urljoin(self._route_prefix, ClientRoute.DETECTIONS_URL.format(det_id=detection_id)),
headers=self.headers,
timeout=self.timeout,
)

def fetch_detections(self) -> Response:
"""List the detections accessible to the authenticated user
>>> from pyroclient import client
>>> api_client = Client("MY_USER_TOKEN")
>>> response = api_client.fetch_detections()
Returns:
HTTP response
"""
return requests.get(
urljoin(self._route_prefix, ClientRoute.DETECTIONS_FETCH),
headers=self.headers,
timeout=self.timeout,
)

def label_sequence(self, sequence_id: int, is_wildfire: bool) -> Response:
"""Update the label of a sequence made by a camera
>>> from pyroclient import client
>>> api_client = Client("MY_USER_TOKEN")
>>> response = api_client.label_sequence(1, is_wildfire=True)
Args:
sequence_id: ID of the associated sequence entry
is_wildfire: whether this sequence is confirmed as a wildfire
Returns:
HTTP response
"""
return requests.patch(
self.routes["detections-label"].format(det_id=detection_id),
urljoin(self._route_prefix, ClientRoute.SEQUENCES_LABEL.format(seq_id=sequence_id)),
headers=self.headers,
json={"is_wildfire": is_wildfire},
timeout=self.timeout,
)

def get_detection_url(self, detection_id: int) -> Response:
"""Retrieve the URL of the media linked to a detection
def fetch_sequences_from_date(self, from_date: str) -> Response:
"""List the sequences accessible to the authenticated user for a specific date
>>> from pyroclient import client
>>> api_client = Client("MY_USER_TOKEN")
>>> response = api_client.get_detection_url(1)
>>> response = api_client.fetch_sequences_from_date("2023-07-04")
Args:
detection_id: ID of the associated detection entry
from_date: date of the sequences to fetch
Returns:
HTTP response
"""
params = {"from_date": from_date}
return requests.get(
self.routes["detections-url"].format(det_id=detection_id),
urljoin(self._route_prefix, ClientRoute.SEQUENCES_FETCH_FROMDATE),
headers=self.headers,
params=params,
timeout=self.timeout,
)

def fetch_detections(self) -> Response:
"""List the detections accessible to the authenticated user
def fetch_latest_sequences(self) -> Response:
"""List the latest sequences accessible to the authenticated user
>>> from pyroclient import client
>>> api_client = Client("MY_USER_TOKEN")
>>> response = api_client.fetch_detections()
>>> response = api_client.fetch_latest_sequences()
Returns:
HTTP response
"""
return requests.get(
self.routes["detections-fetch"],
urljoin(self._route_prefix, ClientRoute.SEQUENCES_FETCH_LATEST),
headers=self.headers,
timeout=self.timeout,
)

def fetch_unlabeled_detections(self, from_date: str) -> Response:
"""List the detections accessible to the authenticated user
def fetch_sequences_detections(self, sequence_id: int) -> Response:
"""List the detections of a sequence
>>> from pyroclient import client
>>> api_client = Client("MY_USER_TOKEN")
>>> response = api_client.fetch_unacknowledged_detections("2023-07-04T00:00:00")
>>> response = api_client.fetch_sequences_detections(1)
Returns:
HTTP response
"""
params = {"from_date": from_date}
return requests.get(
self.routes["detections-fetch-unl"],
urljoin(self._route_prefix, ClientRoute.SEQUENCES_FETCH_DETECTIONS.format(seq_id=sequence_id)),
headers=self.headers,
params=params,
timeout=self.timeout,
)

Expand All @@ -272,7 +305,7 @@ def fetch_organizations(self) -> Response:
HTTP response
"""
return requests.get(
self.routes["organizations-fetch"],
urljoin(self._route_prefix, ClientRoute.ORGS_FETCH),
headers=self.headers,
timeout=self.timeout,
)
19 changes: 17 additions & 2 deletions client/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

import pytest
from requests.exceptions import ConnectionError as ConnError
from requests.exceptions import ReadTimeout
Expand Down Expand Up @@ -42,13 +44,17 @@ def test_cam_workflow(cam_token, mock_img):
assert response.status_code == 201, response.__dict__
response = cam_client.create_detection(mock_img, 123.2, [(0, 0, 1.0, 0.9, 0.5), (0.2, 0.2, 0.7, 0.7, 0.8)])
assert response.status_code == 201, response.__dict__
response = cam_client.create_detection(mock_img, 123.2, [(0, 0, 1.0, 0.9, 0.5)])
assert response.status_code == 201, response.__dict__
return response.json()["id"]


def test_agent_workflow(test_cam_workflow, agent_token):
# Agent workflow
agent_client = Client(agent_token, "http://localhost:5050", timeout=10)
response = agent_client.label_detection(test_cam_workflow, True)
response = agent_client.fetch_latest_sequences().json()
assert len(response) == 1
response = agent_client.label_sequence(response[0]["id"], True)
assert response.status_code == 200, response.__dict__


Expand All @@ -59,5 +65,14 @@ def test_user_workflow(test_cam_workflow, user_token):
assert response.status_code == 200, response.__dict__
response = user_client.fetch_detections()
assert response.status_code == 200, response.__dict__
response = user_client.fetch_unlabeled_detections("2018-06-06T00:00:00")
response = user_client.fetch_sequences_from_date("2018-06-06")
assert len(response.json()) == 0
assert response.status_code == 200, response.__dict__
response = user_client.fetch_latest_sequences()
assert response.status_code == 200, response.__dict__
assert len(response.json()) == 0 # Sequence was labeled by agent
response = user_client.fetch_sequences_from_date(datetime.utcnow().date().isoformat())
assert len(response.json()) == 1
response = user_client.fetch_sequences_detections(response.json()[0]["id"])
assert response.status_code == 200, response.__dict__
assert len(response.json()) == 3
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 13 additions & 2 deletions scripts/dbdiagram.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,26 @@ Table "Camera" as C {
}
}

Table "Sequence" as S {
"id" int [not null]
"camera_id" int [ref: > C.id, not null]
"azimuth" float [not null]
"is_wildfire" bool
"started_at" timestamp [not null]
"last_seen_at" timestamp [not null]
Indexes {
(id) [pk]
}
}

Table "Detection" as D {
"id" int [not null]
"camera_id" int [ref: > C.id, not null]
"sequence_id" int [ref: > S.id]
"azimuth" float [not null]
"bucket_key" varchar [not null]
"bboxes" varchar [not null]
"is_wildfire" bool
"created_at" timestamp [not null]
"updated_at" timestamp [not null]
Indexes {
(id) [pk]
}
Expand Down
37 changes: 24 additions & 13 deletions scripts/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import argparse
import time
from datetime import datetime
from typing import Any, Dict, Optional

import requests
Expand Down Expand Up @@ -100,12 +101,7 @@ def main(args):
)
assert response.status_code == 201, response.text
detection_id = response.json()["id"]

# Fetch unlabeled detections
api_request("get", f"{args.endpoint}/detections/unlabeled/fromdate?from_date=2018-06-06T00:00:00", agent_auth)

# Acknowledge it
api_request("patch", f"{args.endpoint}/detections/{detection_id}/label", agent_auth, {"is_wildfire": True})
today = datetime.fromisoformat(response.json()["created_at"]).date()

# Fetch detections & their URLs
api_request("get", f"{args.endpoint}/detections", agent_auth)
Expand All @@ -127,18 +123,33 @@ def main(args):
timeout=5,
).json()["id"]
# Check that a sequence has been created
sequences = api_request("get", f"{args.endpoint}/sequences", superuser_auth)
assert len(sequences) == 1
assert sequences[0]["camera_id"] == cam_id
assert sequences[0]["started_at"] == response.json()["created_at"]
assert sequences[0]["last_seen_at"] > sequences[0]["started_at"]
assert sequences[0]["azimuth"] == response.json()["azimuth"]
sequence = api_request("get", f"{args.endpoint}/sequences/1", agent_auth)
assert sequence["camera_id"] == cam_id
assert sequence["started_at"] == response.json()["created_at"]
assert sequence["last_seen_at"] > sequence["started_at"]
assert sequence["azimuth"] == response.json()["azimuth"]
# Fetch the latest sequence
assert len(api_request("get", f"{args.endpoint}/sequences/unlabeled/latest", agent_auth)) == 1
# Fetch from date
assert len(api_request("get", f"{args.endpoint}/sequences/all/fromdate?from_date=2019-09-10", agent_auth)) == 0
assert (
len(api_request("get", f"{args.endpoint}/sequences/all/fromdate?from_date={today.isoformat()}", agent_auth))
== 1
)
# Label the sequence
api_request("patch", f"{args.endpoint}/sequences/{sequence['id']}/label", agent_auth, {"is_wildfire": True})
# Check the sequence's detections
dets = api_request("get", f"{args.endpoint}/sequences/{sequence['id']}/detections", agent_auth)
assert len(dets) == 3
assert dets[0]["id"] == det_id_3
assert dets[1]["id"] == det_id_2
assert dets[2]["id"] == detection_id

# Cleaning (order is important because of foreign key protection in existing tables)
api_request("delete", f"{args.endpoint}/sequences/{sequences[0]['id']}/", superuser_auth)
api_request("delete", f"{args.endpoint}/detections/{detection_id}/", superuser_auth)
api_request("delete", f"{args.endpoint}/detections/{det_id_2}/", superuser_auth)
api_request("delete", f"{args.endpoint}/detections/{det_id_3}/", superuser_auth)
api_request("delete", f"{args.endpoint}/sequences/{sequence['id']}/", superuser_auth)
api_request("delete", f"{args.endpoint}/cameras/{cam_id}/", superuser_auth)
api_request("delete", f"{args.endpoint}/users/{user_id}/", superuser_auth)
api_request("delete", f"{args.endpoint}/organizations/{org_id}/", superuser_auth)
Expand Down
Loading

0 comments on commit 6c8b02f

Please sign in to comment.