Skip to content

Commit

Permalink
investigating why the input to the model is incorrect resulting in er…
Browse files Browse the repository at this point in the history
…ror on initialization.
  • Loading branch information
lucasoskorep committed Jun 22, 2022
1 parent dc4bf39 commit ebfbfb5
Show file tree
Hide file tree
Showing 10 changed files with 340 additions and 17 deletions.
38 changes: 38 additions & 0 deletions install_tflite.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
@echo off
setlocal enableextensions

cd %~dp0

set TF_VERSION=2.5
set URL=https://github.com/am15h/tflite_flutter_plugin/releases/download/
set TAG=tf_%TF_VERSION%

set ANDROID_DIR=android\app\src\main\jniLibs\
set ANDROID_LIB=libtensorflowlite_c.so

set ARM_DELEGATE=libtensorflowlite_c_arm_delegate.so
set ARM_64_DELEGATE=libtensorflowlite_c_arm64_delegate.so
set ARM=libtensorflowlite_c_arm.so
set ARM_64=libtensorflowlite_c_arm64.so
set X86=libtensorflowlite_c_x86_delegate.so
set X86_64=libtensorflowlite_c_x86_64_delegate.so

SET /A d = 0

:GETOPT
if /I "%1"=="-d" SET /A d = 1

SETLOCAL
if %d%==1 CALL :Download %ARM_DELEGATE% armeabi-v7a
if %d%==1 CALL :Download %ARM_64_DELEGATE% arm64-v8a
if %d%==0 CALL :Download %ARM% armeabi-v7a
if %d%==0 CALL :Download %ARM_64% arm64-v8a
CALL :Download %X86% x86
CALL :Download %X86_64% x86_64
EXIT /B %ERRORLEVEL%

:Download
curl -L -o %~1 %URL%%TAG%/%~1
mkdir %ANDROID_DIR%%~2\
move /-Y %~1 %ANDROID_DIR%%~2\%ANDROID_LIB%
EXIT /B 0
52 changes: 52 additions & 0 deletions install_tflite.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env bash

cd "$(dirname "$(readlink -f "$0")")"

# Available versions
# 2.5, 2.4.1

TF_VERSION=2.5
URL="https://github.com/am15h/tflite_flutter_plugin/releases/download/"
TAG="tf_$TF_VERSION"

ANDROID_DIR="android/app/src/main/jniLibs/"
ANDROID_LIB="libtensorflowlite_c.so"

ARM_DELEGATE="libtensorflowlite_c_arm_delegate.so"
ARM_64_DELEGATE="libtensorflowlite_c_arm64_delegate.so"
ARM="libtensorflowlite_c_arm.so"
ARM_64="libtensorflowlite_c_arm64.so"
X86="libtensorflowlite_c_x86_delegate.so"
X86_64="libtensorflowlite_c_x86_64_delegate.so"

delegate=0

while getopts "d" OPTION
do
case $OPTION in
d) delegate=1;;
esac
done


download () {
wget "${URL}${TAG}/$1" -O "$1"
mkdir -p "${ANDROID_DIR}$2/"
mv $1 "${ANDROID_DIR}$2/${ANDROID_LIB}"
}

if [ ${delegate} -eq 1 ]
then

download ${ARM_DELEGATE} "armeabi-v7a"
download ${ARM_64_DELEGATE} "arm64-v8a"

else

download ${ARM} "armeabi-v7a"
download ${ARM_64} "arm64-v8a"

fi

download ${X86} "x86"
download ${X86_64} "x86_64"
Empty file removed lib/classifier.dart
Empty file.
2 changes: 1 addition & 1 deletion lib/main.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import 'package:flutter/material.dart';
import 'package:tensordex_mobile/ui/home.dart';
import 'package:tensordex_mobile/ui/tensordex_home.dart';
import 'package:tensordex_mobile/utils/logger.dart';

