forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvgg7.py
248 lines (210 loc) · 9.05 KB
/
vgg7.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import sys
import random
import json
import numpy
from pathlib import Path
from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import SGD
from examples.vgg7_helpers.kinne import KinneDir
from examples.vgg7_helpers.waifu2x import image_load, image_save, Vgg7
# amount of context erased by model
CONTEXT = 7
def get_sample_count(samples_dir):
try:
samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r")
v = samples_dir_count_file.readline()
samples_dir_count_file.close()
return int(v)
except:
return 0
def set_sample_count(samples_dir, sc):
with open(samples_dir + "/sample_count.txt", "w") as file:
file.write(str(sc) + "\n")
if len(sys.argv) < 2:
print("python3 -m examples.vgg7 import MODELJSON MODELDIR")
print(" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json")
print(" into a directory of float binaries along with a meta.txt file containing tensor sizes")
print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)")
print(" *this format is used by all other commands in this program*")
print("python3 -m examples.vgg7 execute MODELDIR IMG_IN IMG_OUT")
print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it")
print(" output image has 7 pixels removed on all edges")
print(" do not run on large images, will have *hilarious* RAM use")
print("python3 -m examples.vgg7 execute_full MODELDIR IMG_IN IMG_OUT")
print(" does the 'whole thing' (padding, tiling)")
print(" safe for large images, etc.")
print("python3 -m examples.vgg7 new MODELDIR")
print(" creates a new model (experimental)")
print("python3 -m examples.vgg7 train MODELDIR SAMPLES_DIR ROUNDS ROUNDS_SAVE")
print(" trains a model (experimental)")
print(" (how experimental? well, every time I tried it, it flooded w/ NaNs)")
print(" note: ROUNDS < 0 means 'forever'. ROUNDS_SAVE <= 0 is not a good idea.")
print(" expects roughly execute's input as SAMPLES_DIR/IDXa.png")
print(" expects roughly execute's output as SAMPLES_DIR/IDXb.png")
print(" (i.e. my_samples/0a.png is the first pre-nearest-scaled image,")
print(" my_samples/0b.png is the first original image)")
print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count")
print(" won't pad or tile, so keep image sizes sane")
print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE")
print(" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training")
print(" maintains/creates samples_count.txt automatically")
print(" unlike training, IMG_A must be exactly half the size of IMG_B")
sys.exit(1)
cmd = sys.argv[1]
vgg7 = Vgg7()
def nansbane(p):
if numpy.isnan(numpy.min(p.numpy())):
raise Exception("A NaN in the model has been detected. This model will not be interacted with to prevent further damage.")
def load_and_save(path, save):
if save:
for v in vgg7.get_parameters():
nansbane(v)
kn = KinneDir(model, save)
kn.parameters(vgg7.get_parameters())
kn.close()
if not save:
for v in vgg7.get_parameters():
nansbane(v)
if cmd == "import":
src = sys.argv[2]
model = sys.argv[3]
vgg7.load_waifu2x_json(json.load(open(src, "rb")))
Path(model).mkdir(exist_ok=True)
load_and_save(model, True)
elif cmd == "execute":
model = sys.argv[2]
in_file = sys.argv[3]
out_file = sys.argv[4]
load_and_save(model, False)
image_save(out_file, vgg7.forward(Tensor(image_load(in_file))).numpy())
elif cmd == "execute_full":
model = sys.argv[2]
in_file = sys.argv[3]
out_file = sys.argv[4]
load_and_save(model, False)
image_save(out_file, vgg7.forward_tiled(image_load(in_file), 156))
elif cmd == "new":
model = sys.argv[2]
Path(model).mkdir(exist_ok=True)
load_and_save(model, True)
elif cmd == "train":
model = sys.argv[2]
samples_base = sys.argv[3]
samples_count = get_sample_count(samples_base)
rounds = int(sys.argv[4])
rounds_per_save = int(sys.argv[5])
load_and_save(model, False)
# Initialize sample probabilities.
# This is used to try and get the network to focus on "interesting" samples,
# which works nicely with the microsample system.
sample_probs = None
sample_probs_path = model + "/sample_probs.bin"
try:
# try to read...
sample_probs = numpy.fromfile(sample_probs_path, "<f8")
if sample_probs.shape[0] != samples_count:
print("sample probs size != sample count - initializing")
sample_probs = None
except:
# it's fine
print("sample probs could not be loaded - initializing")
if sample_probs is None:
# This stupidly high amount is used to force an initial pass over all samples
sample_probs = numpy.ones(samples_count) * 1000
print("Training...")
# Adam has a tendency to destroy the state of the network when restarted
# Plus it's slower
optim = SGD(vgg7.get_parameters())
rnum = 0
while True:
# The way the -1 option works is that rnum is never -1.
if rnum == rounds:
break
sample_idx = 0
try:
sample_idx = numpy.random.choice(samples_count, p = sample_probs / sample_probs.sum())
except:
print("exception occurred (PROBABLY value-probabilities-dont-sum-to-1)")
sample_idx = random.randint(0, samples_count - 1)
x_img = image_load(samples_base + "/" + str(sample_idx) + "a.png")
y_img = image_load(samples_base + "/" + str(sample_idx) + "b.png")
sample_x = Tensor(x_img, requires_grad = False)
sample_y = Tensor(y_img, requires_grad = False)
# magic code roughly from readme example
# An explanation, in case anyone else has to go down this path:
# This runs the actual network normally
out = vgg7.forward(sample_x)
# Subtraction determines error here (as this is an image, not classification).
# *Abs is the important bit* - at least for me, anyway.
# The training process seeks to minimize this 'loss' value.
# Minimization of loss *tends towards negative infinity*, so without the abs,
# or without an implicit abs (the mul in the README),
# loss will always go haywire in one direction or another.
# Mean determines how errors are treated.
# Do not use Sum. I tried that. It worked while I was using 1x1 patches...
# Then it went exponential.
# Also, Mean goes *after* abs. I realize this should have been obvious to me.
loss = sample_y.sub(out).abs().mean()
# This is the bit where tinygrad works backward from the loss
optim.zero_grad()
loss.backward()
# And this updates the parameters
optim.step()
# warning: used by sample probability adjuster
loss_indicator = loss.max().numpy()[0]
print("Round " + str(rnum) + " : " + str(loss_indicator))
if (rnum % rounds_per_save) == 0:
print("Saving")
load_and_save(model, True)
sample_probs.astype("<f8", "C").tofile(sample_probs_path)
# Update round state
# Number
rnum = rnum + 1
# Probability management
# there must always be a probability, no matter how slim, even if loss goes to 0
sample_probs[sample_idx] = max(loss_indicator, 1.e-10)
# if we were told to save every round, we already saved
if rounds_per_save != 1:
print("Done with all rounds, saving")
load_and_save(model, True)
sample_probs.astype("<f8", "C").tofile(sample_probs_path)
elif cmd == "samplify":
a_img = sys.argv[2]
b_img = sys.argv[3]
samples_base = sys.argv[4]
sample_size = int(sys.argv[5])
samples_count = get_sample_count(samples_base)
# This bit is interesting because it actually does some work.
# Not much, but some work.
a_img = image_load(a_img)
b_img = image_load(b_img)
# as with the main library body,
# Y X order is used here
# assertion before pre-upscaling is performed
assert a_img.shape[2] == (b_img.shape[2] // 2)
assert a_img.shape[3] == (b_img.shape[3] // 2)
# pre-upscaling - this matches the sizes (and coordinates)
a_img = a_img.repeat(2, 2).repeat(2, 3)
samples_added = 0
# actual patch extraction
for posy in range(CONTEXT, b_img.shape[2] - (CONTEXT + sample_size - 1), sample_size):
for posx in range(CONTEXT, b_img.shape[3] - (CONTEXT + sample_size - 1), sample_size):
# this is a viable patch location, add it
# note the ranges here:
# + there are always CONTEXT pixels *before* the point
# + with no subtraction at the end, there'd already be a pixel *at* the point,
# as ranges are exclusive
# + additionally, there are sample_size - 1 additional sample pixels
# + additionally, there are CONTEXT additional pixels
# + therefore there are CONTEXT + sample_size pixels *at & after* the point
patch_x = a_img[:, :, posy - CONTEXT : posy + CONTEXT + sample_size, posx - CONTEXT : posx + CONTEXT + sample_size]
patch_y = b_img[:, :, posy : posy + sample_size, posx : posx + sample_size]
image_save(f"{samples_base}/{str(samples_count)}a.png", patch_x)
image_save(f"{samples_base}/{str(samples_count)}b.png", patch_y)
samples_count += 1
samples_added += 1
print(f"Added {str(samples_added)} samples")
set_sample_count(samples_base, samples_count)
else:
print("unknown command")