-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathcustom_dataset.py
132 lines (96 loc) · 4.01 KB
/
custom_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
########################################################################
#
# Functions for downloading a custom data-set from the internet
# and loading it into memory. Note that this only loads the file-names
# for the images in the data-set and does not load the actual images.
#
# Implemented in Python 3.5
#
########################################################################
#
# This file is part of the TensorFlow Tutorials
#
# Copyright 2016 by Magnus Erik Hvass Pedersen
# It has been modified for this repository
#
########################################################################
from dataset import load_cached
import download
import os
########################################################################
# Directory where you want to download and save the data-set.
# Set this before you start calling any of the functions below.
data_dir = "data/"
# URL for the data-set on the internet.
data_url = ""
# Dataset name
dataset_name = ""
########################################################################
# Various constants for the size of the images.
# Use these constants in your own program.
# Width and height of each image.
img_size = 200
# Number of channels in each image, 3 channels: Red, Green, Blue.
num_channels = 3
# Length of an image when flattened to a 1-dim array.
img_size_flat = img_size * img_size * num_channels
# Number of classes.
num_classes = 3
########################################################################
# Public functions that you may call to download the data-set from
# the internet and load the data into memory.
def maybe_download_and_extract():
"""
Download and extract the data-set if it doesn't already exist
in data_dir (set this variable first to the desired directory).
"""
download.maybe_download_and_extract(url=data_url, download_dir=data_dir)
def load():
"""
Load the data-set into memory.
This uses a cache-file which is reloaded if it already exists,
otherwise the data-set is created and saved to
the cache-file. The reason for using a cache-file is that it
ensure the files are ordered consistently each time the data-set
is loaded. This is important when the data-set is used in
combination with Transfer Learning as is done in Tutorial #09.
:return:
A DataSet-object for the data-set.
"""
# Path for the cache-file.
cache_path = os.path.join(data_dir, dataset_name + ".pkl")
# If the DataSet-object already exists in a cache-file
# then load it, otherwise create a new object and save
# it to the cache-file so it can be loaded the next time.
dataset = load_cached(cache_path=cache_path,
in_dir=data_dir)
return dataset
########################################################################
if __name__ == '__main__':
# Download and extract the data-set if it doesn't already exist.
maybe_download_and_extract()
# Load the data-set.
dataset = load()
# Get the file-paths for the images and their associated class-numbers
# and class-labels. This is for the training-set.
image_paths_train, cls_train, labels_train = dataset.get_training_set()
# Get the file-paths for the images and their associated class-numbers
# and class-labels. This is for the test-set.
image_paths_test, cls_test, labels_test = dataset.get_test_set()
# Check if the training-set looks OK.
# Print some of the file-paths for the training-set.
for path in image_paths_train[0:5]:
print(path)
# Print the associated class-numbers.
print(cls_train[0:5])
# Print the class-numbers as one-hot encoded arrays.
print(labels_train[0:5])
# Check if the test-set looks OK.
# Print some of the file-paths for the test-set.
for path in image_paths_test[0:5]:
print(path)
# Print the associated class-numbers.
print(cls_test[0:5])
# Print the class-numbers as one-hot encoded arrays.
print(labels_test[0:5])
########################################################################