Skip to content

Commit

Permalink
Fix Qwen2VL generation without images (#1645)
Browse files Browse the repository at this point in the history
Ticket: CVS-156940, CVS-161487

---------

Co-authored-by: Vladimir Zlobin <[email protected]>
  • Loading branch information
yatarkan and Wovchena authored Jan 30, 2025
1 parent b10ebcf commit 2c5064b
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 31 deletions.
51 changes: 22 additions & 29 deletions src/cpp/src/visual_language/inputs_embedder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1428,15 +1428,15 @@ std::vector<ov::Tensor> split_tokenize(const std::string& text, ov::genai::Token
return tokenized;
}

ov::Tensor insert_image_placeholders(const std::vector<ov::Tensor>& chunks, size_t tokens_per_image) {
ov::Tensor insert_image_placeholders(const std::vector<ov::Tensor>& chunks, const std::vector<size_t>& tokens_per_images) {
size_t merged_length = 0;
for (const ov::Tensor& chunk : chunks) {
merged_length += chunk.get_shape().at(1);
}
merged_length += chunks.empty() ? 0 : (chunks.size() - 1) * tokens_per_image;
merged_length += std::accumulate(tokens_per_images.begin(), tokens_per_images.end(), 0);
ov::Tensor merged{ov::element::i64, {1, merged_length}};
size_t offset = 0;
int64_t image_id = -1;
int64_t image_id = 0;
for (const ov::Tensor& chunk : chunks) {
size_t length = chunk.get_shape().at(1);
std::copy_n(
Expand All @@ -1448,11 +1448,11 @@ ov::Tensor insert_image_placeholders(const std::vector<ov::Tensor>& chunks, size
if (offset < merged_length) {
std::fill_n(
merged.data<int64_t>() + offset,
tokens_per_image,
image_id
tokens_per_images.at(image_id),
-image_id - 1 // It could be just -image_id. -1 is for consistency with the original implementation.
);
offset += tokens_per_image;
--image_id;
offset += tokens_per_images.at(image_id);
++image_id;
}
}
return merged;
Expand Down Expand Up @@ -1481,17 +1481,15 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder {
public:
ov::InferRequest m_hd_feature_transformer;
ov::InferRequest m_vision_projection;
// Used to insert <|image_i|>\n per image (not a slice).
size_t m_image_id = 1;
size_t m_tokens_per_image = 0;
std::vector<size_t> m_tokens_per_images;

InputsEmbedderPhi3V(
const VLMConfig& vlm_config,
const std::filesystem::path& model_dir,
const std::string& device,
const ov::AnyMap device_config
):
IInputsEmbedder(vlm_config, model_dir, device, device_config), m_image_id{0},
IInputsEmbedder(vlm_config, model_dir, device, device_config),
m_hd_feature_transformer{phi3_v::create_hd_feature_transformer()},
m_vision_projection{utils::singleton_core().compile_model(model_dir / "openvino_vision_projection_model.xml", device, {}).create_infer_request()} {}

Expand All @@ -1502,17 +1500,16 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder {
for (const ov::Tensor& image : to_single_image_tensors(images)) {
EncodedImage encoded_image = m_vision_encoder.encode(image);
images_features_proj.push_back(phi3_v::hd_feature_transform(encoded_image, m_hd_feature_transformer, m_vlm_config.sub_GN, m_vlm_config.glb_GN, m_vision_projection));
images_prompt << "<|image_" << m_image_id << "|>\n";
++m_image_id;
m_tokens_per_images.push_back(images_features_proj.back().get_shape().at(1));
images_prompt << "<|image_" << m_tokens_per_images.size() << "|>\n";
}
images_prompt << prompt;
std::vector<ov::Tensor> new_chat_tokens;
std::vector<ov::Tensor> prev_chat_tokens;
if (m_is_chat_conversation) {
m_history.push_back({{"role", "user"}, {"content", images_prompt.str()}});
constexpr bool add_generation_prompt = true;
std::string new_templated_chat_history;
new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
std::string new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
auto start_tokenizer_time = std::chrono::steady_clock::now();
new_chat_tokens = phi3_v::split_tokenize(new_templated_chat_history, m_tokenizer);
prev_chat_tokens = phi3_v::split_tokenize(m_templated_chat_history, m_tokenizer);
Expand All @@ -1525,19 +1522,15 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder {
auto end_tokenizer_time = std::chrono::steady_clock::now();
metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time));
}
if (0 == m_tokens_per_image && !images_features_proj.empty()) {
m_tokens_per_image = images_features_proj.at(0).get_shape().at(1);
}
ov::Tensor new_merged_tokens = phi3_v::insert_image_placeholders(new_chat_tokens, m_tokens_per_image);
ov::Tensor prev_merged_tokens = phi3_v::insert_image_placeholders(prev_chat_tokens, m_tokens_per_image);
ov::Tensor new_merged_tokens = phi3_v::insert_image_placeholders(new_chat_tokens, m_tokens_per_images);
ov::Tensor prev_merged_tokens = phi3_v::insert_image_placeholders(prev_chat_tokens, m_tokens_per_images);
ov::Tensor new_tokens = update_history(new_merged_tokens, prev_merged_tokens);
std::vector<ov::Tensor> tokens = phi3_v::drop_image_placeholders(new_tokens);
OPENVINO_ASSERT(tokens.size() == images_features_proj.size() + 1);
size_t features_length = 0;
for (size_t im_id = 0; im_id < images_features_proj.size(); ++im_id) {
size_t text_length = tokens.at(im_id).get_shape().at(1);
size_t im_length = images_features_proj.at(im_id).get_shape().at(1);
OPENVINO_ASSERT(im_length == m_tokens_per_image);
features_length += text_length + im_length;
}
features_length += tokens.back().get_shape().at(1);
Expand Down Expand Up @@ -1570,20 +1563,20 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder {
);

if (!m_is_chat_conversation) {
m_image_id = 0;
m_tokens_per_images.clear();
}

return inputs_embeds;
}

virtual void start_chat(const std::string& system_message) override {
IInputsEmbedder::start_chat(system_message);
m_image_id = 0;
m_tokens_per_images.clear();
}

virtual void finish_chat() override {
IInputsEmbedder::finish_chat();
m_image_id = 0;
m_tokens_per_images.clear();
}
};

Expand Down Expand Up @@ -1662,10 +1655,6 @@ class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder {
ov::Tensor input_ids = get_encoded_input_ids(formatted_prompt, metrics, chat_template_fallback);
ov::Tensor text_embeds = m_embedding.infer(input_ids);

if (images.empty()) {
return text_embeds;
}

auto start_tokenizer_time = std::chrono::steady_clock::now();
ov::Tensor encoded_vision_start_token = m_tokenizer.encode(m_vlm_config.vision_start_token, ov::genai::add_special_tokens(false)).input_ids;
ov::Tensor encoded_image_pad_token = m_tokenizer.encode(m_vlm_config.image_pad_token, ov::genai::add_special_tokens(false)).input_ids;
Expand All @@ -1680,6 +1669,10 @@ class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder {
int64_t position_ids_max_element = *std::max_element(m_position_ids.data<int64_t>(), m_position_ids.data<int64_t>() + m_position_ids.get_size());
m_rope_delta = position_ids_max_element + 1 - static_cast<int64_t>(input_ids.get_shape().at(1));

if (images.empty()) {
return text_embeds;
}

return merge_text_and_image_embeddings_qwen2vl(input_ids, text_embeds, image_embeds, images_grid_thw, image_pad_token_id);
}

Expand Down Expand Up @@ -1874,7 +1867,7 @@ class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder {
}

// Calculate rotary embeddings for max_grid_size
const size_t dim = 1280 / 16 / 2; // config.vision_config.embed_dim / self.config.vision_config.num_heads / 2
const size_t dim = m_vision_embeddings_merger.get_tensor("rotary_pos_emb").get_shape().at(1);
const float theta = 10000.0f;

std::vector<float> inv_freq(dim / 2);
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/visual_language/vision_encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ std::tuple<ov::Tensor, ImageSize> get_pixel_values_phi3_v(const ov::Tensor& imag

ImageSize smart_resize_qwen2vl(size_t height, size_t width, size_t factor, size_t min_pixels, size_t max_pixels) {
if (height < factor || width < factor) {
OPENVINO_THROW("Height or width must be larger than factor");
OPENVINO_THROW("Height (" + std::to_string(height) + ") and width (" + std::to_string(width) + ") must be greater than factor (" + std::to_string(factor) + ")");
}
if (std::max(height, width) / std::min(height, width) > 200) {
OPENVINO_THROW("Absolute aspect ratio must be smaller than 200");
Expand Down
2 changes: 1 addition & 1 deletion tests/python_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def get_image_by_link(link):
image = Image.open(requests.get(link, stream=True).raw)
if image.mode != 'RGB':
image = image.convert('RGB')
image_data = np.array((np.array(image.getdata()) - 128).astype(np.byte)).reshape(1, 3, image.size[1], image.size[0])
image_data = np.array((np.array(image.getdata()) - 128).astype(np.byte)).reshape(1, image.size[1], image.size[0], 3)
return Tensor(image_data)


Expand Down
2 changes: 2 additions & 0 deletions tests/python_tests/test_vlm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def get_ov_model(model_id, cache):
@pytest.mark.parametrize("model_id", [
"katuni4ka/tiny-random-minicpmv-2_6",
"katuni4ka/tiny-random-phi3-vision",
"katuni4ka/tiny-random-llava",
"katuni4ka/tiny-random-qwen2vl",
])
def test_vlm_pipeline(model_id, cache):
def streamer(word: str) -> bool:
Expand Down

0 comments on commit 2c5064b

Please sign in to comment.