diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java index 2f209b4cf357b..280e6274d84de 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -15,11 +15,14 @@ import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.ValuesIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.ValuesLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; @@ -43,6 +46,9 @@ import java.util.List; import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; /** * Benchmark for the {@code VALUES} aggregator that supports grouping by many many @@ -57,6 +63,7 @@ public class ValuesAggregatorBenchmark { static final int MIN_BLOCK_LENGTH = 8 * 1024; private static final int OP_COUNT = 1024; + private static final int UNIQUE_VALUES = 6; private static final BytesRef[] KEYWORDS = new BytesRef[] { new BytesRef("Tokyo"), new BytesRef("Delhi"), @@ -64,6 +71,9 @@ public class ValuesAggregatorBenchmark { new BytesRef("São Paulo"), new BytesRef("Mexico City"), new BytesRef("Cairo") }; + static { + assert KEYWORDS.length == UNIQUE_VALUES; + } private static final BlockFactory blockFactory = BlockFactory.getInstance( new NoopCircuitBreaker("noop"), @@ -84,11 +94,13 @@ public class ValuesAggregatorBenchmark { } private static final String BYTES_REF = "BytesRef"; + private static final String INT = "int"; + private static final String LONG = "long"; @Param({ "1", "1000", /*"1000000"*/ }) public int groups; - @Param({ BYTES_REF }) + @Param({ BYTES_REF, INT, LONG }) public String dataType; private static Operator operator(DriverContext driverContext, int groups, String dataType) { @@ -109,6 +121,8 @@ private static Operator operator(DriverContext driverContext, int groups, String private static AggregatorFunctionSupplier supplier(String dataType) { return switch (dataType) { case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier(); + case INT -> new ValuesIntAggregatorFunctionSupplier(); + case LONG -> new ValuesLongAggregatorFunctionSupplier(); default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]"); }; } @@ -136,7 +150,6 @@ private static void checkGrouped(String prefix, int groups, String dataType, Pag } switch (dataType) { case BYTES_REF -> { - BytesRefBlock values = page.getBlock(1); // Build the expected values List> expected = new ArrayList<>(groups); for (int g = 0; g < groups; g++) { @@ -148,10 +161,45 @@ private static void checkGrouped(String prefix, int groups, String dataType, Pag } // Check them + BytesRefBlock values = page.getBlock(1); for (int p = 0; p < groups; p++) { checkExpectedBytesRef(prefix, values, p, expected.get(p)); } } + case INT -> { + // Build the expected values + List> expected = new ArrayList<>(groups); + for (int g = 0; g < groups; g++) { + expected.add(new HashSet<>(UNIQUE_VALUES)); + } + int blockLength = blockLength(groups); + for (int p = 0; p < blockLength; p++) { + expected.get(p % groups).add(p % KEYWORDS.length); + } + + // Check them + IntBlock values = page.getBlock(1); + for (int p = 0; p < groups; p++) { + checkExpectedInt(prefix, values, p, expected.get(p)); + } + } + case LONG -> { + // Build the expected values + List> expected = new ArrayList<>(groups); + for (int g = 0; g < groups; g++) { + expected.add(new HashSet<>(UNIQUE_VALUES)); + } + int blockLength = blockLength(groups); + for (int p = 0; p < blockLength; p++) { + expected.get(p % groups).add((long) p % KEYWORDS.length); + } + + // Check them + LongBlock values = page.getBlock(1); + for (int p = 0; p < groups; p++) { + checkExpectedLong(prefix, values, p, expected.get(p)); + } + } default -> throw new IllegalArgumentException(prefix + " unsupported data type " + dataType); } } @@ -162,17 +210,30 @@ private static void checkUngrouped(String prefix, String dataType, Page page) { BytesRefBlock values = page.getBlock(0); checkExpectedBytesRef(prefix, values, 0, Set.of(KEYWORDS)); } + case INT -> { + IntBlock values = page.getBlock(0); + checkExpectedInt(prefix, values, 0, IntStream.range(0, UNIQUE_VALUES).boxed().collect(Collectors.toSet())); + } + case LONG -> { + LongBlock values = page.getBlock(0); + checkExpectedLong(prefix, values, 0, LongStream.range(0, UNIQUE_VALUES).boxed().collect(Collectors.toSet())); + } default -> throw new IllegalArgumentException(prefix + " unsupported data type " + dataType); } } - private static void checkExpectedBytesRef(String prefix, BytesRefBlock values, int position, Set expected) { + private static int checkExpectedBlock(String prefix, Block values, int position, Set expected) { int valueCount = values.getValueCount(position); if (valueCount != expected.size()) { throw new IllegalArgumentException( prefix + "[" + position + "] expected " + expected.size() + " values but count was " + valueCount ); } + return valueCount; + } + + private static void checkExpectedBytesRef(String prefix, BytesRefBlock values, int position, Set expected) { + int valueCount = checkExpectedBlock(prefix, values, position, expected); BytesRef scratch = new BytesRef(); for (int i = values.getFirstValueIndex(position); i < valueCount; i++) { BytesRef v = values.getBytesRef(i, scratch); @@ -182,6 +243,26 @@ private static void checkExpectedBytesRef(String prefix, BytesRefBlock values, i } } + private static void checkExpectedInt(String prefix, IntBlock values, int position, Set expected) { + int valueCount = checkExpectedBlock(prefix, values, position, expected); + for (int i = values.getFirstValueIndex(position); i < valueCount; i++) { + Integer v = values.getInt(i); + if (expected.contains(v) == false) { + throw new IllegalArgumentException(prefix + "[" + position + "] expected " + v + " to be in " + expected); + } + } + } + + private static void checkExpectedLong(String prefix, LongBlock values, int position, Set expected) { + int valueCount = checkExpectedBlock(prefix, values, position, expected); + for (int i = values.getFirstValueIndex(position); i < valueCount; i++) { + Long v = values.getLong(i); + if (expected.contains(v) == false) { + throw new IllegalArgumentException(prefix + "[" + position + "] expected " + v + " to be in " + expected); + } + } + } + private static Page page(int groups, String dataType) { Block dataBlock = dataBlock(groups, dataType); if (groups == 1) { @@ -191,9 +272,9 @@ private static Page page(int groups, String dataType) { } private static Block dataBlock(int groups, String dataType) { + int blockLength = blockLength(groups); return switch (dataType) { case BYTES_REF -> { - int blockLength = blockLength(groups); try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(blockLength)) { for (int i = 0; i < blockLength; i++) { builder.appendBytesRef(KEYWORDS[i % KEYWORDS.length]); @@ -201,6 +282,22 @@ private static Block dataBlock(int groups, String dataType) { yield builder.build(); } } + case INT -> { + try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(blockLength)) { + for (int i = 0; i < blockLength; i++) { + builder.appendInt(i % UNIQUE_VALUES); + } + yield builder.build(); + } + } + case LONG -> { + try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(blockLength)) { + for (int i = 0; i < blockLength; i++) { + builder.appendLong(i % UNIQUE_VALUES); + } + yield builder.build(); + } + } default -> throw new IllegalArgumentException("unsupported data type " + dataType); }; }