From 77ef54306128555c68e497f75f2e4e6607c3a2d4 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 21 Jan 2025 22:31:22 +0000 Subject: [PATCH 01/18] feat: refactor model, improve startup and re enable tests --- .../test_flash_qwen2_vl_simple.json | 12 +- .../test_flash_qwen2_vl_simple_streaming.json | 4 +- .../models/test_flash_qwen2_vl.py | 159 +++++++++--------- launcher/src/main.rs | 68 ++++++-- .../text_generation_server/layers/rotary.py | 77 ++++++++- .../text_generation_server/models/__init__.py | 2 +- .../custom_modeling/flash_qwen2_modeling.py | 45 ++--- .../models/custom_modeling/qwen2_vl.py | 12 +- .../models/flash_causal_lm.py | 4 + 9 files changed, 244 insertions(+), 139 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json index 131631e65c8..5162833fe92 100644 --- a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.", + "content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The hitch is in a dynamic pose, with its hands on its hips and legs slightly apart, giving it an imposing stance.", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1730164250, + "created": 1737498164, "id": "", "model": "Qwen/Qwen2-VL-7B-Instruct", "object": "chat.completion", - "system_fingerprint": "2.4.2-dev0-native", + "system_fingerprint": "3.0.2-dev0-native", "usage": { - "completion_tokens": 58, - "prompt_tokens": 349, - "total_tokens": 407 + "completion_tokens": 68, + "prompt_tokens": 1364, + "total_tokens": 1432 } } diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json index 3e2faca714d..b3b14539bca 100644 --- a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json @@ -11,10 +11,10 @@ "logprobs": null } ], - "created": 1730416361, + "created": 1737498227, "id": "", "model": "Qwen/Qwen2-VL-7B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.4.2-dev0-native", + "system_fingerprint": "3.0.2-dev0-native", "usage": null } diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py index 97a533fc5d4..7d51d20d3af 100644 --- a/integration-tests/models/test_flash_qwen2_vl.py +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -1,81 +1,78 @@ -# Disabled because it's broken. -# import pytest -# -# -# @pytest.fixture(scope="module") -# def flash_qwen2_vl_handle(launcher): -# with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: -# yield handle -# -# -# @pytest.fixture(scope="module") -# async def flash_qwen2(flash_qwen2_vl_handle): -# await flash_qwen2_vl_handle.health(300) -# return flash_qwen2_vl_handle.client -# -# -# @pytest.mark.private -# async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): -# response = await flash_qwen2.chat( -# max_tokens=100, -# seed=42, -# messages=[ -# { -# "role": "user", -# "content": [ -# { -# "type": "image_url", -# "image_url": { -# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" -# }, -# }, -# {"type": "text", "text": "Describe this image."}, -# ], -# }, -# ], -# ) -# -# assert ( -# response.choices[0].message.content -# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." -# ) -# -# assert response == response_snapshot -# -# -# @pytest.mark.private -# async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot): -# responses = await flash_qwen2.chat( -# max_tokens=100, -# seed=42, -# messages=[ -# { -# "role": "user", -# "content": [ -# { -# "type": "image_url", -# "image_url": { -# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" -# }, -# }, -# {"type": "text", "text": "Describe this image."}, -# ], -# }, -# ], -# stream=True, -# ) -# -# count = 0 -# generated = "" -# last_response = None -# async for response in responses: -# count += 1 -# generated += response.choices[0].delta.content -# last_response = response -# -# assert ( -# generated -# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." -# ) -# assert count == 58 -# assert last_response == response_snapshot +import pytest + + +@pytest.fixture(scope="module") +def flash_qwen2_vl_handle(launcher): + with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_qwen2(flash_qwen2_vl_handle): + await flash_qwen2_vl_handle.health(300) + return flash_qwen2_vl_handle.client + + +@pytest.mark.private +async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): + response = await flash_qwen2.chat( + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + ], + ) + + assert ( + response.choices[0].message.content + == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The hitch is in a dynamic pose, with its hands on its hips and legs slightly apart, giving it an imposing stance." + ) + + assert response == response_snapshot + + +@pytest.mark.private +async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot): + responses = await flash_qwen2.chat( + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + ], + stream=True, + ) + + count = 0 + generated = "" + last_response = None + async for response in responses: + count += 1 + generated += response.choices[0].delta.content + last_response = response + + assert ( + generated + == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The hitch is in a dynamic pose, with its hands on its hips and legs slightly apart, giving it an imposing stance." + ) + assert count == 68 + assert last_response == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 05ed0202518..8e93b1b2d2d 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -230,7 +230,14 @@ struct QuantizationConfig { } #[derive(Debug, Deserialize)] -struct VisionConfig {} +struct VisionConfig { + depth: usize, + embed_dim: usize, + mlp_ratio: usize, + in_chans: usize, + patch_size: usize, + temporal_patch_size: usize, +} #[derive(Debug, Deserialize)] struct Config { @@ -253,11 +260,6 @@ struct Config { impl Config { fn flop(&self) -> Option { - if self.vision_config.is_some() { - // VLM are much harder to predict and VRAM requirements - // Are more complex. - return None; - } let num_heads = self.num_heads? as u64; let num_kv_heads = self.num_kv_heads? as u64; let head_dim = self.head_dim? as u64; @@ -277,8 +279,38 @@ impl Config { let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size; let layer_flops = attn_layer_flops + gate_up_down_flops; - let total = layer_flops * num_layers; - Some(total) + let text_flops = layer_flops * num_layers; + + tracing::debug!("Text flops: {}", human_size(text_flops as usize, "flop")); + + if let Some(vision_config) = self.vision_config.as_ref() { + let in_chans = vision_config.in_chans as u64; + let patch_size = vision_config.patch_size as u64; + let embed_dim = vision_config.embed_dim as u64; + let vision_depth = vision_config.depth as u64; + let mlp_ratio = vision_config.mlp_ratio as u64; + let temporal_patch_size = vision_config.temporal_patch_size as u64; + // 1. patch embedding: + // - conv3d operation: (t*h*w) * (k_t*k_h*k_w) * c_in * c_out * 2 + // where the 2 accounts for multiply-add + let patch_flops = 2 * temporal_patch_size * patch_size.pow(2) * embed_dim * in_chans; + // 2. self-attention + mlp: + // - qkv projections: 3 * d_model * d_model * 2 + // - attention: d_model * d_model * 2 + // - mlp: 2 * d_model * (mlp_ratio * d_model) * 2 + // simplified to: 2 * d_model * (4 + mlp_ratio * d_model) + let attn_flops = 2 * embed_dim * (4 + mlp_ratio * embed_dim); + // 3. add with layer norm flops for total vision layer flops + let layer_flops = patch_flops + attn_flops + 2 * embed_dim; + let vision_flops = layer_flops * vision_depth; + tracing::debug!( + "Vision flops: {}", + human_size(vision_flops as usize, "flop") + ); + Some(text_flops + vision_flops) + } else { + Some(text_flops) + } } fn kv_vram_per_tok(&self) -> Option { @@ -2012,6 +2044,10 @@ fn main() -> Result<(), LauncherError> { let config: Option = get_config(&args.model_id, &args.revision).ok(); let quantize = config.as_ref().and_then(|c| c.quantize); + let model_type = config + .as_ref() + .and_then(|c| c.model_type.as_deref()) + .map(|s| s.to_owned()); // Quantization usually means you're even more RAM constrained. let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); @@ -2100,8 +2136,20 @@ fn main() -> Result<(), LauncherError> { vec![] } _ => { - let cuda_graphs = vec![1, 2, 4, 8, 16, 32]; - tracing::info!("Using default cuda graphs {cuda_graphs:?}"); + let default_cuda_graphs = vec![1, 2, 4, 8, 16, 32]; + tracing::info!("Using default CUDA graphs: {:?}", default_cuda_graphs); + let cuda_graphs = match model_type.as_deref() { + Some("qwen2_vl") => { + tracing::warn!( + "Qwen VL model detected - restricting CUDA graphs to values >= 3" + ); + default_cuda_graphs + .into_iter() + .filter(|&c| c >= 3) + .collect() + } + _ => default_cuda_graphs, + }; cuda_graphs } }; diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index e346d0f8946..f3ec1f628b4 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -90,7 +90,11 @@ def static(cls, config, dim, base, device): if rope_type == "linear": pass elif rope_type == "default": - pass + if rope_scaling.get("mrope_section", False): + mrope_section = rope_scaling.get("mrope_section") + return RotaryPositionEmbeddingMultimodalSections( + inv_freq, scaling_factor, mrope_section + ) elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -548,3 +552,74 @@ def apply_llama3_scaling( new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): + def __init__(self, inv_freq, scaling_factor, sections): + super().__init__(inv_freq, scaling_factor) + self.sections = sections * 2 + self._cos_cached = None + self._sin_cached = None + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + mrope_section = self.sections + unsqueeze_dim = 1 + + split_cos = cos.split(mrope_section, dim=-1) + split_sin = sin.split(mrope_section, dim=-1) + + cos = [] + for i, m in enumerate(split_cos): + cos.append(m[i % 3]) + + cos = torch.cat(cos, dim=-1).unsqueeze(unsqueeze_dim) + + sin = [] + for i, m in enumerate(split_sin): + sin.append(m[i % 3]) + + sin = torch.cat(sin, dim=-1).unsqueeze(unsqueeze_dim) + + q = query.transpose(0, 1).unsqueeze(0) + k = key.transpose(0, 1).unsqueeze(0) + + rotary_dim = cos.shape[-1] + q1 = q[..., :rotary_dim] + q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1) + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True) + + k1 = k[..., :rotary_dim] + k2 = torch.cat((-k[..., rotary_dim // 2 :], k[..., : rotary_dim // 2]), dim=-1) + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, True) + + def get_cos_sin( + self, + position_ids: torch.Tensor, + max_s: int, + dtype: torch.dtype, + ): + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + inv_freq_expanded = ( + self.inv_freq[None, None, :, None] + .float() + .expand(3, position_ids.shape[1], -1, 1) + ) + + position_ids_expanded = position_ids[ + :, :, None, : + ].float() # shape (3, bs, 1, positions) + + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( + 2, 3 + ) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype), sin.to(dtype) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 205030e95ac..85c98bfde75 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1362,7 +1362,7 @@ def get_model( revision=revision, quantize=quantize, speculator=speculator, - dtype=dtype, + dtype=torch.bfloat16, kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index cc4039b1cbc..78ae3020cb8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -61,11 +61,6 @@ def __init__( config.sliding_window if config.sliding_window is not None else -1 ) self.num_heads = config.num_attention_heads - self.mrope_section = ( - config.rope_scaling.get("mrope_section", None) - if config.rope_scaling is not None - else None - ) self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -127,17 +122,6 @@ def forward( query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - if self.mrope_section is not None: - # if mrope_section is set, we need to split the cos and sin into 3 parts and concatenate them in a specific order - cos = torch.cat( - [m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], - dim=-1, - ) - sin = torch.cat( - [m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], - dim=-1, - ) - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) if prefill_cache_indices is not None: @@ -251,7 +235,8 @@ def forward( max_s, prefill_cache_indices, ): - normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + residual = hidden_states + normed_hidden_states, _ = self.input_layernorm(hidden_states) # Self Attention attn_output = self.self_attn( @@ -266,15 +251,14 @@ def forward( max_s, prefill_cache_indices, ) + hidden_states = attn_output + residual # faster post attention rms norm - normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res - ) - - mlp_output = self.mlp(normed_attn_res_output) - - return mlp_output, attn_res + residual = hidden_states + hidden_states, _ = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(hidden_states) + hidden_states = mlp_output + residual + return hidden_states class Qwen2Model(torch.nn.Module): @@ -322,18 +306,15 @@ def forward( ) -> torch.Tensor: hidden_states = inputs_embeds - # flatten position ids from 2D to 1D cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids.flatten(), true_max_s, hidden_states.dtype + position_ids, + true_max_s, + hidden_states.dtype, ) - # reshape back to 2D if the position_ids were 2D - if position_ids.size(0) != cos.size(0): - cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) - sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) residual = None for i, layer in enumerate(self.layers): - hidden_states, residual = layer( + hidden_states = layer( hidden_states, residual, cos, @@ -347,7 +328,7 @@ def forward( prefill_cache_indices, ) - hidden_states, _ = self.norm(hidden_states, residual) + hidden_states, _ = self.norm(hidden_states) return hidden_states diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index a8e1e8c1593..e0ae19df766 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -222,12 +222,11 @@ def __init__(self, prefix, config, weights): def forward( self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen ) -> torch.Tensor: - hidden_states_post_norm1, res = self.norm1(hidden_states) - hidden_states = hidden_states + self.attn( - hidden_states_post_norm1, cu_seqlens, rotary_pos_emb, max_seqlen - ) - hidden_states_post_norm2, res = self.norm2(hidden_states) - hidden_states = hidden_states + self.mlp(hidden_states_post_norm2) + norm1_out, _ = self.norm1(hidden_states) + attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + hidden_states = hidden_states + attn_out + norm2_out, _ = self.norm2(hidden_states) + hidden_states = hidden_states + self.mlp(norm2_out) return hidden_states @@ -527,6 +526,7 @@ def forward( # apply the visual model to the pixel values if they are provided if pixel_values is not None and len(pixel_values) > 0: + pixel_values = pixel_values.to(inputs_embeds.dtype) if pixel_values is not None: image_embeds = self.visual( pixel_values, grid_thw=image_grid_thw diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1073f4f9c4a..6bc3c2caf2b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1486,6 +1486,10 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): state=state, cache_lengths_tensor=cache_lengths_tensor, ): + # in the case of N dimensional position ids we need to slice the + # position ids to match the input_ids size for cuda graphs warmup + position_ids = position_ids[..., : input_ids.shape[0]] + seqlen = Seqlen( input_lengths=input_lengths_tensor, cache_lengths=cache_lengths_tensor, From d12e075966d5b1f22871a7fe5b82ef8ebbd8677f Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 22 Jan 2025 16:43:53 +0000 Subject: [PATCH 02/18] fix: improve multimodal rotary embed caching --- .../text_generation_server/layers/rotary.py | 82 ++++++++++--------- 1 file changed, 42 insertions(+), 40 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index f3ec1f628b4..061bf024ba0 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -557,6 +557,8 @@ def apply_llama3_scaling( class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): def __init__(self, inv_freq, scaling_factor, sections): super().__init__(inv_freq, scaling_factor) + # expand the inv_freq for the 3 sections + self.inv_freq_exp = inv_freq[None, None, :, None].expand(3, -1, -1, 1) self.sections = sections * 2 self._cos_cached = None self._sin_cached = None @@ -568,36 +570,41 @@ def forward( cos: torch.Tensor, sin: torch.Tensor, ): - mrope_section = self.sections - unsqueeze_dim = 1 - - split_cos = cos.split(mrope_section, dim=-1) - split_sin = sin.split(mrope_section, dim=-1) - - cos = [] - for i, m in enumerate(split_cos): - cos.append(m[i % 3]) - - cos = torch.cat(cos, dim=-1).unsqueeze(unsqueeze_dim) - - sin = [] - for i, m in enumerate(split_sin): - sin.append(m[i % 3]) - - sin = torch.cat(sin, dim=-1).unsqueeze(unsqueeze_dim) - - q = query.transpose(0, 1).unsqueeze(0) - k = key.transpose(0, 1).unsqueeze(0) - + # process multi-modal rotary embeddings + split_cos, split_sin = [ + torch.split(t, self.sections, dim=-1) for t in (cos, sin) + ] + cos = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1).unsqueeze( + 1 + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1).unsqueeze( + 1 + ) + # prepare input tensors + q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)] rotary_dim = cos.shape[-1] - q1 = q[..., :rotary_dim] + q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim] q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1) - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True) - - k1 = k[..., :rotary_dim] k2 = torch.cat((-k[..., rotary_dim // 2 :], k[..., : rotary_dim // 2]), dim=-1) + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, True) + def _update_cos_sin_cache(self, dtype, device, seqlen): + # always cache the cos/sin for the full sequence length to avoid + # recomputing if the sequence length is smaller than the cached one + if ( + seqlen > self._seq_len_cached + or self._cos_cached_exp.device != device + or self._cos_cached_exp.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + freqs = freqs.expand(3, -1, -1) + self._cos_cached_exp = freqs.cos().to(dtype) + self._sin_cached_exp = freqs.sin().to(dtype) + def get_cos_sin( self, position_ids: torch.Tensor, @@ -605,21 +612,16 @@ def get_cos_sin( dtype: torch.dtype, ): self._update_cos_sin_cache(dtype, position_ids.device, max_s) - - inv_freq_expanded = ( - self.inv_freq[None, None, :, None] - .float() - .expand(3, position_ids.shape[1], -1, 1) + # expand the position_ids to match the shape of the cached cos/sin + indices = ( + position_ids.squeeze(1) + .unsqueeze(-1) + .expand(-1, -1, self._cos_cached_exp.shape[-1]) ) + cos_c = torch.gather(self._cos_cached_exp, 1, indices) + cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1) - position_ids_expanded = position_ids[ - :, :, None, : - ].float() # shape (3, bs, 1, positions) + sin_c = torch.gather(self._sin_cached_exp, 1, indices) + sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1) - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( - 2, 3 - ) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype), sin.to(dtype) + return cos_c, sin_c From a0ab962b6d7a804ad466298af69d3402e7536ed9 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 22 Jan 2025 18:30:03 +0000 Subject: [PATCH 03/18] fix: limit vision flop calc to qwen2 vl models and update config typing --- launcher/src/main.rs | 78 +++++++++++-------- .../text_generation_server/layers/rotary.py | 16 ++-- 2 files changed, 56 insertions(+), 38 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8e93b1b2d2d..6cbdb1d64cd 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -231,12 +231,12 @@ struct QuantizationConfig { #[derive(Debug, Deserialize)] struct VisionConfig { - depth: usize, - embed_dim: usize, - mlp_ratio: usize, - in_chans: usize, - patch_size: usize, - temporal_patch_size: usize, + depth: Option, + embed_dim: Option, + mlp_ratio: Option, + in_chans: Option, + patch_size: Option, + temporal_patch_size: Option, } #[derive(Debug, Deserialize)] @@ -283,33 +283,45 @@ impl Config { tracing::debug!("Text flops: {}", human_size(text_flops as usize, "flop")); - if let Some(vision_config) = self.vision_config.as_ref() { - let in_chans = vision_config.in_chans as u64; - let patch_size = vision_config.patch_size as u64; - let embed_dim = vision_config.embed_dim as u64; - let vision_depth = vision_config.depth as u64; - let mlp_ratio = vision_config.mlp_ratio as u64; - let temporal_patch_size = vision_config.temporal_patch_size as u64; - // 1. patch embedding: - // - conv3d operation: (t*h*w) * (k_t*k_h*k_w) * c_in * c_out * 2 - // where the 2 accounts for multiply-add - let patch_flops = 2 * temporal_patch_size * patch_size.pow(2) * embed_dim * in_chans; - // 2. self-attention + mlp: - // - qkv projections: 3 * d_model * d_model * 2 - // - attention: d_model * d_model * 2 - // - mlp: 2 * d_model * (mlp_ratio * d_model) * 2 - // simplified to: 2 * d_model * (4 + mlp_ratio * d_model) - let attn_flops = 2 * embed_dim * (4 + mlp_ratio * embed_dim); - // 3. add with layer norm flops for total vision layer flops - let layer_flops = patch_flops + attn_flops + 2 * embed_dim; - let vision_flops = layer_flops * vision_depth; - tracing::debug!( - "Vision flops: {}", - human_size(vision_flops as usize, "flop") - ); - Some(text_flops + vision_flops) - } else { - Some(text_flops) + // text-only case + if self.vision_config.is_none() { + return Some(text_flops); + } + + let vision_config = self.vision_config.as_ref().unwrap(); + + // estimate vision flops for specific model types + match self.model_type.as_deref() { + Some("qwen2_vl") => { + let in_chans = vision_config.in_chans? as u64; + let patch_size = vision_config.patch_size? as u64; + let embed_dim = vision_config.embed_dim? as u64; + let vision_depth = vision_config.depth? as u64; + let mlp_ratio = vision_config.mlp_ratio? as u64; + let temporal_patch_size = vision_config.temporal_patch_size? as u64; + // 1. patch embedding: + // - conv3d operation: (t*h*w) * (k_t*k_h*k_w) * c_in * c_out * 2 + // where the 2 accounts for multiply-add + let patch_flops = + 2 * temporal_patch_size * patch_size.pow(2) * embed_dim * in_chans; + // 2. self-attention + mlp: + // - qkv projections: 3 * d_model * d_model * 2 + // - attention: d_model * d_model * 2 + // - mlp: 2 * d_model * (mlp_ratio * d_model) * 2 + // simplified to: 2 * d_model * (4 + mlp_ratio * d_model) + let attn_flops = 2 * embed_dim * (4 + mlp_ratio * embed_dim); + // 3. add with layer norm flops for total vision layer flops + let layer_flops = patch_flops + attn_flops + 2 * embed_dim; + let vision_flops = layer_flops * vision_depth; + tracing::debug!( + "Vision flops: {}", + human_size(vision_flops as usize, "flop") + ); + Some(text_flops + vision_flops) + } + // model has a vision config but is not supported for flops calculation + // we return None to avoid overestimating the memory requirements + _ => return None, } } diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 061bf024ba0..9f1770ff6b0 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -86,15 +86,21 @@ def static(cls, config, dim, base, device): # `rope_type` is now standard in transformers, but some existing models # have `type` instead. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) + mrope_section = rope_scaling.get("mrope_section", None) + + # only apply mrope if sections are provided and the rope type is mrope or default + if mrope_section is not None and ( + rope_type == "mrope" or rope_type == "default" + ): + mrope_section = rope_scaling.get("mrope_section") + return RotaryPositionEmbeddingMultimodalSections( + inv_freq, scaling_factor, mrope_section + ) if rope_type == "linear": pass elif rope_type == "default": - if rope_scaling.get("mrope_section", False): - mrope_section = rope_scaling.get("mrope_section") - return RotaryPositionEmbeddingMultimodalSections( - inv_freq, scaling_factor, mrope_section - ) + pass elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( From cf5c66043e43bc886d2b94d52591d033da7e751e Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 22 Jan 2025 18:38:07 +0000 Subject: [PATCH 04/18] fix: include clippy lint --- launcher/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 6cbdb1d64cd..7326fd95f2c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -321,7 +321,7 @@ impl Config { } // model has a vision config but is not supported for flops calculation // we return None to avoid overestimating the memory requirements - _ => return None, + _ => None, } } From 7ab99bc6b3ae44362658de6f4eaa41c8861f4c8b Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 22 Jan 2025 20:51:20 +0000 Subject: [PATCH 05/18] feat: refactor position ids in warmup and bump tests --- ...essed_tensors_w8a8_int_dynamic_weight.json | 422 +++++++++++++++++- ...rs_w8a8_int_dynamic_weight_all_params.json | 50 +-- ..._tensors_w8a8_int_dynamic_weight_load.json | 80 ++-- ...pressed_tensors_w8a8_int_dynamic_weight.py | 9 +- .../models/flash_causal_lm.py | 20 +- 5 files changed, 485 insertions(+), 96 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight.json index 2525f72cd14..7dbfc627c20 100644 --- a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight.json +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight.json @@ -1,73 +1,469 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 76, "prefill": [], "seed": null, "tokens": [ { "id": 18183, - "logprob": -1.6669922, + "logprob": -1.5195312, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.08959961, + "logprob": -0.06817627, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.14685059, + "logprob": -0.13122559, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.125, + "logprob": -0.13415527, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.81640625, + "logprob": -0.8769531, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0013418198, + "logprob": -0.0011396408, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.16027832, + "logprob": -0.16442871, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0016393661, + "logprob": -0.0026416779, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.4477539, + "logprob": -0.48754883, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2802734, + "logprob": -1.2294922, "special": false, "text": " uses" + }, + { + "id": 29728, + "logprob": -0.66503906, + "special": false, + "text": " neural" + }, + { + "id": 14155, + "logprob": -0.02960205, + "special": false, + "text": " networks" + }, + { + "id": 311, + "logprob": -0.7236328, + "special": false, + "text": " to" + }, + { + "id": 3960, + "logprob": -1.1914062, + "special": false, + "text": " learn" + }, + { + "id": 504, + "logprob": -0.7089844, + "special": false, + "text": " from" + }, + { + "id": 821, + "logprob": -0.7729492, + "special": false, + "text": " data" + }, + { + "id": 13, + "logprob": -0.7836914, + "special": false, + "text": "." + }, + { + "id": 1084, + "logprob": -0.9941406, + "special": false, + "text": " It" + }, + { + "id": 374, + "logprob": -0.52441406, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.9511719, + "special": false, + "text": " a" + }, + { + "id": 943, + "logprob": -0.8642578, + "special": false, + "text": " type" + }, + { + "id": 315, + "logprob": -0.00030231476, + "special": false, + "text": " of" + }, + { + "id": 20443, + "logprob": -0.14416504, + "special": false, + "text": " artificial" + }, + { + "id": 11229, + "logprob": -0.013824463, + "special": false, + "text": " intelligence" + }, + { + "id": 429, + "logprob": -0.18762207, + "special": false, + "text": " that" + }, + { + "id": 646, + "logprob": -1.0087891, + "special": false, + "text": " can" + }, + { + "id": 3960, + "logprob": -0.90234375, + "special": false, + "text": " learn" + }, + { + "id": 504, + "logprob": -0.54345703, + "special": false, + "text": " from" + }, + { + "id": 323, + "logprob": -1.0400391, + "special": false, + "text": " and" + }, + { + "id": 1281, + "logprob": -0.072509766, + "special": false, + "text": " make" + }, + { + "id": 19898, + "logprob": -0.16516113, + "special": false, + "text": " predictions" + }, + { + "id": 389, + "logprob": -0.4416504, + "special": false, + "text": " on" + }, + { + "id": 3460, + "logprob": -0.5385742, + "special": false, + "text": " large" + }, + { + "id": 14713, + "logprob": -0.4387207, + "special": false, + "text": " amounts" + }, + { + "id": 315, + "logprob": -0.00015091896, + "special": false, + "text": " of" + }, + { + "id": 821, + "logprob": -0.061431885, + "special": false, + "text": " data" + }, + { + "id": 13, + "logprob": -0.71875, + "special": false, + "text": "." + }, + { + "id": 18183, + "logprob": -0.23632812, + "special": false, + "text": " Deep" + }, + { + "id": 6832, + "logprob": -0.0017204285, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -1.1738281, + "special": false, + "text": " is" + }, + { + "id": 1483, + "logprob": -0.61083984, + "special": false, + "text": " used" + }, + { + "id": 304, + "logprob": -0.035003662, + "special": false, + "text": " in" + }, + { + "id": 264, + "logprob": -0.118652344, + "special": false, + "text": " a" + }, + { + "id": 8045, + "logprob": -0.42016602, + "special": false, + "text": " variety" + }, + { + "id": 315, + "logprob": -1.6212463e-05, + "special": false, + "text": " of" + }, + { + "id": 8357, + "logprob": -0.1315918, + "special": false, + "text": " applications" + }, + { + "id": 11, + "logprob": -0.12915039, + "special": false, + "text": "," + }, + { + "id": 2670, + "logprob": -0.12463379, + "special": false, + "text": " including" + }, + { + "id": 2168, + "logprob": -0.37402344, + "special": false, + "text": " image" + }, + { + "id": 323, + "logprob": -0.1451416, + "special": false, + "text": " and" + }, + { + "id": 8806, + "logprob": -0.028869629, + "special": false, + "text": " speech" + }, + { + "id": 17843, + "logprob": -0.00024068356, + "special": false, + "text": " recognition" + }, + { + "id": 11, + "logprob": -0.00031018257, + "special": false, + "text": "," + }, + { + "id": 5810, + "logprob": -0.019821167, + "special": false, + "text": " natural" + }, + { + "id": 4128, + "logprob": -0.00012528896, + "special": false, + "text": " language" + }, + { + "id": 8692, + "logprob": -0.00089263916, + "special": false, + "text": " processing" + }, + { + "id": 11, + "logprob": -0.00073862076, + "special": false, + "text": "," + }, + { + "id": 323, + "logprob": -0.040161133, + "special": false, + "text": " and" + }, + { + "id": 38193, + "logprob": -0.4519043, + "special": false, + "text": " autonomous" + }, + { + "id": 11474, + "logprob": -0.39941406, + "special": false, + "text": " vehicles" + }, + { + "id": 13, + "logprob": -0.21166992, + "special": false, + "text": "." + }, + { + "id": 1084, + "logprob": -0.9082031, + "special": false, + "text": " It" + }, + { + "id": 374, + "logprob": -0.44213867, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -1.2177734, + "special": false, + "text": " a" + }, + { + "id": 18512, + "logprob": -0.5205078, + "special": false, + "text": " rapidly" + }, + { + "id": 7826, + "logprob": -0.15332031, + "special": false, + "text": " growing" + }, + { + "id": 2070, + "logprob": -0.0039978027, + "special": false, + "text": " field" + }, + { + "id": 448, + "logprob": -0.9091797, + "special": false, + "text": " with" + }, + { + "id": 1657, + "logprob": -0.17114258, + "special": false, + "text": " many" + }, + { + "id": 4650, + "logprob": -0.70703125, + "special": false, + "text": " potential" + }, + { + "id": 8357, + "logprob": -0.025131226, + "special": false, + "text": " applications" + }, + { + "id": 304, + "logprob": -0.6699219, + "special": false, + "text": " in" + }, + { + "id": 279, + "logprob": -0.35205078, + "special": false, + "text": " the" + }, + { + "id": 3853, + "logprob": -0.049194336, + "special": false, + "text": " future" + }, + { + "id": 13, + "logprob": -0.21972656, + "special": false, + "text": "." + }, + { + "id": 151643, + "logprob": -2.0019531, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": " Deep learning is a subset of machine learning that uses" + "generated_text": " Deep learning is a subset of machine learning that uses neural networks to learn from data. It is a type of artificial intelligence that can learn from and make predictions on large amounts of data. Deep learning is used in a variety of applications, including image and speech recognition, natural language processing, and autonomous vehicles. It is a rapidly growing field with many potential applications in the future." } diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json index 6b3f5092917..2c840e67124 100644 --- a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json @@ -7,67 +7,67 @@ "seed": 0, "tokens": [ { - "id": 1939, - "logprob": -2.2460938, + "id": 5267, + "logprob": -1.1464844, "special": false, - "text": "?\n\n" + "text": "?\n" }, { "id": 33464, - "logprob": 0.0, + "logprob": -0.83203125, "special": false, "text": "Deep" }, { "id": 20909, - "logprob": -0.48608398, + "logprob": -0.5625, "special": false, "text": " Learning" }, { - "id": 4102, - "logprob": -2.265625, + "id": 320, + "logprob": -2.1464844, "special": false, - "text": " " + "text": " (" }, { - "id": 285, + "id": 16524, "logprob": 0.0, "special": false, - "text": "is" + "text": "DL" }, { - "id": 458, - "logprob": -0.6328125, + "id": 701, + "logprob": -2.2089844, "special": false, - "text": " an" + "text": ")," }, { - "id": 20443, - "logprob": -0.1796875, + "id": 476, + "logprob": -0.27368164, "special": false, - "text": " artificial" + "text": " or" }, { - "id": 11229, - "logprob": 0.0, + "id": 20443, + "logprob": -0.09442139, "special": false, - "text": " intelligence" + "text": " artificial" }, { - "id": 320, - "logprob": -0.37695312, + "id": 29728, + "logprob": 0.0, "special": false, - "text": " (" + "text": " neural" }, { - "id": 15469, + "id": 14155, "logprob": 0.0, "special": false, - "text": "AI" + "text": " networks" } ], "top_tokens": null }, - "generated_text": "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI" + "generated_text": "What is deep learning?\nDeep Learning (DL), or artificial neural networks" } diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json index 1fa4e33aa05..aee5698b474 100644 --- a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json @@ -9,61 +9,61 @@ "tokens": [ { "id": 18183, - "logprob": -1.4912109, + "logprob": -1.5195312, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.075683594, + "logprob": -0.06817627, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.12408447, + "logprob": -0.13122559, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.12768555, + "logprob": -0.13415527, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.82128906, + "logprob": -0.87353516, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0012636185, + "logprob": -0.0011396408, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.12878418, + "logprob": -0.16442871, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0015888214, + "logprob": -0.0026416779, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.49194336, + "logprob": -0.48754883, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2626953, + "logprob": -1.2294922, "special": false, "text": " uses" } @@ -82,61 +82,61 @@ "tokens": [ { "id": 18183, - "logprob": -1.4912109, + "logprob": -1.5195312, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.075683594, + "logprob": -0.06817627, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.12408447, + "logprob": -0.13122559, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.12768555, + "logprob": -0.13415527, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.82128906, + "logprob": -0.87353516, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0012636185, + "logprob": -0.0011396408, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.12878418, + "logprob": -0.16442871, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0015888214, + "logprob": -0.0026416779, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.49194336, + "logprob": -0.48754883, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2626953, + "logprob": -1.2294922, "special": false, "text": " uses" } @@ -155,61 +155,61 @@ "tokens": [ { "id": 18183, - "logprob": -1.4912109, + "logprob": -1.5195312, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.075683594, + "logprob": -0.06817627, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.12408447, + "logprob": -0.13122559, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.12768555, + "logprob": -0.13415527, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.82128906, + "logprob": -0.87353516, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0012636185, + "logprob": -0.0011396408, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.12878418, + "logprob": -0.16442871, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0015888214, + "logprob": -0.0026416779, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.49194336, + "logprob": -0.48754883, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2626953, + "logprob": -1.2294922, "special": false, "text": " uses" } @@ -228,61 +228,61 @@ "tokens": [ { "id": 18183, - "logprob": -1.4912109, + "logprob": -1.5195312, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.075683594, + "logprob": -0.06817627, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.12408447, + "logprob": -0.13122559, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.12768555, + "logprob": -0.13415527, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.82128906, + "logprob": -0.87353516, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0012636185, + "logprob": -0.0011396408, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.12878418, + "logprob": -0.16442871, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0015888214, + "logprob": -0.0026416779, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.49194336, + "logprob": -0.48754883, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2626953, + "logprob": -1.2294922, "special": false, "text": " uses" } diff --git a/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py b/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py index a0b0416b861..17e12c221c2 100644 --- a/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py +++ b/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py @@ -27,15 +27,16 @@ async def test_compressed_tensors_w8a8_int_dynamic_weight( ): response = await compressed_tensors_w8a8_int_dynamic_weight.generate( "What is deep learning?", - max_new_tokens=10, + # prefer a longer response than the default, allow the llm to end generation + max_new_tokens=1000, decoder_input_details=True, ) assert ( response.generated_text - == " Deep learning is a subset of machine learning that uses" + == " Deep learning is a subset of machine learning that uses neural networks to learn from data. It is a type of artificial intelligence that can learn from and make predictions on large amounts of data. Deep learning is used in a variety of applications, including image and speech recognition, natural language processing, and autonomous vehicles. It is a rapidly growing field with many potential applications in the future." ) - assert response.details.generated_tokens == 10 + assert response.details.generated_tokens == 76 assert response == response_snapshot @@ -64,7 +65,7 @@ async def test_compressed_tensors_w8a8_int_dynamic_weight_all_params( assert response.details.generated_tokens == 10 assert ( response.generated_text - == "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI" + == "What is deep learning?\nDeep Learning (DL), or artificial neural networks" ) assert response == response_snapshot diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6bc3c2caf2b..a7d7f7112ba 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1400,7 +1400,11 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): cache_lengths = [0] * bs if max_bs is None: input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + if hasattr(self.model, "get_position_ids"): + # use model specific position ids for initialization + position_ids = self.model.get_position_ids(input_ids) + else: + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device) input_lengths_tensor = ( torch.ones(bs, dtype=torch.int32, device=self.device) * max_s @@ -1427,7 +1431,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" ) input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] - position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] + position_ids = self.cuda_graphs[max_bs]["position_ids"][..., :bs] if ATTENTION == "flashinfer": block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] else: @@ -1456,14 +1460,6 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): else: state = None - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "model_type") - and self.model.config.model_type == "qwen2_vl" - ): - if position_ids.dim() == 1: - position_ids = self.model.get_position_ids(input_ids) - graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs] = { "input_ids": input_ids, @@ -1486,10 +1482,6 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): state=state, cache_lengths_tensor=cache_lengths_tensor, ): - # in the case of N dimensional position ids we need to slice the - # position ids to match the input_ids size for cuda graphs warmup - position_ids = position_ids[..., : input_ids.shape[0]] - seqlen = Seqlen( input_lengths=input_lengths_tensor, cache_lengths=cache_lengths_tensor, From eef3c7bdf2e5c6018988ae362ad59b11405a6c1e Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 23 Jan 2025 15:07:19 +0000 Subject: [PATCH 06/18] fix: prefer default dtype --- server/text_generation_server/models/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 85c98bfde75..f8150b5e67c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1362,7 +1362,8 @@ def get_model( revision=revision, quantize=quantize, speculator=speculator, - dtype=torch.bfloat16, + dtype=dtype, + default_dtype=torch.bfloat16, kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, From 5f416f6e28c95af643f079a14bea6f25f1dbf360 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 23 Jan 2025 15:32:46 +0000 Subject: [PATCH 07/18] fix: enable all cuda graphs and bump snapshots --- .../test_flash_qwen2_vl_simple.json | 8 ++++---- .../test_flash_qwen2_vl_simple_streaming.json | 2 +- .../models/test_flash_qwen2_vl.py | 6 +++--- launcher/src/main.rs | 20 ++----------------- 4 files changed, 10 insertions(+), 26 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json index 5162833fe92..49f332252bc 100644 --- a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The hitch is in a dynamic pose, with its hands on its hips and legs slightly apart, giving it an imposing stance.", + "content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1737498164, + "created": 1737645979, "id": "", "model": "Qwen/Qwen2-VL-7B-Instruct", "object": "chat.completion", "system_fingerprint": "3.0.2-dev0-native", "usage": { - "completion_tokens": 68, + "completion_tokens": 58, "prompt_tokens": 1364, - "total_tokens": 1432 + "total_tokens": 1422 } } diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json index b3b14539bca..3dc8fc6d6d5 100644 --- a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json @@ -11,7 +11,7 @@ "logprobs": null } ], - "created": 1737498227, + "created": 1737646031, "id": "", "model": "Qwen/Qwen2-VL-7B-Instruct", "object": "chat.completion.chunk", diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py index 7d51d20d3af..dacd92a87b3 100644 --- a/integration-tests/models/test_flash_qwen2_vl.py +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -35,7 +35,7 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): assert ( response.choices[0].message.content - == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The hitch is in a dynamic pose, with its hands on its hips and legs slightly apart, giving it an imposing stance." + == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." ) assert response == response_snapshot @@ -72,7 +72,7 @@ async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot): assert ( generated - == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The hitch is in a dynamic pose, with its hands on its hips and legs slightly apart, giving it an imposing stance." + == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." ) - assert count == 68 + assert count == 58 assert last_response == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 7326fd95f2c..a09ceb31d75 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -2056,10 +2056,6 @@ fn main() -> Result<(), LauncherError> { let config: Option = get_config(&args.model_id, &args.revision).ok(); let quantize = config.as_ref().and_then(|c| c.quantize); - let model_type = config - .as_ref() - .and_then(|c| c.model_type.as_deref()) - .map(|s| s.to_owned()); // Quantization usually means you're even more RAM constrained. let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); @@ -2148,20 +2144,8 @@ fn main() -> Result<(), LauncherError> { vec![] } _ => { - let default_cuda_graphs = vec![1, 2, 4, 8, 16, 32]; - tracing::info!("Using default CUDA graphs: {:?}", default_cuda_graphs); - let cuda_graphs = match model_type.as_deref() { - Some("qwen2_vl") => { - tracing::warn!( - "Qwen VL model detected - restricting CUDA graphs to values >= 3" - ); - default_cuda_graphs - .into_iter() - .filter(|&c| c >= 3) - .collect() - } - _ => default_cuda_graphs, - }; + let cuda_graphs = vec![1, 2, 4, 8, 16, 32]; + tracing::info!("Using default cuda graphs {cuda_graphs:?}"); cuda_graphs } }; From 6893eb3834898a658b1a81526c1a6b18ab10e79b Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 27 Jan 2025 16:02:51 +0000 Subject: [PATCH 08/18] fix: adjust rotaty init path --- launcher/src/main.rs | 51 +++---------------- .../text_generation_server/layers/rotary.py | 22 ++++---- .../models/custom_modeling/qwen2_vl.py | 3 ++ 3 files changed, 22 insertions(+), 54 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a09ceb31d75..6391f9eb931 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -260,6 +260,11 @@ struct Config { impl Config { fn flop(&self) -> Option { + if self.vision_config.is_some() { + // VLM are much harder to predict and VRAM requirements + // Are more complex. + return None; + } let num_heads = self.num_heads? as u64; let num_kv_heads = self.num_kv_heads? as u64; let head_dim = self.head_dim? as u64; @@ -279,50 +284,8 @@ impl Config { let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size; let layer_flops = attn_layer_flops + gate_up_down_flops; - let text_flops = layer_flops * num_layers; - - tracing::debug!("Text flops: {}", human_size(text_flops as usize, "flop")); - - // text-only case - if self.vision_config.is_none() { - return Some(text_flops); - } - - let vision_config = self.vision_config.as_ref().unwrap(); - - // estimate vision flops for specific model types - match self.model_type.as_deref() { - Some("qwen2_vl") => { - let in_chans = vision_config.in_chans? as u64; - let patch_size = vision_config.patch_size? as u64; - let embed_dim = vision_config.embed_dim? as u64; - let vision_depth = vision_config.depth? as u64; - let mlp_ratio = vision_config.mlp_ratio? as u64; - let temporal_patch_size = vision_config.temporal_patch_size? as u64; - // 1. patch embedding: - // - conv3d operation: (t*h*w) * (k_t*k_h*k_w) * c_in * c_out * 2 - // where the 2 accounts for multiply-add - let patch_flops = - 2 * temporal_patch_size * patch_size.pow(2) * embed_dim * in_chans; - // 2. self-attention + mlp: - // - qkv projections: 3 * d_model * d_model * 2 - // - attention: d_model * d_model * 2 - // - mlp: 2 * d_model * (mlp_ratio * d_model) * 2 - // simplified to: 2 * d_model * (4 + mlp_ratio * d_model) - let attn_flops = 2 * embed_dim * (4 + mlp_ratio * embed_dim); - // 3. add with layer norm flops for total vision layer flops - let layer_flops = patch_flops + attn_flops + 2 * embed_dim; - let vision_flops = layer_flops * vision_depth; - tracing::debug!( - "Vision flops: {}", - human_size(vision_flops as usize, "flop") - ); - Some(text_flops + vision_flops) - } - // model has a vision config but is not supported for flops calculation - // we return None to avoid overestimating the memory requirements - _ => None, - } + let total = layer_flops * num_layers; + Some(total) } fn kv_vram_per_tok(&self) -> Option { diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 9f1770ff6b0..7b3500e380e 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -101,6 +101,11 @@ def static(cls, config, dim, base, device): pass elif rope_type == "default": pass + elif rope_type == "mrope": + mrope_section = rope_scaling["mrope_section"] + return RotaryPositionEmbeddingMultimodalSections( + inv_freq, scaling_factor, mrope_section + ) elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -576,16 +581,6 @@ def forward( cos: torch.Tensor, sin: torch.Tensor, ): - # process multi-modal rotary embeddings - split_cos, split_sin = [ - torch.split(t, self.sections, dim=-1) for t in (cos, sin) - ] - cos = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1).unsqueeze( - 1 - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1).unsqueeze( - 1 - ) # prepare input tensors q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)] rotary_dim = cos.shape[-1] @@ -624,10 +619,17 @@ def get_cos_sin( .unsqueeze(-1) .expand(-1, -1, self._cos_cached_exp.shape[-1]) ) + indices = indices.to(dtype=torch.int64) cos_c = torch.gather(self._cos_cached_exp, 1, indices) cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1) + split_cos = torch.split(cos_c, self.sections, dim=-1) + cos_c = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1) + cos_c = cos_c.unsqueeze(1) sin_c = torch.gather(self._sin_cached_exp, 1, indices) sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1) + split_sin = torch.split(sin_c, self.sections, dim=-1) + sin_c = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1) + sin_c = sin_c.unsqueeze(1) return cos_c, sin_c diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index e0ae19df766..7e296b42f2f 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -377,6 +377,9 @@ def __init__(self, prefix, config, weights): self.config = config config.vision_config.quantize = None config.vision_config.speculator = config.speculator + # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly + # returns rope_scaling.type == "default" for Qwen2-VL model at the moment + config.rope_scaling.update({"rope_type": "mrope"}) self.hidden_size = config.hidden_size self.vision_start_token_id = config.vision_start_token_id self.image_token_id = config.image_token_id From 68e3ee8e79e86154b1363d76c906093fca48befb Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 28 Jan 2025 15:40:05 +0000 Subject: [PATCH 09/18] fix: simplify get position ids and remove usused vision config --- launcher/src/main.rs | 9 +- .../models/custom_modeling/qwen2_vl.py | 187 ++++++++++-------- 2 files changed, 102 insertions(+), 94 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 6391f9eb931..05ed0202518 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -230,14 +230,7 @@ struct QuantizationConfig { } #[derive(Debug, Deserialize)] -struct VisionConfig { - depth: Option, - embed_dim: Option, - mlp_ratio: Option, - in_chans: Option, - patch_size: Option, - temporal_patch_size: Option, -} +struct VisionConfig {} #[derive(Debug, Deserialize)] struct Config { diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 7e296b42f2f..fdc426bc284 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -382,6 +382,7 @@ def __init__(self, prefix, config, weights): config.rope_scaling.update({"rope_type": "mrope"}) self.hidden_size = config.hidden_size self.vision_start_token_id = config.vision_start_token_id + self.vision_end_token_id = config.vision_end_token_id self.image_token_id = config.image_token_id self.video_token_id = config.video_token_id self.spatial_merge_size = config.vision_config.spatial_merge_size @@ -411,98 +412,112 @@ def __init__(self, prefix, config, weights): def get_position_ids( self, - batch_input_ids: torch.Tensor, - image_grid_thw: Optional[torch.LongTensor] = None, - # video_grid_thw is not implemented yet as we do not accept video inputs at the moment - ) -> Tuple[torch.Tensor, torch.Tensor]: - if batch_input_ids.dim() == 1: - batch_input_ids = batch_input_ids.unsqueeze(0) + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + ) -> torch.Tensor: - position_ids = torch.ones( - 3, - batch_input_ids.shape[0], - batch_input_ids.shape[1], - dtype=batch_input_ids.dtype, - device=batch_input_ids.device, - ) - d = batch_input_ids.device + # TODO: avoid the early return and extra work in a more efficient way if image_grid_thw is not None: - image_index = 0 - llm_pos_ids_list = [] - - for i, input_ids in enumerate(batch_input_ids): - vision_start_indices = torch.argwhere( - input_ids == self.vision_start_token_id - ).squeeze(1) - vision_tokens = input_ids[vision_start_indices + 1] - # only copy the sum of the image tokens GPU<->CPU - image_count = (vision_tokens == self.image_token_id).sum().item() - - current_pos = 0 - for _ in range(image_count): - # copy the value position of the next image token from GPU<->CPU - next_image_pos = ( - (input_ids[current_pos:] == self.image_token_id) - .nonzero()[0] - .item() - ) - # TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop - time_steps, height, width = image_grid_thw[image_index].clone() - height //= self.spatial_merge_size - width //= self.spatial_merge_size - - # calculate the length of the text and image tokens - text_length = next_image_pos - start_idx = ( - llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - ) - - # text position ids - text_pos_ids = torch.arange(text_length, device=d) - text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx - llm_pos_ids_list.append(text_pos_ids) - - # image position ids - t_indices = torch.arange(time_steps, device=d).repeat_interleave( - height * width - ) - h_indices = ( - torch.arange(height, device=d) - .repeat_interleave(width) - .repeat(time_steps) - ) - w_indices = torch.arange(width, device=d).repeat( - height * time_steps - ) - - image_pos_ids = ( - torch.stack([t_indices, h_indices, w_indices]) - + text_length - + start_idx - ) - llm_pos_ids_list.append(image_pos_ids) - - current_pos += next_image_pos + time_steps * height * width - image_index += 1 - - if current_pos < batch_input_ids.size(1): - st_idx = ( - llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - ) - text_len = batch_input_ids.size(1) - current_pos - llm_pos_ids_list.append( - torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx - ) - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[:, i, :] = llm_positions.to(position_ids.device) - else: + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + + position_ids = torch.ones( + 3, + 1, + input_ids.shape[0], + dtype=input_ids.dtype, + device=input_ids.device, + ) position_ids = ( - torch.arange(batch_input_ids.shape[1], device=batch_input_ids.device) + torch.arange(input_ids.shape[1], device=input_ids.device) .view(1, 1, -1) - .repeat(3, batch_input_ids.shape[0], 1) + .repeat(3, input_ids.shape[0], 1) + ) + return position_ids + + # if image grid provided than we need to calculate the position ids + + spatial_merge_size = self.spatial_merge_size + vision_start_token_id = self.vision_start_token_id + vision_end_token_id = self.vision_end_token_id + + device = input_ids.device + dtype = input_ids.dtype + input_ids_len = input_ids.shape[0] + position_ids = torch.ones( + 3, + input_ids_len, + dtype=dtype, + device=device, + ) + + # capture vision segments + starts = torch.where(input_ids == vision_start_token_id)[0] + ends = torch.where(input_ids == vision_end_token_id)[0] + # ie. [[ 14, 2181], [2212, 4379]] + vision_segments = torch.stack((starts, ends), dim=1) + # capture text lengths as the space between vision segments + + prev_end = torch.cat( # shift to the left to get the previous end + [torch.zeros(1, device=ends.device, dtype=dtype), ends[:-1]] + ) # ie. [0, 2181] + + # text is the space between the end of one vision segment and the start of the next + text_lengths = vision_segments[:, 0] - prev_end + 1 # ie. [15, 32] + + # calculate the max id from the image width for each segment + vision_widths_max = torch.cat( + [ + torch.zeros(1, device=image_grid_thw.device, dtype=dtype), + image_grid_thw[:-1, 2] // spatial_merge_size, + ] + ) + total_segment_lengths = vision_widths_max + text_lengths + total_segment_lengths = total_segment_lengths.cumsum(dim=0) + text_diff = total_segment_lengths - text_lengths + + # create position ids for each vision segment based on the image grid + llm_pos_ids_list = [] + for i, _ in enumerate(vision_segments): + t, h, w = ( + image_grid_thw[i][0], + image_grid_thw[i][1] // spatial_merge_size, + image_grid_thw[i][2] // spatial_merge_size, ) - return position_ids + t_indices = torch.arange(t, device=device).repeat_interleave(h * w) + h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) + w_indices = torch.arange(w, device=device).repeat(t * h) + image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) + + # offset by the position of the last vision segment + im = image_position_ids + total_segment_lengths[i] + llm_pos_ids_list.append(im) + + # create position ids for each text segment + text_ranges = [ + torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) + + text_diff[i] + for i, seq_len in enumerate(text_lengths) + ] # ie. [[ 0, 1, ..., 14], [2182, 2183, ..., 2213]] + + # combine by alternating text and vision segments (text, vision, text, vision, ...) + full_llm_pos_ids_list = [ + item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist + ] + + # the final segment is the difference between the last vision segment and the end of the input + max_s = full_llm_pos_ids_list[-1].max() + 1 + final_text_len = input_ids_len - ends[-1] + if final_text_len > 0: + m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) + full_llm_pos_ids_list.append(m + max_s) + + # combine all the segments and reshape to (3, input_ids_len) + llm_positions = torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., :] = llm_positions.to(position_ids.device) + # TODO: avoid the extra dimension when updating the consumer of this function + return position_ids.unsqueeze(1) def forward( self, From c75c01e9b910095c35621b3c909d1d68f7b21d7b Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 28 Jan 2025 19:25:23 +0000 Subject: [PATCH 10/18] fix: update position ids so first dim is batch, simplify rotary and bump vlm default token limit --- launcher/src/main.rs | 11 ++++- .../text_generation_server/layers/rotary.py | 48 +++++++++---------- .../models/custom_modeling/qwen2_vl.py | 44 +++++------------ .../models/flash_causal_lm.py | 4 +- 4 files changed, 47 insertions(+), 60 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 05ed0202518..3c9ee850053 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -2049,7 +2049,16 @@ fn main() -> Result<(), LauncherError> { None => { let compute_type = compute_type(num_shard); let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref()); - let default = compute_optimal.unwrap_or(4096); + // TODO: remove this when we correctly esimate the flops for VLMs + // this is a short term temporary fix to enable vlms to avoid rejecting images + let default_optimal = match config { + Some(ref config) => match config.model_type.as_deref() { + Some("qwen2_vl") => 10_000, + _ => 4096, + }, + None => 4096, + }; + let default = compute_optimal.unwrap_or(default_optimal); let vram_maximum = vram_maximum( config.as_ref(), compute_type.as_ref(), diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 7b3500e380e..c0baaf597d4 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -568,9 +568,7 @@ def apply_llama3_scaling( class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): def __init__(self, inv_freq, scaling_factor, sections): super().__init__(inv_freq, scaling_factor) - # expand the inv_freq for the 3 sections - self.inv_freq_exp = inv_freq[None, None, :, None].expand(3, -1, -1, 1) - self.sections = sections * 2 + self.sections = sections self._cos_cached = None self._sin_cached = None @@ -582,7 +580,7 @@ def forward( sin: torch.Tensor, ): # prepare input tensors - q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)] + q, k = [x.transpose(0, 1) for x in (query, key)] rotary_dim = cos.shape[-1] q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim] q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1) @@ -596,15 +594,14 @@ def _update_cos_sin_cache(self, dtype, device, seqlen): # recomputing if the sequence length is smaller than the cached one if ( seqlen > self._seq_len_cached - or self._cos_cached_exp.device != device - or self._cos_cached_exp.dtype != dtype + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq.to(device=t.device)) - freqs = freqs.expand(3, -1, -1) - self._cos_cached_exp = freqs.cos().to(dtype) - self._sin_cached_exp = freqs.sin().to(dtype) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) def get_cos_sin( self, @@ -613,23 +610,24 @@ def get_cos_sin( dtype: torch.dtype, ): self._update_cos_sin_cache(dtype, position_ids.device, max_s) - # expand the position_ids to match the shape of the cached cos/sin - indices = ( - position_ids.squeeze(1) - .unsqueeze(-1) - .expand(-1, -1, self._cos_cached_exp.shape[-1]) + + # access freqs for each of the 3 sections and stack them + cos_c = torch.stack( + [self._cos_cached[position_ids[:, i]] for i in range(3)], dim=0 + ) + sin_c = torch.stack( + [self._sin_cached[position_ids[:, i]] for i in range(3)], dim=0 ) - indices = indices.to(dtype=torch.int64) - cos_c = torch.gather(self._cos_cached_exp, 1, indices) - cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1) - split_cos = torch.split(cos_c, self.sections, dim=-1) - cos_c = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1) - cos_c = cos_c.unsqueeze(1) - sin_c = torch.gather(self._sin_cached_exp, 1, indices) - sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1) + # chunk based on sections + split_cos = torch.split(cos_c, self.sections, dim=-1) split_sin = torch.split(sin_c, self.sections, dim=-1) - sin_c = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1) - sin_c = sin_c.unsqueeze(1) - return cos_c, sin_c + # for each section, select the corresponding cos/sin (0, 1, 2, ...) + cos_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1) + sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1) + + # double the size and add a batch dimension + cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(0) + sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(0) + return cos, sin diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index fdc426bc284..65d1896305d 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -413,31 +413,17 @@ def __init__(self, prefix, config, weights): def get_position_ids( self, input_ids: torch.Tensor, - image_grid_thw: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - # TODO: avoid the early return and extra work in a more efficient way - if image_grid_thw is not None: - - if input_ids.dim() == 1: - input_ids = input_ids.unsqueeze(0) - - position_ids = torch.ones( - 3, - 1, - input_ids.shape[0], - dtype=input_ids.dtype, - device=input_ids.device, - ) - position_ids = ( - torch.arange(input_ids.shape[1], device=input_ids.device) - .view(1, 1, -1) - .repeat(3, input_ids.shape[0], 1) + if image_grid_thw is None: + # (batch_size, 3) + return ( + torch.arange(input_ids.shape[0], device=input_ids.device) + .unsqueeze(1) + .repeat(1, 3) ) - return position_ids # if image grid provided than we need to calculate the position ids - spatial_merge_size = self.spatial_merge_size vision_start_token_id = self.vision_start_token_id vision_end_token_id = self.vision_end_token_id @@ -445,12 +431,6 @@ def get_position_ids( device = input_ids.device dtype = input_ids.dtype input_ids_len = input_ids.shape[0] - position_ids = torch.ones( - 3, - input_ids_len, - dtype=dtype, - device=device, - ) # capture vision segments starts = torch.where(input_ids == vision_start_token_id)[0] @@ -513,11 +493,11 @@ def get_position_ids( m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) full_llm_pos_ids_list.append(m + max_s) - # combine all the segments and reshape to (3, input_ids_len) - llm_positions = torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., :] = llm_positions.to(position_ids.device) - # TODO: avoid the extra dimension when updating the consumer of this function - return position_ids.unsqueeze(1) + # concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) + position_ids = ( + torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) + ) + return position_ids def forward( self, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a7d7f7112ba..c5d80bc59a6 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1431,7 +1431,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" ) input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] - position_ids = self.cuda_graphs[max_bs]["position_ids"][..., :bs] + position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] if ATTENTION == "flashinfer": block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] else: @@ -2046,7 +2046,7 @@ def generate_token( # instantly become of shape [BATCH_SIZE] if prefill and finished_prefilling: indices = batch.cu_seqlen_prefill[1:] - 1 - batch.position_ids = batch.position_ids[(..., indices)] + batch.position_ids = batch.position_ids[indices] batch.slot_indices = batch.slot_indices[indices] batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ indices From 79a2c956dec6a8126630e817aa529389f2c75471 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 28 Jan 2025 21:08:58 +0000 Subject: [PATCH 11/18] fix: improve position id init during cuda warmup for mrope and simplfy rotary forward --- .../text_generation_server/layers/rotary.py | 19 +++++++++---------- .../models/flash_causal_lm.py | 10 +++++----- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index c0baaf597d4..8132478c901 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -579,15 +579,14 @@ def forward( cos: torch.Tensor, sin: torch.Tensor, ): - # prepare input tensors - q, k = [x.transpose(0, 1) for x in (query, key)] - rotary_dim = cos.shape[-1] - q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim] - q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1) - k2 = torch.cat((-k[..., rotary_dim // 2 :], k[..., : rotary_dim // 2]), dim=-1) + # rotate half the sequence length + rot = cos.shape[-1] // 2 + q2 = torch.cat([-query[..., rot:], query[..., :rot]], dim=-1) + k2 = torch.cat([-key[..., rot:], key[..., :rot]], dim=-1) - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, True) - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, True) + # apply the rotation + rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True) + rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True) def _update_cos_sin_cache(self, dtype, device, seqlen): # always cache the cos/sin for the full sequence length to avoid @@ -628,6 +627,6 @@ def get_cos_sin( sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1) # double the size and add a batch dimension - cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(0) - sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(0) + cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(1) + sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(1) return cos, sin diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c5d80bc59a6..600ed716ed9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1400,11 +1400,11 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): cache_lengths = [0] * bs if max_bs is None: input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - if hasattr(self.model, "get_position_ids"): - # use model specific position ids for initialization - position_ids = self.model.get_position_ids(input_ids) - else: - position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + # mrope have position_ids per section, if so repeat n times + if self.model.config.rope_scaling["rope_type"] == "mrope": + n_sections = len(self.model.config.rope_scaling["mrope_section"]) + position_ids = position_ids.unsqueeze(1).repeat(1, n_sections) slots = torch.arange(bs, dtype=torch.int64, device=self.device) input_lengths_tensor = ( torch.ones(bs, dtype=torch.int32, device=self.device) * max_s From d0e2332d174ac8d94acbb60e4b5b797b2acdf16b Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 28 Jan 2025 22:54:34 +0000 Subject: [PATCH 12/18] fix: check existance before accessing rope type in cuda warmup --- server/text_generation_server/models/flash_causal_lm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 600ed716ed9..47d372adf1e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1401,8 +1401,11 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): if max_bs is None: input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - # mrope have position_ids per section, if so repeat n times - if self.model.config.rope_scaling["rope_type"] == "mrope": + if ( # mrope have position_ids per section, if so repeat n times + hasattr(self.model, "config") + and hasattr(self.model.config, "rope_scaling") + and self.model.config.rope_scaling["rope_type"] == "mrope" + ): n_sections = len(self.model.config.rope_scaling["mrope_section"]) position_ids = position_ids.unsqueeze(1).repeat(1, n_sections) slots = torch.arange(bs, dtype=torch.int64, device=self.device) From 585e270ac375af5a995d2d710358620797cd56a0 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 29 Jan 2025 00:10:43 +0000 Subject: [PATCH 13/18] fix: check key before access --- server/text_generation_server/models/flash_causal_lm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 47d372adf1e..579c7dc2812 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1404,6 +1404,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): if ( # mrope have position_ids per section, if so repeat n times hasattr(self.model, "config") and hasattr(self.model.config, "rope_scaling") + and "rope_type" in self.model.config.rope_scaling and self.model.config.rope_scaling["rope_type"] == "mrope" ): n_sections = len(self.model.config.rope_scaling["mrope_section"]) From cb7ec9cb60b1daa1c7e6478c7b79f3d048960c17 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 29 Jan 2025 13:03:36 +0000 Subject: [PATCH 14/18] fix: improve mrope check in cuda graph warmup --- server/text_generation_server/models/flash_causal_lm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 579c7dc2812..f268e499584 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1401,11 +1401,10 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): if max_bs is None: input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + config = getattr(self.model, "config", None) + rope_scaling = getattr(config, "rope_scaling", None) if config else None if ( # mrope have position_ids per section, if so repeat n times - hasattr(self.model, "config") - and hasattr(self.model.config, "rope_scaling") - and "rope_type" in self.model.config.rope_scaling - and self.model.config.rope_scaling["rope_type"] == "mrope" + isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope" ): n_sections = len(self.model.config.rope_scaling["mrope_section"]) position_ids = position_ids.unsqueeze(1).repeat(1, n_sections) From 79550f8b47b710c6ae66bb374b406d25a862a3b0 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 29 Jan 2025 16:10:17 +0000 Subject: [PATCH 15/18] fix: remove check for default rope type --- server/text_generation_server/layers/rotary.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 8132478c901..b40d413ad08 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -88,10 +88,8 @@ def static(cls, config, dim, base, device): rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) mrope_section = rope_scaling.get("mrope_section", None) - # only apply mrope if sections are provided and the rope type is mrope or default - if mrope_section is not None and ( - rope_type == "mrope" or rope_type == "default" - ): + # only apply mrope if sections are provided and the rope type is mrope and a section is provided + if mrope_section is not None and rope_type == "mrope": mrope_section = rope_scaling.get("mrope_section") return RotaryPositionEmbeddingMultimodalSections( inv_freq, scaling_factor, mrope_section From 9eaa163239c556592c2a6197060ee703a0a05921 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 31 Jan 2025 18:30:32 +0000 Subject: [PATCH 16/18] fix: add more test and improve model generation --- .../test_flash_qwen2_vl_bay.json | 26 ++++++++++ .../test_flash_qwen2_vl_inpaint.json | 26 ++++++++++ .../test_flash_qwen2_vl_simple.json | 10 ++-- .../models/test_flash_qwen2_vl.py | 50 +++++++++++++++++-- .../models/custom_modeling/qwen2_vl.py | 1 - 5 files changed, 104 insertions(+), 9 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_bay.json create mode 100644 integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_inpaint.json diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_bay.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_bay.json new file mode 100644 index 00000000000..25a1abc7a1e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_bay.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "The image showcases a stunning cityscape, featuring the iconic Statue of Liberty in the foreground. The image displays Lady Liberty's imposing presence, with her towering base standing beside her. Behind the statue, the city's skyline extends across the horizon, adorned with numerous tall buildings, including the Empire State Building and other notable skyscrapers. The water reflecting the sun's rays creates a serene and picturesque scene, emphasizing the beauty and resilience of this global landmark. The sky is a clear, pale blue, adding to the overall tranquility of the scene.", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1738348090, + "id": "", + "model": "Qwen/Qwen2-VL-7B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.1.1-dev0-native", + "usage": { + "completion_tokens": 110, + "prompt_tokens": 8736, + "total_tokens": 8846 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_inpaint.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_inpaint.json new file mode 100644 index 00000000000..325e658f1b0 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_inpaint.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "The image shows a stylized scene set in what appears to be a diner or restaurant. In the foreground, there is a table with various food items, including a burger with lettuce and tomato, a bowl of fries, and a drink in a cup with a straw. On the right side of the table, there is an owl sitting alertly, looking directly at the camera. Behind the owl and the table, there is a large, green, dinosaur-like creature resembling Godzilla, with its mouth open and tongue visible. In the background, the diner's decor includes various signs and posters, with a green sign reading \"Basta\" and another sign that says \"Tabasco.\" The setting has a retro or vintage feel, with fluorescent lighting overhead and clean, polished surfaces.", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1738348100, + "id": "", + "model": "Qwen/Qwen2-VL-7B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.1.1-dev0-native", + "usage": { + "completion_tokens": 156, + "prompt_tokens": 5375, + "total_tokens": 5531 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json index 49f332252bc..6b6017c9428 100644 --- a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.", + "content": "The image depicts an anthropomorphic rabbit, wearing a spacesuit, standing in a barren, rocky landscape that resembles the surface of another planet, possibly Mars. The rabbit has a red digestive system label on its chest, and the surrounding environment features red sandy terrain and a hazy, floating planet or moon in the background. The scene has a surreal, fantastical quality, blending elements of science fiction and space exploration with a whimsical character.", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1737645979, + "created": 1738347908, "id": "", "model": "Qwen/Qwen2-VL-7B-Instruct", "object": "chat.completion", - "system_fingerprint": "3.0.2-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": { - "completion_tokens": 58, + "completion_tokens": 89, "prompt_tokens": 1364, - "total_tokens": 1422 + "total_tokens": 1453 } } diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py index dacd92a87b3..5a12eba831c 100644 --- a/integration-tests/models/test_flash_qwen2_vl.py +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -35,7 +35,7 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): assert ( response.choices[0].message.content - == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." + == "The image depicts an anthropomorphic rabbit, wearing a spacesuit, standing in a barren, rocky landscape that resembles the surface of another planet, possibly Mars. The rabbit has a red digestive system label on its chest, and the surrounding environment features red sandy terrain and a hazy, floating planet or moon in the background. The scene has a surreal, fantastical quality, blending elements of science fiction and space exploration with a whimsical character." ) assert response == response_snapshot @@ -72,7 +72,51 @@ async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot): assert ( generated - == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." + == "The image depicts an anthropomorphic rabbit, wearing a spacesuit, standing in a barren, rocky landscape that resembles the surface of another planet, possibly Mars. The rabbit has a red digestive system label on its chest, and the surrounding environment features red sandy terrain and a hazy, floating planet or moon in the background. The scene has a surreal, fantastical quality, blending elements of science fiction and space exploration with a whimsical character." ) - assert count == 58 + assert count == 89 assert last_response == response_snapshot + + +@pytest.mark.private +async def test_flash_qwen2_vl_bay(flash_qwen2, response_snapshot): + response = await flash_qwen2.chat( + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + }, + }, + {"type": "text", "text": "Describe the image"}, + ], + }, + ], + ) + assert response == response_snapshot + + +@pytest.mark.private +async def test_flash_qwen2_vl_inpaint(flash_qwen2, response_snapshot): + response = await flash_qwen2.chat( + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png" + }, + }, + {"type": "text", "text": "Describe the image"}, + ], + }, + ], + ) + assert response == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 65d1896305d..4031fe8f214 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -543,7 +543,6 @@ def forward( true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, ) - hidden_states, _ = self.norm(hidden_states) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) From 6cb0cb68b455c379959490ed41cf26303c544da7 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 4 Feb 2025 00:25:59 +0000 Subject: [PATCH 17/18] fix: improve and simplify get_cos_sin, refactors and cleanup get_position_ids --- .../text_generation_server/layers/rotary.py | 44 +++++-------- .../custom_modeling/flash_qwen2_modeling.py | 6 +- .../models/custom_modeling/qwen2_vl.py | 62 +++++++++---------- 3 files changed, 47 insertions(+), 65 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index b40d413ad08..576aeb52b6e 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -88,22 +88,16 @@ def static(cls, config, dim, base, device): rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) mrope_section = rope_scaling.get("mrope_section", None) - # only apply mrope if sections are provided and the rope type is mrope and a section is provided - if mrope_section is not None and rope_type == "mrope": - mrope_section = rope_scaling.get("mrope_section") - return RotaryPositionEmbeddingMultimodalSections( - inv_freq, scaling_factor, mrope_section - ) - if rope_type == "linear": pass elif rope_type == "default": pass elif rope_type == "mrope": mrope_section = rope_scaling["mrope_section"] - return RotaryPositionEmbeddingMultimodalSections( - inv_freq, scaling_factor, mrope_section - ) + if mrope_section is not None: + return RotaryPositionEmbeddingMultimodalSections( + inv_freq, scaling_factor, mrope_section + ) elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -569,6 +563,12 @@ def __init__(self, inv_freq, scaling_factor, sections): self.sections = sections self._cos_cached = None self._sin_cached = None + self.section_indices = ( + torch.arange(len(self.sections)) + .repeat_interleave(torch.tensor(self.sections)) + .view(1, 1, -1) + .to(inv_freq.device) + ) def forward( self, @@ -599,6 +599,7 @@ def _update_cos_sin_cache(self, dtype, device, seqlen): freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) + self._sections = self.section_indices.expand(seqlen, -1, -1) def get_cos_sin( self, @@ -607,24 +608,11 @@ def get_cos_sin( dtype: torch.dtype, ): self._update_cos_sin_cache(dtype, position_ids.device, max_s) + slen = position_ids.shape[0] - # access freqs for each of the 3 sections and stack them - cos_c = torch.stack( - [self._cos_cached[position_ids[:, i]] for i in range(3)], dim=0 - ) - sin_c = torch.stack( - [self._sin_cached[position_ids[:, i]] for i in range(3)], dim=0 - ) - - # chunk based on sections - split_cos = torch.split(cos_c, self.sections, dim=-1) - split_sin = torch.split(sin_c, self.sections, dim=-1) - - # for each section, select the corresponding cos/sin (0, 1, 2, ...) - cos_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1) - sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1) + cos = self._cos_cached[position_ids].gather(1, self._sections[:slen]) + sin = self._sin_cached[position_ids].gather(1, self._sections[:slen]) - # double the size and add a batch dimension - cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(1) - sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(1) + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([sin, sin], dim=-1) return cos, sin diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 78ae3020cb8..d6569a1db35 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -235,8 +235,7 @@ def forward( max_s, prefill_cache_indices, ): - residual = hidden_states - normed_hidden_states, _ = self.input_layernorm(hidden_states) + normed_hidden_states, residual = self.input_layernorm(hidden_states) # Self Attention attn_output = self.self_attn( @@ -254,8 +253,7 @@ def forward( hidden_states = attn_output + residual # faster post attention rms norm - residual = hidden_states - hidden_states, _ = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm(hidden_states) mlp_output = self.mlp(hidden_states) hidden_states = mlp_output + residual return hidden_states diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 4031fe8f214..2d017e382ca 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -222,10 +222,10 @@ def __init__(self, prefix, config, weights): def forward( self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen ) -> torch.Tensor: - norm1_out, _ = self.norm1(hidden_states) + norm1_out, residual = self.norm1(hidden_states) attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) - hidden_states = hidden_states + attn_out - norm2_out, _ = self.norm2(hidden_states) + hidden_states = attn_out + residual + norm2_out, residual = self.norm2(hidden_states) hidden_states = hidden_states + self.mlp(norm2_out) return hidden_states @@ -410,52 +410,52 @@ def __init__(self, prefix, config, weights): ) self.device = weights.device + # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391 + # modified to first find segments then initialize position ids for each segment + # Steps: + # locate all vision and text segments + # calculate `vision_segment_lengths` for each vision segment to be use as offset + # calculate `text_segment_lengths` for each text segment to be used as offset + # create position ids for each vision segment based on the image grid + # create position ids for each text segment + # combine all the position ids + # the final segment is the difference between the last vision segment and the end of the input + # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) def get_position_ids( self, input_ids: torch.Tensor, image_grid_thw: Optional[torch.Tensor] = None, ) -> torch.Tensor: if image_grid_thw is None: - # (batch_size, 3) return ( torch.arange(input_ids.shape[0], device=input_ids.device) .unsqueeze(1) .repeat(1, 3) ) - # if image grid provided than we need to calculate the position ids spatial_merge_size = self.spatial_merge_size vision_start_token_id = self.vision_start_token_id vision_end_token_id = self.vision_end_token_id - device = input_ids.device dtype = input_ids.dtype input_ids_len = input_ids.shape[0] - # capture vision segments - starts = torch.where(input_ids == vision_start_token_id)[0] - ends = torch.where(input_ids == vision_end_token_id)[0] - # ie. [[ 14, 2181], [2212, 4379]] - vision_segments = torch.stack((starts, ends), dim=1) - # capture text lengths as the space between vision segments - - prev_end = torch.cat( # shift to the left to get the previous end - [torch.zeros(1, device=ends.device, dtype=dtype), ends[:-1]] - ) # ie. [0, 2181] - - # text is the space between the end of one vision segment and the start of the next - text_lengths = vision_segments[:, 0] - prev_end + 1 # ie. [15, 32] - - # calculate the max id from the image width for each segment + vision_starts = torch.where(input_ids == vision_start_token_id)[0] + vision_ends = torch.where(input_ids == vision_end_token_id)[0] + vision_segments = torch.stack((vision_starts, vision_ends), dim=1) + prev_vision_end = torch.cat( + [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]] + ) + text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1 vision_widths_max = torch.cat( [ torch.zeros(1, device=image_grid_thw.device, dtype=dtype), image_grid_thw[:-1, 2] // spatial_merge_size, ] ) - total_segment_lengths = vision_widths_max + text_lengths - total_segment_lengths = total_segment_lengths.cumsum(dim=0) - text_diff = total_segment_lengths - text_lengths + vision_segment_lengths = vision_widths_max + text_lengths_between_vision + vision_segment_lengths = vision_segment_lengths.cumsum(dim=0) + text_segment_lengths = vision_segment_lengths - text_lengths_between_vision # create position ids for each vision segment based on the image grid llm_pos_ids_list = [] @@ -471,29 +471,25 @@ def get_position_ids( image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) # offset by the position of the last vision segment - im = image_position_ids + total_segment_lengths[i] + im = image_position_ids + vision_segment_lengths[i] llm_pos_ids_list.append(im) # create position ids for each text segment text_ranges = [ torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) - + text_diff[i] - for i, seq_len in enumerate(text_lengths) - ] # ie. [[ 0, 1, ..., 14], [2182, 2183, ..., 2213]] + + text_segment_lengths[i] + for i, seq_len in enumerate(text_lengths_between_vision) + ] - # combine by alternating text and vision segments (text, vision, text, vision, ...) full_llm_pos_ids_list = [ item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist ] - - # the final segment is the difference between the last vision segment and the end of the input max_s = full_llm_pos_ids_list[-1].max() + 1 - final_text_len = input_ids_len - ends[-1] + final_text_len = input_ids_len - vision_ends[-1] if final_text_len > 0: m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) full_llm_pos_ids_list.append(m + max_s) - # concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) position_ids = ( torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) ) From 58f5f2ee27ecb24b3adb373d6f4980efe6064f90 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 4 Feb 2025 00:30:47 +0000 Subject: [PATCH 18/18] fix: adjust signatures with types --- server/text_generation_server/layers/rotary.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 576aeb52b6e..f38f685970e 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -558,7 +558,7 @@ def apply_llama3_scaling( class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): - def __init__(self, inv_freq, scaling_factor, sections): + def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list): super().__init__(inv_freq, scaling_factor) self.sections = sections self._cos_cached = None @@ -586,7 +586,9 @@ def forward( rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True) rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True) - def _update_cos_sin_cache(self, dtype, device, seqlen): + def _update_cos_sin_cache( + self, dtype: torch.dtype, device: torch.device, seqlen: int + ): # always cache the cos/sin for the full sequence length to avoid # recomputing if the sequence length is smaller than the cached one if (