Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default empty task name #222

Merged
merged 13 commits into from
Jan 14, 2025
5 changes: 3 additions & 2 deletions luxonis_ml/data/datasets/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def validate_path(cls, path: FilePath) -> FilePath:
class DatasetRecord(BaseModelExtraForbid):
files: Dict[str, FilePath]
annotation: Optional[Detection] = None
task: str = "detection"
task: str = ""

@property
def file(self) -> FilePath:
Expand Down Expand Up @@ -564,7 +564,8 @@ def check_valid_identifier(name: str, *, label: str) -> None:
Albumentations requires that the names of the targets
passed as `additional_targets` are valid Python identifiers.
"""
if not name.replace("-", "_").isidentifier():
name = name.replace("-", "_")
if name and not name.isidentifier():
raise ValueError(
f"{label} can only contain alphanumeric characters, "
"underscores, and dashes. Additionaly, the first character "
Expand Down
5 changes: 3 additions & 2 deletions luxonis_ml/data/datasets/luxonis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,10 +728,10 @@ def delete_dataset(self, *, delete_remote: bool = False) -> None:
"""
if not self.is_remote:
shutil.rmtree(self.path)
logger.info(f"Deleted dataset {self.dataset_name}")
logger.info(f"Deleted dataset '{self.dataset_name}'")

if self.is_remote and delete_remote:
logger.info(f"Deleting dataset {self.dataset_name} from cloud")
logger.info(f"Deleting dataset '{self.dataset_name}' from cloud")
assert self.path
assert self.dataset_name
assert self.local_path
Expand Down Expand Up @@ -828,6 +828,7 @@ def _add_process_batch(
def add(
self, generator: DatasetIterator, batch_size: int = 1_000_000
) -> Self:
logger.info(f"Adding data to dataset '{self.dataset_name}'...")
index = self._get_file_index(sync_from_cloud=True)
new_index = {"uuid": [], "file": [], "original_filepath": []}
processed_uuids = set()
Expand Down
2 changes: 0 additions & 2 deletions luxonis_ml/data/loaders/luxonis_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ def __init__(

self.class_mappings: Dict[str, Dict[str, int]] = {}
for task in self.df["task_name"].unique():
if not task:
continue
class_mapping = {
class_: i
for i, class_ in enumerate(
Expand Down
2 changes: 1 addition & 1 deletion luxonis_ml/data/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _add_task(self, generator: DatasetIterator) -> DatasetIterator:
@return: Generator function with added task
"""

task_name = self.task_name or self.dataset_type.value
task_name = self.task_name or ""
for item in generator:
if isinstance(item, dict):
item["task"] = task_name
Expand Down
9 changes: 5 additions & 4 deletions luxonis_ml/data/utils/visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,21 +329,22 @@ def create_mask(

for task, arr in task_type_iterator(labels, "segmentation"):
task_name = get_task_name(task)
images[task_name] = create_mask(
image_name = task_name if task_name and not blend_all else "labels"
images[image_name] = create_mask(
image, arr, task_name, is_instance=False
)

for task, arr in task_type_iterator(labels, "instance_segmentation"):
task_name = get_task_name(task)
image_name = task_name if not blend_all else "labels"
image_name = task_name if task_name and not blend_all else "labels"
curr_image = images.get(image_name, image.copy())
images[image_name] = create_mask(
curr_image, arr, task_name, is_instance=True
)

for task, arr in task_type_iterator(labels, "boundingbox"):
task_name = get_task_name(task)
image_name = task_name if not blend_all else "labels"
image_name = task_name if task_name and not blend_all else "labels"
curr_image = images.get(image_name, image.copy())

draw_function = cv2.rectangle
Expand Down Expand Up @@ -374,7 +375,7 @@ def create_mask(

for task, arr in task_type_iterator(labels, "keypoints"):
task_name = get_task_name(task)
image_name = task_name if not blend_all else "labels"
image_name = task_name if task_name and not blend_all else "labels"
curr_image = images.get(image_name, image.copy())

task_classes = class_names[task_name]
Expand Down
26 changes: 16 additions & 10 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_dataset(
bucket_storage=bucket_storage,
delete_existing=True,
delete_remote=True,
task_name="coco",
)
parser.parse()
dataset = LuxonisDataset(dataset_name, bucket_storage=bucket_storage)
Expand Down Expand Up @@ -173,6 +174,7 @@ def test_loader_iterator(storage_url: str, tempdir: Path):
save_dir=tempdir,
dataset_type=DatasetType.COCO,
delete_existing=True,
task_name="coco",
).parse()
loader = LuxonisLoader(dataset)

Expand Down Expand Up @@ -417,12 +419,12 @@ def generator():
compare_loader_output(
loader,
{
"detection/classification",
"detection/boundingbox",
"detection/driver/boundingbox",
"detection/driver/keypoints",
"detection/license_plate/boundingbox",
"detection/license_plate/metadata/text",
"/classification",
"/boundingbox",
"/driver/boundingbox",
"/driver/keypoints",
"/license_plate/boundingbox",
"/license_plate/metadata/text",
},
)

Expand Down Expand Up @@ -482,10 +484,10 @@ def generator():
compare_loader_output(
loader,
{
"detection/classification",
"detection/boundingbox",
"detection/keypoints",
"detection/segmentation",
"/classification",
"/boundingbox",
"/keypoints",
"/segmentation",
},
)

Expand Down Expand Up @@ -533,6 +535,8 @@ def generator1():

df_cloned = cloned_dataset._load_df_offline()
df_original = dataset._load_df_offline()
assert df_cloned is not None
assert df_original is not None
assert df_cloned.equals(df_original)


Expand Down Expand Up @@ -620,4 +624,6 @@ def generator2():
df_cloned_merged = dataset1.merge_with(
dataset2, inplace=True
)._load_df_offline()
assert df_merged is not None
assert df_cloned_merged is not None
assert df_merged.equals(df_cloned_merged)
Loading