Skip to content

Commit

Permalink
int/long
Browse files Browse the repository at this point in the history
  • Loading branch information
nik9000 committed Feb 23, 2025
1 parent 5f9e2e9 commit c0c6bde
Showing 1 changed file with 101 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.ValuesIntAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.ValuesLongAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
Expand All @@ -43,6 +46,9 @@
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;

/**
* Benchmark for the {@code VALUES} aggregator that supports grouping by many many
Expand All @@ -57,13 +63,17 @@
public class ValuesAggregatorBenchmark {
static final int MIN_BLOCK_LENGTH = 8 * 1024;
private static final int OP_COUNT = 1024;
private static final int UNIQUE_VALUES = 6;
private static final BytesRef[] KEYWORDS = new BytesRef[] {
new BytesRef("Tokyo"),
new BytesRef("Delhi"),
new BytesRef("Shanghai"),
new BytesRef("São Paulo"),
new BytesRef("Mexico City"),
new BytesRef("Cairo") };
static {
assert KEYWORDS.length == UNIQUE_VALUES;
}

private static final BlockFactory blockFactory = BlockFactory.getInstance(
new NoopCircuitBreaker("noop"),
Expand All @@ -84,11 +94,13 @@ public class ValuesAggregatorBenchmark {
}

private static final String BYTES_REF = "BytesRef";
private static final String INT = "int";
private static final String LONG = "long";

@Param({ "1", "1000", /*"1000000"*/ })
public int groups;

@Param({ BYTES_REF })
@Param({ BYTES_REF, INT, LONG })
public String dataType;

private static Operator operator(DriverContext driverContext, int groups, String dataType) {
Expand All @@ -109,6 +121,8 @@ private static Operator operator(DriverContext driverContext, int groups, String
private static AggregatorFunctionSupplier supplier(String dataType) {
return switch (dataType) {
case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier();
case INT -> new ValuesIntAggregatorFunctionSupplier();
case LONG -> new ValuesLongAggregatorFunctionSupplier();
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
};
}
Expand Down Expand Up @@ -136,7 +150,6 @@ private static void checkGrouped(String prefix, int groups, String dataType, Pag
}
switch (dataType) {
case BYTES_REF -> {
BytesRefBlock values = page.getBlock(1);
// Build the expected values
List<Set<BytesRef>> expected = new ArrayList<>(groups);
for (int g = 0; g < groups; g++) {
Expand All @@ -148,10 +161,45 @@ private static void checkGrouped(String prefix, int groups, String dataType, Pag
}

// Check them
BytesRefBlock values = page.getBlock(1);
for (int p = 0; p < groups; p++) {
checkExpectedBytesRef(prefix, values, p, expected.get(p));
}
}
case INT -> {
// Build the expected values
List<Set<Integer>> expected = new ArrayList<>(groups);
for (int g = 0; g < groups; g++) {
expected.add(new HashSet<>(UNIQUE_VALUES));
}
int blockLength = blockLength(groups);
for (int p = 0; p < blockLength; p++) {
expected.get(p % groups).add(p % KEYWORDS.length);
}

// Check them
IntBlock values = page.getBlock(1);
for (int p = 0; p < groups; p++) {
checkExpectedInt(prefix, values, p, expected.get(p));
}
}
case LONG -> {
// Build the expected values
List<Set<Long>> expected = new ArrayList<>(groups);
for (int g = 0; g < groups; g++) {
expected.add(new HashSet<>(UNIQUE_VALUES));
}
int blockLength = blockLength(groups);
for (int p = 0; p < blockLength; p++) {
expected.get(p % groups).add((long) p % KEYWORDS.length);
}

// Check them
LongBlock values = page.getBlock(1);
for (int p = 0; p < groups; p++) {
checkExpectedLong(prefix, values, p, expected.get(p));
}
}
default -> throw new IllegalArgumentException(prefix + " unsupported data type " + dataType);
}
}
Expand All @@ -162,17 +210,30 @@ private static void checkUngrouped(String prefix, String dataType, Page page) {
BytesRefBlock values = page.getBlock(0);
checkExpectedBytesRef(prefix, values, 0, Set.of(KEYWORDS));
}
case INT -> {
IntBlock values = page.getBlock(0);
checkExpectedInt(prefix, values, 0, IntStream.range(0, UNIQUE_VALUES).boxed().collect(Collectors.toSet()));
}
case LONG -> {
LongBlock values = page.getBlock(0);
checkExpectedLong(prefix, values, 0, LongStream.range(0, UNIQUE_VALUES).boxed().collect(Collectors.toSet()));
}
default -> throw new IllegalArgumentException(prefix + " unsupported data type " + dataType);
}
}

private static void checkExpectedBytesRef(String prefix, BytesRefBlock values, int position, Set<BytesRef> expected) {
private static int checkExpectedBlock(String prefix, Block values, int position, Set<?> expected) {
int valueCount = values.getValueCount(position);
if (valueCount != expected.size()) {
throw new IllegalArgumentException(
prefix + "[" + position + "] expected " + expected.size() + " values but count was " + valueCount
);
}
return valueCount;
}

private static void checkExpectedBytesRef(String prefix, BytesRefBlock values, int position, Set<BytesRef> expected) {
int valueCount = checkExpectedBlock(prefix, values, position, expected);
BytesRef scratch = new BytesRef();
for (int i = values.getFirstValueIndex(position); i < valueCount; i++) {
BytesRef v = values.getBytesRef(i, scratch);
Expand All @@ -182,6 +243,26 @@ private static void checkExpectedBytesRef(String prefix, BytesRefBlock values, i
}
}

private static void checkExpectedInt(String prefix, IntBlock values, int position, Set<Integer> expected) {
int valueCount = checkExpectedBlock(prefix, values, position, expected);
for (int i = values.getFirstValueIndex(position); i < valueCount; i++) {
Integer v = values.getInt(i);
if (expected.contains(v) == false) {
throw new IllegalArgumentException(prefix + "[" + position + "] expected " + v + " to be in " + expected);
}
}
}

private static void checkExpectedLong(String prefix, LongBlock values, int position, Set<Long> expected) {
int valueCount = checkExpectedBlock(prefix, values, position, expected);
for (int i = values.getFirstValueIndex(position); i < valueCount; i++) {
Long v = values.getLong(i);
if (expected.contains(v) == false) {
throw new IllegalArgumentException(prefix + "[" + position + "] expected " + v + " to be in " + expected);
}
}
}

private static Page page(int groups, String dataType) {
Block dataBlock = dataBlock(groups, dataType);
if (groups == 1) {
Expand All @@ -191,16 +272,32 @@ private static Page page(int groups, String dataType) {
}

private static Block dataBlock(int groups, String dataType) {
int blockLength = blockLength(groups);
return switch (dataType) {
case BYTES_REF -> {
int blockLength = blockLength(groups);
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(blockLength)) {
for (int i = 0; i < blockLength; i++) {
builder.appendBytesRef(KEYWORDS[i % KEYWORDS.length]);
}
yield builder.build();
}
}
case INT -> {
try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(blockLength)) {
for (int i = 0; i < blockLength; i++) {
builder.appendInt(i % UNIQUE_VALUES);
}
yield builder.build();
}
}
case LONG -> {
try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(blockLength)) {
for (int i = 0; i < blockLength; i++) {
builder.appendLong(i % UNIQUE_VALUES);
}
yield builder.build();
}
}
default -> throw new IllegalArgumentException("unsupported data type " + dataType);
};
}
Expand Down

0 comments on commit c0c6bde

Please sign in to comment.