Skip to content

Commit

Permalink
Restructure the AMX integration with faiss
Browse files Browse the repository at this point in the history
  • Loading branch information
guangzegu authored and xtangxtang committed Sep 24, 2024
1 parent a15c5cc commit 116fc01
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 62 deletions.
4 changes: 4 additions & 0 deletions c_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ faiss_install_headers("${FAISS_C_API_HEADERS}" c_api)
add_executable(example_c EXCLUDE_FROM_ALL example_c.c)
target_link_libraries(example_c PRIVATE faiss_c)

if(FAISS_ENABLE_DNNL)
add_compile_definitions(ENABLE_DNNL)
endif()

if(FAISS_ENABLE_GPU)
add_subdirectory(gpu)
endif()
12 changes: 11 additions & 1 deletion c_api/utils/distances_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,20 @@ int faiss_get_distance_compute_min_k_reservoir() {
return faiss::distance_compute_min_k_reservoir;
}

#ifdef ENABLE_DNNL
void faiss_set_distance_compute_dnnl_query_bs(int value) {
faiss::distance_compute_dnnl_query_bs = value;
}

int faiss_get_distance_compute_dnnl_query_bs() {
return faiss::distance_compute_dnnl_query_bs;
}
}

void faiss_set_distance_compute_dnnl_database_bs(int value) {
faiss::distance_compute_dnnl_database_bs = value;
}

int faiss_get_distance_compute_dnnl_database_bs() {
return faiss::distance_compute_dnnl_database_bs;
}
#endif
2 changes: 2 additions & 0 deletions c_api/utils/distances_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ void faiss_set_distance_compute_min_k_reservoir(int value);
/// rather than a heap
int faiss_get_distance_compute_min_k_reservoir();

#ifdef ENABLE_DNNL
/// Setter of block sizes value for oneDNN/AMX distance computations
void faiss_set_distance_compute_dnnl_query_bs(int value);

Expand All @@ -114,6 +115,7 @@ void faiss_set_distance_compute_dnnl_database_bs(int value);

/// Getter of block sizes value for oneDNN/AMX distance computations
int faiss_get_distance_compute_dnnl_database_bs();
#endif

#ifdef __cplusplus
}
Expand Down
3 changes: 1 addition & 2 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ if(NOT WIN32)
endif()

if(FAISS_ENABLE_DNNL)
list(APPEND FAISS_HEADERS utils/onednn/onednn_utils.h)
list(APPEND FAISS_HEADERS cppcontrib/amx/onednn_utils.h)
endif()

if(FAISS_ENABLE_DNNL)
Expand Down Expand Up @@ -310,7 +310,6 @@ target_compile_definitions(faiss_avx512 PRIVATE FINTEGER=int)
if(FAISS_ENABLE_DNNL)
find_library(RT_LIB rt)
find_library(DNNL_LIB dnnl)
message(DNNL_LIB=${DNNL_LIB})
target_link_libraries(faiss PRIVATE ${RT_LIB} ${DNNL_LIB})
target_link_libraries(faiss_avx2 PRIVATE ${RT_LIB} ${DNNL_LIB})
target_link_libraries(faiss_avx512 PRIVATE ${RT_LIB} ${DNNL_LIB})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "oneapi/dnnl/dnnl.hpp"

