Skip to content

Commit

Permalink
Merge pull request #803 from roflcoopter/feature/cpai-letterbox-bug
Browse files Browse the repository at this point in the history
fix codeprojectai object detection when running with no image_size
  • Loading branch information
roflcoopter authored Aug 31, 2024
2 parents 94e014c + 527f33b commit cff837d
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 9 deletions.
16 changes: 14 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,35 @@
"""Common mocks for Viseron tests."""
from __future__ import annotations

import datetime
from collections.abc import Callable, Generator
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal
from unittest.mock import MagicMock

import pytest
from sqlalchemy import insert
from sqlalchemy.orm import Session

from viseron.components.storage.models import Files, FilesMeta, Recordings
from viseron.const import LOADED
from viseron.domains.camera.const import DOMAIN as CAMERA_DOMAIN
from viseron.helpers import utcnow

if TYPE_CHECKING:
from viseron import Viseron


class MockComponent:
"""Representation of a fake component."""

def __init__(self, component, setup_component=None):
def __init__(self, component, vis: Viseron | None = None, setup_component=None):
"""Initialize the mock component."""
self.__name__ = f"viseron.components.{component}"
self.__file__ = f"viseron/components/{component}"

self.name = component
if vis:
vis.data[LOADED][component] = self
if setup_component is not None:
self.setup_component = setup_component

Expand All @@ -30,6 +39,7 @@ class MockCamera(MagicMock):

