Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

173 support on device training #219

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
dfcc45a
makes metadata extraction work
treyfel May 13, 2022
bc82658
continue working on metadata
treyfel May 16, 2022
de9a555
Merge branch 'master' into 173-support-on-device-training
treyfel May 16, 2022
bd01e2f
continue working on metadata
treyfel May 17, 2022
b13207c
makes normalization work
treyfel May 18, 2022
5255867
makes model working with metadata
treyfel Jun 8, 2022
971ae6f
Show warning for invalid sensor tags
tfiedlerdev Jun 9, 2022
e757c93
Merge branch 'master' into 173-support-on-device-training
treyfel Jun 9, 2022
6ac6197
merge master into prediction screen
treyfel Jun 9, 2022
22d2d69
Merge branch '173-support-on-device-training' of https://github.com/S…
tfiedlerdev Jun 9, 2022
dfaf5ca
Fix bugs in InMemoryWindow.kt, start working on prediction
tfiedlerdev Jun 9, 2022
c64cb54
Fix bugs to make prediction work
tfiedlerdev Jun 10, 2022
7b19b87
Replace default model
tfiedlerdev Jun 13, 2022
748ef6f
updates gradle plugin and implements forwardFill
treyfel Jun 13, 2022
9798f39
Merge remote-tracking branch 'origin/173-support-on-device-training' …
treyfel Jun 13, 2022
d257c09
Replace default model
tfiedlerdev Jun 13, 2022
8647bff
Work on loading InMemoryWindow from Recording.kt instance
tfiedlerdev Jun 14, 2022
e6a6cf2
Implemented function to merge sensor recording files
tfiedlerdev Jun 15, 2022
f305283
Add class to windowize merged recording data files. Add tests for it …
tfiedlerdev Jun 16, 2022
a28eb53
Fix issues so that tests pass now
tfiedlerdev Jun 17, 2022
e1ae174
Partly satisfy linter + attempt to improve logging for instrumented CI
tfiedlerdev Jun 17, 2022
365c73c
Export test results
tfiedlerdev Jun 17, 2022
c3068c4
Update sdk version for instrumentation tests
tfiedlerdev Jun 17, 2022
929145b
Try fixing broken CI
tfiedlerdev Jun 17, 2022
ee13d97
Try fixing broken CI
tfiedlerdev Jun 17, 2022
ae21584
Try fixing broken CI
tfiedlerdev Jun 17, 2022
e1f625e
Work on training. (does not work yet)
tfiedlerdev Jun 21, 2022
cc40d23
Made training functional and added convenience classes
tfiedlerdev Jun 22, 2022
737ddd9
Improve convenience functions
tfiedlerdev Jun 23, 2022
5759151
Satisfy Linter
tfiedlerdev Jun 23, 2022
5a04405
Merge branch 'master' into 173-support-on-device-training
tfiedlerdev Jun 23, 2022
306d689
Fix merge conflicts
tfiedlerdev Jun 23, 2022
180c35d
implements in memory window test for forward fill
treyfel Jun 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ android {
configurations.all {
resolutionStrategy.force 'com.google.code.findbugs:jsr305:3.0.2'
}
sourceSets {
androidTest {
assets.srcDirs = ['src/main/assets', 'src/androidTest/assets', 'src/debug/assets/']
java.srcDirs = ['src/main/java', 'src/androidTest/java', 'src/debug/java']
}
}
}

