From 570215da208562cfab6c877119bed9a3139bafb0 Mon Sep 17 00:00:00 2001 From: Wangyuan Zhang Date: Mon, 3 Mar 2025 21:48:49 -0800 Subject: [PATCH] [Inference PagedAttention] Integrate initial paged attention implementation into maxengine (2/N) This change is based on a branch from Pate and Rupeng with code refactoring and modifications. What: * This PR integrate initial paged attention components into maxengine, guarded behind `attention=paged` config setting. Impact of this change: * This PR is a NOOP. Paged attention is not enabled unless `attention=paged` is set in the config. The default `attention=autoselected` will NOT trigger paged attention. Key changes: * MaxText/layers/attentions.py: Use paged attention op when `attention=paged` for all model mode other than MODEL_MODE_TRAIN * MaxText/layers/models.py: Initialize paged attention components when `attention=paged` Why: * Page attention should be able to enhance inference performance. Testing: * python -m unittest tests/inference/paged_attention_test.py * python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 \ load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_ \ max_prefill_predict_length=16 max_target_length=32 model_name=llama2-7b \ ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 \ scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1 \ checkpoint_is_quantized=true quantization=int8 \ attention=paged pagedattn_num_pages=64 pagedattn_tokens_per_page=8 pagedattn_pages_per_compute_block=4 --- MaxText/decode.py | 14 +++-- MaxText/inference/paged_attention.py | 9 ++- MaxText/layers/attentions.py | 41 +++++++++++-- MaxText/layers/deepseek.py | 2 + MaxText/layers/gemma.py | 1 + MaxText/layers/gemma2.py | 1 + MaxText/layers/gpt3.py | 1 + MaxText/layers/llama2.py | 3 + MaxText/layers/mistral.py | 2 + MaxText/layers/models.py | 60 ++++++++++++++++++- MaxText/layers/simple_layer.py | 4 +- MaxText/maxengine.py | 53 ++++++++++++---- .../tests/inference/paged_attention_test.py | 2 - 13 files changed, 163 insertions(+), 30 deletions(-) diff --git a/MaxText/decode.py b/MaxText/decode.py index 8a59f0dcf..826558629 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -47,15 +47,20 @@ def main(argv: Sequence[str]) -> None: assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet" assert config.quantization != "nanoo_fp8", "NANOO fp8 on AMD MI300/MI325 GPUs is not supported in decode.py yet" - # Split RNG before calling prefill - rng, rng_prefill = jax.random.split(rng) - prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill) - slot = 0 + slot = 0 # Always use decode batch slot 0. + # Prefill + rng, rng_prefill = jax.random.split(rng) # Split RNG before calling prefill + prefill_result, first_token = engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill, slot=slot + ) + + # Insert rng, rng_init_decode = jax.random.split(rng) decode_state = engine.init_decode_state(rng_init_decode) decode_state = engine.insert(prefill_result, decode_state, slot=slot) + # Generate steps = range(config.max_prefill_predict_length, config.max_target_length) sampled_tokens_list = [] sampled_tokens_list.append(first_token) @@ -64,6 +69,7 @@ def main(argv: Sequence[str]) -> None: decode_state, sampled_tokens = engine.generate(params, decode_state, rng=rng_generate) sampled_tokens_list.append(sampled_tokens) + # Get results results = [sampled_tokens.get_result_at_slot(slot).tokens.item() for sampled_tokens in sampled_tokens_list] output = tokenizer_model.decode(results) print(f"Input `{text}` -> `{output}`") diff --git a/MaxText/inference/paged_attention.py b/MaxText/inference/paged_attention.py index fcd88fd25..5baaabe15 100644 --- a/MaxText/inference/paged_attention.py +++ b/MaxText/inference/paged_attention.py @@ -24,11 +24,13 @@ import jax.numpy as jnp from flax import linen as nn from jax.experimental import shard_map -from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention +from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention_kernel from jax.sharding import PartitionSpec as P from inference import page_manager +# pytype: disable=attribute-error + Mesh = common_types.Mesh Array = common_types.Array @@ -166,7 +168,7 @@ def paged_attention( ) def wrap_paged_attention(q, k_pages, v_pages, lengths, page_indices, pages_per_compute_block): q = jnp.squeeze(q, axis=1) - result = paged_attention( + result = paged_attention_kernel.paged_attention( q=q, k_pages=k_pages, v_pages=v_pages, @@ -193,7 +195,8 @@ def __call__( value: Array, decoder_segment_ids: Array, model_mode: str, - page_state: page_manager.PageState, + previous_chunk=None, + page_state: Optional[page_manager.PageState] = None, ) -> Array: """Apply paged attention mechanism. diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index a63f356cb..fee603dfd 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -32,12 +32,12 @@ import common_types from kernels.ragged_attention import ragged_gqa from kernels.ragged_attention import ragged_mha +from inference import page_manager, paged_attention from layers import embeddings from layers import initializers from layers import linears from layers import quantizations - # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes # pytype: disable=attribute-error @@ -279,6 +279,7 @@ def apply_attention( self.attention_kernel == "dot_product" or (self.attention_kernel == "autoselected" and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE) or (self.attention_kernel == "autoselected" and length < 128) + or (self.attention_kernel == "paged") ): return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode, previous_chunk) elif self.attention_kernel == "flash" or self.attention_kernel == "autoselected": @@ -1188,7 +1189,16 @@ def normalize_attention(self, local_outs, local_maxes, local_sums): return attn_out @nn.compact - def __call__(self, query, key, value, decoder_segment_ids, model_mode, previous_chunk=None): + def __call__( + self, + query, + key, + value, + decoder_segment_ids, + model_mode, + previous_chunk=None, + page_state: Optional[page_manager.PageState] = None, + ): prefill_kv_cache, ar_kv_cache = self.kv_cache( key, value, decoder_segment_ids, model_mode, use_ragged_attention=self.use_ragged_attention ) @@ -1318,6 +1328,21 @@ def setup(self): use_ragged_attention=self.use_ragged_attention, ragged_block_size=self.ragged_block_size, ) + # When paged attention is enabled, paged attention op is used for all model modes except TRAIN, + # which uses default attention op. + if self.config.attention == "paged": + self.paged_attention_op = paged_attention.PagedAttentionOp( + mesh=self.mesh, + num_pages=self.config.pagedattn_num_pages, + tokens_per_page=self.config.pagedattn_tokens_per_page, + max_pages_per_slot=self.config.max_target_length // self.config.pagedattn_tokens_per_page, + max_pages_per_prefill=self.config.max_prefill_predict_length // self.config.pagedattn_tokens_per_page, + pages_per_compute_block=self.config.pagedattn_pages_per_compute_block, + num_kv_heads=self.num_kv_heads, + kv_head_dim_size=self.head_dim, + dtype=self.dtype, + attn_logits_soft_cap=self.attn_logits_soft_cap, + ) def query_projection(self, inputs_q: Array) -> Array: """Query projection.""" @@ -1468,6 +1493,7 @@ def __call__( model_mode: str = common_types.MODEL_MODE_TRAIN, deterministic: bool = False, previous_chunk: Any = None, + page_state: Optional[page_manager.PageState] = None, ): """Applies Attention on the input data. @@ -1520,11 +1546,15 @@ def __call__( assert not self.config.quantize_kvcache or self.kv_quant - out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, previous_chunk) + if self.config.attention == "paged" and model_mode != common_types.MODEL_MODE_TRAIN: + unnormalized_out, _, exp_sum = self.paged_attention_op( + query, key, value, decoder_segment_ids, model_mode, previous_chunk, page_state=page_state + ) + out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out + else: + out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, previous_chunk) out = nn.with_logical_constraint(out, self.out_axis_names) - - # apply output projection, output dim is set to the input dim. out = self.out_projection(inputs_q.shape[-1], out) out = checkpoint_name(out, "out_proj") return out @@ -1691,6 +1721,7 @@ def __call__( model_mode: str = common_types.MODEL_MODE_TRAIN, deterministic: bool = False, previous_chunk: Any = None, + page_state: Optional[page_manager.PageState] = None, ) -> Array: """Forward pass for MLA, reusing `AttentionOp` for the actual attention. diff --git a/MaxText/layers/deepseek.py b/MaxText/layers/deepseek.py index f03276e83..8ca53a92d 100644 --- a/MaxText/layers/deepseek.py +++ b/MaxText/layers/deepseek.py @@ -149,6 +149,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state=None, ): cfg = self.config inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) @@ -196,6 +197,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state=None, ): cfg = self.config inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) diff --git a/MaxText/layers/gemma.py b/MaxText/layers/gemma.py index 051bdf91e..587c40502 100644 --- a/MaxText/layers/gemma.py +++ b/MaxText/layers/gemma.py @@ -66,6 +66,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_manager=None, ): cfg = self.config mesh = self.mesh diff --git a/MaxText/layers/gemma2.py b/MaxText/layers/gemma2.py index 1acf8cd67..3d6424bdc 100644 --- a/MaxText/layers/gemma2.py +++ b/MaxText/layers/gemma2.py @@ -66,6 +66,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state=None, ): cfg = self.config mesh = self.mesh diff --git a/MaxText/layers/gpt3.py b/MaxText/layers/gpt3.py index cdf07f23c..873a5bdc2 100644 --- a/MaxText/layers/gpt3.py +++ b/MaxText/layers/gpt3.py @@ -278,6 +278,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state=None, ): cfg = self.config mesh = self.mesh diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index 9769edace..e76ac2973 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -31,6 +31,7 @@ from layers import quantizations import common_types +from inference import page_manager from typing import Optional Array = common_types.Array @@ -74,6 +75,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state: Optional[page_manager.PageState] = None, ): cfg = self.config mesh = self.mesh @@ -124,6 +126,7 @@ def __call__( decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + page_state=page_state, ) attention_lnx = nn.with_logical_constraint( diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index d37b9d80d..667d6c387 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -34,6 +34,7 @@ from layers import models import common_types import max_logging +from maxtext.inference import page_manager Array = common_types.Array Config = common_types.Config @@ -67,6 +68,7 @@ def __call__( deterministic, model_mode, previous_chunk=None, + page_state=None, ): cfg = self.config mesh = self.mesh diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 6b985b94b..e48395123 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -25,6 +25,7 @@ import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name import common_types +from inference import page_manager from layers import attentions from layers import embeddings from layers import linears @@ -63,6 +64,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state: Optional[page_manager.PageState] = None, ): cfg = self.config mesh = self.mesh @@ -160,7 +162,9 @@ class SequentialBlockDecoderLayers(nn.Module): quant: Quant @nn.compact - def __call__(self, inputs: jnp.ndarray, decoder_segment_ids, decoder_positions, deterministic, model_mode) -> jnp.ndarray: + def __call__( + self, inputs: jnp.ndarray, decoder_segment_ids, decoder_positions, deterministic, model_mode, page_state=None + ) -> jnp.ndarray: for lyr in range(self.num_decoder_layers): inputs = self.decoder_layer(config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant)( inputs, @@ -168,6 +172,7 @@ def __call__(self, inputs: jnp.ndarray, decoder_segment_ids, decoder_positions, decoder_positions, deterministic, model_mode, + page_state=page_state, ) return inputs @@ -380,6 +385,7 @@ def __call__( deterministic=False, model_mode=common_types.MODEL_MODE_TRAIN, previous_chunk=None, + page_state: Optional[page_manager.PageState] = None, ): cfg = self.config mesh = self.mesh @@ -428,6 +434,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state=page_state, ) num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers y, _ = self.scan_decoder_layers(cfg, moe_layer, num_moe_layers, "moe_layers", mesh)( @@ -436,6 +443,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state=page_state, ) else: RemattedBlockLayer = RemattedBlockLayers[0] @@ -445,6 +453,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state=page_state, ) else: if cfg.decoder_block == "deepseek": @@ -463,6 +472,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state=page_state, ) else: for lyr in range(cfg.num_decoder_layers): @@ -473,8 +483,8 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state=page_state, ) - y = self.get_norm_layer()( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, @@ -539,6 +549,41 @@ def setup(self): self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant) + if cfg.attention == "paged": + self.page_manager = self._create_page_manager(cfg) + + def _create_page_manager(self, config) -> Optional[page_manager.PageManager]: + """Creates page manager for managing pages in the paged attention mechanism""" + assert config.max_target_length % config.pagedattn_tokens_per_page == 0 + assert config.max_prefill_predict_length % config.pagedattn_tokens_per_page == 0 + return page_manager.PageManager( + num_pages=self.config.pagedattn_num_pages, + tokens_per_page=self.config.pagedattn_tokens_per_page, + slots=int(self.config.per_device_batch_size * jax.device_count()), + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + max_pages_per_slot=self.config.max_target_length // self.config.pagedattn_tokens_per_page, + ) + + def _create_page_state( + self, model_mode: str, true_length: Optional[int] = None, slot: Optional[int] = None + ) -> Optional[page_manager.PageState]: + """Creates page state for tracking page status in the paged attention mechanism.""" + if self.config.attention != "paged" or model_mode == common_types.MODEL_MODE_TRAIN: + return None + page_state = None + if model_mode == common_types.MODEL_MODE_PREFILL: + page_state = self.page_manager( + model_mode=model_mode, + slot=slot, + true_length=true_length, + ) + elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + page_state = self.page_manager(model_mode) + else: + raise ValueError(f"Unsupported model_mode {model_mode} by paged attention") + return page_state + def __call__( self, decoder_input_tokens, @@ -547,8 +592,16 @@ def __call__( enable_dropout=True, model_mode=common_types.MODEL_MODE_TRAIN, previous_chunk=None, + true_length: Optional[int] = None, + slot: Optional[int] = None, ): - """Applies Transformer decoder-branch on encoded-input and target.""" + """Applies Transformer decoder-branch on encoded-input and target. + + Args: + true_length: (Optional) Prompt length before padding + slot: (Optional) An integer representing the decode batch index selected + for this request. + """ if decoder_segment_ids is not None and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: raise ValueError( @@ -563,5 +616,6 @@ def __call__( deterministic=not enable_dropout, model_mode=model_mode, previous_chunk=previous_chunk, + page_state=self._create_page_state(model_mode=model_mode, true_length=true_length, slot=slot), ) return logits diff --git a/MaxText/layers/simple_layer.py b/MaxText/layers/simple_layer.py index a5d24ab04..3787f9b24 100644 --- a/MaxText/layers/simple_layer.py +++ b/MaxText/layers/simple_layer.py @@ -37,7 +37,7 @@ def setup(self): (self.config.emb_dim, self.config.emb_dim), ) - def __call__(self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode): + def __call__(self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode, page_state=None): if self.config.scan_layers: return inputs @ self.weight_mat.astype(inputs.dtype), None else: @@ -63,7 +63,7 @@ def setup(self): (self.config.mlp_dim, self.config.emb_dim), ) - def __call__(self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode): + def __call__(self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode, page_state=None): intermediate = inputs @ self.ff_1.astype(inputs.dtype) output = intermediate @ self.ff_2.astype(inputs.dtype) if self.config.scan_layers: diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index 70d6b08ea..4ebe9a755 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -418,6 +418,7 @@ def prefill( complete_padded_prompt: Optional[jax.Array] = None, positions: Optional[jax.Array] = None, previous_chunk: Optional[Any] = None, + slot: Optional[int] = None, ) -> Tuple[Prefix, engine_api.ResultTokens]: """Computes a kv-cache for a new generate request. @@ -462,6 +463,7 @@ def prefill( complete_padded_prompt = None positions = None previous_chunk = None + Returns: kv_cache: For the resulting text. """ @@ -498,6 +500,8 @@ def prefill( rngs={"params": new_rng}, mutable=["cache"], previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, ) generated_tokens = jnp.zeros((1, 1), dtype=jnp.int32) selected_logits = jax.lax.dynamic_slice( @@ -533,7 +537,6 @@ def prefill( cache = new_vars["cache"] cache = self._maybe_stack_prefill_result_cache(cache) - return { "logits": selected_logits, "cache": cache, @@ -675,7 +678,6 @@ def generate( rng = jax.random.PRNGKey(0) previous_token = decode_state["tokens"] - rng, new_rng = jax.random.split(rng) # run one step generation with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -688,10 +690,8 @@ def generate( rngs={"params": new_rng}, mutable=["cache"], ) - out_logits = jax.lax.with_sharding_constraint(out_logits, self.replicated_sharding) new_cache = jax.lax.with_sharding_constraint(new_vars["cache"], self.kv_cache_shardings) - # sampling tokens new_token = inference_utils.sampling( out_logits, @@ -701,7 +701,6 @@ def generate( nucleus_topp=self.config.decode_sampling_nucleus_p, temperature=self.config.decode_sampling_temperature, ) - all_valid = jnp.ones(new_token.shape, dtype=jnp.int8) result = engine_api.ResultTokens( data=jnp.concatenate((new_token, all_valid, decode_state["generated_tokens"]), axis=1), @@ -897,12 +896,44 @@ def copy(path, partial_cache, full_cache, annotations): else: raise ValueError(f"We don't have a strategy for inserting {path_key}") - inserted_cache = jax.tree_util.tree_map_with_path( - copy, - unboxed_prefix["cache"], - decode_state["cache"], - self.kv_cache_annotations_named, - ) + if self.config.attention == "paged": + + def _copy_paged(path, prefix_cache, decode_state_cache): + if path[-2].key == "page_manager": + return prefix_cache + path_key = path[-1].key + if path_key in ["key_pages", "value_pages"]: + + def _update_pages(prefix_page_idx, state): + decode_state_pages, prefix_pages, page_map = state + prefix_page = jax.lax.dynamic_index_in_dim(prefix_pages, prefix_page_idx, axis=1) + decode_state_pages = jax.lax.dynamic_update_slice_in_dim( + decode_state_pages, prefix_page, page_map[prefix_page_idx], axis=1 + ) + return decode_state_pages, prefix_pages, page_map + + decode_state_cache, _, _ = jax.lax.fori_loop( + 0, + prefix["cache"]["page_manager"]["num_pages_used"].value[slot], + _update_pages, + (decode_state_cache, prefix_cache, prefix["cache"]["page_manager"]["page_map"].value[slot]), + ) + return decode_state_cache + else: + raise ValueError(f"We don't have a strategy for inserting {path_key} for paged attention.") + + inserted_cache = jax.tree_util.tree_map_with_path( + _copy_paged, + unboxed_prefix["cache"], + decode_state["cache"], + ) + else: + inserted_cache = jax.tree_util.tree_map_with_path( + copy, + unboxed_prefix["cache"], + decode_state["cache"], + self.kv_cache_annotations_named, + ) inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state["logits"], unboxed_prefix["logits"], slot, 0) inserted_next_pos = jax.lax.dynamic_update_index_in_dim(decode_state["next_pos"], unboxed_prefix["next_pos"], slot, 0) inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim( diff --git a/MaxText/tests/inference/paged_attention_test.py b/MaxText/tests/inference/paged_attention_test.py index c76ae5ab4..8849762c0 100644 --- a/MaxText/tests/inference/paged_attention_test.py +++ b/MaxText/tests/inference/paged_attention_test.py @@ -40,12 +40,10 @@ def setUp(self): self._max_prefill_predict_length = 512 self._max_target_length = 1024 self._dtype = jnp.float32 - # PagedAttention settings self._num_pages = 64 self._tokens_per_page = 32 self._pages_per_compute_block = 16 - self.rng = jax.random.PRNGKey(42) devices = jax.devices() if len(devices) > 1: