Skip to content

Commit

Permalink
[Tokenizers] add max_lengh parametrisation to encode (#1518)
Browse files Browse the repository at this point in the history
Works in collaboration with tokenizers changes
openvinotoolkit/openvino_tokenizers#362
Ticket: CVS-157356, CVS-159924
  • Loading branch information
pavel-esir authored Feb 7, 2025
1 parent 1bdd4f9 commit be40053
Show file tree
Hide file tree
Showing 12 changed files with 374 additions and 105 deletions.
97 changes: 97 additions & 0 deletions src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,103 @@ print(f'Median from token to token duration: {np.median(durations):.2f} ms')

For more examples of how metrics are used, please refer to the Python [benchmark_genai.py](../samples/python/text_generation/README.md) and C++ [benchmark_genai](../samples/cpp/text_generation/README.md) samples.

### Tokenization

OpenVINO™ GenAI provides a way to tokenize and detokenize text using the `ov::genai::Tokenizer` class. The `Tokenizer` is a high level abstraction over the OpenVINO Tokenizers library.

It can be initialized from the path, in-memory IR representation or obtained from the `ov::genai::LLMPipeline` object.

```cpp
// Initialize from the path
#include "openvino/genai/llm_pipeline.hpp"
auto tokenizer = ov::genai::Tokenizer(models_path);

// Get instance of Tokenizer from LLMPipeline.
auto pipe = ov::genai::LLMPipeline pipe(models_path, "CPU");
auto tokenzier = pipe.get_tokenizer();
````
```python
import openvino_genai as ov_genai
tokenizer = ov_genai.Tokenizer(models_path)
# Or from LLMPipeline.
pipe = ov_genai.LLMPipeline(models_path, "CPU")
tokenizer = pipe.get_tokenizer()
```
`Tokenizer` has `encode` and `decode` methods which support the following arguments: `add_special_tokens`, `skip_special_tokens`, `pad_to_max_length`, `max_length` arguments.
In order to disable adding special tokens do the followings, in C++:
```cpp
auto tokens = tokenizer.encode("The Sun is yellow because", ov::genai::add_special_tokens(false));
```
In Python:
```python
tokens = tokenizer.encode("The Sun is yellow because", add_special_tokens=False)
```
The `encode` method returns a `TokenizedInputs` object containing `input_ids` and `attention_mask`, both stored as ov::Tensor. Since ov::Tensor requires fixed-length sequences, padding is applied to match the longest sequence in a batch, ensuring a uniform shape. Also resulting sequence is truncated by `max_length`. If this value is not defined by used, it's is taken from the IR.
Both padding and `max_length` can be controlled by the user. If `pad_to_max_length` is set to true, then instead of padding to the longest sequence it will be padded to the `max_length`.
Below are example how padding can be controlled, in C++:
```cpp
#include "openvino/genai/llm_pipeline.hpp"
auto tokenizer = ov::genai::Tokenizer(models_path);
std::vector<std::string> prompts = {"The Sun is yellow because", "The"};
// Since prompt is defenitely shorter than maximal length (which is taken from IR) will not affect shape.
// Resulting shape is defined by length of the longest tokens sequence.
// Equivalent of HuggingFace hf_tokenizer.encode(prompt, padding="longest", truncation=True)
tokens = tokenizer.encode({"The Sun is yellow because", "The"})
// or is equivalent to
tokens = tokenizer.encode({"The Sun is yellow because", "The"}, ov::genai::pad_to_max_length(False))
// out_shape: [2, 6]
// Resulting tokens tensor will be padded to 1024.
// Equivalent of HuggingFace hf_tokenizer.encode(prompt, padding="max_length", truncation=True, max_length=1024)
tokens = tokenizer.encode({"The Sun is yellow because",
"The",
std::string(2000, 'n')}, ov::genai::pad_to_max_length(True), ov::genai::max_length(1024))
// out_shape: [3, 1024]
// For single string prompts truncation and padding are also applied.
tokens = tokenizer.encode({"The Sun is yellow because"}, ov::genai::pad_to_max_length(True), ov::genai::max_length(1024))
// out_shape: [1, 128]
```
In Python:
```python
import openvino_genai as ov_genai
tokenizer = ov_genai.Tokenizer(models_path)
prompts = ["The Sun is yellow because", "The"]
# Since prompt is defenitely shorter than maximal length (which is taken from IR) will not affect shape.
# Resulting shape is defined by length of the longest tokens sequence.
# Equivalent of HuggingFace hf_tokenizer.encode(prompt, padding="longest", truncation=True)
tokens = tokenizer.encode(["The Sun is yellow because", "The"])
# or is equivalent to
tokens = tokenizer.encode(["The Sun is yellow because", "The"], pad_to_max_length=False)
print(tokens.input_ids.shape)
# out_shape: [2, 6]
# Resulting tokens tensor will be padded to 1024, sequences which exceed this length will be truncated.
# Equivalent of HuggingFace hf_tokenizer.encode(prompt, padding="max_length", truncation=True, max_length=1024)
tokens = tokenizer.encode(["The Sun is yellow because",
"The"
"The longest string ever" * 2000], pad_to_max_length=True, max_length=1024)
print(tokens.input_ids.shape)
# out_shape: [3, 1024]
# For single string prompts truncation and padding are also applied.
tokens = tokenizer.encode("The Sun is yellow because", pad_to_max_length=True, max_length=128)
print(tokens.input_ids.shape)
# out_shape: [1, 128]
```
## How It Works
For information on how OpenVINO™ GenAI works, refer to the [How It Works Section](./docs/HOW_IT_WORKS.md).
Expand Down
13 changes: 9 additions & 4 deletions src/cpp/include/openvino/genai/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {
/**
* @brief encode a single prompt
* @param prompt std::string with input prompt
* @param tokenization_params AnyMap with tokenization parameters, e.g. {"add_special_tokens", false}
* @param tokenization_params AnyMap with tokenization parameters, e.g. {{"add_special_tokens", false}, {"max_length", 128}}
* @return pair of [input_ids, attention_mask]
*/
TokenizedInputs encode(const std::string prompt, const ov::AnyMap& tokenization_params = {});

/**
* @brief encode batch of prompts. Left padding will be applied by default
* @param prompts vector storing batch of prompts
* @param tokenization_params AnyMap with tokenization parameters, e.g. {"add_special_tokens", false}
* @param tokenization_params AnyMap with tokenization parameters, e.g. {{"add_special_tokens", false}, {"max_length", 128}}
* @return pair of [input_ids, attention_mask]
*/
TokenizedInputs encode(std::vector<std::string>& prompt, const ov::AnyMap& tokenization_params = {});
Expand All @@ -125,7 +125,9 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {
/**
* @brief encode a single prompt
* @param prompt std::string with input prompt
* @param properties tokenization properties, e.g. ov::genai::add_special_tokens(false)
* @param add_special_tokens whether to add special tokens
* @param max_length optional maximum length to which output will be truncated and/or padded. If not defined, taken from IR.
* @param pad_to_max_length either pad to max_length, or pad to the longest sequence in the batch. Default is false.
* @return pair of [input_ids, attention_mask]
*/
template <typename... Properties>
Expand All @@ -136,7 +138,9 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {
/**
* @brief encode batch of prompts. Left padding will be applied by default
* @param prompts vector storing batch of prompts
* @param properties tokenization properties, e.g. ov::genai::add_special_tokens(false)
* @param add_special_tokens whether to add special tokens
* @param max_length optional maximum length to which output will be truncated and/or padded. If not defined, taken from IR.
* @param pad_to_max_length either pad to max_length, or pad to the longest sequence in the batch. Default is false.
* @return pair of [input_ids, attention_mask]
*/
template <typename... Properties>
Expand Down Expand Up @@ -243,6 +247,7 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {

static constexpr ov::Property<bool> add_special_tokens{"add_special_tokens"};
static constexpr ov::Property<bool> skip_special_tokens{"skip_special_tokens"};
static constexpr ov::Property<bool> pad_to_max_length{"pad_to_max_length"};

} // namespace genai
} // namespace ov
122 changes: 121 additions & 1 deletion src/cpp/src/make_tokenizer_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include "make_tokenizer_stateful.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/maximum.hpp"
#include "openvino/op/minimum.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/read_value.hpp"
Expand All @@ -13,7 +17,7 @@
using namespace ov;
using namespace ov::op;

bool ov::genai::MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {
bool ov::genai::MakeAddSpecialTokensSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {
std::shared_ptr<ov::Node> combine_seg_node;
for (auto node: model->get_ordered_ops()) {
if (strcmp(node->get_type_info().name, "CombineSegments") == 0) {
Expand Down Expand Up @@ -56,6 +60,7 @@ bool ov::genai::MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr
return true;
}


bool ov::genai::MakeVocabDecoderSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {
std::shared_ptr<ov::Node> vocab_decoder_node;
for (auto node: model->get_ordered_ops()) {
Expand Down Expand Up @@ -97,3 +102,118 @@ bool ov::genai::MakeVocabDecoderSatateful::run_on_model(const std::shared_ptr<ov
model->add_variables({variable});
return true;
}


bool ov::genai::MakePaddingSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {
std::shared_ptr<ov::Node> combine_seg_node;
for (auto node: model->get_ordered_ops()) {
if (strcmp(node->get_type_info().name, "CombineSegments") == 0) {
combine_seg_node = node;
}
}
if (!combine_seg_node) { return false; }
auto num_comb = combine_seg_node->get_input_size();

size_t num_segments = (combine_seg_node->get_input_size() - 1) / 3;
size_t number_of_main_tokens_inputs = 0;
std::shared_ptr<Node> add_or_sub_node;
for (size_t i = 0; i < num_segments; i++) {
// Check all ends inputs of CombineSegments node.
// For special tokens they are Constant/Select,
// for the ends input with main tokens sequence it's Add/Subtract.
// If Add then it's a right truncation, if Subtract then it's a left truncation.
auto tmp_node = combine_seg_node->input_value(3*i + 1).get_node_shared_ptr();
if (ov::as_type_ptr<v1::Add>(tmp_node) || ov::as_type_ptr<v1::Subtract>(tmp_node)) {
number_of_main_tokens_inputs += 1;
add_or_sub_node = tmp_node;
}
}

// Exit if couldn't find main input or there are several.
if (number_of_main_tokens_inputs != 1) { return false; }

// Minimum between max_length and length of token sequence.
auto min_node = ov::as_type_ptr<v1::Minimum>(add_or_sub_node->get_input_node_shared_ptr(1));
if (!min_node) { return false; }

// constant containing final max_length - num_added tokens at the end of pipeline.
auto const_node = ov::as_type_ptr<v0::Constant>(min_node->get_input_node_shared_ptr(1));
if (!const_node) { return false; }

op::util::VariableInfo var_info{const_node->get_output_shape(0), const_node->get_output_element_type(0), MAX_LENGTH_VAR_ID};
auto variable_1 = std::make_shared<op::util::Variable>(var_info);

size_t num_added_tokens = num_segments - number_of_main_tokens_inputs;
// Constant which stores number of added_tokens.
auto num_added_tokens_const = std::make_shared<v0::Constant>(
const_node->get_output_element_type(0), const_node->get_output_shape(0), std::vector{num_added_tokens});

OPENVINO_ASSERT(const_node->get_element_type() == element::i32);
auto values = const_node->get_vector<int32_t>();
OPENVINO_ASSERT(values.size() == 1);
// Since const_node contain value = max_length - num_added tokens,
size_t default_max_length = values[0] + num_added_tokens;

auto default_max_length_const = std::make_shared<v0::Constant>(
const_node->get_output_element_type(0), const_node->get_output_shape(0), std::vector{default_max_length});

// Save targets before adding new target with ReadValue to avoid recursion.
auto target_inputs = const_node->output(0).get_target_inputs();
auto max_length_rv = std::make_shared<v6::ReadValue>(default_max_length_const, variable_1);
auto subtract_node = std::make_shared<v1::Subtract>(max_length_rv, num_added_tokens_const);

for (auto target_input : target_inputs) {
target_input.replace_source_output(subtract_node->output(0));
}

// We need to check if user requested to not add special tokens.
std::shared_ptr<v6::ReadValue> read_value_spec_tokens;
for (const auto& sink : model->get_sinks()) {
// Check if sink accepts input from Assign, and if that't the case get the ReadValus node input.
if (auto read_value = ov::as_type_ptr<v6::ReadValue>(sink->get_input_node_shared_ptr(0))) {
if (read_value->get_variable()->get_info().variable_id == ADD_SPECIAL_TOKENS_VAR_ID) {
read_value_spec_tokens = read_value;
break;
}
}
}

// If user requested to not add special tokens in order to correctly calculate
// truncation we need to enforce num_added_tokens to 0 regardless the hardcoded value of Constant.
if (read_value_spec_tokens && num_added_tokens_const) {
auto zero_constant = std::make_shared<v0::Constant>(ov::element::i32, ov::Shape{}, std::vector{0});
auto select_node = std::make_shared<v1::Select>(read_value_spec_tokens, num_added_tokens_const, zero_constant);
subtract_node->input(1).replace_source_output(select_node->output(0));
}

model->add_sinks({std::make_shared<v6::Assign>(max_length_rv, variable_1)});
model->add_variables({variable_1});

std::shared_ptr<ov::Node> ragged_to_dense_node;
for (auto node: model->get_ordered_ops()) {
if (strcmp(node->get_type_info().name, "RaggedToDense") == 0) {
ragged_to_dense_node = node;
}
}

if (!ragged_to_dense_node || ragged_to_dense_node->input_value(3).get_element_type() != ov::element::i32) {
return true; // true since at this point we already have modified the graph.s
}

auto variable_2 = std::make_shared<op::util::Variable>(op::util::VariableInfo{ov::Shape{1}, ov::element::boolean, PAD_TO_LONGEST_VAR_ID});

// By default do not pad to max_length
auto default_false_const = std::make_shared<v0::Constant>(ov::element::boolean, ov::Shape{1}, std::vector{false});
auto pad_to_max_length_rv = std::make_shared<v6::ReadValue>(default_false_const, variable_2);

auto zero_constant = std::make_shared<v0::Constant>(ov::element::i32, ov::Shape{}, std::vector{0});
auto select_node = std::make_shared<v1::Select>(pad_to_max_length_rv, max_length_rv, zero_constant);

auto max_op = std::make_shared<v1::Maximum>(ragged_to_dense_node->input_value(3), select_node);
ragged_to_dense_node->input(3).replace_source_output(max_op->output(0));

model->add_sinks({std::make_shared<v6::Assign>(pad_to_max_length_rv, variable_2)});
model->add_variables({variable_2});

return true;
}
21 changes: 17 additions & 4 deletions src/cpp/src/make_tokenizer_stateful.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "openvino/op/constant.hpp"
#include "openvino/pass/pass.hpp"
#include "openvino/pass/matcher_pass.hpp"

namespace ov {
namespace genai {
Expand Down Expand Up @@ -32,9 +33,19 @@ namespace genai {
* | CombineSegments |
* +-------------------------+
**/
class MakeCombineSegmentsSatateful : public ov::pass::ModelPass {
class MakeAddSpecialTokensSatateful : public ov::pass::ModelPass {
public:
OPENVINO_MODEL_PASS_RTTI("MakeCombineSegmentsSatateful");
OPENVINO_MODEL_PASS_RTTI("MakeAddSpecialTokensSatateful");
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
};

/**
* @brief This pass modifies tokenizer ov::Model so that inputs to RaggedToDense, CombineSegments
* become modifiable during runtime so that padding can be controlled.
*/
class MakePaddingSatateful : public ov::pass::ModelPass {
public:
OPENVINO_MODEL_PASS_RTTI("MakePaddingSatateful");
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
};

Expand Down Expand Up @@ -74,8 +85,10 @@ class MakeVocabDecoderSatateful : public ov::pass::ModelPass {
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
};

const std::string ADD_SPECIAL_TOKENS_VAR_ID = "add_special_tokens";
const std::string SKIP_SPECIAL_TOKENS_VAR_ID = "skip_special_tokens";
inline const std::string ADD_SPECIAL_TOKENS_VAR_ID = "add_special_tokens";
inline const std::string SKIP_SPECIAL_TOKENS_VAR_ID = "skip_special_tokens";
inline const std::string MAX_LENGTH_VAR_ID = "max_length";
inline const std::string PAD_TO_LONGEST_VAR_ID = "PAD_TO_LONGEST";

} // namespace genai
} // namespace ov
Loading

0 comments on commit be40053

Please sign in to comment.