Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference PagedAttention] Integrate initial paged attention implementation into maxengine (2/N) #1336

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

wyzhang
Copy link
Collaborator

@wyzhang wyzhang commented Mar 2, 2025

This change is based on a branch from Pate and Rupeng with code refactoring and modifications.

What:

  • This PR integrate initial paged attention components into maxengine, guarded behind attention=paged config setting.

Impact of this change:

  • This PR is a NOOP. Paged attention is not enabled unless attention=paged is set in the config. The default attention=autoselected will NOT trigger paged attention.

Key changes:

  • MaxText/layers/attentions.py: Use paged attention op when attention=paged for all model mode other than MODEL_MODE_TRAIN
  • MaxText/layers/models.py: Initialize paged attention components when attention=paged

Why:

  • Page attention should be able to enhance inference performance.

Testing:

  • python -m unittest tests/inference/paged_attention_test.py
  • python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 \ load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_
    max_prefill_predict_length=16 max_target_length=32 model_name=llama2-7b
    ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1
    scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1
    checkpoint_is_quantized=true quantization=int8
    attention=paged pagedattn_num_pages=64 pagedattn_tokens_per_page=8 pagedattn_pages_per_compute_block=4

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@vipannalla vipannalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vipannalla
Copy link
Collaborator

@richjames0 , can you also take a look and LGTM?

@wyzhang wyzhang force-pushed the wyzhang/page/1-n-0228 branch 2 times, most recently from f93f36a to 6814eaa Compare March 4, 2025 03:48
Base automatically changed from wyzhang/page/1-n-0228 to main March 4, 2025 04:58
@wyzhang wyzhang force-pushed the wyzhang/page/2-n-0228 branch 2 times, most recently from 18c9b2b to 0e620e5 Compare March 4, 2025 05:49
@wyzhang wyzhang force-pushed the wyzhang/page/2-n-0228 branch from 0e620e5 to c131d50 Compare March 4, 2025 21:26
@wyzhang wyzhang force-pushed the wyzhang/page/2-n-0228 branch 3 times, most recently from ace9d03 to 887e944 Compare March 5, 2025 05:11
@@ -1520,11 +1546,15 @@ def __call__(

assert not self.config.quantize_kvcache or self.kv_quant

out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, previous_chunk)
if self.config.attention == "paged" and model_mode != common_types.MODEL_MODE_TRAIN:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is better to do:

and (model_mode == common_types.MODEL_MODE_PREFILLL or model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE)

decode_state["cache"],
self.kv_cache_annotations_named,
)
if self.config.attention == "paged":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you run test for this chunk of code?

@xy12181 xy12181 self-requested a review March 5, 2025 05:42
@wyzhang wyzhang force-pushed the wyzhang/page/2-n-0228 branch 2 times, most recently from 83ee89a to 570215d Compare March 5, 2025 06:23
@wyzhang wyzhang force-pushed the wyzhang/page/2-n-0228 branch 4 times, most recently from ea4e683 to f6afd83 Compare March 5, 2025 17:26
…tation into maxengine (2/N)

This change is based on a branch from Pate and Rupeng with code refactoring and modifications.

What:
* This PR integrate initial paged attention components into maxengine, guarded behind
  `attention=paged` config setting.

Impact of this change:
* This PR is a NOOP. Paged attention is not enabled unless `attention=paged` is set in
  the config. The default `attention=autoselected` will NOT trigger paged attention.

Key changes:
* MaxText/layers/attentions.py: Use paged attention op when `attention=paged` for all
  model mode other than MODEL_MODE_TRAIN
* MaxText/layers/models.py: Initialize paged attention components when `attention=paged`

Why:
* Page attention should be able to enhance inference performance.

Testing:
* python -m unittest tests/inference/paged_attention_test.py
* python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2  \
    load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_  \
    max_prefill_predict_length=16 max_target_length=32 model_name=llama2-7b   \
    ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 \
    scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1   \
    checkpoint_is_quantized=true quantization=int8 \
    attention=paged pagedattn_num_pages=64 pagedattn_tokens_per_page=8 pagedattn_pages_per_compute_block=4
@wyzhang wyzhang force-pushed the wyzhang/page/2-n-0228 branch from f6afd83 to 6fd47d0 Compare March 5, 2025 18:11
copybara-service bot pushed a commit that referenced this pull request Mar 5, 2025
--
887e944 by Wangyuan Zhang <[email protected]>:

[Inference PagedAttention] Integrate initial paged attention implementation into maxengine (2/N)

This change is based on a branch from Pate and Rupeng with code refactoring and modifications.

What:
* This PR integrate initial paged attention components into maxengine, guarded behind
  `attention=paged` config setting.

Impact of this change:
* This PR is a NOOP. Paged attention is not enabled unless `attention=paged` is set in
  the config. The default `attention=autoselected` will NOT trigger paged attention.

Key changes:
* MaxText/layers/attentions.py: Use paged attention op when `attention=paged` for all
  model mode other than MODEL_MODE_TRAIN
* MaxText/layers/models.py: Initialize paged attention components when `attention=paged`

Why:
* Page attention should be able to enhance inference performance.

Testing:
* python -m unittest tests/inference/paged_attention_test.py
* python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2  \
    load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_  \
    max_prefill_predict_length=16 max_target_length=32 model_name=llama2-7b   \
    ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 \
    scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1   \
    checkpoint_is_quantized=true quantization=int8 \
    attention=paged pagedattn_num_pages=64 pagedattn_tokens_per_page=8 pagedattn_pages_per_compute_block=4
COPYBARA_INTEGRATE_REVIEW=#1336 from AI-Hypercomputer:wyzhang/page/2-n-0228 887e944
PiperOrigin-RevId: 733799896
Lumosis pushed a commit that referenced this pull request Mar 6, 2025
--
887e944 by Wangyuan Zhang <[email protected]>:

[Inference PagedAttention] Integrate initial paged attention implementation into maxengine (2/N)

This change is based on a branch from Pate and Rupeng with code refactoring and modifications.

What:
* This PR integrate initial paged attention components into maxengine, guarded behind
  `attention=paged` config setting.

Impact of this change:
* This PR is a NOOP. Paged attention is not enabled unless `attention=paged` is set in
  the config. The default `attention=autoselected` will NOT trigger paged attention.

Key changes:
* MaxText/layers/attentions.py: Use paged attention op when `attention=paged` for all
  model mode other than MODEL_MODE_TRAIN
* MaxText/layers/models.py: Initialize paged attention components when `attention=paged`

Why:
* Page attention should be able to enhance inference performance.

Testing:
* python -m unittest tests/inference/paged_attention_test.py
* python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2  \
    load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_  \
    max_prefill_predict_length=16 max_target_length=32 model_name=llama2-7b   \
    ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 \
    scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1   \
    checkpoint_is_quantized=true quantization=int8 \
    attention=paged pagedattn_num_pages=64 pagedattn_tokens_per_page=8 pagedattn_pages_per_compute_block=4
COPYBARA_INTEGRATE_REVIEW=#1336 from AI-Hypercomputer:wyzhang/page/2-n-0228 887e944
PiperOrigin-RevId: 733799896
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants