From 34146093df82ed1ad53a26bfaed8460cef5aaadd Mon Sep 17 00:00:00 2001 From: Rishabh Maurya Date: Wed, 14 Feb 2024 11:40:43 -0800 Subject: [PATCH 01/17] Cardinality aggregation dynamic pruning changes Signed-off-by: bowenlan-amzn --- .../metrics/CardinalityAggregator.java | 10 +- .../DisjunctionWithDynamicPruningScorer.java | 264 ++++++++++++++++++ .../DynamicPruningCollectorWrapper.java | 106 +++++++ .../metrics/CardinalityAggregatorTests.java | 58 ++++ 4 files changed, 435 insertions(+), 3 deletions(-) create mode 100644 server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java index 99c4eaac4b777..91887e2e4a202 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java @@ -53,6 +53,7 @@ import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.LeafBucketCollector; +import org.opensearch.search.aggregations.support.FieldContext; import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; @@ -71,6 +72,8 @@ public class CardinalityAggregator extends NumericMetricsAggregator.SingleValue private final int precision; private final ValuesSource valuesSource; + private final FieldContext fieldContext; + // Expensive to initialize, so we only initialize it when we have an actual value source @Nullable private HyperLogLogPlusPlus counts; @@ -95,6 +98,7 @@ public CardinalityAggregator( // TODO: Stop using nulls here this.valuesSource = valuesSourceConfig.hasValues() ? valuesSourceConfig.getValuesSource() : null; this.precision = precision; + this.fieldContext = valuesSourceConfig.fieldContext(); this.counts = valuesSource == null ? null : new HyperLogLogPlusPlus(precision, context.bigArrays(), 1); } @@ -132,11 +136,11 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException { // only use ordinals if they don't increase memory usage by more than 25% if (ordinalsMemoryUsage < countsMemoryUsage / 4) { ordinalsCollectorsUsed++; - return new OrdinalsCollector(counts, ordinalValues, context.bigArrays()); + return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), + context, ctx, fieldContext, source); } ordinalsCollectorsOverheadTooHigh++; } - stringHashingCollectorsUsed++; return new DirectCollector(counts, MurmurHash3Values.hash(valuesSource.bytesValues(ctx))); } @@ -206,7 +210,7 @@ public void collectDebugInfo(BiConsumer add) { * * @opensearch.internal */ - private abstract static class Collector extends LeafBucketCollector implements Releasable { + abstract static class Collector extends LeafBucketCollector implements Releasable { public abstract void postCollect() throws IOException; diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java new file mode 100644 index 0000000000000..6a7e66e8be2f0 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java @@ -0,0 +1,264 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.metrics; + +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DisjunctionDISIApproximation; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.PriorityQueue; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + + +/** + * Clone of {@link org.apache.lucene.search} {@code DisjunctionScorer.java} in lucene with following modifications - + * 1. {@link #removeAllDISIsOnCurrentDoc()} - it removes all the DISIs for subscorer pointing to current doc. This is + * helpful in dynamic pruning for Cardinality aggregation, where once a term is found, it becomes irrelevant for + * rest of the search space, so this term's subscorer DISI can be safely removed from list of subscorer to process. + *

