Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat - support mla kvcache store #888

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

baowendin
Copy link
Contributor

Summary

related to #877
This PR implement MLA cache store,and passed correctness test in the case of ckv_dim=512 and kpe_dim=64,but no further performance test yet. Sincerely hope somebody who familiar with CUDA can help improve performance.

@baowendin baowendin changed the title feat - support mla kvache store feat - support mla kvcache store Feb 23, 2025
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @baowendin , thanks for the contribution, the kernel look good to me overall.

Would you mind trying using triton instead? I expect we can get similar performance to cuda. In the future, we hope all elementwise / data-movement kernels can be written in triton, to save maintenance overhead.

@baowendin baowendin force-pushed the feature/support_mla_store_kvcache branch from db20a81 to 61fb997 Compare February 24, 2025 01:55
@baowendin baowendin force-pushed the feature/support_mla_store_kvcache branch from 61fb997 to f7bd89f Compare February 24, 2025 02:00
@baowendin
Copy link
Contributor Author

baowendin commented Feb 24, 2025

hi, I have formatted code with pre-commit, but since I'm not familiar with triton, so this time I can't reformat it with triton, maybe next time ?

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -16,6 +16,7 @@
#ifndef FLASHINFER_PAGE_CUH_
#define FLASHINFER_PAGE_CUH_

#include <assert.h>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind changing them to FLASHINFER_CHECK? (defined in

#define FLASHINFER_CHECK(condition, message) \
)

assert would only work when you compile the program in debug mode, not release mode.

import flashinfer


def test_append_mla_paged_kv_cache():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you design some stronger test cases, I apologize that the previous test_page.py is very weak (most of the tests around append page APIs is performed in C++ unittests and in the future we should move them to python).

By "stronger" I mean more cases (nnz/append length), data types, page_size etc.

@baowendin
Copy link
Contributor Author

okay, I'll fix these problem and add more test later in this week

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants