-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualizer.js
138 lines (114 loc) · 7.66 KB
/
visualizer.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class Visualizer {
/**
*
* @param {any} ctx A canvas 2d context
* @param {any} network A network instance
*/
static drawNetwork(ctx, network) {
const margin = 50;
const left = margin;
const top = margin;
const width = ctx.canvas.width - margin * 2;
const height = ctx.canvas.height - margin * 2;
const levelHeight = height / network.levels.length;
const lastIndex = network.levels.length - 1;
// We draw the level from top to bottom, so the draws do not override each one biases
for (let i = lastIndex; i >= 0; i--) {
const maxHeight = height - levelHeight;
const minHeight = 0;
const levelTop =
top + lerp(maxHeight, minHeight, lastIndex === 0 ? 0.5 : i / lastIndex);
const isControlsOutput = i === lastIndex;
const controlsLabels = ["🠉", "🠈", "🠊", "🠋"];
ctx.setLineDash([7, 3]);
Visualizer.drawLevel(
ctx,
network.levels[i],
{
left,
top: levelTop,
width,
height: levelHeight,
},
isControlsOutput ? controlsLabels : []
);
}
}
/**
* @param {any} ctx A canvas 2d context
* @param {any[]} level A network level instance
* @param {{ left: number; top: number; width: number; height: number; }} boundings
* @param {string[]} labels
*/
static drawLevel(ctx, level, { left, top, width, height }, labels) {
const right = left + width;
const bottom = top + height;
const { inputs, outputs, weights, biases } = level;
const nodeRadius = 18;
// Draw connections between nodes
for (let i = 0; i < inputs.length; i++) {
for (let j = 0; j < outputs.length; j++) {
ctx.beginPath();
ctx.moveTo(Visualizer.#getNodeX(inputs, i, left, right), bottom);
ctx.lineTo(Visualizer.#getNodeX(outputs, j, left, right), top);
ctx.lineWidth = 2;
const value = weights[i][j];
ctx.strokeStyle = getRGBA(value);
ctx.stroke();
}
}
// Draw the input nodes
for (let i = 0; i < inputs.length; i++) {
const x = Visualizer.#getNodeX(inputs, i, left, right);
ctx.beginPath();
ctx.arc(x, bottom, nodeRadius, 0, Math.PI * 2);
ctx.fillStyle = "black";
ctx.fill();
ctx.beginPath();
ctx.arc(x, bottom, nodeRadius * 0.6, 0, Math.PI * 2);
ctx.fillStyle = getRGBA(inputs[i]);
ctx.fill();
}
// Draw the output nodes
for (let i = 0; i < outputs.length; i++) {
const x = Visualizer.#getNodeX(outputs, i, left, right);
ctx.beginPath();
ctx.arc(x, top, nodeRadius, 0, Math.PI * 2);
ctx.fillStyle = "black";
ctx.fill();
ctx.beginPath();
ctx.arc(x, top, nodeRadius * 0.6, 0, Math.PI * 2);
ctx.fillStyle = getRGBA(outputs[i]);
ctx.fill();
ctx.beginPath();
ctx.lineWidth = 2;
ctx.arc(x, top, nodeRadius * 0.8, 0, Math.PI * 2);
ctx.strokeStyle = getRGBA(biases[i]);
ctx.setLineDash([3, 3]);
ctx.stroke();
ctx.setLineDash([]);
if (labels[i]) {
ctx.beginPath();
ctx.textAlign = "center";
ctx.textBaseline = "middle";
ctx.fillStyle = "black";
ctx.strokeStyle = "white";
ctx.font = `${nodeRadius * 1.5}px Arial`;
ctx.fillText(labels[i], x, top + nodeRadius * 0.1);
ctx.lineWidth = 0.5;
ctx.strokeText(labels[i], x, top + nodeRadius * 0.1);
}
}
}
/**
* @param {any[]} nodes
* @param {number} index
* @param {number} left
* @param {number} right
*/
static #getNodeX(nodes, index, left, right) {
const lastIndex = nodes.length - 1;
const t = lastIndex === 0 ? 0.5 : index / lastIndex;
return lerp(left, right, t);
}
}