Skip to content

Commit

Permalink
working tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Dronakurl committed Feb 8, 2025
1 parent 1b67e5f commit ae96f0c
Show file tree
Hide file tree
Showing 28 changed files with 1,734 additions and 1,376 deletions.
46 changes: 24 additions & 22 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
# python-version: ["3.12"]
# python-version: ["3.10"]
steps:
- uses: actions/checkout@v4
# opencv is needed, because otherwise the following error occurs
Expand All @@ -32,6 +32,17 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
enable-cache: true
# - name: Restore Ubuntu cache
# uses: actions/cache@v4
# with:
# path: .venv
# key: ${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml')}}
- name: Test with python ${{ matrix.python-version }}
run: |
uv sync --extra ci -U
# uv run --extra ci pytest
# uv run --extra ci --frozen python --version
uv run python -c 'from mmdet.apis.det_inferencer import DetInferencer'
- name: Cache apt packages
uses: actions/cache@v4
with:
Expand All @@ -41,32 +52,23 @@ jobs:
- name: Install libgl
# run: apt-get update && apt-get install -y python3-opencv
run: apt-get update && apt-get install -y libgl1
- name: Restore Ubuntu cache
uses: actions/cache@v4
with:
path: .venv
key: ${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml')}}
- name: Test with python ${{ matrix.python-version }}
run: |
uv sync --extra ci -U
# uv run --extra ci pytest
uv run --extra ci --frozen python --version
- name: Test SAHI CLI
run: |
source .venv/bin/activate
# help
uv run sahi --help
sahi --help
# predict mmdet
uv run sahi predict --source tests/data/ --novisual --model_path tests/data/models/mmdet/yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth --model_config_path tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py --image_size 320
uv run sahi predict --source tests/data/coco_utils/terrain1.jpg --export_pickle --export_crop --model_path tests/data/models/mmdet/yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth --model_config_path tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py --image_size 320
uv run sahi predict --source tests/data/coco_utils/ --novisual --dataset_json_path tests/data/coco_utils/combined_coco.json --model_path tests/data/models/mmdet/yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth --model_config_path tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py --image_size 320
sahi predict --source tests/data/ --novisual --model_path tests/data/models/mmdet/yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth --model_config_path tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py --image_size 320
sahi predict --source tests/data/coco_utils/terrain1.jpg --export_pickle --export_crop --model_path tests/data/models/mmdet/yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth --model_config_path tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py --image_size 320
sahi predict --source tests/data/coco_utils/ --novisual --dataset_json_path tests/data/coco_utils/combined_coco.json --model_path tests/data/models/mmdet/yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth --model_config_path tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py --image_size 320
# predict yolov5
uv run sahi predict --no_sliced_prediction --model_type yolov5 --source tests/data/coco_utils/terrain1.jpg --novisual --model_path tests/data/models/yolov5/yolov5s6.pt --image_size 320
uv run sahi predict --model_type yolov5 --source tests/data/ --novisual --model_path tests/data/models/yolov5/yolov5s6.pt --image_size 320
uv run sahi predict --model_type yolov5 --source tests/data/coco_utils/terrain1.jpg --export_pickle --export_crop --model_path tests/data/models/yolov5/yolov5s6.pt --image_size 320
uv run sahi predict --model_type yolov5 --source tests/data/coco_utils/ --novisual --dataset_json_path tests/data/coco_utils/combined_coco.json --model_path tests/data/models/yolov5/yolov5s6.pt --image_size 320
sahi predict --no_sliced_prediction --model_type yolov5 --source tests/data/coco_utils/terrain1.jpg --novisual --model_path tests/data/models/yolov5/yolov5s6.pt --image_size 320
sahi predict --model_type yolov5 --source tests/data/ --novisual --model_path tests/data/models/yolov5/yolov5s6.pt --image_size 320
sahi predict --model_type yolov5 --source tests/data/coco_utils/terrain1.jpg --export_pickle --export_crop --model_path tests/data/models/yolov5/yolov5s6.pt --image_size 320
sahi predict --model_type yolov5 --source tests/data/coco_utils/ --novisual --dataset_json_path tests/data/coco_utils/combined_coco.json --model_path tests/data/models/yolov5/yolov5s6.pt --image_size 320
# coco yolov5
uv run sahi coco yolov5 --image_dir tests/data/coco_utils/ --dataset_json_path tests/data/coco_utils/combined_coco.json --train_split 0.9
sahi coco yolov5 --image_dir tests/data/coco_utils/ --dataset_json_path tests/data/coco_utils/combined_coco.json --train_split 0.9
# coco evaluate
uv run sahi coco evaluate --dataset_json_path tests/data/coco_evaluate/dataset.json --result_json_path tests/data/coco_evaluate/result.json
sahi coco evaluate --dataset_json_path tests/data/coco_evaluate/dataset.json --result_json_path tests/data/coco_evaluate/result.json
# coco analyse
uv run sahi coco analyse --dataset_json_path tests/data/coco_evaluate/dataset.json --result_json_path tests/data/coco_evaluate/result.json --out_dir tests/data/coco_evaluate/
sahi coco analyse --dataset_json_path tests/data/coco_evaluate/dataset.json --result_json_path tests/data/coco_evaluate/result.json --out_dir tests/data/coco_evaluate/
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,6 @@ cython_debug/
!.elasticbeanstalk/*.cfg.yml
!.elasticbeanstalk/*.global.yml
tests/data

.archive
.python-version
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ Find detailed info at [Interactive Result Visualization and Inspection](https://

Find detailed info on COCO utilities (yolov5 conversion, slicing, subsampling, filtering, merging, splitting) at [coco.md](docs/coco.md).

<!-- TODO: The mot link is missing -->
<!-- TODO: The 'mot' link is missing -->
<!-- Find detailed info on MOT utilities (ground truth dataset creation, exporting tracker metrics in mot challenge format) at [mot.md](docs/mot.md). -->

## <div align="center">Citation</div>
Expand Down
55 changes: 31 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,38 +38,42 @@ homepage = "https://github.com/obss/sahi"
[project.scripts]
sahi = "sahi.cli:app"

[project.optional-dependencies]
ci = [
# "torch>=1.13.1;python_version<'3.12'",
"torch>=1.13.0;python_version<='3.11'",
"torch==2.3.0;python_version>'3.11'",
"torchvision>=0.14.1",
"mmengine==0.7.3;python_version<'3.12'",
"mmcv==2.0.0;python_version<'3.12'",
"mmdet==3.0.0;python_version<'3.12'",
# "yolov5>=7.0.13",
"transformers>=4.35.0",
"pycocotools>=2.0.7",
"ultralytics>=8.3.50",
"scikit-image",
"fiftyone",
"onnx;python_version>='3.10'",
"onnxruntime;python_version>='3.10'",
# deepsparse depends on onnxruntime and is only available with 3.10 and 3.11
"deepsparse;python_version>='3.10' and python_version<'3.12'",
]

[tool.uv]
find-links = [
"https://download.openmmlab.com/mmcv/dist/cpu/torch1.13.0/index.html",
]
default-groups = ["dev", "ci"]

dev-dependencies = [
[dependency-groups]
dev = [
"pytest",
"ruff",
"pre-commit>=2.0",
"jupyterlab==3.0.14",
"matplotlib-stubs>=0.2.0",
"pipdeptree>=2.24.0",
]
ci = [
# pytorch should be present for all python versions
"torch==1.13.1;python_version<'3.11' and sys_platform=='linux' and platform_machine=='x86_64'",
"torch>1.13.1;python_version>='3.11' and sys_platform=='linux' and platform_machine=='x86_64'",
"torchvision>=0.14.1;sys_platform=='linux' and platform_machine=='x86_64'",
# mmcv is available for python < 3.11
"mmengine==0.7.3;python_version<'3.11'",
"mmcv==2.0.0;python_version<'3.11'",
"mmdet==3.0.0;python_version<'3.11'",
# onnx is available for python >=3.10
"onnx;python_version>='3.10'",
"onnxruntime;python_version>='3.10'",
# deepsparse depends on onnxruntime and is only available with 3.10 and 3.11
"deepsparse;python_version>='3.10' and python_version<'3.12'",
# These are available everywhere
"yolov5>=7.0.13;sys_platform=='linux' and platform_machine=='x86_64'",
"transformers>=4.35.0",
"pycocotools>=2.0.7",
"ultralytics>=8.3.50;sys_platform=='linux' and platform_machine=='x86_64'",
"scikit-image",
"fiftyone",
]

[[tool.uv.index]]
Expand All @@ -78,8 +82,8 @@ url = "https://download.pytorch.org/whl/cpu"
explicit = true

[tool.uv.sources]
torch = { index = "pytorch-cpu", extra = "ci" }
torchvision = { index = "pytorch-cpu", extra = "ci" }
torch = { index = "pytorch-cpu", marker = "sys_platform == 'linux' and platform_machine == 'x86_64'" }
torchvision = { index = "pytorch-cpu", marker = "sys_platform == 'linux' and platform_machine == 'x86_64'" }

[build-system]
requires = ["hatchling"]
Expand All @@ -93,3 +97,6 @@ exclude = ["**/__init__.py", ".git", "__pycache__", "*.ipynb"]
minversion = "6.0"
addopts = ["--import-mode=importlib", "--no-header"]
pythonpath = ["."]

[tool.typos.default]
extend-ignore-identifiers-re = ["fo"]
27 changes: 12 additions & 15 deletions sahi/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Code written by Fatih C Akyon, 2020.

import copy
import logging
from typing import Dict, List, Optional

import numpy as np
Expand All @@ -14,12 +15,15 @@
)
from sahi.utils.shapely import ShapelyAnnotation

