Skip to content

Commit

Permalink
[Model] Migrate Mistral to use PagedKVCache (#1967)
Browse files Browse the repository at this point in the history
This PR migrates the mistral model to the PagedKVCache interface
which supports sliding window attention with paged attention kernel
written in TensorIR.

We thereby introduce a `support_sliding_window` mode for KV cache,
which leaves space for supporting sliding window for any model at
runtime.

This PR tests the mistral on with both chat and serve.
The chat performance of Mistral 7B gets improvement than before,
benefitted from the paged attention implementation.
  • Loading branch information
MasterJH5574 authored Mar 16, 2024
1 parent d6b86d1 commit edffce4
Show file tree
Hide file tree
Showing 35 changed files with 627 additions and 918 deletions.
122 changes: 40 additions & 82 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ struct FunctionTable {
this->reset_kv_cache_func_ = get_global_func("vm.builtin.kv_state_clear");
this->kv_cache_add_sequence_func_ = get_global_func("vm.builtin.kv_state_add_sequence");
this->kv_cache_remove_sequence_func_ = get_global_func("vm.builtin.kv_state_remove_sequence");
this->kv_cache_enable_sliding_window_for_seq_ =
get_global_func("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq");
this->kv_cache_begin_forward_func_ = get_global_func("vm.builtin.kv_state_begin_forward");
this->kv_cache_end_forward_func_ = get_global_func("vm.builtin.kv_state_end_forward");
this->fkvcache_array_popn_ = get_global_func("vm.builtin.kv_state_popn");
Expand Down Expand Up @@ -345,6 +347,7 @@ struct FunctionTable {
PackedFunc reset_kv_cache_func_;
PackedFunc kv_cache_add_sequence_func_;
PackedFunc kv_cache_remove_sequence_func_;
PackedFunc kv_cache_enable_sliding_window_for_seq_;
PackedFunc kv_cache_begin_forward_func_;
PackedFunc kv_cache_end_forward_func_;
bool support_backtracking_kv_;
Expand Down Expand Up @@ -663,12 +666,17 @@ class LLMChat {
this->params_ = ft_.LoadParams(model_path, device_, use_presharded_weights_);
// Step 6. KV cache creation.
if (ft_.use_kv_state == FunctionTable::KVStateKind::kAttention) {
int max_total_seq_length =
this->max_window_size_ == -1 ? this->sliding_window_size_ : this->max_window_size_;
ICHECK_GT(max_total_seq_length, 0);
IntTuple max_num_sequence{1};
IntTuple max_total_sequence_length{this->max_window_size_};
IntTuple max_total_sequence_length{max_total_seq_length};
IntTuple prefill_chunk_size{this->prefill_chunk_size_};
IntTuple page_size{16};
this->kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence, max_total_sequence_length,
prefill_chunk_size, page_size);
IntTuple support_sliding_window{sliding_window_size_ != -1};
this->kv_cache_ =
ft_.create_kv_cache_func_(max_num_sequence, max_total_sequence_length, prefill_chunk_size,
page_size, support_sliding_window);
} else if (ft_.use_kv_state == FunctionTable::KVStateKind::kRNNState) {
IntTuple max_num_sequence{1};
IntTuple max_history_length{1};
Expand Down Expand Up @@ -697,8 +705,6 @@ class LLMChat {
this->ResetRuntimeStats();
this->ResetKVCache();
this->total_seq_len_ = 0;
this->sliding_window_cache_offset_ = 0;
this->sink_triggered_ = false;
}

/*! \brief reset the runtime stats. */
Expand Down Expand Up @@ -984,19 +990,6 @@ class LLMChat {
std::vector<int32_t>(prompt_tokens.begin() + begin, prompt_tokens.begin() + end);
new_seq_len += static_cast<int64_t>(chunk.size());
logits_on_device = this->ForwardTokens(chunk, new_seq_len);

// update window cache offset (prefill)
if (this->sliding_window_size_ != -1) {
if (sink_triggered_) {
sliding_window_cache_offset_ =
std::max((sliding_window_cache_offset_ + static_cast<int64_t>(chunk.size())) %
sliding_window_size_,
attention_sink_size_);
} else {
sliding_window_cache_offset_ += static_cast<int64_t>(chunk.size());
sink_triggered_ = sliding_window_cache_offset_ >= attention_sink_size_;
}
}
}
ICHECK_EQ(new_seq_len, total_seq_len_ + token_len) << "Expect chunking process all tokens";
} else {
Expand Down Expand Up @@ -1035,18 +1028,6 @@ class LLMChat {

NDArray logits_on_device = this->ForwardTokens({last_token}, total_seq_len_ + 1);
total_seq_len_ += 1;

// update window cache offset (decoding)
if (this->sliding_window_size_ != -1) {
if (sink_triggered_) {
sliding_window_cache_offset_ = std::max(
(sliding_window_cache_offset_ + 1) % sliding_window_size_, attention_sink_size_);
} else {
sliding_window_cache_offset_ += 1;
sink_triggered_ = sliding_window_cache_offset_ >= attention_sink_size_;
}
}

int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config);

auto tend = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -1372,32 +1353,20 @@ class LLMChat {
ObjectRef ret{nullptr};
if (input_tokens.size() > 1 && ft_.prefill_func_.defined()) {
ObjectRef input_data = ft_.CopyToWorker0(this->GetInputTokenNDArray(input_tokens));
if (sliding_window_size_ == -1) {
if (ft_.use_kv_state) {
int input_len = input_tokens.size();
IntTuple seq_ids_tuple({0});
ShapeTuple input_len_shape{input_len};
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, input_len_shape);
input_data = ft_.nd_view_func_(input_data, input_len_shape);
auto embed = ft_.embed_func_(input_data, params_);
ShapeTuple embedding_shape = {1, input_len, GetHiddenSizeFromEmbedding(embed)};
embed = ft_.nd_view_func_(embed, embedding_shape);
ret = ft_.prefill_func_(embed, kv_cache_, params_);
ft_.kv_cache_end_forward_func_(kv_cache_);
} else {
ShapeTuple cur_pos_shape = ShapeTuple({cur_pos});
ret = ft_.prefill_func_(input_data, cur_pos_shape, kv_cache_, params_);
}
if (ft_.use_kv_state) {
int input_len = input_tokens.size();
IntTuple seq_ids_tuple({0});
ShapeTuple input_len_shape{input_len};
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, input_len_shape);
input_data = ft_.nd_view_func_(input_data, input_len_shape);
auto embed = ft_.embed_func_(input_data, params_);
ShapeTuple embedding_shape = {1, input_len, GetHiddenSizeFromEmbedding(embed)};
embed = ft_.nd_view_func_(embed, embedding_shape);
ret = ft_.prefill_func_(embed, kv_cache_, params_);
ft_.kv_cache_end_forward_func_(kv_cache_);
} else {
// Sliding window attention needs extra shape parameters
int64_t seq_len = static_cast<int64_t>(input_tokens.size());
// Number of elements in the cache
int64_t cache_len = std::min(this->sliding_window_size_, cur_pos - seq_len);
ShapeTuple cache_len_shape = ShapeTuple({cache_len});
ShapeTuple kv_seq_len_shape = ShapeTuple({cache_len + seq_len});
ShapeTuple cache_offset_shape = ShapeTuple({sliding_window_cache_offset_});
ret = ft_.prefill_func_(input_data, cache_len_shape, kv_seq_len_shape, cache_offset_shape,
kv_cache_, params_);
ShapeTuple cur_pos_shape = ShapeTuple({cur_pos});
ret = ft_.prefill_func_(input_data, cur_pos_shape, kv_cache_, params_);
}
} else {
// running decode function when prefill is not available
Expand All @@ -1412,30 +1381,18 @@ class LLMChat {
}
int64_t pos = cur_pos + i + 1 - input_tokens.size();
ShapeTuple pos_shape = ShapeTuple({pos});
if (sliding_window_size_ == -1) {
if (ft_.use_kv_state) {
IntTuple seq_ids_tuple({0});
IntTuple append_length({1});
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, append_length);
input_data = ft_.nd_view_func_(input_data, append_length);
auto embed = ft_.embed_func_(input_data, params_);
ShapeTuple embedding_shape = {1, 1, GetHiddenSizeFromEmbedding(embed)};
embed = ft_.nd_view_func_(embed, embedding_shape);
ret = ft_.decode_func_(embed, kv_cache_, params_);
ft_.kv_cache_end_forward_func_(kv_cache_);
} else {
ret = ft_.decode_func_(input_data, pos_shape, kv_cache_, params_);
}
if (ft_.use_kv_state) {
IntTuple seq_ids_tuple({0});
IntTuple append_length({1});
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, append_length);
input_data = ft_.nd_view_func_(input_data, append_length);
auto embed = ft_.embed_func_(input_data, params_);
ShapeTuple embedding_shape = {1, 1, GetHiddenSizeFromEmbedding(embed)};
embed = ft_.nd_view_func_(embed, embedding_shape);
ret = ft_.decode_func_(embed, kv_cache_, params_);
ft_.kv_cache_end_forward_func_(kv_cache_);
} else {
// Sliding window attention needs extra shape parameters
int64_t seq_len = static_cast<int64_t>(input_tokens.size());
// Number of elements in the cache
int64_t cache_len = std::min(this->sliding_window_size_, pos - seq_len);
ShapeTuple cache_len_shape = ShapeTuple({cache_len});
ShapeTuple kv_seq_len_shape = ShapeTuple({cache_len + seq_len});
ShapeTuple cache_offset_shape = ShapeTuple({sliding_window_cache_offset_});
ret = ft_.decode_func_(input_data, cache_len_shape, kv_seq_len_shape, cache_offset_shape,
kv_cache_, params_);
ret = ft_.decode_func_(input_data, pos_shape, kv_cache_, params_);
}
}
}
Expand Down Expand Up @@ -1553,6 +1510,11 @@ class LLMChat {
ft_.reset_kv_cache_func_(kv_cache_);
if (ft_.use_kv_state) {
ft_.kv_cache_add_sequence_func_(kv_cache_, 0);
if (sliding_window_size_ != -1) {
int attention_sink_size = std::max(static_cast<int>(attention_sink_size_), 0);
ft_.kv_cache_enable_sliding_window_for_seq_(kv_cache_, 0, sliding_window_size_,
attention_sink_size);
}
}
}

Expand Down Expand Up @@ -1624,10 +1586,6 @@ class LLMChat {
std::string output_message_;
// Whether encounter stop str
bool stop_triggered_{false};
// Whether sink is in action
bool sink_triggered_{false};
// sliding window cache offset
int64_t sliding_window_cache_offset_{0};
//----------------------------
// Model configurations
//----------------------------
Expand Down
7 changes: 3 additions & 4 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,9 @@ KVCacheConfig::KVCacheConfig(const std::string& config_str, int max_single_seque
if (config.count("max_num_sequence")) {
CHECK(config["max_num_sequence"].is<int64_t>());
max_num_sequence = config["max_num_sequence"].get<int64_t>();
}

if (max_num_sequence == -1) {
max_num_sequence = max_total_sequence_length / max_single_sequence_length;
CHECK_GT(max_num_sequence, 0) << "Max number of sequence should be positive.";
} else {
LOG(FATAL) << "Key \"max_num_sequence\" not found.";
}

ObjectPtr<KVCacheConfigNode> n = make_object<KVCacheConfigNode>();
Expand Down
28 changes: 16 additions & 12 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ class EngineImpl : public Engine {
CHECK_GE(model_infos.size(), 1) << "ValueError: No model is provided in the engine.";
// Step 1. Initialize metadata and singleton states inside the engine
this->estate_->Reset();
this->max_single_sequence_length_ = max_single_sequence_length;
// Being "-1" means there is no limit on single sequence length.
this->max_single_sequence_length_ = max_single_sequence_length != -1
? max_single_sequence_length
: std::numeric_limits<int>::max();
this->kv_cache_config_ = KVCacheConfig(kv_cache_config_json_str, max_single_sequence_length);
this->engine_mode_ = EngineMode(engine_mode_json_str);
this->request_stream_callback_ = std::move(request_stream_callback);
Expand Down Expand Up @@ -140,6 +143,17 @@ class EngineImpl : public Engine {
// Get a request copy where all text inputs are tokenized.
request = Request::FromUntokenized(request, tokenizer_);
ICHECK_NE(request->input_total_length, -1);

if (request->input_total_length >= kv_cache_config_->prefill_chunk_size) {
// If the request input length exceeds the prefill chunk size,
// invoke callback and do not process the request.
// Todo(mlc-team): Use "maximum single sequence length" after impl input chunking.
Array<RequestStreamOutput> output{RequestStreamOutput(
request->id, {}, Optional<Array<Array<String>>>(), {String("length")})};
request_stream_callback_.value()(std::move(output));
return;
}

// Append to the waiting queue and create the request state.
estate_->waiting_queue.push_back(request);

Expand Down Expand Up @@ -189,21 +203,11 @@ class EngineImpl : public Engine {
// The request to abort is in running queue
estate_->running_queue.erase(it_running);

// Reduce the input length.
estate_->stats.current_total_seq_len -= request->input_total_length;
// Reduce the generated length.
for (int i = 0; i < static_cast<int>(rstate->entries.size()); ++i) {
for (int i = static_cast<int>(rstate->entries.size()) - 1; i >= 0; --i) {
if (rstate->entries[i]->status != RequestStateStatus::kAlive) {
continue;
}
estate_->stats.current_total_seq_len -=
rstate->entries[i]->mstates[0]->committed_tokens.size();
RemoveRequestFromModel(estate_, rstate->entries[i]->mstates[0]->internal_id, models_);
if (rstate->entries[i]->child_indices.empty()) {
// For each running leaf state, length 1 is over reduced since the last
// token is not added into KV cache. So we add the length back.
++estate_->stats.current_total_seq_len;
}
}
}
if (it_waiting != estate_->waiting_queue.end()) {
Expand Down
19 changes: 0 additions & 19 deletions cpp/serve/engine_actions/action_commons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ void ProcessFinishedRequestStateEntries(std::vector<RequestStateEntry> finished_
// Remove the request state entry from all the models.
RemoveRequestFromModel(estate, rsentry->mstates[0]->internal_id, models);
estate->id_manager.RecycleId(rsentry->mstates[0]->internal_id);
estate->stats.current_total_seq_len -=
static_cast<int>(rsentry->mstates[0]->committed_tokens.size()) - 1;

RequestState rstate = estate->GetRequestState(rsentry->request);
int parent_idx = rsentry->parent_idx;
Expand All @@ -51,16 +49,11 @@ void ProcessFinishedRequestStateEntries(std::vector<RequestStateEntry> finished_
// Remove the request state entry from all the models.
RemoveRequestFromModel(estate, rstate->entries[parent_idx]->mstates[0]->internal_id, models);
estate->id_manager.RecycleId(rstate->entries[parent_idx]->mstates[0]->internal_id);
estate->stats.current_total_seq_len -=
static_cast<int>(rstate->entries[parent_idx]->mstates[0]->committed_tokens.size());
// Climb up to the parent.
parent_idx = rstate->entries[parent_idx]->parent_idx;
}

if (parent_idx == -1) {
// All request state entries of the request have been removed.
// Reduce the total input length from the engine stats.
estate->stats.current_total_seq_len -= rsentry->request->input_total_length;
// Remove from running queue and engine state.
auto it =
std::find(estate->running_queue.begin(), estate->running_queue.end(), rsentry->request);
Expand Down Expand Up @@ -163,18 +156,6 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate,
// - Update `inputs` for future prefill.
RECORD_EVENT(trace_recorder, rsentry->request->id, "preempt");
rsentry->status = RequestStateStatus::kPending;
estate->stats.current_total_seq_len -= rsentry->mstates[0]->committed_tokens.size();
if (rsentry->child_indices.empty()) {
// The length was overly decreased by 1 when the entry has no child.
++estate->stats.current_total_seq_len;
}
if (rsentry->parent_idx == -1) {
// Subtract the input length from the total length when the
// current entry is the root entry of the request.
estate->stats.current_total_seq_len -= request->input_total_length;
}
estate->stats.current_total_seq_len -=
request->input_total_length + rsentry->mstates[0]->committed_tokens.size() - 1;
for (RequestModelState mstate : rsentry->mstates) {
mstate->RemoveAllDraftTokens();
ICHECK(mstate->inputs.empty());
Expand Down
1 change: 0 additions & 1 deletion cpp/serve/engine_actions/batch_decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class BatchDecodeActionObj : public EngineActionObj {

// NOTE: Right now we only support decode all the running request states at a time.
int num_rsentries = running_rsentries.size();
estate->stats.current_total_seq_len += num_rsentries;
// Collect
// - the last committed token,
// - the request id,
Expand Down
1 change: 0 additions & 1 deletion cpp/serve/engine_actions/batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ class BatchVerifyActionObj : public EngineActionObj {
rsentries[i]->mstates[verify_model_id_]->CommitToken(sample_result);
rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result);
}
estate->stats.current_total_seq_len += accept_length;
estate->stats.total_accepted_length += accept_length;
// - Minus one because the last draft token has no kv cache entry
// - Take max with 0 in case of all accepted.
Expand Down
Loading

0 comments on commit edffce4

Please sign in to comment.