Skip to content

Commit

Permalink
Add Koltin and Java API for Kokoro TTS models (#1728)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jan 17, 2025
1 parent 3a1de0b commit 99cef41
Show file tree
Hide file tree
Showing 18 changed files with 549 additions and 40 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/run-java-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,12 @@ jobs:
run: |
cd ./java-api-examples
./run-non-streaming-tts-kokoro-en.sh
./run-non-streaming-tts-matcha-zh.sh
./run-non-streaming-tts-matcha-en.sh
ls -lh
rm -rf kokoro-en-*
rm -rf matcha-icefall-*
rm hifigan_v2.onnx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class MainActivity : AppCompatActivity() {
var modelName: String?
var acousticModelName: String?
var vocoder: String?
var voices: String?
var ruleFsts: String?
var ruleFars: String?
var lexicon: String?
Expand All @@ -205,6 +206,10 @@ class MainActivity : AppCompatActivity() {
vocoder = null
// Matcha -- end

// For Kokoro -- begin
voices = null
// For Kokoro -- end


modelDir = null
ruleFsts = null
Expand Down Expand Up @@ -269,6 +274,13 @@ class MainActivity : AppCompatActivity() {
// vocoder = "hifigan_v2.onnx"
// dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data"

// Example 9
// kokoro-en-v0_19
// modelDir = "kokoro-en-v0_19"
// modelName = "model.onnx"
// voices = "voices.bin"
// dataDir = "kokoro-en-v0_19/espeak-ng-data"

if (dataDir != null) {
val newDir = copyDataDir(dataDir!!)
dataDir = "$newDir/$dataDir"
Expand All @@ -285,6 +297,7 @@ class MainActivity : AppCompatActivity() {
modelName = modelName ?: "",
acousticModelName = acousticModelName ?: "",
vocoder = vocoder ?: "",
voices = voices ?: "",
lexicon = lexicon ?: "",
dataDir = dataDir ?: "",
dictDir = dictDir ?: "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fun getSampleText(lang: String): String {
}

"eng" -> {
text = "This is a text-to-speech engine using next generation Kaldi"
text = "How are you doing today? This is a text-to-speech engine using next generation Kaldi"
}

"est" -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
package com.k2fsa.sherpa.onnx.tts.engine

import PreferenceHelper
import android.media.AudioAttributes
import android.media.AudioFormat
import android.media.AudioManager
import android.media.AudioTrack
import android.media.MediaPlayer
import android.net.Uri
import android.os.Bundle
Expand Down Expand Up @@ -36,7 +40,13 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.unit.dp
import com.k2fsa.sherpa.onnx.tts.engine.ui.theme.SherpaOnnxTtsEngineTheme
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File
import kotlin.time.TimeSource

const val TAG = "sherpa-onnx-tts-engine"

Expand All @@ -45,9 +55,26 @@ class MainActivity : ComponentActivity() {
private val ttsViewModel: TtsViewModel by viewModels()

private var mediaPlayer: MediaPlayer? = null

// see
// https://developer.android.com/reference/kotlin/android/media/AudioTrack
private lateinit var track: AudioTrack

private var stopped: Boolean = false

private var samplesChannel = Channel<FloatArray>()

override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)

Log.i(TAG, "Start to initialize TTS")
TtsEngine.createTts(this)
Log.i(TAG, "Finish initializing TTS")

Log.i(TAG, "Start to initialize AudioTrack")
initAudioTrack()
Log.i(TAG, "Finish initializing AudioTrack")

val preferenceHelper = PreferenceHelper(this)
setContent {
SherpaOnnxTtsEngineTheme {
Expand Down Expand Up @@ -77,6 +104,11 @@ class MainActivity : ComponentActivity() {
val testTextContent = getSampleText(TtsEngine.lang ?: "")

var testText by remember { mutableStateOf(testTextContent) }
var startEnabled by remember { mutableStateOf(true) }
var playEnabled by remember { mutableStateOf(false) }
var rtfText by remember {
mutableStateOf("")
}

val numSpeakers = TtsEngine.tts!!.numSpeakers()
if (numSpeakers > 1) {
Expand Down Expand Up @@ -119,52 +151,117 @@ class MainActivity : ComponentActivity() {

Row {
Button(
modifier = Modifier.padding(20.dp),
enabled = startEnabled,
modifier = Modifier.padding(5.dp),
onClick = {
Log.i(TAG, "Clicked, text: $testText")
if (testText.isBlank() || testText.isEmpty()) {
Toast.makeText(
applicationContext,
"Please input a test sentence",
"Please input some text to generate",
Toast.LENGTH_SHORT
).show()
} else {
val audio = TtsEngine.tts!!.generate(
text = testText,
sid = TtsEngine.speakerId,
speed = TtsEngine.speed,
)

val filename =
application.filesDir.absolutePath + "/generated.wav"
val ok =
audio.samples.isNotEmpty() && audio.save(
filename
)
startEnabled = false
playEnabled = false
stopped = false

if (ok) {
stopMediaPlayer()
mediaPlayer = MediaPlayer.create(
applicationContext,
Uri.fromFile(File(filename))
)
mediaPlayer?.start()
} else {
Log.i(TAG, "Failed to generate or save audio")
track.pause()
track.flush()
track.play()
rtfText = ""
Log.i(TAG, "Started with text $testText")

samplesChannel = Channel<FloatArray>()

CoroutineScope(Dispatchers.IO).launch {
for (samples in samplesChannel) {
track.write(
samples,
0,
samples.size,
AudioTrack.WRITE_BLOCKING
)
if (stopped) {
break
}
}
}

CoroutineScope(Dispatchers.Default).launch {
val timeSource = TimeSource.Monotonic
val startTime = timeSource.markNow()

val audio =
TtsEngine.tts!!.generateWithCallback(
text = testText,
sid = TtsEngine.speakerId,
speed = TtsEngine.speed,
callback = ::callback,
)

val elapsed =
startTime.elapsedNow().inWholeMilliseconds.toFloat() / 1000;
val audioDuration =
audio.samples.size / TtsEngine.tts!!.sampleRate()
.toFloat()
val RTF = String.format(
"Number of threads: %d\nElapsed: %.3f s\nAudio duration: %.3f s\nRTF: %.3f/%.3f = %.3f",
TtsEngine.tts!!.config.model.numThreads,
audioDuration,
elapsed,
elapsed,
audioDuration,
elapsed / audioDuration
)
samplesChannel.close()

val filename =
application.filesDir.absolutePath + "/generated.wav"


val ok =
audio.samples.isNotEmpty() && audio.save(
filename
)

if (ok) {
withContext(Dispatchers.Main) {
startEnabled = true
playEnabled = true
rtfText = RTF
}
}
}.start()
}
}) {
Text("Test")
Text("Start")
}

Button(
modifier = Modifier.padding(20.dp),
modifier = Modifier.padding(5.dp),
enabled = playEnabled,
onClick = {
TtsEngine.speakerId = 0
TtsEngine.speed = 1.0f
testText = ""
stopped = true
track.pause()
track.flush()
onClickPlay()
}) {
Text("Reset")
Text("Play")
}

Button(
modifier = Modifier.padding(5.dp),
onClick = {
onClickStop()
startEnabled = true
}) {
Text("Stop")
}
}
if (rtfText.isNotEmpty()) {
Row {
Text(rtfText)
}
}
}
Expand All @@ -185,4 +282,63 @@ class MainActivity : ComponentActivity() {
mediaPlayer?.release()
mediaPlayer = null
}

private fun onClickPlay() {
val filename = application.filesDir.absolutePath + "/generated.wav"
stopMediaPlayer()
mediaPlayer = MediaPlayer.create(
applicationContext,
Uri.fromFile(File(filename))
)
mediaPlayer?.start()
}

private fun onClickStop() {
stopped = true
track.pause()
track.flush()

stopMediaPlayer()
}

// this function is called from C++
private fun callback(samples: FloatArray): Int {
if (!stopped) {
val samplesCopy = samples.copyOf()
CoroutineScope(Dispatchers.IO).launch {
samplesChannel.send(samplesCopy)
}
return 1
} else {
track.stop()
Log.i(TAG, " return 0")
return 0
}
}

private fun initAudioTrack() {
val sampleRate = TtsEngine.tts!!.sampleRate()
val bufLength = AudioTrack.getMinBufferSize(
sampleRate,
AudioFormat.CHANNEL_OUT_MONO,
AudioFormat.ENCODING_PCM_FLOAT
)
Log.i(TAG, "sampleRate: $sampleRate, buffLength: $bufLength")

val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.setUsage(AudioAttributes.USAGE_MEDIA)
.build()

val format = AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.setSampleRate(sampleRate)
.build()

track = AudioTrack(
attr, format, bufLength, AudioTrack.MODE_STREAM,
AudioManager.AUDIO_SESSION_ID_GENERATE
)
track.play()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ object TtsEngine {

private var modelDir: String? = null
private var modelName: String? = null
private var acousticModelName: String? = null
private var vocoder: String? = null
private var acousticModelName: String? = null // for matcha tts
private var vocoder: String? = null // for matcha tts
private var voices: String? = null // for kokoro
private var ruleFsts: String? = null
private var ruleFars: String? = null
private var lexicon: String? = null
Expand All @@ -64,6 +65,10 @@ object TtsEngine {
vocoder = null
// For Matcha -- end

// For Kokoro -- begin
voices = null
// For Kokoro -- end

modelDir = null
ruleFsts = null
ruleFars = null
Expand Down Expand Up @@ -139,6 +144,14 @@ object TtsEngine {
// vocoder = "hifigan_v2.onnx"
// dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data"
// lang = "eng"

// Example 9
// kokoro-en-v0_19
// modelDir = "kokoro-en-v0_19"
// modelName = "model.onnx"
// voices = "voices.bin"
// dataDir = "kokoro-en-v0_19/espeak-ng-data"
// lang = "eng"
}

fun createTts(context: Context) {
Expand Down Expand Up @@ -167,6 +180,7 @@ object TtsEngine {
modelName = modelName ?: "",
acousticModelName = acousticModelName ?: "",
vocoder = vocoder ?: "",
voices = voices ?: "",
lexicon = lexicon ?: "",
dataDir = dataDir ?: "",
dictDir = dictDir ?: "",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
<resources>
<string name="app_name">TTS Engine</string>
<string name="app_name">TTS Engine: Next-gen Kaldi</string>
</resources>
Loading

0 comments on commit 99cef41

Please sign in to comment.