Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify precomputation of aggregations behind a common API #16733

Merged
merged 5 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
package org.opensearch.search.aggregations;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreMode;
import org.opensearch.core.common.breaker.CircuitBreaker;
Expand Down Expand Up @@ -200,6 +201,9 @@ public Map<String, Object> metadata() {

@Override
public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException {
if (tryPrecomputeAggregationForLeaf(ctx)) {
throw new CollectionTerminatedException();
}
preGetSubLeafCollectors(ctx);
final LeafBucketCollector sub = collectableSubAggregators.getLeafCollector(ctx);
return getLeafCollector(ctx, sub);
Expand All @@ -216,6 +220,10 @@ protected void preGetSubLeafCollectors(LeafReaderContext ctx) throws IOException
*/
protected void doPreCollection() throws IOException {}

protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
msfroh marked this conversation as resolved.
Show resolved Hide resolved
msfroh marked this conversation as resolved.
Show resolved Hide resolved
return false;
}

@Override
public final void preCollection() throws IOException {
List<BucketCollector> collectors = Arrays.asList(subAggregators);
Expand Down Expand Up @@ -251,8 +259,8 @@ public Aggregator[] subAggregators() {
public Aggregator subAggregator(String aggName) {
if (subAggregatorbyName == null) {
subAggregatorbyName = new HashMap<>(subAggregators.length);
for (int i = 0; i < subAggregators.length; i++) {
subAggregatorbyName.put(subAggregators[i].name(), subAggregators[i]);
for (Aggregator subAggregator : subAggregators) {
subAggregatorbyName.put(subAggregator.name(), subAggregator);
}
}
return subAggregatorbyName.get(aggName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,13 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t
}

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
if (optimized) throw new CollectionTerminatedException();
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
finishLeaf(); // May need to wrap up previous leaf if it could not be precomputed
msfroh marked this conversation as resolved.
Show resolved Hide resolved
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
}

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
finishLeaf();

boolean fillDocIdSet = deferredCollectors != NO_OP_COLLECTOR;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.CollectionUtil;
Expand Down Expand Up @@ -187,22 +186,23 @@ public ScoreMode scoreMode() {
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}

boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
if (optimized) throw new CollectionTerminatedException();

SortedNumericDocValues values = valuesSource.longValues(ctx);
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
if (supportedStarTree != null) {
if (preComputeWithStarTree(ctx, supportedStarTree) == true) {
throw new CollectionTerminatedException();
return true;
}
}
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}

SortedNumericDocValues values = valuesSource.longValues(ctx);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
package org.opensearch.search.aggregations.bucket.range;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ScoreMode;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
Expand Down Expand Up @@ -310,10 +309,15 @@ public ScoreMode scoreMode() {
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
if (segmentMatchAll(context, ctx) && filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false)) {
throw new CollectionTerminatedException();
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
if (segmentMatchAll(context, ctx)) {
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false);
}
return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {

final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
return new LeafBucketCollectorBase(sub, values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
Expand Down Expand Up @@ -165,35 +164,32 @@ public void setWeight(Weight weight) {
@return A LeafBucketCollector implementation with collection termination, since collection is complete
@throws IOException If an I/O error occurs during reading
*/
LeafBucketCollector termDocFreqCollector(
LeafReaderContext ctx,
SortedSetDocValues globalOrds,
BiConsumer<Long, Integer> ordCountConsumer
) throws IOException {
boolean tryCollectFromTermFrequencies(LeafReaderContext ctx, SortedSetDocValues globalOrds, BiConsumer<Long, Integer> ordCountConsumer)
throws IOException {
if (weight == null) {
// Weight not assigned - cannot use this optimization
return null;
return false;
} else {
if (weight.count(ctx) == 0) {
// No documents matches top level query on this segment, we can skip the segment entirely
return LeafBucketCollector.NO_OP_COLLECTOR;
return true;
} else if (weight.count(ctx) != ctx.reader().maxDoc()) {
// weight.count(ctx) == ctx.reader().maxDoc() implies there are no deleted documents and
// top-level query matches all docs in the segment
return null;
return false;
}
}

Terms segmentTerms = ctx.reader().terms(this.fieldName);
if (segmentTerms == null) {
// Field is not indexed.
return null;
return false;
}

NumericDocValues docCountValues = DocValues.getNumeric(ctx.reader(), DocCountFieldMapper.NAME);
if (docCountValues.nextDoc() != NO_MORE_DOCS) {
// This segment has at least one document with the _doc_count field.
return null;
return false;
}

TermsEnum indexTermsEnum = segmentTerms.iterator();
Expand All @@ -217,31 +213,28 @@ LeafBucketCollector termDocFreqCollector(
ordinalTerm = globalOrdinalTermsEnum.next();
}
}
return new LeafBucketCollector() {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
throw new CollectionTerminatedException();
}
};
return true;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx);
collectionStrategy.globalOrdsReady(globalOrds);

if (collectionStrategy instanceof DenseGlobalOrds
&& this.resultStrategy instanceof StandardTermsResults
&& sub == LeafBucketCollector.NO_OP_COLLECTOR) {
LeafBucketCollector termDocFreqCollector = termDocFreqCollector(
&& subAggregators.length == 0) {
msfroh marked this conversation as resolved.
Show resolved Hide resolved
return tryCollectFromTermFrequencies(
ctx,
globalOrds,
(ord, docCount) -> incrementBucketDocCount(collectionStrategy.globalOrdToBucketOrd(0, ord), docCount)
);
if (termDocFreqCollector != null) {
return termDocFreqCollector;
}
}
return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx);
collectionStrategy.globalOrdsReady(globalOrds);

SortedDocValues singleValues = DocValues.unwrapSingleton(globalOrds);
if (singleValues != null) {
Expand Down Expand Up @@ -436,6 +429,24 @@ static class LowCardinality extends GlobalOrdinalsStringTermsAggregator {
this.segmentDocCounts = context.bigArrays().newLongArray(1, true);
}

@Override
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
if (subAggregators.length == 0) {
if (mapping != null) {
mapSegmentCountsToGlobalCounts(mapping);
}
final SortedSetDocValues segmentOrds = valuesSource.ordinalsValues(ctx);
segmentDocCounts = context.bigArrays().grow(segmentDocCounts, 1 + segmentOrds.getValueCount());
mapping = valuesSource.globalOrdinalsMapping(ctx);
return tryCollectFromTermFrequencies(
ctx,
segmentOrds,
(ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount)
);
}
return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
if (mapping != null) {
Expand All @@ -446,17 +457,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol
assert sub == LeafBucketCollector.NO_OP_COLLECTOR;
mapping = valuesSource.globalOrdinalsMapping(ctx);

if (this.resultStrategy instanceof StandardTermsResults) {
LeafBucketCollector termDocFreqCollector = this.termDocFreqCollector(
ctx,
segmentOrds,
(ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount)
);
if (termDocFreqCollector != null) {
return termDocFreqCollector;
}
}

final SortedDocValues singleValues = DocValues.unwrapSingleton(segmentOrds);
if (singleValues != null) {
segmentsWithSingleValuedOrds++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,21 @@ public ScoreMode scoreMode() {
return valuesSource != null && valuesSource.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
}

@Override
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
if (supportedStarTree != null) {
if (parent != null && subAggregators.length == 0) {
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
return true;
msfroh marked this conversation as resolved.
Show resolved Hide resolved
}
precomputeLeafUsingStarTree(ctx, supportedStarTree);
return true;
}
return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
Expand All @@ -130,20 +145,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
}
}

CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
if (supportedStarTree != null) {
if (parent != null && subAggregators.length == 0) {
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
return LeafBucketCollector.NO_OP_COLLECTOR;
}
getStarTreeCollector(ctx, sub, supportedStarTree);
}
return getDefaultLeafCollector(ctx, sub);
}

private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {

final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx);
final NumericDoubleValues values = MultiValueMode.MAX.select(allValues);
Expand All @@ -167,9 +168,9 @@ public void collect(int doc, long bucket) throws IOException {
};
}

public void getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) throws IOException {
private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
AtomicReference<Double> max = new AtomicReference<>(maxes.get(0));
StarTreeQueryHelper.getStarTreeLeafCollector(context, valuesSource, ctx, sub, starTree, MetricStat.MAX.getTypeName(), value -> {
StarTreeQueryHelper.precomputeLeafUsingStarTree(context, valuesSource, ctx, starTree, MetricStat.MAX.getTypeName(), value -> {
max.set(Math.max(max.get(), (NumericUtils.sortableLongToDouble(value))));
}, () -> maxes.set(0, max.get()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,22 @@ public ScoreMode scoreMode() {
return valuesSource != null && valuesSource.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
}

@Override
msfroh marked this conversation as resolved.
Show resolved Hide resolved
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
if (supportedStarTree != null) {
if (parent != null && subAggregators.length == 0) {
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
return true;
}
precomputeLeafUsingStarTree(ctx, supportedStarTree);
return true;
}

return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
Expand All @@ -129,19 +145,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
}
}

CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
if (supportedStarTree != null) {
if (parent != null && subAggregators.length == 0) {
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
return LeafBucketCollector.NO_OP_COLLECTOR;
}
getStarTreeCollector(ctx, sub, supportedStarTree);
}
return getDefaultLeafCollector(ctx, sub);
}

private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx);
final NumericDoubleValues values = MultiValueMode.MIN.select(allValues);
Expand All @@ -164,9 +167,9 @@ public void collect(int doc, long bucket) throws IOException {
};
}

public void getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) throws IOException {
private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
AtomicReference<Double> min = new AtomicReference<>(mins.get(0));
StarTreeQueryHelper.getStarTreeLeafCollector(context, valuesSource, ctx, sub, starTree, MetricStat.MIN.getTypeName(), value -> {
StarTreeQueryHelper.precomputeLeafUsingStarTree(context, valuesSource, ctx, starTree, MetricStat.MIN.getTypeName(), value -> {
min.set(Math.min(min.get(), (NumericUtils.sortableLongToDouble(value))));
}, () -> mins.set(0, min.get()));
}
Expand Down
Loading
Loading