Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for nested labels #224

Merged
merged 8 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 32 additions & 25 deletions luxonis_ml/data/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import random
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Optional, Set, Tuple

import cv2
import numpy as np
Expand All @@ -16,7 +16,6 @@

from luxonis_ml.data import LuxonisDataset, LuxonisLoader, LuxonisParser
from luxonis_ml.data.utils.constants import LDF_VERSION
from luxonis_ml.data.utils.task_utils import split_task, task_is_metadata
from luxonis_ml.data.utils.visualizations import visualize
from luxonis_ml.enums import DatasetType

Expand Down Expand Up @@ -47,39 +46,47 @@ def check_exists(name: str):
raise typer.Exit()


def get_dataset_info(name: str) -> Tuple[int, List[str], List[str]]:
dataset = LuxonisDataset(name)
size = len(dataset)
classes, _ = dataset.get_classes()
return size, classes, dataset.get_task_names()
def get_dataset_info(dataset: LuxonisDataset) -> Tuple[Set[str], List[str]]:
all_classes = {
c for classes in dataset.get_classes().values() for c in classes
}
return all_classes, dataset.get_task_names()


def print_info(name: str) -> None:
dataset = LuxonisDataset(name)
_, classes = dataset.get_classes()
classes = dataset.get_classes()
class_table = Table(
title="Classes", box=rich.box.ROUNDED, row_styles=["yellow", "cyan"]
)
class_table.add_column("Task Name", header_style="magenta i", max_width=30)
class_table.add_column("Classes", header_style="magenta i", max_width=50)
if len(classes) > 1 or next(iter(classes)):
class_table.add_column(
"Task Name", header_style="magenta i", max_width=30
)
class_table.add_column(
"Class Names", header_style="magenta i", max_width=50
)
for task_name, c in classes.items():
class_table.add_row(task_name, ", ".join(c))
if not task_name:
class_table.add_row(", ".join(c))
else:
class_table.add_row(task_name, ", ".join(c))

tasks = dataset.get_tasks()
tasks.sort(key=task_is_metadata)
task_table = Table(
title="Tasks", box=rich.box.ROUNDED, row_styles=["yellow", "cyan"]
)
task_table.add_column("Task Name", header_style="magenta i", max_width=30)
task_table.add_column("Task Type", header_style="magenta i", max_width=50)
separated = False
for task in tasks:
if task_is_metadata(task):
if not separated:
task_table.add_section()
separated = True
task_name, task_type = split_task(task)
task_table.add_row(task_name, task_type)
if len(tasks) > 1 or next(iter(tasks)):
task_table.add_column(
"Task Name", header_style="magenta i", max_width=30
)
task_table.add_column("Task Types", header_style="magenta i", max_width=50)
for task_name, task_types in tasks.items():
task_types.sort()
if not task_name:
task_table.add_row(", ".join(task_types))
else:
task_table.add_row(task_name, ", ".join(task_types))

splits = dataset.get_splits()

