-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inference PagedAttention] Integrate initial paged attention implementation into maxengine (2/N) #1336
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@richjames0 , can you also take a look and LGTM? |
f93f36a
to
6814eaa
Compare
18c9b2b
to
0e620e5
Compare
0e620e5
to
c131d50
Compare
ace9d03
to
887e944
Compare
@@ -1520,11 +1546,15 @@ def __call__( | |||
|
|||
assert not self.config.quantize_kvcache or self.kv_quant | |||
|
|||
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, previous_chunk) | |||
if self.config.attention == "paged" and model_mode != common_types.MODEL_MODE_TRAIN: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is better to do:
and (model_mode == common_types.MODEL_MODE_PREFILLL or model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE)
decode_state["cache"], | ||
self.kv_cache_annotations_named, | ||
) | ||
if self.config.attention == "paged": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have you run test for this chunk of code?
83ee89a
to
570215d
Compare
ea4e683
to
f6afd83
Compare
…tation into maxengine (2/N) This change is based on a branch from Pate and Rupeng with code refactoring and modifications. What: * This PR integrate initial paged attention components into maxengine, guarded behind `attention=paged` config setting. Impact of this change: * This PR is a NOOP. Paged attention is not enabled unless `attention=paged` is set in the config. The default `attention=autoselected` will NOT trigger paged attention. Key changes: * MaxText/layers/attentions.py: Use paged attention op when `attention=paged` for all model mode other than MODEL_MODE_TRAIN * MaxText/layers/models.py: Initialize paged attention components when `attention=paged` Why: * Page attention should be able to enhance inference performance. Testing: * python -m unittest tests/inference/paged_attention_test.py * python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 \ load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_ \ max_prefill_predict_length=16 max_target_length=32 model_name=llama2-7b \ ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 \ scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1 \ checkpoint_is_quantized=true quantization=int8 \ attention=paged pagedattn_num_pages=64 pagedattn_tokens_per_page=8 pagedattn_pages_per_compute_block=4
f6afd83
to
6fd47d0
Compare
-- 887e944 by Wangyuan Zhang <[email protected]>: [Inference PagedAttention] Integrate initial paged attention implementation into maxengine (2/N) This change is based on a branch from Pate and Rupeng with code refactoring and modifications. What: * This PR integrate initial paged attention components into maxengine, guarded behind `attention=paged` config setting. Impact of this change: * This PR is a NOOP. Paged attention is not enabled unless `attention=paged` is set in the config. The default `attention=autoselected` will NOT trigger paged attention. Key changes: * MaxText/layers/attentions.py: Use paged attention op when `attention=paged` for all model mode other than MODEL_MODE_TRAIN * MaxText/layers/models.py: Initialize paged attention components when `attention=paged` Why: * Page attention should be able to enhance inference performance. Testing: * python -m unittest tests/inference/paged_attention_test.py * python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 \ load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_ \ max_prefill_predict_length=16 max_target_length=32 model_name=llama2-7b \ ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 \ scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1 \ checkpoint_is_quantized=true quantization=int8 \ attention=paged pagedattn_num_pages=64 pagedattn_tokens_per_page=8 pagedattn_pages_per_compute_block=4 COPYBARA_INTEGRATE_REVIEW=#1336 from AI-Hypercomputer:wyzhang/page/2-n-0228 887e944 PiperOrigin-RevId: 733799896
-- 887e944 by Wangyuan Zhang <[email protected]>: [Inference PagedAttention] Integrate initial paged attention implementation into maxengine (2/N) This change is based on a branch from Pate and Rupeng with code refactoring and modifications. What: * This PR integrate initial paged attention components into maxengine, guarded behind `attention=paged` config setting. Impact of this change: * This PR is a NOOP. Paged attention is not enabled unless `attention=paged` is set in the config. The default `attention=autoselected` will NOT trigger paged attention. Key changes: * MaxText/layers/attentions.py: Use paged attention op when `attention=paged` for all model mode other than MODEL_MODE_TRAIN * MaxText/layers/models.py: Initialize paged attention components when `attention=paged` Why: * Page attention should be able to enhance inference performance. Testing: * python -m unittest tests/inference/paged_attention_test.py * python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 \ load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_ \ max_prefill_predict_length=16 max_target_length=32 model_name=llama2-7b \ ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 \ scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1 \ checkpoint_is_quantized=true quantization=int8 \ attention=paged pagedattn_num_pages=64 pagedattn_tokens_per_page=8 pagedattn_pages_per_compute_block=4 COPYBARA_INTEGRATE_REVIEW=#1336 from AI-Hypercomputer:wyzhang/page/2-n-0228 887e944 PiperOrigin-RevId: 733799896
This change is based on a branch from Pate and Rupeng with code refactoring and modifications.
What:
attention=paged
config setting.Impact of this change:
attention=paged
is set in the config. The defaultattention=autoselected
will NOT trigger paged attention.Key changes:
attention=paged
for all model mode other than MODEL_MODE_TRAINattention=paged
Why:
Testing:
max_prefill_predict_length=16 max_target_length=32 model_name=llama2-7b
ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1
scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1
checkpoint_is_quantized=true quantization=int8
attention=paged pagedattn_num_pages=64 pagedattn_tokens_per_page=8 pagedattn_pages_per_compute_block=4
Before submitting this PR, please make sure (put X in square brackets):