Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper pipeline: use parallel streamer #1642

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/cpp/src/perf_metrics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,19 @@ PerfMetrics PerfMetrics::operator+(const PerfMetrics& right) const {

// Concatenate durations, batch_sizes first token times.
auto& new_durations = res.raw_metrics.m_durations;
auto& new_inference_durations = res.raw_metrics.m_inference_durations;
auto& new_token_infer_durations = res.raw_metrics.m_token_infer_durations;
auto& new_batch_sizes = res.raw_metrics.m_batch_sizes;
auto& new_times_to_first_token = res.raw_metrics.m_times_to_first_token;
auto& right_inference_durations = right.raw_metrics.m_inference_durations;
auto& right_token_infer_durations = right.raw_metrics.m_token_infer_durations;
auto& right_durations = right.raw_metrics.m_durations;
auto& right_batch_sizes = right.raw_metrics.m_batch_sizes;
auto& right_times_to_first_token = right.raw_metrics.m_times_to_first_token;

new_durations.insert(new_durations.end(), right_durations.begin(), right_durations.end());
new_inference_durations.insert(new_inference_durations.end(), right_inference_durations.begin(), right_inference_durations.end());
new_token_infer_durations.insert(new_token_infer_durations.end(), right_token_infer_durations.begin(), right_token_infer_durations.end());
new_times_to_first_token.insert(new_times_to_first_token.end(), right_times_to_first_token.begin(), right_times_to_first_token.end());
new_batch_sizes.insert(new_batch_sizes.end(), right_batch_sizes.begin(), right_batch_sizes.end());

Expand Down
120 changes: 120 additions & 0 deletions src/cpp/src/whisper/streamer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

#pragma once

#include <condition_variable>
#include <queue>
#include <thread>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These headers seem to be redundant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


#include "openvino/genai/tokenizer.hpp"
#include "openvino/genai/whisper_pipeline.hpp"
#include "text_callback_streamer.hpp"
Expand All @@ -20,5 +24,121 @@ class ChunkTextCallbackStreamer : private TextCallbackStreamer, public ChunkStre
: TextCallbackStreamer(tokenizer, callback){};
};

