Skip to content

Commit

Permalink
Merge pull request #1278 from AI-Hypercomputer:yuyan-prefix-cache-dev
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733516196
  • Loading branch information
maxtext authors committed Mar 5, 2025
2 parents 0492608 + c82c472 commit ee14ae6
Show file tree
Hide file tree
Showing 4 changed files with 1,005 additions and 0 deletions.
2 changes: 2 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,8 @@ vertex_tensorboard_region: ""
max_checkify: False

# Inference
inference_microbenchmark_prefix_cache_entries_num: 100
inference_microbenchmark_prefix_cache_common_prefix_proportion: 0.5
inference_microbenchmark_prefill_lengths: "64,128,256,512,1024"
inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
Expand Down
114 changes: 114 additions & 0 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import max_utils
import maxengine
import maxtext_utils
import prefix_cache
import profiler
import pyconfig

Expand All @@ -39,6 +40,103 @@
# pylint: disable=too-many-positional-arguments


def prefix_cache_benchmark(
prefix, prefill_length: int, true_length: int, common_prefix_proportion: float, prefix_cache_entries_num: int, iters: int
):
"""Handles running prefix cache benchmark, and printing results.
Create different key with half of prefill_length common prefix insert into cache.
The value is not relevant to the cache for now. Just copy the prefix for every cache entry.
1. Fill the prefix cache to full capacity.
2. Benchmark save prefix cache with evicting time average by prefix_cache_entries_num.
3. Benchmark fetch_longest_common_prefix_key average by iters.
4. Benchmark load prefix cache time average by iters.
Args:
prefix: prefix return from prefill function
prefill_length: prefill token length after padding
true_length: true prefill token length
common_prefix_proportion: [0., 1.] common prefix proportion to the prefill_length
prefix_cache_entries_num: number of prefix cache entries insert into PrefixCache
iters: repeat time to test fetch_longest_common_prefix_key and load from cache
"""

print(f"Prefix Cache benchmark results for prefill length {prefill_length}:\n")

value = prefix_cache.Value(
prefix=prefix,
true_length=true_length,
padded_length=prefill_length,
tokens=tuple(i for i in range(prefill_length)),
)
prefix_size_bytes_gb = value.prefix_size_bytes / 1024 / 1024 / 1024
prefix_cache_inst = prefix_cache.PrefixCache(prefix_cache_entries_num * value.prefix_size_bytes)
common_len = int(prefill_length * common_prefix_proportion)
remain_len = prefill_length - common_len
common_prefix_key = tuple(i for i in range(common_len))

# Fill the prefix caching
new_value_list = []
for c_idx in range(prefix_cache_entries_num):
# Add 100 to make sure filled prefix caching will not share the common_prefix_key.
# The later save prefix part will evict all of them.
key = tuple(100 + i + c_idx * prefill_length for i in range(prefill_length))
new_value = value.clone()
prefix_cache_inst.save(key, new_value)
new_value_list.append(new_value)
jax.block_until_ready(new_value_list)
del new_value_list

# Save prefix
new_value = None
save_sec = 0
for c_idx in range(iters):
key = common_prefix_key + tuple(i + c_idx * remain_len for i in range(remain_len))
# values are not relevant for caching now, just clone the same tokens and values for test
new_value = value.clone()
jax.block_until_ready(new_value)
start = datetime.datetime.now()
prefix_cache_inst.save(key, new_value)
end = datetime.datetime.now()
save_sec += (end - start).total_seconds()
del new_value
save_avg_ms = save_sec * 1000 / iters

# Fetch longest prefix key
key_load = common_prefix_key + tuple(i + prefix_cache_entries_num * remain_len for i in range(remain_len))
matched_key = None
fetch_sec = 0
for _ in range(iters):
start = datetime.datetime.now()
matched_key = prefix_cache_inst.fetch_longest_common_prefix_key(key_load)
end = datetime.datetime.now()
fetch_sec += (end - start).total_seconds()
fetch_avg_ms = fetch_sec * 1000 / iters

assert matched_key is not None

# Load prefix
load_sec = 0
value_load = None
for _ in range(iters):
start = datetime.datetime.now()
value = prefix_cache_inst.load(matched_key)
jax.block_until_ready(value)
end = datetime.datetime.now()
load_sec += (end - start).total_seconds()
del value_load
load_avg_ms = load_sec * 1000 / iters

print(
f"PrefixCaching results:\n"
f"\tPer prefix size bytes: {prefix_size_bytes_gb:.3f} GB\n"
f"\tAverage save cache time: {save_avg_ms:.3f} ms\n"
f"\tAverage fetch longest prefix time: {fetch_avg_ms:.3f} ms\n"
f"\tAverage load cache time: {load_avg_ms:.3f} ms\n\n\n"
)
del prefix_cache_inst


def prefill_benchmark_loop(engine_prefill, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
start = datetime.datetime.now()
Expand Down Expand Up @@ -305,6 +403,22 @@ def run_benchmarks(config):
prefill_executable[prefill_length], params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length]
)

if "prefix_cache" in stages_to_benchmark:
for prefill_length in prefill_lengths:
rng_cache = jax.random.PRNGKey(1234)
prefill_result, _ = prefill_executable[prefill_length](
params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length], rng_cache
)
prefix_cache_benchmark(
prefill_result,
prefill_length,
prefill_true_lengths[prefill_length],
config.inference_microbenchmark_prefix_cache_common_prefix_proportion,
config.inference_microbenchmark_prefix_cache_entries_num,
benchmark_loop_iters,
)
del prefill_result

for prefill_length in prefill_lengths:
benchmark_results["prefill"][prefill_length] = prefill_benchmark(
config,
Expand Down
Loading

0 comments on commit ee14ae6

Please sign in to comment.