+ * 2. {@link #removeAllDISIsOnCurrentDoc()} breaks the invariant of Conjuction DISI i.e. the docIDs of all sub-scorers should be + * less than or equal to current docID iterator is pointing to. When we remove elements from priority, it results in heapify action, which modifies + * the top of the priority queye, which represents the current docID for subscorers here. To address this, we are wrapping the + * iterator with {@link SlowDocIdPropagatorDISI} which keeps the iterator pointing to last docID before {@link #removeAllDISIsOnCurrentDoc()} + * is called and updates this docID only when next() or advance() is called. + */ +public class DisjunctionWithDynamicPruningScorer extends Scorer { + + private final boolean needsScores; + private final DisiPriorityQueue subScorers; + private final DocIdSetIterator approximation; + private final TwoPhase twoPhase; + + private Integer docID; + + public DisjunctionWithDynamicPruningScorer(Weight weight, List subScorers) + throws IOException { + super(weight); + if (subScorers.size() <= 1) { + throw new IllegalArgumentException("There must be at least 2 subScorers"); + } + this.subScorers = new DisiPriorityQueue(subScorers.size()); + for (Scorer scorer : subScorers) { + final DisiWrapper w = new DisiWrapper(scorer); + this.subScorers.add(w); + } + this.needsScores = false; + this.approximation = new DisjunctionDISIApproximation(this.subScorers); + + boolean hasApproximation = false; + float sumMatchCost = 0; + long sumApproxCost = 0; + // Compute matchCost as the average over the matchCost of the subScorers. + // This is weighted by the cost, which is an expected number of matching documents. + for (DisiWrapper w : this.subScorers) { + long costWeight = (w.cost <= 1) ? 1 : w.cost; + sumApproxCost += costWeight; + if (w.twoPhaseView != null) { + hasApproximation = true; + sumMatchCost += w.matchCost * costWeight; + } + } + + if (hasApproximation == false) { // no sub scorer supports approximations + twoPhase = null; + } else { + final float matchCost = sumMatchCost / sumApproxCost; + twoPhase = new TwoPhase(approximation, matchCost); + } + } + + public void removeAllDISIsOnCurrentDoc() { + docID = this.docID(); + while (subScorers.size() > 0 && subScorers.top().doc == docID) { + subScorers.pop(); + } + } + + @Override + public DocIdSetIterator iterator() { + DocIdSetIterator disi = getIterator(); + docID = disi.docID(); + return new SlowDocIdPropagatorDISI(getIterator(), docID); + } + + private static class SlowDocIdPropagatorDISI extends DocIdSetIterator { + DocIdSetIterator disi; + + Integer curDocId; + + SlowDocIdPropagatorDISI(DocIdSetIterator disi, Integer curDocId) { + this.disi = disi; + this.curDocId = curDocId; + } + + @Override + public int docID() { + assert curDocId <= disi.docID(); + return curDocId; + } + + @Override + public int nextDoc() throws IOException { + return advance(curDocId + 1); + } + + @Override + public int advance(int i) throws IOException { + if (i <= disi.docID()) { + // since we are slow propagating docIDs, it may happen the disi is already advanced to a higher docID than i + // in such scenarios we can simply return the docID where disi is pointing to and update the curDocId + curDocId = disi.docID(); + return disi.docID(); + } + curDocId = disi.advance(i); + return curDocId; + } + + @Override + public long cost() { + return disi.cost(); + } + } + + private DocIdSetIterator getIterator() { + if (twoPhase != null) { + return TwoPhaseIterator.asDocIdSetIterator(twoPhase); + } else { + return approximation; + } + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return twoPhase; + } + + @Override + public float getMaxScore(int i) throws IOException { + return 0; + } + + private class TwoPhase extends TwoPhaseIterator { + + private final float matchCost; + // list of verified matches on the current doc + DisiWrapper verifiedMatches; + // priority queue of approximations on the current doc that have not been verified yet + final PriorityQueue unverifiedMatches; + + private TwoPhase(DocIdSetIterator approximation, float matchCost) { + super(approximation); + this.matchCost = matchCost; + unverifiedMatches = + new PriorityQueue(DisjunctionWithDynamicPruningScorer.this.subScorers.size()) { + @Override + protected boolean lessThan(DisiWrapper a, DisiWrapper b) { + return a.matchCost < b.matchCost; + } + }; + } + + DisiWrapper getSubMatches() throws IOException { + // iteration order does not matter + for (DisiWrapper w : unverifiedMatches) { + if (w.twoPhaseView.matches()) { + w.next = verifiedMatches; + verifiedMatches = w; + } + } + unverifiedMatches.clear(); + return verifiedMatches; + } + + @Override + public boolean matches() throws IOException { + verifiedMatches = null; + unverifiedMatches.clear(); + + for (DisiWrapper w = subScorers.topList(); w != null; ) { + DisiWrapper next = w.next; + + if (w.twoPhaseView == null) { + // implicitly verified, move it to verifiedMatches + w.next = verifiedMatches; + verifiedMatches = w; + + if (needsScores == false) { + // we can stop here + return true; + } + } else { + unverifiedMatches.add(w); + } + w = next; + } + + if (verifiedMatches != null) { + return true; + } + + // verify subs that have an two-phase iterator + // least-costly ones first + while (unverifiedMatches.size() > 0) { + DisiWrapper w = unverifiedMatches.pop(); + if (w.twoPhaseView.matches()) { + w.next = null; + verifiedMatches = w; + return true; + } + } + + return false; + } + + @Override + public float matchCost() { + return matchCost; + } + } + + + @Override + public final int docID() { + return subScorers.top().doc; + } + + DisiWrapper getSubMatches() throws IOException { + if (twoPhase == null) { + return subScorers.topList(); + } else { + return twoPhase.getSubMatches(); + } + } + + @Override + public final float score() throws IOException { + return score(getSubMatches()); + } + + protected float score(DisiWrapper topList) throws IOException { + return 1f; + } + + @Override + public final Collection getChildren() throws IOException { + ArrayList children = new ArrayList<>(); + for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) { + children.add(new ChildScorable(scorer.scorer, "SHOULD")); + } + return children; + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java new file mode 100644 index 0000000000000..f4c3d59a3833f --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.metrics; + +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.ConjunctionUtils; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; +import org.opensearch.search.aggregations.support.FieldContext; +import org.opensearch.search.aggregations.support.ValuesSource; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +class DynamicPruningCollectorWrapper extends CardinalityAggregator.Collector { + + private final LeafReaderContext ctx; + private final DisjunctionWithDynamicPruningScorer disjunctionScorer; + private final DocIdSetIterator disi; + private final CardinalityAggregator.Collector delegateCollector; + + DynamicPruningCollectorWrapper(CardinalityAggregator.Collector delegateCollector, + SearchContext context, LeafReaderContext ctx, FieldContext fieldContext, + ValuesSource.Bytes.WithOrdinals source) throws IOException { + this.ctx = ctx; + this.delegateCollector = delegateCollector; + final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx); + boolean isCardinalityLow = ordinalValues.getValueCount() < 10; + boolean isCardinalityAggregationOnlyAggregation = true; + boolean isFieldSupportedForDynamicPruning = true; + if (isCardinalityLow && isCardinalityAggregationOnlyAggregation && isFieldSupportedForDynamicPruning) { + // create disjunctions from terms + // this logic should be pluggable depending on the type of leaf bucket collector by CardinalityAggregator + TermsEnum terms = ordinalValues.termsEnum(); + Weight weight = context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE_NO_SCORES, 1f); + Map found = new HashMap<>(); + List subScorers = new ArrayList<>(); + while (terms.next() != null && !found.containsKey(terms.ord())) { + // TODO can we get rid of terms previously encountered in other segments? + TermQuery termQuery = new TermQuery(new Term(fieldContext.field(), terms.term())); + Weight subWeight = context.searcher().createWeight(termQuery, ScoreMode.COMPLETE_NO_SCORES, 1f); + Scorer scorer = subWeight.scorer(ctx); + if (scorer != null) { + subScorers.add(scorer); + } + found.put(terms.ord(), true); + } + disjunctionScorer = new DisjunctionWithDynamicPruningScorer(weight, subScorers); + disi = ConjunctionUtils.intersectScorers(List.of(disjunctionScorer, weight.scorer(ctx))); + } else { + disjunctionScorer = null; + disi = null; + } + } + + @Override + public void collect(int doc, long bucketOrd) throws IOException { + if (disi == null || disjunctionScorer == null) { + delegateCollector.collect(doc, bucketOrd); + } else { + // perform the full iteration using dynamic pruning of DISIs and return right away + disi.advance(doc); + int currDoc = disi.docID(); + assert currDoc == doc; + final Bits liveDocs = ctx.reader().getLiveDocs(); + assert liveDocs == null || liveDocs.get(currDoc); + do { + if (liveDocs == null || liveDocs.get(currDoc)) { + delegateCollector.collect(currDoc, bucketOrd); + disjunctionScorer.removeAllDISIsOnCurrentDoc(); + } + currDoc = disi.nextDoc(); + } while (currDoc != DocIdSetIterator.NO_MORE_DOCS); + throw new CollectionTerminatedException(); + } + } + + @Override + public void close() { + delegateCollector.close(); + } + + @Override + public void postCollect() throws IOException { + delegateCollector.postCollect(); + } +} diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java index cdd17e2fa7dd6..a9966c9e70e76 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java @@ -33,15 +33,22 @@ package org.opensearch.search.aggregations.metrics; import org.apache.lucene.document.BinaryDocValuesField; +import org.apache.lucene.document.Field; import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.KeywordField; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.document.SortedSetDocValuesField; +import org.apache.lucene.index.Term; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.util.BytesRef; import org.opensearch.common.CheckedConsumer; import org.opensearch.common.geo.GeoPoint; +import org.opensearch.index.mapper.KeywordFieldMapper; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.mapper.RangeFieldMapper; @@ -56,6 +63,7 @@ import java.util.Set; import java.util.function.Consumer; +import static java.util.Arrays.asList; import static java.util.Collections.singleton; public class CardinalityAggregatorTests extends AggregatorTestCase { @@ -90,6 +98,56 @@ public void testRangeFieldValues() throws IOException { }, fieldType); } + public void testDynamicPruningOrdinalCollector() throws IOException { + final String fieldName = "testField"; + final String filterFieldName = "filterField"; + + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); + testAggregation(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> { + iw.addDocument(asList( + new KeywordField(fieldName, "1", Field.Store.NO), + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("1")), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + )); + iw.addDocument(asList( + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + )); + iw.addDocument(asList( + new KeywordField(fieldName, "1", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("1")) + )); + iw.addDocument(asList( + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + )); + iw.addDocument(asList( + new KeywordField(fieldName, "3", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("3")) + )); + iw.addDocument(asList( + new KeywordField(fieldName, "4", Field.Store.NO), + new KeywordField(filterFieldName, "bar", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("4")) + )); + iw.addDocument(asList( + new KeywordField(fieldName, "5", Field.Store.NO), + new KeywordField(filterFieldName, "bar", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("5")) + )); + }, card -> { + assertEquals(3.0, card.getValue(), 0); + assertTrue(AggregationInspectionHelper.hasValue(card)); + }, fieldType); + } + public void testNoMatchingField() throws IOException { testAggregation(new MatchAllDocsQuery(), iw -> { iw.addDocument(singleton(new SortedNumericDocValuesField("wrong_number", 7))); From 0d6f15140924cfc818fec05f56b88cac67464a05 Mon Sep 17 00:00:00 2001 From: bowenlan-amzn Date: Thu, 16 May 2024 08:37:06 -0700 Subject: [PATCH 02/17] Reading Signed-off-by: bowenlan-amzn --- .../metrics/DisjunctionWithDynamicPruningScorer.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java index 6a7e66e8be2f0..75fa3d2eb6f93 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java @@ -24,8 +24,9 @@ /** - * Clone of {@link org.apache.lucene.search} {@code DisjunctionScorer.java} in lucene with following modifications - - * 1. {@link #removeAllDISIsOnCurrentDoc()} - it removes all the DISIs for subscorer pointing to current doc. This is + * Clone of {@link org.apache.lucene.search} {@code DisjunctionScorer.java} in lucene with following modifications + *

+ * 1. {@link #removeAllDISIsOnCurrentDoc()} removes all the DISIs for subscorer pointing to current doc. This is * helpful in dynamic pruning for Cardinality aggregation, where once a term is found, it becomes irrelevant for * rest of the search space, so this term's subscorer DISI can be safely removed from list of subscorer to process. *

From a18b597ea8dc47fd0211f599908d84008daf5a77 Mon Sep 17 00:00:00 2001 From: bowenlan-amzn Date: Tue, 21 May 2024 07:49:37 -0700 Subject: [PATCH 03/17] remaining disjunction scorer full understand Signed-off-by: bowenlan-amzn --- .../metrics/CardinalityAggregator.java | 8 +++++++- .../DisjunctionWithDynamicPruningScorer.java | 14 ++++++-------- .../aggregations/metrics/InternalCardinality.java | 1 - 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java index 91887e2e4a202..dd2e7458d81d2 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java @@ -125,6 +125,7 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException { if (valuesSource instanceof ValuesSource.Bytes.WithOrdinals) { ValuesSource.Bytes.WithOrdinals source = (ValuesSource.Bytes.WithOrdinals) valuesSource; final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx); + final SortedSetDocValues globalOrdinalValues = source.globalOrdinalsValues(ctx); final long maxOrd = ordinalValues.getValueCount(); if (maxOrd == 0) { emptyCollectorsUsed++; @@ -179,7 +180,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) { if (counts == null || owningBucketOrdinal >= counts.maxOrd() || counts.cardinality(owningBucketOrdinal) == 0) { return buildEmptyAggregation(); } - // We need to build a copy because the returned Aggregation needs remain usable after + // We need to build a copy because the returned Aggregation needs to remain usable after // this Aggregator (and its HLL++ counters) is released. AbstractHyperLogLogPlusPlus copy = counts.clone(owningBucketOrdinal, BigArrays.NON_RECYCLING_INSTANCE); return new InternalCardinality(name, copy, metadata()); @@ -322,6 +323,9 @@ public void collect(int doc, long bucketOrd) throws IOException { bits.set((int) ord); } } + // for this owning bucket ord, save the values of current doc as value ordinals bits + // visitedOrds (array with index as owning bucket, value as bits) + // ordinals is number array, each element representing a text or term, and sorted by the values } @Override @@ -336,6 +340,7 @@ public void postCollect() throws IOException { try (LongArray hashes = bigArrays.newLongArray(maxOrd, false)) { final MurmurHash3.Hash128 hash = new MurmurHash3.Hash128(); + // for every ordinal, we want the hash of its value for (long ord = allVisitedOrds.nextSetBit(0); ord < Long.MAX_VALUE; ord = ord + 1 < maxOrd ? allVisitedOrds.nextSetBit(ord + 1) : Long.MAX_VALUE) { @@ -347,6 +352,7 @@ public void postCollect() throws IOException { for (long bucket = visitedOrds.size() - 1; bucket >= 0; --bucket) { final BitArray bits = visitedOrds.get(bucket); if (bits != null) { + // for every ordinal of this bucket, we collect by using its hash for (long ord = bits.nextSetBit(0); ord < Long.MAX_VALUE; ord = ord + 1 < maxOrd ? bits.nextSetBit(ord + 1) : Long.MAX_VALUE) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java index 75fa3d2eb6f93..9b6a0f42e0fa9 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java @@ -31,7 +31,7 @@ * rest of the search space, so this term's subscorer DISI can be safely removed from list of subscorer to process. *

* 2. {@link #removeAllDISIsOnCurrentDoc()} breaks the invariant of Conjuction DISI i.e. the docIDs of all sub-scorers should be - * less than or equal to current docID iterator is pointing to. When we remove elements from priority, it results in heapify action, which modifies + * less than or equal to current docID iterator is pointing to. When we remove elements from disi priority queue, it results in heapify action, which modifies * the top of the priority queye, which represents the current docID for subscorers here. To address this, we are wrapping the * iterator with {@link SlowDocIdPropagatorDISI} which keeps the iterator pointing to last docID before {@link #removeAllDISIsOnCurrentDoc()} * is called and updates this docID only when next() or advance() is called. @@ -97,7 +97,6 @@ public DocIdSetIterator iterator() { private static class SlowDocIdPropagatorDISI extends DocIdSetIterator { DocIdSetIterator disi; - Integer curDocId; SlowDocIdPropagatorDISI(DocIdSetIterator disi, Integer curDocId) { @@ -147,11 +146,6 @@ public TwoPhaseIterator twoPhaseIterator() { return twoPhase; } - @Override - public float getMaxScore(int i) throws IOException { - return 0; - } - private class TwoPhase extends TwoPhaseIterator { private final float matchCost; @@ -231,7 +225,6 @@ public float matchCost() { } } - @Override public final int docID() { return subScorers.top().doc; @@ -262,4 +255,9 @@ public final Collection getChildren() throws IOException { } return children; } + + @Override + public float getMaxScore(int i) throws IOException { + return 0; + } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalCardinality.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalCardinality.java index 7e9511ffdd379..9f9ad63220fea 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalCardinality.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalCardinality.java @@ -117,7 +117,6 @@ public InternalAggregation reduce(List aggregations, Reduce return aggregations.get(0); } else { return new InternalCardinality(name, reduced, getMetadata()); - } } From 85133c4c3ed43aaf339f665597f89cba98bb761d Mon Sep 17 00:00:00 2001 From: bowenlan-amzn Date: Fri, 24 May 2024 10:53:25 -0700 Subject: [PATCH 04/17] utilize competitive iterator api to perform pruning Signed-off-by: bowenlan-amzn --- .../metrics/CardinalityAggregator.java | 124 +++++++++++++++++- .../DisjunctionWithDynamicPruningScorer.java | 25 ++-- .../DynamicPruningCollectorWrapper.java | 13 +- .../metrics/CardinalityAggregatorTests.java | 90 +++++++------ 4 files changed, 193 insertions(+), 59 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java index dd2e7458d81d2..6f66fc64f6dc7 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java @@ -35,7 +35,15 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.Weight; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.RamUsageEstimator; @@ -59,6 +67,7 @@ import org.opensearch.search.internal.SearchContext; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.function.BiConsumer; @@ -137,8 +146,15 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException { // only use ordinals if they don't increase memory usage by more than 25% if (ordinalsMemoryUsage < countsMemoryUsage / 4) { ordinalsCollectorsUsed++; - return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), - context, ctx, fieldContext, source); + // return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), + // context, ctx, fieldContext, source); + return new CompetitiveCollector( + new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), + source, + ctx, + context, + fieldContext + ); } ordinalsCollectorsOverheadTooHigh++; } @@ -217,6 +233,110 @@ abstract static class Collector extends LeafBucketCollector implements Releasabl } + private static class CompetitiveCollector extends Collector { + + private final Collector delegate; + private final DisiPriorityQueue pq; + + CompetitiveCollector( + Collector delegate, + ValuesSource.Bytes.WithOrdinals source, + LeafReaderContext ctx, + SearchContext context, + FieldContext fieldContext + ) throws IOException { + this.delegate = delegate; + + final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx); + TermsEnum terms = ordinalValues.termsEnum(); + Map postingMap = new HashMap<>(); + while (terms.next() != null) { + BytesRef term = terms.term(); + + TermQuery termQuery = new TermQuery(new Term(fieldContext.field(), term)); + Weight subWeight = context.searcher().createWeight(termQuery, ScoreMode.COMPLETE_NO_SCORES, 1f); + Scorer scorer = subWeight.scorer(ctx); + + postingMap.put(term, scorer); + } + this.pq = new DisiPriorityQueue(postingMap.size()); + for (Map.Entry entry : postingMap.entrySet()) { + pq.add(new DisiWrapper(entry.getValue())); + } + } + + @Override + public void close() { + delegate.close(); + } + + @Override + public void collect(int doc, long owningBucketOrd) throws IOException { + delegate.collect(doc, owningBucketOrd); + } + + @Override + public DocIdSetIterator competitiveIterator() throws IOException { + return new DisjunctionDISIWithPruning(pq); + } + + @Override + public void postCollect() throws IOException { + delegate.postCollect(); + } + } + + private static class DisjunctionDISIWithPruning extends DocIdSetIterator { + + final DisiPriorityQueue queue; + + public DisjunctionDISIWithPruning(DisiPriorityQueue queue) { + this.queue = queue; + } + + @Override + public int docID() { + return queue.top().doc; + } + + @Override + public int nextDoc() throws IOException { + // don't expect this to be called + throw new UnsupportedOperationException(); + } + + @Override + public int advance(int target) throws IOException { + // more than advance to the next doc >= target + // we also do the pruning of current doc here + + DisiWrapper top = queue.top(); + + // after collecting the doc, before advancing to target + // we can safely remove all the iterators that having this doc + if (top.doc != -1) { + int curTopDoc = top.doc; + do { + top.doc = top.approximation.advance(Integer.MAX_VALUE); + top = queue.updateTop(); + } while (top.doc == curTopDoc); + } + + if (top.doc >= target) return top.doc; + do { + top.doc = top.approximation.advance(target); + top = queue.updateTop(); + } while (top.doc < target); + return top.doc; + } + + @Override + public long cost() { + // don't expect this to be called + throw new UnsupportedOperationException(); + } + } + /** * Empty Collector for the Cardinality agg * diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java index 9b6a0f42e0fa9..93a01aaa1e053 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java @@ -22,7 +22,6 @@ import java.util.Collection; import java.util.List; - /** * Clone of {@link org.apache.lucene.search} {@code DisjunctionScorer.java} in lucene with following modifications *

@@ -45,8 +44,7 @@ public class DisjunctionWithDynamicPruningScorer extends Scorer { private Integer docID; - public DisjunctionWithDynamicPruningScorer(Weight weight, List subScorers) - throws IOException { + public DisjunctionWithDynamicPruningScorer(Weight weight, List subScorers) throws IOException { super(weight); if (subScorers.size() <= 1) { throw new IllegalArgumentException("There must be at least 2 subScorers"); @@ -90,9 +88,9 @@ public void removeAllDISIsOnCurrentDoc() { @Override public DocIdSetIterator iterator() { - DocIdSetIterator disi = getIterator(); - docID = disi.docID(); - return new SlowDocIdPropagatorDISI(getIterator(), docID); + DocIdSetIterator disi = getIterator(); + docID = disi.docID(); + return new SlowDocIdPropagatorDISI(getIterator(), docID); } private static class SlowDocIdPropagatorDISI extends DocIdSetIterator { @@ -157,13 +155,12 @@ private class TwoPhase extends TwoPhaseIterator { private TwoPhase(DocIdSetIterator approximation, float matchCost) { super(approximation); this.matchCost = matchCost; - unverifiedMatches = - new PriorityQueue(DisjunctionWithDynamicPruningScorer.this.subScorers.size()) { - @Override - protected boolean lessThan(DisiWrapper a, DisiWrapper b) { - return a.matchCost < b.matchCost; - } - }; + unverifiedMatches = new PriorityQueue(DisjunctionWithDynamicPruningScorer.this.subScorers.size()) { + @Override + protected boolean lessThan(DisiWrapper a, DisiWrapper b) { + return a.matchCost < b.matchCost; + } + }; } DisiWrapper getSubMatches() throws IOException { @@ -183,7 +180,7 @@ public boolean matches() throws IOException { verifiedMatches = null; unverifiedMatches.clear(); - for (DisiWrapper w = subScorers.topList(); w != null; ) { + for (DisiWrapper w = subScorers.topList(); w != null;) { DisiWrapper next = w.next; if (w.twoPhaseView == null) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java index f4c3d59a3833f..cb735a3257289 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java @@ -8,7 +8,6 @@ package org.opensearch.search.aggregations.metrics; -import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.Term; @@ -38,9 +37,13 @@ class DynamicPruningCollectorWrapper extends CardinalityAggregator.Collector { private final DocIdSetIterator disi; private final CardinalityAggregator.Collector delegateCollector; - DynamicPruningCollectorWrapper(CardinalityAggregator.Collector delegateCollector, - SearchContext context, LeafReaderContext ctx, FieldContext fieldContext, - ValuesSource.Bytes.WithOrdinals source) throws IOException { + DynamicPruningCollectorWrapper( + CardinalityAggregator.Collector delegateCollector, + SearchContext context, + LeafReaderContext ctx, + FieldContext fieldContext, + ValuesSource.Bytes.WithOrdinals source + ) throws IOException { this.ctx = ctx; this.delegateCollector = delegateCollector; final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx); @@ -52,7 +55,7 @@ class DynamicPruningCollectorWrapper extends CardinalityAggregator.Collector { // this logic should be pluggable depending on the type of leaf bucket collector by CardinalityAggregator TermsEnum terms = ordinalValues.termsEnum(); Weight weight = context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE_NO_SCORES, 1f); - Map found = new HashMap<>(); + Map found = new HashMap<>(); // ord : found or not List subScorers = new ArrayList<>(); while (terms.next() != null && !found.containsKey(terms.ord())) { // TODO can we get rid of terms previously encountered in other segments? diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java index a9966c9e70e76..d21e7f6ed8550 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java @@ -105,44 +105,58 @@ public void testDynamicPruningOrdinalCollector() throws IOException { MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); testAggregation(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> { - iw.addDocument(asList( - new KeywordField(fieldName, "1", Field.Store.NO), - new KeywordField(fieldName, "2", Field.Store.NO), - new KeywordField(filterFieldName, "foo", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("1")), - new SortedSetDocValuesField(fieldName, new BytesRef("2")) - )); - iw.addDocument(asList( - new KeywordField(fieldName, "2", Field.Store.NO), - new KeywordField(filterFieldName, "foo", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("2")) - )); - iw.addDocument(asList( - new KeywordField(fieldName, "1", Field.Store.NO), - new KeywordField(filterFieldName, "foo", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("1")) - )); - iw.addDocument(asList( - new KeywordField(fieldName, "2", Field.Store.NO), - new KeywordField(filterFieldName, "foo", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("2")) - )); - iw.addDocument(asList( - new KeywordField(fieldName, "3", Field.Store.NO), - new KeywordField(filterFieldName, "foo", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("3")) - )); - iw.addDocument(asList( - new KeywordField(fieldName, "4", Field.Store.NO), - new KeywordField(filterFieldName, "bar", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("4")) - )); - iw.addDocument(asList( - new KeywordField(fieldName, "5", Field.Store.NO), - new KeywordField(filterFieldName, "bar", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("5")) - )); - }, card -> { + iw.addDocument( + asList( + new KeywordField(fieldName, "1", Field.Store.NO), + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("1")), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "1", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("1")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "3", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("3")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "4", Field.Store.NO), + new KeywordField(filterFieldName, "bar", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("4")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "5", Field.Store.NO), + new KeywordField(filterFieldName, "bar", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("5")) + ) + ); + }, card -> { assertEquals(3.0, card.getValue(), 0); assertTrue(AggregationInspectionHelper.hasValue(card)); }, fieldType); From 9d4701c0866c8293e51896af44039574f5ff466f Mon Sep 17 00:00:00 2001 From: bowenlan-amzn Date: Fri, 24 May 2024 15:34:03 -0700 Subject: [PATCH 05/17] handle missing input Signed-off-by: bowenlan-amzn --- .../search/aggregations/metrics/CardinalityAggregator.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java index 6f66fc64f6dc7..6fe1e7345cda4 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java @@ -81,6 +81,7 @@ public class CardinalityAggregator extends NumericMetricsAggregator.SingleValue private final int precision; private final ValuesSource valuesSource; + private final ValuesSourceConfig valuesSourceConfig; private final FieldContext fieldContext; // Expensive to initialize, so we only initialize it when we have an actual value source @@ -105,6 +106,7 @@ public CardinalityAggregator( ) throws IOException { super(name, context, parent, metadata); // TODO: Stop using nulls here + this.valuesSourceConfig = valuesSourceConfig; this.valuesSource = valuesSourceConfig.hasValues() ? valuesSourceConfig.getValuesSource() : null; this.precision = precision; this.fieldContext = valuesSourceConfig.fieldContext(); @@ -148,6 +150,9 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException { ordinalsCollectorsUsed++; // return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), // context, ctx, fieldContext, source); + if (valuesSourceConfig.missing() != null) { + return new OrdinalsCollector(counts, ordinalValues, context.bigArrays()); + } return new CompetitiveCollector( new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), source, From 77fceeaaed9f9ea6fa621e7aef1341a956e1716c Mon Sep 17 00:00:00 2001 From: bowenlan-amzn Date: Fri, 24 May 2024 15:46:53 -0700 Subject: [PATCH 06/17] add change log Signed-off-by: bowenlan-amzn --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b4a5e5f4f981..c62122ddd9ce9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Make outbound side of transport protocol dependent ([#13293](https://github.com/opensearch-project/OpenSearch/pull/13293)) - [Remote Store] Add dynamic cluster settings to set timeout for segments upload to Remote Store ([#13679](https://github.com/opensearch-project/OpenSearch/pull/13679)) - [Remote Store] Upload translog checkpoint as object metadata to translog.tlog([#13637](https://github.com/opensearch-project/OpenSearch/pull/13637)) +- Support Dynamic Pruning in Cardinality Aggregation ([#13821](https://github.com/opensearch-project/OpenSearch/pull/13821)) ### Dependencies - Bump `com.github.spullara.mustache.java:compiler` from 0.9.10 to 0.9.13 ([#13329](https://github.com/opensearch-project/OpenSearch/pull/13329), [#13559](https://github.com/opensearch-project/OpenSearch/pull/13559)) From d0043ae610fefd3c4d417d823c44e6354aec736a Mon Sep 17 00:00:00 2001 From: bowenlan-amzn Date: Tue, 4 Jun 2024 11:51:01 -0700 Subject: [PATCH 07/17] clean up Signed-off-by: bowenlan-amzn --- .idea/runConfigurations/Debug_OpenSearch.xml | 6 +- .../metrics/CardinalityAggregator.java | 11 +- .../DisjunctionWithDynamicPruningScorer.java | 260 ------------------ .../DynamicPruningCollectorWrapper.java | 109 -------- 4 files changed, 8 insertions(+), 378 deletions(-) delete mode 100644 server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java delete mode 100644 server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java diff --git a/.idea/runConfigurations/Debug_OpenSearch.xml b/.idea/runConfigurations/Debug_OpenSearch.xml index 0d8bf59823acf..c18046f873477 100644 --- a/.idea/runConfigurations/Debug_OpenSearch.xml +++ b/.idea/runConfigurations/Debug_OpenSearch.xml @@ -6,6 +6,10 @@

- * 1. {@link #removeAllDISIsOnCurrentDoc()} removes all the DISIs for subscorer pointing to current doc. This is - * helpful in dynamic pruning for Cardinality aggregation, where once a term is found, it becomes irrelevant for - * rest of the search space, so this term's subscorer DISI can be safely removed from list of subscorer to process. - *

- * 2. {@link #removeAllDISIsOnCurrentDoc()} breaks the invariant of Conjuction DISI i.e. the docIDs of all sub-scorers should be - * less than or equal to current docID iterator is pointing to. When we remove elements from disi priority queue, it results in heapify action, which modifies - * the top of the priority queye, which represents the current docID for subscorers here. To address this, we are wrapping the - * iterator with {@link SlowDocIdPropagatorDISI} which keeps the iterator pointing to last docID before {@link #removeAllDISIsOnCurrentDoc()} - * is called and updates this docID only when next() or advance() is called. - */ -public class DisjunctionWithDynamicPruningScorer extends Scorer { - - private final boolean needsScores; - private final DisiPriorityQueue subScorers; - private final DocIdSetIterator approximation; - private final TwoPhase twoPhase; - - private Integer docID; - - public DisjunctionWithDynamicPruningScorer(Weight weight, List subScorers) throws IOException { - super(weight); - if (subScorers.size() <= 1) { - throw new IllegalArgumentException("There must be at least 2 subScorers"); - } - this.subScorers = new DisiPriorityQueue(subScorers.size()); - for (Scorer scorer : subScorers) { - final DisiWrapper w = new DisiWrapper(scorer); - this.subScorers.add(w); - } - this.needsScores = false; - this.approximation = new DisjunctionDISIApproximation(this.subScorers); - - boolean hasApproximation = false; - float sumMatchCost = 0; - long sumApproxCost = 0; - // Compute matchCost as the average over the matchCost of the subScorers. - // This is weighted by the cost, which is an expected number of matching documents. - for (DisiWrapper w : this.subScorers) { - long costWeight = (w.cost <= 1) ? 1 : w.cost; - sumApproxCost += costWeight; - if (w.twoPhaseView != null) { - hasApproximation = true; - sumMatchCost += w.matchCost * costWeight; - } - } - - if (hasApproximation == false) { // no sub scorer supports approximations - twoPhase = null; - } else { - final float matchCost = sumMatchCost / sumApproxCost; - twoPhase = new TwoPhase(approximation, matchCost); - } - } - - public void removeAllDISIsOnCurrentDoc() { - docID = this.docID(); - while (subScorers.size() > 0 && subScorers.top().doc == docID) { - subScorers.pop(); - } - } - - @Override - public DocIdSetIterator iterator() { - DocIdSetIterator disi = getIterator(); - docID = disi.docID(); - return new SlowDocIdPropagatorDISI(getIterator(), docID); - } - - private static class SlowDocIdPropagatorDISI extends DocIdSetIterator { - DocIdSetIterator disi; - Integer curDocId; - - SlowDocIdPropagatorDISI(DocIdSetIterator disi, Integer curDocId) { - this.disi = disi; - this.curDocId = curDocId; - } - - @Override - public int docID() { - assert curDocId <= disi.docID(); - return curDocId; - } - - @Override - public int nextDoc() throws IOException { - return advance(curDocId + 1); - } - - @Override - public int advance(int i) throws IOException { - if (i <= disi.docID()) { - // since we are slow propagating docIDs, it may happen the disi is already advanced to a higher docID than i - // in such scenarios we can simply return the docID where disi is pointing to and update the curDocId - curDocId = disi.docID(); - return disi.docID(); - } - curDocId = disi.advance(i); - return curDocId; - } - - @Override - public long cost() { - return disi.cost(); - } - } - - private DocIdSetIterator getIterator() { - if (twoPhase != null) { - return TwoPhaseIterator.asDocIdSetIterator(twoPhase); - } else { - return approximation; - } - } - - @Override - public TwoPhaseIterator twoPhaseIterator() { - return twoPhase; - } - - private class TwoPhase extends TwoPhaseIterator { - - private final float matchCost; - // list of verified matches on the current doc - DisiWrapper verifiedMatches; - // priority queue of approximations on the current doc that have not been verified yet - final PriorityQueue unverifiedMatches; - - private TwoPhase(DocIdSetIterator approximation, float matchCost) { - super(approximation); - this.matchCost = matchCost; - unverifiedMatches = new PriorityQueue(DisjunctionWithDynamicPruningScorer.this.subScorers.size()) { - @Override - protected boolean lessThan(DisiWrapper a, DisiWrapper b) { - return a.matchCost < b.matchCost; - } - }; - } - - DisiWrapper getSubMatches() throws IOException { - // iteration order does not matter - for (DisiWrapper w : unverifiedMatches) { - if (w.twoPhaseView.matches()) { - w.next = verifiedMatches; - verifiedMatches = w; - } - } - unverifiedMatches.clear(); - return verifiedMatches; - } - - @Override - public boolean matches() throws IOException { - verifiedMatches = null; - unverifiedMatches.clear(); - - for (DisiWrapper w = subScorers.topList(); w != null;) { - DisiWrapper next = w.next; - - if (w.twoPhaseView == null) { - // implicitly verified, move it to verifiedMatches - w.next = verifiedMatches; - verifiedMatches = w; - - if (needsScores == false) { - // we can stop here - return true; - } - } else { - unverifiedMatches.add(w); - } - w = next; - } - - if (verifiedMatches != null) { - return true; - } - - // verify subs that have an two-phase iterator - // least-costly ones first - while (unverifiedMatches.size() > 0) { - DisiWrapper w = unverifiedMatches.pop(); - if (w.twoPhaseView.matches()) { - w.next = null; - verifiedMatches = w; - return true; - } - } - - return false; - } - - @Override - public float matchCost() { - return matchCost; - } - } - - @Override - public final int docID() { - return subScorers.top().doc; - } - - DisiWrapper getSubMatches() throws IOException { - if (twoPhase == null) { - return subScorers.topList(); - } else { - return twoPhase.getSubMatches(); - } - } - - @Override - public final float score() throws IOException { - return score(getSubMatches()); - } - - protected float score(DisiWrapper topList) throws IOException { - return 1f; - } - - @Override - public final Collection getChildren() throws IOException { - ArrayList children = new ArrayList<>(); - for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) { - children.add(new ChildScorable(scorer.scorer, "SHOULD")); - } - return children; - } - - @Override - public float getMaxScore(int i) throws IOException { - return 0; - } -} diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java deleted file mode 100644 index cb735a3257289..0000000000000 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.aggregations.metrics; - -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.SortedSetDocValues; -import org.apache.lucene.index.Term; -import org.apache.lucene.index.TermsEnum; -import org.apache.lucene.search.CollectionTerminatedException; -import org.apache.lucene.search.ConjunctionUtils; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.TermQuery; -import org.apache.lucene.search.Weight; -import org.apache.lucene.util.Bits; -import org.opensearch.search.aggregations.support.FieldContext; -import org.opensearch.search.aggregations.support.ValuesSource; -import org.opensearch.search.internal.SearchContext; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -class DynamicPruningCollectorWrapper extends CardinalityAggregator.Collector { - - private final LeafReaderContext ctx; - private final DisjunctionWithDynamicPruningScorer disjunctionScorer; - private final DocIdSetIterator disi; - private final CardinalityAggregator.Collector delegateCollector; - - DynamicPruningCollectorWrapper( - CardinalityAggregator.Collector delegateCollector, - SearchContext context, - LeafReaderContext ctx, - FieldContext fieldContext, - ValuesSource.Bytes.WithOrdinals source - ) throws IOException { - this.ctx = ctx; - this.delegateCollector = delegateCollector; - final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx); - boolean isCardinalityLow = ordinalValues.getValueCount() < 10; - boolean isCardinalityAggregationOnlyAggregation = true; - boolean isFieldSupportedForDynamicPruning = true; - if (isCardinalityLow && isCardinalityAggregationOnlyAggregation && isFieldSupportedForDynamicPruning) { - // create disjunctions from terms - // this logic should be pluggable depending on the type of leaf bucket collector by CardinalityAggregator - TermsEnum terms = ordinalValues.termsEnum(); - Weight weight = context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE_NO_SCORES, 1f); - Map found = new HashMap<>(); // ord : found or not - List subScorers = new ArrayList<>(); - while (terms.next() != null && !found.containsKey(terms.ord())) { - // TODO can we get rid of terms previously encountered in other segments? - TermQuery termQuery = new TermQuery(new Term(fieldContext.field(), terms.term())); - Weight subWeight = context.searcher().createWeight(termQuery, ScoreMode.COMPLETE_NO_SCORES, 1f); - Scorer scorer = subWeight.scorer(ctx); - if (scorer != null) { - subScorers.add(scorer); - } - found.put(terms.ord(), true); - } - disjunctionScorer = new DisjunctionWithDynamicPruningScorer(weight, subScorers); - disi = ConjunctionUtils.intersectScorers(List.of(disjunctionScorer, weight.scorer(ctx))); - } else { - disjunctionScorer = null; - disi = null; - } - } - - @Override - public void collect(int doc, long bucketOrd) throws IOException { - if (disi == null || disjunctionScorer == null) { - delegateCollector.collect(doc, bucketOrd); - } else { - // perform the full iteration using dynamic pruning of DISIs and return right away - disi.advance(doc); - int currDoc = disi.docID(); - assert currDoc == doc; - final Bits liveDocs = ctx.reader().getLiveDocs(); - assert liveDocs == null || liveDocs.get(currDoc); - do { - if (liveDocs == null || liveDocs.get(currDoc)) { - delegateCollector.collect(currDoc, bucketOrd); - disjunctionScorer.removeAllDISIsOnCurrentDoc(); - } - currDoc = disi.nextDoc(); - } while (currDoc != DocIdSetIterator.NO_MORE_DOCS); - throw new CollectionTerminatedException(); - } - } - - @Override - public void close() { - delegateCollector.close(); - } - - @Override - public void postCollect() throws IOException { - delegateCollector.postCollect(); - } -} From 82a5f0676871e0116269606c837ec82f6b9b58cf Mon Sep 17 00:00:00 2001 From: bowenlan-amzn Date: Tue, 4 Jun 2024 16:46:50 -0700 Subject: [PATCH 08/17] Clean up Signed-off-by: bowenlan-amzn --- .idea/runConfigurations/Debug_OpenSearch.xml | 6 +- .../metrics/CardinalityAggregator.java | 85 +++++++++---------- 2 files changed, 42 insertions(+), 49 deletions(-) diff --git a/.idea/runConfigurations/Debug_OpenSearch.xml b/.idea/runConfigurations/Debug_OpenSearch.xml index c18046f873477..0d8bf59823acf 100644 --- a/.idea/runConfigurations/Debug_OpenSearch.xml +++ b/.idea/runConfigurations/Debug_OpenSearch.xml @@ -6,10 +6,6 @@