diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index ed7b7390b..0dc72c62c 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -209,6 +209,17 @@ void PQFlashIndex::generate_cache_list_from_sample_queries(std::strin std::vector &node_list) { #endif + if (num_nodes_to_cache >= this->num_points) + { + // for small num_points and big num_nodes_to_cache, use below way to get the node_list quickly + node_list.resize(this->num_points); + for (uint32_t i = 0; i < this->num_points; ++i) + { + node_list[i] = i; + } + return; + } + this->count_visited_nodes = true; this->node_visit_counter.clear(); this->node_visit_counter.resize(this->num_points); @@ -244,8 +255,8 @@ void PQFlashIndex::generate_cache_list_from_sample_queries(std::strin #pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) for (int64_t i = 0; i < (int64_t)sample_num; i++) { - cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, tmp_result_ids_64.data() + (i * 1), - tmp_result_dists.data() + (i * 1), beamwidth); + cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, tmp_result_ids_64.data() + i, + tmp_result_dists.data() + i, beamwidth); } std::sort(this->node_visit_counter.begin(), node_visit_counter.end(), @@ -254,6 +265,7 @@ void PQFlashIndex::generate_cache_list_from_sample_queries(std::strin }); node_list.clear(); node_list.shrink_to_fit(); + num_nodes_to_cache = std::min(num_nodes_to_cache, this->node_visit_counter.size()); node_list.reserve(num_nodes_to_cache); for (uint64_t i = 0; i < num_nodes_to_cache; i++) {