-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve qwen vl impl #2943
Improve qwen vl impl #2943
Conversation
This PR improves the performance and response from qwen2 vl based models. Small reproducible examples can be run with the startup commands and script below expected outputtext-generation-launcher --model-id bytedance-research/UI-TARS-7B-DPO {
"generated_text": "The image depicts the Statue of Liberty, a renowned landmark located on Liberty Island in New York Bay."
}
{
"generated_text": "The image features the logo of Flash Attention, a state-of-the-art attention mechanism designed for transformers,"
}
{
"generated_text": "The image features a stylized illustration of a rabbit, rendered in a minimalist and abstract design. The"
} text-generation-launcher --model-id Qwen/Qwen2-VL-2B-Instruct {
"generated_text": "The image depicts the iconic Statue of Liberty in New York City, with the city's skyline in the"
}
{
"generated_text": "The image compares two different implementations of attention mechanisms in neural networks:\n\n### Standard Attention Implementation"
}
{
"generated_text": "The image depicts a rabbit in an astronaut's suit standing on a rocky, red-brown planet with"
} text-generation-launcher --model-id Qwen/Qwen2-VL-7B-Instruct --num-shard 2 {
"generated_text": "The image depicts the iconic Statue of Liberty, a colossal neoclassical sculpture on Liberty Island in"
}
{
"generated_text": "The image compares the standard attention implementation with Flash Attention in the context of memory and computation operations.\n\n###"
}
{
"generated_text": "The image depicts an astronaut in a futuristic space suit standing on a rocky surface with a reddish-orange"
} script for testing a couple images with the models import requests
import json
url = "http://127.0.0.1:3000/generate"
headers = {"Content-Type": "application/json"}
image_urls = [
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/flash-attn.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/rabbit.png",
]
for image in image_urls:
query = "Describe the image"
payload = {
"inputs": f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n",
"parameters": {"max_new_tokens": 20},
}
response = requests.post(url, headers=headers, json=payload)
# print the response
print(json.dumps(response.json(), indent=4)) |
@@ -1248,7 +1248,7 @@ def get_model( | |||
revision=revision, | |||
quantize=quantize, | |||
speculator=speculator, | |||
dtype=dtype, | |||
dtype=torch.bfloat16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be passed through with default_dtype
I think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah good catch! updating to use default_dtype
in the latest commit
|
||
# only apply mrope if sections are provided and the rope type is mrope or default | ||
if mrope_section is not None and ( | ||
rope_type == "mrope" or rope_type == "default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we modifying the default rope ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a mrope section is not enough. default
really means default
, not use mrope
under X condition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea makes sense I've updated to only match if rope_type == "mrope"
and ensure the rope_type
is set to mrope
within the qwen2-vl model.
note: currently transformers AutoModel
returns rope_type
default for all qwen-vl models
launcher/src/main.rs
Outdated
let text_flops = layer_flops * num_layers; | ||
|
||
tracing::debug!("Text flops: {}", human_size(text_flops as usize, "flop")); | ||
|
||
// text-only case | ||
if self.vision_config.is_none() { | ||
return Some(text_flops); | ||
} | ||
|
||
let vision_config = self.vision_config.as_ref().unwrap(); | ||
|
||
// estimate vision flops for specific model types | ||
match self.model_type.as_deref() { | ||
Some("qwen2_vl") => { | ||
let in_chans = vision_config.in_chans? as u64; | ||
let patch_size = vision_config.patch_size? as u64; | ||
let embed_dim = vision_config.embed_dim? as u64; | ||
let vision_depth = vision_config.depth? as u64; | ||
let mlp_ratio = vision_config.mlp_ratio? as u64; | ||
let temporal_patch_size = vision_config.temporal_patch_size? as u64; | ||
// 1. patch embedding: | ||
// - conv3d operation: (t*h*w) * (k_t*k_h*k_w) * c_in * c_out * 2 | ||
// where the 2 accounts for multiply-add | ||
let patch_flops = | ||
2 * temporal_patch_size * patch_size.pow(2) * embed_dim * in_chans; | ||
// 2. self-attention + mlp: | ||
// - qkv projections: 3 * d_model * d_model * 2 | ||
// - attention: d_model * d_model * 2 | ||
// - mlp: 2 * d_model * (mlp_ratio * d_model) * 2 | ||
// simplified to: 2 * d_model * (4 + mlp_ratio * d_model) | ||
let attn_flops = 2 * embed_dim * (4 + mlp_ratio * embed_dim); | ||
// 3. add with layer norm flops for total vision layer flops | ||
let layer_flops = patch_flops + attn_flops + 2 * embed_dim; | ||
let vision_flops = layer_flops * vision_depth; | ||
tracing::debug!( | ||
"Vision flops: {}", | ||
human_size(vision_flops as usize, "flop") | ||
); | ||
Some(text_flops + vision_flops) | ||
} | ||
// model has a vision config but is not supported for flops calculation | ||
// we return None to avoid overestimating the memory requirements | ||
_ => None, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why ?
Let's not add specific models here.
If you want to patch it, patch it in a consistent way where we can actually add every vlm models.
But I think this belong in a subsequent PR where we can add at least a few VLMs (they should all be mostly clip so it should be easy to add).
The code here is on purpose not model dependent so that it more general (since at first approximations, all models are exactly the same)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sense, will followup with another PR that avoid the specific model
split_cos, split_sin = [ | ||
torch.split(t, self.sections, dim=-1) for t in (cos, sin) | ||
] | ||
cos = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1).unsqueeze( | ||
1 | ||
) | ||
sin = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1).unsqueeze( | ||
1 | ||
) | ||
# prepare input tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is data movement that should be done ahead of time in get_cos_sin
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, and moved into get_cos_sin
in the latest commit
q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)] | ||
rotary_dim = cos.shape[-1] | ||
q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim] | ||
q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1) | ||
k2 = torch.cat((-k[..., rotary_dim // 2 :], k[..., : rotary_dim // 2]), dim=-1) | ||
|
||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True) | ||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is extremely confusing code.
Do we have any reference at least to check ?
It also feels like also this data movement could be avoided by having the correct Q, K in the first place (probably we need to load Q, K differently on the GPU)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, I've simplified the code to avoid the transpose and simply rotate q and k and pass them to rotary_emb.apply_rotary
in the latest commits
@@ -1427,7 +1431,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): | |||
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage" | |||
) | |||
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] | |||
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] | |||
position_ids = self.cuda_graphs[max_bs]["position_ids"][..., :bs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm very wary about this line.
position_ids should have BS as a first dimension. we can probably change the code around to put the multimodal in second dimension.
The issue with this code, is that we have no way of knowing what kind of shape position ids is, and therefore if the slicing is actually valid.
As a cuda graph, the values SHOULD be always zero anyway, we should never initialize any values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea agreed, i've refactored the position ids to always have the batch in the first dim and fixed the indexing to slice correctly (which also fixes all issues with cuda graphs)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) | ||
if hasattr(self.model, "get_position_ids"): | ||
# use model specific position ids for initialization | ||
position_ids = self.model.get_position_ids(input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why cannot we put this INTO qwenl2-vl, so that it only applies to Qwen2LV.
Changing position ids shape + indirection is quite a high cost in my book.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently the position ids are not mutated by the model before they are placed on the batch, which is why this logic lives outside of the model.
I've improved the logic a bit to expand position_ids
if rope_scaling exists and the type is mrope, in this case the position ids are expanded to the size of the sections.
This isn't a huge change from the original code but feel a bit cleaner. Please let me know if theres a better solution 🙏
…ump vlm default token limit
dca2f12
to
79550f8
Compare
cos_c = torch.stack( | ||
[self._cos_cached[position_ids[:, i]] for i in range(3)], dim=0 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if the looping/stacking is necessary? I think we can do this with a single gather?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, is this reliably 3? Or should it be the length of mrope_section
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed! I reworked the logic to avoid the loop/stacking and now cache a tensor of the indices thats used to gather in get_cos_sin
. This approach removed a lot of the extra reshaping and limits the initialization. Thanks!
split_cos = torch.split(cos_c, self.sections, dim=-1) | ||
split_sin = torch.split(sin_c, self.sections, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do the sections here relate to the sections above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these sections relate to the temporal, height and width embeddings for images and effectively are the number of each index to select from cos/sin.
ie. if cos
is torch.Size([10000,3,64])
and sections=[16,24,24]
the values would be cos[:, 0, :16]
, cos[:, 1, 16:16+24]
... and concatenated to be torch.Size([10000, 64])
however as noted above the logic has been updated to simply initialize a tensor and then use gather to select the correct elements
server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py
Outdated
Show resolved
Hide resolved
server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py
Outdated
Show resolved
Hide resolved
server/text_generation_server/models/custom_modeling/qwen2_vl.py
Outdated
Show resolved
Hide resolved
…tion_ids
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the changes, looks good to me! Since this is a fairly large change, it would be good if someone else can do a final sanity check (@Narsil ?).
@@ -378,8 +377,12 @@ def __init__(self, prefix, config, weights): | |||
self.config = config | |||
config.vision_config.quantize = None | |||
config.vision_config.speculator = config.speculator | |||
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly | |||
# returns rope_scaling.type == "default" for Qwen2-VL model at the moment | |||
config.rope_scaling.update({"rope_type": "mrope"}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we do that only if the rope_type
is actually default
? Otherwise this look better.
@@ -527,6 +520,7 @@ def forward( | |||
|
|||
# apply the visual model to the pixel values if they are provided | |||
if pixel_values is not None and len(pixel_values) > 0: | |||
pixel_values = pixel_values.to(inputs_embeds.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
This is much better !
This PR improves qwen2-vl in the following ways
max_tokens
)RotaryPositionEmbeddingMultimodalSections
inrotary.py