Skip to content

Commit

Permalink
ESQL: Speed up VALUES for many buckets
Browse files Browse the repository at this point in the history
Speeds up the VALUES agg when collecting from many buckets.
Specifically, this speeds up the algorithm used to `finish` the
aggregation. Most specifically, this makes the algorithm more tollerant
to large numbers of groups being collected. The old algorithm was
`O(n^2)` with the number of groups. The new one is `O(n)`

```
(groups)
      1     219.683 ±    1.069  ->   223.477 ±    1.990 ms/op
   1000     426.323 ±   75.963  ->   463.670 ±    7.275 ms/op
 100000   36690.871 ± 4656.350  ->  7800.332 ± 2775.869 ms/op
 200000   89422.113 ± 2972.606  -> 21920.288 ± 3427.962 ms/op
 400000 timed out at 10 minutes -> 40051.524 ± 2011.706 ms/op
```

The `1` group version was not changed at all. That's just noise in the
measurement. The small bump in the `1000` case is almost certainly worth
it and real. The huge drop in the `100000` case is quite real.
  • Loading branch information
nik9000 committed Feb 20, 2025
1 parent 268413b commit 0478be2
Show file tree
Hide file tree
Showing 7 changed files with 730 additions and 177 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.benchmark.compute.operator;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AggregationOperator;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.Operator;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;

@Warmup(iterations = 5)
@Measurement(iterations = 7)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Thread)
@Fork(1)
public class ValuesAggregatorBenchmark {
static final int MIN_BLOCK_LENGTH = 8 * 1024;
private static final int OP_COUNT = 1024;
private static final BytesRef[] KEYWORDS = new BytesRef[] {
new BytesRef("Tokyo"),
new BytesRef("Delhi"),
new BytesRef("Shanghai"),
new BytesRef("São Paulo"),
new BytesRef("Mexico City"),
new BytesRef("Cairo") };

private static final BlockFactory blockFactory = BlockFactory.getInstance(
new NoopCircuitBreaker("noop"),
BigArrays.NON_RECYCLING_INSTANCE // TODO real big arrays?
);

static {
// Smoke test all the expected values and force loading subclasses more like prod
try {
for (String groups : ValuesAggregatorBenchmark.class.getField("groups").getAnnotationsByType(Param.class)[0].value()) {
for (String dataType : ValuesAggregatorBenchmark.class.getField("dataType").getAnnotationsByType(Param.class)[0].value()) {
run(Integer.parseInt(groups), dataType, 10);
}
}
} catch (NoSuchFieldException e) {
throw new AssertionError();
}
}

private static final String BYTES_REF = "BytesRef";

@Param({ "1", "1000", /*"1000000"*/ })
public int groups;

@Param({ BYTES_REF })
public String dataType;

private static Operator operator(DriverContext driverContext, int groups, String dataType) {
if (groups == 1) {
return new AggregationOperator(
List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
driverContext
);
}
List<BlockHash.GroupSpec> groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG));
return new HashAggregationOperator(
List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))),
() -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false),
driverContext
);
}

private static AggregatorFunctionSupplier supplier(String dataType) {
return switch (dataType) {
case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier();
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
};
}

private static void checkExpected(int groups, String dataType, Page page) {
String prefix = String.format("[%s][%s]", groups, dataType);
int positions = page.getPositionCount();
if (positions != groups) {
throw new IllegalArgumentException(prefix + " expected " + groups + " positions, got " + positions);
}
if (groups == 1) {
checkUngrouped(prefix, dataType, page);
return;
}
checkGrouped(prefix, groups, dataType, page);
}

private static void checkGrouped(String prefix, int groups, String dataType, Page page) {
LongVector groupsVector = page.<LongBlock>getBlock(0).asVector();
for (int p = 0; p < groups; p++) {
long group = groupsVector.getLong(p);
if (group != p) {
throw new IllegalArgumentException(prefix + "[" + p + "] expected group " + p + " but was " + groups);
}
}
switch (dataType) {
case BYTES_REF -> {
BytesRefBlock values = page.getBlock(1);
// Build the expected values
List<Set<BytesRef>> expected = new ArrayList<>(groups);
for (int g = 0; g < groups; g++) {
expected.add(new HashSet<>(KEYWORDS.length));
}
int blockLength = blockLength(groups);
for (int p = 0; p < blockLength; p++) {
expected.get(p % groups).add(KEYWORDS[p % KEYWORDS.length]);
}

// Check them
for (int p = 0; p < groups; p++) {
checkExpectedBytesRef(prefix, values, p, expected.get(p));
}
}
default -> throw new IllegalArgumentException(prefix + " unsupported data type " + dataType);
}
}

private static void checkUngrouped(String prefix, String dataType, Page page) {
switch (dataType) {
case BYTES_REF -> {
BytesRefBlock values = page.getBlock(0);
checkExpectedBytesRef(prefix, values, 0, Set.of(KEYWORDS));
}
default -> throw new IllegalArgumentException(prefix + " unsupported data type " + dataType);
}
}

private static void checkExpectedBytesRef(String prefix, BytesRefBlock values, int position, Set<BytesRef> expected) {
int valueCount = values.getValueCount(position);
if (valueCount != expected.size()) {
throw new IllegalArgumentException(
prefix + "[" + position + "] expected " + expected.size() + " values but count was " + valueCount
);
}
BytesRef scratch = new BytesRef();
for (int i = values.getFirstValueIndex(position); i < valueCount; i++) {
BytesRef v = values.getBytesRef(i, scratch);
if (expected.contains(v) == false) {
throw new IllegalArgumentException(prefix + "[" + position + "] expected " + v + " to be in " + expected);
}
}
}

private static Page page(int groups, String dataType) {
Block dataBlock = dataBlock(groups, dataType);
if (groups == 1) {
return new Page(dataBlock);
}
return new Page(groupingBlock(groups), dataBlock);
}

private static Block dataBlock(int groups, String dataType) {
return switch (dataType) {
case BYTES_REF -> {
int blockLength = blockLength(groups);
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(blockLength)) {
for (int i = 0; i < blockLength; i++) {
builder.appendBytesRef(KEYWORDS[i % KEYWORDS.length]);
}
yield builder.build();
}
}
default -> throw new IllegalArgumentException("unsupported data type " + dataType);
};
}

private static Block groupingBlock(int groups) {
int blockLength = blockLength(groups);
try (LongVector.Builder builder = blockFactory.newLongVectorBuilder(blockLength)) {
for (int i = 0; i < blockLength; i++) {
builder.appendLong(i % groups);
}
return builder.build().asBlock();
}
}

@Benchmark
public void run() {
run(groups, dataType, OP_COUNT);
}

private static void run(int groups, String dataType, int opCount) {
DriverContext driverContext = driverContext();
try (Operator operator = operator(driverContext, groups, dataType)) {
Page page = page(groups, dataType);
for (int i = 0; i < opCount; i++) {
operator.addInput(page.shallowCopy());
}
operator.finish();
checkExpected(groups, dataType, operator.getOutput());
}
}

static DriverContext driverContext() {
return new DriverContext(BigArrays.NON_RECYCLING_INSTANCE, blockFactory);
}

static int blockLength(int groups) {
return Math.max(MIN_BLOCK_LENGTH, groups);
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 0478be2

Please sign in to comment.