diff --git a/include/index.h b/include/index.h index 73b832641..9e8ab645a 100644 --- a/include/index.h +++ b/include/index.h @@ -104,6 +104,9 @@ template clas DISKANN_DLLEXPORT size_t get_num_points(); DISKANN_DLLEXPORT size_t get_max_points(); + DISKANN_DLLEXPORT bool detect_common_filters(uint32_t point_id, bool search_invocation, + const std::vector &incoming_labels); + // Batch build from a file. Optionally pass tags vector. DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load, const IndexWriteParameters ¶meters, diff --git a/src/index.cpp b/src/index.cpp index c82249b5e..0acb19aaf 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -837,6 +837,38 @@ template std::vector Inde return init_ids; } +// Find common filter between a node's labels and a given set of labels, while taking into account universal label +template +bool Index::detect_common_filters(uint32_t point_id, bool search_invocation, + const std::vector &incoming_labels) +{ + auto &curr_node_labels = _pts_to_labels[point_id]; + std::vector common_filters; + std::set_intersection(incoming_labels.begin(), incoming_labels.end(), curr_node_labels.begin(), + curr_node_labels.end(), std::back_inserter(common_filters)); + if (common_filters.size() > 0) + { + // This is to reduce the repetitive calls. If common_filters size is > 0 , we dont need to check further for + // universal label + return true; + } + if (_use_universal_label) + { + if (!search_invocation) + { + if (std::find(incoming_labels.begin(), incoming_labels.end(), _universal_label) != incoming_labels.end() || + std::find(curr_node_labels.begin(), curr_node_labels.end(), _universal_label) != curr_node_labels.end()) + common_filters.push_back(_universal_label); + } + else + { + if (std::find(curr_node_labels.begin(), curr_node_labels.end(), _universal_label) != curr_node_labels.end()) + common_filters.push_back(_universal_label); + } + } + return (common_filters.size() > 0); +} + template std::pair Index::iterate_to_fixed_point( const T *query, const uint32_t Lsize, const std::vector &init_ids, InMemQueryScratch *scratch, @@ -933,18 +965,7 @@ std::pair Index::iterate_to_fixed_point( if (use_filter) { - std::vector common_filters; - auto &x = _pts_to_labels[id]; - std::set_intersection(filter_label.begin(), filter_label.end(), x.begin(), x.end(), - std::back_inserter(common_filters)); - if (_use_universal_label) - { - if (std::find(filter_label.begin(), filter_label.end(), _universal_label) != filter_label.end() || - std::find(x.begin(), x.end(), _universal_label) != x.end()) - common_filters.emplace_back(_universal_label); - } - - if (common_filters.size() == 0) + if (!detect_common_filters(id, search_invocation, filter_label)) continue; } @@ -1012,19 +1033,7 @@ std::pair Index::iterate_to_fixed_point( if (use_filter) { // NOTE: NEED TO CHECK IF THIS CORRECT WITH NEW LOCKS. - std::vector common_filters; - auto &x = _pts_to_labels[id]; - std::set_intersection(filter_label.begin(), filter_label.end(), x.begin(), x.end(), - std::back_inserter(common_filters)); - if (_use_universal_label) - { - if (std::find(filter_label.begin(), filter_label.end(), _universal_label) != - filter_label.end() || - std::find(x.begin(), x.end(), _universal_label) != x.end()) - common_filters.emplace_back(_universal_label); - } - - if (common_filters.size() == 0) + if (!detect_common_filters(id, search_invocation, filter_label)) continue; }