namespace faiss {

static dnnl::engine cpu_engine;
static dnnl::stream engine_stream;
static bool is_onednn_init = false;
Expand Down
168 changes: 109 additions & 59 deletions faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#endif

#ifdef ENABLE_DNNL
#include <faiss/utils/onednn/onednn_utils.h>
#include <faiss/cppcontrib/amx/onednn_utils.h>
#endif

#include <faiss/impl/AuxIndexStructures.h>
Expand Down Expand Up @@ -133,34 +133,33 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {

namespace {

/* Find the nearest neighbors for nx queries in a set of ny vectors */
template <class BlockResultHandler, bool use_sel = false>
void exhaustive_inner_product_seq(
#ifdef ENABLE_DNNL
/* Find the nearest neighbors for nx queries in a set of ny vectors using oneDNN/AMX */
template <class BlockResultHandler, bool use_sel = false>
void exhaustive_inner_product_seq_dnnl(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
BlockResultHandler& res,
const IDSelector* sel = nullptr) {
const IDSelector* sel = nullptr) {
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;
[[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads());

FAISS_ASSERT(use_sel == (sel != nullptr));

#ifdef ENABLE_DNNL
// use AMX to accelerate if available
if (is_amxbf16_supported()) {
float* res_arr = (float*)malloc(nx * ny * sizeof(float));
comput_f32bf16f32_inner_product(
nx,
d,
ny,
d,
const_cast<float*>(x),
const_cast<float*>(y),
res_arr);
float* res_arr = (float*)malloc(nx * ny * sizeof(float));

comput_f32bf16f32_inner_product(
nx,
d,
ny,
d,
const_cast<float*>(x),
const_cast<float*>(y),
res_arr);

#pragma omp parallel num_threads(nt)
{
Expand All @@ -175,34 +174,46 @@ void exhaustive_inner_product_seq(
resi.end();
}
}
delete[] res_arr;
} else {
free(res_arr);
}
#endif

/* Find the nearest neighbors for nx queries in a set of ny vectors */
template <class BlockResultHandler, bool use_sel = false>
void exhaustive_inner_product_seq(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
BlockResultHandler& res,
const IDSelector* sel = nullptr) {
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;
int nt = std::min(int(nx), omp_get_max_threads());

FAISS_ASSERT(use_sel == (sel != nullptr));

#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
{
SingleResultHandler resi(res);
#pragma omp for
for (int64_t i = 0; i < nx; i++) {
const float* x_i = x + i * d;
const float* y_j = y;
for (int64_t i = 0; i < nx; i++) {
const float* x_i = x + i * d;
const float* y_j = y;

resi.begin(i);
resi.begin(i);

for (size_t j = 0; j < ny; j++, y_j += d) {
if (use_sel && !sel->is_member(j)) {
continue;
}
float ip = fvec_inner_product(x_i, y_j, d);
resi.add_result(ip, j);
for (size_t j = 0; j < ny; j++, y_j += d) {
if (use_sel && !sel->is_member(j)) {
continue;
}
resi.end();
float ip = fvec_inner_product(x_i, y_j, d);
resi.add_result(ip, j);
}
resi.end();
}

#ifdef ENABLE_DNNL
}
#endif
}

template <class BlockResultHandler, bool use_sel = false>
Expand Down Expand Up @@ -240,6 +251,53 @@ void exhaustive_L2sqr_seq(
}
}


#ifdef ENABLE_DNNL
/** Find the nearest neighbors for nx queries in a set of ny vectors using oneDNN/AMX */
template <class BlockResultHandler>
void exhaustive_inner_product_blas_dnnl(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
BlockResultHandler& res) {
/* block sizes */
const size_t bs_x = distance_compute_dnnl_query_bs;
const size_t bs_y = distance_compute_dnnl_database_bs;
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);

for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
size_t i1 = i0 + bs_x;
if (i1 > nx)
i1 = nx;

res.begin_multiple(i0, i1);

for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
size_t j1 = j0 + bs_y;
if (j1 > ny)
j1 = ny;
/* compute the actual dot products */
FINTEGER nyi = j1 - j0, nxi = i1 - i0;
comput_f32bf16f32_inner_product(
nxi,
d,
nyi,
d,
const_cast<float*>(x + i0 * d),
const_cast<float*>(y + j0 * d),
ip_block.get());


res.add_results(j0, j1, ip_block.get());
}
res.end_multiple();
InterruptCallback::check();
}
}
#endif

/** Find the nearest neighbors for nx queries in a set of ny vectors */
template <class BlockResultHandler>
void exhaustive_inner_product_blas(
Expand All @@ -254,16 +312,8 @@ void exhaustive_inner_product_blas(
return;

/* block sizes */
size_t prov_bs_x = distance_compute_blas_query_bs;
size_t prov_bs_y = distance_compute_blas_database_bs;
#ifdef ENABLE_DNNL
if (is_amxbf16_supported()) {
prov_bs_x = distance_compute_dnnl_query_bs;
prov_bs_y = distance_compute_dnnl_database_bs;
}
#endif
const size_t bs_x = prov_bs_x;
const size_t bs_y = prov_bs_y;
const size_t bs_x = distance_compute_blas_query_bs;
const size_t bs_y = distance_compute_blas_database_bs;
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);

for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
Expand All @@ -277,20 +327,7 @@ void exhaustive_inner_product_blas(
size_t j1 = j0 + bs_y;
if (j1 > ny)
j1 = ny;
/* compute the actual dot products */
#ifdef ENABLE_DNNL
if (is_amxbf16_supported()) {
FINTEGER nyi = j1 - j0, nxi = i1 - i0;
comput_f32bf16f32_inner_product(
nxi,
d,
nyi,
d,
const_cast<float*>(x + i0 * d),
const_cast<float*>(y + j0 * d),
ip_block.get());
} else
#endif
/* compute the actual dot products */
{
float one = 1, zero = 0;
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
Expand Down Expand Up @@ -693,8 +730,20 @@ void knn_inner_product_select(
exhaustive_inner_product_seq<BlockResultHandler, true>(
x, y, d, nx, ny, res, sel);
} else if (nx < distance_compute_blas_threshold) {
#ifdef ENABLE_DNNL
if(is_amxbf16_supported()){
exhaustive_inner_product_seq_dnnl(x, y, d, nx, ny, res);
return;
}
#endif
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
} else {
#ifdef ENABLE_DNNL
if(is_amxbf16_supported()){
exhaustive_inner_product_blas_dnnl(x, y, d, nx, ny, res);
return;
}
#endif
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
}
}
Expand All @@ -712,6 +761,7 @@ int distance_compute_min_k_reservoir = 100;
int distance_compute_dnnl_query_bs = 10240;
int distance_compute_dnnl_database_bs = 10240;


void knn_inner_product(
const float* x,
const float* y,
Expand Down

0 comments on commit 116fc01

Please sign in to comment.