class WhisperStreamer {
public:
WhisperStreamer(ChunkStreamerVariant& streamer, Tokenizer& tokenizer) {
if (auto streamer_obj = std::get_if<std::monostate>(&streamer)) {
m_streamer_ptr = nullptr;
} else if (auto streamer_obj = std::get_if<std::shared_ptr<ChunkStreamerBase>>(&streamer)) {
m_streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
m_streamer_ptr = std::make_shared<ChunkTextCallbackStreamer>(tokenizer, *callback);
}
};

void start() {
if (!m_streamer_ptr) {
return;
}

m_worker_thread = std::make_shared<std::thread>(&WhisperStreamer::_worker, this);
}

void put_chunk(const std::vector<int64_t>& tokens) {
if (!m_streamer_ptr) {
return;
}

std::lock_guard<std::mutex> lock(m_mutex);
m_queue.push(tokens);
m_cv.notify_one();
}

void put(const int64_t token) {
if (!m_streamer_ptr) {
return;
}

std::lock_guard<std::mutex> lock(m_mutex);
m_queue.push(token);
m_cv.notify_one();
}

void end() {
if (!m_streamer_ptr) {
return;
}

{
std::lock_guard<std::mutex> lock(m_mutex);
m_stopped = true;
}

m_cv.notify_one();

if (m_worker_thread && m_worker_thread->joinable()) {
m_worker_thread->join();
}

m_streamer_ptr->end();
}

bool is_dropped() {
std::lock_guard<std::mutex> lock(m_mutex);
return m_dropped;
}

private:
std::shared_ptr<ChunkStreamerBase> m_streamer_ptr = nullptr;
std::shared_ptr<std::thread> m_worker_thread = nullptr;
std::mutex m_mutex;
std::condition_variable m_cv;
std::queue<std::variant<int64_t, std::vector<int64_t>>> m_queue;

bool m_stopped = false;
bool m_dropped = false;

void _worker() {
while (true) {
std::variant<int64_t, std::vector<int64_t>> token_variant;
{
std::unique_lock<std::mutex> lock(m_mutex);

// wait for the next token in queue or if streamer was stopped
m_cv.wait(lock, [this] {
return m_stopped || !m_queue.empty();
});

// continue streaming until queue is empty
if (m_stopped && m_queue.empty()) {
break;
}
Wovchena marked this conversation as resolved.
Show resolved Hide resolved

token_variant = m_queue.front();
Wovchena marked this conversation as resolved.
Show resolved Hide resolved
m_queue.pop();
}

// wait for streamer_ptr result
bool is_dropped = false;

if (auto token = std::get_if<int64_t>(&token_variant)) {
is_dropped = m_streamer_ptr->put(*token);
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
} else {
auto tokens = std::get_if<std::vector<int64_t>>(&token_variant);
is_dropped = m_streamer_ptr->put_chunk(*tokens);
}

{
std::lock_guard<std::mutex> lock(m_mutex);
m_dropped = is_dropped;

if (m_dropped) {
break;
}
}
}
}
};

} // namespace genai
} // namespace ov
32 changes: 18 additions & 14 deletions src/cpp/src/whisper/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,33 @@ void process_whisper_logits(ov::Tensor logits,
std::pair<ov::genai::EncodedResults, bool> decode(std::shared_ptr<ov::genai::WhisperDecoder> decoder,
const std::vector<int64_t>& input_ids,
const ov::Tensor& encoder_hidden_state,
const std::shared_ptr<ov::genai::StreamerBase> streamer_ptr,
const std::shared_ptr<ov::genai::WhisperStreamer> streamer,
ov::genai::Sampler& sampler,
ov::genai::SequenceGroup::Ptr sequence_group,
const bool return_timestamps,
const ov::genai::WhisperGenerationConfig& config,
ov::genai::RawPerfMetrics& raw_metrics) {
const auto handle = std::make_shared<ov::genai::GenerationHandleImpl>(sequence_group->get_generation_stream(),
sequence_group->get_sampling_parameters());
auto on_generated_tokens = [&streamer, &handle, &return_timestamps]() {
// handle return_timestamps case, where streamer->put_chunk called once per chunk
if (streamer->is_dropped()) {
handle->drop();
return;
}

auto stream_generated_tokens = [&streamer_ptr, &handle, &return_timestamps]() {
if (return_timestamps || !streamer_ptr || !handle->can_read()) {
if (return_timestamps || !handle->can_read()) {
return;
}

std::unordered_map<uint64_t, ov::genai::GenerationOutput> token = handle->back();
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
if (streamer->is_dropped()) {
handle->drop();
break;
}

streamer->put(gen_token);
}
};

Expand All @@ -96,7 +103,7 @@ std::pair<ov::genai::EncodedResults, bool> decode(std::shared_ptr<ov::genai::Whi
sequence_group->set_output_seq_len(output_sequence_len);

sampler.sample({sequence_group}, logits);
stream_generated_tokens();
on_generated_tokens();

// "Generation" phase
while (!sequence_group->has_finished() && !sequence_group->handle_dropped()) {
Expand Down Expand Up @@ -151,7 +158,7 @@ std::pair<ov::genai::EncodedResults, bool> decode(std::shared_ptr<ov::genai::Whi
process_whisper_logits(logits, config, return_timestamps, batch_to_generated_ids);

sampler.sample({sequence_group}, logits);
stream_generated_tokens();
on_generated_tokens();
}

ov::genai::EncodedResults results;
Expand Down Expand Up @@ -255,7 +262,7 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig&
ov::InferRequest& encoder,
std::shared_ptr<WhisperDecoder> decoder,
WhisperFeatureExtractor& feature_extractor,
const std::shared_ptr<ChunkStreamerBase> streamer,
const std::shared_ptr<WhisperStreamer> streamer,
Sampler& sampler) {
size_t max_new_tokens = config.get_max_new_tokens();

Expand All @@ -276,6 +283,8 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig&
// long-form audio processing requires timestamps to be enabled
const bool return_timestamps = config.return_timestamps || !is_shortform;

streamer->start();

std::vector<int64_t> init_tokens;
std::vector<int64_t>& output_tokens = result.output_tokens;
std::vector<Segment> segments;
Expand Down Expand Up @@ -329,10 +338,7 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig&
extracted_segments.non_timestamp_tokens.begin(),
extracted_segments.non_timestamp_tokens.end());

if (streamer && streamer->put_chunk(extracted_segments.non_timestamp_tokens)) {
cancelled = true;
break;
}
streamer->put_chunk(extracted_segments.non_timestamp_tokens);

segment_offset = extracted_segments.last_offset;
} else {
Expand All @@ -348,9 +354,7 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig&
}
}

if (streamer) {
streamer->end();
}
streamer->end();

// if return_timestamps wasn't enabled by user
if (!config.return_timestamps) {
Expand Down
3 changes: 2 additions & 1 deletion src/cpp/src/whisper/whisper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/genai/whisper_generation_config.hpp"
#include "openvino/genai/whisper_pipeline.hpp"
#include "sampler.hpp"
#include "streamer.hpp"
#include "whisper_config.hpp"
#include "whisper_feature_extractor.hpp"
#include "whisper_models.hpp"
Expand All @@ -36,7 +37,7 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig&
ov::InferRequest& encoder,
std::shared_ptr<WhisperDecoder> decoder,
WhisperFeatureExtractor& feature_extractor,
const std::shared_ptr<ChunkStreamerBase> streamer,
const std::shared_ptr<WhisperStreamer> streamer,
Sampler& sampler);

} // namespace genai
Expand Down
10 changes: 2 additions & 8 deletions src/cpp/src/whisper_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,7 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi
config.set_eos_token_id(m_generation_config.eos_token_id);
config.validate();

std::shared_ptr<ChunkStreamerBase> streamer_ptr;
if (auto streamer_obj = std::get_if<std::monostate>(&streamer)) {
streamer_ptr = nullptr;
} else if (auto streamer_obj = std::get_if<std::shared_ptr<ChunkStreamerBase>>(&streamer)) {
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<ChunkTextCallbackStreamer>(m_tokenizer, *callback);
}
const std::shared_ptr<WhisperStreamer> streamer_ptr = std::make_shared<WhisperStreamer>(streamer, m_tokenizer);

auto [context_tokens, tokenization_duration_microseconds] = prepare_context_tokens(config, m_tokenizer);

Expand All @@ -101,6 +94,7 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi
m_feature_extractor,
streamer_ptr,
m_sampler);