Expand Down Expand Up @@ -150,7 +157,7 @@ def ls(
size = -1
rows.append(str(size))
if full:
_, classes, tasks = get_dataset_info(name)
classes, tasks = get_dataset_info(dataset)
rows.extend(
[
", ".join(classes) if classes else "[red]<empty>[no red]",
Expand Down Expand Up @@ -250,15 +257,15 @@ def inspect(
if len(dataset) == 0:
raise ValueError(f"Dataset '{name}' is empty.")

class_names = dataset.get_classes()[1]
classes = dataset.get_classes()
for image, labels in loader:
image = image.astype(np.uint8)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

h, w, _ = image.shape
new_h, new_w = int(h * size_multiplier), int(w * size_multiplier)
image = cv2.resize(image, (new_w, new_h))
image = visualize(image, labels, class_names, blend_all=blend_all)
image = visualize(image, labels, classes, blend_all=blend_all)
cv2.imshow("image", image)
if cv2.waitKey() == ord("q"):
break
Expand Down
112 changes: 60 additions & 52 deletions luxonis_ml/data/datasets/annotation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union

Expand All @@ -14,7 +13,7 @@
from typeguard import check_type
from typing_extensions import Annotated, Self, TypeAlias, override

from luxonis_ml.data.utils.parquet import ParquetDetection, ParquetRecord
from luxonis_ml.data.utils.parquet import ParquetRecord
from luxonis_ml.utils import BaseModelExtraForbid

logger = logging.getLogger(__name__)
Expand All @@ -41,43 +40,6 @@ class Detection(BaseModelExtraForbid):

sub_detections: Dict[str, "Detection"] = {}

def to_parquet_rows(self) -> Iterable[ParquetDetection]:
yield from self._to_parquet_rows()

def _to_parquet_rows(self, prefix: str = "") -> Iterable[ParquetDetection]:
for task_type in [
"boundingbox",
"keypoints",
"segmentation",
"instance_segmentation",
"array",
]:
label: Optional[Annotation] = getattr(self, task_type)

if label is not None:
yield {
"class_name": self.class_name,
"instance_id": self.instance_id,
"task_type": f"{prefix}{task_type}",
"annotation": label.model_dump_json(),
}
for key, data in self.metadata.items():
yield {
"class_name": self.class_name,
"instance_id": self.instance_id,
"task_type": f"{prefix}metadata/{key}",
"annotation": json.dumps(data),
}
if self.class_name is not None:
yield {
"class_name": self.class_name,
"instance_id": self.instance_id,
"task_type": f"{prefix}classification",
"annotation": "{}",
}
for name, detection in self.sub_detections.items():
yield from detection._to_parquet_rows(f"{prefix}{name}/")

@model_validator(mode="after")
def validate_names(self) -> Self:
for name in self.sub_detections:
Expand Down Expand Up @@ -533,28 +495,74 @@ def to_parquet_rows(self) -> Iterable[ParquetRecord]:
@rtype: L{ParquetDict}
@return: A dictionary of annotation data.
"""
timestamp = datetime.now(timezone.utc)
yield from self._to_parquet_rows(self.annotation, self.task)

def _to_parquet_rows(
self, annotation: Optional[Detection], task_name: str
) -> Iterable[ParquetRecord]:
"""Converts an annotation to a dictionary for writing to a
parquet file.

@rtype: L{ParquetDict}
@return: A dictionary of annotation data.
"""
for source, file_path in self.files.items():
if self.annotation is not None:
for detection in self.annotation.to_parquet_rows():
yield {
"file": str(file_path),
"source_name": source,
"task_name": self.task,
"created_at": timestamp,
**detection,
}
else:
if annotation is None:
yield {
"file": str(file_path),
"source_name": source,
"task_name": self.task,
"created_at": timestamp,
"task_name": task_name,
"class_name": None,
"instance_id": None,
"task_type": None,
"annotation": None,
}
else:
for task_type in [
"boundingbox",
"keypoints",
"segmentation",
"instance_segmentation",
"array",
]:
label: Optional[Annotation] = getattr(
annotation, task_type
)

if label is not None:
yield {
"file": str(file_path),
"source_name": source,
"task_name": task_name,
"class_name": annotation.class_name,
"instance_id": annotation.instance_id,
"task_type": task_type,
"annotation": label.model_dump_json(),
}
for key, data in annotation.metadata.items():
yield {
"file": str(file_path),
"source_name": source,
"task_name": task_name,
"class_name": annotation.class_name,
"instance_id": annotation.instance_id,
"task_type": f"metadata/{key}",
"annotation": json.dumps(data),
}
if annotation.class_name is not None:
yield {
"file": str(file_path),
"source_name": source,
"task_name": task_name,
"class_name": annotation.class_name,
"instance_id": annotation.instance_id,
"task_type": "classification",
"annotation": "{}",
}
for name, detection in annotation.sub_detections.items():
yield from self._to_parquet_rows(
detection, f"{task_name}/{name}"
)


def check_valid_identifier(name: str, *, label: str) -> None:
Expand Down
22 changes: 10 additions & 12 deletions luxonis_ml/data/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from luxonis_ml.data.datasets.annotation import DatasetRecord
from luxonis_ml.data.datasets.source import LuxonisSource
from luxonis_ml.data.utils.task_utils import get_task_name
from luxonis_ml.typing import PathType
from luxonis_ml.utils import AutoRegisterMeta, Registry

Expand Down Expand Up @@ -41,11 +40,11 @@ def version(self) -> Version:
...

@abstractmethod
def get_tasks(self) -> List[str]:
"""Returns the list of tasks in the dataset.
def get_tasks(self) -> Dict[str, str]:
"""Returns a dictionary mapping task names to task types.

@rtype: List[str]
@return: List of task names.
@rtype: Dict[str, str]
@return: A dictionary mapping task names to task types.
"""
...

Expand Down Expand Up @@ -75,13 +74,12 @@ def set_classes(
...

@abstractmethod
def get_classes(self) -> Tuple[List[str], Dict[str, List[str]]]:
"""Gets overall classes in the dataset and classes according to
computer vision task.
def get_classes(self) -> Dict[str, List[str]]:
"""Get classes according to computer vision tasks.

@rtype: Tuple[List[str], Dict]
@return: A combined list of classes for all tasks and a
dictionary mapping tasks to the classes used in each task.
@rtype: Dict[str, List[str]]
@return: A dictionary mapping tasks to the classes used in each
task.
"""
...

Expand Down Expand Up @@ -202,4 +200,4 @@ def get_task_names(self) -> List[str]:
@rtype: List[str]
@return: List of task names.
"""
return [get_task_name(task) for task in self.get_tasks()]
return list(self.get_tasks().keys())
Loading
Loading