diff --git a/samples/python/text_generation/multinomial_causal_lm.py b/samples/python/text_generation/multinomial_causal_lm.py index c915b89a2f..1c954e7224 100755 --- a/samples/python/text_generation/multinomial_causal_lm.py +++ b/samples/python/text_generation/multinomial_causal_lm.py @@ -63,7 +63,7 @@ def get_stop_flag(self): """ return False - def put_word(self, word: str): + def put_word(self, word: str | None): """ Puts a word into the text queue. @@ -72,17 +72,23 @@ def put_word(self, word: str): """ self.text_queue.put(word) - def put(self, token_id: int) -> bool: + def put(self, token: int | list[int]) -> bool: """ Processes a token and manages the decoding buffer. Adds decoded text to the queue. Args: - token_id (int): The token_id to process. + token (int | list[int]): The token(s) to process. Returns: bool: True if generation should be stopped, False otherwise. """ - self.tokens_cache.append(token_id) + + if type(token) is int: + self.tokens_cache.append(token) + elif type(token) is list: + self.tokens_cache += token + self.decoded_lengths += [-1 for _ in token[:-1]] + text = self.tokenizer.decode(self.tokens_cache) self.decoded_lengths.append(len(text)) @@ -132,12 +138,18 @@ def __init__(self, tokenizer, tokens_len): super().__init__(tokenizer) self.tokens_len = tokens_len - def put(self, token_id: int) -> bool: - if (len(self.tokens_cache) + 1) % self.tokens_len != 0: - self.tokens_cache.append(token_id) + def put(self, token: int | list[int]) -> bool: + if (len(self.tokens_cache) + 1) % self.tokens_len == 0: + return super().put(token) + + if type(token) is int: + self.tokens_cache.append(token) self.decoded_lengths.append(-1) - return False - return super().put(token_id) + elif type(token) is list: + self.tokens_cache += token + self.decoded_lengths += [-1 for _ in token] + + return False def main(): diff --git a/src/cpp/include/openvino/genai/streamer_base.hpp b/src/cpp/include/openvino/genai/streamer_base.hpp index f286e896e5..696af77737 100644 --- a/src/cpp/include/openvino/genai/streamer_base.hpp +++ b/src/cpp/include/openvino/genai/streamer_base.hpp @@ -18,6 +18,16 @@ class OPENVINO_GENAI_EXPORTS StreamerBase { /// @brief put is called every time new token is decoded, /// @return bool flag to indicate whether generation should be stopped, if return true generation stops virtual bool put(int64_t token) = 0; + /// @brief put is called every time new vector of tokens is decoded, in case of assisting or prompt lookup decoding + /// @return bool flag to indicate whether generation should be stopped, if return true generation stops + virtual bool put(const std::vector& tokens) { + for (const auto token : tokens) { + if (put(token)) { + return true; + } + } + return false; + }; /// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one virtual void end() = 0; diff --git a/src/cpp/include/openvino/genai/whisper_pipeline.hpp b/src/cpp/include/openvino/genai/whisper_pipeline.hpp index 8ba6a6a8e1..bab1203245 100644 --- a/src/cpp/include/openvino/genai/whisper_pipeline.hpp +++ b/src/cpp/include/openvino/genai/whisper_pipeline.hpp @@ -23,18 +23,23 @@ using RawSpeechInput = std::vector; * * @param m_tokenizer tokenizer */ -class OPENVINO_GENAI_EXPORTS ChunkStreamerBase : public StreamerBase { +class OPENVINO_GENAI_EXPORTS ChunkStreamerBase { public: + /// @brief put is called every time new token is decoded, + /// @return bool flag to indicate whether generation should be stopped, if return true generation stops + virtual bool put(int64_t token) = 0; + /// @brief put is called every time new token chunk is generated, /// @return bool flag to indicate whether generation should be stopped, if return true generation stops virtual bool put_chunk(std::vector tokens) = 0; -}; -// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop. -using ChunkStreamerVariant = - std::variant, std::shared_ptr, std::monostate>; + /// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one + virtual void end() = 0; + + virtual ~ChunkStreamerBase() = 0; +}; -struct OPENVINO_GENAI_EXPORTS WhisperRawPerfMetrics { +struct WhisperRawPerfMetrics { /** @brief Duration for each features extraction call */ std::vector features_extraction_durations; }; @@ -151,7 +156,13 @@ class OPENVINO_GENAI_EXPORTS WhisperPipeline { */ WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input, OptionalWhisperGenerationConfig generation_config = std::nullopt, - ChunkStreamerVariant streamer = std::monostate()); + StreamerVariant streamer = std::monostate()); + + OPENVINO_DEPRECATED("ChunkStreamerBase is deprecated. " + "Use StreamerBase instead. Support will be removed in 2026.0") + WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input, + WhisperGenerationConfig generation_config, + std::shared_ptr streamer); /** * @brief High level generate that receives raw speech as a vector of floats and returns decoded output. @@ -174,6 +185,9 @@ class OPENVINO_GENAI_EXPORTS WhisperPipeline { void set_generation_config(const WhisperGenerationConfig& config); }; -OPENVINO_GENAI_EXPORTS std::pair streamer(ChunkStreamerVariant func); +OPENVINO_DEPRECATED("ChunkStreamerBase is deprecated. " + "Use StreamerBase instead. Support will be removed in 2026.0") +OPENVINO_GENAI_EXPORTS std::pair streamer(std::shared_ptr func); + OPENVINO_GENAI_EXPORTS std::pair generation_config(const WhisperGenerationConfig& config); } // namespace ov::genai diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index be1eba04f9..9a5efb1ecd 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -11,6 +11,7 @@ #include "lora_helper.hpp" #include "cache_state_dumper.hpp" #include "utils.hpp" +#include "threaded_streamer.hpp" namespace { @@ -429,19 +430,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector& streamer_ptr = std::visit(overloaded{ - [](std::monostate) -> std::shared_ptr { - return nullptr; - }, - [](const std::shared_ptr& streamer) { - return streamer; - }, - [this](const std::function& streamer) -> std::shared_ptr { - return std::make_unique(m_tokenizer, streamer); - } - }, streamer); + const auto streamer_ptr = std::make_shared(streamer, m_tokenizer); - OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && sampling_params[0].num_return_sequences == 1 && + OPENVINO_ASSERT(!streamer_ptr->has_callback() || input_ids.size() == 1 && sampling_params[0].num_return_sequences == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); @@ -452,49 +443,12 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector has_active_requests = has_non_finished_requests(); GenerationHandle& generation = generations.at(0); - // create variables to make optimal thread-safe streaming - std::mutex mutex; - std::unique_lock lock(mutex); - std::condition_variable cv; - - // to define streaming thread - std::shared_ptr t_stream_ptr = nullptr; - if (streamer_ptr) { - // define stream token lambda to use in `t_stream_ptr` - auto stream_tokens = [this, &generation, &streamer_ptr, &has_active_requests, &cv, &lock]() { - while (has_active_requests || generation->can_read()) { - // waiting for any tokens or request finishing - cv.wait(lock, [&generation, &has_active_requests]{ - return generation->can_read() || !has_active_requests; - }); - - if (generation->can_read()) { - std::unordered_map generation_outputs = generation->read(); - OPENVINO_ASSERT(generation_outputs.size() <= 1); - if (!generation_outputs.empty()) { - for (const auto& generated_token_id : generation_outputs.begin()->second.generated_ids) { - if (streamer_ptr->put(generated_token_id)) { - generation->drop(); - break; - } - } - } - } - }; - streamer_ptr->end(); - }; - - // to define streaming thread - t_stream_ptr = std::make_shared([&stream_tokens] { - stream_tokens(); - }); - } + streamer_ptr->start(); std::exception_ptr thrown_exception = nullptr; - while (has_active_requests) { + while (has_non_finished_requests()) { try { const auto infer_start = std::chrono::steady_clock::now(); step(); @@ -510,17 +464,14 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorend(); + std::rethrow_exception(thrown_exception); } } - // waiting for competion of streaming - if (t_stream_ptr && t_stream_ptr->joinable()) { - t_stream_ptr->join(); - } + streamer_ptr->end(); OPENVINO_ASSERT(m_requests.empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); diff --git a/src/cpp/src/icontinuous_batching.cpp b/src/cpp/src/icontinuous_batching.cpp index 5bdf00d51d..c7dafbe4f8 100644 --- a/src/cpp/src/icontinuous_batching.cpp +++ b/src/cpp/src/icontinuous_batching.cpp @@ -112,4 +112,28 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate( return decoded; } + +void ContinuousBatchingPipeline::IContinuousBatchingPipeline::stream_tokens( + const std::shared_ptr& streamer_ptr, + const GenerationHandle& handle +) { + if (!streamer_ptr->has_callback() || !handle->can_read()) { + return; + } + + if (streamer_ptr->is_dropped()) { + handle->drop(); + return; + } + + std::unordered_map generation_outputs = handle->read(); + OPENVINO_ASSERT(generation_outputs.size() <= 1); + if (generation_outputs.empty()) { + return; + } + + const auto tokens = generation_outputs.begin()->second.generated_ids; + streamer_ptr->put(tokens); +} + } diff --git a/src/cpp/src/icontinuous_batching.hpp b/src/cpp/src/icontinuous_batching.hpp index 11c9b67e69..02a28644da 100644 --- a/src/cpp/src/icontinuous_batching.hpp +++ b/src/cpp/src/icontinuous_batching.hpp @@ -9,6 +9,7 @@ #include "sampler.hpp" #include "model_runner.hpp" #include "scheduler.hpp" +#include "threaded_streamer.hpp" namespace ov::genai { @@ -46,6 +47,8 @@ class ContinuousBatchingPipeline::IContinuousBatchingPipeline { // to access m_load_time_ms friend class ContinuousBatchingPipeline; + void stream_tokens(const std::shared_ptr& streamer_ptr, const GenerationHandle& handle); + public: GenerationConfig get_config() const; PipelineMetrics get_metrics() const; diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index a84b83dd2f..4ec44bda3a 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -138,13 +138,19 @@ PerfMetrics PerfMetrics::operator+(const PerfMetrics& right) const { // Concatenate durations, batch_sizes first token times. auto& new_durations = res.raw_metrics.m_durations; + auto& new_inference_durations = res.raw_metrics.m_inference_durations; + auto& new_token_infer_durations = res.raw_metrics.m_token_infer_durations; auto& new_batch_sizes = res.raw_metrics.m_batch_sizes; auto& new_times_to_first_token = res.raw_metrics.m_times_to_first_token; + auto& right_inference_durations = right.raw_metrics.m_inference_durations; + auto& right_token_infer_durations = right.raw_metrics.m_token_infer_durations; auto& right_durations = right.raw_metrics.m_durations; auto& right_batch_sizes = right.raw_metrics.m_batch_sizes; auto& right_times_to_first_token = right.raw_metrics.m_times_to_first_token; new_durations.insert(new_durations.end(), right_durations.begin(), right_durations.end()); + new_inference_durations.insert(new_inference_durations.end(), right_inference_durations.begin(), right_inference_durations.end()); + new_token_infer_durations.insert(new_token_infer_durations.end(), right_token_infer_durations.begin(), right_token_infer_durations.end()); new_times_to_first_token.insert(new_times_to_first_token.end(), right_times_to_first_token.begin(), right_times_to_first_token.end()); new_batch_sizes.insert(new_batch_sizes.end(), right_batch_sizes.begin(), right_batch_sizes.end()); diff --git a/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp b/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp index 6e58662b33..a2f230904d 100644 --- a/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp +++ b/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp @@ -5,6 +5,7 @@ #include "prompt_lookup_impl.hpp" #include "text_callback_streamer.hpp" +#include "threaded_streamer.hpp" namespace ov::genai { template struct overloaded : Ts... {using Ts::operator()...;}; @@ -108,19 +109,9 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vectorset_adapters(sampling_params[0].adapters); - const std::shared_ptr& streamer_ptr = std::visit(overloaded{ - [](std::monostate) -> std::shared_ptr { - return nullptr; - }, - [](const std::shared_ptr& streamer) { - return streamer; - }, - [this](const std::function& streamer) -> std::shared_ptr { - return std::make_unique(m_tokenizer, streamer); - } - }, streamer); + const auto streamer_ptr = std::make_shared(streamer, m_tokenizer); - OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), + OPENVINO_ASSERT(!streamer_ptr->has_callback() || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); std::vector generations; @@ -131,66 +122,26 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vectorget_awaiting_requests(); - std::atomic has_active_requests = has_non_finished_requests(); auto& generation = generations.at(0); - // create variables to make optimal thread-safe streaming - std::mutex mutex; - std::unique_lock lock(mutex); - std::condition_variable cv; - - // to define streaming thread - std::shared_ptr t_stream_ptr = nullptr; - if (streamer_ptr) { - // define stream token lambda to use in `t_stream_ptr` - auto stream_tokens = [this, &generation, &streamer_ptr, &has_active_requests, &cv, &lock]() { - while (has_active_requests || generation->can_read()) { - // waiting for any tokens or request finishing - cv.wait(lock, [&generation, &has_active_requests]{ - return generation->can_read() || !has_active_requests; - }); - - if (generation->can_read()) { - std::unordered_map generation_outputs = generation->read(); - OPENVINO_ASSERT(generation_outputs.size() <= 1); - if (!generation_outputs.empty()) { - for (const auto& generated_token_id : generation_outputs.begin()->second.generated_ids) { - if (streamer_ptr->put(generated_token_id)) { - generation->drop(); - break; - } - } - } - } - }; - streamer_ptr->end(); - }; - - // to define streaming thread - t_stream_ptr = std::make_shared([&stream_tokens] { - stream_tokens(); - }); - } + streamer_ptr->start(); std::exception_ptr thrown_exception = nullptr; - while (has_active_requests) { + while (has_non_finished_requests()) { try { step(); } catch (...) { drop_requests(); // remove all requests from pipeline state in case of exception thrown_exception = std::current_exception(); } - has_active_requests = has_non_finished_requests(); - cv.notify_one(); + stream_tokens(streamer_ptr, generation); if (thrown_exception) { - throw thrown_exception; + streamer_ptr->end(); + std::rethrow_exception(thrown_exception); } } - // waiting for competion of streaming - if (t_stream_ptr && t_stream_ptr->joinable()) { - t_stream_ptr->join(); - } + streamer_ptr->end(); OPENVINO_ASSERT(m_pipeline->is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 6fb4e8ac53..3c21113ccf 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -7,6 +7,7 @@ #include "speculative_decoding_impl.hpp" #include "paged_attention_transformations.hpp" #include "utils.hpp" +#include "threaded_streamer.hpp" namespace ov::genai { @@ -231,19 +232,9 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< m_main_pipeline->set_adapters(sampling_params[0].adapters); m_draft_pipeline->set_adapters(sampling_params[0].adapters); - const std::shared_ptr& streamer_ptr = std::visit(overloaded{ - [](std::monostate) -> std::shared_ptr { - return nullptr; - }, - [](const std::shared_ptr& streamer) { - return streamer; - }, - [this](const std::function& streamer) -> std::shared_ptr { - return std::make_unique(m_tokenizer, streamer); - } - }, streamer); + const auto streamer_ptr = std::make_shared(streamer, m_tokenizer); - OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), + OPENVINO_ASSERT(!streamer_ptr->has_callback() || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); std::vector main_generations; @@ -260,65 +251,26 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< } auto all_requests = get_awaiting_requests(); - std::atomic has_active_requests = has_non_finished_requests(); - GenerationHandle& generation = main_generations.at(0); - - // create variables to make optimal thread-safe streaming - std::mutex mutex; - std::unique_lock lock(mutex); - std::condition_variable cv; - - std::shared_ptr t_stream_ptr = nullptr; - if (streamer_ptr) { - // define stream token lambda to use in `t_stream_ptr` - auto stream_tokens = [this, &generation, &streamer_ptr, &has_active_requests, &cv, &lock]() { - while (has_active_requests || generation->can_read()) { - // waiting for any tokens or request finishing - cv.wait(lock, [&generation, &has_active_requests]{ - return generation->can_read() || !has_active_requests; - }); - - if (generation->can_read()) { - std::unordered_map generation_outputs = generation->read(); - OPENVINO_ASSERT(generation_outputs.size() <= 1); - if (!generation_outputs.empty()) { - for (const auto& generated_token_id : generation_outputs.begin()->second.generated_ids) { - if (streamer_ptr->put(generated_token_id)) { - generation->drop(); - break; - } - } - } - } - }; - - streamer_ptr->end(); - }; + auto& generation = main_generations.at(0); - t_stream_ptr = std::make_shared([&stream_tokens] { - stream_tokens(); - }); - } + streamer_ptr->start(); std::exception_ptr thrown_exception = nullptr; - while (has_active_requests) { + while (has_non_finished_requests()) { try { step(); } catch (...) { drop_requests(); // remove all requests from pipeline state in case of exception thrown_exception = std::current_exception(); } - has_active_requests = has_non_finished_requests(); - cv.notify_one(); + stream_tokens(streamer_ptr, generation); if (thrown_exception) { - throw thrown_exception; + streamer_ptr->end(); + std::rethrow_exception(thrown_exception); } } - // waiting for competion of streaming - if (t_stream_ptr && t_stream_ptr->joinable()) { - t_stream_ptr->join(); - } + streamer_ptr->end(); OPENVINO_ASSERT(is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); diff --git a/src/cpp/src/text_callback_streamer.cpp b/src/cpp/src/text_callback_streamer.cpp index aee909dfb8..17cff558f0 100644 --- a/src/cpp/src/text_callback_streamer.cpp +++ b/src/cpp/src/text_callback_streamer.cpp @@ -8,7 +8,7 @@ namespace genai { TextCallbackStreamer::TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback) { m_tokenizer = tokenizer; - on_finalized_subword_callback = callback; + m_on_finalized_subword_callback = callback; } bool TextCallbackStreamer::put(int64_t token) { @@ -16,28 +16,30 @@ bool TextCallbackStreamer::put(int64_t token) { m_tokens_cache.push_back(token); std::string text = m_tokenizer.decode(m_tokens_cache); m_decoded_lengths.push_back(text.length()); - + if (!text.empty() && '\n' == text.back() && text.size() > m_printed_len) { // Flush the cache after the new line symbol res << std::string_view{text.data() + m_printed_len, text.size() - m_printed_len}; m_tokens_cache.clear(); m_decoded_lengths.clear(); m_printed_len = 0; - return on_finalized_subword_callback(res.str()); + return m_on_finalized_subword_callback(res.str()); } constexpr size_t delay_n_tokens = 3; - // In some cases adding the next token can shorten the text, + // In some cases adding the next token can shorten the text, // e.g. when apostrophe removing regex had worked after adding new tokens. // Printing several last tokens is delayed. if (m_decoded_lengths.size() < delay_n_tokens) { - return on_finalized_subword_callback(res.str()); + return m_on_finalized_subword_callback(res.str()); } - constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error. + + // MSVC with /utf-8 fails to compile � directly with newline in string literal error. + constexpr char replacement[] = "\xef\xbf\xbd"; if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) { m_decoded_lengths[m_decoded_lengths.size() - 1] = -1; // Don't print incomplete text - return on_finalized_subword_callback(res.str()); + return m_on_finalized_subword_callback(res.str()); } auto print_until = m_decoded_lengths[m_decoded_lengths.size() - delay_n_tokens]; if (print_until != -1 && print_until > m_printed_len) { @@ -46,7 +48,7 @@ bool TextCallbackStreamer::put(int64_t token) { res << std::string_view{text.data() + m_printed_len, print_until - m_printed_len} << std::flush; m_printed_len = print_until; } - return on_finalized_subword_callback(res.str()); + return m_on_finalized_subword_callback(res.str()); } void TextCallbackStreamer::end() { @@ -58,7 +60,7 @@ void TextCallbackStreamer::end() { m_tokens_cache.clear(); m_decoded_lengths.clear(); m_printed_len = 0; - on_finalized_subword_callback(res.str()); + m_on_finalized_subword_callback(res.str()); return; } diff --git a/src/cpp/src/text_callback_streamer.hpp b/src/cpp/src/text_callback_streamer.hpp index 2c5fab5700..b8675bd9d3 100644 --- a/src/cpp/src/text_callback_streamer.hpp +++ b/src/cpp/src/text_callback_streamer.hpp @@ -9,20 +9,23 @@ namespace ov { namespace genai { -class TextCallbackStreamer: public StreamerBase { +class TextCallbackStreamer : public StreamerBase { public: - bool put(int64_t token) override; - void end() override; - TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback); - std::function on_finalized_subword_callback = [](std::string words)->bool { return false; }; + bool put(int64_t token) override; + void end() override; protected: Tokenizer m_tokenizer; std::vector m_tokens_cache; std::vector m_decoded_lengths; size_t m_printed_len = 0; + +private: + std::function m_on_finalized_subword_callback = [](std::string words) -> bool { + return false; + }; }; } // namespace genai diff --git a/src/cpp/src/threaded_streamer.hpp b/src/cpp/src/threaded_streamer.hpp new file mode 100644 index 0000000000..84c45877a1 --- /dev/null +++ b/src/cpp/src/threaded_streamer.hpp @@ -0,0 +1,110 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "openvino/genai/llm_pipeline.hpp" +#include "openvino/genai/tokenizer.hpp" +#include "synchronized_queue.hpp" +#include "text_callback_streamer.hpp" + +namespace ov { +namespace genai { + +class ThreadedStreamerWrapper { +public: + ThreadedStreamerWrapper(const StreamerVariant& streamer, Tokenizer& tokenizer) { + if (auto streamer_obj = std::get_if(&streamer)) { + m_streamer_ptr = nullptr; + } else if (auto streamer_obj = std::get_if>(&streamer)) { + m_streamer_ptr = *streamer_obj; + } else if (auto callback = std::get_if>(&streamer)) { + m_streamer_ptr = std::make_shared(tokenizer, *callback); + } + } + + void start() { + if (!m_streamer_ptr) { + return; + } + + m_worker_thread = std::make_shared(&ThreadedStreamerWrapper::_worker, this); + } + + void put(const std::vector& tokens) { + if (!m_streamer_ptr) { + return; + } + + m_squeue.push(tokens); + } + + void put(const int64_t token) { + if (!m_streamer_ptr) { + return; + } + + m_squeue.push(token); + } + + void end() { + if (!m_streamer_ptr) { + return; + } + + // push stop token to unblock squeue.pull + m_squeue.push(std::monostate()); + + if (m_worker_thread && m_worker_thread->joinable()) { + m_worker_thread->join(); + } + + m_streamer_ptr->end(); + } + + bool is_dropped() const { + if (!m_streamer_ptr) { + return false; + } + + return m_dropped; + } + + bool has_callback() const { + return static_cast(m_streamer_ptr); + } + +private: + std::shared_ptr m_streamer_ptr = nullptr; + std::shared_ptr m_worker_thread = nullptr; + SynchronizedQueue, std::monostate>> m_squeue; + + std::atomic m_dropped = false; + + void _worker() { + while (true) { + // wait for queue pull + std::variant, std::monostate> token_variant = m_squeue.pull(); + + // wait for streamer_ptr result + if (auto token = std::get_if(&token_variant)) { + m_dropped = m_streamer_ptr->put(*token); + } else if (auto tokens = std::get_if>(&token_variant)) { + m_dropped = m_streamer_ptr->put(*tokens); + } else if (auto stop_token = std::get_if(&token_variant)) { + break; + } else { + OPENVINO_THROW("Internal error: unsupported threaded streamer value"); + } + + if (m_dropped) { + break; + } + } + } +}; + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/whisper/models/decoder.cpp b/src/cpp/src/whisper/models/decoder.cpp index c09a84ccdd..4622023ed7 100644 --- a/src/cpp/src/whisper/models/decoder.cpp +++ b/src/cpp/src/whisper/models/decoder.cpp @@ -30,7 +30,11 @@ std::pair WhisperDecoder::detect_language(const ov::Tensor& enco Tensor beam_idx_tensor{ov::element::i32, {1}}; beam_idx_tensor.data()[0] = 0; - auto [output_tensor, infer_ms] = decode(encoder_hidden_state, input_ids_tensor, beam_idx_tensor); + const auto infer_start = std::chrono::steady_clock::now(); + start_async(encoder_hidden_state, input_ids_tensor, beam_idx_tensor); + + auto output_tensor = wait(); + const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); int64_t output_token = ov::genai::utils::argmax(output_tensor, 0); diff --git a/src/cpp/src/whisper/models/decoder.hpp b/src/cpp/src/whisper/models/decoder.hpp index 6eeba2b387..50b085203c 100644 --- a/src/cpp/src/whisper/models/decoder.hpp +++ b/src/cpp/src/whisper/models/decoder.hpp @@ -17,9 +17,8 @@ class WhisperDecoder { std::pair detect_language(const Tensor& encoder_hidden_state, const int64_t decoder_start_token_id); - virtual std::pair decode(const Tensor& encoder_hidden_state, - const Tensor& input_ids, - const Tensor& beam_idx) = 0; + virtual void start_async(const Tensor& encoder_hidden_state, const Tensor& input_ids, const Tensor& beam_idx) = 0; + virtual Tensor wait() = 0; virtual void reset_state() = 0; diff --git a/src/cpp/src/whisper/models/statefull_decoder.cpp b/src/cpp/src/whisper/models/statefull_decoder.cpp index affdfffaf5..a47d54bfcd 100644 --- a/src/cpp/src/whisper/models/statefull_decoder.cpp +++ b/src/cpp/src/whisper/models/statefull_decoder.cpp @@ -22,9 +22,9 @@ WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& mo m_request = compiled_model.create_infer_request(); } -std::pair WhisperStatefullDecoder::decode(const Tensor& encoder_hidden_state, - const Tensor& input_ids, - const Tensor& beam_idx) { +void WhisperStatefullDecoder::start_async(const Tensor& encoder_hidden_state, + const Tensor& input_ids, + const Tensor& beam_idx) { const size_t batch_size = input_ids.get_shape().at(0); const size_t seq_len = input_ids.get_shape().at(1); @@ -34,15 +34,14 @@ std::pair WhisperStatefullDecoder::decode(const Tensor& encod m_request.set_tensor("input_ids", input_ids); m_request.set_tensor("beam_idx", beam_idx); - const auto infer_start = std::chrono::steady_clock::now(); - m_request.infer(); - const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); - - auto output_tensor = m_request.get_tensor("logits"); - - return {output_tensor, infer_ms}; + m_request.start_async(); }; +Tensor WhisperStatefullDecoder::wait() { + m_request.wait(); + return m_request.get_tensor("logits"); +} + void WhisperStatefullDecoder::_set_cache_position_tensor(const size_t seq_len) { ov::Tensor cache_position_tensor = m_request.get_tensor("cache_position"); diff --git a/src/cpp/src/whisper/models/statefull_decoder.hpp b/src/cpp/src/whisper/models/statefull_decoder.hpp index c8c733e943..9834c65ef9 100644 --- a/src/cpp/src/whisper/models/statefull_decoder.hpp +++ b/src/cpp/src/whisper/models/statefull_decoder.hpp @@ -14,9 +14,9 @@ class WhisperStatefullDecoder : public WhisperDecoder { const std::string& device, const ov::AnyMap& properties); - std::pair decode(const Tensor& encoder_hidden_state, - const Tensor& input_ids, - const Tensor& beam_idx) override; + void start_async(const Tensor& encoder_hidden_state, const Tensor& input_ids, const Tensor& beam_idx) override; + + Tensor wait() override; void reset_state() override; diff --git a/src/cpp/src/whisper/models/with_past_decoder.cpp b/src/cpp/src/whisper/models/with_past_decoder.cpp index 1ade0dea6b..2ec3e63fa2 100644 --- a/src/cpp/src/whisper/models/with_past_decoder.cpp +++ b/src/cpp/src/whisper/models/with_past_decoder.cpp @@ -97,9 +97,9 @@ WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& mode m_request_decoder_with_past = compiled_model.create_infer_request(); } -std::pair WhisperWithPastDecoder::decode(const Tensor& encoder_hidden_state, - const Tensor& input_ids, - const Tensor& beam_idx) { +void WhisperWithPastDecoder::start_async(const Tensor& encoder_hidden_state, + const Tensor& input_ids, + const Tensor& beam_idx) { const bool is_initial_step = m_cache_position == 0; ov::InferRequest& request = is_initial_step ? m_request_decoder : m_request_decoder_with_past; @@ -117,15 +117,20 @@ std::pair WhisperWithPastDecoder::decode(const Tensor& encoder_hi _set_past_key_value(beam_idx); - const auto infer_start = std::chrono::steady_clock::now(); - request.infer(); - const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); + request.start_async(); +}; + +Tensor WhisperWithPastDecoder::wait() { + const bool is_initial_step = m_cache_position == 0; + ov::InferRequest& request = is_initial_step ? m_request_decoder : m_request_decoder_with_past; + + request.wait(); - auto output_tensor = request.get_tensor("logits"); + const size_t seq_length = request.get_tensor("input_ids").get_shape().at(1); m_cache_position += seq_length; - return {output_tensor, infer_ms}; + return request.get_tensor("logits"); } void WhisperWithPastDecoder::_set_past_key_value(const Tensor& beam_idx) { diff --git a/src/cpp/src/whisper/models/with_past_decoder.hpp b/src/cpp/src/whisper/models/with_past_decoder.hpp index 1610c60d4e..b268903802 100644 --- a/src/cpp/src/whisper/models/with_past_decoder.hpp +++ b/src/cpp/src/whisper/models/with_past_decoder.hpp @@ -14,9 +14,9 @@ class WhisperWithPastDecoder : public WhisperDecoder { const std::string& device, const ov::AnyMap& properties); - std::pair decode(const Tensor& encoder_hidden_state, - const Tensor& input_ids, - const Tensor& beam_idx) override; + void start_async(const Tensor& encoder_hidden_state, const Tensor& input_ids, const Tensor& beam_idx) override; + + Tensor wait() override; void reset_state() override; diff --git a/src/cpp/src/whisper/streamer.cpp b/src/cpp/src/whisper/streamer.cpp deleted file mode 100644 index cf84a0b9b2..0000000000 --- a/src/cpp/src/whisper/streamer.cpp +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (C) 2023-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include "streamer.hpp" - -#include "text_callback_streamer.hpp" - -namespace ov { -namespace genai { - -bool ChunkTextCallbackStreamer::put(int64_t token) { - return ov::genai::TextCallbackStreamer::put(token); -} - -bool ChunkTextCallbackStreamer::put_chunk(std::vector tokens) { - if (tokens.empty()) { - return false; - } - - if (tokens.size() > 1) { - m_tokens_cache.insert(m_tokens_cache.end(), tokens.begin(), tokens.end() - 1); - } - - return ov::genai::TextCallbackStreamer::put(tokens.back()); -} - -void ChunkTextCallbackStreamer::end() { - ov::genai::TextCallbackStreamer::end(); -} - -} // namespace genai -} // namespace ov diff --git a/src/cpp/src/whisper/streamer.hpp b/src/cpp/src/whisper/streamer.hpp index df81b03f20..712a05b02d 100644 --- a/src/cpp/src/whisper/streamer.hpp +++ b/src/cpp/src/whisper/streamer.hpp @@ -10,14 +10,24 @@ namespace ov { namespace genai { -class ChunkTextCallbackStreamer : private TextCallbackStreamer, public ChunkStreamerBase { +class ChunkToBaseStreamerAdapter : public StreamerBase { public: - bool put(int64_t token) override; - bool put_chunk(std::vector tokens) override; - void end() override; + ChunkToBaseStreamerAdapter(std::shared_ptr chunk_streamer) : m_chunk_streamer{chunk_streamer} {} - ChunkTextCallbackStreamer(const Tokenizer& tokenizer, std::function callback) - : TextCallbackStreamer(tokenizer, callback){}; + bool put(const std::vector& tokens) override { + return m_chunk_streamer->put_chunk(tokens); + } + + bool put(int64_t token) override { + return m_chunk_streamer->put(token); + } + + void end() override { + return m_chunk_streamer->end(); + } + +private: + std::shared_ptr m_chunk_streamer; }; } // namespace genai diff --git a/src/cpp/src/whisper/whisper.cpp b/src/cpp/src/whisper/whisper.cpp index f773debf77..75991b0608 100644 --- a/src/cpp/src/whisper/whisper.cpp +++ b/src/cpp/src/whisper/whisper.cpp @@ -50,7 +50,7 @@ void process_whisper_logits(ov::Tensor logits, std::pair decode(std::shared_ptr decoder, const std::vector& input_ids, const ov::Tensor& encoder_hidden_state, - const std::shared_ptr streamer_ptr, + const std::shared_ptr streamer, ov::genai::Sampler& sampler, ov::genai::SequenceGroup::Ptr sequence_group, const bool return_timestamps, @@ -59,14 +59,14 @@ std::pair decode(std::shared_ptr(sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters()); - auto stream_generated_tokens = [&streamer_ptr, &handle, &return_timestamps]() { - if (return_timestamps || !streamer_ptr || !handle->can_read()) { + auto on_generated_tokens = [&streamer, &handle, &return_timestamps]() { + if (!streamer || return_timestamps || !handle->can_read()) { return; } std::unordered_map token = handle->read(); for (const auto& gen_token : token.begin()->second.generated_ids) { - if (streamer_ptr->put(gen_token)) { + if (streamer->put(gen_token)) { handle->drop(); break; } @@ -80,9 +80,12 @@ std::pair decode(std::shared_ptrdecode(encoder_hidden_state, input_ids_tensor, beam_idx); + const auto infer_start = std::chrono::steady_clock::now(); + decoder->start_async(encoder_hidden_state, input_ids_tensor, beam_idx); + auto logits = decoder->wait(); const auto infer_end = std::chrono::steady_clock::now(); + const auto infer_ms = ov::genai::PerfMetrics::get_microsec(infer_end - infer_start); raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms); raw_metrics.m_token_infer_durations.emplace_back(infer_ms); raw_metrics.m_new_token_times.emplace_back(infer_end); @@ -96,7 +99,7 @@ std::pair decode(std::shared_ptrset_output_seq_len(output_sequence_len); sampler.sample({sequence_group}, logits); - stream_generated_tokens(); + on_generated_tokens(); // "Generation" phase while (!sequence_group->has_finished() && !sequence_group->handle_dropped()) { @@ -138,11 +141,19 @@ std::pair decode(std::shared_ptrget_generated_ids(); } - auto [logits, infer_ms] = decoder->decode(encoder_hidden_state, - new_input_ids, - ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()}); + const auto infer_start = std::chrono::steady_clock::now(); + + decoder->start_async(encoder_hidden_state, + new_input_ids, + ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()}); + + // infer stats can be affected by heavy streamer callback + on_generated_tokens(); + + auto logits = decoder->wait(); const auto infer_end = std::chrono::steady_clock::now(); + const auto infer_ms = ov::genai::PerfMetrics::get_microsec(infer_end - infer_start); raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms); raw_metrics.m_token_infer_durations.emplace_back(infer_ms); raw_metrics.m_new_token_times.emplace_back(infer_end); @@ -151,9 +162,10 @@ std::pair decode(std::shared_ptrget_sampling_parameters(); @@ -255,7 +267,7 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& ov::InferRequest& encoder, std::shared_ptr decoder, WhisperFeatureExtractor& feature_extractor, - const std::shared_ptr streamer, + const std::shared_ptr streamer, Sampler& sampler) { size_t max_new_tokens = config.get_max_new_tokens(); @@ -329,7 +341,7 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& extracted_segments.non_timestamp_tokens.begin(), extracted_segments.non_timestamp_tokens.end()); - if (streamer && streamer->put_chunk(extracted_segments.non_timestamp_tokens)) { + if (streamer && streamer->put(extracted_segments.non_timestamp_tokens)) { cancelled = true; break; } diff --git a/src/cpp/src/whisper/whisper.hpp b/src/cpp/src/whisper/whisper.hpp index 88957119e8..1ed1b15a8d 100644 --- a/src/cpp/src/whisper/whisper.hpp +++ b/src/cpp/src/whisper/whisper.hpp @@ -10,6 +10,7 @@ #include "openvino/genai/whisper_generation_config.hpp" #include "openvino/genai/whisper_pipeline.hpp" #include "sampler.hpp" +#include "streamer.hpp" #include "whisper_config.hpp" #include "whisper_feature_extractor.hpp" #include "whisper_models.hpp" @@ -36,7 +37,7 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& ov::InferRequest& encoder, std::shared_ptr decoder, WhisperFeatureExtractor& feature_extractor, - const std::shared_ptr streamer, + const std::shared_ptr streamer, Sampler& sampler); } // namespace genai diff --git a/src/cpp/src/whisper_pipeline.cpp b/src/cpp/src/whisper_pipeline.cpp index 19a1a4831d..06ef1f818c 100644 --- a/src/cpp/src/whisper_pipeline.cpp +++ b/src/cpp/src/whisper_pipeline.cpp @@ -28,19 +28,23 @@ ov::genai::OptionalWhisperGenerationConfig get_config_from_map(const ov::AnyMap& } } -ov::genai::ChunkStreamerVariant get_chunk_streamer_from_map(const ov::AnyMap& config_map) { - ov::genai::ChunkStreamerVariant streamer = std::monostate(); +ov::genai::StreamerVariant get_streamer_from_map(const ov::AnyMap& config_map) { + ov::genai::StreamerVariant streamer = std::monostate(); if (config_map.count(ov::genai::utils::STREAMER_ARG_NAME)) { auto any_val = config_map.at(ov::genai::utils::STREAMER_ARG_NAME); - if (any_val.is>()) { - streamer = any_val.as>(); + if (any_val.is>()) { + streamer = any_val.as>(); } else if (any_val.is>()) { streamer = any_val.as>(); + } else if (any_val.is>()) { + auto chunk_streamer = any_val.as>(); + streamer = std::make_shared(chunk_streamer); } } return streamer; } + } // namespace namespace ov { @@ -72,7 +76,7 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input, OptionalWhisperGenerationConfig generation_config, - ChunkStreamerVariant streamer) override { + StreamerVariant streamer) override { auto start_time = std::chrono::steady_clock::now(); WhisperGenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; @@ -81,13 +85,14 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi config.set_eos_token_id(m_generation_config.eos_token_id); config.validate(); - std::shared_ptr streamer_ptr; + std::shared_ptr streamer_ptr = nullptr; + if (auto streamer_obj = std::get_if(&streamer)) { streamer_ptr = nullptr; - } else if (auto streamer_obj = std::get_if>(&streamer)) { + } else if (auto streamer_obj = std::get_if>(&streamer)) { streamer_ptr = *streamer_obj; } else if (auto callback = std::get_if>(&streamer)) { - streamer_ptr = std::make_shared(m_tokenizer, *callback); + streamer_ptr = std::make_shared(m_tokenizer, *callback); } auto [context_tokens, tokenization_duration_microseconds] = prepare_context_tokens(config, m_tokenizer); @@ -101,6 +106,7 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi m_feature_extractor, streamer_ptr, m_sampler); + auto decode_start_time = std::chrono::steady_clock::now(); WhisperDecodedResults result{std::vector{m_tokenizer.decode(generate_result.output_tokens)}, std::vector{1.f}}; generate_result.perf_metrics.raw_metrics.detokenization_durations.emplace_back( @@ -142,13 +148,8 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi Sampler m_sampler; }; -std::pair streamer(ChunkStreamerVariant func) { - if (auto streamer_obj = std::get_if>(&func)) { - return {utils::STREAMER_ARG_NAME, Any::make>(*streamer_obj)}; - } else { - auto callback = std::get>(func); - return {utils::STREAMER_ARG_NAME, Any::make>(callback)}; - } +std::pair streamer(std::shared_ptr func) { + return {utils::STREAMER_ARG_NAME, Any::make>(func)}; } std::pair generation_config(const WhisperGenerationConfig& config) { @@ -173,17 +174,24 @@ ov::genai::WhisperPipeline::WhisperPipeline(const std::filesystem::path& models_ ov::genai::WhisperDecodedResults ov::genai::WhisperPipeline::generate(const RawSpeechInput& raw_speech_input, OptionalWhisperGenerationConfig generation_config, - ChunkStreamerVariant streamer) { + StreamerVariant streamer) { return m_impl->generate(raw_speech_input, generation_config, streamer); } +ov::genai::WhisperDecodedResults ov::genai::WhisperPipeline::generate(const RawSpeechInput& raw_speech_input, + WhisperGenerationConfig generation_config, + std::shared_ptr streamer) { + StreamerVariant _streamer = std::make_shared(streamer); + return m_impl->generate(raw_speech_input, generation_config, _streamer); +} + ov::genai::WhisperDecodedResults ov::genai::WhisperPipeline::generate(const RawSpeechInput& raw_speech_input, const ov::AnyMap& config_map) { auto config_arg = get_config_from_map(config_map); WhisperGenerationConfig config = (config_arg.has_value()) ? *config_arg : get_generation_config(); config.update_generation_config(config_map); - return m_impl->generate(raw_speech_input, config, get_chunk_streamer_from_map(config_map)); + return m_impl->generate(raw_speech_input, config, get_streamer_from_map(config_map)); } ov::genai::WhisperGenerationConfig ov::genai::WhisperPipeline::get_generation_config() const { @@ -205,3 +213,5 @@ void ov::genai::WhisperPipeline::set_generation_config(const WhisperGenerationCo } ov::genai::WhisperPipeline::~WhisperPipeline() = default; + +ov::genai::ChunkStreamerBase::~ChunkStreamerBase() = default; diff --git a/src/cpp/src/whisper_pipeline_base.hpp b/src/cpp/src/whisper_pipeline_base.hpp index 0aa4790cb8..bf8c78e467 100644 --- a/src/cpp/src/whisper_pipeline_base.hpp +++ b/src/cpp/src/whisper_pipeline_base.hpp @@ -4,12 +4,10 @@ #pragma once #include "openvino/genai/whisper_pipeline.hpp" +#include "utils.hpp" #include "whisper/whisper_config.hpp" #include "whisper/whisper_feature_extractor.hpp" -#include "utils.hpp" - - namespace ov { namespace genai { @@ -30,7 +28,7 @@ class WhisperPipeline::WhisperPipelineImplBase { virtual WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input, OptionalWhisperGenerationConfig generation_config, - ChunkStreamerVariant streamer) = 0; + StreamerVariant streamer) = 0; virtual ~WhisperPipelineImplBase() = default; }; diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index e49a25e2d2..3ff40c3899 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -295,7 +295,7 @@ std::pair> full_decode(ov::Tensor& encoder_hidden_sta const size_t max_new_tokens, const bool return_timestamps, ov::genai::RawPerfMetrics& raw_metrics, - const std::shared_ptr streamer) { + const std::shared_ptr streamer) { int64_t output_token = decode(encoder_hidden_state, models.decoder, init_ids, config, raw_metrics, true, return_timestamps); std::vector output_tokens{output_token}; @@ -564,7 +564,7 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( const RawSpeechInput& raw_speech_input, OptionalWhisperGenerationConfig generation_config, - ChunkStreamerVariant streamer) { + StreamerVariant streamer) { auto start_time = std::chrono::steady_clock::now(); WhisperGenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; config.validate(); @@ -572,13 +572,13 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( OPENVINO_ASSERT(!config.initial_prompt.has_value(), "'initial_prompt' parameter is not supported on NPU device."); OPENVINO_ASSERT(!config.hotwords.has_value(), "'hotwords' parameter is not supported on NPU device."); - std::shared_ptr streamer_ptr; + std::shared_ptr streamer_ptr; if (auto streamer_obj = std::get_if(&streamer)) { streamer_ptr = nullptr; - } else if (auto streamer_obj = std::get_if>(&streamer)) { + } else if (auto streamer_obj = std::get_if>(&streamer)) { streamer_ptr = *streamer_obj; } else if (auto callback = std::get_if>(&streamer)) { - streamer_ptr = std::make_shared(m_tokenizer, *callback); + streamer_ptr = std::make_shared(m_tokenizer, *callback); } size_t max_new_tokens = config.get_max_new_tokens(); @@ -654,7 +654,7 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( extracted_segments.non_timestamp_tokens.begin(), extracted_segments.non_timestamp_tokens.end()); - if (streamer_ptr && streamer_ptr->put_chunk(extracted_segments.non_timestamp_tokens)) { + if (streamer_ptr && streamer_ptr->put(extracted_segments.non_timestamp_tokens)) { cancelled = true; break; } diff --git a/src/cpp/src/whisper_pipeline_static.hpp b/src/cpp/src/whisper_pipeline_static.hpp index 48425356b2..e389aeaf29 100644 --- a/src/cpp/src/whisper_pipeline_static.hpp +++ b/src/cpp/src/whisper_pipeline_static.hpp @@ -19,10 +19,11 @@ class DecoderCache { public: DecoderCache() = default; DecoderCache(std::shared_ptr model, ov::AnyMap properties) - : m_decoder_model(model) - , m_properties(properties) {} + : m_decoder_model(model), + m_properties(properties) {} ov::InferRequest get_model(uint8_t input_ids_size); + private: std::unordered_map m_cache; std::shared_ptr m_decoder_model; @@ -35,7 +36,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input, OptionalWhisperGenerationConfig generation_config, - ChunkStreamerVariant streamer) override; + StreamerVariant streamer) override; private: WhisperInitializedModels m_models; diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index 92e624e60e..7b4ae06acb 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -1514,9 +1514,9 @@ class StreamerBase: """ End is called at the end of generation. It can be used to flush cache if your own streamer has one """ - def put(self, token: int) -> bool: + def put(self, token: int | list[int]) -> bool: """ - Put is called every time new token is decoded. Returns a bool flag to indicate whether generation should be stopped, if return true generation stops + Put is called every time new token or vector of tokens is decoded. Returns a bool flag to indicate whether generation should be stopped, if return true generation stops """ class T5EncoderModel: """ @@ -2125,7 +2125,142 @@ class WhisperPipeline: models_path (os.PathLike): Path to the model file. device (str): Device to run the model on (e.g., CPU, GPU). """ - def generate(self, raw_speech_input: list[float], generation_config: WhisperGenerationConfig | None = None, streamer: typing.Callable[[str], bool] | ChunkStreamerBase | None = None, **kwargs) -> WhisperDecodedResults: + @typing.overload + def generate(self, raw_speech_input: list[float], generation_config: WhisperGenerationConfig | None = None, streamer: typing.Callable[[str], bool] | StreamerBase | None = None, **kwargs) -> WhisperDecodedResults: + """ + High level generate that receives raw speech as a vector of floats and returns decoded output. + + :param raw_speech_input: inputs in the form of list of floats. Required to be normalized to near [-1, 1] range and have 16k Hz sampling rate. + :type raw_speech_input: List[float] + + :param generation_config: generation_config + :type generation_config: WhisperGenerationConfig or a Dict + + :param streamer: streamer either as a lambda with a boolean returning flag whether generation should be stopped. + Streamer supported for short-form audio (< 30 seconds) with `return_timestamps=False` only + :type : Callable[[str], bool], ov.genai.StreamerBase + + :param kwargs: arbitrary keyword arguments with keys corresponding to WhisperGenerationConfig fields. + :type : Dict + + :return: return results in decoded form + :rtype: WhisperDecodedResults + + + WhisperGenerationConfig + + Whisper specific parameters: + :param decoder_start_token_id: Corresponds to the ”<|startoftranscript|>” token. + :type decoder_start_token_id: int + + :param pad_token_id: Padding token id. + :type pad_token_id: int + + :param translate_token_id: Translate token id. + :type translate_token_id: int + + :param transcribe_token_id: Transcribe token id. + :type transcribe_token_id: int + + :param no_timestamps_token_id: No timestamps token id. + :type no_timestamps_token_id: int + + :param prev_sot_token_id: Corresponds to the ”<|startofprev|>” token. + :type prev_sot_token_id: int + + :param is_multilingual: + :type is_multilingual: bool + + :param begin_suppress_tokens: A list containing tokens that will be suppressed at the beginning of the sampling process. + :type begin_suppress_tokens: list[int] + + :param suppress_tokens: A list containing the non-speech tokens that will be suppressed during generation. + :type suppress_tokens: list[int] + + :param language: Language token to use for generation in the form of <|en|>. + You can find all the possible language tokens in the generation_config.json lang_to_id dictionary. + :type language: Optional[str] + + :param lang_to_id: Language token to token_id map. Initialized from the generation_config.json lang_to_id dictionary. + :type lang_to_id: Dict[str, int] + + :param task: Task to use for generation, either “translate” or “transcribe” + :type task: int + + :param return_timestamps: If `true` the pipeline will return timestamps along the text for *segments* of words in the text. + For instance, if you get + WhisperDecodedResultChunk + start_ts = 0.5 + end_ts = 1.5 + text = " Hi there!" + then it means the model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds. + Note that a segment of text refers to a sequence of one or more words, rather than individual words. + :type return_timestamps: bool + + :param initial_prompt: Initial prompt tokens passed as a previous transcription (after `<|startofprev|>` token) to the first processing + window. Can be used to steer the model to use particular spellings or styles. + + Example: + auto result = pipeline.generate(raw_speech); + // He has gone and gone for good answered Paul Icrom who... + + auto result = pipeline.generate(raw_speech, ov::genai::initial_prompt("Polychrome")); + // He has gone and gone for good answered Polychrome who... + :type initial_prompt: Optional[str] + + :param hotwords: Hotwords tokens passed as a previous transcription (after `<|startofprev|>` token) to the all processing windows. + Can be used to steer the model to use particular spellings or styles. + + Example: + auto result = pipeline.generate(raw_speech); + // He has gone and gone for good answered Paul Icrom who... + + auto result = pipeline.generate(raw_speech, ov::genai::hotwords("Polychrome")); + // He has gone and gone for good answered Polychrome who... + :type hotwords: Optional[str] + + Generic parameters: + max_length: the maximum length the generated tokens can have. Corresponds to the length of the input prompt + + max_new_tokens. Its effect is overridden by `max_new_tokens`, if also set. + max_new_tokens: the maximum numbers of tokens to generate, excluding the number of tokens in the prompt. max_new_tokens has priority over max_length. + min_new_tokens: set 0 probability for eos_token_id for the first eos_token_id generated tokens. + ignore_eos: if set to true, then generation will not stop even if token is met. + eos_token_id: token_id of (end of sentence) + stop_strings: a set of strings that will cause pipeline to stop generating further tokens. + include_stop_str_in_output: if set to true stop string that matched generation will be included in generation output (default: false) + stop_token_ids: a set of tokens that will cause pipeline to stop generating further tokens. + echo: if set to true, the model will echo the prompt in the output. + logprobs: number of top logprobs computed for each position, if set to 0, logprobs are not computed and value 0.0 is returned. + Currently only single top logprob can be returned, so any logprobs > 1 is treated as logprobs == 1. (default: 0). + + repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty. + presence_penalty: reduces absolute log prob if the token was generated at least once. + frequency_penalty: reduces absolute log prob as many times as the token was generated. + + Beam search specific parameters: + num_beams: number of beams for beam search. 1 disables beam search. + num_beam_groups: number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. + diversity_penalty: value is subtracted from a beam's score if it generates the same token as any beam from other group at a particular time. + length_penalty: exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to + the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log + likelihood of the sequence (i.e. negative), length_penalty > 0.0 promotes longer sequences, while + length_penalty < 0.0 encourages shorter sequences. + num_return_sequences: the number of sequences to return for grouped beam search decoding. + no_repeat_ngram_size: if set to int > 0, all ngrams of that size can only occur once. + stop_criteria: controls the stopping condition for grouped beam search. It accepts the following values: + "openvino_genai.StopCriteria.EARLY", where the generation stops as soon as there are `num_beams` complete candidates; + "openvino_genai.StopCriteria.HEURISTIC" is applied and the generation stops when is it very unlikely to find better candidates; + "openvino_genai.StopCriteria.NEVER", where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm). + + Random sampling parameters: + temperature: the value used to modulate token probabilities for random sampling. + top_p: if set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering. + do_sample: whether or not to use multinomial random sampling that add up to `top_p` or higher are kept. + num_return_sequences: the number of sequences to generate from a single prompt. + """ + @typing.overload + def generate(self, raw_speech_input: list[float], generation_config: WhisperGenerationConfig, streamer: ChunkStreamerBase, **kwargs) -> WhisperDecodedResults: """ High level generate that receives raw speech as a vector of floats and returns decoded output. diff --git a/src/python/py_openvino_genai.cpp b/src/python/py_openvino_genai.cpp index 8b8bd831b0..a307ebebcf 100644 --- a/src/python/py_openvino_genai.cpp +++ b/src/python/py_openvino_genai.cpp @@ -73,6 +73,14 @@ class ConstructableStreamer: public StreamerBase { token // Argument(s) ); } + bool put(const std::vector& tokens) override { + PYBIND11_OVERRIDE_PURE( + bool, // Return type + StreamerBase, // Parent class + put, // Name of function in C++ (must match Python name) + tokens // Argument(s) + ); + } void end() override { PYBIND11_OVERRIDE_PURE(void, StreamerBase, end); } @@ -115,7 +123,18 @@ PYBIND11_MODULE(py_openvino_genai, m) { py::class_>(m, "StreamerBase", streamer_base_docstring) // Change the holder form unique_ptr to shared_ptr .def(py::init<>()) - .def("put", &StreamerBase::put, "Put is called every time new token is decoded. Returns a bool flag to indicate whether generation should be stopped, if return true generation stops", py::arg("token")) + .def("put", + [](StreamerBase& self, std::variant> token) { + if (auto _token = std::get_if(&token)) { + return self.put(*_token); + } else { + auto tokens = std::get>(token); + return self.put(tokens); + } + }, + "Put is called every time new token or vector of tokens is decoded. Returns a bool flag to indicate whether generation should be stopped, if return true generation stops", + py::arg("token") + ) .def("end", &StreamerBase::end, "End is called at the end of generation. It can be used to flush cache if your own streamer has one"); init_tokenizer(m); diff --git a/src/python/py_whisper_pipeline.cpp b/src/python/py_whisper_pipeline.cpp index aac14c258a..78b3f813d9 100644 --- a/src/python/py_whisper_pipeline.cpp +++ b/src/python/py_whisper_pipeline.cpp @@ -12,10 +12,10 @@ #include "openvino/genai/whisper_pipeline.hpp" #include "py_utils.hpp" #include "tokenizers_path.hpp" +#include "whisper/streamer.hpp" namespace py = pybind11; using ov::genai::ChunkStreamerBase; -using ov::genai::ChunkStreamerVariant; using ov::genai::DecodedResults; using ov::genai::GenerationConfig; using ov::genai::OptionalWhisperGenerationConfig; @@ -30,8 +30,6 @@ using ov::genai::WhisperGenerationConfig; using ov::genai::WhisperPerfMetrics; using ov::genai::WhisperPipeline; using ov::genai::WhisperRawPerfMetrics; -using PyBindChunkStreamerVariant = - std::variant, std::shared_ptr, std::monostate>; namespace pyutils = ov::genai::pybind::utils; @@ -245,30 +243,10 @@ class ConstructableChunkStreamer : public ChunkStreamerBase { } }; -ChunkStreamerVariant pystreamer_to_chunk_streamer(const PyBindChunkStreamerVariant& py_streamer) { - return std::visit( - pyutils::overloaded{[](const std::function& py_callback) { - // Wrap python streamer with manual utf-8 decoding. Do not rely - // on pybind automatic decoding since it raises exceptions on incomplete - // strings. - return static_cast([py_callback](std::string subword) -> bool { - auto py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace"); - return py_callback(py::reinterpret_borrow(py_str)); - }); - }, - [](std::shared_ptr streamer_cls) { - return static_cast(streamer_cls); - }, - [](std::monostate none) { - return static_cast(none); - }}, - py_streamer); -} - py::object call_whisper_common_generate(WhisperPipeline& pipe, const RawSpeechInput& raw_speech_input, const OptionalWhisperGenerationConfig& config, - const PyBindChunkStreamerVariant& py_streamer, + const StreamerVariant& streamer, const py::kwargs& kwargs) { // whisper config should initialized from generation_config.json in case of only kwargs provided // otherwise it would be initialized with default values which is unexpected for kwargs use case @@ -277,9 +255,13 @@ py::object call_whisper_common_generate(WhisperPipeline& pipe, auto updated_config = update_whisper_config_from_kwargs(base_config, kwargs); - ChunkStreamerVariant streamer = pystreamer_to_chunk_streamer(py_streamer); + ov::genai::WhisperDecodedResults results; + { + py::gil_scoped_release rel; + results = pipe.generate(raw_speech_input, updated_config, streamer); + } - return py::cast(pipe.generate(raw_speech_input, updated_config, streamer)); + return py::cast(results); } } // namespace @@ -397,9 +379,10 @@ void init_whisper_pipeline(py::module_& m) { [](WhisperPipeline& pipe, const RawSpeechInput& raw_speech_input, const OptionalWhisperGenerationConfig& generation_config, - const PyBindChunkStreamerVariant& streamer, + const pyutils::PyBindStreamerVariant& streamer, const py::kwargs& kwargs) -> py::typing::Union { - return call_whisper_common_generate(pipe, raw_speech_input, generation_config, streamer, kwargs); + StreamerVariant _streamer = pyutils::pystreamer_to_streamer(streamer); + return call_whisper_common_generate(pipe, raw_speech_input, generation_config, _streamer, kwargs); }, py::arg("raw_speech_input"), "List of floats representing raw speech audio. " @@ -410,6 +393,25 @@ void init_whisper_pipeline(py::module_& m) { "streamer", (whisper_generate_docstring + std::string(" \n ") + whisper_generation_config_docstring).c_str()) + .def( + "generate", + [](WhisperPipeline& pipe, + const RawSpeechInput& raw_speech_input, + const WhisperGenerationConfig& generation_config, + const std::shared_ptr& streamer, + const py::kwargs& kwargs) -> py::typing::Union { + StreamerVariant _streamer = std::make_shared(streamer); + return call_whisper_common_generate(pipe, raw_speech_input, generation_config, _streamer, kwargs); + }, + py::arg("raw_speech_input"), + "List of floats representing raw speech audio. " + "Required to be normalized to near [-1, 1] range and have 16k Hz sampling rate.", + py::arg("generation_config"), + "generation_config", + py::arg("streamer"), + "streamer", + (whisper_generate_docstring + std::string(" \n ") + whisper_generation_config_docstring).c_str()) + .def("get_tokenizer", &WhisperPipeline::get_tokenizer) .def("get_generation_config", &WhisperPipeline::get_generation_config, py::return_value_policy::copy) .def("set_generation_config", &WhisperPipeline::set_generation_config, py::arg("config")); diff --git a/tests/cpp/threaded_streamer.cpp b/tests/cpp/threaded_streamer.cpp new file mode 100644 index 0000000000..7af3900d72 --- /dev/null +++ b/tests/cpp/threaded_streamer.cpp @@ -0,0 +1,168 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "threaded_streamer.hpp" + +#include +#include + +using ov::genai::ThreadedStreamerWrapper; +using ::testing::An; + +class MockStreamerBase : public ov::genai::StreamerBase { +private: + std::chrono::milliseconds m_sleep_for{200}; + +public: + bool should_sleep = false; + bool should_drop = false; + + MockStreamerBase() { + ON_CALL(*this, put(An())).WillByDefault([this](int64_t token) { + if (should_sleep) { + std::this_thread::sleep_for(m_sleep_for); + } + return should_drop; + }); + + ON_CALL(*this, put(An&>())) + .WillByDefault([this](const std::vector& tokens) { + if (should_sleep) { + std::this_thread::sleep_for(m_sleep_for); + } + return should_drop; + }); + + ON_CALL(*this, end()).WillByDefault([this]() { + if (should_sleep) { + std::this_thread::sleep_for(m_sleep_for); + } + }); + } + + MOCK_METHOD(bool, put, (int64_t), (override)); + MOCK_METHOD(bool, put, (const std::vector&), (override)); + MOCK_METHOD(void, end, (), (override)); +}; + +TEST(TestThreadedStreamer, general_test) { + ov::genai::Tokenizer tokenizer{}; + const auto streamer = std::make_shared(); + + ThreadedStreamerWrapper threaded_streamer(streamer, tokenizer); + + threaded_streamer.start(); + + EXPECT_FALSE(threaded_streamer.is_dropped()); + EXPECT_CALL(*streamer, put(0)); + EXPECT_CALL(*streamer, put(1)); + + threaded_streamer.put(0); + threaded_streamer.put(1); + + EXPECT_FALSE(threaded_streamer.is_dropped()); + std::vector value{0, 1, 2}; + EXPECT_CALL(*streamer, put(value)); + threaded_streamer.put(value); + + EXPECT_FALSE(threaded_streamer.is_dropped()); + EXPECT_CALL(*streamer, end()); + threaded_streamer.end(); +} + +TEST(TestThreadedStreamer, heavy_callback_test) { + ov::genai::Tokenizer tokenizer{}; + const auto streamer = std::make_shared(); + streamer->should_sleep = true; + + ThreadedStreamerWrapper threaded_streamer(streamer, tokenizer); + + threaded_streamer.start(); + + EXPECT_FALSE(threaded_streamer.is_dropped()); + EXPECT_CALL(*streamer, put(0)).Times(3); + + std::vector value{0, 1, 2}; + EXPECT_CALL(*streamer, put(value)); + + EXPECT_CALL(*streamer, end()); + + threaded_streamer.put(0); + threaded_streamer.put(0); + threaded_streamer.put(0); + EXPECT_FALSE(threaded_streamer.is_dropped()); + threaded_streamer.put(value); + EXPECT_FALSE(threaded_streamer.is_dropped()); + + threaded_streamer.end(); +} + +TEST(TestThreadedStreamer, heavy_main_thread_test) { + ov::genai::Tokenizer tokenizer{}; + const auto streamer = std::make_shared(); + + ThreadedStreamerWrapper threaded_streamer(streamer, tokenizer); + + threaded_streamer.start(); + + EXPECT_CALL(*streamer, put(0)); + EXPECT_CALL(*streamer, put(1)); + threaded_streamer.put(0); + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + threaded_streamer.put(1); + + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + + std::vector value{0, 1, 2}; + EXPECT_CALL(*streamer, put(value)); + threaded_streamer.put(value); + + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + + EXPECT_CALL(*streamer, end()); + threaded_streamer.end(); +} + +TEST(TestThreadedStreamer, put_end_test) { + ov::genai::Tokenizer tokenizer{}; + const auto streamer = std::make_shared(); + + ThreadedStreamerWrapper threaded_streamer(streamer, tokenizer); + + threaded_streamer.start(); + + EXPECT_FALSE(threaded_streamer.is_dropped()); + EXPECT_CALL(*streamer, put(0)); + EXPECT_CALL(*streamer, end()); + + threaded_streamer.put(0); + threaded_streamer.end(); + + EXPECT_FALSE(threaded_streamer.is_dropped()); +} + +TEST(TestThreadedStreamer, drop_test) { + ov::genai::Tokenizer tokenizer{}; + const auto streamer = std::make_shared(); + + ThreadedStreamerWrapper threaded_streamer(streamer, tokenizer); + + threaded_streamer.start(); + + EXPECT_FALSE(threaded_streamer.is_dropped()); + EXPECT_CALL(*streamer, put(0)); + threaded_streamer.put(0); + + // wait to process prev token + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + EXPECT_FALSE(threaded_streamer.is_dropped()); + streamer->should_drop = true; + EXPECT_CALL(*streamer, put(1)); + threaded_streamer.put(1); + + // wait to process prev token + std::this_thread::sleep_for(std::chrono::milliseconds(4)); + EXPECT_TRUE(threaded_streamer.is_dropped()); + EXPECT_CALL(*streamer, end()); + threaded_streamer.end(); +} diff --git a/tests/python_tests/test_whisper_pipeline.py b/tests/python_tests/test_whisper_pipeline.py index 683693857e..37382d43de 100644 --- a/tests/python_tests/test_whisper_pipeline.py +++ b/tests/python_tests/test_whisper_pipeline.py @@ -20,6 +20,7 @@ from typing import Any, List, Dict + @pytest.fixture(scope="class", autouse=True) def run_gc_after_test(): """ @@ -158,8 +159,10 @@ def run_genai( return pipeline.generate(sample, genai_config, streamer=streamer) + MAX_DATASET_LENGTH = 30 + @functools.lru_cache(16) def get_whisper_dataset(language: str, long_form: bool) -> List: if not long_form: @@ -184,6 +187,7 @@ def get_whisper_dataset(language: str, long_form: bool) -> List: return [x["audio"]["array"] for x in ds] + @pytest.fixture def sample_from_dataset(request): language = request.param.get("language", "en") @@ -593,7 +597,100 @@ def test_random_sampling(model_descr, sample_from_dataset): @pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True)) -@pytest.mark.parametrize("sample_from_dataset", [{"language" : "en", "sample_id": 0}], indirect=True) +@pytest.mark.parametrize( + "sample_from_dataset", [{"language": "en", "sample_id": 0}], indirect=True +) +@pytest.mark.precommit +def test_sampler_interface(model_descr, sample_from_dataset): + _, _, _, genai_pipe = read_whisper_model(model_descr) + + # callback streamer + streamer_result = [] + + def streamer(text: str): + streamer_result.append(text) + return False + + result = genai_pipe.generate(sample_from_dataset, streamer=streamer) + + assert "".join(streamer_result) == result.texts[0] + + streamer_result = [] + + result = genai_pipe.generate(sample_from_dataset, return_timestamps=True, streamer=streamer) + + assert "".join(streamer_result) == result.texts[0] + + # streamer base + class Streamer(ov_genai.StreamerBase): + def __init__(self): + super().__init__() + self.tokens = [] + + def put(self, token: int | List[int]) -> bool: + if type(token) is int: + self.tokens.append(token) + else: + self.tokens += typing.cast(List[int], token) + return False + + def end(self) -> None: + self.text = genai_pipe.get_tokenizer().decode(self.tokens) + self.tokens = [] + + streamer_instance = Streamer() + result = genai_pipe.generate(sample_from_dataset, streamer=streamer_instance) + + assert streamer_instance.text == result.texts[0] + + result = genai_pipe.generate( + sample_from_dataset, + streamer=streamer_instance, + return_timestamps=True, + ) + + assert streamer_instance.text == result.texts[0] + + # chunk streamer base + class ChunkStreamer(ov_genai.ChunkStreamerBase): + def __init__(self): + super().__init__() + self.tokens = [] + self.text = "" + + def put(self, token: int) -> bool: + print("put: ", token) + self.tokens.append(token) + return False + + def put_chunk(self, tokens: List[int]) -> bool: + self.tokens += tokens + return False + + def end(self) -> None: + self.text = genai_pipe.get_tokenizer().decode(self.tokens) + self.tokens = [] + + chunk_streamer_instance = ChunkStreamer() + result = genai_pipe.generate(sample_from_dataset, streamer=chunk_streamer_instance) + + assert streamer_instance.text == result.texts[0] + + config = genai_pipe.get_generation_config() + config.return_timestamps = True + result = genai_pipe.generate( + sample_from_dataset, + config, + chunk_streamer_instance, + ) + + assert streamer_instance.text == result.texts[0] + + +@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True)) +@pytest.mark.parametrize( + "sample_from_dataset", [{"language": "en", "sample_id": 0}], indirect=True +) @pytest.mark.precommit def test_perf_metrics(model_descr, sample_from_dataset): model_id, path, hf_pipe, genai_pipe = read_whisper_model(model_descr)