auto decode_start_time = std::chrono::steady_clock::now();
WhisperDecodedResults result{std::vector{m_tokenizer.decode(generate_result.output_tokens)}, std::vector{1.f}};
generate_result.perf_metrics.raw_metrics.detokenization_durations.emplace_back(
Expand Down
9 changes: 8 additions & 1 deletion src/python/py_whisper_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ ChunkStreamerVariant pystreamer_to_chunk_streamer(const PyBindChunkStreamerVaria
// on pybind automatic decoding since it raises exceptions on incomplete
// strings.
return static_cast<ChunkStreamerVariant>([py_callback](std::string subword) -> bool {
py::gil_scoped_acquire acquire;
auto py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
return py_callback(py::reinterpret_borrow<py::str>(py_str));
});
Expand Down Expand Up @@ -279,7 +280,13 @@ py::object call_whisper_common_generate(WhisperPipeline& pipe,

ChunkStreamerVariant streamer = pystreamer_to_chunk_streamer(py_streamer);

return py::cast(pipe.generate(raw_speech_input, updated_config, streamer));
ov::genai::WhisperDecodedResults results;
{
py::gil_scoped_release rel;
results = pipe.generate(raw_speech_input, updated_config, streamer);
}

return py::cast(results);
}

} // namespace
Expand Down
Loading