logger = logging.getLogger(__name__)


class BoundingBox:
"""
Bounding box of the annotation.
"""

# TODO: Better use tuple not lists for data that has a defined length and should no mutate a lot
def __init__(self, box: List[float], shift_amount: List[int] = [0, 0]):
"""
Args:
Expand Down Expand Up @@ -176,20 +180,13 @@ def __init__(
To shift the box and mask predictions from sliced image to full
sized image, should be in the form of [shift_x, shift_y]
"""
# confirm full_shape is given
if full_shape is None:
raise ValueError("full_shape must be provided")
raise ValueError("full_shape must be provided") # type: ignore[reportUnreachable]

self.shift_x = shift_amount[0]
self.shift_y = shift_amount[1]

if full_shape:
self.full_shape_height = full_shape[0]
self.full_shape_width = full_shape[1]
else:
self.full_shape_height = None
self.full_shape_width = None

self.full_shape_height = full_shape[0]
self.full_shape_width = full_shape[1]
self.segmentation = segmentation

@classmethod
Expand All @@ -216,20 +213,20 @@ def from_bool_mask(
)

@property
def bool_mask(self):
def bool_mask(self) -> np.ndarray:
return get_bool_mask_from_coco_segmentation(
self.segmentation, width=self.full_shape[1], height=self.full_shape[0]
)

