Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper pipeline: use parallel streamer #1642

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions samples/python/text_generation/multinomial_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_stop_flag(self):
"""
return False

def put_word(self, word: str):
def put_word(self, word: str | None):
"""
Puts a word into the text queue.

Expand All @@ -72,17 +72,23 @@ def put_word(self, word: str):
"""
self.text_queue.put(word)

def put(self, token_id: int) -> bool:
def put(self, token: int | list[int]) -> bool:
"""
Processes a token and manages the decoding buffer. Adds decoded text to the queue.

Args:
token_id (int): The token_id to process.
token (int | list[int]): The token(s) to process.

Returns:
bool: True if generation should be stopped, False otherwise.
"""
self.tokens_cache.append(token_id)

if type(token) is int:
self.tokens_cache.append(token)
elif type(token) is list:
self.tokens_cache += token
self.decoded_lengths += [-1 for _ in token[:-1]]

text = self.tokenizer.decode(self.tokens_cache)
self.decoded_lengths.append(len(text))

Expand Down Expand Up @@ -132,12 +138,18 @@ def __init__(self, tokenizer, tokens_len):
super().__init__(tokenizer)
self.tokens_len = tokens_len

def put(self, token_id: int) -> bool:
if (len(self.tokens_cache) + 1) % self.tokens_len != 0:
self.tokens_cache.append(token_id)
def put(self, token: int | list[int]) -> bool:
if (len(self.tokens_cache) + 1) % self.tokens_len == 0:
return super().put(token)

if type(token) is int:
self.tokens_cache.append(token)
self.decoded_lengths.append(-1)
return False
return super().put(token_id)
elif type(token) is list:
self.tokens_cache += token
self.decoded_lengths += [-1 for _ in token]

return False


def main():
Expand Down
10 changes: 10 additions & 0 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ class OPENVINO_GENAI_EXPORTS StreamerBase {
/// @brief put is called every time new token is decoded,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
virtual bool put(int64_t token) = 0;
/// @brief put is called every time new vector of tokens is decoded, in case of assisting or prompt lookup decoding
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
virtual bool put(const std::vector<int64_t>& tokens) {
for (const auto token : tokens) {
if (put(token)) {
return true;
}
}
return false;
};

/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
virtual void end() = 0;
Expand Down
30 changes: 22 additions & 8 deletions src/cpp/include/openvino/genai/whisper_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,23 @@ using RawSpeechInput = std::vector<float>;
*
* @param m_tokenizer tokenizer
*/
class OPENVINO_GENAI_EXPORTS ChunkStreamerBase : public StreamerBase {
class OPENVINO_GENAI_EXPORTS ChunkStreamerBase {
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
public:
/// @brief put is called every time new token is decoded,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
virtual bool put(int64_t token) = 0;

/// @brief put is called every time new token chunk is generated,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
virtual bool put_chunk(std::vector<int64_t> tokens) = 0;
};

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using ChunkStreamerVariant =
std::variant<std::function<bool(std::string)>, std::shared_ptr<ChunkStreamerBase>, std::monostate>;
/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
virtual void end() = 0;

virtual ~ChunkStreamerBase() = 0;
};

struct OPENVINO_GENAI_EXPORTS WhisperRawPerfMetrics {
struct WhisperRawPerfMetrics {
/** @brief Duration for each features extraction call */
std::vector<MicroSeconds> features_extraction_durations;
};
Expand Down Expand Up @@ -151,7 +156,13 @@ class OPENVINO_GENAI_EXPORTS WhisperPipeline {
*/
WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input,
OptionalWhisperGenerationConfig generation_config = std::nullopt,
ChunkStreamerVariant streamer = std::monostate());
StreamerVariant streamer = std::monostate());

OPENVINO_DEPRECATED("ChunkStreamerBase is deprecated. "
"Use StreamerBase instead. Support will be removed in 2026.0")
WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input,
WhisperGenerationConfig generation_config,
std::shared_ptr<ChunkStreamerBase> streamer);

/**
* @brief High level generate that receives raw speech as a vector of floats and returns decoded output.
Expand All @@ -174,6 +185,9 @@ class OPENVINO_GENAI_EXPORTS WhisperPipeline {
void set_generation_config(const WhisperGenerationConfig& config);
};

OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> streamer(ChunkStreamerVariant func);
OPENVINO_DEPRECATED("ChunkStreamerBase is deprecated. "
"Use StreamerBase instead. Support will be removed in 2026.0")
OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> streamer(std::shared_ptr<ChunkStreamerBase> func);

OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> generation_config(const WhisperGenerationConfig& config);
} // namespace ov::genai
67 changes: 9 additions & 58 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "lora_helper.hpp"
#include "cache_state_dumper.hpp"
#include "utils.hpp"
#include "threaded_streamer.hpp"

namespace {

Expand Down Expand Up @@ -429,19 +430,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
}
set_adapters(sampling_params[0].adapters);

const std::shared_ptr<StreamerBase>& streamer_ptr = std::visit(overloaded{
[](std::monostate) -> std::shared_ptr<StreamerBase> {
return nullptr;
},
[](const std::shared_ptr<StreamerBase>& streamer) {
return streamer;
},
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
}
}, streamer);
const auto streamer_ptr = std::make_shared<ThreadedStreamerWrapper>(streamer, m_tokenizer);

OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && sampling_params[0].num_return_sequences == 1 &&
OPENVINO_ASSERT(!streamer_ptr->has_callback() || input_ids.size() == 1 && sampling_params[0].num_return_sequences == 1 &&
(sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()),
"Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding");

Expand All @@ -452,49 +443,12 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
}
auto all_requests = m_awaiting_requests; // we need to store all requests to get results from them once generation has finished

std::atomic<bool> has_active_requests = has_non_finished_requests();
GenerationHandle& generation = generations.at(0);

// create variables to make optimal thread-safe streaming
std::mutex mutex;
std::unique_lock lock(mutex);
std::condition_variable cv;

// to define streaming thread
std::shared_ptr<std::thread> t_stream_ptr = nullptr;
if (streamer_ptr) {
// define stream token lambda to use in `t_stream_ptr`
auto stream_tokens = [this, &generation, &streamer_ptr, &has_active_requests, &cv, &lock]() {
while (has_active_requests || generation->can_read()) {
// waiting for any tokens or request finishing
cv.wait(lock, [&generation, &has_active_requests]{
return generation->can_read() || !has_active_requests;
});

if (generation->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> generation_outputs = generation->read();
OPENVINO_ASSERT(generation_outputs.size() <= 1);
if (!generation_outputs.empty()) {
for (const auto& generated_token_id : generation_outputs.begin()->second.generated_ids) {
if (streamer_ptr->put(generated_token_id)) {
generation->drop();
break;
}
}
}
}
};
streamer_ptr->end();
};

// to define streaming thread
t_stream_ptr = std::make_shared<std::thread>([&stream_tokens] {
stream_tokens();
});
}
streamer_ptr->start();

std::exception_ptr thrown_exception = nullptr;
while (has_active_requests) {
while (has_non_finished_requests()) {
try {
const auto infer_start = std::chrono::steady_clock::now();
step();
Expand All @@ -510,17 +464,14 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
drop_requests(); // remove all requests from pipeline state in case of exception
thrown_exception = std::current_exception();
}
has_active_requests = has_non_finished_requests();
cv.notify_one();
stream_tokens(streamer_ptr, generation);
if (thrown_exception) {
throw thrown_exception;
streamer_ptr->end();
std::rethrow_exception(thrown_exception);
}
}

// waiting for competion of streaming
if (t_stream_ptr && t_stream_ptr->joinable()) {
t_stream_ptr->join();
}
streamer_ptr->end();

OPENVINO_ASSERT(m_requests.empty(), "Internal error: current request is supposed to be dropped within step() function as completed");

Expand Down
24 changes: 24 additions & 0 deletions src/cpp/src/icontinuous_batching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,28 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(

return decoded;
}

void ContinuousBatchingPipeline::IContinuousBatchingPipeline::stream_tokens(
const std::shared_ptr<ThreadedStreamerWrapper>& streamer_ptr,
const GenerationHandle& handle
) {
if (!streamer_ptr->has_callback() || !handle->can_read()) {
return;
}

if (streamer_ptr->is_dropped()) {
handle->drop();
return;
}

std::unordered_map<uint64_t, GenerationOutput> generation_outputs = handle->read();
OPENVINO_ASSERT(generation_outputs.size() <= 1);
if (generation_outputs.empty()) {
return;
}

const auto tokens = generation_outputs.begin()->second.generated_ids;
streamer_ptr->put(tokens);
}

}
3 changes: 3 additions & 0 deletions src/cpp/src/icontinuous_batching.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "sampler.hpp"
#include "model_runner.hpp"
#include "scheduler.hpp"
#include "threaded_streamer.hpp"

namespace ov::genai {

Expand Down Expand Up @@ -46,6 +47,8 @@ class ContinuousBatchingPipeline::IContinuousBatchingPipeline {
// to access m_load_time_ms
friend class ContinuousBatchingPipeline;

void stream_tokens(const std::shared_ptr<ThreadedStreamerWrapper>& streamer_ptr, const GenerationHandle& handle);

public:
GenerationConfig get_config() const;
PipelineMetrics get_metrics() const;
Expand Down
6 changes: 6 additions & 0 deletions src/cpp/src/perf_metrics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,19 @@ PerfMetrics PerfMetrics::operator+(const PerfMetrics& right) const {

// Concatenate durations, batch_sizes first token times.
auto& new_durations = res.raw_metrics.m_durations;
auto& new_inference_durations = res.raw_metrics.m_inference_durations;
auto& new_token_infer_durations = res.raw_metrics.m_token_infer_durations;
auto& new_batch_sizes = res.raw_metrics.m_batch_sizes;
auto& new_times_to_first_token = res.raw_metrics.m_times_to_first_token;
auto& right_inference_durations = right.raw_metrics.m_inference_durations;
auto& right_token_infer_durations = right.raw_metrics.m_token_infer_durations;
auto& right_durations = right.raw_metrics.m_durations;
auto& right_batch_sizes = right.raw_metrics.m_batch_sizes;
auto& right_times_to_first_token = right.raw_metrics.m_times_to_first_token;

new_durations.insert(new_durations.end(), right_durations.begin(), right_durations.end());
new_inference_durations.insert(new_inference_durations.end(), right_inference_durations.begin(), right_inference_durations.end());
new_token_infer_durations.insert(new_token_infer_durations.end(), right_token_infer_durations.begin(), right_token_infer_durations.end());
new_times_to_first_token.insert(new_times_to_first_token.end(), right_times_to_first_token.begin(), right_times_to_first_token.end());
new_batch_sizes.insert(new_batch_sizes.end(), right_batch_sizes.begin(), right_batch_sizes.end());

Expand Down
Loading
Loading