From df845b2b46c3e30f5bd3e3be286285ed148323fc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 19 Aug 2024 09:29:31 -0700 Subject: [PATCH] [Misc] Remove Gemma RoPE (#7638) --- vllm/model_executor/layers/rotary_embedding.py | 15 --------------- vllm/model_executor/models/gemma.py | 8 +++----- vllm/model_executor/models/gemma2.py | 10 ++++------ 3 files changed, 7 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index e6ee2b967c8da..0562b71aa7493 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -93,11 +93,6 @@ def __init__( def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" - # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. - # However, we use `torch.arange(..., dtype=torch.float)` instead to - # avoid numerical issues with large base values (e.g., 10000000). - # This may cause a slight numerical difference between the HF - # implementation and ours. # NOTE(woosuk): To exactly match the HF implementation, we need to # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause @@ -724,16 +719,6 @@ def forward( return query, key -class GemmaRotaryEmbedding(RotaryEmbedding): - - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: - # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 - inv_freq = 1.0 / (base**( - torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() / - self.rotary_dim)) - return inv_freq - - class Llama3RotaryEmbedding(RotaryEmbedding): def __init__( diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 14d1578863e5e..7a9ee3d9477ca 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -148,14 +148,12 @@ def __init__(self, quant_config=quant_config, ) - # TODO(woosuk): Use the `get_rope` interface. - self.rotary_emb = GemmaRotaryEmbedding( + self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, - max_position_embeddings=max_position_embeddings, + max_position=max_position_embeddings, base=self.rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), ) self.attn = Attention(self.num_heads, self.head_dim, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index aa9cff02283c0..ff547c2c3b8ab 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -32,7 +32,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -130,14 +130,12 @@ def __init__(self, bias=config.attention_bias, quant_config=quant_config, ) - # TODO(woosuk): Use the `get_rope` interface. - self.rotary_emb = GemmaRotaryEmbedding( + self.rotary_emb = get_rope( self.head_dim, - self.head_dim, - max_position_embeddings, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, base=self.rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), ) # FIXME(woosuk): While Gemma 2 uses sliding window attention for every