@property
def shape(self):
def shape(self) -> List[int]:
"""
Returns mask shape as [height, width]
"""
return [self.bool_mask.shape[0], self.bool_mask.shape[1]]

@property
def full_shape(self):
def full_shape(self) -> List[int]:
"""
Returns full mask shape after shifting as [height, width]
"""
Expand All @@ -242,7 +239,7 @@ def shift_amount(self):
"""
return [self.shift_x, self.shift_y]

def get_shifted_mask(self):
def get_shifted_mask(self) -> "Mask":
# Confirm full_shape is specified
if (self.full_shape_height is None) or (self.full_shape_width is None):
raise ValueError("full_shape is None")
Expand Down Expand Up @@ -382,7 +379,7 @@ def from_coco_annotation_dict(
cls,
annotation_dict: Dict,
full_shape: List[int],
category_name: str = None,
category_name: Optional[str] = None,
shift_amount: Optional[List[int]] = [0, 0],
):
"""
Expand Down
35 changes: 15 additions & 20 deletions sahi/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Code written by Fatih C Akyon, 2020.

import logging
from typing import Any, Dict, List, Optional, TypeVar, Union
from typing import Any, Dict, List, Optional

import numpy as np

Expand All @@ -12,17 +12,14 @@

logger = logging.getLogger(__name__)

