Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: german_traffic_sign #1849

Draft
wants to merge 5 commits into
base: tfdsv4
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions armory/datasets/cached_datasets.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
"url": null,
"version": "1.0.8"
},
"german_traffic_sign": {
"sha256": "f154d931293d40a96ace873ebad113f66e7580ade0ae6c495a0478d706b83315",
"size": 377068263,
"subdir": "german_traffic_sign/3.0.0",
"url": null,
"version": "3.0.0"
},
"mnist": {
"sha256": "fdc3408e29580367145e95ac7cb1d51e807105b174314cd52c16d27a13b98979",
"size": 16920751,
Expand Down
36 changes: 36 additions & 0 deletions armory/datasets/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


import tensorflow as tf
import tensorflow_addons as tfa


REGISTERED_PREPROCESSORS = {}
Expand Down Expand Up @@ -39,10 +40,16 @@ def supervised_image_classification(element):
return (image_to_canon(element["image"]), element["label"])


@register
def supervised_gtsrb_classification(element):
return (gtsrb_to_canon(element["image"]), element["label"])


mnist = register(supervised_image_classification, "mnist")
cifar10 = register(supervised_image_classification, "cifar10")
cifar100 = register(supervised_image_classification, "cifar100")
resisc45 = register(supervised_image_classification, "resisc45")
gtsrb = register(supervised_gtsrb_classification, "german_traffic_sign")


@register
Expand Down Expand Up @@ -91,6 +98,35 @@ def image_to_canon(image, resize=None, target_dtype=tf.float32, input_type="uint
return image


def gtsrb_to_canon(image, target_dtype=tf.float32, input_type="uint8"):
"""
TFDS Image feature uses (height, width, channels)
"""
if input_type == "uint8":
scale = 255.0
else:
raise NotImplementedError(f"Currently only supports uint8, not {input_type}")
image = tf.cast(image, target_dtype)
image = tfa.image.equalize(image)
image = image / scale

width, height = image.size
min_side = min(image.size)
center = width // 2, height // 2

left = center[0] - min_side // 2
top = center[1] - min_side // 2
right = center[0] + min_side // 2
bottom = center[1] + min_side // 2

img_size = 48
image = tf.image.crop_and_resize(
image, [[bottom, left, top, right]], [img_size, img_size]
)

return image


def audio_to_canon(audio, resample=None, target_dtype=tf.float32, input_type="int16"):
"""
Note: input_type is the scale of the actual data
Expand Down
3 changes: 3 additions & 0 deletions armory/datasets/standard/german_traffic_sign/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""german_traffic_sign dataset."""

from .german_traffic_sign import GermanTrafficSign
1 change: 1 addition & 0 deletions armory/datasets/standard/german_traffic_sign/checksums.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://armory-public-data.s3.us-east-2.amazonaws.com/german-traffic-sign/german_traffic_sign.tar.gz 367878784 0a39ee87e4cfd83b293eae21c0206dee304b2d09e17bcce369c97e4807c5ce3f german_traffic_sign.tar.gz
123 changes: 123 additions & 0 deletions armory/datasets/standard/german_traffic_sign/german_traffic_sign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
German traffic sign dataset with 43 classes and over 50,000 images.
"""

import csv
import os

import PIL
import tensorflow_datasets as tfds

# TODO(german_traffic_sign): Markdown description that will appear on the catalog page.
_DESCRIPTION = """
German traffic sign dataset with 43 classes and over 50,000 images.
"""

_HOMEPAGE = "http://benchmark.ini.rub.de/?section=gtsrb&subsection=dataset"

_NUM_CLASSES = 43
_LABELS = [str(x) for x in range(_NUM_CLASSES)]

_URL = "https://armory-public-data.s3.us-east-2.amazonaws.com/german-traffic-sign/german_traffic_sign.tar.gz"

# # TODO(german_traffic_sign): BibTeX citation
# _CITATION = """
# """


class GermanTrafficSign(tfds.core.GeneratorBasedBuilder):
"""DatasetBuilder for german_traffic_sign dataset."""

VERSION = tfds.core.Version("3.0.0")
# RELEASE_NOTES = {
# "1.0.0": "Initial release.",
# }

def _info(self) -> tfds.core.DatasetInfo:
"""Returns the dataset metadata."""
# TODO(german_traffic_sign): Specifies the tfds.core.DatasetInfo object
return tfds.core.DatasetInfo(
builder=self,
description=_DESCRIPTION,
features=tfds.features.FeaturesDict(
{
# These are the features of your dataset like images, labels ...
"image": tfds.features.Image(shape=(None, None, 3)),
# "label": tfds.features.ClassLabel(names=["no", "yes"]),
"label": tfds.features.ClassLabel(names=_LABELS),
"filename": tfds.features.Text(),
}
),
# If there's a common (input, target) tuple from the
# features, specify them here. They'll be used if
# `as_supervised=True` in `builder.as_dataset`.
supervised_keys=("image", "label"), # Set to `None` to disable
homepage=_HOMEPAGE,
# homepage="https://dataset-homepage/",
# citation=_CITATION,
)

def _split_generators(self, dl_manager: tfds.download.DownloadManager):
"""Returns SplitGenerators."""
# # TODO(german_traffic_sign): Downloads the data and defines the splits
# path = dl_manager.download_and_extract("https://todo-data-url")

# # TODO(german_traffic_sign): Returns the Dict[split names, Iterator[Key, Example]]
# return {
# "train": self._generate_examples(path / "train_imgs"),
# }
path = os.path.join(dl_manager.download_and_extract(_URL), "GTSRB")
splits = [
tfds.core.SplitGenerator(name=x, gen_kwargs={"path": path, "split": x})
for x in (tfds.Split.TRAIN, tfds.Split.TEST)
]
return splits

# def _generate_examples(self, path):
def _generate_examples(self, path, split):
# """Yields examples."""
# # TODO(german_traffic_sign): Yields (key, example) tuples from the dataset
# for f in path.glob("*.jpeg"):
# yield "key", {
# "image": f,
# "label": "yes",
# }
"""Yields examples. Converts PPM files to BMP before yielding."""

def _read_images(prefix, gtFile):
with open(gtFile, newline="") as csvFile:
gtReader = csv.reader(csvFile, delimiter=";")
next(gtReader) # skip header
# loop over all images in current annotations file
for i, row in enumerate(gtReader):
ppm_filename = row[0]
ppm_filepath = os.path.join(prefix, ppm_filename)
# translate ppm files to bmp files
base, ext = os.path.splitext(ppm_filename)
bmp_filename = base + ".bmp"
bmp_filepath = os.path.join(prefix, bmp_filename)
with PIL.Image.open(ppm_filepath) as image:
image.save(bmp_filepath, "BMP")

example = {
"image": bmp_filepath,
"label": row[7],
"filename": bmp_filename,
}
yield bmp_filepath, example

if split is tfds.Split.TRAIN:
for c in range(_NUM_CLASSES):
# subdirectory for class
prefix = os.path.join(path, "Final_Training", "Images", f"{c:05d}")
# annotations file
gtFile = os.path.join(prefix, f"GT-{c:05d}.csv")
for x in _read_images(prefix, gtFile):
yield x
elif split is tfds.Split.TEST:
prefix = os.path.join(path, "Final_Test", "Images")
gtFile = os.path.join(path, "GT-final_test.csv")
for x in _read_images(prefix, gtFile):
yield x
else:
raise ValueError(f"split {split} not in ('train', 'test')")
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""german_traffic_sign dataset."""

import tensorflow_datasets as tfds
from . import german_traffic_sign


class GermanTrafficSignTest(tfds.testing.DatasetBuilderTestCase):
"""Tests for german_traffic_sign dataset."""

# TODO(german_traffic_sign):
DATASET_CLASS = german_traffic_sign.GermanTrafficSign
SPLITS = {
"train": 3, # Number of fake train example
"test": 1, # Number of fake test example
}

# If you are calling `download/download_and_extract` with a dict, like:
# dl_manager.download({'some_key': 'http://a.org/out.txt', ...})
# then the tests needs to provide the fake output paths relative to the
# fake data directory
# DL_EXTRACT_RESULT = {'some_key': 'output_file1.txt', ...}


if __name__ == "__main__":
tfds.testing.test_main()