def __init__( # pylint: disable=dangerous-default-value
self,
vis: Viseron | None = None,
identifier="test_camera_identifier",
resolution=(1920, 1080),
extension="mp4",
Expand All @@ -46,6 +56,8 @@ def __init__( # pylint: disable=dangerous-default-value
access_tokens=access_tokens,
**kwargs,
)
if vis:
vis.register_domain(CAMERA_DOMAIN, identifier, self)


def return_any(cls: type[Any]):
Expand Down
296 changes: 296 additions & 0 deletions tests/components/codeprojectai/test_object_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
"""CodeProjectAI object detector tests."""
from unittest.mock import Mock, patch

import numpy as np
import pytest

from viseron import Viseron
from viseron.components.codeprojectai import CONFIG_SCHEMA
from viseron.components.codeprojectai.const import COMPONENT
from viseron.components.codeprojectai.object_detector import (
DOMAIN as OBJECT_DETECTOR_DOMAIN,
ObjectDetector,
setup as cpai_setup,
)
from viseron.domains.object_detector.detected_object import DetectedObject

from tests.common import MockCamera, MockComponent
from tests.conftest import MockViseron

CAMERA_IDENTIFIER = "test_camera"


@pytest.fixture(name="mock_detected_object")
def fixture_mock_detected_object():
"""Fixture to provide a mocked DetectedObject class."""
with patch(
"viseron.components.codeprojectai.object_detector.DetectedObject"
) as mock:
yield mock


@pytest.fixture
def config():
"""
Fixture to provide a test configuration.
Returns:
dict: A dictionary containing the test configuration.
"""
return CONFIG_SCHEMA(
{
"codeprojectai": {
"host": "localhost",
"port": 32168,
"object_detector": {
"image_size": 640,
"cameras": {
CAMERA_IDENTIFIER: {
"labels": [
{
"label": "person",
"confidence": 0.8,
"trigger_recorder": True,
}
],
}
},
},
}
}
)


def test_setup(vis: Viseron, config):
"""
Test the setup function of the CodeProjectAI object detector.
Args:
vis (Viseron): The Viseron instance.
config (dict): The configuration dictionary.
"""
with patch(
"viseron.components.codeprojectai.object_detector.ObjectDetector"
) as mock_object_detector:
result = cpai_setup(vis, config, CAMERA_IDENTIFIER)
assert result is True
mock_object_detector.assert_called_once_with(vis, config, CAMERA_IDENTIFIER)


def test_object_detector_init(vis: MockViseron, config):
"""
Test the initialization of the ObjectDetector class.
Args:
vis (MockViseron): The mocked Viseron instance.
config (dict): The configuration dictionary.
"""
_ = MockComponent(COMPONENT, vis)
_ = MockCamera(vis, identifier=CAMERA_IDENTIFIER)
with patch("codeprojectai.core.CodeProjectAIObject"):
detector = ObjectDetector(vis, config["codeprojectai"], CAMERA_IDENTIFIER)
assert detector._image_resolution == ( # pylint: disable=protected-access
640,
640,
)
vis.mocked_register_domain.assert_called_with(
OBJECT_DETECTOR_DOMAIN, CAMERA_IDENTIFIER, detector
)


def test_preprocess(vis: Viseron, config):
"""
Test the preprocess method of the ObjectDetector class.
Args:
vis (Viseron): The Viseron instance.
config (dict): The configuration dictionary.
"""
_ = MockComponent(COMPONENT, vis)
_ = MockCamera(vis, identifier=CAMERA_IDENTIFIER)
with patch("codeprojectai.core.CodeProjectAIObject"):
detector = ObjectDetector(vis, config["codeprojectai"], CAMERA_IDENTIFIER)
frame = np.zeros((480, 640, 3), dtype=np.uint8)
processed = detector.preprocess(frame)
assert isinstance(processed, bytes)


def test_postprocess(vis: Viseron, config):
"""
Test the postprocess method of the ObjectDetector class.
Args:
vis (Viseron): The Viseron instance.
config (dict): The configuration dictionary.
"""
_ = MockComponent(COMPONENT, vis)
_ = MockCamera(vis, identifier=CAMERA_IDENTIFIER)
with patch("codeprojectai.core.CodeProjectAIObject"):
detector = ObjectDetector(vis, config["codeprojectai"], CAMERA_IDENTIFIER)
detections = [
{
"label": "person",
"confidence": 0.9,
"x_min": 100,
"y_min": 100,
"x_max": 200,
"y_max": 200,
}
]
objects = detector.postprocess(detections)
assert len(objects) == 1
assert isinstance(objects[0], DetectedObject)


@patch("codeprojectai.core.CodeProjectAIObject.detect")
def test_return_objects_success(mock_detect, vis: Viseron, config):
"""
Test the return_objects method of the ObjectDetector class for successful detection.
Args:
mock_detect (MagicMock): Mocked detect method.
vis (Viseron): The Viseron instance.
config (dict): The configuration dictionary.
"""
_ = MockComponent(COMPONENT, vis)
_ = MockCamera(vis, identifier=CAMERA_IDENTIFIER)
mock_detect.return_value = [
{
"label": "person",
"confidence": 0.9,
"x_min": 100,
"y_min": 100,
"x_max": 200,
"y_max": 200,
}
]
detector = ObjectDetector(vis, config["codeprojectai"], CAMERA_IDENTIFIER)
frame = np.zeros((480, 640, 3), dtype=np.uint8)
objects = detector.return_objects(frame)
assert len(objects) == 1
assert isinstance(objects[0], DetectedObject)


@patch("codeprojectai.core.CodeProjectAIObject.detect")
def test_return_objects_exception(mock_detect, vis: Viseron, config):
"""
Test the return_objects method of the ObjectDetector class when an exception occurs.
Args:
mock_detect (MagicMock): Mocked detect method.
vis (Viseron): The Viseron instance.
config (dict): The configuration dictionary.
"""
from codeprojectai.core import ( # pylint: disable=import-outside-toplevel
CodeProjectAIException,
)

_ = MockComponent(COMPONENT, vis)
_ = MockCamera(vis, identifier=CAMERA_IDENTIFIER)
mock_detect.side_effect = CodeProjectAIException("Test error")
detector = ObjectDetector(vis, config["codeprojectai"], CAMERA_IDENTIFIER)
frame = np.zeros((480, 640, 3), dtype=np.uint8)
objects = detector.return_objects(frame)
assert len(objects) == 0


def test_object_detector_init_no_image_size(vis: Viseron, config, mock_detected_object):
"""
Test the initialization of the ObjectDetector class when image_size is not set.
Args:
vis (Viseron): The Viseron instance.
config (dict): The configuration dictionary.
mock_detected_object (MagicMock): Mocked DetectedObject class.
"""
with patch("codeprojectai.core.CodeProjectAIObject"):
# Set non-square image resolution
config["codeprojectai"]["object_detector"]["image_size"] = None

# Mock camera with non-square resolution
_ = MockComponent(COMPONENT, vis)
_ = MockCamera(vis, identifier=CAMERA_IDENTIFIER, resolution=(1280, 720))

detector = ObjectDetector(vis, config["codeprojectai"], CAMERA_IDENTIFIER)

detections = [
{
"label": "person",
"confidence": 0.9,
"x_min": 100,
"y_min": 100,
"x_max": 200,
"y_max": 200,
}
]

objects = detector.postprocess(detections)

assert len(objects) == 1
assert isinstance(objects[0], Mock)

# Check if from_absolute was called instead of from_absolute_letterboxed
mock_detected_object.from_absolute.assert_called_once()
mock_detected_object.from_absolute_letterboxed.assert_not_called()

# Check the arguments passed to from_absolute
mock_detected_object.from_absolute.assert_called_with(
"person",
0.9,
100,
100,
200,
200,
frame_res=(1280, 720),
model_res=(1280, 720),
)


def test_postprocess_square_resolution(vis: Viseron, config, mock_detected_object):
"""
Test the postprocess method of the ObjectDetector class with a square resolution.
Args:
vis (Viseron): The Viseron instance.
config (dict): The configuration dictionary.
mock_detected_object (MagicMock): Mocked DetectedObject class.
"""
with patch("codeprojectai.core.CodeProjectAIObject"):
# Set square image resolution
config["codeprojectai"]["object_detector"]["image_size"] = 640

# Mock camera with square resolution
_ = MockComponent(COMPONENT, vis)
_ = MockCamera(vis, identifier=CAMERA_IDENTIFIER, resolution=(640, 640))

detector = ObjectDetector(vis, config["codeprojectai"], CAMERA_IDENTIFIER)

detections = [
{
"label": "person",
"confidence": 0.9,
"x_min": 100,
"y_min": 100,
"x_max": 200,
"y_max": 200,
}
]

objects = detector.postprocess(detections)

assert len(objects) == 1

# Check if from_absolute_letterboxed was called instead of from_absolute
mock_detected_object.from_absolute_letterboxed.assert_called_once()
mock_detected_object.from_absolute.assert_not_called()

mock_detected_object.from_absolute_letterboxed.assert_called_with(
"person",
0.9,
100,
100,
200,
200,
frame_res=(640, 640),
model_res=(640, 640),
)
18 changes: 15 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from collections.abc import Generator, Iterator
from typing import Any
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest
from pytest_postgresql import factories
Expand All @@ -16,19 +16,31 @@
from viseron.components.storage import COMPONENT as STORAGE, Storage
from viseron.components.storage.models import Base
from viseron.components.webserver import COMPONENT as WEBSERVER, Webserver
from viseron.const import LOADED

from tests.common import MockCamera

test_db = factories.postgresql_proc(port=None, dbname="test_db")


class MockViseron(Viseron):
"""Protocol for mocking Viseron."""

def __init__(self) -> None:
super().__init__()
self.register_domain = Mock(side_effect=self.register_domain) # type: ignore
self.mocked_register_domain = self.register_domain # type: ignore


@pytest.fixture
def vis() -> Viseron:
def vis() -> MockViseron:
"""Fixture to test Viseron instance."""
viseron = Viseron()
viseron = MockViseron()
viseron.data[DATA_STREAM] = MagicMock(spec=DataStream)
viseron.data[STORAGE] = MagicMock(spec=Storage)
viseron.data[WEBSERVER] = MagicMock(spec=Webserver)
viseron.data[LOADED] = {}

return viseron


Expand Down
Loading

0 comments on commit cff837d

Please sign in to comment.