Skip to content

Commit

Permalink
[Misc] Remove Gemma RoPE (vllm-project#7638)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Aug 19, 2024
1 parent 1a36287 commit df845b2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 26 deletions.
15 changes: 0 additions & 15 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@ def __init__(

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
Expand Down Expand Up @@ -724,16 +719,6 @@ def forward(
return query, key


class GemmaRotaryEmbedding(RotaryEmbedding):

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq = 1.0 / (base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
self.rotary_dim))
return inv_freq


class Llama3RotaryEmbedding(RotaryEmbedding):

def __init__(
Expand Down
8 changes: 3 additions & 5 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
Expand Down Expand Up @@ -148,14 +148,12 @@ def __init__(self,
quant_config=quant_config,
)

# TODO(woosuk): Use the `get_rope` interface.
self.rotary_emb = GemmaRotaryEmbedding(
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position_embeddings=max_position_embeddings,
max_position=max_position_embeddings,
base=self.rope_theta,
is_neox_style=True,
dtype=torch.get_default_dtype(),
)
self.attn = Attention(self.num_heads,
self.head_dim,
Expand Down
10 changes: 4 additions & 6 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
Expand Down Expand Up @@ -130,14 +130,12 @@ def __init__(self,
bias=config.attention_bias,
quant_config=quant_config,
)
# TODO(woosuk): Use the `get_rope` interface.
self.rotary_emb = GemmaRotaryEmbedding(
self.rotary_emb = get_rope(
self.head_dim,
self.head_dim,
max_position_embeddings,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=self.rope_theta,
is_neox_style=True,
dtype=torch.get_default_dtype(),
)

# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
Expand Down

0 comments on commit df845b2

Please sign in to comment.