diff --git a/csrc/flashinfer_ops.cu b/csrc/flashinfer_ops.cu index b3133593c..3e22e8122 100644 --- a/csrc/flashinfer_ops.cu +++ b/csrc/flashinfer_ops.cu @@ -15,23 +15,6 @@ */ #include "aot_extension_utils.h" -//========== activation ========== - -void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); -void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); -void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); - -//========== cascade ========== - -void merge_state(at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, - at::Tensor v_merged, at::Tensor s_merged, int64_t cuda_stream); - -void merge_state_in_place(at::Tensor v, at::Tensor s, at::Tensor v_other, at::Tensor s_other, - std::optional mask, int64_t cuda_stream); - -void merge_states(at::Tensor v, at::Tensor s, at::Tensor v_merged, at::Tensor s_merged, - int64_t cuda_stream); - //========== decode ========== void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, @@ -55,45 +38,6 @@ void BatchDecodeWithPagedKVCacheRun( unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream); -//========== gemm ========== - -void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, - at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); - -void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr, - at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld, - at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major, - int64_t cuda_stream); - -//========== norm ========== - -void rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, - int64_t cuda_stream); - -void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, - int64_t cuda_stream); - -void gemma_rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, - int64_t cuda_stream); - -void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, - double eps, int64_t cuda_stream); - -//========== page ========== - -void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::Tensor batch_indices, - at::Tensor positions, at::Tensor paged_k_cache, at::Tensor paged_v_cache, - at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, - unsigned int layout, int64_t cuda_stream); - -void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices, - at::Tensor block_sparse_indptr, - at::Tensor vector_sparse_offsets, - at::Tensor vector_sparse_indptr, - at::Tensor kv_len_arr, unsigned int stride_block, - unsigned int stride_n, unsigned int batch_size, - unsigned int block_size, int64_t cuda_stream); - //========== prefill ========== void single_prefill_with_kv_cache(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, @@ -128,148 +72,19 @@ void BatchPrefillWithPagedKVCacheRun( int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream); -//========== quantization ========== - -void packbits(at::Tensor x, const std::string& bitorder, at::Tensor y, int64_t cuda_stream); - -void segment_packbits(at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, - const std::string& bitorder, at::Tensor y, int64_t cuda_stream); - -//========== rope ========== - -void apply_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor indptr, - at::Tensor offsets, unsigned int rotary_dim, bool interleave, float rope_scale, - float rope_theta, int64_t cuda_stream); - -void apply_llama31_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor indptr, at::Tensor offsets, unsigned int rotary_dim, - bool interleave, float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length, int64_t cuda_stream); - -void apply_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, - float rope_scale, float rope_theta, int64_t cuda_stream); - -void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, - float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length, - int64_t cuda_stream); - -void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, - at::Tensor k_rope, at::Tensor cos_cache, at::Tensor sin_cache, - at::Tensor pos_ids, bool interleave, int64_t cuda_stream); - -//========== sampling ========== - -void sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, - bool deterministic, int64_t cuda_stream); - -void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, - at::Tensor success, std::optional maybe_top_p_arr, - double top_p_val, bool deterministic, int64_t cuda_stream); - -void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, - at::Tensor success, std::optional maybe_top_k_arr, - unsigned int top_k_val, bool deterministic, int64_t cuda_stream); - -void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, - std::optional maybe_min_p_arr, double min_p_val, - bool deterministic, int64_t cuda_stream); - -void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, - at::Tensor samples, at::Tensor success, - std::optional maybe_top_k_arr, double top_k_val, - std::optional maybe_top_p_arr, double top_p_val, - bool deterministic, int64_t cuda_stream); - -void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, - std::optional maybe_top_p_arr, double top_p_val, - int64_t cuda_stream); - -void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, - std::optional maybe_top_k_arr, unsigned int top_k_val, - int64_t cuda_stream); - -void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, - std::optional maybe_top_k_arr, unsigned int top_k_val, - int64_t cuda_stream); - -void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids, - at::Tensor uniform_samples, at::Tensor target_probs, - at::Tensor output_token_ids, at::Tensor output_accepted_token_num, - at::Tensor output_emitted_token_num, bool deterministic, - int64_t cuda_stream); - //========== pybind11 ========== PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // activation - m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); - m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); - m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); - - // cascade - m.def("merge_state", &merge_state, "Merge two self-attention states"); - m.def("merge_state_in_place", &merge_state_in_place, - "Merge another self-attention state in-place."); - m.def("merge_states", &merge_states, "Merge multiple self-attention states"); - // decode m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, "Single-request decode with KV-Cache operator"); m.def("batch_decode_with_paged_kv_cache_plan", &BatchDecodeWithPagedKVCachePlan); m.def("batch_decode_with_paged_kv_cache_run", &BatchDecodeWithPagedKVCacheRun); - // gemm - m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); - m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); - - // norm - m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); - m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); - m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); - m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, - "Gemma Fused add root mean square normalization"); - - // page - m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); - m.def("block_sparse_indices_to_vector_sparse_offsets", - &block_sparse_indices_to_vector_sparse_offsets, "Precompute block sparse offsets"); - // prefill m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, "Single-request prefill attention with KV-Cache operator"); m.def("batch_prefill_with_kv_cache_plan", &BatchPrefillWithKVCachePlan); m.def("batch_prefill_with_ragged_kv_cache_run", &BatchPrefillWithRaggedKVCacheRun); m.def("batch_prefill_with_paged_kv_cache_run", &BatchPrefillWithPagedKVCacheRun); - - // quantization - m.def("packbits", &packbits, "GPU packbits operator"); - m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); - - // rope - m.def("apply_rope", &apply_rope, "Apply RoPE"); - m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); - m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); - m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, - "Apply Llama 3.1 style RoPE with positional ids"); - m.def("apply_rope_pos_ids_cos_sin_cache", &apply_rope_pos_ids_cos_sin_cache, - "Apply RoPE with positional ids and cosine/sine cache"); - - // sampling - m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); - m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, - "Top-k sampling from probabilities"); - m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, - "Min-p sampling from probabilities"); - m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, - "Top-p sampling from probabilities"); - m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, - "Top-k and top-p sampling from probabilities"); - m.def("top_k_renorm_probs", &top_k_renorm_probs, "Renormalize probabilities by top-k mask"); - m.def("top_p_renorm_probs", &top_p_renorm_probs, "Renormalize probabilities by top-p mask"); - m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); - m.def("chain_speculative_sampling", &chain_speculative_sampling, - "Speculative sampling from sequence of probabilities"); } diff --git a/csrc/flashinfer_ops_aux.cu b/csrc/flashinfer_ops_aux.cu new file mode 100644 index 000000000..ae16d4128 --- /dev/null +++ b/csrc/flashinfer_ops_aux.cu @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "aot_extension_utils.h" + +//========== activation ========== + +void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); +void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +//========== cascade ========== + +void merge_state(at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, + at::Tensor v_merged, at::Tensor s_merged, int64_t cuda_stream); + +void merge_state_in_place(at::Tensor v, at::Tensor s, at::Tensor v_other, at::Tensor s_other, + std::optional mask, int64_t cuda_stream); + +void merge_states(at::Tensor v, at::Tensor s, at::Tensor v_merged, at::Tensor s_merged, + int64_t cuda_stream); + +//========== gemm ========== + +void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, + at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); + +void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr, + at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld, + at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major, + int64_t cuda_stream); + +//========== norm ========== + +void rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, + int64_t cuda_stream); + +void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, + int64_t cuda_stream); + +void gemma_rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, + int64_t cuda_stream); + +void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, + double eps, int64_t cuda_stream); + +//========== page ========== + +void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::Tensor batch_indices, + at::Tensor positions, at::Tensor paged_k_cache, at::Tensor paged_v_cache, + at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, + unsigned int layout, int64_t cuda_stream); + +void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices, + at::Tensor block_sparse_indptr, + at::Tensor vector_sparse_offsets, + at::Tensor vector_sparse_indptr, + at::Tensor kv_len_arr, unsigned int stride_block, + unsigned int stride_n, unsigned int batch_size, + unsigned int block_size, int64_t cuda_stream); + +//========== quantization ========== + +void packbits(at::Tensor x, const std::string& bitorder, at::Tensor y, int64_t cuda_stream); + +void segment_packbits(at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, + const std::string& bitorder, at::Tensor y, int64_t cuda_stream); + +//========== rope ========== + +void apply_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor indptr, + at::Tensor offsets, unsigned int rotary_dim, bool interleave, float rope_scale, + float rope_theta, int64_t cuda_stream); + +void apply_llama31_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor indptr, at::Tensor offsets, unsigned int rotary_dim, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, int64_t cuda_stream); + +void apply_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, + float rope_scale, float rope_theta, int64_t cuda_stream); + +void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, + float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, + int64_t cuda_stream); + +void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, + at::Tensor k_rope, at::Tensor cos_cache, at::Tensor sin_cache, + at::Tensor pos_ids, bool interleave, int64_t cuda_stream); + +//========== sampling ========== + +void sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + bool deterministic, int64_t cuda_stream); + +void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_p_arr, + double top_p_val, bool deterministic, int64_t cuda_stream); + +void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_k_arr, + unsigned int top_k_val, bool deterministic, int64_t cuda_stream); + +void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + std::optional maybe_min_p_arr, double min_p_val, + bool deterministic, int64_t cuda_stream); + +void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, + at::Tensor samples, at::Tensor success, + std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, + bool deterministic, int64_t cuda_stream); + +void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_p_arr, double top_p_val, + int64_t cuda_stream); + +void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_k_arr, unsigned int top_k_val, + int64_t cuda_stream); + +void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, + std::optional maybe_top_k_arr, unsigned int top_k_val, + int64_t cuda_stream); + +void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids, + at::Tensor uniform_samples, at::Tensor target_probs, + at::Tensor output_token_ids, at::Tensor output_accepted_token_num, + at::Tensor output_emitted_token_num, bool deterministic, + int64_t cuda_stream); + +//========== pybind11 ========== + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // activation + m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); + m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); + m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); + + // cascade + m.def("merge_state", &merge_state, "Merge two self-attention states"); + m.def("merge_state_in_place", &merge_state_in_place, + "Merge another self-attention state in-place."); + m.def("merge_states", &merge_states, "Merge multiple self-attention states"); + + // gemm + m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); + m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); + + // norm + m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); + m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); + m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); + m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, + "Gemma Fused add root mean square normalization"); + + // page + m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + m.def("block_sparse_indices_to_vector_sparse_offsets", + &block_sparse_indices_to_vector_sparse_offsets, "Precompute block sparse offsets"); + + // quantization + m.def("packbits", &packbits, "GPU packbits operator"); + m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); + + // rope + m.def("apply_rope", &apply_rope, "Apply RoPE"); + m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); + m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); + m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, + "Apply Llama 3.1 style RoPE with positional ids"); + m.def("apply_rope_pos_ids_cos_sin_cache", &apply_rope_pos_ids_cos_sin_cache, + "Apply RoPE with positional ids and cosine/sine cache"); + + // sampling + m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); + m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, + "Top-k sampling from probabilities"); + m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, + "Min-p sampling from probabilities"); + m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, + "Top-p sampling from probabilities"); + m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, + "Top-k and top-p sampling from probabilities"); + m.def("top_k_renorm_probs", &top_k_renorm_probs, "Renormalize probabilities by top-k mask"); + m.def("top_p_renorm_probs", &top_p_renorm_probs, "Renormalize probabilities by top-p mask"); + m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); + m.def("chain_speculative_sampling", &chain_speculative_sampling, + "Speculative sampling from sequence of probabilities"); +} diff --git a/flashinfer/activation.py b/flashinfer/activation.py index 5c78918fa..4d38624e7 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -18,7 +18,7 @@ import torch -from .jit import gen_act_and_mul_module, has_prebuilt_ops, load_cuda_ops +from .jit import gen_act_and_mul_module, has_prebuilt_kernels_aux, load_cuda_ops from .utils import get_cuda_stream, register_custom_op, register_fake_op silu_def_cu_str = r""" @@ -55,10 +55,10 @@ def get_act_and_mul_module(act_func_name: str): global _jit_modules if act_func_name not in _jit_modules: - if has_prebuilt_ops: - from . import _kernels # type: ignore[attr-defined] + if has_prebuilt_kernels_aux: + from . import _kernels_aux # type: ignore[attr-defined] - module = _kernels + module = _kernels_aux else: module = gen_act_and_mul_module( act_func_name, act_func_def_str[act_func_name] diff --git a/flashinfer/cascade.py b/flashinfer/cascade.py index 20b912599..162920822 100644 --- a/flashinfer/cascade.py +++ b/flashinfer/cascade.py @@ -19,7 +19,7 @@ import torch from .decode import BatchDecodeWithPagedKVCacheWrapper -from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_kernels_aux, load_cuda_ops from .prefill import BatchPrefillWithPagedKVCacheWrapper, single_prefill_with_kv_cache from .utils import get_cuda_stream, register_custom_op, register_fake_op @@ -29,10 +29,10 @@ def get_cascade_module(): global _cascade_module if _cascade_module is None: - if has_prebuilt_ops: - from . import _kernels + if has_prebuilt_kernels_aux: + from . import _kernels_aux - _cascade_module = _kernels + _cascade_module = _kernels_aux else: _cascade_module = load_cuda_ops( "cascade", diff --git a/flashinfer/decode.py b/flashinfer/decode.py index e073b4be9..7400b8d06 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -28,7 +28,7 @@ get_batch_decode_mla_uri, get_batch_decode_uri, get_single_decode_uri, - has_prebuilt_ops, + has_prebuilt_kernels, prebuilt_ops_uri, ) from .prefill import get_batch_prefill_module, get_single_prefill_module @@ -58,7 +58,7 @@ def get_single_decode_module(*args): global _single_decode_modules if args not in _single_decode_modules: uri = get_single_decode_uri(*args) - if has_prebuilt_ops and uri in prebuilt_ops_uri: + if has_prebuilt_kernels and uri in prebuilt_ops_uri: from . import _kernels run_func = _kernels.single_decode_with_kv_cache @@ -126,7 +126,7 @@ def get_batch_decode_module(*args): global _batch_decode_modules if args not in _batch_decode_modules: uri = get_batch_decode_uri(*args) - if has_prebuilt_ops and uri in prebuilt_ops_uri: + if has_prebuilt_kernels and uri in prebuilt_ops_uri: from . import _kernels # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 2069ccd49..09e4aaef1 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -21,7 +21,7 @@ import triton import triton.language as tl -from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_kernels_aux, has_prebuilt_kernels_sm90, load_cuda_ops from .utils import ( _get_cache_buf, determine_gemm_backend, @@ -38,10 +38,10 @@ def get_gemm_module(): global _gemm_module if _gemm_module is None: - if has_prebuilt_ops: - from . import _kernels + if has_prebuilt_kernels_aux: + from . import _kernels_aux - module = _kernels + module = _kernels_aux else: module = load_cuda_ops( "gemm", @@ -148,7 +148,7 @@ def _fake_cutlass_segment_gemm( def get_gemm_sm90_module(): global _gemm_module_sm90 if _gemm_module_sm90 is None: - if has_prebuilt_ops: + if has_prebuilt_kernels_sm90: from . import _kernels_sm90 module = _kernels_sm90 diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 6c418366a..ad1b4ad4b 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -43,8 +43,23 @@ from .utils import parallel_load_modules as parallel_load_modules try: - from .. import _kernels, _kernels_sm90 + from .. import _kernels - has_prebuilt_ops = True -except ImportError: - has_prebuilt_ops = False + has_prebuilt_kernels = True +except ImportError as exn: + print(type(exn)) + has_prebuilt_kernels = False + +try: + from .. import _kernels_sm90 + + has_prebuilt_kernels_sm90 = True +except ImportError as exn: + has_prebuilt_kernels_sm90 = False + +try: + from .. import _kernels_aux + + has_prebuilt_kernels_aux = True +except ImportError as exn: + has_prebuilt_kernels_aux = False diff --git a/flashinfer/norm.py b/flashinfer/norm.py index 1919296fb..6438572aa 100644 --- a/flashinfer/norm.py +++ b/flashinfer/norm.py @@ -18,7 +18,7 @@ import torch -from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_kernels_aux, load_cuda_ops from .utils import get_cuda_stream, register_custom_op, register_fake_op _norm_module = None @@ -27,10 +27,10 @@ def get_norm_module(): global _norm_module if _norm_module is None: - if has_prebuilt_ops: - from . import _kernels + if has_prebuilt_kernels_aux: + from . import _kernels_aux - _norm_module = _kernels + _norm_module = _kernels_aux else: _norm_module = load_cuda_ops( "norm", diff --git a/flashinfer/page.py b/flashinfer/page.py index b0f80903d..adf96b9a5 100644 --- a/flashinfer/page.py +++ b/flashinfer/page.py @@ -20,7 +20,7 @@ import triton import triton.language as tl -from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_kernels_aux, load_cuda_ops from .utils import ( TensorLayout, _check_kv_layout, @@ -36,10 +36,10 @@ def get_page_module(): global _page_module if _page_module is None: - if has_prebuilt_ops: - from . import _kernels + if has_prebuilt_kernels_aux: + from . import _kernels_aux - _page_module = _kernels + _page_module = _kernels_aux else: _page_module = load_cuda_ops( "page", diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index b997ab0e3..784d85d3b 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -31,7 +31,8 @@ get_batch_prefill_uri, get_single_prefill_sm90_uri, get_single_prefill_uri, - has_prebuilt_ops, + has_prebuilt_kernels, + has_prebuilt_kernels_sm90, load_cuda_ops, prebuilt_ops_uri, ) @@ -66,7 +67,7 @@ def get_single_prefill_sm90_module(*args): global _single_prefill_sm90_modules if args not in _single_prefill_sm90_modules: uri = get_single_prefill_sm90_uri(*args) - if has_prebuilt_ops and uri in prebuilt_ops_uri: + if has_prebuilt_kernels and uri in prebuilt_ops_uri: from . import _kernels_sm90 run_func = _kernels_sm90.single_prefill_with_kv_cache_sm90 @@ -143,7 +144,7 @@ def get_single_prefill_module(*args): global _single_prefill_modules if args not in _single_prefill_modules: uri = get_single_prefill_uri(*args) - if has_prebuilt_ops and uri in prebuilt_ops_uri: + if has_prebuilt_kernels and uri in prebuilt_ops_uri: from . import _kernels run_func = _kernels.single_prefill_with_kv_cache @@ -221,7 +222,7 @@ def get_batch_prefill_sm90_module(*args): if args not in _batch_prefill_sm90_modules: uri = get_batch_prefill_sm90_uri(*args) - if has_prebuilt_ops and uri in prebuilt_ops_uri: + if has_prebuilt_kernels and uri in prebuilt_ops_uri: from . import _kernels_sm90 head_dim = args[4] @@ -427,7 +428,7 @@ def get_batch_prefill_module(*args): global _batch_prefill_modules if args not in _batch_prefill_modules: uri = get_batch_prefill_uri(*args) - if has_prebuilt_ops and uri in prebuilt_ops_uri: + if has_prebuilt_kernels and uri in prebuilt_ops_uri: from . import _kernels # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later diff --git a/flashinfer/quantization.py b/flashinfer/quantization.py index f5c00340b..91b7ab44f 100644 --- a/flashinfer/quantization.py +++ b/flashinfer/quantization.py @@ -18,7 +18,7 @@ import torch -from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_kernels_aux, load_cuda_ops from .utils import get_cuda_stream, register_custom_op, register_fake_op _quantization_module = None @@ -27,10 +27,10 @@ def get_quantization_module(): global _quantization_module if _quantization_module is None: - if has_prebuilt_ops: - from . import _kernels + if has_prebuilt_kernels_aux: + from . import _kernels_aux - _quantization_module = _kernels + _quantization_module = _kernels_aux else: _quantization_module = load_cuda_ops( "quantization", diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 8c781b29f..e00e8ed8b 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -18,7 +18,7 @@ import torch -from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_kernels_aux, load_cuda_ops from .utils import get_cuda_stream, register_custom_op, register_fake_op _rope_module = None @@ -27,10 +27,10 @@ def get_rope_module(): global _rope_module if _rope_module is None: - if has_prebuilt_ops: - from . import _kernels + if has_prebuilt_kernels_aux: + from . import _kernels_aux - _rope_module = _kernels + _rope_module = _kernels_aux else: _rope_module = load_cuda_ops( "rope", diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 344283dd8..f22598ffd 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -19,7 +19,7 @@ import torch -from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_kernels_aux, load_cuda_ops from .utils import get_cuda_stream, register_custom_op, register_fake_op _sampling_module = None @@ -28,10 +28,10 @@ def get_sampling_module(): global _sampling_module if _sampling_module is None: - if has_prebuilt_ops: - from . import _kernels + if has_prebuilt_kernels_aux: + from . import _kernels_aux - module = _kernels + module = _kernels_aux else: module = load_cuda_ops( "sampling", diff --git a/setup.py b/setup.py index 33d2d4e99..9d2ffc1c0 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ enable_fp8_e4m3 = os.environ.get("FLASHINFER_ENABLE_FP8_E4M3", "1" if enable_fp8 else "0") == "1" enable_fp8_e5m2 = os.environ.get("FLASHINFER_ENABLE_FP8_E5M2", "1" if enable_fp8 else "0") == "1" enable_sm90 = os.environ.get("FLASHINFER_ENABLE_SM90", "1") == "1" +enable_aux = os.environ.get("FLASHINFER_ENABLE_AUX", "1") == "1" def get_version(): @@ -191,6 +192,13 @@ def __init__(self, *args, **kwargs) -> None: ] sm90a_flags = "-gencode arch=compute_90a,code=sm_90a".split() kernel_sources = [ + "csrc/batch_decode.cu", + "csrc/batch_prefill.cu", + "csrc/single_decode.cu", + "csrc/single_prefill.cu", + "csrc/flashinfer_ops.cu", + ] + kernel_aux_sources = [ "csrc/bmm_fp8.cu", "csrc/cascade.cu", "csrc/group_gemm.cu", @@ -201,11 +209,7 @@ def __init__(self, *args, **kwargs) -> None: "csrc/sampling.cu", "csrc/renorm.cu", "csrc/activation.cu", - "csrc/batch_decode.cu", - "csrc/batch_prefill.cu", - "csrc/single_decode.cu", - "csrc/single_prefill.cu", - "csrc/flashinfer_ops.cu", + "csrc/flashinfer_ops_aux.cu", ] kernel_sm90_sources = [ "csrc/group_gemm_sm90.cu", @@ -241,6 +245,18 @@ def __init__(self, *args, **kwargs) -> None: }, ), ] + if enable_aux: + ext_modules += [ + torch_cpp_ext.CUDAExtension( + name="flashinfer._kernels_aux", + sources=kernel_aux_sources, + include_dirs=include_dirs, + extra_compile_args={ + "cxx": cxx_flags, + "nvcc": nvcc_flags, + }, + ), + ] setuptools.setup( version=get_version(), diff --git a/tests/test_alibi.py b/tests/test_alibi.py index 0042c073d..7a8837ad4 100644 --- a/tests/test_alibi.py +++ b/tests/test_alibi.py @@ -25,7 +25,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - if flashinfer.jit.has_prebuilt_ops: + if flashinfer.jit.has_prebuilt_kernels: yield else: try: diff --git a/tests/test_batch_decode_kernels.py b/tests/test_batch_decode_kernels.py index a23e4d21f..3c3432b94 100644 --- a/tests/test_batch_decode_kernels.py +++ b/tests/test_batch_decode_kernels.py @@ -23,7 +23,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - if flashinfer.jit.has_prebuilt_ops: + if flashinfer.jit.has_prebuilt_kernels: yield else: try: diff --git a/tests/test_batch_prefill_kernels.py b/tests/test_batch_prefill_kernels.py index 81bdeeaae..02820221f 100644 --- a/tests/test_batch_prefill_kernels.py +++ b/tests/test_batch_prefill_kernels.py @@ -23,7 +23,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - if flashinfer.jit.has_prebuilt_ops: + if flashinfer.jit.has_prebuilt_kernels: yield else: try: diff --git a/tests/test_block_sparse.py b/tests/test_block_sparse.py index da37593b6..d9b79d59b 100644 --- a/tests/test_block_sparse.py +++ b/tests/test_block_sparse.py @@ -25,7 +25,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - if flashinfer.jit.has_prebuilt_ops: + if flashinfer.jit.has_prebuilt_kernels: yield else: try: diff --git a/tests/test_logits_cap.py b/tests/test_logits_cap.py index 5ee1df077..da707de1d 100644 --- a/tests/test_logits_cap.py +++ b/tests/test_logits_cap.py @@ -25,7 +25,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - if flashinfer.jit.has_prebuilt_ops: + if flashinfer.jit.has_prebuilt_kernels: yield else: try: diff --git a/tests/test_non_contiguous_decode.py b/tests/test_non_contiguous_decode.py index 8d1321b58..e0b4d8ba3 100644 --- a/tests/test_non_contiguous_decode.py +++ b/tests/test_non_contiguous_decode.py @@ -7,7 +7,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - if flashinfer.jit.has_prebuilt_ops: + if flashinfer.jit.has_prebuilt_kernels: yield else: try: diff --git a/tests/test_non_contiguous_prefill.py b/tests/test_non_contiguous_prefill.py index 5f7594941..86a296d66 100644 --- a/tests/test_non_contiguous_prefill.py +++ b/tests/test_non_contiguous_prefill.py @@ -23,7 +23,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - if flashinfer.jit.has_prebuilt_ops: + if flashinfer.jit.has_prebuilt_kernels: yield else: try: diff --git a/tests/test_shared_prefix_kernels.py b/tests/test_shared_prefix_kernels.py index 097a3c86c..f3a6a5a4e 100644 --- a/tests/test_shared_prefix_kernels.py +++ b/tests/test_shared_prefix_kernels.py @@ -23,7 +23,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - if flashinfer.jit.has_prebuilt_ops: + if flashinfer.jit.has_prebuilt_kernels: yield else: try: diff --git a/tests/test_sliding_window.py b/tests/test_sliding_window.py index bdac80b6b..e8efba6e5 100644 --- a/tests/test_sliding_window.py +++ b/tests/test_sliding_window.py @@ -23,7 +23,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - if flashinfer.jit.has_prebuilt_ops: + if flashinfer.jit.has_prebuilt_kernels: yield else: try: diff --git a/tests/test_tensor_cores_decode.py b/tests/test_tensor_cores_decode.py index 98160c7ef..9204c51f8 100644 --- a/tests/test_tensor_cores_decode.py +++ b/tests/test_tensor_cores_decode.py @@ -23,7 +23,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - if flashinfer.jit.has_prebuilt_ops: + if flashinfer.jit.has_prebuilt_kernels: yield else: try: