Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Jan 17, 2025
1 parent 09f3020 commit c9f0df0
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace ov {
namespace genai {

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::function<UintCallbackStreamerResult(std::string)>, std::function<GenerationStatus(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::function<GenerationStatus(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using OptionalGenerationConfig = std::optional<GenerationConfig>;
using EncodedInputs = std::variant<ov::Tensor, TokenizedInputs>;
using StringInputs = std::variant<std::string, std::vector<std::string>>;
Expand Down
6 changes: 1 addition & 5 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@

namespace ov {
namespace genai {
// uint16_t for Python API here
struct UintCallbackStreamerResult {
uint16_t result;
};
using CallbackTypeVariant = std::variant<bool, UintCallbackStreamerResult, ov::genai::GenerationStatus, std::monostate>;
using CallbackTypeVariant = std::variant<bool, ov::genai::GenerationStatus, std::monostate>;

/**
* @brief base class for streamers. In order to use inherit from from this class and implement put, and methods
Expand Down
2 changes: 0 additions & 2 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ std::pair<ov::AnyMap, ov::genai::static_llm::ModelConfigDesc> split_model_descr(
std::pair<std::string, Any> streamer(StreamerVariant func) {
if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::shared_ptr<StreamerBase>>(*streamer_obj)};
} else if (auto streamer_obj = std::get_if<std::function<UintCallbackStreamerResult(std::string)>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::function<UintCallbackStreamerResult(std::string)>>(*streamer_obj)};
} else if (auto streamer_obj = std::get_if<std::function<GenerationStatus(std::string)>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::function<GenerationStatus(std::string)>>(*streamer_obj)};
} else {
Expand Down
10 changes: 0 additions & 10 deletions src/cpp/src/text_callback_streamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,6 @@ bool TextCallbackStreamer::is_generation_complete(CallbackTypeVariant callback_s
if (auto status = std::get_if<GenerationStatus>(&callback_status)) {
m_streaming_finish_status = *status;
is_complete = (m_streaming_finish_status == GenerationStatus::STOP || m_streaming_finish_status == GenerationStatus::CANCEL);
} else if (auto status = std::get_if<UintCallbackStreamerResult>(&callback_status)) {
auto result = status->result;
is_complete = result > 0;
if (result == (uint16_t)GenerationStatus::RUNNING) {
m_streaming_finish_status = GenerationStatus::RUNNING;
} else if (result == (uint16_t)GenerationStatus::CANCEL) {
m_streaming_finish_status = GenerationStatus::CANCEL;
} else {
m_streaming_finish_status = GenerationStatus::STOP;
}
} else if (auto status = std::get_if<bool>(&callback_status)) {
is_complete = *status;
m_streaming_finish_status = *status ? GenerationStatus::STOP : GenerationStatus::RUNNING;
Expand Down
5 changes: 0 additions & 5 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ ov::genai::StreamerVariant get_streamer_from_map(const ov::AnyMap& config_map) {
streamer = any_val.as<std::function<bool(std::string)>>();
} else if (any_val.is<std::function<GenerationStatus(std::string)>>()) {
streamer = any_val.as<std::function<GenerationStatus(std::string)>>();
} else if (any_val.is<std::function<UintCallbackStreamerResult(std::string)>>()) {
streamer = any_val.as<std::function<UintCallbackStreamerResult(std::string)>>();
}
}
return streamer;
Expand All @@ -197,9 +195,6 @@ std::shared_ptr<StreamerBase> create_streamer(StreamerVariant streamer, Tokenize
[&tokenizer = tokenizer](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(tokenizer, streamer);
},
[&tokenizer = tokenizer](const std::function<UintCallbackStreamerResult(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(tokenizer, streamer);
},
[&tokenizer = tokenizer](const std::function<ov::genai::GenerationStatus(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(tokenizer, streamer);
}
Expand Down
16 changes: 12 additions & 4 deletions src/python/py_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,20 @@ ov::genai::StreamerVariant pystreamer_to_streamer(const PyBindStreamerVariant& p
[&streamer](const std::function<std::optional<uint16_t>(py::str)>& py_callback){
// Wrap python streamer with manual utf-8 decoding. Do not rely
// on pybind automatic decoding since it raises exceptions on incomplete strings.
auto callback_wrapped = [py_callback](std::string subword) -> ov::genai::UintCallbackStreamerResult {
auto callback_wrapped = [py_callback](std::string subword) -> ov::genai::GenerationStatus {
auto py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
std::optional<uint16_t> callback_output = py_callback(py::reinterpret_borrow<py::str>(py_str));
ov::genai::UintCallbackStreamerResult result = {0};
if (callback_output.has_value())
result.result = *callback_output;
auto result = GenerationStatus::RUNNING;
if (callback_output.has_value()) {
if (*callback_output == (uint16_t)GenerationStatus::RUNNING) {
result = GenerationStatus::RUNNING;
} else if (*callback_output == (uint16_t)GenerationStatus::CANCEL) {
result = GenerationStatus::CANCEL;
} else {
result = GenerationStatus::STOP;
}
}

return result;
};
streamer = callback_wrapped;
Expand Down
2 changes: 1 addition & 1 deletion tests/python_tests/test_llm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def test_chat_scenario_callback_cancel(model_descr):
def callback(subword):
nonlocal current_iter
current_iter += 1
return ov_genai.GenerationStatus.CENCEL if current_iter == num_iters else ov_genai.GenerationStatus.RUNNING
return ov_genai.GenerationStatus.CANCEL if current_iter == num_iters else ov_genai.GenerationStatus.RUNNING

ov_pipe.start_chat()
for prompt in callback_questions:
Expand Down

0 comments on commit c9f0df0

Please sign in to comment.