Skip to content

Commit

Permalink
[Inference PagedAttention] Integrate initial paged attention implemen…
Browse files Browse the repository at this point in the history
…tation into maxengine (2/N)

This change is based on a branch from Pate and Rupeng with code refactoring and modifications.

What:
* This PR integrate initial paged attention components into maxengine, guarded behind
  `attention=paged` config setting.

Impact of this change:
* This PR is a NOOP. Paged attention is not enabled unless `attention=paged` is set in
  the config. The default `attention=autoselected` will NOT trigger paged attention.

Key changes:
* MaxText/layers/attentions.py: Use paged attention op when `attention=paged` for all
  model mode other than MODEL_MODE_TRAIN
* MaxText/layers/models.py: Initialize paged attention components when `attention=paged`

Why:
* Page attention should be able to enhance inference performance.

Testing:
* python -m unittest tests/inference/paged_attention_test.py
* python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2  \
    load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_  \
    max_prefill_predict_length=16 max_target_length=32 model_name=llama2-7b   \
    ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 \
    scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1   \
    checkpoint_is_quantized=true quantization=int8 \
    attention=paged pagedattn_num_pages=64 pagedattn_tokens_per_page=8 pagedattn_pages_per_compute_block=4
  • Loading branch information
wyzhang committed Mar 4, 2025
1 parent 61b5875 commit 0e620e5
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 26 deletions.
14 changes: 10 additions & 4 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,20 @@ def main(argv: Sequence[str]) -> None:
assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet"
assert config.quantization != "nanoo_fp8", "NANOO fp8 on AMD MI300/MI325 GPUs is not supported in decode.py yet"

# Split RNG before calling prefill
rng, rng_prefill = jax.random.split(rng)
prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
slot = 0
slot = 0 # Always use decode batch slot 0.

# Prefill
rng, rng_prefill = jax.random.split(rng) # Split RNG before calling prefill
prefill_result, first_token = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill, slot=slot
)

# Insert
rng, rng_init_decode = jax.random.split(rng)
decode_state = engine.init_decode_state(rng_init_decode)
decode_state = engine.insert(prefill_result, decode_state, slot=slot)

# Generate
steps = range(config.max_prefill_predict_length, config.max_target_length)
sampled_tokens_list = []
sampled_tokens_list.append(first_token)
Expand All @@ -64,6 +69,7 @@ def main(argv: Sequence[str]) -> None:
decode_state, sampled_tokens = engine.generate(params, decode_state, rng=rng_generate)
sampled_tokens_list.append(sampled_tokens)

# Get results
results = [sampled_tokens.get_result_at_slot(slot).tokens.item() for sampled_tokens in sampled_tokens_list]
output = tokenizer_model.decode(results)
print(f"Input `{text}` -> `{output}`")
Expand Down
33 changes: 28 additions & 5 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
import common_types
from kernels.ragged_attention import ragged_gqa
from kernels.ragged_attention import ragged_mha
from inference import page_manager, paged_attention
from layers import embeddings
from layers import initializers
from layers import linears
from layers import quantizations


# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
# pytype: disable=attribute-error

Expand Down Expand Up @@ -241,6 +241,7 @@ def apply_attention(
self.attention_kernel == "dot_product"
or (self.attention_kernel == "autoselected" and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE)
or (self.attention_kernel == "autoselected" and length < 128)
or (self.attention_kernel == "paged")
):
return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode)
elif self.attention_kernel == "flash" or self.attention_kernel == "autoselected":
Expand Down Expand Up @@ -1070,7 +1071,9 @@ def normalize_attention(self, local_outs, local_maxes, local_sums):
return attn_out

@nn.compact
def __call__(self, query, key, value, decoder_segment_ids, model_mode):
def __call__(
self, query, key, value, decoder_segment_ids, model_mode, page_state: Optional[page_manager.PageState] = None
):
prefill_kv_cache, ar_kv_cache = self.kv_cache(
key, value, decoder_segment_ids, model_mode, use_ragged_attention=self.use_ragged_attention
)
Expand Down Expand Up @@ -1199,6 +1202,21 @@ def setup(self):
use_ragged_attention=self.use_ragged_attention,
ragged_block_size=self.ragged_block_size,
)
# When paged attention is enabled, paged attention op is used for all model modes except TRAIN,
# which uses default attention op.
if self.config.attention == "paged":
self.paged_attention_op = paged_attention.PagedAttentionOp(
mesh=self.mesh,
num_pages=self.config.pagedattn_num_pages,
tokens_per_page=self.config.pagedattn_tokens_per_page,
max_pages_per_slot=self.config.max_target_length // self.config.pagedattn_tokens_per_page,
max_pages_per_prefill=self.config.max_prefill_predict_length // self.config.pagedattn_tokens_per_page,
pages_per_compute_block=self.config.pagedattn_pages_per_compute_block,
num_kv_heads=self.num_kv_heads,
kv_head_dim_size=self.head_dim,
dtype=self.dtype,
attn_logits_soft_cap=self.attn_logits_soft_cap,
)

