Skip to content

Commit

Permalink
[TPU] Optimize RoPE forward_native2 (vllm-project#7636)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Aug 18, 2024
1 parent 0c2fa50 commit ab7165f
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,23 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:

def _apply_rotary_emb(
x: torch.Tensor,
freqs_cis: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
return x_out
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
"""
orig_dtype = x.dtype
x = x.float()
x1, x2 = torch.chunk(x, 2, dim=-1)
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
return torch.cat((o1, o2), dim=-1).to(orig_dtype)


class RotaryEmbedding(CustomOp):
Expand All @@ -78,14 +86,10 @@ def __init__(
self.dtype = dtype

cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)

self.use_native2 = current_platform.is_tpu() and is_neox_style
if not self.use_native2:
cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
else:
cos, sin = cache.chunk(2, dim=-1)
freqs_cis = cos + 1j * sin
self.register_buffer("freqs_cis", freqs_cis, persistent=False)

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
Expand Down Expand Up @@ -173,28 +177,25 @@ def forward_native2(
This method might perform better than `forward_native()` when compiled.
"""
if positions.dim() == 1:
batch_size = 1
seq_len = positions.shape[0]
else:
batch_size, seq_len = positions.shape
if offsets is not None:
positions = positions + offsets
freqs_cis = self.freqs_cis.index_select(0, positions.flatten())
freqs_cis = freqs_cis.view(batch_size, 1, seq_len, -1)
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)

query_shape = query.shape
query = query.view(batch_size, seq_len, -1, self.head_size)
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, freqs_cis)
query_rot = _apply_rotary_emb(query_rot, cos, sin)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

key_shape = key.shape
key = key.view(batch_size, seq_len, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, freqs_cis)
key_rot = _apply_rotary_emb(key_rot, cos, sin)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

Expand Down

0 comments on commit ab7165f

Please sign in to comment.