From 061b68b43a9eb53866b8890dd88904b08de50f08 Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Fri, 30 Apr 2021 00:01:44 -0700 Subject: [PATCH] Fix performance regression in ResultHandler (#1840) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1840 This diff is related to https://github.com/facebookresearch/faiss/issues/1762 The ResultHandler introduced for FlatL2 and FlatIP was not multithreaded. This diff attempts to fix that. To be verified if it is indeed faster. Reviewed By: wickedfoo Differential Revision: D27939173 fbshipit-source-id: c85f01a97d4249fe0c6bfb04396b68a7a9fe643d --- faiss/impl/ResultHandler.h | 13 ++++++++----- faiss/utils/distances.cpp | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h index ec03f7f656..65790d6011 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h @@ -92,13 +92,14 @@ struct HeapResultHandler { /// add results for query i0..i1 and j0..j1 void add_results(size_t j0, size_t j1, const T* dis_tab) { - // maybe parallel for - for (size_t i = i0; i < i1; i++) { +#pragma omp parallel for + for (int64_t i = i0; i < i1; i++) { T* heap_dis = heap_dis_tab + i * k; TI* heap_ids = heap_ids_tab + i * k; + const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0; T thresh = heap_dis[0]; for (size_t j = j0; j < j1; j++) { - T dis = *dis_tab++; + T dis = dis_tab_i[j]; if (C::cmp(thresh, dis)) { heap_replace_top(k, heap_dis, heap_ids, dis, j); thresh = heap_dis[0]; @@ -281,10 +282,12 @@ struct ReservoirResultHandler { /// add results for query i0..i1 and j0..j1 void add_results(size_t j0, size_t j1, const T* dis_tab) { // maybe parallel for - for (size_t i = i0; i < i1; i++) { +#pragma omp parallel for + for (int64_t i = i0; i < i1; i++) { ReservoirTopN& reservoir = reservoirs[i - i0]; + const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0; for (size_t j = j0; j < j1; j++) { - T dis = *dis_tab++; + T dis = dis_tab_i[j]; reservoir.add(dis, j); } } diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index 60c50ac1e9..e39b55e801 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -286,7 +286,7 @@ void exhaustive_L2sqr_blas( ip_block.get(), &nyi); } - +#pragma omp parallel for for (int64_t i = i0; i < i1; i++) { float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);