Future<void> main() async {
Expand Down
160 changes: 160 additions & 0 deletions lib/tflite/classifier.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import 'dart:math';
import 'dart:ui';

import 'package:collection/collection.dart';
import 'package:image/image.dart' as image_lib;
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';

import '../utils/logger.dart';
import '../utils/recognition.dart';
import '../utils/stats.dart';

/// Classifier
class Classifier {
static const String MODEL_FILE_NAME = "detect.tflite";
static const String LABEL_FILE_NAME = "labelmap.txt";

/// Input size of image (height = width = 300)
static const int INPUT_SIZE = 224;

/// Result score threshold
static const double THRESHOLD = 0.5;

/// [ImageProcessor] used to pre-process the image
ImageProcessor? imageProcessor;

/// Padding the image to transform into square
// int padSize = 0;
/// Instance of Interpreter
late Interpreter _interpreter;

late TensorBuffer _outputBuffer;
late var _probabilityProcessor;

/// Labels file loaded as list
late List<String> _labels;

/// Number of results to show
static const int NUM_RESULTS = 10;

Classifier({
Interpreter? interpreter,
List<String>? labels,
}) {
loadModel(interpreter: interpreter);
loadLabels(labels: labels);
}

/// Loads interpreter from asset
void loadModel({Interpreter? interpreter}) async {
try {
_interpreter = interpreter ??
await Interpreter.fromAsset(
MODEL_FILE_NAME,
options: InterpreterOptions()..threads = 4,
);
var outputTensor = _interpreter.getOutputTensor(0);
var outputShape = outputTensor.shape;
var outputType = outputTensor.type;

var inputTensor = _interpreter.getInputTensor(0);
var intputShape = inputTensor.shape;
var intputType = inputTensor.type;

_outputBuffer = TensorBuffer.createFixedSize(outputShape, outputType);
_probabilityProcessor =
TensorProcessorBuilder().add(NormalizeOp(0, 1)).build();
} catch (e) {
logger.e("Error while creating interpreter: ", e);
}
}

/// Loads labels from assets
void loadLabels({List<String>? labels}) async {
try {
_labels = labels ?? await FileUtil.loadLabels("assets/labels.txt");
} catch (e) {
logger.e("Error while loading labels: $e");
}
}

/// Pre-process the image
TensorImage? getProcessedImage(TensorImage inputImage) {
// padSize = max(inputImage.height, inputImage.width);
imageProcessor ??= ImageProcessorBuilder()
// .add(ResizeWithCropOrPadOp(padSize, padSize))
.add(ResizeOp(INPUT_SIZE, INPUT_SIZE, ResizeMethod.BILINEAR))
.add(NormalizeOp(127.5, 127.5))
.build();
return imageProcessor?.process(inputImage);
}

/// Runs object detection on the input image
Map<String, dynamic>? predict(image_lib.Image image) {
logger.i(labels);
var predictStartTime = DateTime.now().millisecondsSinceEpoch;
if (_interpreter == null) {
logger.e("Interpreter not initialized");
return null;
}
var preProcessStart = DateTime.now().millisecondsSinceEpoch;
// Create TensorImage from image
// Pre-process TensorImage
var procImage = getProcessedImage(TensorImage.fromImage(image));

var preProcessElapsedTime =
DateTime.now().millisecondsSinceEpoch - preProcessStart;
if (procImage != null) {
var inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
// run inference
var inferenceTimeElapsed =
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;

logger.i("Sending image to ML");

logger.i(procImage.buffer.asFloat32List());
logger.i(procImage.width);
logger.i(procImage.height);
logger.i(procImage.tensorBuffer.shape);
logger.i(procImage.tensorBuffer.isDynamic);
_interpreter.run(procImage.buffer, _outputBuffer.getBuffer());

Map<String, double> labeledProb = TensorLabel.fromList(
labels, _probabilityProcessor.process(_outputBuffer))
.getMapWithFloatValue();
final pred = getTopProbability(labeledProb);
Recognition rec = Recognition(1, pred.key, pred.value);
var predictElapsedTime = DateTime.now().millisecondsSinceEpoch - predictStartTime;
return {
"recognitions": rec,
"stats": Stats(predictElapsedTime, predictElapsedTime, predictElapsedTime, predictElapsedTime),
};
} else {
return null;
}
}

/// Gets the interpreter instance
Interpreter get interpreter => _interpreter;

/// Gets the loaded labels
List<String> get labels => _labels;
}

MapEntry<String, double> getTopProbability(Map<String, double> labeledProb) {
var pq = PriorityQueue<MapEntry<String, double>>(compare);
pq.addAll(labeledProb.entries);

return pq.first;
}

int compare(MapEntry<String, double> e1, MapEntry<String, double> e2) {
if (e1.value > e2.value) {
return -1;
} else if (e1.value == e2.value) {
return 0;
} else {
return 1;
}
}
37 changes: 36 additions & 1 deletion lib/ui/poke_view.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ import 'dart:isolate';

import 'package:camera/camera.dart';
import 'package:flutter/material.dart';
import 'package:tensordex_mobile/tflite/classifier.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tensordex_mobile/utils/image_utils.dart';

import '../utils/logger.dart';
import '../utils/recognition.dart';
Expand Down Expand Up @@ -30,10 +33,13 @@ class _CameraViewState extends State<CameraView> with WidgetsBindingObserver {

/// Controller
late CameraController cameraController;
Interpreter? interp;

/// true when inference is ongoing
bool predicting = false;

late Classifier classy;

// /// Instance of [Classifier]
// Classifier classifier;
//
Expand All @@ -56,9 +62,28 @@ class _CameraViewState extends State<CameraView> with WidgetsBindingObserver {
// Camera initialization
initializeCamera();

// final gpuDelegateV2 = GpuDelegateV2(
// options: GpuDelegateOptionsV2(
// isPrecisionLossAllowed: false,
// inferencePreference: TfLiteGpuInferenceUsage.fastSingleAnswer,
// inferencePriority1: TfLiteGpuInferencePriority.minLatency,
// inferencePriority2: TfLiteGpuInferencePriority.auto,
// inferencePriority3: TfLiteGpuInferencePriority.auto,
// ));


logger.e("CREATING THE INTERPRETOR");
var interpreterOptions = InterpreterOptions();//..addDelegate(gpuDelegateV2);
interp = await Interpreter.fromAsset('efficientnet_v2s.tflite',
options: interpreterOptions);
logger.e("CREATING THE INTERPRETOR");

classy = Classifier(interpreter: interp);
logger.i(interp?.getOutputTensors());
// Create an instance of classifier to load model and labels
// classifier = Classifier();


// Initially predicting = false
predicting = false;
}
Expand Down Expand Up @@ -94,7 +119,7 @@ class _CameraViewState extends State<CameraView> with WidgetsBindingObserver {
@override
Widget build(BuildContext context) {
// Return empty container while the camera is not initialized
if (!cameraController.value.isInitialized || cameraController == null) {
if (!cameraController.value.isInitialized) {
return Container();
}

Expand All @@ -114,6 +139,16 @@ class _CameraViewState extends State<CameraView> with WidgetsBindingObserver {
predicting = true;
});
logger.i("RECIEVED IMAGE");
logger.i(cameraImage.format.group);
logger.i(cameraImage);
var converted = ImageUtils.convertCameraImage(cameraImage);
if (converted != null){

var result = classy.predict(converted);

logger.e("PREDICTED IMAGE");
logger.i(result);
}
// logger.i(cameraImage);
// logger.i(cameraImage.height);
// logger.i(cameraImage.width);
Expand Down
33 changes: 33 additions & 0 deletions lib/ui/results_view.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import 'package:flutter/material.dart';
import 'package:tensordex_mobile/ui/poke_view.dart';
import 'package:tensordex_mobile/utils/recognition.dart';

import '../utils/logger.dart';

/// [CameraView] sends each frame for inference
class ResultsView extends StatefulWidget {

/// Constructor
const ResultsView({Key? key}) : super(key: key);


void setResults(Recognition results){
logger.i("RESULTS IN THE RESULT VIEW");
}

@override
State<ResultsView> createState() => _ResultsViewState();
}

class _ResultsViewState extends State<ResultsView> {

@override
void initState() {
super.initState();
}

@override
Widget build(BuildContext context) {
return Text("data");
}
}
Loading

0 comments on commit ebfbfb5

Please sign in to comment.