-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathImageClassifier.ts
86 lines (73 loc) · 3 KB
/
ImageClassifier.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import * as fsp from "fs/promises";
import * as mobilenet from "@tensorflow-models/mobilenet";
import * as knnClassifier from "@tensorflow-models/knn-classifier";
import Tensorset from "tensorset/lib/Tensorset";
import * as tf from "@tensorflow/tfjs-node";
class ImageClassifier {
static default = ImageClassifier;
static async create() {
try {
const classifier = knnClassifier.create();
return new ImageClassifier(await mobilenet.load(), classifier);
} catch (error) {
// ERROR: Mobilenet fails to load
throw error;
}
}
static async load(datasetPath: string) {
try {
const classifier = knnClassifier.create();
const dataset = await fsp.readFile(datasetPath, { encoding: 'utf-8' });
const tensorset = await Tensorset.parse(dataset);
classifier.setClassifierDataset(tensorset);
return new ImageClassifier(await mobilenet.load(), classifier);
} catch (error) {
// ERROR (Option 1): Attempts to load an invalid dataset
// ERROR (Option 2): Mobilenet fails to load
throw error;
}
}
private mobilenet: mobilenet.MobileNet;
private classifier: knnClassifier.KNNClassifier;
private constructor(mobilenet: mobilenet.MobileNet, classifier: knnClassifier.KNNClassifier) {
this.mobilenet = mobilenet;
this.classifier = classifier;
}
async save(datasetDestination: string) {
try {
const dataset = this.classifier.getClassifierDataset();
const data = await Tensorset.stringify(dataset);
await fsp.writeFile(datasetDestination, data);
} catch (error) {
// ERROR (Option 1): Destination path is not specified and there is no default path
// ERROR (Option 2): File could not be written
throw error;
}
}
async addExample(label: string, image: string | Buffer) {
try {
const imageData = image instanceof Buffer ? image : await fsp.readFile(image);
const tensor = this.mobilenet.infer(tf.node.decodeImage(new Uint8Array(imageData), 3), true);
this.classifier.addExample(tensor, label);
} catch (error) {
// ERROR (Option 1): Failed to read file
// ERROR (Option 2): File was not a proper image
throw error;
}
}
dropClassifier(label: string) {
this.classifier.clearClass(label);
}
async predict(image: string | Buffer) {
try {
const imageData = image instanceof Buffer ? image : await fsp.readFile(image);
const tensor = this.mobilenet.infer(tf.node.decodeImage(new Uint8Array(imageData), 3), true);
return this.classifier.predictClass(tensor);
} catch (error) {
// ERROR (Option 1): Failed to read file
// ERROR (Option 2): File was not a proper image
throw error;
}
}
}
export = ImageClassifier;