-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat - support mla kvcache store #888
base: main
Are you sure you want to change the base?
feat - support mla kvcache store #888
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @baowendin , thanks for the contribution, the kernel look good to me overall.
Would you mind trying using triton instead? I expect we can get similar performance to cuda. In the future, we hope all elementwise / data-movement kernels can be written in triton, to save maintenance overhead.
db20a81
to
61fb997
Compare
61fb997
to
f7bd89f
Compare
hi, I have formatted code with pre-commit, but since I'm not familiar with triton, so this time I can't reformat it with triton, maybe next time ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to add a benchmark like: https://github.com/flashinfer-ai/flashinfer/blob/main/benchmarks/bench_append_paged_kv_cache.py
@@ -16,6 +16,7 @@ | |||
#ifndef FLASHINFER_PAGE_CUH_ | |||
#define FLASHINFER_PAGE_CUH_ | |||
|
|||
#include <assert.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind changing them to FLASHINFER_CHECK
? (defined in
flashinfer/include/flashinfer/exception.h
Line 41 in 341ae09
#define FLASHINFER_CHECK(condition, message) \ |
assert would only work when you compile the program in debug mode, not release mode.
import flashinfer | ||
|
||
|
||
def test_append_mla_paged_kv_cache(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you design some stronger test cases, I apologize that the previous test_page.py
is very weak (most of the tests around append page APIs is performed in C++ unittests and in the future we should move them to python).
By "stronger" I mean more cases (nnz/append length), data types, page_size etc.
okay, I'll fix these problem and add more test later in this week |
Summary
related to #877
This PR implement MLA cache store,and passed correctness test in the case of ckv_dim=512 and kpe_dim=64,but no further performance test yet. Sincerely hope somebody who familiar with CUDA can help improve performance.