Skip to content

Commit

Permalink
Added name attribute to HeadType (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
klemen1999 authored Oct 3, 2024
1 parent 9bda0c5 commit 46e2348
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion luxonis_ml/nn_archive/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .model import Model

CONFIG_VERSION = Literal["1.0"]
CONFIG_VERSION = Literal["1.0", "1.1"]


class Config(BaseModelExtraForbid):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
class Head(BaseModel, ABC):
"""Represents head of a model.
@type name: str | None
@ivar name: Optional name of the head.
@type parser: str
@ivar parser: Name of the parser responsible for processing the models output.
@type outputs: List[str] | None
Expand All @@ -27,6 +29,7 @@ class Head(BaseModel, ABC):
@ivar metadata: Metadata of the parser.
"""

name: Optional[str] = Field(None, description="Optional name of the head.")
parser: str = Field(
description="Name of the parser responsible for processing the models output."
)
Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
onnx>=1.14.0,<=1.16.2
pre-commit>=3.2.1
pytest-cov>=4.1.0
pytest-dependency>=0.6.0
pytest-subtests>=0.12.1
pytest-md>=0.2.0
gdown>=4.7.1
coverage-badge>=1.1.0
coverage-badge>=1.1.0
8 changes: 8 additions & 0 deletions tests/test_nn_archive/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@

classification_head = dict(
Head(
name="ClassificationHead",
parser="Classification",
outputs=["output"],
metadata=head_classification_metadata,
Expand All @@ -164,6 +165,7 @@

ssd_object_detection_head = dict(
Head(
name="ObjectDetectionSSDHead",
parser="ObjectDetectionSSD",
outputs=["boxes"],
metadata=head_object_detection_ssd_metadata,
Expand All @@ -172,6 +174,7 @@

yolo_object_detection_head = dict(
Head(
name="YoloDetectionHead",
parser="YOLO",
outputs=["output"],
metadata=head_yolo_obb_det_metadata,
Expand All @@ -180,6 +183,7 @@

yolo_instance_segmentation_head = dict(
Head(
name="YoloInstanceSegHead",
parser="YOLO",
outputs=["output"],
metadata=head_yolo_instance_seg_metadata,
Expand All @@ -188,6 +192,7 @@

yolo_keypoint_detection_head = dict(
Head(
name="YoloKeypointDetectionHead",
parser="YOLO",
outputs=["output"],
metadata=head_yolo_keypoint_det_metadata,
Expand All @@ -196,6 +201,7 @@

yolo_obb_detection_head = dict(
Head(
name="YoloOBBHead",
parser="YOLO",
outputs=["output"],
metadata=head_yolo_obb_det_metadata,
Expand All @@ -204,6 +210,7 @@

yolo_instance_seg_kpts_head = dict(
Head(
name="YoloInstaceSegKptHead",
parser="YOLO",
outputs=["outputs"],
metadata=head_yolo_instance_seg_kpts_metadata,
Expand All @@ -212,6 +219,7 @@

custom_segmentation_head = dict(
Head(
name="SegmentationHead",
parser="Segmentation",
outputs=["output"],
metadata=head_segmentation_metadata,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nn_archive/test_nn_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_archive_generator(
archive_name=archive_name,
save_path="tests/data/test_nn_archive",
cfg_dict={
"config_version": "1.0",
"config_version": "1.1",
"model": {
"metadata": {
"name": "test_model",
Expand Down

0 comments on commit 46e2348

Please sign in to comment.