Skip to content

Commit

Permalink
Format distances_dnnl.h
Browse files Browse the repository at this point in the history
  • Loading branch information
guangzegu committed Oct 22, 2024
1 parent 9e34323 commit f556407
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions faiss/cppcontrib/amx/distances_dnnl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace faiss {
// block sizes for oneDNN/AMX distance computations
FAISS_API int distance_compute_dnnl_query_bs = 10240;
FAISS_API int distance_compute_dnnl_database_bs = 10240;

/* Find the nearest neighbors for nx queries in a set of ny vectors using oneDNN/AMX */
template <class BlockResultHandler>
void exhaustive_inner_product_seq_dnnl(
Expand All @@ -33,11 +33,11 @@ void exhaustive_inner_product_seq_dnnl(
size_t d,
size_t nx,
size_t ny,
BlockResultHandler& res) {
BlockResultHandler& res) {
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;
[[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads());

std::unique_ptr<float[]> res_arr(new float[nx * ny]);

comput_f32bf16f32_inner_product(
Expand All @@ -50,30 +50,30 @@ void exhaustive_inner_product_seq_dnnl(
res_arr.get());

#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
{
SingleResultHandler resi(res);
#pragma omp for
for (size_t i = 0; i < nx; i++) {
resi.begin(i);
for (size_t j = 0; j < ny; j++) {
float ip = res_arr[i * ny + j];
resi.add_result(ip, j);
}
resi.end();
for (size_t i = 0; i < nx; i++) {
resi.begin(i);
for (size_t j = 0; j < ny; j++) {
float ip = res_arr[i * ny + j];
resi.add_result(ip, j);
}
resi.end();
}
}
}

/** Find the nearest neighbors for nx queries in a set of ny vectors using oneDNN/AMX */
/* Find the nearest neighbors for nx queries in a set of ny vectors using oneDNN/AMX */
template <class BlockResultHandler>
void exhaustive_inner_product_blas_dnnl(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
BlockResultHandler& res) {
/* block sizes */
BlockResultHandler& res) {
/* block sizes */
const size_t bs_x = distance_compute_dnnl_query_bs;
const size_t bs_y = distance_compute_dnnl_database_bs;
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
Expand All @@ -100,12 +100,11 @@ void exhaustive_inner_product_blas_dnnl(
const_cast<float*>(y + j0 * d),
ip_block.get());


res.add_results(j0, j1, ip_block.get());
}
res.end_multiple();
InterruptCallback::check();
}
}

}// namespace faiss
} // namespace faiss

0 comments on commit f556407

Please sign in to comment.