dependencies {
Expand All @@ -83,11 +89,11 @@ dependencies {
implementation 'androidx.swiperefreshlayout:swiperefreshlayout:1.1.0'
implementation 'androidx.preference:preference-ktx:1.2.0'
implementation 'androidx.preference:preference-ktx:1.2.0'
implementation 'org.tensorflow:tensorflow-lite:2.7.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.3.1'
implementation 'org.tensorflow:tensorflow-lite-metadata:0.3.1'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.7.0'
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.6.0'
implementation 'org.tensorflow:tensorflow-lite:2.8.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.4.0'
implementation 'org.tensorflow:tensorflow-lite-metadata:0.4.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.8.0'
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.8.0'

implementation 'androidx.camera:camera-lifecycle:1.0.2'
implementation 'com.google.mlkit:pose-detection-common:17.0.0'
Expand Down
33,998 changes: 33,998 additions & 0 deletions app/src/androidTest/assets/0_orhan_1652085453257.csv

Large diffs are not rendered by default.

Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package sensors_in_paradise.sonar

import com.xsens.dot.android.sdk.events.XsensDotData
import org.junit.Assert
import org.junit.Test
import sensors_in_paradise.sonar.machine_learning.InMemoryWindow

class InMemoryWindowTest {
@Test
fun initializationTest() {
val features = arrayOf("Quat_Z_LF", "dq_W_LF", "dv[1]_LF").map { it.uppercase() }.toTypedArray()
val window = InMemoryWindow(features, 2)
assert(features.contentEquals(window.keys.toTypedArray()))
}
@Test
fun addDataTest() {
val features = arrayOf("Quat_Z_LF", "dq_W_RW", "dv[1]_LF")
val window = InMemoryWindow(features, 2)
for (feature in features) {
assert(window.needsFeature(feature))
}
val data = XsensDotData().apply {
quat = floatArrayOf(0f, 0f, 1f, 0f)
dq = doubleArrayOf(0.0, 0.0, 0.2, 23.0)
dv = doubleArrayOf(23.0, 0.0, 0.0)
}

window.appendSensorData("LF", data)
Assert.assertFalse(window.hasEnoughDataToCompileWindow())
window.appendSensorData("RW", data)
Assert.assertFalse(window.hasEnoughDataToCompileWindow())
data.sampleTimeFine += 1L

window.appendSensorData("LF", data)
Assert.assertFalse(window.hasEnoughDataToCompileWindow())
data.sampleTimeFine += 1L

window.appendSensorData("RW", data)
Assert.assertTrue(window.hasEnoughDataToCompileWindow())

window.compileWindow()

}
@Test
fun forwardFillDataTest(){
val features = arrayOf("Quat_W_LF", "Quat_X_LF", "Quat_Y_LF", "Quat_Z_LF")
val window = InMemoryWindow(features, 3)

val data0 = XsensDotData().apply {
sampleTimeFine = 0L
quat = floatArrayOf(0f, Float.NaN, 1f, 0f)
dq = doubleArrayOf(0.0, 0.0, 0.2, 23.0)
dv = doubleArrayOf(23.0, 0.0, Double.NaN)
quat = floatArrayOf(0f, Float.NaN, 1f, 1f)
}
val data1 = XsensDotData().apply {
sampleTimeFine = 1L
quat = floatArrayOf(0f, 1f, Float.NaN, 2f)
}
val data2 = XsensDotData().apply {
sampleTimeFine = 2L
quat = floatArrayOf(0f, Float.NaN, 0f, 0f)
}

window.appendSensorData("LF", data0)
Assert.assertFalse(windowHasNan(window))
window.appendSensorData("LF", data1)
Assert.assertFalse(windowHasNan(window))
window.appendSensorData("LF", data2)
Assert.assertFalse(windowHasNan(window))
Assert.assertTrue(window["QUAT_W_LF"]!! == arrayListOf(Pair(0L, 0f), Pair(1L, 0f), Pair(2L, 0f)))
Assert.assertTrue(window["QUAT_X_LF"]!! == arrayListOf(Pair(0L, 0f), Pair(1L, 1f), Pair(2L, 1f)))
Assert.assertTrue(window["QUAT_Y_LF"]!! == arrayListOf(Pair(0L, 1f), Pair(1L, 1f), Pair(2L, 0f)))
Assert.assertTrue(window["QUAT_Z_LF"]!! == arrayListOf(Pair(0L, 1f), Pair(1L, 2f), Pair(2L, 0f)))
}

private fun windowHasNan(window: InMemoryWindow): Boolean {
return window.values.any { arrayList -> arrayList.any { it.second.isNaN() } }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package sensors_in_paradise.sonar

import android.util.Log
import androidx.test.platform.app.InstrumentationRegistry
import org.junit.Before
import org.junit.Test
import sensors_in_paradise.sonar.machine_learning.DataSet
import sensors_in_paradise.sonar.machine_learning.TFLiteModel
import sensors_in_paradise.sonar.screen_recording.RecordingDataFile
import sensors_in_paradise.sonar.use_cases.UseCase
import java.io.File

class RecordingDataFileTest {
private val assetContext = InstrumentationRegistry.getInstrumentation().context
private val appContext = InstrumentationRegistry.getInstrumentation().targetContext
private val recordingFile = File(appContext.cacheDir, "recordingData.csv")
private val modelFile = File(appContext.cacheDir, "model_test.tflite")

@Before
fun init() {
UseCase.extractFileFromAssets(assetContext, "0_orhan_1652085453257.csv", recordingFile)
UseCase.extractFileFromAssets(assetContext, "resnet_model_22_06_2022.tflite", modelFile)
}

@Test
fun windowizeTest() {
val features =
arrayOf("Quat_Z_LF", "dq_W_LF", "dv[1]_LF").map { it.uppercase() }.toTypedArray()
val windowSize = 90
val data = RecordingDataFile(recordingFile)
val startIndexes = data.getWindowStartIndexes(windowSize)

assert(startIndexes.size > 0)
var i = 1
for (startIndex in startIndexes) {
Log.d(
"RecordingDataFileTest-windowizeTest",
"Working on window $i of ${startIndexes.size}"
)
val (window, activity) = data.getWindowAtIndex(startIndex, windowSize, features)
assert(window.size == features.size)

// Test if compiling the window into a float buffer runs through
window.compileWindow()
i++
}
}

@Test
fun predictionPipelineTest() {
val model = TFLiteModel(modelFile)

val features = model.getFeaturesToPredict().map { it.uppercase() }.toTypedArray()

val data = RecordingDataFile(recordingFile)
val startIndexes = data.getWindowStartIndexes(model.windowSize)
var correctPredictions = 0
assert(startIndexes.size > 0)

var i = 1
for (startIndex in startIndexes) {
val (window, activity) = data.getWindowAtIndex(startIndex, model.windowSize, features)
assert(window.size == features.size)

// Test if compiling the window into a float buffer runs through
val input = window.compileWindowToArray()
val prediction = model.infer(arrayOf(input))
val predictedLabel = model.convertPredictionToLabel(prediction[0])
if (predictedLabel == activity) {
correctPredictions++
}
Log.d(
"RecordingDataFileTest-predictionPipelineTest",
"Working on window $i of ${startIndexes.size}"
)
i++
}
val accuracy = (correctPredictions * 100 / startIndexes.size)
Log.d("RecordingDataFileTest", "Prediction accuracy on the example recording: $accuracy%")
}

@Test
fun trainingPipelineTest() {
val model = TFLiteModel(modelFile)
val data = RecordingDataFile(recordingFile)
val dataSet = DataSet().apply { add(data) }

val batches = dataSet.convertToBatches(7, model.windowSize, progressCallback = { progress ->
Log.d(
"RecordingDataFileTest-trainingPipelineTest",
"Batching dataset: $progress%"
)
})

val accuracyBeforeTraining = model.evaluate(batches) { batch, window ->
Log.d(
"RecordingDataFileTest",
"Evaluating model before training. Batch: $batch%, Window Progress: $window%"
)
}
val losses = model.train(batches, 5) { epoch, batch, window ->
Log.d(
"RecordingDataFileTest",
"Training model. Epoch $epoch%, batch $batch%, window $window%"
)
}
Log.d(
"RecordingDataFileTest",
"Losses during training epochs: ${losses.joinToString()}"
)
val accuracyAfterTraining = model.evaluate(batches) { batch, window ->
Log.d(
"RecordingDataFileTest",
"Evaluating model after training. Batch: $batch%, Window Progress: $window%"
)
}
Log.d(
"RecordingDataFileTest",
"Prediction accuracy on the example recording before training: " +
"$accuracyBeforeTraining and after training: $accuracyAfterTraining%"
)
}
}
46 changes: 40 additions & 6 deletions app/src/main/java/sensors_in_paradise/sonar/GlobalValues.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@ package sensors_in_paradise.sonar
import android.Manifest
import android.content.Context
import android.os.Build
import android.util.Log
import com.xsens.dot.android.sdk.models.XsensDotPayload
import java.io.BufferedReader
import java.io.File
import java.io.FileReader
import java.net.NetworkInterface
import java.net.SocketException
import java.util.*
import kotlin.collections.ArrayList

Expand All @@ -17,10 +22,6 @@ class GlobalValues private constructor() {
const val METADATA_JSON_FILENAME = "metadata.json"
const val MEASUREMENT_MODE = XsensDotPayload.PAYLOAD_TYPE_CUSTOM_MODE_4

fun getSensorRecordingsBaseDir(context: Context): File {
return context.getExternalFilesDir(null) ?: context.dataDir
}

fun getSensorRecordingsTempDir(context: Context): File {
return context.dataDir.resolve("temp")
}
Expand Down Expand Up @@ -52,8 +53,6 @@ class GlobalValues private constructor() {
return result
}

val sensorTagPrefixes = listOf("LF", "LW", "ST", "RW", "RF")

fun formatTag(tagPrefix: String, deviceSetKey: String): String {
return "$tagPrefix-$deviceSetKey"
}
Expand All @@ -66,6 +65,32 @@ class GlobalValues private constructor() {

return minutes.toString().padStart(2, '0') + ":" + seconds.toString().padStart(2, '0')
}

fun getMacAddress(): String {
try {
val all = Collections.list(NetworkInterface.getNetworkInterfaces())
for (nif in all) {
if (!nif.name.equals("wlan0", ignoreCase=true)) continue

val macBytes = nif.hardwareAddress ?: return ""

val res1 = StringBuilder()
for (b in macBytes) {
res1.append(String.format(Locale.US, "%02X:", b))
}

if (res1.isNotEmpty()) {
res1.deleteCharAt(res1.length - 1)
}
return res1.toString()
}
} catch (ex: SocketException) {
ex.message?.let { Log.e("GlobalValues", it) }
}

return "02:00:00:00:00:00"
}

private val fileEmojiMap = mapOf(
"mp4" to "\uD83C\uDF9E️",
"json" to "\uD83D\uDCD8",
Expand All @@ -80,5 +105,14 @@ class GlobalValues private constructor() {
val extension = name.substring(name.lastIndexOf(".") + 1)
return fileEmojiMap[extension] ?: "\uD83D\uDCC4"
}
@Throws(NumberFormatException::class)
fun getCSVHeaderAwareFileReader(inputFile: File): BufferedReader {
val fileReader = BufferedReader(FileReader(inputFile))
var line = fileReader.readLine()
while (line != "") {
line = fileReader.readLine()
}
return fileReader
}
}
}
Loading