Skip to content

Commit

Permalink
Fixes for nested labels (#224)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 authored Jan 16, 2025
1 parent 2fbec83 commit c08bd5b
Show file tree
Hide file tree
Showing 12 changed files with 376 additions and 363 deletions.
57 changes: 32 additions & 25 deletions luxonis_ml/data/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import random
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Optional, Set, Tuple

import cv2
import numpy as np
Expand All @@ -16,7 +16,6 @@

from luxonis_ml.data import LuxonisDataset, LuxonisLoader, LuxonisParser
from luxonis_ml.data.utils.constants import LDF_VERSION
from luxonis_ml.data.utils.task_utils import split_task, task_is_metadata
from luxonis_ml.data.utils.visualizations import visualize
from luxonis_ml.enums import DatasetType

Expand Down Expand Up @@ -47,39 +46,47 @@ def check_exists(name: str):
raise typer.Exit()


def get_dataset_info(name: str) -> Tuple[int, List[str], List[str]]:
dataset = LuxonisDataset(name)
size = len(dataset)
classes, _ = dataset.get_classes()
return size, classes, dataset.get_task_names()
def get_dataset_info(dataset: LuxonisDataset) -> Tuple[Set[str], List[str]]:
all_classes = {
c for classes in dataset.get_classes().values() for c in classes
}
return all_classes, dataset.get_task_names()


def print_info(name: str) -> None:
dataset = LuxonisDataset(name)
_, classes = dataset.get_classes()
classes = dataset.get_classes()
class_table = Table(
title="Classes", box=rich.box.ROUNDED, row_styles=["yellow", "cyan"]
)
class_table.add_column("Task Name", header_style="magenta i", max_width=30)
class_table.add_column("Classes", header_style="magenta i", max_width=50)
if len(classes) > 1 or next(iter(classes)):
class_table.add_column(
"Task Name", header_style="magenta i", max_width=30
)
class_table.add_column(
"Class Names", header_style="magenta i", max_width=50
)
for task_name, c in classes.items():
class_table.add_row(task_name, ", ".join(c))
if not task_name:
class_table.add_row(", ".join(c))
else:
class_table.add_row(task_name, ", ".join(c))

tasks = dataset.get_tasks()
tasks.sort(key=task_is_metadata)
task_table = Table(
title="Tasks", box=rich.box.ROUNDED, row_styles=["yellow", "cyan"]
)
task_table.add_column("Task Name", header_style="magenta i", max_width=30)
task_table.add_column("Task Type", header_style="magenta i", max_width=50)
separated = False
for task in tasks:
if task_is_metadata(task):
if not separated:
task_table.add_section()
separated = True
task_name, task_type = split_task(task)
task_table.add_row(task_name, task_type)
if len(tasks) > 1 or next(iter(tasks)):
task_table.add_column(
"Task Name", header_style="magenta i", max_width=30
)
task_table.add_column("Task Types", header_style="magenta i", max_width=50)
for task_name, task_types in tasks.items():
task_types.sort()
if not task_name:
task_table.add_row(", ".join(task_types))
else:
task_table.add_row(task_name, ", ".join(task_types))

splits = dataset.get_splits()

Expand Down Expand Up @@ -150,7 +157,7 @@ def ls(
size = -1
rows.append(str(size))
if full:
_, classes, tasks = get_dataset_info(name)
classes, tasks = get_dataset_info(dataset)
rows.extend(
[
", ".join(classes) if classes else "[red]<empty>[no red]",
Expand Down Expand Up @@ -250,15 +257,15 @@ def inspect(
if len(dataset) == 0:
raise ValueError(f"Dataset '{name}' is empty.")

class_names = dataset.get_classes()[1]
classes = dataset.get_classes()
for image, labels in loader:
image = image.astype(np.uint8)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

h, w, _ = image.shape
new_h, new_w = int(h * size_multiplier), int(w * size_multiplier)
image = cv2.resize(image, (new_w, new_h))
image = visualize(image, labels, class_names, blend_all=blend_all)
image = visualize(image, labels, classes, blend_all=blend_all)
cv2.imshow("image", image)
if cv2.waitKey() == ord("q"):
break
Expand Down
112 changes: 60 additions & 52 deletions luxonis_ml/data/datasets/annotation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union

Expand All @@ -14,7 +13,7 @@
from typeguard import check_type
from typing_extensions import Annotated, Self, TypeAlias, override

from luxonis_ml.data.utils.parquet import ParquetDetection, ParquetRecord
from luxonis_ml.data.utils.parquet import ParquetRecord
from luxonis_ml.utils import BaseModelExtraForbid

logger = logging.getLogger(__name__)
Expand All @@ -41,43 +40,6 @@ class Detection(BaseModelExtraForbid):

sub_detections: Dict[str, "Detection"] = {}

def to_parquet_rows(self) -> Iterable[ParquetDetection]:
yield from self._to_parquet_rows()

def _to_parquet_rows(self, prefix: str = "") -> Iterable[ParquetDetection]:
for task_type in [
"boundingbox",
"keypoints",
"segmentation",
"instance_segmentation",
"array",
]:
label: Optional[Annotation] = getattr(self, task_type)

if label is not None:
yield {
"class_name": self.class_name,
"instance_id": self.instance_id,
"task_type": f"{prefix}{task_type}",
"annotation": label.model_dump_json(),
}
for key, data in self.metadata.items():
yield {
"class_name": self.class_name,
"instance_id": self.instance_id,
"task_type": f"{prefix}metadata/{key}",
"annotation": json.dumps(data),
}
if self.class_name is not None:
yield {
"class_name": self.class_name,
"instance_id": self.instance_id,
"task_type": f"{prefix}classification",
"annotation": "{}",
}
for name, detection in self.sub_detections.items():
yield from detection._to_parquet_rows(f"{prefix}{name}/")

@model_validator(mode="after")
def validate_names(self) -> Self:
for name in self.sub_detections:
Expand Down Expand Up @@ -533,28 +495,74 @@ def to_parquet_rows(self) -> Iterable[ParquetRecord]:
@rtype: L{ParquetDict}
@return: A dictionary of annotation data.
"""
timestamp = datetime.now(timezone.utc)
yield from self._to_parquet_rows(self.annotation, self.task)

def _to_parquet_rows(
self, annotation: Optional[Detection], task_name: str
) -> Iterable[ParquetRecord]:
"""Converts an annotation to a dictionary for writing to a
parquet file.
@rtype: L{ParquetDict}
@return: A dictionary of annotation data.
"""
for source, file_path in self.files.items():
if self.annotation is not None:
for detection in self.annotation.to_parquet_rows():
yield {
"file": str(file_path),
"source_name": source,
"task_name": self.task,
"created_at": timestamp,
**detection,
}
else:
if annotation is None:
yield {
"file": str(file_path),
"source_name": source,
"task_name": self.task,
"created_at": timestamp,
"task_name": task_name,
"class_name": None,
"instance_id": None,
"task_type": None,
"annotation": None,
}
else:
for task_type in [
"boundingbox",
"keypoints",
"segmentation",
"instance_segmentation",
"array",
]:
label: Optional[Annotation] = getattr(
annotation, task_type
)

if label is not None:
yield {
"file": str(file_path),
"source_name": source,
"task_name": task_name,
"class_name": annotation.class_name,
"instance_id": annotation.instance_id,
"task_type": task_type,
"annotation": label.model_dump_json(),
}
for key, data in annotation.metadata.items():
yield {
"file": str(file_path),
"source_name": source,
"task_name": task_name,
"class_name": annotation.class_name,
"instance_id": annotation.instance_id,
"task_type": f"metadata/{key}",
"annotation": json.dumps(data),
}
if annotation.class_name is not None:
yield {
"file": str(file_path),
"source_name": source,
"task_name": task_name,
"class_name": annotation.class_name,
"instance_id": annotation.instance_id,
"task_type": "classification",
"annotation": "{}",
}
for name, detection in annotation.sub_detections.items():
yield from self._to_parquet_rows(
detection, f"{task_name}/{name}"
)


def check_valid_identifier(name: str, *, label: str) -> None:
Expand Down
22 changes: 10 additions & 12 deletions luxonis_ml/data/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from luxonis_ml.data.datasets.annotation import DatasetRecord
from luxonis_ml.data.datasets.source import LuxonisSource
from luxonis_ml.data.utils.task_utils import get_task_name
from luxonis_ml.typing import PathType
from luxonis_ml.utils import AutoRegisterMeta, Registry

Expand Down Expand Up @@ -41,11 +40,11 @@ def version(self) -> Version:
...

@abstractmethod
def get_tasks(self) -> List[str]:
"""Returns the list of tasks in the dataset.
def get_tasks(self) -> Dict[str, str]:
"""Returns a dictionary mapping task names to task types.
@rtype: List[str]
@return: List of task names.
@rtype: Dict[str, str]
@return: A dictionary mapping task names to task types.
"""
...

Expand Down Expand Up @@ -75,13 +74,12 @@ def set_classes(
...

@abstractmethod
def get_classes(self) -> Tuple[List[str], Dict[str, List[str]]]:
"""Gets overall classes in the dataset and classes according to
computer vision task.
def get_classes(self) -> Dict[str, List[str]]:
"""Get classes according to computer vision tasks.
@rtype: Tuple[List[str], Dict]
@return: A combined list of classes for all tasks and a
dictionary mapping tasks to the classes used in each task.
@rtype: Dict[str, List[str]]
@return: A dictionary mapping tasks to the classes used in each
task.
"""
...

Expand Down Expand Up @@ -202,4 +200,4 @@ def get_task_names(self) -> List[str]:
@rtype: List[str]
@return: List of task names.
"""
return [get_task_name(task) for task in self.get_tasks()]
return list(self.get_tasks().keys())
Loading

0 comments on commit c08bd5b

Please sign in to comment.