T = TypeVar("T")
ListOrListofList = Union[List[List[T]], List[T]]


class DetectionModel:
def __init__(
self,
model_path: Optional[str] = None,
model: Optional[Any] = None,
config_path: Optional[str] = None,
device: Optional[str] = "cpu",
device: str = "cpu",
mask_threshold: float = 0.5,
confidence_threshold: float = 0.3,
category_mapping: Optional[Dict] = None,
Expand All @@ -37,8 +34,7 @@ def __init__(
Path for the instance segmentation model weight
config_path: str
Path for the mmdetection instance segmentation model config file
device: str
Torch device, "cpu" or "cuda"
device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
mask_threshold: float
Value to threshold mask pixels, should be between 0 and 1
confidence_threshold: float
Expand All @@ -55,15 +51,13 @@ def __init__(
self.model_path = model_path
self.config_path = config_path
self.model = None
self.device: Optional[str] = device
self.mask_threshold = mask_threshold
self.confidence_threshold = confidence_threshold
self.category_mapping = category_mapping
self.category_remapping = category_remapping
self.image_size = image_size
self._original_predictions = None
self._object_prediction_list_per_image = None

self.set_device()

# automatically load model if load_at_init is True
Expand Down Expand Up @@ -96,13 +90,14 @@ def set_model(self, model: Any, **kwargs):
"""
raise NotImplementedError()

def set_device(self):
"""
Sets the device for the model.
def set_device(self, device: str = "cpu"):
"""Sets the device pytorch should use for the model
Args:
device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
"""
if is_available("torch") and self.device is not None:
# TODO: This reassigns self.device not to be a string any more, but the
self.device = select_torch_device(self.device)
if is_available("torch"):
self.device = select_torch_device(device)
else:
raise NotImplementedError(f"Could not set device {self.device}")

Expand All @@ -128,8 +123,8 @@ def perform_inference(self, image: np.ndarray):

def _create_object_prediction_list_from_original_predictions(
self,
shift_amount_list: Optional[ListOrListofList] = [[0, 0]],
full_shape_list: Optional[ListOrListofList] = None,
shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
full_shape_list: Optional[List[List[int]]] = None,
):
"""
This function should be implemented in a way that self._original_predictions should
Expand Down Expand Up @@ -166,8 +161,8 @@ def _apply_category_remapping(self):

def convert_original_predictions(
self,
shift_amount: Optional[ListOrListofList] = [0, 0],
full_shape: Optional[ListOrListofList] = None,
shift_amount: Optional[List[List[int]]] = [[0, 0]],
full_shape: Optional[List[List[int]]] = None,
):
"""
Converts original predictions of the detection model to a list of
Expand All @@ -186,7 +181,7 @@ def convert_original_predictions(
self._apply_category_remapping()

@property
def object_prediction_list(self) -> ListOrListofList[ObjectPrediction]:
def object_prediction_list(self) -> List[List[ObjectPrediction]]:
if self._object_prediction_list_per_image is None:
return []
if len(self._object_prediction_list_per_image) == 0:
Expand Down
6 changes: 3 additions & 3 deletions sahi/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
model: Optional[Any] = None,
processor: Optional[Any] = None,
config_path: Optional[str] = None,
device: Optional[str] = None,
device: str = "cpu",
mask_threshold: float = 0.5,
confidence_threshold: float = 0.3,
category_mapping: Optional[Dict] = None,
Expand Down Expand Up @@ -99,7 +99,7 @@ def perform_inference(self, image: Union[List, np.ndarray]):
import torch

# Confirm model is loaded
if self.model is None:
if self.model is None or self.processor is None:
raise RuntimeError("Model is not loaded, load it by calling .load_model()")

with torch.no_grad():
Expand Down Expand Up @@ -156,7 +156,7 @@ def _create_object_prediction_list_from_original_predictions(
"""
original_predictions = self._original_predictions

# compatilibty for sahi v0.8.15
# compatibility for sahi v0.8.15
shift_amount_list = fix_shift_amount_list(shift_amount_list)
full_shape_list = fix_full_shape_list(full_shape_list)

Expand Down
Loading

0 comments on commit ae96f0c

Please sign in to comment.