This document walks through the code of a simple iOS mobile application that demonstrates image classification using the device camera.
The app is written entirely in Swift and uses the TensorFlow Lite Swift library for performing image classification.
Note: Objective-C developers should use the TensorFlow Lite Objective-C library.
We're now going to walk through the most important parts of the sample code.
The app's main view is represented by the ViewController
class in
ViewController.swift
,
which we extend with functionality from the CameraFeedManagerDelegate
protocol
to process frames from a camera feed. To run inference on a given frame, we
implement the didOutput
method, which is called whenever a frame is available
from the camera.
Our implementation of didOutput
includes a call to the runModel
method of a
ModelDataHandler
instance. As we will see below, this class gives us access to
the TensorFlow Lite Interpreter
class for performing image classification.
extension ViewController: CameraFeedManagerDelegate {
func didOutput(pixelBuffer: CVPixelBuffer) {
let currentTimeMs = Date().timeIntervalSince1970 * 1000
guard (currentTimeMs - previousInferenceTimeMs) >= delayBetweenInferencesMs else { return }
previousInferenceTimeMs = currentTimeMs
// Pass the pixel buffer to TensorFlow Lite to perform inference.
result = modelDataHandler?.runModel(onFrame: pixelBuffer)
// Display results by handing off to the InferenceViewController.
DispatchQueue.main.async {
let resolution = CGSize(width: CVPixelBufferGetWidth(pixelBuffer), height: CVPixelBufferGetHeight(pixelBuffer))
self.inferenceViewController?.inferenceResult = self.result
self.inferenceViewController?.resolution = resolution
self.inferenceViewController?.tableView.reloadData()
}
}
...
The Swift class ModelDataHandler
, defined in
ModelDataHandler.swift
,
handles all data preprocessing and makes calls to run inference on a given frame
using the TensorFlow Lite Interpreter
.
It then formats the inferences obtained from invoking the Interpreter
and
returns the top N results for a successful inference.
The following sections show how this works.
The init
method creates a new instance of the Interpreter
and loads the
specified model and labels files from the app's main bundle.
init?(modelFileInfo: FileInfo, labelsFileInfo: FileInfo, threadCount: Int = 1) {
let modelFilename = modelFileInfo.name
// Construct the path to the model file.
guard let modelPath = Bundle.main.path(
forResource: modelFilename,
ofType: modelFileInfo.extension
) else {
print("Failed to load the model file with name: \(modelFilename).")
return nil
}
// Specify the options for the `Interpreter`.
self.threadCount = threadCount
var options = InterpreterOptions()
options.threadCount = threadCount
options.isErrorLoggingEnabled = true
do {
// Create the `Interpreter`.
interpreter = try Interpreter(modelPath: modelPath, options: options)
} catch let error {
print("Failed to create the interpreter with error: \(error.localizedDescription)")
return nil
}
// Load the classes listed in the labels file.
loadLabels(fileInfo: labelsFileInfo)
}
The method runModel
accepts a CVPixelBuffer
of camera data, which can be
obtained from the didOutput
method defined in ViewController
.
We crop the image to the size that the model was trained on. For example,
224x224
for the MobileNet v1 model.
The image buffer contains an encoded color for each pixel in BGRA
format
(where A
represents Alpha, or transparency). Our model expects the format to
be RGB
, so we use the following helper method to remove the alpha component
from the image buffer to get the RGB
data representation:
private let alphaComponent = (baseOffset: 4, moduloRemainder: 3)
private func rgbDataFromBuffer(
_ buffer: CVPixelBuffer,
byteCount: Int,
isModelQuantized: Bool
) -> Data? {
CVPixelBufferLockBaseAddress(buffer, .readOnly)
defer { CVPixelBufferUnlockBaseAddress(buffer, .readOnly) }
guard let mutableRawPointer = CVPixelBufferGetBaseAddress(buffer) else {
return nil
}
let count = CVPixelBufferGetDataSize(buffer)
let bufferData = Data(bytesNoCopy: mutableRawPointer, count: count, deallocator: .none)
var rgbBytes = [UInt8](repeating: 0, count: byteCount)
var index = 0
for component in bufferData.enumerated() {
let offset = component.offset
let isAlphaComponent = (offset % alphaComponent.baseOffset) == alphaComponent.moduloRemainder
guard !isAlphaComponent else { continue }
rgbBytes[index] = component.element
index += 1
}
if isModelQuantized { return Data(bytes: rgbBytes) }
return Data(copyingBufferOf: rgbBytes.map { Float($0) / 255.0 })
}
Here's the code for getting the RGB
data representation of the pixel buffer,
copying that data to the input
Tensor
,
and running inference by invoking the Interpreter
:
let outputTensor: Tensor
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
let inputTensor = try interpreter.input(at: 0)
// Remove the alpha component from the image buffer to get the RGB data.
guard let rgbData = rgbDataFromBuffer(
thumbnailPixelBuffer,
byteCount: batchSize * inputWidth * inputHeight * inputChannels,
isModelQuantized: inputTensor.dataType == .uInt8
) else {
print("Failed to convert the image buffer to RGB data.")
return
}
// Copy the RGB data to the input `Tensor`.
try interpreter.copy(rgbData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
} catch let error {
print("Failed to invoke the interpreter with error: \(error.localizedDescription)")
return
}
If the model is quantized, the output Tensor
contains one UInt8
value per
class label. Dequantize the results so the values are floats, ranging from 0.0
to 1.0, where each value represents the confidence that a label is present in
the image:
guard let quantization = outputTensor.quantizationParameters else {
print("No results returned because the quantization values for the output tensor are nil.")
return
}
// Get the quantized results from the output tensor's `data` property.
let quantizedResults = [UInt8](outputTensor.data)
// Dequantize the results using the quantization values.
let results = quantizedResults.map {
quantization.scale * Float(Int($0) - quantization.zeroPoint)
}
Next, the results are sorted to get the top N
results (where N
is
resultCount
):
// Create a zipped array of tuples [(labelIndex: Int, confidence: Float)].
let zippedResults = zip(labels.indices, results)
// Sort the zipped results by confidence value in descending order.
let sortedResults = zippedResults.sorted { $0.1 > $1.1 }.prefix(resultCount)
// Get the top N `Inference` results.
let topNInferences = sortedResults.map { result in Inference(confidence: result.1, label: labels[result.0]) }
The file
InferenceViewController.swift
defines the app's UI. A UITableView
is used to display the results.