Skip to content

Commit

Permalink
[Inference PagedAttention] Integrate initial paged attention implemen…
Browse files Browse the repository at this point in the history
…tation into maxengine (2/N)

This change is based on a branch from Pate and Rupeng with code refactoring and modifications.

What:
* This PR integrate initial paged attention components into maxengine, guarded behind
  `attention=paged` config setting.

Impact of this change:
* This PR is a NOOP. Paged attention is not enabled unless `attention=paged` is set in
  the config. The default `attention=autoselected` will NOT trigger paged attention.

Key changes:
* MaxText/layers/attentions.py: Use paged attention op when `attention=paged` for all
  model mode other than MODEL_MODE_TRAIN
* MaxText/layers/models.py: Initialize paged attention components when `attention=paged`

Why:
* Page attention should be able to enhance inference performance.

Testing:
* python -m unittest tests/inference/paged_attention_test.py
* python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2  \
    load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_  \
    max_prefill_predict_length=16 max_target_length=32 model_name=llama2-7b   \
    ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 \
    scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1   \
    checkpoint_is_quantized=true quantization=int8 \
    attention=paged pagedattn_num_pages=64 pagedattn_tokens_per_page=8 pagedattn_pages_per_compute_block=4
  • Loading branch information
wyzhang committed Mar 5, 2025
1 parent 4797997 commit 570215d
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 30 deletions.
14 changes: 10 additions & 4 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,20 @@ def main(argv: Sequence[str]) -> None:
assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet"
assert config.quantization != "nanoo_fp8", "NANOO fp8 on AMD MI300/MI325 GPUs is not supported in decode.py yet"

# Split RNG before calling prefill
rng, rng_prefill = jax.random.split(rng)
prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
slot = 0
slot = 0 # Always use decode batch slot 0.

# Prefill
rng, rng_prefill = jax.random.split(rng) # Split RNG before calling prefill
prefill_result, first_token = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill, slot=slot
)

# Insert
rng, rng_init_decode = jax.random.split(rng)
decode_state = engine.init_decode_state(rng_init_decode)
decode_state = engine.insert(prefill_result, decode_state, slot=slot)

# Generate
steps = range(config.max_prefill_predict_length, config.max_target_length)
sampled_tokens_list = []
sampled_tokens_list.append(first_token)
Expand All @@ -64,6 +69,7 @@ def main(argv: Sequence[str]) -> None:
decode_state, sampled_tokens = engine.generate(params, decode_state, rng=rng_generate)
sampled_tokens_list.append(sampled_tokens)

# Get results
results = [sampled_tokens.get_result_at_slot(slot).tokens.item() for sampled_tokens in sampled_tokens_list]
output = tokenizer_model.decode(results)
print(f"Input `{text}` -> `{output}`")
Expand Down
9 changes: 6 additions & 3 deletions MaxText/inference/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
import jax.numpy as jnp
from flax import linen as nn
from jax.experimental import shard_map
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention_kernel
from jax.sharding import PartitionSpec as P

from inference import page_manager

# pytype: disable=attribute-error

Mesh = common_types.Mesh

