diff --git a/c_api/CMakeLists.txt b/c_api/CMakeLists.txt index 01789a3214..bb2b07ba9c 100644 --- a/c_api/CMakeLists.txt +++ b/c_api/CMakeLists.txt @@ -54,6 +54,10 @@ faiss_install_headers("${FAISS_C_API_HEADERS}" c_api) add_executable(example_c EXCLUDE_FROM_ALL example_c.c) target_link_libraries(example_c PRIVATE faiss_c) +if(FAISS_ENABLE_DNNL) + add_compile_definitions(ENABLE_DNNL) +endif() + if(FAISS_ENABLE_GPU) add_subdirectory(gpu) endif() diff --git a/c_api/utils/distances_c.cpp b/c_api/utils/distances_c.cpp index cd46cc1dc1..d9cc137454 100644 --- a/c_api/utils/distances_c.cpp +++ b/c_api/utils/distances_c.cpp @@ -101,10 +101,20 @@ int faiss_get_distance_compute_min_k_reservoir() { return faiss::distance_compute_min_k_reservoir; } +#ifdef ENABLE_DNNL void faiss_set_distance_compute_dnnl_query_bs(int value) { faiss::distance_compute_dnnl_query_bs = value; } int faiss_get_distance_compute_dnnl_query_bs() { return faiss::distance_compute_dnnl_query_bs; -} \ No newline at end of file +} + +void faiss_set_distance_compute_dnnl_database_bs(int value) { + faiss::distance_compute_dnnl_database_bs = value; +} + +int faiss_get_distance_compute_dnnl_database_bs() { + return faiss::distance_compute_dnnl_database_bs; +} +#endif diff --git a/c_api/utils/distances_c.h b/c_api/utils/distances_c.h index 0f1eed91dd..350c11157a 100644 --- a/c_api/utils/distances_c.h +++ b/c_api/utils/distances_c.h @@ -103,6 +103,7 @@ void faiss_set_distance_compute_min_k_reservoir(int value); /// rather than a heap int faiss_get_distance_compute_min_k_reservoir(); +#ifdef ENABLE_DNNL /// Setter of block sizes value for oneDNN/AMX distance computations void faiss_set_distance_compute_dnnl_query_bs(int value); @@ -114,6 +115,7 @@ void faiss_set_distance_compute_dnnl_database_bs(int value); /// Getter of block sizes value for oneDNN/AMX distance computations int faiss_get_distance_compute_dnnl_database_bs(); +#endif #ifdef __cplusplus } diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 7e2a55740c..df52d381ab 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -228,7 +228,7 @@ if(NOT WIN32) endif() if(FAISS_ENABLE_DNNL) - list(APPEND FAISS_HEADERS utils/onednn/onednn_utils.h) + list(APPEND FAISS_HEADERS cppcontrib/amx/onednn_utils.h) endif() if(FAISS_ENABLE_DNNL) @@ -310,7 +310,6 @@ target_compile_definitions(faiss_avx512 PRIVATE FINTEGER=int) if(FAISS_ENABLE_DNNL) find_library(RT_LIB rt) find_library(DNNL_LIB dnnl) - message(DNNL_LIB=${DNNL_LIB}) target_link_libraries(faiss PRIVATE ${RT_LIB} ${DNNL_LIB}) target_link_libraries(faiss_avx2 PRIVATE ${RT_LIB} ${DNNL_LIB}) target_link_libraries(faiss_avx512 PRIVATE ${RT_LIB} ${DNNL_LIB}) diff --git a/faiss/utils/onednn/onednn_utils.h b/faiss/cppcontrib/amx/onednn_utils.h similarity index 99% rename from faiss/utils/onednn/onednn_utils.h rename to faiss/cppcontrib/amx/onednn_utils.h index 2e1b8b6002..955bcc4d80 100644 --- a/faiss/utils/onednn/onednn_utils.h +++ b/faiss/cppcontrib/amx/onednn_utils.h @@ -16,6 +16,7 @@ #include "oneapi/dnnl/dnnl.hpp" namespace faiss { + static dnnl::engine cpu_engine; static dnnl::stream engine_stream; static bool is_onednn_init = false; diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index e00020e205..87f0465443 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -21,7 +21,7 @@ #endif #ifdef ENABLE_DNNL -#include +#include #endif #include @@ -133,34 +133,33 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) { namespace { -/* Find the nearest neighbors for nx queries in a set of ny vectors */ -template -void exhaustive_inner_product_seq( +#ifdef ENABLE_DNNL +/* Find the nearest neighbors for nx queries in a set of ny vectors using oneDNN/AMX */ +template +void exhaustive_inner_product_seq_dnnl( const float* x, const float* y, size_t d, size_t nx, size_t ny, BlockResultHandler& res, - const IDSelector* sel = nullptr) { + const IDSelector* sel = nullptr) { using SingleResultHandler = typename BlockResultHandler::SingleResultHandler; [[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads()); FAISS_ASSERT(use_sel == (sel != nullptr)); -#ifdef ENABLE_DNNL - // use AMX to accelerate if available - if (is_amxbf16_supported()) { - float* res_arr = (float*)malloc(nx * ny * sizeof(float)); - comput_f32bf16f32_inner_product( - nx, - d, - ny, - d, - const_cast(x), - const_cast(y), - res_arr); + float* res_arr = (float*)malloc(nx * ny * sizeof(float)); + + comput_f32bf16f32_inner_product( + nx, + d, + ny, + d, + const_cast(x), + const_cast(y), + res_arr); #pragma omp parallel num_threads(nt) { @@ -175,34 +174,46 @@ void exhaustive_inner_product_seq( resi.end(); } } - delete[] res_arr; - } else { + free(res_arr); +} #endif +/* Find the nearest neighbors for nx queries in a set of ny vectors */ +template +void exhaustive_inner_product_seq( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + BlockResultHandler& res, + const IDSelector* sel = nullptr) { + using SingleResultHandler = + typename BlockResultHandler::SingleResultHandler; + int nt = std::min(int(nx), omp_get_max_threads()); + + FAISS_ASSERT(use_sel == (sel != nullptr)); + #pragma omp parallel num_threads(nt) - { - SingleResultHandler resi(res); + { + SingleResultHandler resi(res); #pragma omp for - for (int64_t i = 0; i < nx; i++) { - const float* x_i = x + i * d; - const float* y_j = y; + for (int64_t i = 0; i < nx; i++) { + const float* x_i = x + i * d; + const float* y_j = y; - resi.begin(i); + resi.begin(i); - for (size_t j = 0; j < ny; j++, y_j += d) { - if (use_sel && !sel->is_member(j)) { - continue; - } - float ip = fvec_inner_product(x_i, y_j, d); - resi.add_result(ip, j); + for (size_t j = 0; j < ny; j++, y_j += d) { + if (use_sel && !sel->is_member(j)) { + continue; } - resi.end(); + float ip = fvec_inner_product(x_i, y_j, d); + resi.add_result(ip, j); } + resi.end(); } - -#ifdef ENABLE_DNNL } -#endif } template @@ -240,6 +251,53 @@ void exhaustive_L2sqr_seq( } } + +#ifdef ENABLE_DNNL +/** Find the nearest neighbors for nx queries in a set of ny vectors using oneDNN/AMX */ +template +void exhaustive_inner_product_blas_dnnl( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + BlockResultHandler& res) { + /* block sizes */ + const size_t bs_x = distance_compute_dnnl_query_bs; + const size_t bs_y = distance_compute_dnnl_database_bs; + std::unique_ptr ip_block(new float[bs_x * bs_y]); + + for (size_t i0 = 0; i0 < nx; i0 += bs_x) { + size_t i1 = i0 + bs_x; + if (i1 > nx) + i1 = nx; + + res.begin_multiple(i0, i1); + + for (size_t j0 = 0; j0 < ny; j0 += bs_y) { + size_t j1 = j0 + bs_y; + if (j1 > ny) + j1 = ny; + /* compute the actual dot products */ + FINTEGER nyi = j1 - j0, nxi = i1 - i0; + comput_f32bf16f32_inner_product( + nxi, + d, + nyi, + d, + const_cast(x + i0 * d), + const_cast(y + j0 * d), + ip_block.get()); + + + res.add_results(j0, j1, ip_block.get()); + } + res.end_multiple(); + InterruptCallback::check(); + } +} +#endif + /** Find the nearest neighbors for nx queries in a set of ny vectors */ template void exhaustive_inner_product_blas( @@ -254,16 +312,8 @@ void exhaustive_inner_product_blas( return; /* block sizes */ - size_t prov_bs_x = distance_compute_blas_query_bs; - size_t prov_bs_y = distance_compute_blas_database_bs; -#ifdef ENABLE_DNNL - if (is_amxbf16_supported()) { - prov_bs_x = distance_compute_dnnl_query_bs; - prov_bs_y = distance_compute_dnnl_database_bs; - } -#endif - const size_t bs_x = prov_bs_x; - const size_t bs_y = prov_bs_y; + const size_t bs_x = distance_compute_blas_query_bs; + const size_t bs_y = distance_compute_blas_database_bs; std::unique_ptr ip_block(new float[bs_x * bs_y]); for (size_t i0 = 0; i0 < nx; i0 += bs_x) { @@ -277,20 +327,7 @@ void exhaustive_inner_product_blas( size_t j1 = j0 + bs_y; if (j1 > ny) j1 = ny; -/* compute the actual dot products */ -#ifdef ENABLE_DNNL - if (is_amxbf16_supported()) { - FINTEGER nyi = j1 - j0, nxi = i1 - i0; - comput_f32bf16f32_inner_product( - nxi, - d, - nyi, - d, - const_cast(x + i0 * d), - const_cast(y + j0 * d), - ip_block.get()); - } else -#endif + /* compute the actual dot products */ { float one = 1, zero = 0; FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d; @@ -693,8 +730,20 @@ void knn_inner_product_select( exhaustive_inner_product_seq( x, y, d, nx, ny, res, sel); } else if (nx < distance_compute_blas_threshold) { +#ifdef ENABLE_DNNL + if(is_amxbf16_supported()){ + exhaustive_inner_product_seq_dnnl(x, y, d, nx, ny, res); + return; + } +#endif exhaustive_inner_product_seq(x, y, d, nx, ny, res); } else { +#ifdef ENABLE_DNNL + if(is_amxbf16_supported()){ + exhaustive_inner_product_blas_dnnl(x, y, d, nx, ny, res); + return; + } +#endif exhaustive_inner_product_blas(x, y, d, nx, ny, res); } } @@ -712,6 +761,7 @@ int distance_compute_min_k_reservoir = 100; int distance_compute_dnnl_query_bs = 10240; int distance_compute_dnnl_database_bs = 10240; + void knn_inner_product( const float* x, const float* y,