Skip to content

Commit

Permalink
fix(nn_archive): object detection inheritance for keypoint head (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
conorsim authored Mar 14, 2024
1 parent 10f9a43 commit c63b752
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions luxonis_ml/nn_archive/config_building_blocks/base_models/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,6 @@ class HeadObjectDetectionYOLO(HeadObjectDetection, ABC):
@ivar outputs: A configuration specifying which output names from the `outputs` block of the archive are fed into the head.
@type subtype: ObjectDetectionSubtypeYOLO
@ivar subtype: YOLO family decoding subtype (e.g. v5, v6, v7 etc.).
@type n_keypoints: int
@ivar n_keypoints: Number of keypoints per bbox if provided.
@type n_prototypes: int
@ivar n_prototypes: Number of prototypes per bbox if provided.
@type prototype_output_name: str
@ivar prototype_output_name: Output node containing prototype information.
"""

family: Literal["ObjectDetectionYOLO"] = Field(..., description="Decoding family.")
Expand All @@ -120,15 +114,6 @@ class HeadObjectDetectionYOLO(HeadObjectDetection, ABC):
subtype: ObjectDetectionSubtypeYOLO = Field(
description="YOLO family decoding subtype (e.g. v5, v6, v7 etc.)."
)
n_keypoints: Optional[int] = Field(
None, description="Number of keypoints per bbox if provided."
)
n_prototypes: Optional[int] = Field(
None, description="Number of prototypes per bbox if provided."
)
prototype_output_name: Optional[str] = Field(
None, description="Output node containing prototype information."
)

@field_validator("family")
def validate_label_type(
Expand Down Expand Up @@ -214,6 +199,8 @@ class HeadInstanceSegmentationYOLO(HeadObjectDetectionYOLO, HeadSegmentation, AB
@type postprocessor_path: str
@ivar postprocessor_path: Path to the secondary executable used in YOLO instance
segmentation.
@type n_prototypes: int
@ivar n_prototypes: Number of prototypes per bbox.
"""

family: Literal["InstanceSegmentationYOLO"] = Field(
Expand All @@ -226,6 +213,7 @@ class HeadInstanceSegmentationYOLO(HeadObjectDetectionYOLO, HeadSegmentation, AB
...,
description="Path to the secondary executable used in YOLO instance segmentation.",
)
n_prototypes: int = Field(description="Number of prototypes per bbox.")

@field_validator("family")
def validate_label_type(
Expand All @@ -237,13 +225,15 @@ def validate_label_type(
return value


class HeadKeypointDetectionYOLO(Head, ABC):
class HeadKeypointDetectionYOLO(HeadObjectDetectionYOLO, ABC):
"""Metadata for YOLO keypoint detection head.
@type family: str
@ivar family: Decoding family.
@type outputs: C{OutputsKeypointDetectionYOLO}
@ivar outputs: A configuration specifying which output names from the `outputs` block of the archive are fed into the head.
@type n_keypoints: int
@ivar n_keypoints: Number of keypoints per bbox.
"""

family: Literal["KeypointDetectionYOLO"] = Field(
Expand All @@ -252,6 +242,16 @@ class HeadKeypointDetectionYOLO(Head, ABC):
outputs: OutputsKeypointDetectionYOLO = Field(
description="A configuration specifying which output names from the `outputs` block of the archive are fed into the head."
)
n_keypoints: int = Field(description="Number of keypoints per bbox.")

@field_validator("family")
def validate_label_type(
cls,
value,
):
if value != "KeypointDetectionYOLO":
raise ValueError("Invalid family")
return value


HeadType = Union[
Expand Down

0 comments on commit c63b752

Please sign in to comment.