Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper pipeline: support stateful decoder #1474

Merged
merged 16 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/cpp/src/whisper/models/decoder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "decoder.hpp"

#include <filesystem>

#include "statefull_decoder.hpp"
#include "utils.hpp"
#include "with_past_decoder.hpp"

namespace ov::genai {
std::shared_ptr<WhisperDecoder> WhisperDecoder::from_path(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties) {
bool has_decoder_with_past = std::filesystem::exists(models_path / "openvino_decoder_with_past_model.xml");

if (has_decoder_with_past) {
return std::make_shared<WhisperWithPastDecoder>(models_path, device, properties);
}

return std::make_shared<WhisperStatefullDecoder>(models_path, device, properties);
}

WhisperDecoder::~WhisperDecoder() = default;
} // namespace ov::genai
29 changes: 29 additions & 0 deletions src/cpp/src/whisper/models/decoder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <filesystem>

#include "openvino/genai/whisper_generation_config.hpp"
#include "openvino/runtime/core.hpp"

namespace ov::genai {
class WhisperDecoder {
public:
static std::shared_ptr<WhisperDecoder> from_path(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties);

virtual std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) = 0;

virtual std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) = 0;

virtual void reset_state() = 0;

virtual ~WhisperDecoder();
};
} // namespace ov::genai
60 changes: 60 additions & 0 deletions src/cpp/src/whisper/models/statefull_decoder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "statefull_decoder.hpp"

#include "utils.hpp"

namespace ov::genai {
WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties) {
ov::Core core = utils::singleton_core();

auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties);

utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
m_request = compiled_model.create_infer_request();
}

std::pair<int64_t, float> WhisperStatefullDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) {
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0);

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);

reset_state();

return {output_token, infer_ms};
}

std::pair<ov::Tensor, float> WhisperStatefullDecoder::decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) {
m_request.set_tensor("encoder_hidden_states", encoder_hidden_state);

ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data());
m_request.set_tensor("input_ids", input_ids_tensor);

ov::Tensor cache_position_tensor = m_request.get_tensor("cache_position");
cache_position_tensor.set_shape({input_ids.size()});

auto cache_data = cache_position_tensor.data<int64_t>();
std::iota(cache_data, cache_data + cache_position_tensor.get_size(), cache_position);

m_request.get_tensor("beam_idx").set_shape({1});
m_request.get_tensor("beam_idx").data<int32_t>()[0] = 0;

const auto infer_start = std::chrono::steady_clock::now();
m_request.infer();
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start);

auto output_tensor = m_request.get_tensor("logits");

return {output_tensor, infer_ms};
};

void WhisperStatefullDecoder::reset_state() {
m_request.reset_state();
}
} // namespace ov::genai
29 changes: 29 additions & 0 deletions src/cpp/src/whisper/models/statefull_decoder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "decoder.hpp"
#include "openvino/runtime/core.hpp"

namespace ov::genai {

class WhisperStatefullDecoder : public WhisperDecoder {
public:
WhisperStatefullDecoder(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties);

std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) override;

std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) override;

void reset_state() override;

private:
ov::InferRequest m_request;
};
} // namespace ov::genai
102 changes: 102 additions & 0 deletions src/cpp/src/whisper/models/with_past_decoder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "with_past_decoder.hpp"

#include <regex>

#include "utils.hpp"

namespace {
void set_past_key_value(ov::InferRequest& source, ov::InferRequest& dest) {
// source outputs:
// present.0.decoder.key
// present.0.decoder.value
// present.0.encoder.key
// present.0.encoder.value

// dest inputs:
// past_key_values.0.decoder.key
// past_key_values.0.decoder.value
// past_key_values.0.encoder.key
// past_key_values.0.encoder.value

for (auto& source_output : source.get_compiled_model().outputs()) {
std::string source_output_name = source_output.get_any_name();
if (source_output_name.find("logits") != std::string::npos) {
continue;
}

std::string with_past_input_name =
std::regex_replace(source_output_name, std::regex("present"), "past_key_values");

auto kv_tensor = source.get_tensor(source_output_name);
dest.set_tensor(with_past_input_name, ov::Tensor{kv_tensor});
}
}
} // namespace

namespace ov::genai {
WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& models_path,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilya-lavrenov @Wovchena I want to add deprecation note for this ctor. I saw OPENVINO_DEPRECATED macros but I think it doesn't fit here as I want to warn user at runtime.
Can I add just std::cout << "[Warning] Whisper decoder with past deprecated ..." ? Does OV have logging utilities we can reuse? Or there is a better way?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have a better way than a message in runtime

const std::string& device,
const ov::AnyMap& properties) {
ov::Core core = utils::singleton_core();

auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties);
utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
m_request_decoder = compiled_model.create_infer_request();

compiled_model = core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, properties);
utils::print_compiled_model_properties(compiled_model, "whisper decoder with past model");
m_request_decoder_with_past = compiled_model.create_infer_request();
}

std::pair<int64_t, float> WhisperWithPastDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) {
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0);

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);

reset_state();

return {output_token, infer_ms};
}

std::pair<ov::Tensor, float> WhisperWithPastDecoder::decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) {
const bool initial_step = cache_position == 0;
ov::InferRequest& request = initial_step ? m_request_decoder : m_request_decoder_with_past;

request.set_tensor("encoder_hidden_states", encoder_hidden_state);

const ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data());
request.set_tensor("input_ids", input_ids_tensor);

if (!initial_step) {
ov::Tensor cache_position_tensor = request.get_tensor("cache_position");
cache_position_tensor.set_shape({1});
cache_position_tensor.data<int64_t>()[0] = cache_position;
}

const auto infer_start = std::chrono::steady_clock::now();
request.infer();
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start);

auto output_tensor = request.get_tensor("logits");

if (initial_step) {
set_past_key_value(m_request_decoder, m_request_decoder_with_past);
} else if (!m_decoder_with_past_kv_value_set) {
set_past_key_value(m_request_decoder_with_past, m_request_decoder_with_past);
m_decoder_with_past_kv_value_set = true;
}

return {output_tensor, infer_ms};
}

void WhisperWithPastDecoder::reset_state() {
m_request_decoder_with_past.reset_state();
m_decoder_with_past_kv_value_set = false;
}
} // namespace ov::genai
32 changes: 32 additions & 0 deletions src/cpp/src/whisper/models/with_past_decoder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "decoder.hpp"
#include "openvino/runtime/core.hpp"

namespace ov::genai {

class WhisperWithPastDecoder : public WhisperDecoder {
public:
WhisperWithPastDecoder(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties);

std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) override;

std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) override;

void reset_state() override;

private:
ov::InferRequest m_request_decoder;
ov::InferRequest m_request_decoder_with_past;
bool m_decoder_with_past_kv_value_set = false;
};

} // namespace ov::genai
Loading
Loading