Array = common_types.Array
Expand Down Expand Up @@ -166,7 +168,7 @@ def paged_attention(
)
def wrap_paged_attention(q, k_pages, v_pages, lengths, page_indices, pages_per_compute_block):
q = jnp.squeeze(q, axis=1)
result = paged_attention(
result = paged_attention_kernel.paged_attention(
q=q,
k_pages=k_pages,
v_pages=v_pages,
Expand All @@ -193,7 +195,8 @@ def __call__(
value: Array,
decoder_segment_ids: Array,
model_mode: str,
page_state: page_manager.PageState,
previous_chunk=None,
page_state: Optional[page_manager.PageState] = None,
) -> Array:
"""Apply paged attention mechanism.
Expand Down
41 changes: 36 additions & 5 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
import common_types
from kernels.ragged_attention import ragged_gqa
from kernels.ragged_attention import ragged_mha
from inference import page_manager, paged_attention
from layers import embeddings
from layers import initializers
from layers import linears
from layers import quantizations


# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
# pytype: disable=attribute-error

Expand Down Expand Up @@ -279,6 +279,7 @@ def apply_attention(
self.attention_kernel == "dot_product"
or (self.attention_kernel == "autoselected" and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE)
or (self.attention_kernel == "autoselected" and length < 128)
or (self.attention_kernel == "paged")
):
return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode, previous_chunk)
elif self.attention_kernel == "flash" or self.attention_kernel == "autoselected":
Expand Down Expand Up @@ -1188,7 +1189,16 @@ def normalize_attention(self, local_outs, local_maxes, local_sums):
return attn_out

@nn.compact
def __call__(self, query, key, value, decoder_segment_ids, model_mode, previous_chunk=None):
def __call__(
self,
query,
key,
value,
decoder_segment_ids,
model_mode,
previous_chunk=None,
page_state: Optional[page_manager.PageState] = None,
):
prefill_kv_cache, ar_kv_cache = self.kv_cache(
key, value, decoder_segment_ids, model_mode, use_ragged_attention=self.use_ragged_attention
)
Expand Down Expand Up @@ -1318,6 +1328,21 @@ def setup(self):
use_ragged_attention=self.use_ragged_attention,
ragged_block_size=self.ragged_block_size,
)
# When paged attention is enabled, paged attention op is used for all model modes except TRAIN,
# which uses default attention op.
if self.config.attention == "paged":
self.paged_attention_op = paged_attention.PagedAttentionOp(
mesh=self.mesh,
num_pages=self.config.pagedattn_num_pages,
tokens_per_page=self.config.pagedattn_tokens_per_page,
max_pages_per_slot=self.config.max_target_length // self.config.pagedattn_tokens_per_page,
max_pages_per_prefill=self.config.max_prefill_predict_length // self.config.pagedattn_tokens_per_page,
pages_per_compute_block=self.config.pagedattn_pages_per_compute_block,
num_kv_heads=self.num_kv_heads,
kv_head_dim_size=self.head_dim,
dtype=self.dtype,
attn_logits_soft_cap=self.attn_logits_soft_cap,
)

def query_projection(self, inputs_q: Array) -> Array:
"""Query projection."""
Expand Down Expand Up @@ -1468,6 +1493,7 @@ def __call__(
model_mode: str = common_types.MODEL_MODE_TRAIN,
deterministic: bool = False,
previous_chunk: Any = None,
page_state: Optional[page_manager.PageState] = None,
):
"""Applies Attention on the input data.
Expand Down Expand Up @@ -1520,11 +1546,15 @@ def __call__(

assert not self.config.quantize_kvcache or self.kv_quant

out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, previous_chunk)
if self.config.attention == "paged" and model_mode != common_types.MODEL_MODE_TRAIN:
unnormalized_out, _, exp_sum = self.paged_attention_op(
query, key, value, decoder_segment_ids, model_mode, previous_chunk, page_state=page_state
)
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
else:
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, previous_chunk)

out = nn.with_logical_constraint(out, self.out_axis_names)

# apply output projection, output dim is set to the input dim.
out = self.out_projection(inputs_q.shape[-1], out)
out = checkpoint_name(out, "out_proj")
return out
Expand Down Expand Up @@ -1691,6 +1721,7 @@ def __call__(
model_mode: str = common_types.MODEL_MODE_TRAIN,
deterministic: bool = False,
previous_chunk: Any = None,
page_state: Optional[page_manager.PageState] = None,
) -> Array:
"""Forward pass for MLA, reusing `AttentionOp` for the actual attention.
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=None,
):
cfg = self.config
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
Expand Down Expand Up @@ -196,6 +197,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=None,
):
cfg = self.config
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
Expand Down
1 change: 1 addition & 0 deletions MaxText/layers/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_manager=None,
):
cfg = self.config
mesh = self.mesh
Expand Down
1 change: 1 addition & 0 deletions MaxText/layers/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=None,
):
cfg = self.config
mesh = self.mesh
Expand Down
1 change: 1 addition & 0 deletions MaxText/layers/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=None,
):
cfg = self.config
mesh = self.mesh
Expand Down
3 changes: 3 additions & 0 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from layers import quantizations

import common_types
from inference import page_manager
from typing import Optional

Array = common_types.Array
Expand Down Expand Up @@ -74,6 +75,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state: Optional[page_manager.PageState] = None,
):
cfg = self.config
mesh = self.mesh
Expand Down Expand Up @@ -124,6 +126,7 @@ def __call__(
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
page_state=page_state,
)

attention_lnx = nn.with_logical_constraint(
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from layers import models
import common_types
import max_logging
from maxtext.inference import page_manager

Array = common_types.Array
Config = common_types.Config
Expand Down Expand Up @@ -67,6 +68,7 @@ def __call__(
deterministic,
model_mode,
previous_chunk=None,
page_state=None,
):
cfg = self.config
mesh = self.mesh
Expand Down
60 changes: 57 additions & 3 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
import common_types
from inference import page_manager
from layers import attentions
from layers import embeddings
from layers import linears
Expand Down Expand Up @@ -63,6 +64,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state: Optional[page_manager.PageState] = None,
):
cfg = self.config
mesh = self.mesh
Expand Down Expand Up @@ -160,14 +162,17 @@ class SequentialBlockDecoderLayers(nn.Module):
quant: Quant

@nn.compact
def __call__(self, inputs: jnp.ndarray, decoder_segment_ids, decoder_positions, deterministic, model_mode) -> jnp.ndarray:
def __call__(
self, inputs: jnp.ndarray, decoder_segment_ids, decoder_positions, deterministic, model_mode, page_state=None
) -> jnp.ndarray:
for lyr in range(self.num_decoder_layers):
inputs = self.decoder_layer(config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant)(
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
page_state=page_state,
)
return inputs

Expand Down Expand Up @@ -380,6 +385,7 @@ def __call__(
deterministic=False,
model_mode=common_types.MODEL_MODE_TRAIN,
previous_chunk=None,
page_state: Optional[page_manager.PageState] = None,
):
cfg = self.config
mesh = self.mesh
Expand Down Expand Up @@ -428,6 +434,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=page_state,
)
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
y, _ = self.scan_decoder_layers(cfg, moe_layer, num_moe_layers, "moe_layers", mesh)(
Expand All @@ -436,6 +443,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=page_state,
)
else:
RemattedBlockLayer = RemattedBlockLayers[0]
Expand All @@ -445,6 +453,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=page_state,
)
else:
if cfg.decoder_block == "deepseek":
Expand All @@ -463,6 +472,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=page_state,
)
else:
for lyr in range(cfg.num_decoder_layers):
Expand All @@ -473,8 +483,8 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=page_state,
)

y = self.get_norm_layer()(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
Expand Down Expand Up @@ -539,6 +549,41 @@ def setup(self):

self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant)

if cfg.attention == "paged":
self.page_manager = self._create_page_manager(cfg)

def _create_page_manager(self, config) -> Optional[page_manager.PageManager]:
"""Creates page manager for managing pages in the paged attention mechanism"""
assert config.max_target_length % config.pagedattn_tokens_per_page == 0
assert config.max_prefill_predict_length % config.pagedattn_tokens_per_page == 0
return page_manager.PageManager(
num_pages=self.config.pagedattn_num_pages,
tokens_per_page=self.config.pagedattn_tokens_per_page,
slots=int(self.config.per_device_batch_size * jax.device_count()),
max_target_length=self.config.max_target_length,
max_prefill_predict_length=self.config.max_prefill_predict_length,
max_pages_per_slot=self.config.max_target_length // self.config.pagedattn_tokens_per_page,
)

def _create_page_state(
self, model_mode: str, true_length: Optional[int] = None, slot: Optional[int] = None
) -> Optional[page_manager.PageState]:
"""Creates page state for tracking page status in the paged attention mechanism."""
if self.config.attention != "paged" or model_mode == common_types.MODEL_MODE_TRAIN:
return None
page_state = None
if model_mode == common_types.MODEL_MODE_PREFILL:
page_state = self.page_manager(
model_mode=model_mode,
slot=slot,
true_length=true_length,
)
elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
page_state = self.page_manager(model_mode)
else:
raise ValueError(f"Unsupported model_mode {model_mode} by paged attention")
return page_state

def __call__(
self,
decoder_input_tokens,
Expand All @@ -547,8 +592,16 @@ def __call__(
enable_dropout=True,
model_mode=common_types.MODEL_MODE_TRAIN,
previous_chunk=None,
true_length: Optional[int] = None,
slot: Optional[int] = None,
):
"""Applies Transformer decoder-branch on encoded-input and target."""
"""Applies Transformer decoder-branch on encoded-input and target.
Args:
true_length: (Optional) Prompt length before padding
slot: (Optional) An integer representing the decode batch index selected
for this request.
"""

if decoder_segment_ids is not None and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError(
Expand All @@ -563,5 +616,6 @@ def __call__(
deterministic=not enable_dropout,
model_mode=model_mode,
previous_chunk=previous_chunk,
page_state=self._create_page_state(model_mode=model_mode, true_length=true_length, slot=slot),
)
return logits
Loading

0 comments on commit 570215d

Please sign in to comment.