-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
investigating why the input to the model is incorrect resulting in er…
…ror on initialization.
- Loading branch information
1 parent
dc4bf39
commit ebfbfb5
Showing
10 changed files
with
340 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
@echo off | ||
setlocal enableextensions | ||
|
||
cd %~dp0 | ||
|
||
set TF_VERSION=2.5 | ||
set URL=https://github.com/am15h/tflite_flutter_plugin/releases/download/ | ||
set TAG=tf_%TF_VERSION% | ||
|
||
set ANDROID_DIR=android\app\src\main\jniLibs\ | ||
set ANDROID_LIB=libtensorflowlite_c.so | ||
|
||
set ARM_DELEGATE=libtensorflowlite_c_arm_delegate.so | ||
set ARM_64_DELEGATE=libtensorflowlite_c_arm64_delegate.so | ||
set ARM=libtensorflowlite_c_arm.so | ||
set ARM_64=libtensorflowlite_c_arm64.so | ||
set X86=libtensorflowlite_c_x86_delegate.so | ||
set X86_64=libtensorflowlite_c_x86_64_delegate.so | ||
|
||
SET /A d = 0 | ||
|
||
:GETOPT | ||
if /I "%1"=="-d" SET /A d = 1 | ||
|
||
SETLOCAL | ||
if %d%==1 CALL :Download %ARM_DELEGATE% armeabi-v7a | ||
if %d%==1 CALL :Download %ARM_64_DELEGATE% arm64-v8a | ||
if %d%==0 CALL :Download %ARM% armeabi-v7a | ||
if %d%==0 CALL :Download %ARM_64% arm64-v8a | ||
CALL :Download %X86% x86 | ||
CALL :Download %X86_64% x86_64 | ||
EXIT /B %ERRORLEVEL% | ||
|
||
:Download | ||
curl -L -o %~1 %URL%%TAG%/%~1 | ||
mkdir %ANDROID_DIR%%~2\ | ||
move /-Y %~1 %ANDROID_DIR%%~2\%ANDROID_LIB% | ||
EXIT /B 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#!/usr/bin/env bash | ||
|
||
cd "$(dirname "$(readlink -f "$0")")" | ||
|
||
# Available versions | ||
# 2.5, 2.4.1 | ||
|
||
TF_VERSION=2.5 | ||
URL="https://github.com/am15h/tflite_flutter_plugin/releases/download/" | ||
TAG="tf_$TF_VERSION" | ||
|
||
ANDROID_DIR="android/app/src/main/jniLibs/" | ||
ANDROID_LIB="libtensorflowlite_c.so" | ||
|
||
ARM_DELEGATE="libtensorflowlite_c_arm_delegate.so" | ||
ARM_64_DELEGATE="libtensorflowlite_c_arm64_delegate.so" | ||
ARM="libtensorflowlite_c_arm.so" | ||
ARM_64="libtensorflowlite_c_arm64.so" | ||
X86="libtensorflowlite_c_x86_delegate.so" | ||
X86_64="libtensorflowlite_c_x86_64_delegate.so" | ||
|
||
delegate=0 | ||
|
||
while getopts "d" OPTION | ||
do | ||
case $OPTION in | ||
d) delegate=1;; | ||
esac | ||
done | ||
|
||
|
||
download () { | ||
wget "${URL}${TAG}/$1" -O "$1" | ||
mkdir -p "${ANDROID_DIR}$2/" | ||
mv $1 "${ANDROID_DIR}$2/${ANDROID_LIB}" | ||
} | ||
|
||
if [ ${delegate} -eq 1 ] | ||
then | ||
|
||
download ${ARM_DELEGATE} "armeabi-v7a" | ||
download ${ARM_64_DELEGATE} "arm64-v8a" | ||
|
||
else | ||
|
||
download ${ARM} "armeabi-v7a" | ||
download ${ARM_64} "arm64-v8a" | ||
|
||
fi | ||
|
||
download ${X86} "x86" | ||
download ${X86_64} "x86_64" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import 'dart:math'; | ||
import 'dart:ui'; | ||
|
||
import 'package:collection/collection.dart'; | ||
import 'package:image/image.dart' as image_lib; | ||
import 'package:tflite_flutter/tflite_flutter.dart'; | ||
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart'; | ||
|
||
import '../utils/logger.dart'; | ||
import '../utils/recognition.dart'; | ||
import '../utils/stats.dart'; | ||
|
||
/// Classifier | ||
class Classifier { | ||
static const String MODEL_FILE_NAME = "detect.tflite"; | ||
static const String LABEL_FILE_NAME = "labelmap.txt"; | ||
|
||
/// Input size of image (height = width = 300) | ||
static const int INPUT_SIZE = 224; | ||
|
||
/// Result score threshold | ||
static const double THRESHOLD = 0.5; | ||
|
||
/// [ImageProcessor] used to pre-process the image | ||
ImageProcessor? imageProcessor; | ||
|
||
/// Padding the image to transform into square | ||
// int padSize = 0; | ||
/// Instance of Interpreter | ||
late Interpreter _interpreter; | ||
|
||
late TensorBuffer _outputBuffer; | ||
late var _probabilityProcessor; | ||
|
||
/// Labels file loaded as list | ||
late List<String> _labels; | ||
|
||
/// Number of results to show | ||
static const int NUM_RESULTS = 10; | ||
|
||
Classifier({ | ||
Interpreter? interpreter, | ||
List<String>? labels, | ||
}) { | ||
loadModel(interpreter: interpreter); | ||
loadLabels(labels: labels); | ||
} | ||
|
||
/// Loads interpreter from asset | ||
void loadModel({Interpreter? interpreter}) async { | ||
try { | ||
_interpreter = interpreter ?? | ||
await Interpreter.fromAsset( | ||
MODEL_FILE_NAME, | ||
options: InterpreterOptions()..threads = 4, | ||
); | ||
var outputTensor = _interpreter.getOutputTensor(0); | ||
var outputShape = outputTensor.shape; | ||
var outputType = outputTensor.type; | ||
|
||
var inputTensor = _interpreter.getInputTensor(0); | ||
var intputShape = inputTensor.shape; | ||
var intputType = inputTensor.type; | ||
|
||
_outputBuffer = TensorBuffer.createFixedSize(outputShape, outputType); | ||
_probabilityProcessor = | ||
TensorProcessorBuilder().add(NormalizeOp(0, 1)).build(); | ||
} catch (e) { | ||
logger.e("Error while creating interpreter: ", e); | ||
} | ||
} | ||
|
||
/// Loads labels from assets | ||
void loadLabels({List<String>? labels}) async { | ||
try { | ||
_labels = labels ?? await FileUtil.loadLabels("assets/labels.txt"); | ||
} catch (e) { | ||
logger.e("Error while loading labels: $e"); | ||
} | ||
} | ||
|
||
/// Pre-process the image | ||
TensorImage? getProcessedImage(TensorImage inputImage) { | ||
// padSize = max(inputImage.height, inputImage.width); | ||
imageProcessor ??= ImageProcessorBuilder() | ||
// .add(ResizeWithCropOrPadOp(padSize, padSize)) | ||
.add(ResizeOp(INPUT_SIZE, INPUT_SIZE, ResizeMethod.BILINEAR)) | ||
.add(NormalizeOp(127.5, 127.5)) | ||
.build(); | ||
return imageProcessor?.process(inputImage); | ||
} | ||
|
||
/// Runs object detection on the input image | ||
Map<String, dynamic>? predict(image_lib.Image image) { | ||
logger.i(labels); | ||
var predictStartTime = DateTime.now().millisecondsSinceEpoch; | ||
if (_interpreter == null) { | ||
logger.e("Interpreter not initialized"); | ||
return null; | ||
} | ||
var preProcessStart = DateTime.now().millisecondsSinceEpoch; | ||
// Create TensorImage from image | ||
// Pre-process TensorImage | ||
var procImage = getProcessedImage(TensorImage.fromImage(image)); | ||
|
||
var preProcessElapsedTime = | ||
DateTime.now().millisecondsSinceEpoch - preProcessStart; | ||
if (procImage != null) { | ||
var inferenceTimeStart = DateTime.now().millisecondsSinceEpoch; | ||
// run inference | ||
var inferenceTimeElapsed = | ||
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart; | ||
|
||
logger.i("Sending image to ML"); | ||
|
||
logger.i(procImage.buffer.asFloat32List()); | ||
logger.i(procImage.width); | ||
logger.i(procImage.height); | ||
logger.i(procImage.tensorBuffer.shape); | ||
logger.i(procImage.tensorBuffer.isDynamic); | ||
_interpreter.run(procImage.buffer, _outputBuffer.getBuffer()); | ||
|
||
Map<String, double> labeledProb = TensorLabel.fromList( | ||
labels, _probabilityProcessor.process(_outputBuffer)) | ||
.getMapWithFloatValue(); | ||
final pred = getTopProbability(labeledProb); | ||
Recognition rec = Recognition(1, pred.key, pred.value); | ||
var predictElapsedTime = DateTime.now().millisecondsSinceEpoch - predictStartTime; | ||
return { | ||
"recognitions": rec, | ||
"stats": Stats(predictElapsedTime, predictElapsedTime, predictElapsedTime, predictElapsedTime), | ||
}; | ||
} else { | ||
return null; | ||
} | ||
} | ||
|
||
/// Gets the interpreter instance | ||
Interpreter get interpreter => _interpreter; | ||
|
||
/// Gets the loaded labels | ||
List<String> get labels => _labels; | ||
} | ||
|
||
MapEntry<String, double> getTopProbability(Map<String, double> labeledProb) { | ||
var pq = PriorityQueue<MapEntry<String, double>>(compare); | ||
pq.addAll(labeledProb.entries); | ||
|
||
return pq.first; | ||
} | ||
|
||
int compare(MapEntry<String, double> e1, MapEntry<String, double> e2) { | ||
if (e1.value > e2.value) { | ||
return -1; | ||
} else if (e1.value == e2.value) { | ||
return 0; | ||
} else { | ||
return 1; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import 'package:flutter/material.dart'; | ||
import 'package:tensordex_mobile/ui/poke_view.dart'; | ||
import 'package:tensordex_mobile/utils/recognition.dart'; | ||
|
||
import '../utils/logger.dart'; | ||
|
||
/// [CameraView] sends each frame for inference | ||
class ResultsView extends StatefulWidget { | ||
|
||
/// Constructor | ||
const ResultsView({Key? key}) : super(key: key); | ||
|
||
|
||
void setResults(Recognition results){ | ||
logger.i("RESULTS IN THE RESULT VIEW"); | ||
} | ||
|
||
@override | ||
State<ResultsView> createState() => _ResultsViewState(); | ||
} | ||
|
||
class _ResultsViewState extends State<ResultsView> { | ||
|
||
@override | ||
void initState() { | ||
super.initState(); | ||
} | ||
|
||
@override | ||
Widget build(BuildContext context) { | ||
return Text("data"); | ||
} | ||
} |
Oops, something went wrong.