From 67293ba8f4ddd2b9fc9799f4f648b4960fb13101 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Sun, 23 Feb 2025 13:29:55 -0500 Subject: [PATCH 01/15] ESQL: Speed up VALUES for many buckets (#123073) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Speeds up the VALUES agg when collecting from many buckets. Specifically, this speeds up the algorithm used to `finish` the aggregation. Most specifically, this makes the algorithm more tollerant to large numbers of groups being collected. The old algorithm was `O(n^2)` with the number of groups. The new one is `O(n)` ``` (groups) 1 219.683 ± 1.069 -> 223.477 ± 1.990 ms/op 1000 426.323 ± 75.963 -> 463.670 ± 7.275 ms/op 100000 36690.871 ± 4656.350 -> 7800.332 ± 2775.869 ms/op 200000 89422.113 ± 2972.606 -> 21920.288 ± 3427.962 ms/op 400000 timed out at 10 minutes -> 40051.524 ± 2011.706 ms/op ``` The `1` group version was not changed at all. That's just noise in the measurement. The small bump in the `1000` case is almost certainly worth it and real. The huge drop in the `100000` case is quite real. --- .../compute/operator/AggregatorBenchmark.java | 3 + .../operator/ValuesAggregatorBenchmark.java | 339 ++++++++++++++++++ docs/changelog/123073.yaml | 5 + .../aggregation/ValuesBytesRefAggregator.java | 136 +++++-- .../aggregation/ValuesDoubleAggregator.java | 134 +++++-- .../aggregation/ValuesFloatAggregator.java | 139 +++++-- .../aggregation/ValuesIntAggregator.java | 139 +++++-- .../aggregation/ValuesLongAggregator.java | 134 +++++-- .../aggregation/X-ValuesAggregator.java.st | 179 ++++++--- 9 files changed, 1031 insertions(+), 177 deletions(-) create mode 100644 benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java create mode 100644 docs/changelog/123073.yaml diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java index 77c70bc3a10f4..a6ee1cc1d21f5 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java @@ -60,6 +60,9 @@ import java.util.stream.LongStream; import java.util.stream.Stream; +/** + * Benchmark for many different kinds of aggregator and groupings. + */ @Warmup(iterations = 5) @Measurement(iterations = 7) @BenchmarkMode(Mode.AverageTime) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java new file mode 100644 index 0000000000000..280e6274d84de --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -0,0 +1,339 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.compute.operator; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.AggregatorMode; +import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.ValuesIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.ValuesLongAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.AggregationOperator; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.HashAggregationOperator; +import org.elasticsearch.compute.operator.Operator; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; + +/** + * Benchmark for the {@code VALUES} aggregator that supports grouping by many many + * many values. + */ +@Warmup(iterations = 5) +@Measurement(iterations = 7) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Thread) +@Fork(1) +public class ValuesAggregatorBenchmark { + static final int MIN_BLOCK_LENGTH = 8 * 1024; + private static final int OP_COUNT = 1024; + private static final int UNIQUE_VALUES = 6; + private static final BytesRef[] KEYWORDS = new BytesRef[] { + new BytesRef("Tokyo"), + new BytesRef("Delhi"), + new BytesRef("Shanghai"), + new BytesRef("São Paulo"), + new BytesRef("Mexico City"), + new BytesRef("Cairo") }; + static { + assert KEYWORDS.length == UNIQUE_VALUES; + } + + private static final BlockFactory blockFactory = BlockFactory.getInstance( + new NoopCircuitBreaker("noop"), + BigArrays.NON_RECYCLING_INSTANCE // TODO real big arrays? + ); + + static { + // Smoke test all the expected values and force loading subclasses more like prod + try { + for (String groups : ValuesAggregatorBenchmark.class.getField("groups").getAnnotationsByType(Param.class)[0].value()) { + for (String dataType : ValuesAggregatorBenchmark.class.getField("dataType").getAnnotationsByType(Param.class)[0].value()) { + run(Integer.parseInt(groups), dataType, 10); + } + } + } catch (NoSuchFieldException e) { + throw new AssertionError(); + } + } + + private static final String BYTES_REF = "BytesRef"; + private static final String INT = "int"; + private static final String LONG = "long"; + + @Param({ "1", "1000", /*"1000000"*/ }) + public int groups; + + @Param({ BYTES_REF, INT, LONG }) + public String dataType; + + private static Operator operator(DriverContext driverContext, int groups, String dataType) { + if (groups == 1) { + return new AggregationOperator( + List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)), + driverContext + ); + } + List groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG)); + return new HashAggregationOperator( + List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))), + () -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false), + driverContext + ); + } + + private static AggregatorFunctionSupplier supplier(String dataType) { + return switch (dataType) { + case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier(); + case INT -> new ValuesIntAggregatorFunctionSupplier(); + case LONG -> new ValuesLongAggregatorFunctionSupplier(); + default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]"); + }; + } + + private static void checkExpected(int groups, String dataType, Page page) { + String prefix = String.format("[%s][%s]", groups, dataType); + int positions = page.getPositionCount(); + if (positions != groups) { + throw new IllegalArgumentException(prefix + " expected " + groups + " positions, got " + positions); + } + if (groups == 1) { + checkUngrouped(prefix, dataType, page); + return; + } + checkGrouped(prefix, groups, dataType, page); + } + + private static void checkGrouped(String prefix, int groups, String dataType, Page page) { + LongVector groupsVector = page.getBlock(0).asVector(); + for (int p = 0; p < groups; p++) { + long group = groupsVector.getLong(p); + if (group != p) { + throw new IllegalArgumentException(prefix + "[" + p + "] expected group " + p + " but was " + groups); + } + } + switch (dataType) { + case BYTES_REF -> { + // Build the expected values + List> expected = new ArrayList<>(groups); + for (int g = 0; g < groups; g++) { + expected.add(new HashSet<>(KEYWORDS.length)); + } + int blockLength = blockLength(groups); + for (int p = 0; p < blockLength; p++) { + expected.get(p % groups).add(KEYWORDS[p % KEYWORDS.length]); + } + + // Check them + BytesRefBlock values = page.getBlock(1); + for (int p = 0; p < groups; p++) { + checkExpectedBytesRef(prefix, values, p, expected.get(p)); + } + } + case INT -> { + // Build the expected values + List> expected = new ArrayList<>(groups); + for (int g = 0; g < groups; g++) { + expected.add(new HashSet<>(UNIQUE_VALUES)); + } + int blockLength = blockLength(groups); + for (int p = 0; p < blockLength; p++) { + expected.get(p % groups).add(p % KEYWORDS.length); + } + + // Check them + IntBlock values = page.getBlock(1); + for (int p = 0; p < groups; p++) { + checkExpectedInt(prefix, values, p, expected.get(p)); + } + } + case LONG -> { + // Build the expected values + List> expected = new ArrayList<>(groups); + for (int g = 0; g < groups; g++) { + expected.add(new HashSet<>(UNIQUE_VALUES)); + } + int blockLength = blockLength(groups); + for (int p = 0; p < blockLength; p++) { + expected.get(p % groups).add((long) p % KEYWORDS.length); + } + + // Check them + LongBlock values = page.getBlock(1); + for (int p = 0; p < groups; p++) { + checkExpectedLong(prefix, values, p, expected.get(p)); + } + } + default -> throw new IllegalArgumentException(prefix + " unsupported data type " + dataType); + } + } + + private static void checkUngrouped(String prefix, String dataType, Page page) { + switch (dataType) { + case BYTES_REF -> { + BytesRefBlock values = page.getBlock(0); + checkExpectedBytesRef(prefix, values, 0, Set.of(KEYWORDS)); + } + case INT -> { + IntBlock values = page.getBlock(0); + checkExpectedInt(prefix, values, 0, IntStream.range(0, UNIQUE_VALUES).boxed().collect(Collectors.toSet())); + } + case LONG -> { + LongBlock values = page.getBlock(0); + checkExpectedLong(prefix, values, 0, LongStream.range(0, UNIQUE_VALUES).boxed().collect(Collectors.toSet())); + } + default -> throw new IllegalArgumentException(prefix + " unsupported data type " + dataType); + } + } + + private static int checkExpectedBlock(String prefix, Block values, int position, Set expected) { + int valueCount = values.getValueCount(position); + if (valueCount != expected.size()) { + throw new IllegalArgumentException( + prefix + "[" + position + "] expected " + expected.size() + " values but count was " + valueCount + ); + } + return valueCount; + } + + private static void checkExpectedBytesRef(String prefix, BytesRefBlock values, int position, Set expected) { + int valueCount = checkExpectedBlock(prefix, values, position, expected); + BytesRef scratch = new BytesRef(); + for (int i = values.getFirstValueIndex(position); i < valueCount; i++) { + BytesRef v = values.getBytesRef(i, scratch); + if (expected.contains(v) == false) { + throw new IllegalArgumentException(prefix + "[" + position + "] expected " + v + " to be in " + expected); + } + } + } + + private static void checkExpectedInt(String prefix, IntBlock values, int position, Set expected) { + int valueCount = checkExpectedBlock(prefix, values, position, expected); + for (int i = values.getFirstValueIndex(position); i < valueCount; i++) { + Integer v = values.getInt(i); + if (expected.contains(v) == false) { + throw new IllegalArgumentException(prefix + "[" + position + "] expected " + v + " to be in " + expected); + } + } + } + + private static void checkExpectedLong(String prefix, LongBlock values, int position, Set expected) { + int valueCount = checkExpectedBlock(prefix, values, position, expected); + for (int i = values.getFirstValueIndex(position); i < valueCount; i++) { + Long v = values.getLong(i); + if (expected.contains(v) == false) { + throw new IllegalArgumentException(prefix + "[" + position + "] expected " + v + " to be in " + expected); + } + } + } + + private static Page page(int groups, String dataType) { + Block dataBlock = dataBlock(groups, dataType); + if (groups == 1) { + return new Page(dataBlock); + } + return new Page(groupingBlock(groups), dataBlock); + } + + private static Block dataBlock(int groups, String dataType) { + int blockLength = blockLength(groups); + return switch (dataType) { + case BYTES_REF -> { + try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(blockLength)) { + for (int i = 0; i < blockLength; i++) { + builder.appendBytesRef(KEYWORDS[i % KEYWORDS.length]); + } + yield builder.build(); + } + } + case INT -> { + try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(blockLength)) { + for (int i = 0; i < blockLength; i++) { + builder.appendInt(i % UNIQUE_VALUES); + } + yield builder.build(); + } + } + case LONG -> { + try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(blockLength)) { + for (int i = 0; i < blockLength; i++) { + builder.appendLong(i % UNIQUE_VALUES); + } + yield builder.build(); + } + } + default -> throw new IllegalArgumentException("unsupported data type " + dataType); + }; + } + + private static Block groupingBlock(int groups) { + int blockLength = blockLength(groups); + try (LongVector.Builder builder = blockFactory.newLongVectorBuilder(blockLength)) { + for (int i = 0; i < blockLength; i++) { + builder.appendLong(i % groups); + } + return builder.build().asBlock(); + } + } + + @Benchmark + public void run() { + run(groups, dataType, OP_COUNT); + } + + private static void run(int groups, String dataType, int opCount) { + DriverContext driverContext = driverContext(); + try (Operator operator = operator(driverContext, groups, dataType)) { + Page page = page(groups, dataType); + for (int i = 0; i < opCount; i++) { + operator.addInput(page.shallowCopy()); + } + operator.finish(); + checkExpected(groups, dataType, operator.getOutput()); + } + } + + static DriverContext driverContext() { + return new DriverContext(BigArrays.NON_RECYCLING_INSTANCE, blockFactory); + } + + static int blockLength(int groups) { + return Math.max(MIN_BLOCK_LENGTH, groups); + } +} diff --git a/docs/changelog/123073.yaml b/docs/changelog/123073.yaml new file mode 100644 index 0000000000000..95c6a4cbfd6b9 --- /dev/null +++ b/docs/changelog/123073.yaml @@ -0,0 +1,5 @@ +pr: 123073 +summary: Speed up VALUES for many buckets +area: ES|QL +type: bug +issues: [] diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index ad0ab2f7189f6..f326492664fb8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -8,6 +8,7 @@ package org.elasticsearch.compute.aggregation; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongLongHash; @@ -151,46 +152,127 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + /** + * Builds a {@link Block} with the unique values collected for the {@code #selected} + * groups. This is the implementation of the final and intermediate results of the agg. + */ Block toBlock(BlockFactory blockFactory, IntVector selected) { if (values.size() == 0) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } - BytesRef scratch = new BytesRef(); - try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { + + long selectedCountsSize = 0; + long idsSize = 0; + try { + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + long adjust = RamUsageEstimator.alignObjectSize( + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES + ); + blockFactory.adjustBreaker(adjust); + selectedCountsSize = adjust; + int[] selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < values.size(); id++) { + int group = (int) values.getKey1(id); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; for (int s = 0; s < selected.getPositionCount(); s++) { - int selectedGroup = selected.getInt(s); - /* - * Count can effectively be in three states - 0, 1, many. We use those - * states to buffer the first value, so we can avoid calling - * beginPositionEntry on single valued fields. - */ - int count = 0; - long first = 0; - for (int id = 0; id < values.size(); id++) { - if (values.getKey1(id) == selectedGroup) { - long value = values.getKey2(id); - switch (count) { - case 0 -> first = value; - case 1 -> { - builder.beginPositionEntry(); - builder.appendBytesRef(bytes.get(first, scratch)); - builder.appendBytesRef(bytes.get(value, scratch)); + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + idsSize = adjust; + int[] ids = new int[total]; + for (int id = 0; id < values.size(); id++) { + int group = (int) values.getKey1(id); + if (group < selectedCounts.length && selectedCounts[group] >= 0) { + ids[selectedCounts[group]++] = id; + } + } + + /* + * Insert the ids in order. + */ + BytesRef scratch = new BytesRef(); + try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> append(builder, ids[start], scratch); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + append(builder, ids[i], scratch); } - default -> builder.appendBytesRef(bytes.get(value, scratch)); + builder.endPositionEntry(); } - count++; } + start = end; } - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendBytesRef(bytes.get(first, scratch)); - default -> builder.endPositionEntry(); - } + return builder.build(); } - return builder.build(); + } finally { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); } } + private void append(BytesRefBlock.Builder builder, int id, BytesRef scratch) { + BytesRef value = bytes.get(values.getKey2(id), scratch); + builder.appendBytesRef(value); + } + @Override public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index 271d7120092ca..752cd53a140f7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.aggregation; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; @@ -130,45 +131,126 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + /** + * Builds a {@link Block} with the unique values collected for the {@code #selected} + * groups. This is the implementation of the final and intermediate results of the agg. + */ Block toBlock(BlockFactory blockFactory, IntVector selected) { if (values.size() == 0) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } - try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) { + + long selectedCountsSize = 0; + long idsSize = 0; + try { + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + long adjust = RamUsageEstimator.alignObjectSize( + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES + ); + blockFactory.adjustBreaker(adjust); + selectedCountsSize = adjust; + int[] selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < values.size(); id++) { + int group = (int) values.getKey1(id); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; for (int s = 0; s < selected.getPositionCount(); s++) { - int selectedGroup = selected.getInt(s); - /* - * Count can effectively be in three states - 0, 1, many. We use those - * states to buffer the first value, so we can avoid calling - * beginPositionEntry on single valued fields. - */ - int count = 0; - double first = 0; - for (int id = 0; id < values.size(); id++) { - if (values.getKey1(id) == selectedGroup) { - double value = Double.longBitsToDouble(values.getKey2(id)); - switch (count) { - case 0 -> first = value; - case 1 -> { - builder.beginPositionEntry(); - builder.appendDouble(first); - builder.appendDouble(value); + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + idsSize = adjust; + int[] ids = new int[total]; + for (int id = 0; id < values.size(); id++) { + int group = (int) values.getKey1(id); + if (group < selectedCounts.length && selectedCounts[group] >= 0) { + ids[selectedCounts[group]++] = id; + } + } + + /* + * Insert the ids in order. + */ + try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> append(builder, ids[start]); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + append(builder, ids[i]); } - default -> builder.appendDouble(value); + builder.endPositionEntry(); } - count++; } + start = end; } - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendDouble(first); - default -> builder.endPositionEntry(); - } + return builder.build(); } - return builder.build(); + } finally { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); } } + private void append(DoubleBlock.Builder builder, int id) { + double value = Double.longBitsToDouble(values.getKey2(id)); + builder.appendDouble(value); + } + @Override public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index b44cad807fba2..91f1730ab3111 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.aggregation; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.compute.ann.Aggregator; @@ -135,47 +136,129 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + /** + * Builds a {@link Block} with the unique values collected for the {@code #selected} + * groups. This is the implementation of the final and intermediate results of the agg. + */ Block toBlock(BlockFactory blockFactory, IntVector selected) { if (values.size() == 0) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } - try (FloatBlock.Builder builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount())) { + + long selectedCountsSize = 0; + long idsSize = 0; + try { + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + long adjust = RamUsageEstimator.alignObjectSize( + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES + ); + blockFactory.adjustBreaker(adjust); + selectedCountsSize = adjust; + int[] selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < values.size(); id++) { + long both = values.get(id); + int group = (int) (both >>> Float.SIZE); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; for (int s = 0; s < selected.getPositionCount(); s++) { - int selectedGroup = selected.getInt(s); - /* - * Count can effectively be in three states - 0, 1, many. We use those - * states to buffer the first value, so we can avoid calling - * beginPositionEntry on single valued fields. - */ - int count = 0; - float first = 0; - for (int id = 0; id < values.size(); id++) { - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group == selectedGroup) { - float value = Float.intBitsToFloat((int) both); - switch (count) { - case 0 -> first = value; - case 1 -> { - builder.beginPositionEntry(); - builder.appendFloat(first); - builder.appendFloat(value); + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + idsSize = adjust; + int[] ids = new int[total]; + for (int id = 0; id < values.size(); id++) { + long both = values.get(id); + int group = (int) (both >>> Float.SIZE); + if (group < selectedCounts.length && selectedCounts[group] >= 0) { + ids[selectedCounts[group]++] = id; + } + } + + /* + * Insert the ids in order. + */ + try (FloatBlock.Builder builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> append(builder, ids[start]); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + append(builder, ids[i]); } - default -> builder.appendFloat(value); + builder.endPositionEntry(); } - count++; } + start = end; } - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendFloat(first); - default -> builder.endPositionEntry(); - } + return builder.build(); } - return builder.build(); + } finally { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); } } + private void append(FloatBlock.Builder builder, int id) { + long both = values.get(id); + float value = Float.intBitsToFloat((int) both); + builder.appendFloat(value); + } + @Override public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index 4d0c518245694..c4f595d938aa9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.aggregation; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.compute.ann.Aggregator; @@ -135,47 +136,129 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + /** + * Builds a {@link Block} with the unique values collected for the {@code #selected} + * groups. This is the implementation of the final and intermediate results of the agg. + */ Block toBlock(BlockFactory blockFactory, IntVector selected) { if (values.size() == 0) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } - try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { + + long selectedCountsSize = 0; + long idsSize = 0; + try { + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + long adjust = RamUsageEstimator.alignObjectSize( + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES + ); + blockFactory.adjustBreaker(adjust); + selectedCountsSize = adjust; + int[] selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < values.size(); id++) { + long both = values.get(id); + int group = (int) (both >>> Float.SIZE); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; for (int s = 0; s < selected.getPositionCount(); s++) { - int selectedGroup = selected.getInt(s); - /* - * Count can effectively be in three states - 0, 1, many. We use those - * states to buffer the first value, so we can avoid calling - * beginPositionEntry on single valued fields. - */ - int count = 0; - int first = 0; - for (int id = 0; id < values.size(); id++) { - long both = values.get(id); - int group = (int) (both >>> Integer.SIZE); - if (group == selectedGroup) { - int value = (int) both; - switch (count) { - case 0 -> first = value; - case 1 -> { - builder.beginPositionEntry(); - builder.appendInt(first); - builder.appendInt(value); + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + idsSize = adjust; + int[] ids = new int[total]; + for (int id = 0; id < values.size(); id++) { + long both = values.get(id); + int group = (int) (both >>> Float.SIZE); + if (group < selectedCounts.length && selectedCounts[group] >= 0) { + ids[selectedCounts[group]++] = id; + } + } + + /* + * Insert the ids in order. + */ + try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> append(builder, ids[start]); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + append(builder, ids[i]); } - default -> builder.appendInt(value); + builder.endPositionEntry(); } - count++; } + start = end; } - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendInt(first); - default -> builder.endPositionEntry(); - } + return builder.build(); } - return builder.build(); + } finally { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); } } + private void append(IntBlock.Builder builder, int id) { + long both = values.get(id); + int value = (int) both; + builder.appendInt(value); + } + @Override public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index 5471c90147ec4..8ae5da509151e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.aggregation; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; @@ -130,45 +131,126 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + /** + * Builds a {@link Block} with the unique values collected for the {@code #selected} + * groups. This is the implementation of the final and intermediate results of the agg. + */ Block toBlock(BlockFactory blockFactory, IntVector selected) { if (values.size() == 0) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } - try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) { + + long selectedCountsSize = 0; + long idsSize = 0; + try { + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + long adjust = RamUsageEstimator.alignObjectSize( + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES + ); + blockFactory.adjustBreaker(adjust); + selectedCountsSize = adjust; + int[] selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < values.size(); id++) { + int group = (int) values.getKey1(id); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; for (int s = 0; s < selected.getPositionCount(); s++) { - int selectedGroup = selected.getInt(s); - /* - * Count can effectively be in three states - 0, 1, many. We use those - * states to buffer the first value, so we can avoid calling - * beginPositionEntry on single valued fields. - */ - int count = 0; - long first = 0; - for (int id = 0; id < values.size(); id++) { - if (values.getKey1(id) == selectedGroup) { - long value = values.getKey2(id); - switch (count) { - case 0 -> first = value; - case 1 -> { - builder.beginPositionEntry(); - builder.appendLong(first); - builder.appendLong(value); + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + idsSize = adjust; + int[] ids = new int[total]; + for (int id = 0; id < values.size(); id++) { + int group = (int) values.getKey1(id); + if (group < selectedCounts.length && selectedCounts[group] >= 0) { + ids[selectedCounts[group]++] = id; + } + } + + /* + * Insert the ids in order. + */ + try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> append(builder, ids[start]); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + append(builder, ids[i]); } - default -> builder.appendLong(value); + builder.endPositionEntry(); } - count++; } + start = end; } - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendLong(first); - default -> builder.endPositionEntry(); - } + return builder.build(); } - return builder.build(); + } finally { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); } } + private void append(LongBlock.Builder builder, int id) { + long value = values.getKey2(id); + builder.appendLong(value); + } + @Override public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 3006af595be1f..68c6a8640cbd0 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -10,6 +10,7 @@ package org.elasticsearch.compute.aggregation; $if(BytesRef)$ import org.apache.lucene.util.BytesRef; $endif$ +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; $if(BytesRef)$ import org.elasticsearch.common.util.BytesRefHash; @@ -268,63 +269,157 @@ $endif$ blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + /** + * Builds a {@link Block} with the unique values collected for the {@code #selected} + * groups. This is the implementation of the final and intermediate results of the agg. + */ Block toBlock(BlockFactory blockFactory, IntVector selected) { if (values.size() == 0) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } -$if(BytesRef)$ - BytesRef scratch = new BytesRef(); + + long selectedCountsSize = 0; + long idsSize = 0; + try { + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + long adjust = RamUsageEstimator.alignObjectSize( + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES + ); + blockFactory.adjustBreaker(adjust); + selectedCountsSize = adjust; + int[] selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < values.size(); id++) { +$if(long||BytesRef||double)$ + int group = (int) values.getKey1(id); +$elseif(float||int)$ + long both = values.get(id); + int group = (int) (both >>> Float.SIZE); $endif$ - try ($Type$Block.Builder builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) { + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; for (int s = 0; s < selected.getPositionCount(); s++) { - int selectedGroup = selected.getInt(s); - /* - * Count can effectively be in three states - 0, 1, many. We use those - * states to buffer the first value, so we can avoid calling - * beginPositionEntry on single valued fields. - */ - int count = 0; - $if(BytesRef)$long$else$$type$$endif$ first = 0; - for (int id = 0; id < values.size(); id++) { -$if(long||BytesRef)$ - if (values.getKey1(id) == selectedGroup) { - long value = values.getKey2(id); -$elseif(double)$ - if (values.getKey1(id) == selectedGroup) { - double value = Double.longBitsToDouble(values.getKey2(id)); -$elseif(float)$ - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group == selectedGroup) { - float value = Float.intBitsToFloat((int) both); -$elseif(int)$ - long both = values.get(id); - int group = (int) (both >>> Integer.SIZE); - if (group == selectedGroup) { - int value = (int) both; + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + idsSize = adjust; + int[] ids = new int[total]; + for (int id = 0; id < values.size(); id++) { +$if(long||BytesRef||double)$ + int group = (int) values.getKey1(id); +$elseif(float||int)$ + long both = values.get(id); + int group = (int) (both >>> Float.SIZE); +$endif$ + if (group < selectedCounts.length && selectedCounts[group] >= 0) { + ids[selectedCounts[group]++] = id; + } + } + + /* + * Insert the ids in order. + */ +$if(BytesRef)$ + BytesRef scratch = new BytesRef(); $endif$ - switch (count) { - case 0 -> first = value; - case 1 -> { - builder.beginPositionEntry(); - builder.append$Type$($if(BytesRef)$bytes.get(first, scratch)$else$first$endif$); - builder.append$Type$($if(BytesRef)$bytes.get(value, scratch)$else$value$endif$); + try ($Type$Block.Builder builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> append(builder, ids[start]$if(BytesRef)$, scratch$endif$); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + append(builder, ids[i]$if(BytesRef)$, scratch$endif$); } - default -> builder.append$Type$($if(BytesRef)$bytes.get(value, scratch)$else$value$endif$); + builder.endPositionEntry(); } - count++; } + start = end; } - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.append$Type$($if(BytesRef)$bytes.get(first, scratch)$else$first$endif$); - default -> builder.endPositionEntry(); - } + return builder.build(); } - return builder.build(); + } finally { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); } } +$if(BytesRef)$ + private void append($Type$Block.Builder builder, int id, BytesRef scratch) { + BytesRef value = bytes.get(values.getKey2(id), scratch); + builder.appendBytesRef(value); + } + +$else$ + private void append($Type$Block.Builder builder, int id) { +$if(long)$ + long value = values.getKey2(id); +$elseif(double)$ + double value = Double.longBitsToDouble(values.getKey2(id)); +$elseif(float)$ + long both = values.get(id); + float value = Float.intBitsToFloat((int) both); +$elseif(int)$ + long both = values.get(id); + int value = (int) both; +$endif$ + builder.append$Type$(value); + } + +$endif$ @Override public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block From 4af51839d2c56354d2ef05ba36632f8bfddd1601 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Mon, 24 Feb 2025 08:00:47 +0100 Subject: [PATCH 02/15] Speedup MultiTermsAggregator (#123220) Creating (and more importantly eventually resizing) a fresh stream output makes up a large chunk of the runtime of this aggregation. Also, recursively calling an inline consumer makes this logic even more confusing and adds additional overhead as escape analysis will not be able to remove the allocation of the consumer. => just call a method recursively and reuse the output --- .../multiterms/MultiTermsAggregator.java | 66 ++++++++----------- 1 file changed, 27 insertions(+), 39 deletions(-) diff --git a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/MultiTermsAggregator.java b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/MultiTermsAggregator.java index 5c10e2c8feeb1..e2d3e300da69b 100644 --- a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/MultiTermsAggregator.java +++ b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/MultiTermsAggregator.java @@ -24,7 +24,6 @@ import org.elasticsearch.common.util.LongArray; import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.common.util.ObjectArrayPriorityQueue; -import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.Releasables; import org.elasticsearch.index.fielddata.FieldData; import org.elasticsearch.index.fielddata.NumericDoubleValues; @@ -168,20 +167,6 @@ static List> docTerms(List termValuesList, int doc) thr return terms; } - /** - * Packs a list of terms into ByteRef so we can use BytesKeyedBucketOrds - * - * TODO: this is a temporary solution, we should replace it with a more optimal mechanism instead of relying on BytesKeyedBucketOrds - */ - static BytesRef packKey(List terms) { - try (BytesStreamOutput output = new BytesStreamOutput()) { - output.writeCollection(terms, StreamOutput::writeGenericValue); - return output.bytes().toBytesRef(); - } catch (IOException ex) { - throw ExceptionsHelper.convertToRuntime(ex); - } - } - /** * Unpacks ByteRef back into a list of terms * @@ -198,36 +183,39 @@ static List unpackTerms(BytesRef termsBytes) { @Override public LeafBucketCollector getLeafCollector(AggregationExecutionContext aggCtx, LeafBucketCollector sub) throws IOException { List termValuesList = termValuesList(aggCtx.getLeafReaderContext()); - + BytesStreamOutput output = new BytesStreamOutput(); return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long owningBucketOrd) throws IOException { List> terms = docTerms(termValuesList, doc); if (terms != null) { - List path = new ArrayList<>(terms.size()); - new CheckedConsumer() { - @Override - public void accept(Integer start) throws IOException { - for (Object term : terms.get(start)) { - if (start == path.size()) { - path.add(term); - } else { - path.set(start, term); - } - if (start < terms.size() - 1) { - this.accept(start + 1); - } else { - long bucketOrd = bucketOrds.add(owningBucketOrd, packKey(path)); - if (bucketOrd < 0) { // already seen - bucketOrd = -1 - bucketOrd; - collectExistingBucket(sub, doc, bucketOrd); - } else { - collectBucket(sub, doc, bucketOrd); - } - } - } + doCollect(terms, new ArrayList<>(terms.size()), owningBucketOrd, doc, 0); + } + } + + private void doCollect(List> terms, List path, long owningBucketOrd, int doc, int start) + throws IOException { + for (Object term : terms.get(start)) { + if (start == path.size()) { + path.add(term); + } else { + path.set(start, term); + } + if (start < terms.size() - 1) { + doCollect(terms, path, owningBucketOrd, doc, start + 1); + } else { + // TODO: this is a temporary solution, we should replace it with a more optimal mechanism instead of relying on + // BytesKeyedBucketOrds + output.seek(0L); + output.writeCollection(path, StreamOutput::writeGenericValue); + long bucketOrd = bucketOrds.add(owningBucketOrd, output.bytes().toBytesRef()); + if (bucketOrd < 0) { // already seen + bucketOrd = -1 - bucketOrd; + collectExistingBucket(sub, doc, bucketOrd); + } else { + collectBucket(sub, doc, bucketOrd); } - }.accept(0); + } } } }; From a5f3da280990e208e8d4b81a282e1a1dab695bf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20R=C3=BChsen?= Date: Mon, 24 Feb 2025 08:09:34 +0100 Subject: [PATCH 03/15] Fix queries for events with missing host.id (#122472) --- .../resources/data/profiling-events-all.ndjson | 2 +- .../xpack/profiling/action/TransportGetStackTracesAction.java | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/profiling/src/internalClusterTest/resources/data/profiling-events-all.ndjson b/x-pack/plugin/profiling/src/internalClusterTest/resources/data/profiling-events-all.ndjson index b09817182eb21..a2aeaed164cc6 100644 --- a/x-pack/plugin/profiling/src/internalClusterTest/resources/data/profiling-events-all.ndjson +++ b/x-pack/plugin/profiling/src/internalClusterTest/resources/data/profiling-events-all.ndjson @@ -1,5 +1,5 @@ {"create": {"_index": "profiling-events-all"}} -{"Stacktrace.count": [1], "profiling.project.id": ["100"], "os.kernel": ["9.9.9-0"], "tags": ["environment:qa", "region:eu-west-1"], "host.ip": ["192.168.1.2"], "@timestamp": ["2023-11-20T19:20:27.000000000Z"], "container.name": ["instance-0000000010"], "ecs.version": ["1.12.0"], "Stacktrace.id": ["S07KmaoGhvNte78xwwRbZQ"], "agent.version": ["head-be593ef3-1688111067"], "host.name": ["ip-192-168-1-2"], "host.id": ["8457605156473051743"], "process.thread.name": ["497295213074376"]} +{"Stacktrace.count": [1], "profiling.project.id": ["100"], "os.kernel": ["9.9.9-0"], "tags": ["environment:qa", "region:eu-west-1"], "host.ip": ["192.168.1.2"], "@timestamp": ["2023-11-20T19:20:27.000000000Z"], "container.name": ["instance-0000000010"], "ecs.version": ["1.12.0"], "Stacktrace.id": ["S07KmaoGhvNte78xwwRbZQ"], "agent.version": ["head-be593ef3-1688111067"], "host.name": ["ip-192-168-1-2"], "process.thread.name": ["497295213074376"]} {"create": {"_index": "profiling-events-all"}} {"Stacktrace.count": [1], "profiling.project.id": ["100"], "os.kernel": ["9.9.9-0"], "tags": ["environment:qa", "region:eu-west-1"], "host.ip": ["192.168.1.2"], "@timestamp": ["1698624000"], "container.name": ["instance-0000000010"], "ecs.version": ["1.12.0"], "Stacktrace.id": ["4tB_mGJrj1xVuMFbXVYwGA"], "agent.version": ["head-be593ef3-1688111067"], "host.name": ["ip-192-168-1-2"], "host.id": ["8457605156473051743"], "process.thread.name": ["497295213074376"]} {"create": {"_index": "profiling-events-all"}} diff --git a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TransportGetStackTracesAction.java b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TransportGetStackTracesAction.java index 982410e8a5345..c7a830f03f3df 100644 --- a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TransportGetStackTracesAction.java +++ b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TransportGetStackTracesAction.java @@ -340,6 +340,8 @@ private void searchEventGroupedByStackTrace( // 'size' specifies the max number of host ID we support per request. .size(MAX_TRACE_EVENTS_RESULT_SIZE) .field("host.id") + // missing("") is used to include documents where the field is missing. + .missing("") // 'execution_hint: map' skips the slow building of ordinals that we don't need. // Especially with high cardinality fields, this makes aggregations really slow. .executionHint("map") From 4ca669ad420227cba369346cb0230d7bb12f2f11 Mon Sep 17 00:00:00 2001 From: Artem Prigoda Date: Mon, 24 Feb 2025 08:22:37 +0100 Subject: [PATCH 04/15] [test] Fix `RetrySearchIntegTests` (#122919) Don't simultaneously restart multiple nodes in a cluster. It causes data races when multiple primaries are trying to mark the `[[.snapshot-blob-cache][0]]` shard as stale. ``` org.elasticsearch.cluster.action.shard.ShardStateAction$NoLongerPrimaryShardException: primary term [2] did not match current primary term [4] at org.elasticsearch.cluster.action.shard.ShardStateAction$ShardFailedClusterStateTaskExecutor.execute(ShardStateAction.java:355) ~[elasticsearch-8.18.0-SNAPSHOT.jar:8.18.0-SNAPSHOT] at org.elasticsearch.cluster.service.MasterService.innerExecuteTasks(MasterService.java:1075) ~[elasticsearch-8.18.0-SNAPSHOT.jar:8.18.0-SNAPSHOT] at org.elasticsearch.cluster.service.MasterService.executeTasks(MasterService.java:1038) ~[elasticsearch-8.18.0-SNAPSHOT.jar:8.18.0-SNAPSHOT] at org.elasticsearch.cluster.service.MasterService.executeAndPublishBatch(MasterService.java:245) ~[elasticsearch-8.18.0-SNAPSHOT.jar:8.18.0-SNAPSHOT] at org.elasticsearch.cluster.service.MasterService$BatchingTaskQueue$Processor.lambda$run$2(MasterService.java:1691) ~[elasticsearch-8.18.0-SNAPSHOT.jar:8.18.0-SNAPSHOT] at org.elasticsearch.action.ActionListener.run(ActionListener.java:452) ~[elasticsearch-8.18.0-SNAPSHOT.jar:8.18.0-SNAPSHOT] at org.elasticsearch.cluster.service.MasterService$BatchingTaskQueue$Processor.run(MasterService.java:1688) ~[elasticsearch-8.18.0-SNAPSHOT.jar:8.18.0-SNAPSHOT] at org.elasticsearch.cluster.service.MasterService$5.lambda$doRun$0(MasterService.java:1283) ~[elasticsearch-8.18.0-SNAPSHOT.jar:8.18.0-SNAPSHOT] at org.elasticsearch.action.ActionListener.run(ActionListener.java:452) ~[elasticsearch-8.18.0-SNAPSHOT.jar:8.18.0-SNAPSHOT] at org.elasticsearch.cluster.service.MasterService$5.doRun(MasterService.java:1262) ~[elasticsearch-8.18.0-SNAPSHOT.jar:8.18.0-SNAPSHOT] at org.elasticsearch.common.util.concurrent.ThreadContext$ContextPreservingAbstractRunnable.doRun(ThreadContext.java:1044) ~[elasticsearch-8.18.0-SNAPSHOT.jar:8.18.0-SNAPSHOT] at org.elasticsearch.common.util.concurrent.AbstractRunnable.run(AbstractRunnable.java:27) ~[elasticsearch-8.18.0-SNAPSHOT.jar:8.18.0-SNAPSHOT] at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144) ~[?:?] at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642) ~[?:?] at java.lang.Thread.run(Thread.java:1575) ~[?:?] ``` Resolve #118374 Resolve #120442 --- muted-tests.yml | 3 --- .../xpack/searchablesnapshots/RetrySearchIntegTests.java | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index 96977744a89ea..d12b04183797a 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -122,9 +122,6 @@ tests: - class: org.elasticsearch.datastreams.DataStreamsClientYamlTestSuiteIT method: test {p0=data_stream/120_data_streams_stats/Multiple data stream} issue: https://github.com/elastic/elasticsearch/issues/118217 -- class: org.elasticsearch.xpack.searchablesnapshots.RetrySearchIntegTests - method: testSearcherId - issue: https://github.com/elastic/elasticsearch/issues/118374 - class: org.elasticsearch.xpack.ccr.rest.ShardChangesRestIT method: testShardChangesNoOperation issue: https://github.com/elastic/elasticsearch/issues/118800 diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/RetrySearchIntegTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/RetrySearchIntegTests.java index c50fe50db8b40..c9a1a82b34118 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/RetrySearchIntegTests.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/RetrySearchIntegTests.java @@ -90,6 +90,7 @@ public void testSearcherId() throws Exception { for (String allocatedNode : allocatedNodes) { if (randomBoolean()) { internalCluster().restartNode(allocatedNode); + ensureGreen(indexName); } } ensureGreen(indexName); @@ -151,6 +152,7 @@ public void testRetryPointInTime() throws Exception { final Set allocatedNodes = internalCluster().nodesInclude(indexName); for (String allocatedNode : allocatedNodes) { internalCluster().restartNode(allocatedNode); + ensureGreen(indexName); } ensureGreen(indexName); assertNoFailuresAndResponse( From f5e2a92a3171808bd163570c3e1546230ec7c5ac Mon Sep 17 00:00:00 2001 From: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Date: Mon, 24 Feb 2025 08:47:05 +0100 Subject: [PATCH 05/15] Add rescore knn vector test coverage (#122801) --- .../search/query/RescoreKnnVectorQueryIT.java | 238 ++++++++++++++++++ .../vectors/DenseVectorFieldMapper.java | 2 +- .../search/vectors/KnnScoreDocQuery.java | 20 +- .../vectors/KnnScoreDocQueryBuilder.java | 10 +- .../search/vectors/RescoreKnnVectorQuery.java | 13 +- .../vectors/RescoreKnnVectorQueryTests.java | 110 ++++---- 6 files changed, 316 insertions(+), 77 deletions(-) create mode 100644 server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java new file mode 100644 index 0000000000000..c8812cfc109f2 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java @@ -0,0 +1,238 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.query; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorIndexType; +import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.script.MockScriptPlugin; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptType; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse; +import static org.hamcrest.Matchers.equalTo; + +public class RescoreKnnVectorQueryIT extends ESIntegTestCase { + + public static final String INDEX_NAME = "test"; + public static final String VECTOR_FIELD = "vector"; + public static final String VECTOR_SCORE_SCRIPT = "vector_scoring"; + public static final String QUERY_VECTOR_PARAM = "query_vector"; + + @Override + protected Collection> nodePlugins() { + return Collections.singleton(CustomScriptPlugin.class); + } + + public static class CustomScriptPlugin extends MockScriptPlugin { + private static final VectorSimilarityFunction SIMILARITY_FUNCTION = DenseVectorFieldMapper.VectorSimilarity.L2_NORM + .vectorSimilarityFunction(IndexVersion.current(), DenseVectorFieldMapper.ElementType.FLOAT); + + @Override + protected Map, Object>> pluginScripts() { + return Map.of(VECTOR_SCORE_SCRIPT, vars -> { + Map doc = (Map) vars.get("doc"); + return SIMILARITY_FUNCTION.compare( + ((DenseVectorScriptDocValues) doc.get(VECTOR_FIELD)).getVectorValue(), + (float[]) vars.get(QUERY_VECTOR_PARAM) + ); + }); + } + } + + @Before + public void setup() throws IOException { + String type = randomFrom( + Arrays.stream(VectorIndexType.values()) + .filter(VectorIndexType::isQuantized) + .map(t -> t.name().toLowerCase(Locale.ROOT)) + .collect(Collectors.toCollection(ArrayList::new)) + ); + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(VECTOR_FIELD) + .field("type", "dense_vector") + .field("similarity", "l2_norm") + .startObject("index_options") + .field("type", type) + .endObject() + .endObject() + .endObject() + .endObject(); + + Settings settings = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)) + .build(); + prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get(); + ensureGreen(INDEX_NAME); + } + + private record TestParams( + int numDocs, + int numDims, + float[] queryVector, + int k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder + ) { + public static TestParams generate() { + int numDims = randomIntBetween(32, 512) * 2; // Ensure even dimensions + int numDocs = randomIntBetween(10, 100); + int k = randomIntBetween(1, numDocs - 5); + return new TestParams( + numDocs, + numDims, + randomVector(numDims), + k, + (int) (k * randomFloatBetween(1.0f, 10.0f, true)), + new RescoreVectorBuilder(randomFloatBetween(1.0f, 100f, true)) + ); + } + } + + public void testKnnSearchRescore() { + BiFunction knnSearchGenerator = (testParams, requestBuilder) -> { + KnnSearchBuilder knnSearch = new KnnSearchBuilder( + VECTOR_FIELD, + testParams.queryVector, + testParams.k, + testParams.numCands, + testParams.rescoreVectorBuilder, + null + ); + return requestBuilder.setKnnSearch(List.of(knnSearch)); + }; + testKnnRescore(knnSearchGenerator); + } + + public void testKnnQueryRescore() { + BiFunction knnQueryGenerator = (testParams, requestBuilder) -> { + KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder( + VECTOR_FIELD, + testParams.queryVector, + testParams.k, + testParams.numCands, + testParams.rescoreVectorBuilder, + null + ); + return requestBuilder.setQuery(knnQuery); + }; + testKnnRescore(knnQueryGenerator); + } + + public void testKnnRetriever() { + BiFunction knnQueryGenerator = (testParams, requestBuilder) -> { + KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder( + VECTOR_FIELD, + testParams.queryVector, + null, + testParams.k, + testParams.numCands, + testParams.rescoreVectorBuilder, + null + ); + return requestBuilder.setSource(new SearchSourceBuilder().retriever(knnRetriever)); + }; + testKnnRescore(knnQueryGenerator); + } + + private void testKnnRescore(BiFunction searchRequestGenerator) { + TestParams testParams = TestParams.generate(); + + int numDocs = testParams.numDocs; + IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; + + for (int i = 0; i < numDocs; i++) { + docs[i] = prepareIndex(INDEX_NAME).setId("" + i).setSource(VECTOR_FIELD, randomVector(testParams.numDims)); + } + indexRandom(true, docs); + + float[] queryVector = testParams.queryVector; + float oversample = randomFloatBetween(1.0f, 100f, true); + RescoreVectorBuilder rescoreVectorBuilder = new RescoreVectorBuilder(oversample); + + SearchRequestBuilder requestBuilder = searchRequestGenerator.apply( + testParams, + prepareSearch(INDEX_NAME).setSize(numDocs).setTrackTotalHits(randomBoolean()) + ); + + assertNoFailuresAndResponse(requestBuilder, knnResponse -> { compareWithExactSearch(knnResponse, queryVector, numDocs); }); + } + + private static void compareWithExactSearch(SearchResponse knnResponse, float[] queryVector, int docCount) { + // Do an exact query and compare + Script script = new Script( + ScriptType.INLINE, + CustomScriptPlugin.NAME, + VECTOR_SCORE_SCRIPT, + Map.of(QUERY_VECTOR_PARAM, queryVector) + ); + ScriptScoreQueryBuilder scriptScoreQueryBuilder = QueryBuilders.scriptScoreQuery(new MatchAllQueryBuilder(), script); + assertNoFailuresAndResponse(prepareSearch(INDEX_NAME).setQuery(scriptScoreQueryBuilder).setSize(docCount), exactResponse -> { + assertHitCount(exactResponse, docCount); + + int i = 0; + SearchHit[] exactHits = exactResponse.getHits().getHits(); + for (SearchHit knnHit : knnResponse.getHits().getHits()) { + while (i < exactHits.length && exactHits[i].getId().equals(knnHit.getId()) == false) { + i++; + } + if (i >= exactHits.length) { + fail("Knn doc not found in exact search"); + } + assertThat("Real score is not the same as rescored score", knnHit.getScore(), equalTo(exactHits[i].getScore())); + } + }); + } + + private static float[] randomVector(int numDimensions) { + float[] vector = new float[numDimensions]; + for (int j = 0; j < numDimensions; j++) { + vector[j] = randomFloatBetween(0, 1, true); + } + return vector; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index ce41c2164e205..193b2f8d90433 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1225,7 +1225,7 @@ public final int hashCode() { } } - private enum VectorIndexType { + public enum VectorIndexType { HNSW("hnsw", false) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java index 3d13f3cd82b9c..35906940a6418 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java @@ -17,6 +17,7 @@ import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; @@ -24,6 +25,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Comparator; import java.util.Objects; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; @@ -31,9 +33,8 @@ /** * A query that matches the provided docs with their scores. * - * Note: this query was adapted from Lucene's DocAndScoreQuery from the class + * Note: this query was originally adapted from Lucene's DocAndScoreQuery from the class * {@link org.apache.lucene.search.KnnFloatVectorQuery}, which is package-private. - * There are no changes to the behavior, just some renames. */ public class KnnScoreDocQuery extends Query { private final int[] docs; @@ -50,13 +51,18 @@ public class KnnScoreDocQuery extends Query { /** * Creates a query. * - * @param docs the global doc IDs of documents that match, in ascending order - * @param scores the scores of the matching documents + * @param scoreDocs an array of ScoreDocs to use for the query * @param reader IndexReader */ - KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) { - this.docs = docs; - this.scores = scores; + KnnScoreDocQuery(ScoreDoc[] scoreDocs, IndexReader reader) { + // Ensure that the docs are sorted by docId, as they are later searched using binary search + Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); + this.docs = new int[scoreDocs.length]; + this.scores = new float[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + docs[i] = scoreDocs[i].doc; + scores[i] = scoreDocs[i].score; + } this.segmentStarts = findSegmentStarts(reader, docs); this.contextIdentity = reader.getContext().id(); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java index 6fa83ccfb6ac2..1a81f4b984e93 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java @@ -141,15 +141,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { - int numDocs = scoreDocs.length; - int[] docs = new int[numDocs]; - float[] scores = new float[numDocs]; - for (int i = 0; i < numDocs; i++) { - docs[i] = scoreDocs[i].doc; - scores[i] = scoreDocs[i].score; - } - - return new KnnScoreDocQuery(docs, scores, context.getIndexReader()); + return new KnnScoreDocQuery(scoreDocs, context.getIndexReader()); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 31d9767e9a857..99568a507ffb9 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -16,14 +16,12 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; import java.io.IOException; import java.util.Arrays; -import java.util.Comparator; import java.util.Objects; /** @@ -60,16 +58,7 @@ public Query rewrite(IndexSearcher searcher) throws IOException { // Retrieve top k documents from the rescored query TopDocs topDocs = searcher.search(query, k); vectorOperations = topDocs.totalHits.value(); - ScoreDoc[] scoreDocs = topDocs.scoreDocs; - Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); - int[] docIds = new int[scoreDocs.length]; - float[] scores = new float[scoreDocs.length]; - for (int i = 0; i < scoreDocs.length; i++) { - docIds[i] = scoreDocs[i].doc; - scores[i] = scoreDocs[i].score; - } - - return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader()); + return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); } public Query innerQuery() { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 861a8b11db567..05b7bc9ef4f82 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -9,36 +9,39 @@ package org.elasticsearch.search.vectors; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; -import org.apache.lucene.index.KnnVectorValues; -import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.FunctionScoreQuery; +import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; +import org.elasticsearch.index.codec.Elasticsearch900Lucene101Codec; +import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; +import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.io.UnsupportedEncodingException; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashSet; -import java.util.Map; -import java.util.PriorityQueue; -import java.util.stream.Collectors; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -59,51 +62,45 @@ public void testRescoreDocs() throws Exception { // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query // and thus we're rescoring the top k docs. float[] queryVector = randomVector(numDims); + Query innerQuery; + if (randomBoolean()) { + innerQuery = new KnnFloatVectorQuery(FIELD_NAME, queryVector, (int) (k * randomFloatBetween(1.0f, 10.0f, true))); + } else { + innerQuery = new MatchAllDocsQuery(); + } RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, k, - new MatchAllDocsQuery() + innerQuery ); IndexSearcher searcher = newSearcher(reader, true, false); - TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs); - Map rescoredDocs = Arrays.stream(docs.scoreDocs) - .collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score)); - - assertThat(rescoredDocs.size(), equalTo(k)); - - Collection rescoredScores = new HashSet<>(rescoredDocs.values()); - - // Collect all docs sequentially, and score them using the similarity function to get the top K scores - PriorityQueue topK = new PriorityQueue<>((o1, o2) -> Float.compare(o2, o1)); - - for (LeafReaderContext leafReaderContext : reader.leaves()) { - FloatVectorValues vectorValues = leafReaderContext.reader().getFloatVectorValues(FIELD_NAME); - KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); - while (iterator.nextDoc() != NO_MORE_DOCS) { - float[] vectorData = vectorValues.vectorValue(iterator.docID()); - float score = VectorSimilarityFunction.COSINE.compare(queryVector, vectorData); - topK.add(score); - int docId = iterator.docID(); - // If the doc has been retrieved from the RescoreKnnVectorQuery, check the score is the same and remove it - // to ensure we found them all - if (rescoredDocs.containsKey(docId)) { - assertThat(rescoredDocs.get(docId), equalTo(score)); - rescoredDocs.remove(docId); - } - } - } - - assertThat(rescoredDocs.size(), equalTo(0)); + TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs); + assertThat(rescoredDocs.scoreDocs.length, equalTo(k)); - // Check top scoring docs are contained in rescored docs - for (int i = 0; i < k; i++) { - Float topScore = topK.poll(); - if (rescoredScores.contains(topScore) == false) { - fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores); + // Get real scores + DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE + ); + FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new MatchAllDocsQuery(), valueSource); + TopDocs realScoreTopDocs = searcher.search(functionScoreQuery, numDocs); + + int i = 0; + ScoreDoc[] realScoreDocs = realScoreTopDocs.scoreDocs; + for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) { + // There are docs that won't be found in the rescored search, but every doc found must be in the same order + // and have the same score + while (i < realScoreDocs.length && realScoreDocs[i].doc != rescoreDoc.doc) { + i++; + } + if (i >= realScoreDocs.length) { + fail("Rescored doc not found in real score docs"); } + assertThat("Real score is not the same as rescored score", rescoreDoc.score, equalTo(realScoreDocs[i].score)); } } } @@ -205,16 +202,33 @@ public void profile(QueryProfiler queryProfiler) { } private static void addRandomDocuments(int numDocs, Directory d, int numDims) throws IOException { + IndexWriterConfig iwc = new IndexWriterConfig(); + // Pick codec from quantized vector formats to ensure scores use real scores when using knn rescore + KnnVectorsFormat format = randomFrom( + new ES818BinaryQuantizedVectorsFormat(), + new ES818HnswBinaryQuantizedVectorsFormat(), + new ES813Int8FlatVectorFormat(), + new ES813Int8FlatVectorFormat(), + new ES814HnswScalarQuantizedVectorsFormat() + ); + iwc.setCodec(new Elasticsearch900Lucene101Codec(randomFrom(Zstd814StoredFieldsFormat.Mode.values())) { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return format; + } + }); try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) { for (int i = 0; i < numDocs; i++) { Document document = new Document(); float[] vector = randomVector(numDims); - KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector); + KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector, VectorSimilarityFunction.COSINE); document.add(vectorField); w.addDocument(document); + if (randomBoolean() && (i % 10 == 0)) { + w.commit(); + } } w.commit(); - w.forceMerge(1); } } } From c06bde4c622777108716a038b4970fbe528eea04 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Mon, 24 Feb 2025 11:57:37 +0100 Subject: [PATCH 06/15] Cleanup duplication and dead code around ChunkedToXContentHelper (#123217) Cleans up a couple things that are obviously broken: * duplicate array and object constructs where the helper utility generates the exact same iterator * unused helper methods * single iterator concatenations --- .../ingest/geoip/IngestGeoIpMetadata.java | 3 +- .../ClusterAllocationExplanation.java | 6 +-- .../allocation/DesiredBalanceResponse.java | 10 ++-- .../node/stats/NodesStatsResponse.java | 7 ++- .../segments/IndicesSegmentResponse.java | 50 ++++++++----------- .../shards/IndicesShardStoresResponse.java | 8 +-- .../indices/stats/IndicesStatsResponse.java | 16 +++--- .../elasticsearch/cluster/ClusterState.java | 7 ++- .../ShutdownShardMigrationStatus.java | 3 +- .../xcontent/ChunkedToXContentHelper.java | 20 +------- .../indices/NodeIndicesStats.java | 21 ++++---- .../snapshots/RegisteredPolicySnapshots.java | 5 +- .../threadpool/ThreadPoolStats.java | 6 +-- ...StreamingUnifiedChatCompletionResults.java | 2 +- .../security/authz/RoleMappingMetadata.java | 3 +- 15 files changed, 57 insertions(+), 110 deletions(-) diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IngestGeoIpMetadata.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IngestGeoIpMetadata.java index 58ff64c97b2e0..11addb156d6a8 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IngestGeoIpMetadata.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IngestGeoIpMetadata.java @@ -15,7 +15,6 @@ import org.elasticsearch.cluster.DiffableUtils; import org.elasticsearch.cluster.NamedDiff; import org.elasticsearch.cluster.metadata.Metadata; -import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; @@ -92,7 +91,7 @@ public static IngestGeoIpMetadata fromXContent(XContentParser parser) throws IOE @Override public Iterator toXContentChunked(ToXContent.Params ignored) { - return Iterators.concat(ChunkedToXContentHelper.xContentObjectFields(DATABASES_FIELD.getPreferredName(), databases)); + return ChunkedToXContentHelper.xContentObjectFields(DATABASES_FIELD.getPreferredName(), databases); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanation.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanation.java index 3b6a161c3db8a..b7eecb6b3ddd8 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanation.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanation.java @@ -196,11 +196,7 @@ public Iterator toXContentChunked(ToXContent.Params params return builder; }), this.clusterInfo != null - ? Iterators.concat( - ChunkedToXContentHelper.startObject("cluster_info"), - this.clusterInfo.toXContentChunked(params), - ChunkedToXContentHelper.endObject() - ) + ? ChunkedToXContentHelper.object("cluster_info", this.clusterInfo.toXContentChunked(params)) : Collections.emptyIterator(), getShardAllocationDecisionChunked(params), Iterators.single((builder, p) -> builder.endObject()) diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/DesiredBalanceResponse.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/DesiredBalanceResponse.java index 2a35851f0017b..f2682e8c922cd 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/DesiredBalanceResponse.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/DesiredBalanceResponse.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.common.xcontent.ChunkedToXContentObject; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContent; @@ -34,8 +35,6 @@ import java.util.Set; import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.chunk; -import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.endObject; -import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.startObject; public class DesiredBalanceResponse extends ActionResponse implements ChunkedToXContentObject { @@ -96,16 +95,15 @@ public Iterator toXContentChunked(ToXContent.Params params ), Iterators.flatMap( routingTable.entrySet().iterator(), - indexEntry -> Iterators.concat( - startObject(indexEntry.getKey()), + indexEntry -> ChunkedToXContentHelper.object( + indexEntry.getKey(), Iterators.flatMap( indexEntry.getValue().entrySet().iterator(), shardEntry -> Iterators.concat( chunk((builder, p) -> builder.field(String.valueOf(shardEntry.getKey()))), shardEntry.getValue().toXContentChunked(params) ) - ), - endObject() + ) ) ), chunk((builder, p) -> builder.endObject().startObject("cluster_info")), diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/stats/NodesStatsResponse.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/stats/NodesStatsResponse.java index 25d34957b6958..864cd18e2ff3b 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/stats/NodesStatsResponse.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/stats/NodesStatsResponse.java @@ -42,14 +42,13 @@ protected void writeNodesTo(StreamOutput out, List nodes) throws IOEx @Override protected Iterator xContentChunks(ToXContent.Params outerParams) { - return Iterators.concat( - ChunkedToXContentHelper.startObject("nodes"), + return ChunkedToXContentHelper.object( + "nodes", Iterators.flatMap(getNodes().iterator(), nodeStats -> Iterators.concat(Iterators.single((builder, params) -> { builder.startObject(nodeStats.getNode().getId()); builder.field("timestamp", nodeStats.getTimestamp()); return builder; - }), nodeStats.toXContentChunked(outerParams), ChunkedToXContentHelper.endObject())), - ChunkedToXContentHelper.endObject() + }), nodeStats.toXContentChunked(outerParams), ChunkedToXContentHelper.endObject())) ); } diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/segments/IndicesSegmentResponse.java b/server/src/main/java/org/elasticsearch/action/admin/indices/segments/IndicesSegmentResponse.java index d34e71f715a5d..c98a4feed6a66 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/segments/IndicesSegmentResponse.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/segments/IndicesSegmentResponse.java @@ -71,9 +71,8 @@ public void writeTo(StreamOutput out) throws IOException { @Override protected Iterator customXContentChunks(ToXContent.Params params) { - return Iterators.concat( - - ChunkedToXContentHelper.startObject(Fields.INDICES), + return ChunkedToXContentHelper.object( + Fields.INDICES, Iterators.flatMap( getIndices().values().iterator(), indexSegments -> Iterators.concat( @@ -81,9 +80,8 @@ protected Iterator customXContentChunks(ToXContent.Params params) { ChunkedToXContentHelper.chunk((builder, p) -> builder.startObject(indexSegments.getIndex()).startObject(Fields.SHARDS)), Iterators.flatMap( indexSegments.iterator(), - indexSegment -> Iterators.concat( - - ChunkedToXContentHelper.startArray(Integer.toString(indexSegment.shardId().id())), + indexSegment -> ChunkedToXContentHelper.array( + Integer.toString(indexSegment.shardId().id()), Iterators.flatMap( indexSegment.iterator(), shardSegments -> Iterators.concat( @@ -141,14 +139,12 @@ protected Iterator customXContentChunks(ToXContent.Params params) { ), ChunkedToXContentHelper.chunk((builder, p) -> builder.endObject().endObject()) ) - ), - ChunkedToXContentHelper.endArray() + ) ) ), ChunkedToXContentHelper.chunk((builder, p) -> builder.endObject().endObject()) ) - ), - ChunkedToXContentHelper.endObject() + ) ); } @@ -157,25 +153,21 @@ private static Iterator getSegmentSortChunks(@Nullable Sort segmentS return Collections.emptyIterator(); } - return Iterators.concat( - ChunkedToXContentHelper.startArray("sort"), - Iterators.map(Iterators.forArray(segmentSort.getSort()), field -> (builder, p) -> { - builder.startObject(); - builder.field("field", field.getField()); - if (field instanceof SortedNumericSortField sortedNumericSortField) { - builder.field("mode", sortedNumericSortField.getSelector().toString().toLowerCase(Locale.ROOT)); - } else if (field instanceof SortedSetSortField sortedSetSortField) { - builder.field("mode", sortedSetSortField.getSelector().toString().toLowerCase(Locale.ROOT)); - } - if (field.getMissingValue() != null) { - builder.field("missing", field.getMissingValue().toString()); - } - builder.field("reverse", field.getReverse()); - builder.endObject(); - return builder; - }), - ChunkedToXContentHelper.endArray() - ); + return ChunkedToXContentHelper.array("sort", Iterators.map(Iterators.forArray(segmentSort.getSort()), field -> (builder, p) -> { + builder.startObject(); + builder.field("field", field.getField()); + if (field instanceof SortedNumericSortField sortedNumericSortField) { + builder.field("mode", sortedNumericSortField.getSelector().toString().toLowerCase(Locale.ROOT)); + } else if (field instanceof SortedSetSortField sortedSetSortField) { + builder.field("mode", sortedSetSortField.getSelector().toString().toLowerCase(Locale.ROOT)); + } + if (field.getMissingValue() != null) { + builder.field("missing", field.getMissingValue().toString()); + } + builder.field("reverse", field.getReverse()); + builder.endObject(); + return builder; + })); } static final class Fields { diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/shards/IndicesShardStoresResponse.java b/server/src/main/java/org/elasticsearch/action/admin/indices/shards/IndicesShardStoresResponse.java index 60a8ae5ad6371..d3fd1c0ab5ef5 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/shards/IndicesShardStoresResponse.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/shards/IndicesShardStoresResponse.java @@ -279,13 +279,7 @@ public Iterator toXContentChunked(ToXContent.Params outerP return Iterators.concat( ChunkedToXContentHelper.startObject(), - failures.isEmpty() - ? Collections.emptyIterator() - : Iterators.concat( - ChunkedToXContentHelper.startArray(Fields.FAILURES), - failures.iterator(), - ChunkedToXContentHelper.endArray() - ), + failures.isEmpty() ? Collections.emptyIterator() : ChunkedToXContentHelper.array(Fields.FAILURES, failures.iterator()), ChunkedToXContentHelper.startObject(Fields.INDICES), diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/stats/IndicesStatsResponse.java b/server/src/main/java/org/elasticsearch/action/admin/indices/stats/IndicesStatsResponse.java index d6c9a3e2e544b..7e5a45fc3d65b 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/stats/IndicesStatsResponse.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/stats/IndicesStatsResponse.java @@ -232,22 +232,20 @@ protected Iterator customXContentChunks(ToXContent.Params params) { }), level == ClusterStatsLevel.SHARDS - ? Iterators.concat( - ChunkedToXContentHelper.startObject(Fields.SHARDS), + ? ChunkedToXContentHelper.object( + Fields.SHARDS, Iterators.flatMap( indexStats.iterator(), - indexShardStats -> Iterators.concat( - ChunkedToXContentHelper.startArray(Integer.toString(indexShardStats.getShardId().id())), - Iterators.map(indexShardStats.iterator(), shardStats -> (builder, p) -> { + indexShardStats -> ChunkedToXContentHelper.array( + Integer.toString(indexShardStats.getShardId().id()), + Iterators.map(indexShardStats.iterator(), shardStats -> (builder, p) -> { builder.startObject(); shardStats.toXContent(builder, p); builder.endObject(); return builder; - }), - ChunkedToXContentHelper.endArray() + }) ) - ), - ChunkedToXContentHelper.endObject() + ) ) : Collections.emptyIterator(), diff --git a/server/src/main/java/org/elasticsearch/cluster/ClusterState.java b/server/src/main/java/org/elasticsearch/cluster/ClusterState.java index 6b222fb8f5bdc..4ae53075d9e28 100644 --- a/server/src/main/java/org/elasticsearch/cluster/ClusterState.java +++ b/server/src/main/java/org/elasticsearch/cluster/ClusterState.java @@ -747,10 +747,9 @@ public Iterator toXContentChunked(ToXContent.Params outerP metrics.contains(Metric.ROUTING_NODES), (builder, params) -> builder.startObject("nodes"), getRoutingNodes().iterator(), - routingNode -> Iterators.concat( - ChunkedToXContentHelper.startArray(routingNode.nodeId() == null ? "null" : routingNode.nodeId()), - routingNode.iterator(), - ChunkedToXContentHelper.endArray() + routingNode -> ChunkedToXContentHelper.array( + routingNode.nodeId() == null ? "null" : routingNode.nodeId(), + routingNode.iterator() ), (builder, params) -> builder.endObject().endObject() ), diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/ShutdownShardMigrationStatus.java b/server/src/main/java/org/elasticsearch/cluster/metadata/ShutdownShardMigrationStatus.java index 72448b56e3a07..0cb51a8a3271e 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/ShutdownShardMigrationStatus.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/ShutdownShardMigrationStatus.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.common.xcontent.ChunkedToXContentObject; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContent; @@ -170,7 +171,7 @@ public Iterator toXContentChunked(ToXContent.Params params startObject(), chunk((builder, p) -> buildHeader(builder)), Objects.nonNull(allocationDecision) - ? Iterators.concat(startObject(NODE_ALLOCATION_DECISION_KEY), allocationDecision.toXContentChunked(params), endObject()) + ? ChunkedToXContentHelper.object(NODE_ALLOCATION_DECISION_KEY, allocationDecision.toXContentChunked(params)) : Collections.emptyIterator(), endObject() ); diff --git a/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java b/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java index 7afc33ff265bb..93d5b34c8ec57 100644 --- a/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java +++ b/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java @@ -83,10 +83,6 @@ public static Iterator field(String name, ChunkedToXContentObject va return Iterators.concat(Iterators.single((builder, innerParam) -> builder.field(name)), value.toXContentChunked(params)); } - public static Iterator array(Iterator contents) { - return Iterators.concat(startArray(), contents, endArray()); - } - public static Iterator array(String name, Iterator contents) { return Iterators.concat(startArray(name), contents, endArray()); } @@ -95,10 +91,6 @@ public static Iterator array(Iterator items, Function Iterator array(String name, Iterator items, Function toXContent) { - return Iterators.concat(startArray(name), Iterators.map(items, toXContent), endArray()); - } - /** * Creates an Iterator to serialize a named field where the value is represented by an iterator of {@link ChunkedToXContentObject}. * Chunked equivalent for {@code XContentBuilder array(String name, ToXContent value)} @@ -108,7 +100,7 @@ public static Iterator array(String name, Iterator items, Fun * @return Iterator composing field name and value serialization */ public static Iterator array(String name, Iterator contents, ToXContent.Params params) { - return Iterators.concat(startArray(name), Iterators.flatMap(contents, c -> c.toXContentChunked(params)), endArray()); + return array(name, Iterators.flatMap(contents, c -> c.toXContentChunked(params))); } /** @@ -129,14 +121,4 @@ public static Iterator chunk(ToXContent item) { return Iterators.single(item); } - /** - * Creates an Iterator of a single ToXContent object that serializes the given object as a single chunk. Just wraps {@link - * Iterators#single}, but still useful because it avoids any type ambiguity. - * - * @param item Item to wrap - * @return Singleton iterator for the given item. - */ - public static Iterator singleChunk(ToXContent item) { - return Iterators.single(item); - } } diff --git a/server/src/main/java/org/elasticsearch/indices/NodeIndicesStats.java b/server/src/main/java/org/elasticsearch/indices/NodeIndicesStats.java index faabccdb1eefa..c8026e98c4948 100644 --- a/server/src/main/java/org/elasticsearch/indices/NodeIndicesStats.java +++ b/server/src/main/java/org/elasticsearch/indices/NodeIndicesStats.java @@ -257,22 +257,21 @@ public Iterator toXContentChunked(ToXContent.Params outerP case NODE -> Collections.emptyIterator(); - case INDICES -> Iterators.concat( - ChunkedToXContentHelper.startObject(Fields.INDICES), + case INDICES -> ChunkedToXContentHelper.object( + Fields.INDICES, Iterators.map(createCommonStatsByIndex().entrySet().iterator(), entry -> (builder, params) -> { builder.startObject(entry.getKey().getName()); entry.getValue().toXContent(builder, params); return builder.endObject(); - }), - ChunkedToXContentHelper.endObject() + }) ); - case SHARDS -> Iterators.concat( - ChunkedToXContentHelper.startObject(Fields.SHARDS), + case SHARDS -> ChunkedToXContentHelper.object( + Fields.SHARDS, Iterators.flatMap( statsByShard.entrySet().iterator(), - entry -> Iterators.concat( - ChunkedToXContentHelper.startArray(entry.getKey().getName()), + entry -> ChunkedToXContentHelper.array( + entry.getKey().getName(), Iterators.flatMap( entry.getValue().iterator(), indexShardStats -> Iterators.concat( @@ -282,11 +281,9 @@ public Iterator toXContentChunked(ToXContent.Params outerP Iterators.flatMap(Iterators.forArray(indexShardStats.getShards()), Iterators::single), Iterators.single((b, p) -> b.endObject().endObject()) ) - ), - ChunkedToXContentHelper.endArray() + ) ) - ), - ChunkedToXContentHelper.endObject() + ) ); }, diff --git a/server/src/main/java/org/elasticsearch/snapshots/RegisteredPolicySnapshots.java b/server/src/main/java/org/elasticsearch/snapshots/RegisteredPolicySnapshots.java index 231894875b7fa..31c50d3313a75 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/RegisteredPolicySnapshots.java +++ b/server/src/main/java/org/elasticsearch/snapshots/RegisteredPolicySnapshots.java @@ -111,10 +111,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public Iterator toXContentChunked(ToXContent.Params ignored) { - return Iterators.concat(Iterators.single((builder, params) -> { - builder.field(SNAPSHOTS.getPreferredName(), snapshots); - return builder; - })); + return Iterators.single((builder, params) -> builder.field(SNAPSHOTS.getPreferredName(), snapshots)); } public static RegisteredPolicySnapshots parse(XContentParser parser) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/threadpool/ThreadPoolStats.java b/server/src/main/java/org/elasticsearch/threadpool/ThreadPoolStats.java index 8761f9b109a26..1edb7da539418 100644 --- a/server/src/main/java/org/elasticsearch/threadpool/ThreadPoolStats.java +++ b/server/src/main/java/org/elasticsearch/threadpool/ThreadPoolStats.java @@ -162,10 +162,6 @@ static final class Fields { @Override public Iterator toXContentChunked(ToXContent.Params params) { - return Iterators.concat( - ChunkedToXContentHelper.startObject(Fields.THREAD_POOL), - Iterators.flatMap(stats.iterator(), s -> s.toXContentChunked(params)), - ChunkedToXContentHelper.endObject() - ); + return ChunkedToXContentHelper.object(Fields.THREAD_POOL, Iterators.flatMap(stats.iterator(), s -> s.toXContentChunked(params))); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java index 4604a522c147b..99404d9ce66b0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -76,7 +76,7 @@ public Results(StreamInput in) throws IOException { @Override public Iterator toXContentChunked(ToXContent.Params params) { - return Iterators.concat(Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params))); + return Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params)); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/RoleMappingMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/RoleMappingMetadata.java index 31fe86ca77edd..8e366547ba8a7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/RoleMappingMetadata.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/RoleMappingMetadata.java @@ -15,7 +15,6 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.NamedDiff; import org.elasticsearch.cluster.metadata.Metadata; -import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; @@ -97,7 +96,7 @@ public static NamedDiff readDiffFrom(StreamInput streamInput) t @Override public Iterator toXContentChunked(ToXContent.Params params) { // role mappings are serialized without their names - return Iterators.concat(ChunkedToXContentHelper.startArray(TYPE), roleMappings.iterator(), ChunkedToXContentHelper.endArray()); + return ChunkedToXContentHelper.array(TYPE, roleMappings.iterator()); } public static RoleMappingMetadata fromXContent(XContentParser parser) throws IOException { From 2a9de3f57f1bb871b5d4dae443549b7d668323a8 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Mon, 24 Feb 2025 11:59:39 +0100 Subject: [PATCH 07/15] Avoid creating IndexSearcher in Engine.refreshNeeded (#123218) Checking whether we need to refresh does not require a searcher so we can simplify this to just work based on the reader and avoid lots of contention etc. for setting up the searcher. relates #122374 --- .../elasticsearch/index/engine/Engine.java | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/engine/Engine.java b/server/src/main/java/org/elasticsearch/index/engine/Engine.java index 0589741a70281..3298d8757ca92 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/Engine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/Engine.java @@ -1214,25 +1214,29 @@ private void fillSegmentInfo( public abstract List segments(boolean includeVectorFormatsInfo); public boolean refreshNeeded() { - if (store.tryIncRef()) { - /* - we need to inc the store here since we acquire a searcher and that might keep a file open on the - store. this violates the assumption that all files are closed when - the store is closed so we need to make sure we increment it here - */ + if (store.tryIncRef() == false) { + return false; + } + /* + we need to inc the store here since we acquire a directory reader and that might open a file on the store. + This violates the assumption that all files are closed when the store is closed so we need to make + sure we increment it here. + */ + try { + var refManager = getReferenceManager(SearcherScope.EXTERNAL); + var reader = refManager.acquire(); try { - try (Searcher searcher = acquireSearcher("refresh_needed", SearcherScope.EXTERNAL)) { - return searcher.getDirectoryReader().isCurrent() == false; - } - } catch (IOException e) { - logger.error("failed to access searcher manager", e); - failEngine("failed to access searcher manager", e); - throw new EngineException(shardId, "failed to access searcher manager", e); + return reader.isCurrent() == false; } finally { - store.decRef(); + refManager.release(reader); } + } catch (IOException e) { + logger.error("failed to access directory reader", e); + failEngine("failed to access directory reader", e); + throw new EngineException(shardId, "failed to access directory reader", e); + } finally { + store.decRef(); } - return false; } /** From 236b955b95e9b13ddc45d059ae87201ff44d953e Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Mon, 24 Feb 2025 12:00:29 +0100 Subject: [PATCH 08/15] Speedup SearchResponse serialization (#123211) No need to have a nested concat here. There's obviously lots and lots of room for optimization on this one, but just flattening out one obvious step here outright halves the number of method calls required when serializing a search response. Given that method calls can consume up to half the serialization cost this change might massively speed up some usecases. --- .../action/search/SearchResponse.java | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchResponse.java b/server/src/main/java/org/elasticsearch/action/search/SearchResponse.java index 787dc14f6cd96..2a9acd465f727 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchResponse.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchResponse.java @@ -382,23 +382,23 @@ public Clusters getClusters() { @Override public Iterator toXContentChunked(ToXContent.Params params) { assert hasReferences(); - return Iterators.concat( - ChunkedToXContentHelper.startObject(), - this.innerToXContentChunked(params), - ChunkedToXContentHelper.endObject() - ); + return getToXContentIterator(true, params); } public Iterator innerToXContentChunked(ToXContent.Params params) { + return getToXContentIterator(false, params); + } + + private Iterator getToXContentIterator(boolean wrapInObject, ToXContent.Params params) { return Iterators.concat( + wrapInObject ? ChunkedToXContentHelper.startObject() : Collections.emptyIterator(), ChunkedToXContentHelper.chunk(SearchResponse.this::headerToXContent), Iterators.single(clusters), - Iterators.concat( - hits.toXContentChunked(params), - aggregations == null ? Collections.emptyIterator() : ChunkedToXContentHelper.chunk(aggregations), - suggest == null ? Collections.emptyIterator() : ChunkedToXContentHelper.chunk(suggest), - profileResults == null ? Collections.emptyIterator() : ChunkedToXContentHelper.chunk(profileResults) - ) + hits.toXContentChunked(params), + aggregations == null ? Collections.emptyIterator() : ChunkedToXContentHelper.chunk(aggregations), + suggest == null ? Collections.emptyIterator() : ChunkedToXContentHelper.chunk(suggest), + profileResults == null ? Collections.emptyIterator() : ChunkedToXContentHelper.chunk(profileResults), + wrapInObject ? ChunkedToXContentHelper.endObject() : Collections.emptyIterator() ); } From 135f00a9ff760122a5faf9cc38ec1c720d65b631 Mon Sep 17 00:00:00 2001 From: Chris Hegarty <62058229+ChrisHegarty@users.noreply.github.com> Date: Mon, 24 Feb 2025 11:10:00 +0000 Subject: [PATCH 09/15] Enable testCommandNamesAsIdentifiers (#123249) --- .../xpack/esql/parser/ExpressionTests.java | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java index 4f7233a399d09..384b80c16c78b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java @@ -208,21 +208,8 @@ public void testParenthesizedExpression() { assertThat(((UnresolvedAttribute) and.left()).name(), equalTo("a")); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/121950") - public void testCommandNamesAsIdentifiersWithLimit() { - Expression expr = whereExpression("from and limit"); - assertThat(expr, instanceOf(And.class)); - And and = (And) expr; - - assertThat(and.left(), instanceOf(UnresolvedAttribute.class)); - assertThat(((UnresolvedAttribute) and.left()).name(), equalTo("from")); - - assertThat(and.right(), instanceOf(UnresolvedAttribute.class)); - assertThat(((UnresolvedAttribute) and.right()).name(), equalTo("limit")); - } - public void testCommandNamesAsIdentifiers() { - for (var commandName : List.of("keep", "drop", "dissect", "eval")) { + for (var commandName : List.of("dissect", "drop", "enrich", "eval", "keep", "limit", "sort")) { Expression expr = whereExpression("from and " + commandName); assertThat(expr, instanceOf(And.class)); And and = (And) expr; From 329514965ec3db2b1e414db553d7b34767fe88c1 Mon Sep 17 00:00:00 2001 From: Moritz Mack Date: Mon, 24 Feb 2025 12:31:15 +0100 Subject: [PATCH 10/15] Unmute AsyncSearchSecurityIT hitting potential JDK bug to gather more samples (#123253) --- muted-tests.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index d12b04183797a..81d56d427e0c8 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -309,8 +309,6 @@ tests: - class: org.elasticsearch.xpack.esql.heap_attack.HeapAttackIT method: testEnrichExplosionManyMatches issue: https://github.com/elastic/elasticsearch/issues/122913 -- class: org.elasticsearch.xpack.search.AsyncSearchSecurityIT - issue: https://github.com/elastic/elasticsearch/issues/122940 - class: org.elasticsearch.test.apmintegration.TracesApmIT method: testApmIntegration issue: https://github.com/elastic/elasticsearch/issues/122129 From 4e71c0435310c9024d8491fa2829048d3cdefd57 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 24 Feb 2025 07:39:17 -0500 Subject: [PATCH 11/15] Adding more tests for rank_vectors for hex string case (#123185) --- .../rank_vectors/rank_vectors_max_sim.yml | 60 ++++++++++++++++--- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/rank_vectors/rank_vectors_max_sim.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/rank_vectors/rank_vectors_max_sim.yml index acaf1b99b626e..51fb47e1c4e7d 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/rank_vectors/rank_vectors_max_sim.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/rank_vectors/rank_vectors_max_sim.yml @@ -14,12 +14,12 @@ setup: mappings: properties: vector: - type: rank_vectors - dims: 5 + type: rank_vectors + dims: 5 byte_vector: - type: rank_vectors - dims: 5 - element_type: byte + type: rank_vectors + dims: 5 + element_type: byte bit_vector: type: rank_vectors dims: 40 @@ -39,8 +39,8 @@ setup: id: "3" body: vector: [[0.5, 111.3, -13.0, 14.8, -156.0]] - byte_vector: [[2, 18, -5, 0, -124]] - bit_vector: [[2, 18, -5, 0, -124]] + byte_vector: ["0212fb0084"] + bit_vector: ["0212fb0084"] - do: indices.refresh: {} @@ -115,7 +115,7 @@ setup: - match: {hits.hits.1._id: "3"} - close_to: {hits.hits.1._score: {value: 2, error: 0.01}} -# doing max-sim dot product with a vector where the stored bit vectors are used as masks + # doing max-sim dot product with a vector where the stored bit vectors are used as masks - do: headers: Content-Type: application/json @@ -179,6 +179,28 @@ setup: - match: {hits.hits.1._id: "1"} - close_to: {hits.hits.1._score: {value: 0.65, error: 0.01}} + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "maxSimInvHamming(params.query_vector, 'byte_vector')" + params: + query_vector: ["0102010101"] + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "3"} + - close_to: {hits.hits.0._score: {value: 0.675, error: 0.01}} + + - match: {hits.hits.1._id: "1"} + - close_to: {hits.hits.1._score: {value: 0.65, error: 0.01}} + - do: headers: Content-Type: application/json @@ -200,3 +222,25 @@ setup: - match: {hits.hits.1._id: "1"} - close_to: {hits.hits.1._score: {value: 0.65, error: 0.01}} + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "maxSimInvHamming(params.query_vector, 'bit_vector')" + params: + query_vector: ["0102010101"] + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "3"} + - close_to: {hits.hits.0._score: {value: 0.675, error: 0.01}} + + - match: {hits.hits.1._id: "1"} + - close_to: {hits.hits.1._score: {value: 0.65, error: 0.01}} From 187b192dfe7c6dc60bfa59fcd90c0ce1e6e3a857 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 24 Feb 2025 13:21:10 +0000 Subject: [PATCH 12/15] Deduplicate allocation stats calls (#123246) These things can be quite expensive and there's no need to recompute them in parallel across all management threads as done today. This commit adds a deduplicator to avoid redundant work. --- docs/changelog/123246.yaml | 5 ++ .../TransportGetAllocationStatsAction.java | 31 +++++++---- .../allocation/AllocationStatsService.java | 3 ++ ...ransportGetAllocationStatsActionTests.java | 51 +++++++++++++++++++ 4 files changed, 81 insertions(+), 9 deletions(-) create mode 100644 docs/changelog/123246.yaml diff --git a/docs/changelog/123246.yaml b/docs/changelog/123246.yaml new file mode 100644 index 0000000000000..3477cf70ac89b --- /dev/null +++ b/docs/changelog/123246.yaml @@ -0,0 +1,5 @@ +pr: 123246 +summary: Deduplicate allocation stats calls +area: Allocation +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java index 15c1456f98eaa..b2e995a116331 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java @@ -13,9 +13,12 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.SingleResultDeduplicator; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters.Metric; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.action.support.master.MasterNodeReadRequest; import org.elasticsearch.action.support.master.TransportMasterNodeReadAction; import org.elasticsearch.cluster.ClusterState; @@ -28,6 +31,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.injection.guice.Inject; @@ -46,7 +50,7 @@ public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAc public static final ActionType TYPE = new ActionType<>("cluster:monitor/allocation/stats"); - private final AllocationStatsService allocationStatsService; + private final SingleResultDeduplicator> allocationStatsSupplier; private final DiskThresholdSettings diskThresholdSettings; @Inject @@ -66,9 +70,15 @@ public TransportGetAllocationStatsAction( actionFilters, TransportGetAllocationStatsAction.Request::new, TransportGetAllocationStatsAction.Response::new, - threadPool.executor(ThreadPool.Names.MANAGEMENT) + // DIRECT is ok here because we fork the allocation stats computation onto a MANAGEMENT thread if needed, or else we return + // very cheaply. + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + final var managementExecutor = threadPool.executor(ThreadPool.Names.MANAGEMENT); + this.allocationStatsSupplier = new SingleResultDeduplicator<>( + threadPool.getThreadContext(), + l -> managementExecutor.execute(ActionRunnable.supply(l, allocationStatsService::stats)) ); - this.allocationStatsService = allocationStatsService; this.diskThresholdSettings = new DiskThresholdSettings(clusterService.getSettings(), clusterService.getClusterSettings()); } @@ -84,12 +94,15 @@ protected void doExecute(Task task, Request request, ActionListener li @Override protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) throws Exception { - listener.onResponse( - new Response( - request.metrics().contains(Metric.ALLOCATIONS) ? allocationStatsService.stats() : Map.of(), - request.metrics().contains(Metric.FS) ? diskThresholdSettings : null - ) - ); + // NB we are still on a transport thread here - if adding more functionality here make sure to fork to a different pool + + final SubscribableListener> allocationStatsStep = request.metrics().contains(Metric.ALLOCATIONS) + ? SubscribableListener.newForked(allocationStatsSupplier::execute) + : SubscribableListener.newSucceeded(Map.of()); + + allocationStatsStep.andThenApply( + allocationStats -> new Response(allocationStats, request.metrics().contains(Metric.FS) ? diskThresholdSettings : null) + ).addListener(listener); } @Override diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsService.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsService.java index dfcc4da9dcc56..926a6926c9aea 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsService.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsService.java @@ -14,6 +14,7 @@ import org.elasticsearch.cluster.routing.allocation.allocator.DesiredBalanceShardsAllocator; import org.elasticsearch.cluster.routing.allocation.allocator.ShardsAllocator; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.transport.Transports; import java.util.Map; import java.util.function.Supplier; @@ -46,6 +47,8 @@ public AllocationStatsService( * Returns a map of node IDs to node allocation stats. */ public Map stats() { + assert Transports.assertNotTransportThread("too expensive for a transport worker"); + var clusterState = clusterService.state(); var nodesStatsAndWeights = nodeAllocationStatsAndWeightsCalculator.nodesAllocationStatsAndWeights( clusterState.metadata(), diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java index d9bf3e0e99c81..d362be6a3e321 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java @@ -31,8 +31,13 @@ import java.util.EnumSet; import java.util.Map; import java.util.Set; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.Matchers.anEmptyMap; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.not; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -112,4 +117,50 @@ public void testReturnsOnlyRequestedStats() throws Exception { assertNull(response.getDiskThresholdSettings()); } } + + public void testDeduplicatesStatsComputations() throws InterruptedException { + final var requestCounter = new AtomicInteger(); + final var isExecuting = new AtomicBoolean(); + when(allocationStatsService.stats()).thenAnswer(invocation -> { + try { + assertTrue(isExecuting.compareAndSet(false, true)); + assertThat(Thread.currentThread().getName(), containsString("[management]")); + return Map.of(Integer.toString(requestCounter.incrementAndGet()), NodeAllocationStatsTests.randomNodeAllocationStats()); + } finally { + Thread.yield(); + assertTrue(isExecuting.compareAndSet(true, false)); + } + }); + + final var threads = new Thread[between(1, 5)]; + final var startBarrier = new CyclicBarrier(threads.length); + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(() -> { + safeAwait(startBarrier); + + final var minRequestIndex = requestCounter.get(); + + final TransportGetAllocationStatsAction.Response response = safeAwait( + l -> action.masterOperation( + mock(Task.class), + new TransportGetAllocationStatsAction.Request( + TEST_REQUEST_TIMEOUT, + TaskId.EMPTY_TASK_ID, + EnumSet.of(Metric.ALLOCATIONS) + ), + ClusterState.EMPTY_STATE, + l + ) + ); + + final var requestIndex = Integer.valueOf(response.getNodeAllocationStats().keySet().iterator().next()); + assertThat(requestIndex, greaterThanOrEqualTo(minRequestIndex)); // did not get a stale result + }, "thread-" + i); + threads[i].start(); + } + + for (final var thread : threads) { + thread.join(); + } + } } From cae7f0a80973310cf321aabcdb409276499e3950 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Mon, 24 Feb 2025 14:26:22 +0100 Subject: [PATCH 13/15] Use inheritance instead of composition to simplify search phase transitions (#119272) We only need the extensibility for testing and it's a lot easier to reason about the code if we have explicit methods instead of overly complicated composition with lots of redundant references being retained all over the place. -> lets simplify to inheritance and get shorter code that performs more predictably (especially when it comes to memory) as a first step. This also opens up the possibility of further simplifications and removing more retained state/memory as we go through the search phases. --- .../action/search/DfsQueryPhase.java | 150 ++++++++++++++---- .../action/search/ExpandSearchPhase.java | 31 ++-- .../action/search/FetchLookupFieldsPhase.java | 14 +- .../action/search/FetchSearchPhase.java | 36 ++--- .../SearchDfsQueryThenFetchAsyncAction.java | 115 +------------- .../action/search/DfsQueryPhaseTests.java | 49 +++--- .../action/search/ExpandSearchPhaseTests.java | 88 +++++----- .../action/search/FetchSearchPhaseTests.java | 82 ++++------ 8 files changed, 262 insertions(+), 303 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index faeb552530e47..d67e656773495 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -8,8 +8,15 @@ */ package org.elasticsearch.action.search; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TermStatistics; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; @@ -27,9 +34,11 @@ import org.elasticsearch.transport.Transport; import java.util.ArrayList; +import java.util.Collection; import java.util.Comparator; +import java.util.HashMap; import java.util.List; -import java.util.function.Function; +import java.util.Map; /** * This search phase fans out to every shards to execute a distributed search with a pre-collected distributed frequencies for all @@ -38,56 +47,50 @@ * operation. * @see CountedCollector#onFailure(int, SearchShardTarget, Exception) */ -final class DfsQueryPhase extends SearchPhase { +class DfsQueryPhase extends SearchPhase { public static final String NAME = "dfs_query"; private final SearchPhaseResults queryResult; - private final List searchResults; - private final AggregatedDfs dfs; - private final List knnResults; - private final Function, SearchPhase> nextPhaseFactory; + private final Client client; private final AbstractSearchAsyncAction context; - private final SearchTransportService searchTransportService; private final SearchProgressListener progressListener; - DfsQueryPhase( - List searchResults, - AggregatedDfs dfs, - List knnResults, - SearchPhaseResults queryResult, - Function, SearchPhase> nextPhaseFactory, - AbstractSearchAsyncAction context - ) { + DfsQueryPhase(SearchPhaseResults queryResult, Client client, AbstractSearchAsyncAction context) { super(NAME); this.progressListener = context.getTask().getProgressListener(); this.queryResult = queryResult; - this.searchResults = searchResults; - this.dfs = dfs; - this.knnResults = knnResults; - this.nextPhaseFactory = nextPhaseFactory; + this.client = client; this.context = context; - this.searchTransportService = context.getSearchTransport(); } + // protected for testing + protected SearchPhase nextPhase(AggregatedDfs dfs) { + return SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResult, dfs); + } + + @SuppressWarnings("unchecked") @Override protected void run() { + List searchResults = (List) context.results.getAtomicArray().asList(); + AggregatedDfs dfs = aggregateDfs(searchResults); // TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs // to free up memory early final CountedCollector counter = new CountedCollector<>( queryResult, searchResults.size(), - () -> context.executeNextPhase(NAME, () -> nextPhaseFactory.apply(queryResult)), + () -> context.executeNextPhase(NAME, () -> nextPhase(dfs)), context ); + List knnResults = mergeKnnResults(context.getRequest(), searchResults); for (final DfsSearchResult dfsResult : searchResults) { final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget(); final int shardIndex = dfsResult.getShardIndex(); QuerySearchRequest querySearchRequest = new QuerySearchRequest( context.getOriginalIndices(shardIndex), dfsResult.getContextId(), - rewriteShardSearchRequest(dfsResult.getShardSearchRequest()), + rewriteShardSearchRequest(knnResults, dfsResult.getShardSearchRequest()), dfs ); final Transport.Connection connection; @@ -97,11 +100,8 @@ protected void run() { shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter); continue; } - searchTransportService.sendExecuteQuery( - connection, - querySearchRequest, - context.getTask(), - new SearchActionListener<>(shardTarget, shardIndex) { + context.getSearchTransport() + .sendExecuteQuery(connection, querySearchRequest, context.getTask(), new SearchActionListener<>(shardTarget, shardIndex) { @Override protected void innerOnResponse(QuerySearchResult response) { @@ -126,8 +126,7 @@ public void onFailure(Exception exception) { } } } - } - ); + }); } } @@ -144,7 +143,7 @@ private void shardFailure( } // package private for testing - ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) { + ShardSearchRequest rewriteShardSearchRequest(List knnResults, ShardSearchRequest request) { SearchSourceBuilder source = request.source(); if (source == null || source.knnSearch().isEmpty()) { return request; @@ -180,4 +179,95 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) { return request; } + + private static List mergeKnnResults(SearchRequest request, List dfsSearchResults) { + if (request.hasKnnSearch() == false) { + return null; + } + SearchSourceBuilder source = request.source(); + List> topDocsLists = new ArrayList<>(source.knnSearch().size()); + List> nestedPath = new ArrayList<>(source.knnSearch().size()); + for (int i = 0; i < source.knnSearch().size(); i++) { + topDocsLists.add(new ArrayList<>()); + nestedPath.add(new SetOnce<>()); + } + + for (DfsSearchResult dfsSearchResult : dfsSearchResults) { + if (dfsSearchResult.knnResults() != null) { + for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) { + DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i); + ScoreDoc[] scoreDocs = knnResults.scoreDocs(); + TotalHits totalHits = new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO); + TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs); + SearchPhaseController.setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex()); + topDocsLists.get(i).add(shardTopDocs); + nestedPath.get(i).trySet(knnResults.getNestedPath()); + } + } + } + + List mergedResults = new ArrayList<>(source.knnSearch().size()); + for (int i = 0; i < source.knnSearch().size(); i++) { + TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0])); + mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs)); + } + return mergedResults; + } + + private static AggregatedDfs aggregateDfs(Collection results) { + Map termStatistics = new HashMap<>(); + Map fieldStatistics = new HashMap<>(); + long aggMaxDoc = 0; + for (DfsSearchResult lEntry : results) { + final Term[] terms = lEntry.terms(); + final TermStatistics[] stats = lEntry.termStatistics(); + assert terms.length == stats.length; + for (int i = 0; i < terms.length; i++) { + assert terms[i] != null; + if (stats[i] == null) { + continue; + } + TermStatistics existing = termStatistics.get(terms[i]); + if (existing != null) { + assert terms[i].bytes().equals(existing.term()); + termStatistics.put( + terms[i], + new TermStatistics( + existing.term(), + existing.docFreq() + stats[i].docFreq(), + existing.totalTermFreq() + stats[i].totalTermFreq() + ) + ); + } else { + termStatistics.put(terms[i], stats[i]); + } + + } + + assert lEntry.fieldStatistics().containsKey(null) == false; + for (var entry : lEntry.fieldStatistics().entrySet()) { + String key = entry.getKey(); + CollectionStatistics value = entry.getValue(); + if (value == null) { + continue; + } + assert key != null; + CollectionStatistics existing = fieldStatistics.get(key); + if (existing != null) { + CollectionStatistics merged = new CollectionStatistics( + key, + existing.maxDoc() + value.maxDoc(), + existing.docCount() + value.docCount(), + existing.sumTotalTermFreq() + value.sumTotalTermFreq(), + existing.sumDocFreq() + value.sumDocFreq() + ); + fieldStatistics.put(key, merged); + } else { + fieldStatistics.put(key, value); + } + } + aggMaxDoc += lEntry.maxDoc(); + } + return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc); + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java index b0b3f15265920..8055ebb1a7358 100644 --- a/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java @@ -12,37 +12,47 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.Maps; +import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.InnerHitBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; import java.util.Iterator; import java.util.List; -import java.util.function.Supplier; /** * This search phase is an optional phase that will be executed once all hits are fetched from the shards that executes * field-collapsing on the inner hits. This phase only executes if field collapsing is requested in the search request and otherwise * forwards to the next phase immediately. */ -final class ExpandSearchPhase extends SearchPhase { +class ExpandSearchPhase extends SearchPhase { static final String NAME = "expand"; private final AbstractSearchAsyncAction context; - private final SearchHits searchHits; - private final Supplier nextPhase; + private final SearchResponseSections searchResponseSections; + private final AtomicArray queryPhaseResults; - ExpandSearchPhase(AbstractSearchAsyncAction context, SearchHits searchHits, Supplier nextPhase) { + ExpandSearchPhase( + AbstractSearchAsyncAction context, + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { super(NAME); this.context = context; - this.searchHits = searchHits; - this.nextPhase = nextPhase; + this.searchResponseSections = searchResponseSections; + this.queryPhaseResults = queryPhaseResults; + } + + // protected for tests + protected SearchPhase nextPhase() { + return new FetchLookupFieldsPhase(context, searchResponseSections, queryPhaseResults); } /** @@ -55,14 +65,15 @@ private boolean isCollapseRequest() { @Override protected void run() { + var searchHits = searchResponseSections.hits(); if (isCollapseRequest() == false || searchHits.getHits().length == 0) { onPhaseDone(); } else { - doRun(); + doRun(searchHits); } } - private void doRun() { + private void doRun(SearchHits searchHits) { SearchRequest searchRequest = context.getRequest(); CollapseBuilder collapseBuilder = searchRequest.source().collapse(); final List innerHitBuilders = collapseBuilder.getInnerHits(); @@ -171,6 +182,6 @@ private static SearchSourceBuilder buildExpandSearchSourceBuilder(InnerHitBuilde } private void onPhaseDone() { - context.executeNextPhase(NAME, nextPhase); + context.executeNextPhase(NAME, this::nextPhase); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java index 2e98d50196490..9aba4efa03bf4 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java @@ -51,9 +51,7 @@ final class FetchLookupFieldsPhase extends SearchPhase { this.queryResults = queryResults; } - private record Cluster(String clusterAlias, List hitsWithLookupFields, List lookupFields) { - - } + private record Cluster(String clusterAlias, List hitsWithLookupFields, List lookupFields) {} private static List groupLookupFieldsByClusterAlias(SearchHits searchHits) { final Map> perClusters = new HashMap<>(); @@ -80,7 +78,7 @@ private static List groupLookupFieldsByClusterAlias(SearchHits searchHi protected void run() { final List clusters = groupLookupFieldsByClusterAlias(searchResponse.hits); if (clusters.isEmpty()) { - context.sendSearchResponse(searchResponse, queryResults); + sendResponse(); return; } doRun(clusters); @@ -132,9 +130,9 @@ public void onResponse(MultiSearchResponse items) { } } if (failure != null) { - context.onPhaseFailure(NAME, "failed to fetch lookup fields", failure); + onFailure(failure); } else { - context.sendSearchResponse(searchResponse, queryResults); + sendResponse(); } } @@ -144,4 +142,8 @@ public void onFailure(Exception e) { } }); } + + private void sendResponse() { + context.sendSearchResponse(searchResponse, queryResults); + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index 080295210fced..e63a3ef5b979f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -27,18 +27,16 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; /** * This search phase merges the query results from the previous phase together and calculates the topN hits for this search. * Then it reaches out to all relevant shards to fetch the topN hits. */ -final class FetchSearchPhase extends SearchPhase { +class FetchSearchPhase extends SearchPhase { static final String NAME = "fetch"; private final AtomicArray searchPhaseShardResults; - private final BiFunction, SearchPhase> nextPhaseFactory; private final AbstractSearchAsyncAction context; private final Logger logger; private final SearchProgressListener progressListener; @@ -52,26 +50,6 @@ final class FetchSearchPhase extends SearchPhase { AggregatedDfs aggregatedDfs, AbstractSearchAsyncAction context, @Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase - ) { - this( - resultConsumer, - aggregatedDfs, - context, - reducedQueryPhase, - (response, queryPhaseResults) -> new ExpandSearchPhase( - context, - response.hits, - () -> new FetchLookupFieldsPhase(context, response, queryPhaseResults) - ) - ); - } - - FetchSearchPhase( - SearchPhaseResults resultConsumer, - AggregatedDfs aggregatedDfs, - AbstractSearchAsyncAction context, - @Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase, - BiFunction, SearchPhase> nextPhaseFactory ) { super(NAME); if (context.getNumShards() != resultConsumer.getNumShards()) { @@ -84,7 +62,6 @@ final class FetchSearchPhase extends SearchPhase { } this.searchPhaseShardResults = resultConsumer.getAtomicArray(); this.aggregatedDfs = aggregatedDfs; - this.nextPhaseFactory = nextPhaseFactory; this.context = context; this.logger = context.getLogger(); this.progressListener = context.getTask().getProgressListener(); @@ -92,6 +69,11 @@ final class FetchSearchPhase extends SearchPhase { this.resultConsumer = reducedQueryPhase == null ? resultConsumer : null; } + // protected for tests + protected SearchPhase nextPhase(SearchResponseSections searchResponseSections, AtomicArray queryPhaseResults) { + return new ExpandSearchPhase(context, searchResponseSections, queryPhaseResults); + } + @Override protected void run() { context.execute(new AbstractRunnable() { @@ -115,7 +97,7 @@ private void innerRun() throws Exception { final int numShards = context.getNumShards(); // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase. - final boolean queryAndFetchOptimization = searchPhaseShardResults.length() == 1 + final boolean queryAndFetchOptimization = numShards == 1 && context.getRequest().hasKnnSearch() == false && reducedQueryPhase.queryPhaseRankCoordinatorContext() == null && (context.getRequest().source() == null || context.getRequest().source().rankBuilder() == null); @@ -130,7 +112,7 @@ private void innerRun() throws Exception { // we have to release contexts here to free up resources searchPhaseShardResults.asList() .forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context)); - moveToNextPhase(new AtomicArray<>(numShards), reducedQueryPhase); + moveToNextPhase(new AtomicArray<>(0), reducedQueryPhase); } else { innerRunFetch(scoreDocs, numShards, reducedQueryPhase); } @@ -281,7 +263,7 @@ private void moveToNextPhase( context.executeNextPhase(NAME, () -> { var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr); context.addReleasable(resp); - return nextPhaseFactory.apply(resp, searchPhaseShardResults); + return nextPhase(resp, searchPhaseShardResults); }); } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 056806fbb0b00..dd97f02dd8f40 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -10,29 +10,16 @@ package org.elasticsearch.action.search; import org.apache.logging.log4j.Logger; -import org.apache.lucene.index.Term; -import org.apache.lucene.search.CollectionStatistics; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TermStatistics; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; -import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.dfs.AggregatedDfs; -import org.elasticsearch.search.dfs.DfsKnnResults; import org.elasticsearch.search.dfs.DfsSearchResult; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.transport.Transport; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; @@ -102,111 +89,11 @@ protected void executePhaseOnShard( @Override protected SearchPhase getNextPhase() { - final List dfsSearchResults = results.getAtomicArray().asList(); - final AggregatedDfs aggregatedDfs = aggregateDfs(dfsSearchResults); - return new DfsQueryPhase( - dfsSearchResults, - aggregatedDfs, - mergeKnnResults(getRequest(), dfsSearchResults), - queryPhaseResultConsumer, - (queryResults) -> SearchQueryThenFetchAsyncAction.nextPhase(client, this, queryResults, aggregatedDfs), - this - ); + return new DfsQueryPhase(queryPhaseResultConsumer, client, this); } @Override protected void onShardGroupFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { progressListener.notifyQueryFailure(shardIndex, shardTarget, exc); } - - private static List mergeKnnResults(SearchRequest request, List dfsSearchResults) { - if (request.hasKnnSearch() == false) { - return null; - } - SearchSourceBuilder source = request.source(); - List> topDocsLists = new ArrayList<>(source.knnSearch().size()); - List> nestedPath = new ArrayList<>(source.knnSearch().size()); - for (int i = 0; i < source.knnSearch().size(); i++) { - topDocsLists.add(new ArrayList<>()); - nestedPath.add(new SetOnce<>()); - } - - for (DfsSearchResult dfsSearchResult : dfsSearchResults) { - if (dfsSearchResult.knnResults() != null) { - for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) { - DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i); - ScoreDoc[] scoreDocs = knnResults.scoreDocs(); - TotalHits totalHits = new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO); - TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs); - SearchPhaseController.setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex()); - topDocsLists.get(i).add(shardTopDocs); - nestedPath.get(i).trySet(knnResults.getNestedPath()); - } - } - } - - List mergedResults = new ArrayList<>(source.knnSearch().size()); - for (int i = 0; i < source.knnSearch().size(); i++) { - TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0])); - mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs)); - } - return mergedResults; - } - - private static AggregatedDfs aggregateDfs(Collection results) { - Map termStatistics = new HashMap<>(); - Map fieldStatistics = new HashMap<>(); - long aggMaxDoc = 0; - for (DfsSearchResult lEntry : results) { - final Term[] terms = lEntry.terms(); - final TermStatistics[] stats = lEntry.termStatistics(); - assert terms.length == stats.length; - for (int i = 0; i < terms.length; i++) { - assert terms[i] != null; - if (stats[i] == null) { - continue; - } - TermStatistics existing = termStatistics.get(terms[i]); - if (existing != null) { - assert terms[i].bytes().equals(existing.term()); - termStatistics.put( - terms[i], - new TermStatistics( - existing.term(), - existing.docFreq() + stats[i].docFreq(), - existing.totalTermFreq() + stats[i].totalTermFreq() - ) - ); - } else { - termStatistics.put(terms[i], stats[i]); - } - - } - - assert lEntry.fieldStatistics().containsKey(null) == false; - for (var entry : lEntry.fieldStatistics().entrySet()) { - String key = entry.getKey(); - CollectionStatistics value = entry.getValue(); - if (value == null) { - continue; - } - assert key != null; - CollectionStatistics existing = fieldStatistics.get(key); - if (existing != null) { - CollectionStatistics merged = new CollectionStatistics( - key, - existing.maxDoc() + value.maxDoc(), - existing.docCount() + value.docCount(), - existing.sumTotalTermFreq() + value.sumTotalTermFreq(), - existing.sumDocFreq() + value.sumDocFreq() - ); - fieldStatistics.put(key, merged); - } else { - fieldStatistics.put(key, value); - } - } - aggMaxDoc += lEntry.maxDoc(); - } - return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc); - } } diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index bf62973b9b052..43292c4f65245 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.dfs.AggregatedDfs; import org.elasticsearch.search.dfs.DfsKnnResults; import org.elasticsearch.search.dfs.DfsSearchResult; import org.elasticsearch.search.internal.AliasFilter; @@ -139,12 +140,7 @@ public void sendExecuteQuery( exc -> {} ) ) { - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { - @Override - protected void run() { - responseRef.set(((QueryPhaseResultConsumer) response).results); - } - }, mockSearchPhaseContext); + DfsQueryPhase phase = makeDfsPhase(results, consumer, mockSearchPhaseContext, responseRef); assertEquals("dfs_query", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -225,12 +221,7 @@ public void sendExecuteQuery( exc -> {} ) ) { - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { - @Override - protected void run() { - responseRef.set(((QueryPhaseResultConsumer) response).results); - } - }, mockSearchPhaseContext); + DfsQueryPhase phase = makeDfsPhase(results, consumer, mockSearchPhaseContext, responseRef); assertEquals("dfs_query", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -313,12 +304,7 @@ public void sendExecuteQuery( exc -> {} ) ) { - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { - @Override - protected void run() { - responseRef.set(((QueryPhaseResultConsumer) response).results); - } - }, mockSearchPhaseContext); + DfsQueryPhase phase = makeDfsPhase(results, consumer, mockSearchPhaseContext, responseRef); assertEquals("dfs_query", phase.getName()); phase.run(); assertThat(mockSearchPhaseContext.failures, hasSize(1)); @@ -328,6 +314,29 @@ protected void run() { } } + private static DfsQueryPhase makeDfsPhase( + AtomicArray results, + SearchPhaseResults consumer, + MockSearchPhaseContext mockSearchPhaseContext, + AtomicReference> responseRef + ) { + int shards = mockSearchPhaseContext.numShards; + for (int i = 0; i < shards; i++) { + mockSearchPhaseContext.results.getAtomicArray().set(i, results.get(i)); + } + return new DfsQueryPhase(consumer, null, mockSearchPhaseContext) { + @Override + protected SearchPhase nextPhase(AggregatedDfs dfs) { + return new SearchPhase("test") { + @Override + public void run() { + responseRef.set(((QueryPhaseResultConsumer) consumer).results); + } + }; + } + }; + } + public void testRewriteShardSearchRequestWithRank() { List dkrs = List.of( new DfsKnnResults(null, new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1), new ScoreDoc(7, 0.1f, 2) }), @@ -338,7 +347,7 @@ public void testRewriteShardSearchRequestWithRank() { ); MockSearchPhaseContext mspc = new MockSearchPhaseContext(2); mspc.searchTransport = new SearchTransportService(null, null, null); - DfsQueryPhase dqp = new DfsQueryPhase(null, null, dkrs, mock(QueryPhaseResultConsumer.class), null, mspc); + DfsQueryPhase dqp = new DfsQueryPhase(mock(QueryPhaseResultConsumer.class), null, mspc); QueryBuilder bm25 = new TermQueryBuilder("field", "term"); SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25) @@ -352,7 +361,7 @@ public void testRewriteShardSearchRequestWithRank() { SearchRequest sr = new SearchRequest().allowPartialSearchResults(true).source(ssb); ShardSearchRequest ssr = new ShardSearchRequest(null, sr, new ShardId("test", "testuuid", 1), 1, 1, null, 1.0f, 0, null); - dqp.rewriteShardSearchRequest(ssr); + dqp.rewriteShardSearchRequest(dkrs, ssr); KnnScoreDocQueryBuilder ksdqb0 = new KnnScoreDocQueryBuilder( new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1) }, diff --git a/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java index 99e695228fd33..cdb98dce2533d 100644 --- a/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.search.AbstractSearchTestCase; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchResponseUtils; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -118,14 +119,11 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL hit.setDocumentField("someField", new DocumentField("someField", Collections.singletonList(collapseValue))); SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - protected void run() { - try (var sections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { - mockSearchPhaseContext.sendSearchResponse(sections, null); - } - } - }); + ExpandSearchPhase phase = newExpandSearchPhase( + mockSearchPhaseContext, + new SearchResponseSections(hits, null, null, false, null, null, 1), + null + ); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -149,7 +147,6 @@ protected void run() { if (resp != null) { resp.decRef(); } - } } } @@ -192,15 +189,8 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL SearchHit hit2 = new SearchHit(2, "ID2"); hit2.setDocumentField("someField", new DocumentField("someField", Collections.singletonList(collapseValue))); SearchHits hits = new SearchHits(new SearchHit[] { hit1, hit2 }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); - try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - protected void run() { - try (var sections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { - mockSearchPhaseContext.sendSearchResponse(sections, null); - } - } - }); + try (SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, null); phase.run(); assertThat(mockSearchPhaseContext.phaseFailure.get(), Matchers.instanceOf(RuntimeException.class)); assertEquals("boom", mockSearchPhaseContext.phaseFailure.get().getMessage()); @@ -229,14 +219,11 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL hit2.setDocumentField("someField", new DocumentField("someField", Collections.singletonList(null))); SearchHits hits = new SearchHits(new SearchHit[] { hit1, hit2 }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - protected void run() { - try (var sections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { - mockSearchPhaseContext.sendSearchResponse(sections, null); - } - } - }); + ExpandSearchPhase phase = newExpandSearchPhase( + mockSearchPhaseContext, + new SearchResponseSections(hits, null, null, false, null, null, 1), + null + ); phase.run(); mockSearchPhaseContext.assertNoFailure(); assertNotNull(mockSearchPhaseContext.searchResponse.get()); @@ -269,12 +256,8 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL ); SearchHits hits = SearchHits.empty(new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f); - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - protected void run() { - mockSearchPhaseContext.sendSearchResponse(new SearchResponseSections(hits, null, null, false, null, null, 1), null); - } - }); + final SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1); + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, null); phase.run(); mockSearchPhaseContext.assertNoFailure(); assertNotNull(mockSearchPhaseContext.searchResponse.get()); @@ -318,13 +301,8 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL SearchHit hit = new SearchHit(1, "ID"); hit.setDocumentField("someField", new DocumentField("someField", Collections.singletonList("foo"))); SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); - try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - protected void run() { - mockSearchPhaseContext.sendSearchResponse(new SearchResponseSections(hits, null, null, false, null, null, 1), null); - } - }); + try (SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, null); phase.run(); mockSearchPhaseContext.assertNoFailure(); } finally { @@ -339,6 +317,26 @@ protected void run() { } } + private static ExpandSearchPhase newExpandSearchPhase( + MockSearchPhaseContext mockSearchPhaseContext, + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return new ExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, queryPhaseResults) { + @Override + protected SearchPhase nextPhase() { + return new SearchPhase("test") { + @Override + public void run() { + try (searchResponseSections) { + mockSearchPhaseContext.sendSearchResponse(searchResponseSections, queryPhaseResults); + } + } + }; + } + }; + } + public void testExpandSearchRespectsOriginalPIT() { MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); final PointInTimeBuilder pit = new PointInTimeBuilder(new BytesArray("foo")); @@ -367,16 +365,8 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL SearchHit hit = new SearchHit(1, "ID"); hit.setDocumentField("someField", new DocumentField("someField", Collections.singletonList("foo"))); SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); - try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - protected void run() { - mockSearchPhaseContext.sendSearchResponse( - new SearchResponseSections(hits, null, null, false, null, null, 1), - new AtomicArray<>(0) - ); - } - }); + try (SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, new AtomicArray<>(0)); phase.run(); mockSearchPhaseContext.assertNoFailure(); } finally { diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java index fd60621c7e400..2be679c91bd36 100644 --- a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java @@ -134,13 +134,7 @@ public void testShortcutQueryAndFetchOptimization() throws Exception { numHits = 0; } SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -263,13 +257,7 @@ public void sendExecuteFetch( } }; SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -373,13 +361,7 @@ public void sendExecuteFetch( } }; SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -481,19 +463,21 @@ public void sendExecuteFetch( }; CountDownLatch latch = new CountDownLatch(1); SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - (searchResponse, scrollId) -> new SearchPhase("test") { - @Override - protected void run() { - mockSearchPhaseContext.sendSearchResponse(searchResponse, null); - latch.countDown(); - } + FetchSearchPhase phase = new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return new SearchPhase("test") { + @Override + public void run() { + mockSearchPhaseContext.sendSearchResponse(searchResponseSections, null); + latch.countDown(); + } + }; } - ); + }; assertEquals("fetch", phase.getName()); phase.run(); latch.await(); @@ -621,13 +605,7 @@ public void sendExecuteFetch( } }; SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); assertNotNull(mockSearchPhaseContext.searchResponse.get()); @@ -641,6 +619,22 @@ public void sendExecuteFetch( } } + private static FetchSearchPhase getFetchSearchPhase( + SearchPhaseResults results, + MockSearchPhaseContext mockSearchPhaseContext, + SearchPhaseController.ReducedQueryPhase reducedQueryPhase + ) { + return new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return searchPhaseFactory(mockSearchPhaseContext).apply(searchResponseSections, queryPhaseResults); + } + }; + } + public void testCleanupIrrelevantContexts() throws Exception { // contexts that are not fetched should be cleaned up MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); SearchPhaseController controller = new SearchPhaseController((t, s) -> InternalAggregationTestCase.emptyReduceContextBuilder()); @@ -723,13 +717,7 @@ public void sendExecuteFetch( } }; SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); From c40c5a6c0a0d23a665f9c1ccddff00b8d5a68fed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 24 Feb 2025 14:52:41 +0100 Subject: [PATCH 14/15] ESQL: Fix functions emitting warnings with no source (#122821) Fixes https://github.com/elastic/elasticsearch/issues/122588 - Replaced `Source.EMPTY.writeTo(out)` to `source().writeTo(out)` in functions emitting warnings - Did the same on all aggs, as Top emits an error on type resolution. This is not a bug, as type resolution errors should only happen in the coordinator. Another option would be changing Top to not generate that error there, and make it implement instead `PostAnalysisVerificationAware` - In some cases, we don't even serialize an empty source. So I had to add a new `TransportVersion` to do so - As an special case, `ToLower` and `ToUpper` weren't serializing a source, but they don't emit warnings. As they were the only remaining functions not serializing the source, I added it there too --- docs/changelog/122821.yaml | 6 ++ .../org/elasticsearch/TransportVersions.java | 1 + .../src/main/resources/boolean.csv-spec | 20 +++++ .../src/main/resources/date.csv-spec | 21 +++++ .../src/main/resources/math.csv-spec | 27 ++++++ .../src/main/resources/string.csv-spec | 19 +++++ .../xpack/esql/action/EsqlCapabilities.java | 5 ++ .../function/aggregate/AggregateFunction.java | 2 +- .../function/scalar/date/DateDiff.java | 2 +- .../AbstractMultivalueFunction.java | 2 +- .../function/scalar/multivalue/MvSlice.java | 2 +- .../scalar/spatial/BinarySpatialFunction.java | 9 +- .../function/scalar/string/Replace.java | 9 +- .../function/scalar/string/ToLower.java | 12 ++- .../function/scalar/string/ToUpper.java | 12 ++- .../expression/function/TestCaseSupplier.java | 3 + .../aggregate/AvgSerializationTests.java | 5 -- .../CountDistinctSerializationTests.java | 5 -- .../aggregate/CountSerializationTests.java | 5 -- .../aggregate/MaxSerializationTests.java | 5 -- ...anAbsoluteDeviationSerializationTests.java | 5 -- .../aggregate/MedianSerializationTests.java | 5 -- .../aggregate/MinSerializationTests.java | 5 -- .../PercentileSerializationTests.java | 5 -- .../aggregate/RateSerializationTests.java | 5 -- .../SpatialCentroidSerializationTests.java | 5 -- .../aggregate/SumSerializationTests.java | 5 -- .../aggregate/TopSerializationTests.java | 5 -- .../function/aggregate/TopTests.java | 4 +- .../aggregate/ValuesSerializationTests.java | 5 -- .../scalar/conditional/CaseTests.java | 2 +- .../date/DateDiffSerializationTests.java | 5 -- .../multivalue/MvAvgSerializationTests.java | 5 -- .../multivalue/MvCountSerializationTests.java | 5 -- .../MvDedupeSerializationTests.java | 5 -- .../multivalue/MvFirstSerializationTests.java | 5 -- .../multivalue/MvLastSerializationTests.java | 5 -- .../multivalue/MvMaxSerializationTests.java | 5 -- ...anAbsoluteDeviationSerializationTests.java | 5 -- .../MvMedianSerializationTests.java | 5 -- .../multivalue/MvMinSerializationTests.java | 5 -- .../multivalue/MvSliceSerializationTests.java | 5 -- .../scalar/multivalue/MvSliceTests.java | 59 +++++++++++++ .../multivalue/MvSumSerializationTests.java | 5 -- ...ySpatialFunctionSerializationTestCase.java | 5 -- .../string/ReplaceSerializationTests.java | 5 -- .../string/ToLowerSerializationTests.java | 5 -- .../string/ToUpperSerializationTests.java | 5 -- .../operator/comparison/InTests.java | 84 ++++++------------- 49 files changed, 230 insertions(+), 216 deletions(-) create mode 100644 docs/changelog/122821.yaml diff --git a/docs/changelog/122821.yaml b/docs/changelog/122821.yaml new file mode 100644 index 0000000000000..8773b6f77c4b6 --- /dev/null +++ b/docs/changelog/122821.yaml @@ -0,0 +1,6 @@ +pr: 122821 +summary: Fix functions emitting warnings with no source +area: ES|QL +type: bug +issues: + - 122588 diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 1022dfc3b48fe..35cd890af7fc9 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -203,6 +203,7 @@ static TransportVersion def(int id) { public static final TransportVersion RERANKER_FAILURES_ALLOWED = def(9_013_0_00); public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_014_0_00); public static final TransportVersion BYTE_SIZE_VALUE_ALWAYS_USES_BYTES = def(9_015_0_00); + public static final TransportVersion ESQL_SERIALIZE_SOURCE_FUNCTIONS_WARNINGS = def(9_016_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/boolean.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/boolean.csv-spec index 1e23cf62917fc..b639129ab084d 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/boolean.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/boolean.csv-spec @@ -287,6 +287,26 @@ emp_no:integer | is_rehired:boolean | a1:boolean 10005 | [false,false,false,true] | false ; +mvSliceWarnings +required_capability: functions_source_serialization_warnings + +FROM employees +| SORT first_name ASC +| EVAL + start = CASE(first_name == "Alejandro", 1, 0), + end = CASE(first_name == "Alejandro", 0, 1), + result = MV_SLICE(is_rehired, start, end) +| KEEP first_name, result +| LIMIT 1 +; + +warning:Line 6:10: evaluation of [MV_SLICE(is_rehired, start, end)] failed, treating result as null. Only first 20 failures recorded. +warning:Line 6:10: org.elasticsearch.xpack.esql.core.InvalidArgumentException: Start offset is greater than end offset + +first_name:keyword | result:boolean +Alejandro | null +; + values required_capability: agg_values diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date.csv-spec index e5cab8de8092b..1cb4784923e10 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date.csv-spec @@ -366,6 +366,27 @@ date1:date | dd_ms:integer 2023-12-02T11:00:00.000Z | 1000 ; +dateDiffTestWarnings +required_capability: functions_source_serialization_warnings + +FROM employees +| WHERE first_name IN ("Alejandro", "Mary") +| SORT first_name ASC +| EVAL date = TO_DATETIME("2023-12-02T11:00:00.000Z") +| EVAL dd_ms = DATE_DIFF(first_name, date, date) +| KEEP date, dd_ms +| LIMIT 2 +; + +warning:Line 5:16: evaluation of [DATE_DIFF(first_name, date, date)] failed, treating result as null. Only first 20 failures recorded. +warning:Line 5:16: java.lang.IllegalArgumentException: A value of [YEAR, QUARTER, MONTH, DAYOFYEAR, DAY, WEEK, WEEKDAY, HOUR, MINUTE, SECOND, MILLISECOND, MICROSECOND, NANOSECOND] or their aliases is required; received [Alejandro] +warning:Line 5:16: java.lang.IllegalArgumentException: Received value [Mary] is not valid date part to add; did you mean [day]? + +date:date | dd_ms:integer +2023-12-02T11:00:00.000Z | null +2023-12-02T11:00:00.000Z | null +; + evalDateDiffMonthAsWhole0Months#[skip:-8.14.1, reason:omitting millis/timezone not allowed before 8.14] ROW from=TO_DATETIME("2023-12-31T23:59:59.999Z"), to=TO_DATETIME("2024-01-01T00:00:00") diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec index 56486b8954abe..f2d5451c16316 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec @@ -1009,6 +1009,33 @@ mvsum:unsigned_long null ; +mvSumOverflowOnNode +required_capability: functions_source_serialization_warnings + +FROM employees +| SORT first_name ASC +| EVAL + ints = CASE(first_name == "Alejandro", [0, 1, 2147483647], 0), + longs = CASE(first_name == "Alejandro", [0, 1, 9223372036854775807], 0), + ulongs = CASE(first_name == "Alejandro", [0, 1, 18446744073709551615], 0), + intsResult = mv_sum(ints), + longsResult = mv_sum(longs), + ulongsResult = mv_sum(ulongs) +| KEEP first_name, intsResult, longsResult, ulongsResult +| LIMIT 1 +; + +warning:Line 7:14: evaluation of [mv_sum(ints)] failed, treating result as null. Only first 20 failures recorded. +warning:Line 7:14: java.lang.ArithmeticException: integer overflow +warning:Line 8:15: evaluation of [mv_sum(longs)] failed, treating result as null. Only first 20 failures recorded. +warning:Line 8:15: java.lang.ArithmeticException: long overflow +warning:Line 9:16: evaluation of [mv_sum(ulongs)] failed, treating result as null. Only first 20 failures recorded. +warning:Line 9:16: java.lang.ArithmeticException: unsigned_long overflow + +first_name:keyword | intsResult:integer | longsResult:long | ulongsResult:unsigned_long +Alejandro | null | null | null +; + e // tag::e[] ROW E() diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec index cffb1f950fc6f..d2066c5ebe321 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec @@ -1068,6 +1068,25 @@ Gatewayinstances|Gateway instances null |null ; +replaceWarnings +required_capability: functions_source_serialization_warnings + +FROM employees +| SORT first_name ASC +| EVAL + regex = CASE(first_name == "Alejandro", "(", ""), + result = replace(first_name, regex, "") +| KEEP first_name, result +| LIMIT 1 +; + +warningRegex:Line 5:10: evaluation of \[replace\(first_name, regex, \\"\\"\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:Line 5:10: java.util.regex.PatternSyntaxException: Unclosed group near index 1(%0D)?%0A\( + +first_name:keyword | result:keyword +Alejandro | null +; + left // tag::left[] FROM employees diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 980e5402ea560..5f88f9f348276 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -204,6 +204,11 @@ public enum Cap { */ FN_ROUND_UL_FIXES, + /** + * Fixes for multiple functions not serializing their source, and emitting warnings with wrong line number and text. + */ + FUNCTIONS_SOURCE_SERIALIZATION_WARNINGS, + /** * All functions that take TEXT should never emit TEXT, only KEYWORD. #114334 */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index 8aa7f697489c6..aab893e6ed5cc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -69,7 +69,7 @@ protected AggregateFunction(StreamInput in) throws IOException { @Override public final void writeTo(StreamOutput out) throws IOException { - Source.EMPTY.writeTo(out); + source().writeTo(out); out.writeNamedWriteable(field); if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeNamedWriteable(filter); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java index 4d843ea7180a9..f3da7e07f09c9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java @@ -197,7 +197,7 @@ private DateDiff(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { - Source.EMPTY.writeTo(out); + source().writeTo(out); out.writeNamedWriteable(unit); out.writeNamedWriteable(startTimestamp); out.writeNamedWriteable(endTimestamp); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java index a32761cfd9948..03ede5a424b55 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java @@ -41,7 +41,7 @@ protected AbstractMultivalueFunction(StreamInput in) throws IOException { @Override public final void writeTo(StreamOutput out) throws IOException { - Source.EMPTY.writeTo(out); + source().writeTo(out); out.writeNamedWriteable(field); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSlice.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSlice.java index 3bffb7853d3b4..8cd10f3045fea 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSlice.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSlice.java @@ -132,7 +132,7 @@ private MvSlice(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { - Source.EMPTY.writeTo(out); + source().writeTo(out); out.writeNamedWriteable(field); out.writeNamedWriteable(start); out.writeOptionalNamedWriteable(end); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/BinarySpatialFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/BinarySpatialFunction.java index 25c0607155afd..f2c16c78b291b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/BinarySpatialFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/BinarySpatialFunction.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.spatial; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.geometry.Geometry; @@ -20,6 +21,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes; import org.elasticsearch.xpack.esql.expression.EsqlTypeResolutions; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import java.io.IOException; @@ -63,7 +65,9 @@ protected BinarySpatialFunction( protected BinarySpatialFunction(StreamInput in, boolean leftDocValues, boolean rightDocValues, boolean pointsOnly) throws IOException { // The doc-values fields are only used on data nodes local planning, and therefor never serialized this( - Source.EMPTY, + in.getTransportVersion().onOrAfter(TransportVersions.ESQL_SERIALIZE_SOURCE_FUNCTIONS_WARNINGS) + ? Source.readFrom((PlanStreamInput) in) + : Source.EMPTY, in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class), leftDocValues, @@ -74,6 +78,9 @@ protected BinarySpatialFunction(StreamInput in, boolean leftDocValues, boolean r @Override public void writeTo(StreamOutput out) throws IOException { + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_SERIALIZE_SOURCE_FUNCTIONS_WARNINGS)) { + source().writeTo(out); + } out.writeNamedWriteable(left()); out.writeNamedWriteable(right()); // The doc-values fields are only used on data nodes local planning, and therefor never serialized diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Replace.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Replace.java index 4b963b794aef0..1fbe6bec85121 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Replace.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Replace.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.string; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -22,6 +23,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import java.io.IOException; import java.util.Arrays; @@ -66,7 +68,9 @@ public Replace( private Replace(StreamInput in) throws IOException { this( - Source.EMPTY, + in.getTransportVersion().onOrAfter(TransportVersions.ESQL_SERIALIZE_SOURCE_FUNCTIONS_WARNINGS) + ? Source.readFrom((PlanStreamInput) in) + : Source.EMPTY, in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class) @@ -75,6 +79,9 @@ private Replace(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_SERIALIZE_SOURCE_FUNCTIONS_WARNINGS)) { + source().writeTo(out); + } out.writeNamedWriteable(str); out.writeNamedWriteable(regex); out.writeNamedWriteable(newStr); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLower.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLower.java index 084afb1b69996..eef0241fcd8a9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLower.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLower.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.string; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -42,11 +43,20 @@ public ToLower( } private ToLower(StreamInput in) throws IOException { - this(Source.EMPTY, in.readNamedWriteable(Expression.class), ((PlanStreamInput) in).configuration()); + this( + in.getTransportVersion().onOrAfter(TransportVersions.ESQL_SERIALIZE_SOURCE_FUNCTIONS_WARNINGS) + ? Source.readFrom((PlanStreamInput) in) + : Source.EMPTY, + in.readNamedWriteable(Expression.class), + ((PlanStreamInput) in).configuration() + ); } @Override public void writeTo(StreamOutput out) throws IOException { + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_SERIALIZE_SOURCE_FUNCTIONS_WARNINGS)) { + source().writeTo(out); + } out.writeNamedWriteable(field()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpper.java index 4509404754f36..01994c1b087fc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpper.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.string; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -42,11 +43,20 @@ public ToUpper( } private ToUpper(StreamInput in) throws IOException { - this(Source.EMPTY, in.readNamedWriteable(Expression.class), ((PlanStreamInput) in).configuration()); + this( + in.getTransportVersion().onOrAfter(TransportVersions.ESQL_SERIALIZE_SOURCE_FUNCTIONS_WARNINGS) + ? Source.readFrom((PlanStreamInput) in) + : Source.EMPTY, + in.readNamedWriteable(Expression.class), + ((PlanStreamInput) in).configuration() + ); } @Override public void writeTo(StreamOutput out) throws IOException { + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_SERIALIZE_SOURCE_FUNCTIONS_WARNINGS)) { + source().writeTo(out); + } out.writeNamedWriteable(field()); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java index f0e338c0cf284..66ca512eb135d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java @@ -117,6 +117,9 @@ public static List stringCases( @Override public TestCase get() { TestCase supplied = supplier.get(); + if (types.size() != supplied.getData().size()) { + throw new IllegalStateException(name + ": type/data size mismatch " + types.size() + "/" + supplied.getData().size()); + } for (int i = 0; i < types.size(); i++) { if (supplied.getData().get(i).type() != types.get(i)) { throw new IllegalStateException( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java index 52d3128af5c1c..704a57be15310 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java @@ -21,9 +21,4 @@ protected Avg createTestInstance() { protected Avg mutateInstance(Avg instance) throws IOException { return new Avg(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinctSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinctSerializationTests.java index ab06b0b58f7f0..1dbaaf16a559c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinctSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinctSerializationTests.java @@ -34,9 +34,4 @@ protected CountDistinct mutateInstance(CountDistinct instance) throws IOExceptio } return new CountDistinct(source, field, precision); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountSerializationTests.java index 133979f66860c..39defdd9c0777 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountSerializationTests.java @@ -21,9 +21,4 @@ protected Count createTestInstance() { protected Count mutateInstance(Count instance) throws IOException { return new Count(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxSerializationTests.java index 7a732883a99a5..c1c4898f1e057 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxSerializationTests.java @@ -21,9 +21,4 @@ protected Max createTestInstance() { protected Max mutateInstance(Max instance) throws IOException { return new Max(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviationSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviationSerializationTests.java index bdbe839c46a75..1d6497541e29b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviationSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviationSerializationTests.java @@ -24,9 +24,4 @@ protected MedianAbsoluteDeviation mutateInstance(MedianAbsoluteDeviation instanc randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild) ); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianSerializationTests.java index 75161977319ea..f0b2d63da70d1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianSerializationTests.java @@ -21,9 +21,4 @@ protected Median createTestInstance() { protected Median mutateInstance(Median instance) throws IOException { return new Median(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinSerializationTests.java index 1ff434d8d2a76..18e813acdb74c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinSerializationTests.java @@ -21,9 +21,4 @@ protected Min createTestInstance() { protected Min mutateInstance(Min instance) throws IOException { return new Min(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PercentileSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PercentileSerializationTests.java index a6349b9cb5c81..5c01b62f7d663 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PercentileSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PercentileSerializationTests.java @@ -34,9 +34,4 @@ protected Percentile mutateInstance(Percentile instance) throws IOException { } return new Percentile(source, field, percentile); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateSerializationTests.java index ea7c480817317..94b2a81b308d7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateSerializationTests.java @@ -36,9 +36,4 @@ protected Rate mutateInstance(Rate instance) throws IOException { } return new Rate(source, field, timestamp, unit); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroidSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroidSerializationTests.java index 037b7dc229b03..4d31cd3cf31ff 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroidSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroidSerializationTests.java @@ -24,9 +24,4 @@ protected SpatialCentroid mutateInstance(SpatialCentroid instance) throws IOExce randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild) ); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java index 863392f7eb451..8126b4a30bdb0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java @@ -21,9 +21,4 @@ protected Sum createTestInstance() { protected Sum mutateInstance(Sum instance) throws IOException { return new Sum(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java index e74b26c87c84f..82bf57d1a194e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java @@ -36,9 +36,4 @@ protected Top mutateInstance(Top instance) throws IOException { } return new Top(source, field, limit, order); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java index fd2d42c6e988b..052f0f491dbed 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java @@ -126,7 +126,7 @@ public static Iterable parameters() { ) ), new TestCaseSupplier( - List.of(DataType.IP), + List.of(DataType.IP, DataType.INTEGER, DataType.KEYWORD), () -> new TestCaseSupplier.TestCase( List.of( TestCaseSupplier.TypedData.multiRow( @@ -215,7 +215,7 @@ public static Iterable parameters() { ) ), new TestCaseSupplier( - List.of(DataType.IP), + List.of(DataType.IP, DataType.INTEGER, DataType.KEYWORD), () -> new TestCaseSupplier.TestCase( List.of( TestCaseSupplier.TypedData.multiRow( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesSerializationTests.java index 6787e8d1ad09a..aac10e14a6999 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesSerializationTests.java @@ -21,9 +21,4 @@ protected Values createTestInstance() { protected Values mutateInstance(Values instance) throws IOException { return new Values(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java index 2fa82b9f1caa2..184e8bc68ba69 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java @@ -532,7 +532,7 @@ private static void fourAndFiveArgs( suppliers.add( new TestCaseSupplier( "partial foldable 1 " + TestCaseSupplier.nameFrom(Arrays.asList(cond1, type, cond2, type)), - List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type), + List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type, type), () -> { Object r1 = randomLiteral(type).value(); Object r2 = randomLiteral(type).value(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiffSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiffSerializationTests.java index 77158b6f1866f..f1f8d1b0f8dad 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiffSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiffSerializationTests.java @@ -36,9 +36,4 @@ protected DateDiff mutateInstance(DateDiff instance) throws IOException { } return new DateDiff(source, unit, startTimestamp, endTimestamp); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvAvgSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvAvgSerializationTests.java index 271312622245d..0a3ddaf576c38 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvAvgSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvAvgSerializationTests.java @@ -21,9 +21,4 @@ protected MvAvg createTestInstance() { protected MvAvg mutateInstance(MvAvg instance) throws IOException { return new MvAvg(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvCountSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvCountSerializationTests.java index 0ec51d73982ec..be641e210782c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvCountSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvCountSerializationTests.java @@ -21,9 +21,4 @@ protected MvCount createTestInstance() { protected MvCount mutateInstance(MvCount instance) throws IOException { return new MvCount(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvDedupeSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvDedupeSerializationTests.java index 410167addf163..151b73c3fb9a3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvDedupeSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvDedupeSerializationTests.java @@ -21,9 +21,4 @@ protected MvDedupe createTestInstance() { protected MvDedupe mutateInstance(MvDedupe instance) throws IOException { return new MvDedupe(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvFirstSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvFirstSerializationTests.java index 8934dde1717c6..59d42abfbf291 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvFirstSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvFirstSerializationTests.java @@ -21,9 +21,4 @@ protected MvFirst createTestInstance() { protected MvFirst mutateInstance(MvFirst instance) throws IOException { return new MvFirst(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvLastSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvLastSerializationTests.java index 9c4ad7ab059ef..d3aada6879caa 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvLastSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvLastSerializationTests.java @@ -21,9 +21,4 @@ protected MvLast createTestInstance() { protected MvLast mutateInstance(MvLast instance) throws IOException { return new MvLast(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMaxSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMaxSerializationTests.java index 4ce5112c4e8e7..d4880b6814905 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMaxSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMaxSerializationTests.java @@ -21,9 +21,4 @@ protected MvMax createTestInstance() { protected MvMax mutateInstance(MvMax instance) throws IOException { return new MvMax(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationSerializationTests.java index 6a63c38c924d9..3e402a9f7422e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationSerializationTests.java @@ -24,9 +24,4 @@ protected MvMedianAbsoluteDeviation mutateInstance(MvMedianAbsoluteDeviation ins randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild) ); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianSerializationTests.java index 0e35ec6f77150..6b7a1dd5a8d0f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianSerializationTests.java @@ -21,9 +21,4 @@ protected MvMedian createTestInstance() { protected MvMedian mutateInstance(MvMedian instance) throws IOException { return new MvMedian(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMinSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMinSerializationTests.java index 0769e41a09921..f53aa660834cf 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMinSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMinSerializationTests.java @@ -21,9 +21,4 @@ protected MvMin createTestInstance() { protected MvMin mutateInstance(MvMin instance) throws IOException { return new MvMin(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSliceSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSliceSerializationTests.java index ffa355178b460..70f106b65f78d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSliceSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSliceSerializationTests.java @@ -36,9 +36,4 @@ protected MvSlice mutateInstance(MvSlice instance) throws IOException { } return new MvSlice(source, field, start, end); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSliceTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSliceTests.java index 24da717630733..e847c90bfb838 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSliceTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSliceTests.java @@ -13,6 +13,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.geo.GeometryTestUtils; import org.elasticsearch.geo.ShapeTestUtils; +import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -44,6 +45,64 @@ public static Iterable parameters() { longs(suppliers); doubles(suppliers); bytesRefs(suppliers); + + // Warnings cases + suppliers.stream().toList().forEach(supplier -> { + DataType firstArgumentType = supplier.types().get(0); + String evaluatorTypePart = switch (firstArgumentType) { + case BOOLEAN -> "Boolean"; + case INTEGER -> "Int"; + case LONG, DATE_NANOS, DATETIME, UNSIGNED_LONG -> "Long"; + case DOUBLE -> "Double"; + case KEYWORD, TEXT, SEMANTIC_TEXT, IP, VERSION, GEO_POINT, CARTESIAN_POINT, GEO_SHAPE, CARTESIAN_SHAPE -> "BytesRef"; + default -> throw new IllegalArgumentException("Unsupported type: " + firstArgumentType); + }; + + // Start offset greater than end offset + suppliers.add(new TestCaseSupplier(List.of(firstArgumentType, DataType.INTEGER, DataType.INTEGER), () -> { + int end = randomIntBetween(0, 10); + int start = randomIntBetween(end + 1, end + 10); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(List.of(randomLiteral(firstArgumentType).value()), firstArgumentType, "field"), + new TestCaseSupplier.TypedData(start, DataType.INTEGER, "start"), + new TestCaseSupplier.TypedData(end, DataType.INTEGER, "end") + ), + "MvSlice" + + evaluatorTypePart + + "Evaluator[field=Attribute[channel=0], start=Attribute[channel=1], end=Attribute[channel=2]]", + firstArgumentType, + nullValue() + ).withFoldingException(InvalidArgumentException.class, "Start offset is greater than end offset") + .withWarning("Line 1:1: evaluation of [source] failed, treating result as null. Only first 20 failures recorded.") + .withWarning( + "Line 1:1: org.elasticsearch.xpack.esql.core.InvalidArgumentException: Start offset is greater than end offset" + ); + })); + + // Negative start with positive end + suppliers.add(new TestCaseSupplier(List.of(firstArgumentType, DataType.INTEGER, DataType.INTEGER), () -> { + int start = randomIntBetween(-10, -1); + int end = randomIntBetween(0, 10); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(List.of(randomLiteral(firstArgumentType).value()), firstArgumentType, "field"), + new TestCaseSupplier.TypedData(start, DataType.INTEGER, "start"), + new TestCaseSupplier.TypedData(end, DataType.INTEGER, "end") + ), + "MvSlice" + + evaluatorTypePart + + "Evaluator[field=Attribute[channel=0], start=Attribute[channel=1], end=Attribute[channel=2]]", + firstArgumentType, + nullValue() + ).withFoldingException(InvalidArgumentException.class, "Start and end offset have different signs") + .withWarning("Line 1:1: evaluation of [source] failed, treating result as null. Only first 20 failures recorded.") + .withWarning( + "Line 1:1: org.elasticsearch.xpack.esql.core.InvalidArgumentException: Start and end offset have different signs" + ); + })); + }); + return parameterSuppliersFromTypedData( anyNullIsNull( suppliers, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSumSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSumSerializationTests.java index 15f6d94b44066..01db7335d7901 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSumSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSumSerializationTests.java @@ -21,9 +21,4 @@ protected MvSum createTestInstance() { protected MvSum mutateInstance(MvSum instance) throws IOException { return new MvSum(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/AbstractBinarySpatialFunctionSerializationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/AbstractBinarySpatialFunctionSerializationTestCase.java index 006fdf6865340..b4862c52ef710 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/AbstractBinarySpatialFunctionSerializationTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/AbstractBinarySpatialFunctionSerializationTestCase.java @@ -38,9 +38,4 @@ protected final T mutateInstance(T instance) throws IOException { } return build(source, left, right); } - - @Override - protected final boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ReplaceSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ReplaceSerializationTests.java index 4bc54241eca2c..21e1f51063ceb 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ReplaceSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ReplaceSerializationTests.java @@ -36,9 +36,4 @@ protected Replace mutateInstance(Replace instance) throws IOException { } return new Replace(source, str, regex, newStr); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerSerializationTests.java index caff331755f44..5950aeffea79a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerSerializationTests.java @@ -25,9 +25,4 @@ protected ToLower mutateInstance(ToLower instance) throws IOException { Expression child = randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild); return new ToLower(source, child, configuration()); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperSerializationTests.java index 97316c9cc7681..74d891ac2f5cd 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperSerializationTests.java @@ -25,9 +25,4 @@ protected ToUpper mutateInstance(ToUpper instance) throws IOException { Expression child = randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild); return new ToUpper(source, child, configuration()); } - - @Override - protected boolean alwaysEmptySource() { - return true; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java index 03a4b063d6294..34b82dc3878af 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java @@ -26,7 +26,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.function.Supplier; +import java.util.stream.IntStream; import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; @@ -91,7 +93,7 @@ public static Iterable parameters() { } private static void booleans(List suppliers, int items) { - suppliers.add(new TestCaseSupplier("boolean", List.of(DataType.BOOLEAN, DataType.BOOLEAN), () -> { + suppliers.add(new TestCaseSupplier("boolean", typesList(DataType.BOOLEAN, DataType.BOOLEAN, items), () -> { List inlist = randomList(items, items, () -> randomBoolean()); boolean field = randomBoolean(); List args = new ArrayList<>(inlist.size() + 1); @@ -109,7 +111,7 @@ private static void booleans(List suppliers, int items) { } private static void numerics(List suppliers, int items) { - suppliers.add(new TestCaseSupplier("integer", List.of(DataType.INTEGER, DataType.INTEGER), () -> { + suppliers.add(new TestCaseSupplier("integer", typesList(DataType.INTEGER, DataType.INTEGER, items), () -> { List inlist = randomList(items, items, () -> randomInt()); int field = inlist.get(inlist.size() - 1); List args = new ArrayList<>(inlist.size() + 1); @@ -125,7 +127,7 @@ private static void numerics(List suppliers, int items) { ); })); - suppliers.add(new TestCaseSupplier("long", List.of(DataType.LONG, DataType.LONG), () -> { + suppliers.add(new TestCaseSupplier("long", typesList(DataType.LONG, DataType.LONG, items), () -> { List inlist = randomList(items, items, () -> randomLong()); long field = randomLong(); List args = new ArrayList<>(inlist.size() + 1); @@ -141,7 +143,7 @@ private static void numerics(List suppliers, int items) { ); })); - suppliers.add(new TestCaseSupplier("double", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { + suppliers.add(new TestCaseSupplier("double", typesList(DataType.DOUBLE, DataType.DOUBLE, items), () -> { List inlist = randomList(items, items, () -> randomDouble()); double field = inlist.get(0); List args = new ArrayList<>(inlist.size() + 1); @@ -159,58 +161,10 @@ private static void numerics(List suppliers, int items) { } private static void bytesRefs(List suppliers, int items) { - suppliers.add(new TestCaseSupplier("keyword", List.of(DataType.KEYWORD, DataType.KEYWORD), () -> { - List inlist = randomList(items, items, () -> randomLiteral(DataType.KEYWORD).value()); - Object field = inlist.get(inlist.size() - 1); - List args = new ArrayList<>(inlist.size() + 1); - for (Object i : inlist) { - args.add(new TestCaseSupplier.TypedData(i, DataType.KEYWORD, "inlist" + i)); - } - args.add(new TestCaseSupplier.TypedData(field, DataType.KEYWORD, "field")); - return new TestCaseSupplier.TestCase( - args, - matchesPattern("InBytesRefEvaluator.*"), - DataType.BOOLEAN, - equalTo(inlist.contains(field)) - ); - })); - - suppliers.add(new TestCaseSupplier("text", List.of(DataType.TEXT, DataType.TEXT), () -> { - List inlist = randomList(items, items, () -> randomLiteral(DataType.TEXT).value()); - Object field = inlist.get(0); - List args = new ArrayList<>(inlist.size() + 1); - for (Object i : inlist) { - args.add(new TestCaseSupplier.TypedData(i, DataType.TEXT, "inlist" + i)); - } - args.add(new TestCaseSupplier.TypedData(field, DataType.TEXT, "field")); - return new TestCaseSupplier.TestCase( - args, - matchesPattern("InBytesRefEvaluator.*"), - DataType.BOOLEAN, - equalTo(inlist.contains(field)) - ); - })); - - suppliers.add(new TestCaseSupplier("semantic_text", List.of(DataType.SEMANTIC_TEXT, DataType.SEMANTIC_TEXT), () -> { - List inlist = randomList(items, items, () -> randomLiteral(DataType.SEMANTIC_TEXT).value()); - Object field = inlist.get(0); - List args = new ArrayList<>(inlist.size() + 1); - for (Object i : inlist) { - args.add(new TestCaseSupplier.TypedData(i, DataType.SEMANTIC_TEXT, "inlist" + i)); - } - args.add(new TestCaseSupplier.TypedData(field, DataType.SEMANTIC_TEXT, "field")); - return new TestCaseSupplier.TestCase( - args, - matchesPattern("InBytesRefEvaluator.*"), - DataType.BOOLEAN, - equalTo(inlist.contains(field)) - ); - })); - for (DataType type1 : DataType.stringTypes()) { for (DataType type2 : DataType.stringTypes()) { - if (type1 == type2 || items > 1) continue; - suppliers.add(new TestCaseSupplier(type1 + " " + type2, List.of(type1, type2), () -> { + String name = type1 == type2 ? type1.toString() : type1 + " " + type2; + suppliers.add(new TestCaseSupplier(name.toLowerCase(Locale.ROOT), typesList(type1, type2, items), () -> { List inlist = randomList(items, items, () -> randomLiteral(type1).value()); Object field = randomLiteral(type2).value(); List args = new ArrayList<>(inlist.size() + 1); @@ -227,7 +181,7 @@ private static void bytesRefs(List suppliers, int items) { })); } } - suppliers.add(new TestCaseSupplier("ip", List.of(DataType.IP, DataType.IP), () -> { + suppliers.add(new TestCaseSupplier("ip", typesList(DataType.IP, DataType.IP, items), () -> { List inlist = randomList(items, items, () -> randomLiteral(DataType.IP).value()); Object field = randomLiteral(DataType.IP).value(); List args = new ArrayList<>(inlist.size() + 1); @@ -243,7 +197,7 @@ private static void bytesRefs(List suppliers, int items) { ); })); - suppliers.add(new TestCaseSupplier("version", List.of(DataType.VERSION, DataType.VERSION), () -> { + suppliers.add(new TestCaseSupplier("version", typesList(DataType.VERSION, DataType.VERSION, items), () -> { List inlist = randomList(items, items, () -> randomLiteral(DataType.VERSION).value()); Object field = randomLiteral(DataType.VERSION).value(); List args = new ArrayList<>(inlist.size() + 1); @@ -259,7 +213,7 @@ private static void bytesRefs(List suppliers, int items) { ); })); - suppliers.add(new TestCaseSupplier("geo_point", List.of(GEO_POINT, GEO_POINT), () -> { + suppliers.add(new TestCaseSupplier("geo_point", typesList(GEO_POINT, GEO_POINT, items), () -> { List inlist = randomList(items, items, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomPoint()))); Object field = inlist.get(0); List args = new ArrayList<>(inlist.size() + 1); @@ -275,7 +229,7 @@ private static void bytesRefs(List suppliers, int items) { ); })); - suppliers.add(new TestCaseSupplier("geo_shape", List.of(GEO_SHAPE, GEO_SHAPE), () -> { + suppliers.add(new TestCaseSupplier("geo_shape", typesList(GEO_SHAPE, GEO_SHAPE, items), () -> { List inlist = randomList( items, items, @@ -295,7 +249,7 @@ private static void bytesRefs(List suppliers, int items) { ); })); - suppliers.add(new TestCaseSupplier("cartesian_point", List.of(CARTESIAN_POINT, CARTESIAN_POINT), () -> { + suppliers.add(new TestCaseSupplier("cartesian_point", typesList(CARTESIAN_POINT, CARTESIAN_POINT, items), () -> { List inlist = randomList(items, items, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomPoint()))); Object field = new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomPoint())); List args = new ArrayList<>(inlist.size() + 1); @@ -311,7 +265,7 @@ private static void bytesRefs(List suppliers, int items) { ); })); - suppliers.add(new TestCaseSupplier("cartesian_shape", List.of(CARTESIAN_SHAPE, CARTESIAN_SHAPE), () -> { + suppliers.add(new TestCaseSupplier("cartesian_shape", typesList(CARTESIAN_SHAPE, CARTESIAN_SHAPE, items), () -> { List inlist = randomList( items, items, @@ -332,6 +286,16 @@ private static void bytesRefs(List suppliers, int items) { })); } + /** + * Returns a list with N dataType1, followed by 1 dataType2. + */ + private static List typesList(DataType inListType, DataType fieldType, int n) { + List types = new ArrayList<>(n + 1); + IntStream.range(0, n).forEach(i -> types.add(inListType)); + types.add(fieldType); + return types; + } + @Override protected Expression build(Source source, List args) { return new In(source, args.get(args.size() - 1), args.subList(0, args.size() - 1)); From 3f51012481c2e44dffb99088079093769cc14341 Mon Sep 17 00:00:00 2001 From: Niels Bauman <33722607+nielsbauman@users.noreply.github.com> Date: Mon, 24 Feb 2025 14:56:37 +0100 Subject: [PATCH 15/15] Fix NPE in `ReindexDataStreamTransportAction` (#123262) In the multi-project branch, we're making some changes to persistent tasks and those changes can cause the persistent tasks custom to still be `null`. This resulted in an NPE here, so I'm fixing the check here. --- .../migrate/action/ReindexDataStreamTransportAction.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportAction.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportAction.java index 15d21b79fee90..ab2ffaaf4bd0e 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportAction.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportAction.java @@ -83,11 +83,7 @@ protected void doExecute(Task task, ReindexDataStreamRequest request, ActionList ClientHelper.getPersistableSafeSecurityHeaders(transportService.getThreadPool().getThreadContext(), clusterService.state()) ); String persistentTaskId = getPersistentTaskId(sourceDataStreamName); - - PersistentTasksCustomMetadata persistentTasksCustomMetadata = clusterService.state() - .getMetadata() - .custom(PersistentTasksCustomMetadata.TYPE); - PersistentTasksCustomMetadata.PersistentTask persistentTask = persistentTasksCustomMetadata.getTask(persistentTaskId); + final var persistentTask = PersistentTasksCustomMetadata.getTaskWithId(clusterService.state(), persistentTaskId); if (persistentTask == null) { startTask(listener, persistentTaskId, params);