diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt index 18a89f93c..6b1371450 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt @@ -32,6 +32,7 @@ import androidx.compose.material3.Slider import androidx.compose.material3.Surface import androidx.compose.material3.Text import androidx.compose.material3.TopAppBar +import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember @@ -46,6 +47,8 @@ import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import java.io.File +import kotlin.math.roundToInt +import kotlinx.coroutines.delay import kotlin.time.TimeSource const val TAG = "sherpa-onnx-tts-engine" @@ -76,6 +79,9 @@ class MainActivity : ComponentActivity() { Log.i(TAG, "Finish initializing AudioTrack") val preferenceHelper = PreferenceHelper(this) + + TtsEngine.cacheSize = preferenceHelper.getCacheSizeInMB() + setContent { SherpaOnnxTtsEngineTheme { // A surface container using the 'background' color from the theme @@ -88,6 +94,17 @@ class MainActivity : ComponentActivity() { }) { Box(modifier = Modifier.padding(it)) { Column(modifier = Modifier.padding(16.dp)) { + // Track used cache size in a mutable state + var usedCacheSizeMB by remember { mutableStateOf(0) } + + // LaunchedEffect to periodically update the used cache size + LaunchedEffect(Unit) { + while (true) { + usedCacheSizeMB = TtsEngine.tts?.getTotalUsedCacheSizeInMB() ?: 0 + delay(5000) // Update every 5 seconds + } + } + Column { Text("Speed " + String.format("%.1f", TtsEngine.speed)) Slider( @@ -95,10 +112,26 @@ class MainActivity : ComponentActivity() { onValueChange = { TtsEngine.speed = it preferenceHelper.setSpeed(it) + TtsEngine.tts?.clearCache() // Call the clearCache method + usedCacheSizeMB = 0 // Reset used cache size }, valueRange = 0.2F..3.0F, modifier = Modifier.fillMaxWidth() ) + + Text("Cache Size: ${TtsEngine.cacheSize}MB (${usedCacheSizeMB}MB used)") + Slider( + value = TtsEngine.cacheSizeState.value.toFloat(), + onValueChange = { newValue -> + // Round the value to the nearest multiple of 10 + val roundedValue = (newValue / 5).roundToInt() * 5 + TtsEngine.cacheSize = roundedValue + preferenceHelper.setCacheSizeInMB(roundedValue) + TtsEngine.tts?.setCacheSizeInMB(roundedValue) + }, + valueRange = 0f..100f, + modifier = Modifier.fillMaxWidth() + ) } val testTextContent = getSampleText(TtsEngine.lang ?: "") @@ -272,6 +305,13 @@ class MainActivity : ComponentActivity() { } } + override fun onResume() { + super.onResume() + // Update used cache size when the app is resumed + val usedCacheSizeMB = TtsEngine.tts?.getTotalUsedCacheSizeInMB() ?: 0 + Log.i(TAG, "App resumed. Used cache size: ${usedCacheSizeMB}MB") + } + override fun onDestroy() { stopMediaPlayer() super.onDestroy() diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/PreferencesHelper.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/PreferencesHelper.kt index 57a6e324c..314ffa87d 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/PreferencesHelper.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/PreferencesHelper.kt @@ -6,6 +6,7 @@ class PreferenceHelper(context: Context) { private val PREFS_NAME = "com.k2fsa.sherpa.onnx.tts.engine" private val SPEED_KEY = "speed" private val SID_KEY = "speaker_id" + private val CACHE_SIZE_KEY = "cache_size" private val sharedPreferences: SharedPreferences = context.getSharedPreferences(PREFS_NAME, Context.MODE_PRIVATE) @@ -29,4 +30,14 @@ class PreferenceHelper(context: Context) { fun getSid(): Int { return sharedPreferences.getInt(SID_KEY, 0) } -} \ No newline at end of file + + fun setCacheSizeInMB(value: Int) { + val editor = sharedPreferences.edit() + editor.putInt(CACHE_SIZE_KEY, value) + editor.apply() + } + + fun getCacheSizeInMB(): Int { + return sharedPreferences.getInt(CACHE_SIZE_KEY, 20) // Default cache size is 20MB + } +} diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt index 69792a560..91f6b2d12 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt @@ -25,6 +25,7 @@ object TtsEngine { val speedState: MutableState = mutableFloatStateOf(1.0F) + val cacheSizeState: MutableState = mutableIntStateOf(0) val speakerIdState: MutableState = mutableIntStateOf(0) var speed: Float @@ -33,6 +34,12 @@ object TtsEngine { speedState.value = value } + var cacheSize: Int + get() = cacheSizeState.value + set(value) { + cacheSizeState.value = value + } + var speakerId: Int get() = speakerIdState.value set(value) { @@ -190,8 +197,11 @@ object TtsEngine { speed = PreferenceHelper(context).getSpeed() speakerId = PreferenceHelper(context).getSid() + cacheSize = PreferenceHelper(context).getCacheSizeInMB() tts = OfflineTts(assetManager = assets, config = config) + + tts?.setCacheSizeInMB(cacheSize) } diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index d5303b757..5b0bfef35 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -155,6 +155,8 @@ if(SHERPA_ONNX_ENABLE_TTS) jieba-lexicon.cc lexicon.cc melo-tts-lexicon.cc + offline-tts-cache-mechanism-config.cc + offline-tts-cache-mechanism.cc offline-tts-character-frontend.cc offline-tts-frontend.cc offline-tts-impl.cc diff --git a/sherpa-onnx/csrc/offline-tts-cache-mechanism-config.cc b/sherpa-onnx/csrc/offline-tts-cache-mechanism-config.cc new file mode 100644 index 000000000..bd06794fe --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-cache-mechanism-config.cc @@ -0,0 +1,35 @@ +// sherpa-onnx/csrc/offline-tts-cache-mechanism-config.cc +// +// Copyright (c) 2025 @mah92 From Iranian people to the community with love + +#include "sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineTtsCacheMechanismConfig::Register(ParseOptions *po) { + po->Register("tts-cache-dir", &cache_dir, + "Path to the directory containing dict for espeak-ng."); + po->Register("tts-cache-size", &cache_size, + "Cache size for wav files in bytes. After the cache size is filled, wav files are kept based on usage statstics."); +} + +bool OfflineTtsCacheMechanismConfig::Validate() const { + return true; +} + +std::string OfflineTtsCacheMechanismConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTtsCacheMechanismConfig("; + os << "cache_dir=\"" << cache_dir << "\", "; + os << "cache_size=" << cache_size << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h b/sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h new file mode 100644 index 000000000..2f5d2baba --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h @@ -0,0 +1,35 @@ +// sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h +// +// Copyright (c) 2025 @mah92 From Iranian people to the community with love + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineTtsCacheMechanismConfig { + + std::string cache_dir; + + int32_t cache_size; + + OfflineTtsCacheMechanismConfig() = default; + + OfflineTtsCacheMechanismConfig(const std::string &cache_dir, + int32_t cache_size) + : cache_dir(cache_dir), + cache_size(cache_size) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-tts-cache-mechanism.cc b/sherpa-onnx/csrc/offline-tts-cache-mechanism.cc new file mode 100644 index 000000000..166430bc7 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-cache-mechanism.cc @@ -0,0 +1,311 @@ +// sherpa-onnx/csrc/offline-tts-cache-mechanism.cc +// +// Copyright (c) 2025 @mah92 From Iranian people to the community with love + +#include "sherpa-onnx/csrc/offline-tts-cache-mechanism.h" + +#include +#include // NOLINT +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/wave-reader.h" +#include "sherpa-onnx/csrc/wave-writer.h" + +// Platform-specific time functions +#if defined(_WIN32) +#include +#else +#include +#include +#endif + +namespace sherpa_onnx { + +OfflineTtsCacheMechanism::OfflineTtsCacheMechanism(const std::string &cache_dir, + int32_t cache_size) + : cache_dir_(cache_dir), + cache_size_bytes_(cache_size), + used_cache_size_bytes_(0) { + + // Create the cache directory if it doesn't exist + if (!std::filesystem::exists(cache_dir_)) { + bool dir_created = std::filesystem::create_directory(cache_dir_); + if (!dir_created) { + SHERPA_ONNX_LOGE("Unable to create cache directory: %s", + cache_dir_.c_str()); + SHERPA_ONNX_LOGE("Cache mechanism disabled!"); + cache_mechanism_inited_ = false; + return; + } + } + + // Load the repeat counts + LoadRepeatCounts(); + + // Update the cache vector and calculate the total cache size + UpdateCacheVector(); + + // Initialize the last save time + last_save_time_ = std::chrono::steady_clock::now(); + + // Indicate that initialization has been successful + cache_mechanism_inited_ = true; +} + +OfflineTtsCacheMechanism::~OfflineTtsCacheMechanism() { + if (cache_mechanism_inited_ == false) return; + + // Save the repeat counts on destruction + SaveRepeatCounts(); +} + +void OfflineTtsCacheMechanism::AddWavFile( + const std::string &text_hash, + const std::vector &samples, + const int32_t sample_rate) { + std::lock_guard lock(mutex_); + + if (cache_mechanism_inited_ == false) return; + + std::string file_path = cache_dir_ + "/" + text_hash + ".wav"; + + // Check if the file physically exists in the cache directory + bool file_exists = std::filesystem::exists(file_path); + + if (!file_exists) { // If the file does not exist, add it to the cache + // Ensure the cache does not exceed its size limit + EnsureCacheLimit(); + + // Write the audio samples to a WAV file + bool success = WriteWave(file_path, + sample_rate, samples.data(), samples.size()); + if (success) { + // Calculate size of the new WAV file and add it to the total cache size + std::ifstream file(file_path, std::ios::binary | std::ios::ate); + if (file.is_open()) { + used_cache_size_bytes_ += file.tellg(); + } + } else { + SHERPA_ONNX_LOGE("Failed to write wav file: %s", file_path.c_str()); + } + } +} + +std::vector OfflineTtsCacheMechanism::GetWavFile( + const std::string &text_hash, + int32_t *sample_rate) { + std::lock_guard lock(mutex_); + + std::vector samples; + + if (cache_mechanism_inited_ == false) return samples; + + std::string file_path = cache_dir_ + "/" + text_hash + ".wav"; + + if (std::filesystem::exists(file_path)) { + bool is_ok = false; + samples = ReadWave(file_path, sample_rate, &is_ok); + + if (is_ok == false) { + SHERPA_ONNX_LOGE("Failed to read cached file: %s", file_path.c_str()); + } + } + + // Ensure the text_hash exists in the map before incrementing the count + if (repeat_counts_.find(text_hash) == repeat_counts_.end()) { + repeat_counts_[text_hash] = 1; // Initialize if it doesn't exist + } else { + repeat_counts_[text_hash]++; // Increment the repeat count + } + + // Save the repeat counts every 10 minutes + auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast( + now - last_save_time_).count() >= 10 * 60) { + SaveRepeatCounts(); + last_save_time_ = now; + } + + return samples; +} + +int32_t OfflineTtsCacheMechanism::GetCacheSize() const { + if (cache_mechanism_inited_ == false) return 0; + + return cache_size_bytes_; +} + +void OfflineTtsCacheMechanism::SetCacheSize(int32_t cache_size) { + std::lock_guard lock(mutex_); + + if (cache_mechanism_inited_ == false) return; + + cache_size_bytes_ = cache_size; + + EnsureCacheLimit(); +} + +void OfflineTtsCacheMechanism::ClearCache() { + std::lock_guard lock(mutex_); + + if (cache_mechanism_inited_ == false) return; + + // Remove all WAV files in the cache directory + for (const auto &entry : std::filesystem::directory_iterator(cache_dir_)) { + if (entry.path().extension() == ".wav") { + std::filesystem::remove(entry.path()); + } + } + + // Reset the total cache size to 0 + used_cache_size_bytes_ = 0; + + // Clear the repeat counts and cache vector + repeat_counts_.clear(); + cache_vector_.clear(); + + // Remove repeat counts also in the repeat_counts.txt + SaveRepeatCounts(); +} + +int32_t OfflineTtsCacheMechanism::GetTotalUsedCacheSize() const { + std::lock_guard lock(mutex_); + + if (cache_mechanism_inited_ == false) return 0; + + return used_cache_size_bytes_; +} + +// Private functions /////////////////////////////////////////////////// + +void OfflineTtsCacheMechanism::LoadRepeatCounts() { + std::string repeat_count_file = cache_dir_ + "/repeat_counts.txt"; + + // Check if the file exists + if (!std::filesystem::exists(repeat_count_file)) { + return; // Skip loading if the file doesn't exist + } + + // Open the file for reading + std::ifstream ifs(repeat_count_file); + if (!ifs.is_open()) { + SHERPA_ONNX_LOGE("Failed to open repeat count file: %s", + repeat_count_file.c_str()); + return; // Skip loading if the file cannot be opened + } + + // Read the file line by line + std::string line; + while (std::getline(ifs, line)) { + size_t pos = line.find(' '); + if (pos != std::string::npos) { + std::string text_hash = line.substr(0, pos); + int32_t count = std::stoi(line.substr(pos + 1)); + repeat_counts_[text_hash] = count; + } + } +} + +void OfflineTtsCacheMechanism::SaveRepeatCounts() { + std::string repeat_count_file = cache_dir_ + "/repeat_counts.txt"; + + // Open the file for writing + std::ofstream ofs(repeat_count_file); + if (!ofs.is_open()) { + SHERPA_ONNX_LOGE("Failed to open repeat count file for writing: %s", + repeat_count_file.c_str()); + return; // Skip saving if the file cannot be opened + } + + // Write the repeat counts to the file + for (const auto &entry : repeat_counts_) { + ofs << entry.first << " " << entry.second; + if (!ofs) { + SHERPA_ONNX_LOGE("Failed to write repeat count for text hash: %s", + entry.first.c_str()); + return; // Stop writing if an error occurs + } + ofs << std::endl; + } +} + +void OfflineTtsCacheMechanism::RemoveWavFile(const std::string &text_hash) { + std::string file_path = cache_dir_ + "/" + text_hash + ".wav"; + if (std::filesystem::exists(file_path)) { + // Subtract the size of the removed WAV file from the total cache size + std::ifstream file(file_path, std::ios::binary | std::ios::ate); + if (file.is_open()) { + used_cache_size_bytes_ -= file.tellg(); + file.close(); + } + std::filesystem::remove(file_path); + } + + // Remove the entry from the repeat counts and cache vector + if (repeat_counts_.find(text_hash) != repeat_counts_.end()) { + repeat_counts_.erase(text_hash); + cache_vector_.erase( + std::remove(cache_vector_.begin(), cache_vector_.end(), text_hash), + cache_vector_.end()); + } +} + +void OfflineTtsCacheMechanism::UpdateCacheVector() { + used_cache_size_bytes_ = 0; // Reset total cache size before recalculating + + for (const auto &entry : std::filesystem::directory_iterator(cache_dir_)) { + if (entry.path().extension() == ".wav") { + std::string text_hash = entry.path().stem().string(); + if (repeat_counts_.find(text_hash) == repeat_counts_.end()) { + // Remove the file if it's not in the repeat count file (orphaned file) + std::filesystem::remove(entry.path()); + } else { + // Add the size of the WAV file to the total cache size + std::ifstream file(entry.path(), std::ios::binary | std::ios::ate); + if (file.is_open()) { + used_cache_size_bytes_ += file.tellg(); + } + cache_vector_.push_back(text_hash); + } + } + } +} + +void OfflineTtsCacheMechanism::EnsureCacheLimit() { + if (used_cache_size_bytes_ > cache_size_bytes_) { + auto target_cache_size + = std::max(static_cast (cache_size_bytes_*0.95), 0); + while (used_cache_size_bytes_> 0 + && used_cache_size_bytes_ > target_cache_size) { + // Cache is full, remove the least repeated file + std::string least_repeated_file = GetLeastRepeatedFile(); + RemoveWavFile(least_repeated_file); + } + } +} + +std::string OfflineTtsCacheMechanism::GetLeastRepeatedFile() { + std::string least_repeated_file; + int32_t min_count = std::numeric_limits::max(); + + for (const auto &entry : repeat_counts_) { + if (entry.second <= 1) { + least_repeated_file = entry.first; + return least_repeated_file; + } + + if (entry.second < min_count) { + min_count = entry.second; + least_repeated_file = entry.first; + } + } + + return least_repeated_file; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-cache-mechanism.h b/sherpa-onnx/csrc/offline-tts-cache-mechanism.h new file mode 100644 index 000000000..48f94c5b5 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-cache-mechanism.h @@ -0,0 +1,91 @@ +// sherpa-onnx/csrc/offline-tts-cache-mechanism.h +// +// Copyright (c) 2025 @mah92 From Iranian people to the community with love + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_H_ + +#include +#include +#include +#include // NOLINT + +#include "sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h" + +namespace sherpa_onnx { + +class OfflineTtsCacheMechanism { + public: + OfflineTtsCacheMechanism(const std::string &cache_dir, int32_t cache_size); + ~OfflineTtsCacheMechanism(); + + // Add a new wav file to the cache + void AddWavFile( + const std::string &text_hash, + const std::vector &samples, + const int32_t sample_rate); + + // Get the cached wav file if it exists + std::vector GetWavFile( + const std::string &text_hash, + int32_t *sample_rate); + + // Get the current cache size in bytes + int32_t GetCacheSize() const; + + // Set the cache size in bytes + void SetCacheSize(int32_t cache_size); + + // Remove all the wav files in the cache + void ClearCache(); + + // To get total used cache size(for wav files) in bytes + int32_t GetTotalUsedCacheSize() const; + + private: + // Load the repeat count file + void LoadRepeatCounts(); + + // Save the repeat count file + void SaveRepeatCounts(); + + // Remove a wav file from the cache + void RemoveWavFile(const std::string &text_hash); + + // Update the cache vector with the actual files in the cache folder + void UpdateCacheVector(); + + // Reduce used cache size if needed + void EnsureCacheLimit(); + + // Get the least repeated file in the cache + std::string GetLeastRepeatedFile(); + + // Data directory where the cache folder is located + std::string cache_dir_; + + // Maximum number of bytes in the cache + int32_t cache_size_bytes_; + + // Total used cache size for wav files in bytes + int32_t used_cache_size_bytes_; + + // Map of text hash to repeat count + std::unordered_map repeat_counts_; + + // Vector of cached file names + std::vector cache_vector_; + + // Mutex for thread safety (recursive to avoid deadlocks) + mutable std::recursive_mutex mutex_; + + // Time of last save + std::chrono::steady_clock::time_point last_save_time_; + + // if cache mechanism is inited successfully + bool cache_mechanism_inited_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_H_ diff --git a/sherpa-onnx/csrc/offline-tts.cc b/sherpa-onnx/csrc/offline-tts.cc index ec2c69523..b1aff3894 100644 --- a/sherpa-onnx/csrc/offline-tts.cc +++ b/sherpa-onnx/csrc/offline-tts.cc @@ -18,6 +18,7 @@ #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-tts-cache-mechanism.h" #include "sherpa-onnx/csrc/offline-tts-impl.h" #include "sherpa-onnx/csrc/text-utils.h" @@ -85,22 +86,98 @@ std::string OfflineTtsConfig::ToString() const { } OfflineTts::OfflineTts(const OfflineTtsConfig &config) - : impl_(OfflineTtsImpl::Create(config)) {} + : config_(config), + impl_(OfflineTtsImpl::Create(config)), + cache_mechanism_(nullptr) {} template OfflineTts::OfflineTts(Manager *mgr, const OfflineTtsConfig &config) - : impl_(OfflineTtsImpl::Create(mgr, config)) {} + : config_(config), + impl_(OfflineTtsImpl::Create(mgr, config)), + cache_mechanism_(nullptr) {} OfflineTts::~OfflineTts() = default; GeneratedAudio OfflineTts::Generate( const std::string &text, int64_t sid /*=0*/, float speed /*= 1.0*/, GeneratedAudioCallback callback /*= nullptr*/) const { - return impl_->Generate(text, sid, speed, std::move(callback)); + // Generate a hash for the text + std::hash hasher; + std::string text_hash = std::to_string(hasher(text)); + // SHERPA_ONNX_LOGE("Generated text hash: %s", text_hash.c_str()); + + // Check if the cache mechanism is active and if the audio is already cached + if (cache_mechanism_) { + int32_t sample_rate; + std::vector samples + = cache_mechanism_->GetWavFile(text_hash, &sample_rate); + + if (!samples.empty()) { + SHERPA_ONNX_LOGE("Returning cached audio for hash:%s", text_hash.c_str()); + + // If a callback is provided, call it with the cached audio + if (callback) { + int32_t result + = callback(samples.data(), samples.size(), 1.0f /* progress */); + if (result == 0) { + // If the callback returns 0, stop further processing + SHERPA_ONNX_LOGE("Callback requested to stop processing."); + return {samples, sample_rate}; + } + } + + // Return the cached audio + return {samples, sample_rate}; + } + } + + // Generate the audio if not cached + GeneratedAudio audio = impl_->Generate(text, sid, speed, std::move(callback)); + + // Cache the generated audio if the cache mechanism is active + if (cache_mechanism_) { + cache_mechanism_->AddWavFile(text_hash, audio.samples, audio.sample_rate); + // SHERPA_ONNX_LOGE("Cached audio for text hash: %s", text_hash.c_str()); + } + + return audio; } int32_t OfflineTts::SampleRate() const { return impl_->SampleRate(); } +int32_t OfflineTts::CacheSize() const { + return cache_mechanism_ ? cache_mechanism_->GetCacheSize() : 0; +} + +void OfflineTts::SetCacheSize(const int32_t cache_size) { + if (cache_size > 0) { + if (!cache_mechanism_) { + // Initialize the cache mechanism if it hasn't been initialized yet + cache_mechanism_ = std::make_unique( + config_.cache_dir, cache_size); + } else { + // Update the cache size if the cache mechanism is already initialized + cache_mechanism_->SetCacheSize(cache_size); + } + } else if (cache_mechanism_) { + // If cache size is set to 0 or negative, destroy the cache mechanism + cache_mechanism_.reset(); + } +} + +void OfflineTts::ClearCache() { + if (cache_mechanism_) { + cache_mechanism_->ClearCache(); + } +} + +int32_t OfflineTts::GetTotalUsedCacheSize() { + if (cache_mechanism_) { + return cache_mechanism_->GetTotalUsedCacheSize(); + } + return -1; +} + int32_t OfflineTts::NumSpeakers() const { return impl_->NumSpeakers(); } #if __ANDROID_API__ >= 9 diff --git a/sherpa-onnx/csrc/offline-tts.h b/sherpa-onnx/csrc/offline-tts.h index 884173e7b..060751bee 100644 --- a/sherpa-onnx/csrc/offline-tts.h +++ b/sherpa-onnx/csrc/offline-tts.h @@ -10,6 +10,7 @@ #include #include +#include "sherpa-onnx/csrc/offline-tts-cache-mechanism.h" #include "sherpa-onnx/csrc/offline-tts-model-config.h" #include "sherpa-onnx/csrc/parse-options.h" @@ -32,6 +33,9 @@ struct OfflineTtsConfig { // If you set it to -1, then we process all sentences in a single batch. int32_t max_num_sentences = 1; + // Path to cache_directory + std::string cache_dir; + OfflineTtsConfig() = default; OfflineTtsConfig(const OfflineTtsModelConfig &model, const std::string &rule_fsts, const std::string &rule_fars, @@ -87,12 +91,26 @@ class OfflineTts { // Return the sample rate of the generated audio int32_t SampleRate() const; + // Return the maximum number of cached audio files size + int32_t CacheSize() const; + + // Set the maximum number of cached audio files size + void SetCacheSize(const int32_t cache_size); + + // Remove all cache data + void ClearCache(); + + // To get total used cache size(for wav files) in bytes + int32_t GetTotalUsedCacheSize(); + // Number of supported speakers. // If it supports only a single speaker, then it return 0 or 1. int32_t NumSpeakers() const; private: + OfflineTtsConfig config_; std::unique_ptr impl_; + std::unique_ptr cache_mechanism_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/jni/offline-tts.cc b/sherpa-onnx/jni/offline-tts.cc index 6af10788a..693dccaa4 100644 --- a/sherpa-onnx/jni/offline-tts.cc +++ b/sherpa-onnx/jni/offline-tts.cc @@ -175,6 +175,24 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(cls, "maxNumSentences", "I"); ans.max_num_sentences = env->GetIntField(config, fid); + // Get data directory from config + jfieldID model_fid = env->GetFieldID(cls, "model", "Lcom/k2fsa/sherpa/onnx/OfflineTtsModelConfig;"); + jobject model_config = env->GetObjectField(config, model_fid); + jclass model_cls = env->GetObjectClass(model_config); + + jfieldID vits_fid = env->GetFieldID(model_cls, "vits", "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;"); + jobject vits_config = env->GetObjectField(model_config, vits_fid); + + fid = env->GetFieldID(vits_cls, "dataDir", "Ljava/lang/String;"); + jstring data_dir = (jstring)env->GetObjectField(vits_config, fid); + const char *p_data_dir = env->GetStringUTFChars(data_dir, nullptr); + + // Convert data directory to cache directory + std::string cache_dir = std::string(p_data_dir) + "/../cache"; + ans.cache_dir = cache_dir; + + env->ReleaseStringUTFChars(data_dir, p_data_dir); + return ans; } @@ -223,6 +241,18 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete( delete reinterpret_cast(ptr); } +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_setCacheSizeImpl( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jint cacheSize) { + reinterpret_cast(ptr)->SetCacheSize(static_cast(cacheSize)); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getCacheSizeImpl( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + return reinterpret_cast(ptr)->CacheSize(); +} + SHERPA_ONNX_EXTERN_C JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getSampleRate( JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { @@ -235,6 +265,23 @@ JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getNumSpeakers( return reinterpret_cast(ptr)->NumSpeakers(); } +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getTotalUsedCacheSizeImpl( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + return reinterpret_cast(ptr)->GetTotalUsedCacheSize(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineTts_clearCacheImpl( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + auto tts = reinterpret_cast(ptr); + if (tts) { + tts->ClearCache(); + SHERPA_ONNX_LOGE("Cache cleared from JNI."); + } +} + SHERPA_ONNX_EXTERN_C JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/, diff --git a/sherpa-onnx/kotlin-api/Tts.kt b/sherpa-onnx/kotlin-api/Tts.kt index ce85a04da..f53f56462 100644 --- a/sherpa-onnx/kotlin-api/Tts.kt +++ b/sherpa-onnx/kotlin-api/Tts.kt @@ -151,6 +151,22 @@ class OfflineTts( private external fun getSampleRate(ptr: Long): Int private external fun getNumSpeakers(ptr: Long): Int + fun getCacheSizeInMB(): Int { + return (getCacheSizeImpl(ptr) / (1024 * 1024)).toInt() // Convert bytes to MB + } + private external fun getCacheSizeImpl(ptr: Long): Int + + fun setCacheSizeInMB(cacheSize: Int) { + setCacheSizeImpl(ptr, cacheSize * (1024 * 1024)) + } + private external fun setCacheSizeImpl(ptr: Long, cacheSize: Int) + + fun getTotalUsedCacheSizeInMB(): Int { + return (getTotalUsedCacheSizeImpl(ptr) / (1024 * 1024)).toInt() // Convert bytes to MB + } + + private external fun getTotalUsedCacheSizeImpl(ptr: Long): Int + // The returned array has two entries: // - the first entry is an 1-D float array containing audio samples. // Each sample is normalized to the range [-1, 1] @@ -170,6 +186,12 @@ class OfflineTts( callback: (samples: FloatArray) -> Int ): Array + fun clearCache() { + clearCacheImpl(ptr) + } + + private external fun clearCacheImpl(ptr: Long) + companion object { init { System.loadLibrary("sherpa-onnx-jni")