def query_projection(self, inputs_q: Array) -> Array:
"""Query projection."""
Expand Down Expand Up @@ -1348,6 +1366,7 @@ def __call__(
*,
model_mode: str = common_types.MODEL_MODE_TRAIN,
deterministic: bool = False,
page_state: Optional[page_manager.PageState] = None,
):
"""Applies Attention on the input data.
Expand Down Expand Up @@ -1400,11 +1419,15 @@ def __call__(

assert not self.config.quantize_kvcache or self.kv_quant

out = self.attention_op(query, key, value, decoder_segment_ids, model_mode)
if self.config.attention == "paged" and model_mode != common_types.MODEL_MODE_TRAIN:
unnormalized_out, _, exp_sum = self.paged_attention_op(
query, key, value, decoder_segment_ids, model_mode, page_state=page_state
)
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
else:
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode)

out = nn.with_logical_constraint(out, self.out_axis_names)

# apply output projection, output dim is set to the input dim.
out = self.out_projection(inputs_q.shape[-1], out)
out = checkpoint_name(out, "out_proj")
return out
Expand Down
3 changes: 3 additions & 0 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from layers import quantizations

import common_types
from inference import page_manager
from typing import Optional

Array = common_types.Array
Expand Down Expand Up @@ -74,6 +75,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state: Optional[page_manager.PageState] = None,
):
cfg = self.config
mesh = self.mesh
Expand Down Expand Up @@ -124,6 +126,7 @@ def __call__(
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
page_state=page_state,
)

attention_lnx = nn.with_logical_constraint(
Expand Down
56 changes: 53 additions & 3 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
import common_types
from inference import page_manager
import pyconfig
from layers import attentions
from layers import embeddings
from layers import linears
Expand Down Expand Up @@ -379,6 +381,7 @@ def __call__(
decoder_segment_ids=None,
deterministic=False,
model_mode=common_types.MODEL_MODE_TRAIN,
page_state: Optional[page_manager.PageState] = None,
):
cfg = self.config
mesh = self.mesh
Expand Down Expand Up @@ -427,6 +430,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=page_state,
)
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
y, _ = self.scan_decoder_layers(cfg, moe_layer, num_moe_layers, "moe_layers", mesh)(
Expand All @@ -435,6 +439,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=page_state,
)
else:
RemattedBlockLayer = RemattedBlockLayers[0]
Expand All @@ -444,6 +449,7 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=page_state,
)
else:
if cfg.decoder_block == "deepseek":
Expand Down Expand Up @@ -472,8 +478,8 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state=page_state,
)

y = self.get_norm_layer()(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
Expand Down Expand Up @@ -538,15 +544,58 @@ def setup(self):

self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant)

if cfg.attention == "paged":
self.page_manager = self._create_page_manager(cfg)

def _create_page_manager(self, config: pyconfig.HyperParameters) -> Optional[page_manager.PageManager]:
"""Creates page manager for managing pages in the paged attention mechanism"""
assert config.max_target_length % config.pagedattn_tokens_per_page == 0
assert config.max_prefill_predict_length % config.pagedattn_tokens_per_page == 0
return page_manager.PageManager(
num_pages=self.config.pagedattn_num_pages,
tokens_per_page=self.config.pagedattn_tokens_per_page,
slots=int(self.config.per_device_batch_size * jax.device_count()),
max_target_length=self.config.max_target_length,
max_prefill_predict_length=self.config.max_prefill_predict_length,
max_pages_per_slot=self.config.max_target_length // self.config.pagedattn_tokens_per_page,
)

def _create_page_state(
self, model_mode: str, true_length: Optional[int] = None, slot: Optional[int] = None
) -> Optional[page_manager.PageState]:
"""Creates page state for tracking page status in the paged attention mechanism."""
if self.config.attention != "paged" or model_mode == common_types.MODEL_MODE_TRAIN:
return None
page_state = None
if model_mode == common_types.MODEL_MODE_PREFILL:
page_state = self.page_manager(
model_mode=model_mode,
slot=slot,
true_length=true_length,
)
elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
page_state = self.page_manager(model_mode)
else:
raise ValueError(f"Unsupported model_mode {model_mode} by paged attention")
return page_state

def __call__(
self,
decoder_input_tokens,
decoder_positions,
decoder_segment_ids=None,
enable_dropout=True,
model_mode=common_types.MODEL_MODE_TRAIN,
model_mode: str = common_types.MODEL_MODE_TRAIN,
true_length: Optional[int] = None,
slot: Optional[int] = None,
):
"""Applies Transformer decoder-branch on encoded-input and target."""
"""Applies Transformer decoder-branch on encoded-input and target.
Args:
true_length: (Optional) Prompt length before padding
slot: (Optional) An integer representing the decode batch index selected
for this request.
"""

if decoder_segment_ids is not None and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError(
Expand All @@ -560,5 +609,6 @@ def __call__(
decoder_segment_ids=decoder_segment_ids,
deterministic=not enable_dropout,
model_mode=model_mode,
page_state=self._create_page_state(model_mode=model_mode, true_length=true_length, slot=slot),
)
return logits
56 changes: 44 additions & 12 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def prefill(
true_length: int,
sampler: Optional[Callable[[Any], Any]] = None, # pylint: disable=unused-argument
rng: Optional[PRNGKeyType] = None,
slot: Optional[int] = None,
) -> Tuple[Prefix, engine_api.ResultTokens]:
"""Computes a kv-cache for a new generate request.
Expand All @@ -410,7 +411,9 @@ def prefill(
processed by the underlying model.
padded_tokens: Logically appended tokens to any existing prefix, this is
what we compute prefill on.
true_length: The real length of the tokens, pre-pad.
true_length: Prompt length before padding.
slot: (Optional) An integer representing the decode batch index selected
for this request.
Returns:
kv_cache: For the resulting text.
"""
Expand Down Expand Up @@ -439,6 +442,8 @@ def prefill(
model_mode=common_types.MODEL_MODE_PREFILL,
rngs={"params": new_rng},
mutable=["cache"],
true_length=true_length,
slot=slot,
)

next_pos = jnp.full((1, 1), true_length, dtype=jnp.int32)
Expand Down Expand Up @@ -476,7 +481,6 @@ def prefill(

cache = new_vars["cache"]
cache = self._maybe_stack_prefill_result_cache(cache)

return {
"logits": selected_logits,
"cache": cache,
Expand Down Expand Up @@ -618,7 +622,6 @@ def generate(
rng = jax.random.PRNGKey(0)

previous_token = decode_state["tokens"]

rng, new_rng = jax.random.split(rng)
# run one step generation
with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
Expand All @@ -631,10 +634,8 @@ def generate(
rngs={"params": new_rng},
mutable=["cache"],
)

out_logits = jax.lax.with_sharding_constraint(out_logits, self.replicated_sharding)
new_cache = jax.lax.with_sharding_constraint(new_vars["cache"], self.kv_cache_shardings)

# sampling tokens
new_token = inference_utils.sampling(
out_logits,
Expand All @@ -644,7 +645,6 @@ def generate(
nucleus_topp=self.config.decode_sampling_nucleus_p,
temperature=self.config.decode_sampling_temperature,
)

all_valid = jnp.ones(new_token.shape, dtype=jnp.int8)
result = engine_api.ResultTokens(
data=jnp.concatenate((new_token, all_valid, decode_state["generated_tokens"]), axis=1),
Expand Down Expand Up @@ -840,12 +840,44 @@ def copy(path, partial_cache, full_cache, annotations):
else:
raise ValueError(f"We don't have a strategy for inserting {path_key}")

inserted_cache = jax.tree_util.tree_map_with_path(
copy,
unboxed_prefix["cache"],
decode_state["cache"],
self.kv_cache_annotations_named,
)
if self.config.attention == "paged":

def _copy_paged(path, prefix_cache, decode_state_cache):
if path[-2].key == "page_manager":
return prefix_cache
path_key = path[-1].key
if path_key in ["key_pages", "value_pages"]:

def _update_pages(prefix_page_idx, state):
decode_state_pages, prefix_pages, page_map = state
prefix_page = jax.lax.dynamic_index_in_dim(prefix_pages, prefix_page_idx, axis=1)
decode_state_pages = jax.lax.dynamic_update_slice_in_dim(
decode_state_pages, prefix_page, page_map[prefix_page_idx], axis=1
)
return decode_state_pages, prefix_pages, page_map

decode_state_cache, _, _ = jax.lax.fori_loop(
0,
prefix["cache"]["page_manager"]["num_pages_used"].value[slot],
_update_pages,
(decode_state_cache, prefix_cache, prefix["cache"]["page_manager"]["page_map"].value[slot]),
)
return decode_state_cache
else:
raise ValueError(f"We don't have a strategy for inserting {path_key} for paged attention.")

inserted_cache = jax.tree_util.tree_map_with_path(
_copy_paged,
unboxed_prefix["cache"],
decode_state["cache"],
)
else:
inserted_cache = jax.tree_util.tree_map_with_path(
copy,
unboxed_prefix["cache"],
decode_state["cache"],
self.kv_cache_annotations_named,
)
inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state["logits"], unboxed_prefix["logits"], slot, 0)
inserted_next_pos = jax.lax.dynamic_update_index_in_dim(decode_state["next_pos"], unboxed_prefix["next_pos"], slot, 0)
inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim(
Expand Down
2 changes: 0 additions & 2 deletions MaxText/tests/inference/paged_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,10 @@ def setUp(self):
self._max_prefill_predict_length = 512
self._max_target_length = 1024
self._dtype = jnp.float32

# PagedAttention settings
self._num_pages = 64
self._tokens_per_page = 32
self._pages_per_compute_block = 16

self.rng = jax.random.PRNGKey(42)
devices = jax.devices()
if len(devices) > 1:
Expand Down

0 comments on commit 0e620e5

Please sign in to comment.