From ad91d47c2278e74840eb809e960975f9b2561c58 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 5 Feb 2025 15:58:58 +0000 Subject: [PATCH 01/22] Add a lenient option to text similarity reranking --- .../org/elasticsearch/TransportVersions.java | 2 +- .../action/search/RankFeaturePhase.java | 23 +++++++++++++----- .../search/rank/RankBuilder.java | 21 +++++++++++++++- ...ankFeaturePhaseRankCoordinatorContext.java | 24 +++++++++---------- .../action/search/RankFeaturePhaseTests.java | 16 ++++--------- .../search/SearchServiceSingleNodeTests.java | 8 +++---- .../rank/RankFeatureShardPhaseTests.java | 4 ++-- .../search/rank/TestRankBuilder.java | 2 +- .../rank/random/RandomRankBuilder.java | 2 +- ...ankFeaturePhaseRankCoordinatorContext.java | 16 +------------ .../TextSimilarityRankBuilder.java | 22 ++++++++++++----- ...ankFeaturePhaseRankCoordinatorContext.java | 11 +++++---- .../TextSimilarityRankRetrieverBuilder.java | 3 ++- .../TextSimilarityRankBuilderTests.java | 5 ++-- ...aturePhaseRankCoordinatorContextTests.java | 3 ++- .../TextSimilarityRankMultiNodeTests.java | 2 +- .../TextSimilarityRankTests.java | 11 +++++---- .../TextSimilarityTestPlugin.java | 5 ++-- .../xpack/rank/rrf/RRFRankBuilder.java | 2 +- 19 files changed, 103 insertions(+), 79 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 5065d84f84978..c64b154e977b2 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -177,8 +177,8 @@ static TransportVersion def(int id) { public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES = def(9_002_0_00); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED = def(9_003_0_00); - public static final TransportVersion REMOVE_DESIRED_NODE_VERSION = def(9_004_0_00); + public static final TransportVersion LENIENT_RERANKERS = def(9_005_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index e9302883457e1..3f857678c0b6c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -26,6 +26,7 @@ import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; import org.elasticsearch.transport.Transport; +import java.util.Arrays; import java.util.List; /** @@ -181,6 +182,11 @@ private void onPhaseDone( RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext, SearchPhaseController.ReducedQueryPhase reducedQueryPhase ) { + RankFeatureDoc[] docs = rankPhaseResults.getSuccessfulResults() + .flatMap(r -> Arrays.stream(r.rankFeatureResult().shardResult().rankFeatureDocs)) + .filter(rfd -> rfd.featureData != null) + .toArray(RankFeatureDoc[]::new); + ThreadedActionListener rankResultListener = new ThreadedActionListener<>( context::execute, new ActionListener<>() { @@ -196,21 +202,26 @@ public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) { @Override public void onFailure(Exception e) { - context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e); + if (rankFeaturePhaseRankCoordinatorContext.isLenient()) { + // TODO: handle the exception somewhere + logger.warn("Exception computing updated ranks. Continuing with existing ranks.", e); + // use the existing docs as-is + // AbstractThreadedActionListener forks onFailure to the same executor as onResponse, + // so we can just call this direct + onResponse(docs); + } else { + context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e); + } } } ); - rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults( - rankPhaseResults.getAtomicArray().asList().stream().map(SearchPhaseResult::rankFeatureResult).toList(), - rankResultListener - ); + rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults(docs, rankResultListener); } private SearchPhaseController.ReducedQueryPhase newReducedQueryPhaseResults( SearchPhaseController.ReducedQueryPhase reducedQueryPhase, ScoreDoc[] scoreDocs ) { - return new SearchPhaseController.ReducedQueryPhase( reducedQueryPhase.totalHits(), reducedQueryPhase.fetchHits(), diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java b/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java index 9176410f6ea35..ce5dac7bf9ec1 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Query; +import org.elasticsearch.TransportVersions; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -42,21 +43,32 @@ public abstract class RankBuilder implements VersionedNamedWriteable, ToXContentObject { public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size"); + public static final ParseField LENIENT_FIELD = new ParseField("lenient"); public static final int DEFAULT_RANK_WINDOW_SIZE = SearchService.DEFAULT_SIZE; private final int rankWindowSize; + private final boolean lenient; - public RankBuilder(int rankWindowSize) { + public RankBuilder(int rankWindowSize, boolean lenient) { this.rankWindowSize = rankWindowSize; + this.lenient = lenient; } public RankBuilder(StreamInput in) throws IOException { rankWindowSize = in.readVInt(); + if (in.getTransportVersion().onOrAfter(TransportVersions.LENIENT_RERANKERS)) { + lenient = in.readBoolean(); + } else { + lenient = false; + } } public final void writeTo(StreamOutput out) throws IOException { out.writeVInt(rankWindowSize); + if (out.getTransportVersion().onOrAfter(TransportVersions.LENIENT_RERANKERS)) { + out.writeBoolean(lenient); + } doWriteTo(out); } @@ -67,6 +79,9 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params) builder.startObject(); builder.startObject(getWriteableName()); builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); + if (lenient) { + builder.field(LENIENT_FIELD.getPreferredName(), lenient); + } doXContent(builder, params); builder.endObject(); builder.endObject(); @@ -79,6 +94,10 @@ public int rankWindowSize() { return rankWindowSize; } + public boolean isLenient() { + return lenient; + } + /** * Specify whether this rank builder is a compound builder or not. A compound builder is a rank builder that requires * two or more queries to be executed in order to generate the final result. diff --git a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java index 3be3c5db001f3..8b5e49b0964c7 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java @@ -33,11 +33,17 @@ public abstract class RankFeaturePhaseRankCoordinatorContext { protected final int size; protected final int from; protected final int rankWindowSize; + protected final boolean lenient; - public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) { + public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean lenient) { this.size = size < 0 ? DEFAULT_SIZE : size; this.from = from < 0 ? DEFAULT_FROM : from; this.rankWindowSize = rankWindowSize; + this.lenient = lenient; + } + + public boolean isLenient() { + return lenient; } /** @@ -51,9 +57,9 @@ public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindow * @param originalDocs documents to process */ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) { - return Arrays.stream(originalDocs) - .sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()) - .toArray(RankFeatureDoc[]::new); + RankFeatureDoc[] sorted = originalDocs.clone(); + Arrays.sort(sorted, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); + return sorted; } /** @@ -64,16 +70,10 @@ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) { * Once all the scores have been computed, we sort the results, perform any pagination needed, and then call the `onFinish` consumer * with the final array of {@link ScoreDoc} results. * - * @param rankSearchResults a list of rank feature results from each shard + * @param featureDocs an array of rank feature results from each shard * @param rankListener a rankListener to handle the global ranking result */ - public void computeRankScoresForGlobalResults( - List rankSearchResults, - ActionListener rankListener - ) { - // extract feature data from each shard rank-feature phase result - RankFeatureDoc[] featureDocs = extractFeatureDocs(rankSearchResults); - + public void computeRankScoresForGlobalResults(RankFeatureDoc[] featureDocs, ActionListener rankListener) { // generate the final `topResults` results, and pass them to fetch phase through the `rankListener` computeScores(featureDocs, rankListener.delegateFailureAndWrap((listener, scores) -> { for (int i = 0; i < featureDocs.length; i++) { diff --git a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java index 3d104758b096b..533494726bdd0 100644 --- a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java @@ -775,7 +775,7 @@ public void sendExecuteRankFeature( } private RankFeaturePhaseRankCoordinatorContext defaultRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) { - return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize, false) { @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { @@ -785,16 +785,8 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener rankSearchResults, - ActionListener rankListener - ) { - List features = new ArrayList<>(); - for (RankFeatureResult rankFeatureResult : rankSearchResults) { - RankFeatureShardResult shardResult = rankFeatureResult.shardResult(); - features.addAll(Arrays.stream(shardResult.rankFeatureDocs).toList()); - } - rankListener.onResponse(features.toArray(new RankFeatureDoc[0])); + public void computeRankScoresForGlobalResults(RankFeatureDoc[] featureDocs, ActionListener rankListener) { + rankListener.onResponse(featureDocs); } @Override @@ -875,7 +867,7 @@ private RankBuilder rankBuilder( RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext, RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext ) { - return new RankBuilder(rankWindowSize) { + return new RankBuilder(rankWindowSize, false) { @Override protected void doWriteTo(StreamOutput out) throws IOException { // no-op diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java index fe602d2854c8c..9c3dc7ca8ab9d 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java @@ -687,7 +687,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo int from, Client client ) { - return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) { @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { float[] scores = new float[featureDocs.length]; @@ -831,7 +831,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo int from, Client client ) { - return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) { @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { throw new IllegalStateException("should have failed earlier"); @@ -947,7 +947,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo int from, Client client ) { - return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) { @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { float[] scores = new float[featureDocs.length]; @@ -1075,7 +1075,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo int from, Client client ) { - return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) { @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { float[] scores = new float[featureDocs.length]; diff --git a/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java index 41febe77d54aa..4b6144107a1f3 100644 --- a/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java @@ -117,7 +117,7 @@ public boolean isCancelled() { } private RankBuilder getRankBuilder(final String field) { - return new RankBuilder(DEFAULT_RANK_WINDOW_SIZE) { + return new RankBuilder(DEFAULT_RANK_WINDOW_SIZE, false) { @Override protected void doWriteTo(StreamOutput out) throws IOException { // no-op @@ -171,7 +171,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) // no work to be done on the coordinator node for the rank feature phase @Override public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) { - return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) { @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { throw new AssertionError("not expected"); diff --git a/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java b/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java index b44fb7ec77462..9915698c6b452 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java +++ b/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java @@ -52,7 +52,7 @@ public static TestRankBuilder randomRankBuilder() { } public TestRankBuilder(int windowSize) { - super(windowSize); + super(windowSize, false); } public TestRankBuilder(StreamInput in) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java index 15d41301d0a3c..afef043a9de9b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java @@ -61,7 +61,7 @@ public class RandomRankBuilder extends RankBuilder { private final Integer seed; public RandomRankBuilder(int rankWindowSize, String field, Integer seed) { - super(rankWindowSize); + super(rankWindowSize, false); if (field == null || field.isEmpty()) { throw new IllegalArgumentException("field is required"); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java index 446d8e5862dd2..de593e0943f42 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java @@ -11,8 +11,6 @@ import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; -import java.util.Arrays; -import java.util.Comparator; import java.util.Random; /** @@ -24,7 +22,7 @@ public class RandomRankFeaturePhaseRankCoordinatorContext extends RankFeaturePha private final Integer seed; public RandomRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, Integer seed) { - super(size, from, rankWindowSize); + super(size, from, rankWindowSize, false); this.seed = seed; } @@ -40,16 +38,4 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener doc.score).reversed()) - .toArray(RankFeatureDoc[]::new); - } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index e4093a91c2359..de83a6e1c27de 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -57,8 +57,9 @@ public class TextSimilarityRankBuilder extends RankBuilder { String field = (String) args[2]; Integer rankWindowSize = args[3] == null ? DEFAULT_RANK_WINDOW_SIZE : (Integer) args[3]; Float minScore = (Float) args[4]; + boolean lenient = args[5] != null && (Boolean) args[5]; - return new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore); + return new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, lenient); }); static { @@ -67,6 +68,7 @@ public class TextSimilarityRankBuilder extends RankBuilder { PARSER.declareString(constructorArg(), FIELD_FIELD); PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); PARSER.declareFloat(optionalConstructorArg(), MIN_SCORE_FIELD); + PARSER.declareBoolean(optionalConstructorArg(), LENIENT_FIELD); } private final String inferenceId; @@ -74,8 +76,15 @@ public class TextSimilarityRankBuilder extends RankBuilder { private final String field; private final Float minScore; - public TextSimilarityRankBuilder(String field, String inferenceId, String inferenceText, int rankWindowSize, Float minScore) { - super(rankWindowSize); + public TextSimilarityRankBuilder( + String field, + String inferenceId, + String inferenceText, + int rankWindowSize, + Float minScore, + boolean lenient + ) { + super(rankWindowSize, lenient); this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; @@ -103,7 +112,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void doWriteTo(StreamOutput out) throws IOException { - // rankWindowSize serialization is handled by the parent class RankBuilder + // rankWindowSize & lenient serialization is handled by the parent class RankBuilder out.writeString(inferenceId); out.writeString(inferenceText); out.writeString(field); @@ -112,7 +121,7 @@ public void doWriteTo(StreamOutput out) throws IOException { @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { - // rankWindowSize serialization is handled by the parent class RankBuilder + // rankWindowSize & lenient serialization is handled by the parent class RankBuilder builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText); builder.field(FIELD_FIELD.getPreferredName(), field); @@ -177,7 +186,8 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo client, inferenceId, inferenceText, - minScore + minScore, + isLenient() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 63274e5104207..f86c2b0666dc2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -43,9 +43,10 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext( Client client, String inferenceId, String inferenceText, - Float minScore + Float minScore, + boolean lenient ) { - super(size, from, rankWindowSize); + super(size, from, rankWindowSize, lenient); this.client = client; this.inferenceId = inferenceId; this.inferenceText = inferenceText; @@ -130,15 +131,15 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener docs = new ArrayList<>(); + List docs = new ArrayList<>(originalDocs.length); for (RankFeatureDoc doc : originalDocs) { if (minScore == null || doc.score >= minScore) { doc.score = normalizeScore(doc.score); docs.add(doc); } } - docs.sort(RankFeatureDoc::compareTo); - return docs.toArray(new RankFeatureDoc[0]); + docs.sort(null); + return docs.toArray(RankFeatureDoc[]::new); } protected InferenceAction.Request generateRequest(List docFeatures) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index fa6cc3db0ef9f..1dbcf93467148 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -52,6 +52,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder String inferenceText = (String) args[2]; String field = (String) args[3]; int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; + boolean lenient = args[5] != null && (Boolean) args[5]; return new TextSimilarityRankRetrieverBuilder(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize); }); @@ -163,7 +164,7 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b @Override protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) { sourceBuilder.rankBuilder( - new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore) + new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore, false) ); return sourceBuilder; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilderTests.java index 9ea28242f3605..d740ebaf53d10 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilderTests.java @@ -25,7 +25,8 @@ protected TextSimilarityRankBuilder createTestInstance() { "my-inference-id", "my-inference-text", randomIntBetween(1, 1000), - randomBoolean() ? null : randomFloat() + randomBoolean() ? null : randomFloat(), + false ); } @@ -46,7 +47,7 @@ protected TextSimilarityRankBuilder mutateInstance(TextSimilarityRankBuilder ins case 4 -> minScore = randomValueOtherThan(instance.minScore(), this::randomMinScore); default -> throw new IllegalStateException("Requested to modify more than available parameters."); } - return new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore); + return new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, false); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java index d6c476cdc15d6..8d1daa67d1194 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java @@ -31,7 +31,8 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContextTests extends E mockClient, "my-inference-id", "some query", - 0.0f + 0.0f, + false ); public void testComputeScores() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java index 27a8f0e962761..8bcb4d5572948 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java @@ -23,7 +23,7 @@ public class TextSimilarityRankMultiNodeTests extends AbstractRerankerIT { @Override protected RankBuilder getRankBuilder(int rankWindowSize, String rankFeatureField) { - return new TextSimilarityRankBuilder(rankFeatureField, inferenceId, inferenceText, rankWindowSize, minScore); + return new TextSimilarityRankBuilder(rankFeatureField, inferenceId, inferenceText, rankWindowSize, minScore, false); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java index 0969a902870b6..d8abf40a28d33 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java @@ -49,7 +49,7 @@ public TopNConfigurationAcceptingTextSimilarityRankBuilder( Float minScore, int topN ) { - super(field, inferenceId + "-task-settings-top-" + topN, inferenceText, rankWindowSize, minScore); + super(field, inferenceId + "-task-settings-top-" + topN, inferenceText, rankWindowSize, minScore, false); } } @@ -68,7 +68,7 @@ public InferenceResultCountAcceptingTextSimilarityRankBuilder( Float minScore, int inferenceResultCount ) { - super(field, inferenceId, inferenceText, rankWindowSize, minScore); + super(field, inferenceId, inferenceText, rankWindowSize, minScore, false); this.inferenceResultCount = inferenceResultCount; } @@ -81,7 +81,8 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo client, inferenceId, inferenceText, - minScore + minScore, + false ) { @Override protected InferenceAction.Request generateRequest(List docFeatures) { @@ -125,7 +126,7 @@ public void testRerank() { ElasticsearchAssertions.assertNoFailuresAndResponse( // Execute search with text similarity reranking client.prepareSearch() - .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 0.0f)) + .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 0.0f, false)) .setQuery(QueryBuilders.matchAllQuery()), response -> { // Verify order, rank and score of results @@ -145,7 +146,7 @@ public void testRerankWithMinScore() { ElasticsearchAssertions.assertNoFailuresAndResponse( // Execute search with text similarity reranking client.prepareSearch() - .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f)) + .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, false)) .setQuery(QueryBuilders.matchAllQuery()), response -> { // Verify order, rank and score of results diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index 358aa9804b916..051792ae10ddb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -218,7 +218,7 @@ public ThrowingMockRequestActionBasedRankBuilder( final float minScore, final String throwingType ) { - super(field, inferenceId, inferenceText, rankWindowSize, minScore); + super(field, inferenceId, inferenceText, rankWindowSize, minScore, false); this.throwingRankBuilderType = AbstractRerankerIT.ThrowingRankBuilderType.valueOf(throwingType); } @@ -263,7 +263,8 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo client, inferenceId, inferenceText, - minScore + minScore, + false ) { @Override protected InferenceAction.Request generateRequest(List docFeatures) { diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java index 91bc19a3e0903..6ab2bc901df91 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java @@ -79,7 +79,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio private final int rankConstant; public RRFRankBuilder(int rankWindowSize, int rankConstant) { - super(rankWindowSize); + super(rankWindowSize, false); if (rankConstant < 1) { throw new IllegalArgumentException("[rank_constant] must be greater or equal to [1] for [rrf]"); From b447b33a8b25224866008a2e865a7e4c6f59511e Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 5 Feb 2025 16:01:04 +0000 Subject: [PATCH 02/22] Update docs/changelog/121784.yaml --- docs/changelog/121784.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/121784.yaml diff --git a/docs/changelog/121784.yaml b/docs/changelog/121784.yaml new file mode 100644 index 0000000000000..beec96a8e1f0f --- /dev/null +++ b/docs/changelog/121784.yaml @@ -0,0 +1,5 @@ +pr: 121784 +summary: Add a lenient option to text similarity reranking +area: Search +type: enhancement +issues: [] From 1ac18fe696b534d505dd47fa390c2c0ec7841520 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 6 Feb 2025 09:36:43 +0000 Subject: [PATCH 03/22] propagate --- .../search/rank/FieldBasedRerankerIT.java | 6 ++--- .../MockedRequestActionBasedRerankerIT.java | 4 +-- .../search/rank/RankBuilder.java | 5 ++-- .../TextSimilarityRankRetrieverBuilder.java | 26 +++++++++++++------ ...xtSimilarityRankRetrieverBuilderTests.java | 3 ++- ...SimilarityRankRetrieverTelemetryTests.java | 3 ++- 6 files changed, 29 insertions(+), 18 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java index c8c1f50444c1d..d6e9d64b611f3 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java @@ -100,7 +100,7 @@ public static FieldBasedRankBuilder fromXContent(XContentParser parser) throws I } public FieldBasedRankBuilder(final int rankWindowSize, final String field) { - super(rankWindowSize); + super(rankWindowSize, false); this.field = field; } @@ -205,7 +205,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) @Override public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) { - return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize()) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize(), isLenient()) { @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { float[] scores = new float[featureDocs.length]; @@ -346,7 +346,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) @Override public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) { if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT) - return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize()) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize(), isLenient()) { @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { throw new UnsupportedOperationException("rfc - simulated failure"); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java index dbdd6e8c50027..a159ba5078912 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java @@ -249,7 +249,7 @@ public static class TestRerankingRankFeaturePhaseRankCoordinatorContext extends String inferenceText, float minScore ) { - super(size, from, windowSize); + super(size, from, windowSize, false); this.client = client; this.inferenceId = inferenceId; this.inferenceText = inferenceText; @@ -337,7 +337,7 @@ public MockRequestActionBasedRankBuilder( final String inferenceText, final float minScore ) { - super(rankWindowSize); + super(rankWindowSize, false); this.field = field; this.inferenceId = inferenceId; this.inferenceText = inferenceText; diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java b/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java index ce5dac7bf9ec1..4a5b70aada05c 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java @@ -152,16 +152,15 @@ public final boolean equals(Object obj) { if (obj == null || getClass() != obj.getClass()) { return false; } - @SuppressWarnings("unchecked") RankBuilder other = (RankBuilder) obj; - return Objects.equals(rankWindowSize, other.rankWindowSize()) && doEquals(other); + return rankWindowSize == other.rankWindowSize && lenient == other.lenient && doEquals(other); } protected abstract boolean doEquals(RankBuilder other); @Override public final int hashCode() { - return Objects.hash(getClass(), rankWindowSize, doHashCode()); + return Objects.hash(getClass(), rankWindowSize, lenient, doHashCode()); } protected abstract int doHashCode(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 1dbcf93467148..7e564a96481af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -44,6 +44,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text"); public static final ParseField FIELD_FIELD = new ParseField("field"); + public static final ParseField LENIENT_FIELD = new ParseField("lenient"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> { @@ -54,7 +55,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; boolean lenient = args[5] != null && (Boolean) args[5]; - return new TextSimilarityRankRetrieverBuilder(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize); + return new TextSimilarityRankRetrieverBuilder(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize, lenient); }); static { @@ -67,6 +68,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder PARSER.declareString(constructorArg(), INFERENCE_TEXT_FIELD); PARSER.declareString(constructorArg(), FIELD_FIELD); PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); + PARSER.declareBoolean(optionalConstructorArg(), LENIENT_FIELD); RetrieverBuilder.declareBaseParserFields(TextSimilarityRankBuilder.NAME, PARSER); } @@ -85,18 +87,21 @@ public static TextSimilarityRankRetrieverBuilder fromXContent( private final String inferenceId; private final String inferenceText; private final String field; + private final boolean lenient; public TextSimilarityRankRetrieverBuilder( RetrieverBuilder retrieverBuilder, String inferenceId, String inferenceText, String field, - int rankWindowSize + int rankWindowSize, + boolean lenient ) { super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize); this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; + this.lenient = lenient; } public TextSimilarityRankRetrieverBuilder( @@ -106,6 +111,7 @@ public TextSimilarityRankRetrieverBuilder( String field, int rankWindowSize, Float minScore, + boolean lenient, String retrieverName, List preFilterQueryBuilders ) { @@ -117,6 +123,7 @@ public TextSimilarityRankRetrieverBuilder( this.inferenceText = inferenceText; this.field = field; this.minScore = minScore; + this.lenient = lenient; this.retrieverName = retrieverName; this.preFilterQueryBuilders = preFilterQueryBuilders; } @@ -133,6 +140,7 @@ protected TextSimilarityRankRetrieverBuilder clone( field, rankWindowSize, minScore, + lenient, retrieverName, newPreFilterQueryBuilders ); @@ -163,9 +171,7 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b @Override protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) { - sourceBuilder.rankBuilder( - new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore, false) - ); + sourceBuilder.rankBuilder(new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, lenient)); return sourceBuilder; } @@ -189,6 +195,9 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText); builder.field(FIELD_FIELD.getPreferredName(), field); builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); + if (lenient) { + builder.field(LENIENT_FIELD.getPreferredName(), lenient); + } } @Override @@ -198,12 +207,13 @@ public boolean doEquals(Object other) { && Objects.equals(inferenceId, that.inferenceId) && Objects.equals(inferenceText, that.inferenceText) && Objects.equals(field, that.field) - && Objects.equals(rankWindowSize, that.rankWindowSize) - && Objects.equals(minScore, that.minScore); + && rankWindowSize == that.rankWindowSize + && Objects.equals(minScore, that.minScore) + && lenient == that.lenient; } @Override public int doHashCode() { - return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore); + return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, lenient); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java index 93c3ffe5d14fb..3fb7e05f44f79 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java @@ -57,7 +57,8 @@ public static TextSimilarityRankRetrieverBuilder createRandomTextSimilarityRankR randomAlphaOfLength(10), randomAlphaOfLength(20), randomAlphaOfLength(50), - randomIntBetween(100, 10000) + randomIntBetween(100, 10000), + randomBoolean() ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java index ba6924ba0ff3b..237ee1cde1ef8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java @@ -139,7 +139,8 @@ public void testTelemetryForRRFRetriever() throws IOException { "some_inference_id", "some_inference_text", "some_field", - 10 + 10, + false ) ) ); From 22bd2a0065e59d0a4cb9381272676d432eba4a70 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 6 Feb 2025 10:53:14 +0000 Subject: [PATCH 04/22] leniency is only part of the text similarity builders --- .../search/rank/FieldBasedRerankerIT.java | 6 ++--- .../MockedRequestActionBasedRerankerIT.java | 2 +- .../search/rank/RankBuilder.java | 25 +++---------------- .../action/search/RankFeaturePhaseTests.java | 2 +- .../rank/RankFeatureShardPhaseTests.java | 2 +- .../search/rank/TestRankBuilder.java | 2 +- .../rank/random/RandomRankBuilder.java | 2 +- .../TextSimilarityRankBuilder.java | 22 +++++++++++++--- .../xpack/rank/rrf/RRFRankBuilder.java | 2 +- 9 files changed, 30 insertions(+), 35 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java index d6e9d64b611f3..cefabe277eb31 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java @@ -100,7 +100,7 @@ public static FieldBasedRankBuilder fromXContent(XContentParser parser) throws I } public FieldBasedRankBuilder(final int rankWindowSize, final String field) { - super(rankWindowSize, false); + super(rankWindowSize); this.field = field; } @@ -205,7 +205,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) @Override public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) { - return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize(), isLenient()) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize(), false) { @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { float[] scores = new float[featureDocs.length]; @@ -346,7 +346,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) @Override public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) { if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT) - return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize(), isLenient()) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize(), false) { @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { throw new UnsupportedOperationException("rfc - simulated failure"); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java index a159ba5078912..83947b88d1de1 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java @@ -337,7 +337,7 @@ public MockRequestActionBasedRankBuilder( final String inferenceText, final float minScore ) { - super(rankWindowSize, false); + super(rankWindowSize); this.field = field; this.inferenceId = inferenceId; this.inferenceText = inferenceText; diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java b/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java index 4a5b70aada05c..af53273f8bc93 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java @@ -11,7 +11,6 @@ import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Query; -import org.elasticsearch.TransportVersions; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -43,32 +42,21 @@ public abstract class RankBuilder implements VersionedNamedWriteable, ToXContentObject { public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size"); - public static final ParseField LENIENT_FIELD = new ParseField("lenient"); public static final int DEFAULT_RANK_WINDOW_SIZE = SearchService.DEFAULT_SIZE; private final int rankWindowSize; - private final boolean lenient; - public RankBuilder(int rankWindowSize, boolean lenient) { + public RankBuilder(int rankWindowSize) { this.rankWindowSize = rankWindowSize; - this.lenient = lenient; } public RankBuilder(StreamInput in) throws IOException { rankWindowSize = in.readVInt(); - if (in.getTransportVersion().onOrAfter(TransportVersions.LENIENT_RERANKERS)) { - lenient = in.readBoolean(); - } else { - lenient = false; - } } public final void writeTo(StreamOutput out) throws IOException { out.writeVInt(rankWindowSize); - if (out.getTransportVersion().onOrAfter(TransportVersions.LENIENT_RERANKERS)) { - out.writeBoolean(lenient); - } doWriteTo(out); } @@ -79,9 +67,6 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params) builder.startObject(); builder.startObject(getWriteableName()); builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); - if (lenient) { - builder.field(LENIENT_FIELD.getPreferredName(), lenient); - } doXContent(builder, params); builder.endObject(); builder.endObject(); @@ -94,10 +79,6 @@ public int rankWindowSize() { return rankWindowSize; } - public boolean isLenient() { - return lenient; - } - /** * Specify whether this rank builder is a compound builder or not. A compound builder is a rank builder that requires * two or more queries to be executed in order to generate the final result. @@ -153,14 +134,14 @@ public final boolean equals(Object obj) { return false; } RankBuilder other = (RankBuilder) obj; - return rankWindowSize == other.rankWindowSize && lenient == other.lenient && doEquals(other); + return rankWindowSize == other.rankWindowSize && doEquals(other); } protected abstract boolean doEquals(RankBuilder other); @Override public final int hashCode() { - return Objects.hash(getClass(), rankWindowSize, lenient, doHashCode()); + return Objects.hash(getClass(), rankWindowSize, doHashCode()); } protected abstract int doHashCode(); diff --git a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java index 533494726bdd0..91620e323067b 100644 --- a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java @@ -867,7 +867,7 @@ private RankBuilder rankBuilder( RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext, RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext ) { - return new RankBuilder(rankWindowSize, false) { + return new RankBuilder(rankWindowSize) { @Override protected void doWriteTo(StreamOutput out) throws IOException { // no-op diff --git a/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java index 4b6144107a1f3..b6c5eeca26bc4 100644 --- a/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java @@ -117,7 +117,7 @@ public boolean isCancelled() { } private RankBuilder getRankBuilder(final String field) { - return new RankBuilder(DEFAULT_RANK_WINDOW_SIZE, false) { + return new RankBuilder(DEFAULT_RANK_WINDOW_SIZE) { @Override protected void doWriteTo(StreamOutput out) throws IOException { // no-op diff --git a/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java b/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java index 9915698c6b452..b44fb7ec77462 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java +++ b/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java @@ -52,7 +52,7 @@ public static TestRankBuilder randomRankBuilder() { } public TestRankBuilder(int windowSize) { - super(windowSize, false); + super(windowSize); } public TestRankBuilder(StreamInput in) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java index afef043a9de9b..15d41301d0a3c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java @@ -61,7 +61,7 @@ public class RandomRankBuilder extends RankBuilder { private final Integer seed; public RandomRankBuilder(int rankWindowSize, String field, Integer seed) { - super(rankWindowSize, false); + super(rankWindowSize); if (field == null || field.isEmpty()) { throw new IllegalArgumentException("field is required"); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index de83a6e1c27de..17aafff018442 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -36,6 +36,7 @@ import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.FIELD_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_ID_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_TEXT_FIELD; +import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.LENIENT_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.MIN_SCORE_FIELD; /** @@ -75,6 +76,7 @@ public class TextSimilarityRankBuilder extends RankBuilder { private final String inferenceText; private final String field; private final Float minScore; + private final boolean lenient; public TextSimilarityRankBuilder( String field, @@ -84,11 +86,12 @@ public TextSimilarityRankBuilder( Float minScore, boolean lenient ) { - super(rankWindowSize, lenient); + super(rankWindowSize); this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; this.minScore = minScore; + this.lenient = lenient; } public TextSimilarityRankBuilder(StreamInput in) throws IOException { @@ -98,6 +101,11 @@ public TextSimilarityRankBuilder(StreamInput in) throws IOException { this.inferenceText = in.readString(); this.field = in.readString(); this.minScore = in.readOptionalFloat(); + if (in.getTransportVersion().onOrAfter(TransportVersions.LENIENT_RERANKERS)) { + this.lenient = in.readBoolean(); + } else { + this.lenient = false; + } } @Override @@ -112,22 +120,28 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void doWriteTo(StreamOutput out) throws IOException { - // rankWindowSize & lenient serialization is handled by the parent class RankBuilder + // rankWindowSize serialization is handled by the parent class RankBuilder out.writeString(inferenceId); out.writeString(inferenceText); out.writeString(field); out.writeOptionalFloat(minScore); + if (out.getTransportVersion().onOrAfter(TransportVersions.LENIENT_RERANKERS)) { + out.writeBoolean(lenient); + } } @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { - // rankWindowSize & lenient serialization is handled by the parent class RankBuilder + // rankWindowSize serialization is handled by the parent class RankBuilder builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText); builder.field(FIELD_FIELD.getPreferredName(), field); if (minScore != null) { builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); } + if (lenient) { + builder.field(LENIENT_FIELD.getPreferredName(), lenient); + } } @Override @@ -187,7 +201,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo inferenceId, inferenceText, minScore, - isLenient() + lenient ); } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java index 6ab2bc901df91..91bc19a3e0903 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java @@ -79,7 +79,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio private final int rankConstant; public RRFRankBuilder(int rankWindowSize, int rankConstant) { - super(rankWindowSize, false); + super(rankWindowSize); if (rankConstant < 1) { throw new IllegalArgumentException("[rank_constant] must be greater or equal to [1] for [rrf]"); From c47b8a5dd261dfd3fe660dd505b58bba05261c8f Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 6 Feb 2025 11:34:32 +0000 Subject: [PATCH 05/22] Remove superfluous xcontent serialization --- .../TextSimilarityRankBuilder.java | 41 +------- .../TextSimilarityRankBuilderTests.java | 95 ------------------- 2 files changed, 2 insertions(+), 134 deletions(-) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilderTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index 17aafff018442..56157e690049f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -24,21 +24,13 @@ import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext; -import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.io.UnsupportedEncodingException; import java.util.List; import java.util.Objects; -import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; -import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; -import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.FIELD_FIELD; -import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_ID_FIELD; -import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_TEXT_FIELD; -import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.LENIENT_FIELD; -import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.MIN_SCORE_FIELD; - /** * A {@code RankBuilder} that enables ranking with text similarity model inference. Supports parameters for configuring the inference call. */ @@ -52,26 +44,6 @@ public class TextSimilarityRankBuilder extends RankBuilder { License.OperationMode.ENTERPRISE ); - static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { - String inferenceId = (String) args[0]; - String inferenceText = (String) args[1]; - String field = (String) args[2]; - Integer rankWindowSize = args[3] == null ? DEFAULT_RANK_WINDOW_SIZE : (Integer) args[3]; - Float minScore = (Float) args[4]; - boolean lenient = args[5] != null && (Boolean) args[5]; - - return new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, lenient); - }); - - static { - PARSER.declareString(constructorArg(), INFERENCE_ID_FIELD); - PARSER.declareString(constructorArg(), INFERENCE_TEXT_FIELD); - PARSER.declareString(constructorArg(), FIELD_FIELD); - PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); - PARSER.declareFloat(optionalConstructorArg(), MIN_SCORE_FIELD); - PARSER.declareBoolean(optionalConstructorArg(), LENIENT_FIELD); - } - private final String inferenceId; private final String inferenceText; private final String field; @@ -132,16 +104,7 @@ public void doWriteTo(StreamOutput out) throws IOException { @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { - // rankWindowSize serialization is handled by the parent class RankBuilder - builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); - builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText); - builder.field(FIELD_FIELD.getPreferredName(), field); - if (minScore != null) { - builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); - } - if (lenient) { - builder.field(LENIENT_FIELD.getPreferredName(), lenient); - } + throw new UnsupportedEncodingException("This should not be XContent serialized"); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilderTests.java deleted file mode 100644 index d740ebaf53d10..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilderTests.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.rank.textsimilarity; - -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractXContentSerializingTestCase; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.json.JsonXContent; - -import java.io.IOException; - -import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE; - -public class TextSimilarityRankBuilderTests extends AbstractXContentSerializingTestCase { - - @Override - protected TextSimilarityRankBuilder createTestInstance() { - return new TextSimilarityRankBuilder( - "my-field", - "my-inference-id", - "my-inference-text", - randomIntBetween(1, 1000), - randomBoolean() ? null : randomFloat(), - false - ); - } - - @Override - protected TextSimilarityRankBuilder mutateInstance(TextSimilarityRankBuilder instance) throws IOException { - String field = instance.field(); - String inferenceId = instance.inferenceId(); - String inferenceText = instance.inferenceText(); - int rankWindowSize = instance.rankWindowSize(); - Float minScore = instance.minScore(); - - int mutate = randomIntBetween(0, 4); - switch (mutate) { - case 0 -> field = field + randomAlphaOfLength(2); - case 1 -> inferenceId = inferenceId + randomAlphaOfLength(2); - case 2 -> inferenceText = inferenceText + randomAlphaOfLength(2); - case 3 -> rankWindowSize = randomValueOtherThan(instance.rankWindowSize(), this::randomRankWindowSize); - case 4 -> minScore = randomValueOtherThan(instance.minScore(), this::randomMinScore); - default -> throw new IllegalStateException("Requested to modify more than available parameters."); - } - return new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, false); - } - - @Override - protected Writeable.Reader instanceReader() { - return TextSimilarityRankBuilder::new; - } - - @Override - protected TextSimilarityRankBuilder doParseInstance(XContentParser parser) throws IOException { - parser.nextToken(); - assertEquals(parser.currentToken(), XContentParser.Token.START_OBJECT); - parser.nextToken(); - assertEquals(parser.currentToken(), XContentParser.Token.FIELD_NAME); - assertEquals(parser.currentName(), TextSimilarityRankBuilder.NAME); - TextSimilarityRankBuilder builder = TextSimilarityRankBuilder.PARSER.parse(parser, null); - parser.nextToken(); - assertEquals(parser.currentToken(), XContentParser.Token.END_OBJECT); - parser.nextToken(); - assertNull(parser.currentToken()); - return builder; - } - - private int randomRankWindowSize() { - return randomIntBetween(0, 1000); - } - - private float randomMinScore() { - return randomFloatBetween(-1.0f, 1.0f, true); - } - - public void testParserDefaults() throws IOException { - String json = """ - { - "field": "my-field", - "inference_id": "my-inference-id", - "inference_text": "my-inference-text" - }"""; - - try (XContentParser parser = createParser(JsonXContent.jsonXContent, json)) { - TextSimilarityRankBuilder parsed = TextSimilarityRankBuilder.PARSER.parse(parser, null); - assertEquals(DEFAULT_RANK_WINDOW_SIZE, parsed.rankWindowSize()); - } - } - -} From bc81bd03783c2da189a41f8728088ce590d9183c Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 6 Feb 2025 14:30:50 +0000 Subject: [PATCH 06/22] Add a test for lenient rerankers --- .../action/search/RankFeaturePhase.java | 24 ++++-- ...ankFeaturePhaseRankCoordinatorContext.java | 17 ---- .../TextSimilarityRankBuilder.java | 9 +- .../TextSimilarityRankRetrieverBuilder.java | 4 + .../TextSimilarityRankMultiNodeTests.java | 1 + ...xtSimilarityRankRetrieverBuilderTests.java | 43 +++++----- .../TextSimilarityRankTests.java | 84 ++++++++++++++----- .../TextSimilarityTestPlugin.java | 65 ++------------ 8 files changed, 122 insertions(+), 125 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index 3f857678c0b6c..6eadfd8d74b3d 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -182,11 +182,6 @@ private void onPhaseDone( RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext, SearchPhaseController.ReducedQueryPhase reducedQueryPhase ) { - RankFeatureDoc[] docs = rankPhaseResults.getSuccessfulResults() - .flatMap(r -> Arrays.stream(r.rankFeatureResult().shardResult().rankFeatureDocs)) - .filter(rfd -> rfd.featureData != null) - .toArray(RankFeatureDoc[]::new); - ThreadedActionListener rankResultListener = new ThreadedActionListener<>( context::execute, new ActionListener<>() { @@ -204,18 +199,29 @@ public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) { public void onFailure(Exception e) { if (rankFeaturePhaseRankCoordinatorContext.isLenient()) { // TODO: handle the exception somewhere - logger.warn("Exception computing updated ranks. Continuing with existing ranks.", e); - // use the existing docs as-is + // don't want to log the entire stack trace, it's not helpful here + logger.warn("Exception computing updated ranks: {}. Continuing with existing ranks.", e.toString()); + // use the existing score docs as-is + RankFeatureDoc[] existingScores = Arrays.stream(reducedQueryPhase.sortedTopDocs().scoreDocs()) + .map(sd -> new RankFeatureDoc(sd.doc, sd.score, sd.shardIndex)) + .toArray(RankFeatureDoc[]::new); + // AbstractThreadedActionListener forks onFailure to the same executor as onResponse, // so we can just call this direct - onResponse(docs); + onResponse(existingScores); } else { context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e); } } } ); - rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults(docs, rankResultListener); + rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults( + rankPhaseResults.getSuccessfulResults() + .flatMap(r -> Arrays.stream(r.rankFeatureResult().shardResult().rankFeatureDocs)) + .filter(rfd -> rfd.featureData != null) + .toArray(RankFeatureDoc[]::new), + rankResultListener + ); } private SearchPhaseController.ReducedQueryPhase newReducedQueryPhaseResults( diff --git a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java index 8b5e49b0964c7..d4603e056e485 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java @@ -12,13 +12,9 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.action.ActionListener; import org.elasticsearch.search.rank.feature.RankFeatureDoc; -import org.elasticsearch.search.rank.feature.RankFeatureResult; -import org.elasticsearch.search.rank.feature.RankFeatureShardResult; -import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; -import java.util.List; import static org.elasticsearch.search.SearchService.DEFAULT_FROM; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; @@ -97,17 +93,4 @@ public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) { } return topResults; } - - private RankFeatureDoc[] extractFeatureDocs(List rankSearchResults) { - List docFeatures = new ArrayList<>(); - for (RankFeatureResult rankFeatureResult : rankSearchResults) { - RankFeatureShardResult shardResult = rankFeatureResult.shardResult(); - for (RankFeatureDoc rankFeatureDoc : shardResult.rankFeatureDocs) { - if (rankFeatureDoc.featureData != null) { - docFeatures.add(rankFeatureDoc); - } - } - } - return docFeatures.toArray(new RankFeatureDoc[0]); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index 56157e690049f..f7dc706c4043e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -184,17 +184,22 @@ public Float minScore() { return minScore; } + public boolean isLenient() { + return lenient; + } + @Override protected boolean doEquals(RankBuilder other) { TextSimilarityRankBuilder that = (TextSimilarityRankBuilder) other; return Objects.equals(inferenceId, that.inferenceId) && Objects.equals(inferenceText, that.inferenceText) && Objects.equals(field, that.field) - && Objects.equals(minScore, that.minScore); + && Objects.equals(minScore, that.minScore) + && lenient == that.lenient; } @Override protected int doHashCode() { - return Objects.hash(inferenceId, inferenceText, field, minScore); + return Objects.hash(inferenceId, inferenceText, field, minScore, lenient); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 7e564a96481af..bf27308c03fe1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -188,6 +188,10 @@ public int rankWindowSize() { return rankWindowSize; } + public boolean isLenient() { + return lenient; + } + @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.getFirst().retriever()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java index 8bcb4d5572948..18415b9ed6ca3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java @@ -34,6 +34,7 @@ protected RankBuilder getThrowingRankBuilder(int rankWindowSize, String rankFeat inferenceId, inferenceText, minScore, + false, type.name() ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java index 3fb7e05f44f79..2f46d631350bf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java @@ -118,29 +118,33 @@ public void testParserDefaults() throws IOException { parser, new RetrieverParserContext(new SearchUsage(), nf -> true) ); - assertEquals(DEFAULT_RANK_WINDOW_SIZE, parsed.rankWindowSize()); - assertEquals(DEFAULT_RERANK_ID, parsed.inferenceId()); + assertThat(parsed.rankWindowSize(), equalTo(DEFAULT_RANK_WINDOW_SIZE)); + assertThat(parsed.inferenceId(), equalTo(DEFAULT_RERANK_ID)); + assertThat(parsed.isLenient(), equalTo(false)); + } } public void testTextSimilarityRetrieverParsing() throws IOException { - String restContent = "{" - + " \"retriever\": {" - + " \"text_similarity_reranker\": {" - + " \"retriever\": {" - + " \"test\": {" - + " \"value\": \"my-test-retriever\"" - + " }" - + " }," - + " \"field\": \"my-field\"," - + " \"inference_id\": \"my-inference-id\"," - + " \"inference_text\": \"my-inference-text\"," - + " \"rank_window_size\": 100," - + " \"min_score\": 20.0," - + " \"_name\": \"foo_reranker\"" - + " }" - + " }" - + "}"; + String restContent = """ + { + "retriever": { + "text_similarity_reranker": { + "retriever": { + "test": { + "value": "my-test-retriever" + } + }, + "field": "my-field", + "inference_id": "my-inference-id", + "inference_text": "my-inference-text", + "rank_window_size": 100, + "min_score": 20.0, + "lenient": true, + "_name": "foo_reranker" + } + } + }"""; SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); @@ -148,6 +152,7 @@ public void testTextSimilarityRetrieverParsing() throws IOException { TextSimilarityRankRetrieverBuilder parsed = (TextSimilarityRankRetrieverBuilder) source.retriever(); assertThat(parsed.minScore(), equalTo(20f)); assertThat(parsed.retrieverName(), equalTo("foo_reranker")); + assertThat(parsed.isLenient(), equalTo(true)); try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) { SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent( parseSerialized, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java index d8abf40a28d33..1f3f333f09e10 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java @@ -21,14 +21,17 @@ import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.hamcrest.Matcher; import org.junit.Before; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Objects; +import static org.elasticsearch.test.LambdaMatchers.transformedMatch; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -66,9 +69,10 @@ public InferenceResultCountAcceptingTextSimilarityRankBuilder( String inferenceText, int rankWindowSize, Float minScore, + boolean lenient, int inferenceResultCount ) { - super(field, inferenceId, inferenceText, rankWindowSize, minScore, false); + super(field, inferenceId, inferenceText, rankWindowSize, minScore, lenient); this.inferenceResultCount = inferenceResultCount; } @@ -82,7 +86,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo inferenceId, inferenceText, minScore, - false + isLenient() ) { @Override protected InferenceAction.Request generateRequest(List docFeatures) { @@ -130,14 +134,17 @@ public void testRerank() { .setQuery(QueryBuilders.matchAllQuery()), response -> { // Verify order, rank and score of results - SearchHit[] hits = response.getHits().getHits(); - assertEquals(5, hits.length); - // we add + 1 to all expected scores due to the default normalization being applied which shifts positive scores to by 1 - assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4"); - assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3"); - assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2"); - assertHitHasRankScoreAndText(hits[3], 4, 1.0f + 1f, "1"); - assertHitHasRankScoreAndText(hits[4], 5, 0.0f + 1f, "0"); + assertThat( + response.getHits().getHits(), + arrayContaining( + // add 1 to all expected scores due to the default normalization being applied which shifts positive scores by 1 + searchHitWith(1, 4.0f + 1f, "4"), + searchHitWith(2, 3.0f + 1f, "3"), + searchHitWith(3, 2.0f + 1f, "2"), + searchHitWith(4, 1.0f + 1f, "1"), + searchHitWith(5, 0.0f + 1f, "0") + ) + ); } ); } @@ -150,11 +157,10 @@ public void testRerankWithMinScore() { .setQuery(QueryBuilders.matchAllQuery()), response -> { // Verify order, rank and score of results - SearchHit[] hits = response.getHits().getHits(); - assertEquals(3, hits.length); - assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4"); - assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3"); - assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2"); + assertThat( + response.getHits().getHits(), + arrayContaining(searchHitWith(1, 4.0f + 1f, "4"), searchHitWith(2, 3.0f + 1f, "3"), searchHitWith(3, 2.0f + 1f, "2")) + ); } ); } @@ -170,6 +176,7 @@ public void testRerankInferenceFailure() { "my-rerank-model", "my query", 0.7f, + false, AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name() ) ) @@ -179,6 +186,38 @@ public void testRerankInferenceFailure() { ); } + public void testLenientRerankInference() { + ElasticsearchAssertions.assertNoFailuresAndResponse( + // Execute search with text similarity reranking + client.prepareSearch() + .setRankBuilder( + new TextSimilarityTestPlugin.ThrowingMockRequestActionBasedRankBuilder( + 100, + "text", + "my-rerank-model", + "my query", + null, + true, + AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name() + ) + ) + .setQuery(QueryBuilders.matchAllQuery()), + response -> { + // these will all have a score of 2 (default 1 + normalization) + assertThat( + response.getHits().getHits(), + arrayContaining( + searchHitWith(1, 2.0f, "0"), + searchHitWith(2, 2.0f, "1"), + searchHitWith(3, 2.0f, "2"), + searchHitWith(4, 2.0f, "3"), + searchHitWith(5, 2.0f, "4") + ) + ); + } + ); + } + public void testRerankTopNConfigurationAndRankWindowSizeMismatch() { SearchPhaseExecutionException ex = expectThrows( SearchPhaseExecutionException.class, @@ -205,7 +244,7 @@ public void testRerankInputSizeAndInferenceResultsMismatch() { client.prepareSearch() .setRankBuilder( // Simulate reranker returning different number of results from input - new InferenceResultCountAcceptingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, 4) + new InferenceResultCountAcceptingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, false, 4) ) .setQuery(QueryBuilders.matchAllQuery()) ); @@ -213,10 +252,11 @@ public void testRerankInputSizeAndInferenceResultsMismatch() { assertThat(ex.getDetailedMessage(), containsString("Reranker input document count and returned score count mismatch")); } - private static void assertHitHasRankScoreAndText(SearchHit hit, int expectedRank, float expectedScore, String expectedText) { - assertEquals(expectedRank, hit.getRank()); - assertEquals(expectedScore, hit.getScore(), 0.0f); - assertEquals(expectedText, Objects.requireNonNull(hit.getSourceAsMap()).get("text")); + private static Matcher searchHitWith(int expectedRank, float expectedScore, String expectedText) { + return allOf( + transformedMatch(SearchHit::getRank, equalTo(expectedRank)), + transformedMatch(SearchHit::getScore, equalTo(expectedScore)), + transformedMatch(hit -> hit.getSourceAsMap().get("text"), equalTo(expectedText)) + ); } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index 051792ae10ddb..7ab22eae07473 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -30,10 +30,6 @@ import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.search.rank.rerank.AbstractRerankerIT; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; @@ -51,8 +47,6 @@ import java.util.regex.Pattern; import static java.util.Collections.singletonList; -import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; -import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; /** * Plugin for text similarity tests. Defines a filter for modifying inference call behavior, as well as a {@code TextSimilarityRankBuilder} @@ -172,53 +166,18 @@ private void handleInferenceActionRequest( public static class ThrowingMockRequestActionBasedRankBuilder extends TextSimilarityRankBuilder { - public static final ParseField FIELD_FIELD = new ParseField("field"); - public static final ParseField INFERENCE_ID = new ParseField("inference_id"); - public static final ParseField INFERENCE_TEXT = new ParseField("inference_text"); - public static final ParseField THROWING_TYPE_FIELD = new ParseField("throwing-type"); - - static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "throwing_request_action_based_rank", - args -> { - int rankWindowSize = args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0]; - String field = (String) args[1]; - if (field == null || field.isEmpty()) { - throw new IllegalArgumentException("Field cannot be null or empty"); - } - final String inferenceId = (String) args[2]; - final String inferenceText = (String) args[3]; - final float minScore = (float) args[4]; - String throwingType = (String) args[5]; - return new ThrowingMockRequestActionBasedRankBuilder( - rankWindowSize, - field, - inferenceId, - inferenceText, - minScore, - throwingType - ); - } - ); - - static { - PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); - PARSER.declareString(constructorArg(), FIELD_FIELD); - PARSER.declareString(constructorArg(), INFERENCE_ID); - PARSER.declareString(constructorArg(), INFERENCE_TEXT); - PARSER.declareString(constructorArg(), THROWING_TYPE_FIELD); - } - protected final AbstractRerankerIT.ThrowingRankBuilderType throwingRankBuilderType; public ThrowingMockRequestActionBasedRankBuilder( - final int rankWindowSize, - final String field, - final String inferenceId, - final String inferenceText, - final float minScore, - final String throwingType + int rankWindowSize, + String field, + String inferenceId, + String inferenceText, + Float minScore, + boolean lenient, + String throwingType ) { - super(field, inferenceId, inferenceText, rankWindowSize, minScore, false); + super(field, inferenceId, inferenceText, rankWindowSize, minScore, lenient); this.throwingRankBuilderType = AbstractRerankerIT.ThrowingRankBuilderType.valueOf(throwingType); } @@ -233,12 +192,6 @@ public void doWriteTo(StreamOutput out) throws IOException { out.writeEnum(throwingRankBuilderType); } - @Override - public void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { - super.doXContent(builder, params); - builder.field(THROWING_TYPE_FIELD.getPreferredName(), throwingRankBuilderType); - } - @Override public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { if (this.throwingRankBuilderType == AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT) @@ -264,7 +217,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo inferenceId, inferenceText, minScore, - false + isLenient() ) { @Override protected InferenceAction.Request generateRequest(List docFeatures) { From 2ef4e4e9794523e2c7700cb7b6f2374917d76a34 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 7 Feb 2025 08:37:13 +0000 Subject: [PATCH 07/22] Rename lenient to failuresAllowed --- .../org/elasticsearch/TransportVersions.java | 2 +- .../action/search/RankFeaturePhase.java | 13 +++--- ...ankFeaturePhaseRankCoordinatorContext.java | 22 +++++---- .../action/search/RankFeaturePhaseTests.java | 2 +- .../TextSimilarityRankBuilder.java | 26 +++++------ ...ankFeaturePhaseRankCoordinatorContext.java | 15 +++++-- .../TextSimilarityRankRetrieverBuilder.java | 45 +++++++++++-------- ...xtSimilarityRankRetrieverBuilderTests.java | 4 +- .../TextSimilarityRankTests.java | 14 +++--- .../TextSimilarityTestPlugin.java | 2 +- 10 files changed, 84 insertions(+), 61 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 62bdd76adf4e5..ded794c871c31 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -180,7 +180,7 @@ static TransportVersion def(int id) { public static final TransportVersion REMOVE_DESIRED_NODE_VERSION = def(9_004_0_00); public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION = def(9_005_0_00); public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE = def(9_006_0_00); - public static final TransportVersion LENIENT_RERANKERS = def(9_007_0_00); + public static final TransportVersion RERANKER_FAILURES_ALLOWED = def(9_007_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index 6eadfd8d74b3d..fadbebadef097 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -187,7 +187,7 @@ private void onPhaseDone( new ActionListener<>() { @Override public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) { - RankFeatureDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(docsWithUpdatedScores); + RankFeatureDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(docsWithUpdatedScores, true); SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = newReducedQueryPhaseResults( reducedQueryPhase, topResults @@ -197,7 +197,7 @@ public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) { @Override public void onFailure(Exception e) { - if (rankFeaturePhaseRankCoordinatorContext.isLenient()) { + if (rankFeaturePhaseRankCoordinatorContext.failuresAllowed()) { // TODO: handle the exception somewhere // don't want to log the entire stack trace, it's not helpful here logger.warn("Exception computing updated ranks: {}. Continuing with existing ranks.", e.toString()); @@ -206,9 +206,12 @@ public void onFailure(Exception e) { .map(sd -> new RankFeatureDoc(sd.doc, sd.score, sd.shardIndex)) .toArray(RankFeatureDoc[]::new); - // AbstractThreadedActionListener forks onFailure to the same executor as onResponse, - // so we can just call this direct - onResponse(existingScores); + RankFeatureDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(existingScores, false); + SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = newReducedQueryPhaseResults( + reducedQueryPhase, + topResults + ); + moveToNextPhase(rankPhaseResults, reducedRankFeaturePhase); } else { context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e); } diff --git a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java index d4603e056e485..f9a4e8c8dd162 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java @@ -29,17 +29,17 @@ public abstract class RankFeaturePhaseRankCoordinatorContext { protected final int size; protected final int from; protected final int rankWindowSize; - protected final boolean lenient; + protected final boolean failuresAllowed; - public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean lenient) { + public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean failuresAllowed) { this.size = size < 0 ? DEFAULT_SIZE : size; this.from = from < 0 ? DEFAULT_FROM : from; this.rankWindowSize = rankWindowSize; - this.lenient = lenient; + this.failuresAllowed = failuresAllowed; } - public boolean isLenient() { - return lenient; + public boolean failuresAllowed() { + return failuresAllowed; } /** @@ -50,9 +50,11 @@ public boolean isLenient() { /** * Preprocesses the provided documents: sorts them by score descending. - * @param originalDocs documents to process + * + * @param originalDocs documents to process + * @param rerankedScores {@code true} if the document scores have been reranked */ - protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) { + protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs, boolean rerankedScores) { RankFeatureDoc[] sorted = originalDocs.clone(); Arrays.sort(sorted, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); return sorted; @@ -82,10 +84,12 @@ public void computeRankScoresForGlobalResults(RankFeatureDoc[] featureDocs, Acti /** * Ranks the provided {@link RankFeatureDoc} array and paginates the results based on the `from` and `size` parameters. Filters out * documents that have a relevance score less than min_score. + * * @param rankFeatureDocs documents to process + * @param rerankedScores {@code true} if the document scores have been reranked */ - public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) { - RankFeatureDoc[] sortedDocs = preprocess(rankFeatureDocs); + public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs, boolean rerankedScores) { + RankFeatureDoc[] sortedDocs = preprocess(rankFeatureDocs, rerankedScores); RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, sortedDocs.length - from))]; for (int rank = 0; rank < topResults.length; ++rank) { topResults[rank] = sortedDocs[from + rank]; diff --git a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java index 91620e323067b..53045710b5ec1 100644 --- a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java @@ -790,7 +790,7 @@ public void computeRankScoresForGlobalResults(RankFeatureDoc[] featureDocs, Acti } @Override - public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) { + public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs, boolean rerankedScores) { Arrays.sort(rankFeatureDocs, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, rankFeatureDocs.length - from))]; // perform pagination diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index f7dc706c4043e..19f87f8d630d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -48,7 +48,7 @@ public class TextSimilarityRankBuilder extends RankBuilder { private final String inferenceText; private final String field; private final Float minScore; - private final boolean lenient; + private final boolean failuresAllowed; public TextSimilarityRankBuilder( String field, @@ -56,14 +56,14 @@ public TextSimilarityRankBuilder( String inferenceText, int rankWindowSize, Float minScore, - boolean lenient + boolean failuresAllowed ) { super(rankWindowSize); this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; this.minScore = minScore; - this.lenient = lenient; + this.failuresAllowed = failuresAllowed; } public TextSimilarityRankBuilder(StreamInput in) throws IOException { @@ -73,10 +73,10 @@ public TextSimilarityRankBuilder(StreamInput in) throws IOException { this.inferenceText = in.readString(); this.field = in.readString(); this.minScore = in.readOptionalFloat(); - if (in.getTransportVersion().onOrAfter(TransportVersions.LENIENT_RERANKERS)) { - this.lenient = in.readBoolean(); + if (in.getTransportVersion().onOrAfter(TransportVersions.RERANKER_FAILURES_ALLOWED)) { + this.failuresAllowed = in.readBoolean(); } else { - this.lenient = false; + this.failuresAllowed = false; } } @@ -97,8 +97,8 @@ public void doWriteTo(StreamOutput out) throws IOException { out.writeString(inferenceText); out.writeString(field); out.writeOptionalFloat(minScore); - if (out.getTransportVersion().onOrAfter(TransportVersions.LENIENT_RERANKERS)) { - out.writeBoolean(lenient); + if (out.getTransportVersion().onOrAfter(TransportVersions.RERANKER_FAILURES_ALLOWED)) { + out.writeBoolean(failuresAllowed); } } @@ -164,7 +164,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo inferenceId, inferenceText, minScore, - lenient + failuresAllowed ); } @@ -184,8 +184,8 @@ public Float minScore() { return minScore; } - public boolean isLenient() { - return lenient; + public boolean failuresAllowed() { + return failuresAllowed; } @Override @@ -195,11 +195,11 @@ protected boolean doEquals(RankBuilder other) { && Objects.equals(inferenceText, that.inferenceText) && Objects.equals(field, that.field) && Objects.equals(minScore, that.minScore) - && lenient == that.lenient; + && failuresAllowed == that.failuresAllowed; } @Override protected int doHashCode() { - return Objects.hash(inferenceId, inferenceText, field, minScore, lenient); + return Objects.hash(inferenceId, inferenceText, field, minScore, failuresAllowed); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index f86c2b0666dc2..2a41f1d03cf51 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -44,9 +44,9 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext( String inferenceId, String inferenceText, Float minScore, - boolean lenient + boolean failuresAllowed ) { - super(size, from, rankWindowSize, lenient); + super(size, from, rankWindowSize, failuresAllowed); this.client = client; this.inferenceId = inferenceId; this.inferenceText = inferenceText; @@ -127,10 +127,17 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener docs = new ArrayList<>(originalDocs.length); for (RankFeatureDoc doc : originalDocs) { if (minScore == null || doc.score >= minScore) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index bf27308c03fe1..04ff271f4c784 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -44,7 +44,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text"); public static final ParseField FIELD_FIELD = new ParseField("field"); - public static final ParseField LENIENT_FIELD = new ParseField("lenient"); + public static final ParseField FAILURES_ALLOWED_FIELD = new ParseField("failures_allowed"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> { @@ -53,9 +53,16 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder String inferenceText = (String) args[2]; String field = (String) args[3]; int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; - boolean lenient = args[5] != null && (Boolean) args[5]; - - return new TextSimilarityRankRetrieverBuilder(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize, lenient); + boolean failuresAllowed = args[5] != null && (Boolean) args[5]; + + return new TextSimilarityRankRetrieverBuilder( + retrieverBuilder, + inferenceId, + inferenceText, + field, + rankWindowSize, + failuresAllowed + ); }); static { @@ -68,7 +75,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder PARSER.declareString(constructorArg(), INFERENCE_TEXT_FIELD); PARSER.declareString(constructorArg(), FIELD_FIELD); PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); - PARSER.declareBoolean(optionalConstructorArg(), LENIENT_FIELD); + PARSER.declareBoolean(optionalConstructorArg(), FAILURES_ALLOWED_FIELD); RetrieverBuilder.declareBaseParserFields(TextSimilarityRankBuilder.NAME, PARSER); } @@ -87,7 +94,7 @@ public static TextSimilarityRankRetrieverBuilder fromXContent( private final String inferenceId; private final String inferenceText; private final String field; - private final boolean lenient; + private final boolean failuresAllowed; public TextSimilarityRankRetrieverBuilder( RetrieverBuilder retrieverBuilder, @@ -95,13 +102,13 @@ public TextSimilarityRankRetrieverBuilder( String inferenceText, String field, int rankWindowSize, - boolean lenient + boolean failuresAllowed ) { super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize); this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; - this.lenient = lenient; + this.failuresAllowed = failuresAllowed; } public TextSimilarityRankRetrieverBuilder( @@ -111,7 +118,7 @@ public TextSimilarityRankRetrieverBuilder( String field, int rankWindowSize, Float minScore, - boolean lenient, + boolean failuresAllowed, String retrieverName, List preFilterQueryBuilders ) { @@ -123,7 +130,7 @@ public TextSimilarityRankRetrieverBuilder( this.inferenceText = inferenceText; this.field = field; this.minScore = minScore; - this.lenient = lenient; + this.failuresAllowed = failuresAllowed; this.retrieverName = retrieverName; this.preFilterQueryBuilders = preFilterQueryBuilders; } @@ -140,7 +147,7 @@ protected TextSimilarityRankRetrieverBuilder clone( field, rankWindowSize, minScore, - lenient, + failuresAllowed, retrieverName, newPreFilterQueryBuilders ); @@ -171,7 +178,9 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b @Override protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) { - sourceBuilder.rankBuilder(new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, lenient)); + sourceBuilder.rankBuilder( + new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed) + ); return sourceBuilder; } @@ -188,8 +197,8 @@ public int rankWindowSize() { return rankWindowSize; } - public boolean isLenient() { - return lenient; + public boolean failuresAllowed() { + return failuresAllowed; } @Override @@ -199,8 +208,8 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText); builder.field(FIELD_FIELD.getPreferredName(), field); builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); - if (lenient) { - builder.field(LENIENT_FIELD.getPreferredName(), lenient); + if (failuresAllowed) { + builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), failuresAllowed); } } @@ -213,11 +222,11 @@ public boolean doEquals(Object other) { && Objects.equals(field, that.field) && rankWindowSize == that.rankWindowSize && Objects.equals(minScore, that.minScore) - && lenient == that.lenient; + && failuresAllowed == that.failuresAllowed; } @Override public int doHashCode() { - return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, lenient); + return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java index 2f46d631350bf..9ef798f3fca38 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java @@ -120,7 +120,7 @@ public void testParserDefaults() throws IOException { ); assertThat(parsed.rankWindowSize(), equalTo(DEFAULT_RANK_WINDOW_SIZE)); assertThat(parsed.inferenceId(), equalTo(DEFAULT_RERANK_ID)); - assertThat(parsed.isLenient(), equalTo(false)); + assertThat(parsed.failuresAllowed(), equalTo(false)); } } @@ -152,7 +152,7 @@ public void testTextSimilarityRetrieverParsing() throws IOException { TextSimilarityRankRetrieverBuilder parsed = (TextSimilarityRankRetrieverBuilder) source.retriever(); assertThat(parsed.minScore(), equalTo(20f)); assertThat(parsed.retrieverName(), equalTo("foo_reranker")); - assertThat(parsed.isLenient(), equalTo(true)); + assertThat(parsed.failuresAllowed(), equalTo(true)); try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) { SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent( parseSerialized, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java index 1f3f333f09e10..dd23e91d36187 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java @@ -86,7 +86,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo inferenceId, inferenceText, minScore, - isLenient() + failuresAllowed() ) { @Override protected InferenceAction.Request generateRequest(List docFeatures) { @@ -203,15 +203,15 @@ public void testLenientRerankInference() { ) .setQuery(QueryBuilders.matchAllQuery()), response -> { - // these will all have a score of 2 (default 1 + normalization) + // these will all have a score of 1, the score from matchAllQuery assertThat( response.getHits().getHits(), arrayContaining( - searchHitWith(1, 2.0f, "0"), - searchHitWith(2, 2.0f, "1"), - searchHitWith(3, 2.0f, "2"), - searchHitWith(4, 2.0f, "3"), - searchHitWith(5, 2.0f, "4") + searchHitWith(1, 1.0f, "0"), + searchHitWith(2, 1.0f, "1"), + searchHitWith(3, 1.0f, "2"), + searchHitWith(4, 1.0f, "3"), + searchHitWith(5, 1.0f, "4") ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index 7ab22eae07473..7e8b29523beb8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -217,7 +217,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo inferenceId, inferenceText, minScore, - isLenient() + failuresAllowed() ) { @Override protected InferenceAction.Request generateRequest(List docFeatures) { From 9bb1d20802df23ce03f48873422b32f5b5f344be Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 7 Feb 2025 08:40:03 +0000 Subject: [PATCH 08/22] Use correct exception --- .../rank/textsimilarity/TextSimilarityRankBuilder.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index 19f87f8d630d9..8c76db40f832f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -27,7 +27,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; -import java.io.UnsupportedEncodingException; import java.util.List; import java.util.Objects; @@ -104,7 +103,7 @@ public void doWriteTo(StreamOutput out) throws IOException { @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { - throw new UnsupportedEncodingException("This should not be XContent serialized"); + throw new UnsupportedOperationException("This should not be XContent serialized"); } @Override From fd396d9d3e59ce2ce108de61ed10d549660e312f Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 7 Feb 2025 09:52:44 +0000 Subject: [PATCH 09/22] Update xcontent field name --- docs/changelog/121784.yaml | 4 ++-- .../textsimilarity/TextSimilarityRankRetrieverBuilder.java | 2 +- .../TextSimilarityRankRetrieverBuilderTests.java | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/changelog/121784.yaml b/docs/changelog/121784.yaml index beec96a8e1f0f..2ad240d400eb9 100644 --- a/docs/changelog/121784.yaml +++ b/docs/changelog/121784.yaml @@ -1,5 +1,5 @@ pr: 121784 -summary: Add a lenient option to text similarity reranking +summary: Optionally allow text similarity reranking to fail area: Search type: enhancement -issues: [] +issues: [116796] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 04ff271f4c784..b52419602667c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -44,7 +44,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text"); public static final ParseField FIELD_FIELD = new ParseField("field"); - public static final ParseField FAILURES_ALLOWED_FIELD = new ParseField("failures_allowed"); + public static final ParseField FAILURES_ALLOWED_FIELD = new ParseField("allow_rerank_failures"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java index 9ef798f3fca38..9977da9044d44 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java @@ -140,7 +140,7 @@ public void testTextSimilarityRetrieverParsing() throws IOException { "inference_text": "my-inference-text", "rank_window_size": 100, "min_score": 20.0, - "lenient": true, + "allow_rerank_failures": true, "_name": "foo_reranker" } } From 3ee20eda1d4df2c13250c3932d3723398697ed9c Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 7 Feb 2025 10:21:46 +0000 Subject: [PATCH 10/22] Add another test --- .../rank/rerank/AbstractRerankerIT.java | 2 +- .../TextSimilarityRankMultiNodeTests.java | 60 ++++++++++++++++++- .../TextSimilarityTestPlugin.java | 4 +- 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java b/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java index ad4e5842629e7..21f5e66d907c7 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java +++ b/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java @@ -496,7 +496,7 @@ public void testRankFeaturePhaseCoordinatorThrowingAllShardsFail() throws Except assertNoOpenContext(indexName); } - private void assertNoOpenContext(final String indexName) throws Exception { + protected void assertNoOpenContext(final String indexName) throws Exception { assertBusy( () -> assertThat(indicesAdmin().prepareStats(indexName).get().getTotal().getSearch().getOpenContexts(), equalTo(0L)), 1, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java index 18415b9ed6ca3..ba5b6d8a24b78 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java @@ -15,6 +15,11 @@ import java.util.Collection; import java.util.List; +import static org.elasticsearch.index.query.QueryBuilders.boolQuery; +import static org.elasticsearch.index.query.QueryBuilders.matchQuery; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse; + public class TextSimilarityRankMultiNodeTests extends AbstractRerankerIT { private static final String inferenceId = "inference-id"; @@ -28,13 +33,22 @@ protected RankBuilder getRankBuilder(int rankWindowSize, String rankFeatureField @Override protected RankBuilder getThrowingRankBuilder(int rankWindowSize, String rankFeatureField, ThrowingRankBuilderType type) { + return getThrowingRankBuilder(rankWindowSize, rankFeatureField, type, false); + } + + protected RankBuilder getThrowingRankBuilder( + int rankWindowSize, + String rankFeatureField, + ThrowingRankBuilderType type, + boolean failuresAllowed + ) { return new TextSimilarityTestPlugin.ThrowingMockRequestActionBasedRankBuilder( rankWindowSize, rankFeatureField, inferenceId, inferenceText, minScore, - false, + failuresAllowed, type.name() ); } @@ -52,6 +66,50 @@ public void testQueryPhaseCoordinatorThrowingAllShardsFail() throws Exception { // no-op } + public void testRerankerAllowedFailureNoExceptions() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + assertNoFailuresAndResponse( + prepareSearch().setQuery( + boolQuery().should(matchQuery(searchField, "A")) + .should(matchQuery(searchField, "B")) + .should(matchQuery(searchField, "C")) + .should(matchQuery(searchField, "D")) + .should(matchQuery(searchField, "E")) + ) + .setRankBuilder( + getThrowingRankBuilder( + rankWindowSize, + rankFeatureField, + ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT, + true + ) + ) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10), + response -> { + // just check it returns 5 documents, the order will be random due to not getting reranked + assertHitCount(response, 5L); + } + ); + assertNoOpenContext(indexName); + } + @Override protected boolean shouldCheckScores() { return false; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index 7e8b29523beb8..44c80b2a94b12 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -174,10 +174,10 @@ public ThrowingMockRequestActionBasedRankBuilder( String inferenceId, String inferenceText, Float minScore, - boolean lenient, + boolean failuresAllowed, String throwingType ) { - super(field, inferenceId, inferenceText, rankWindowSize, minScore, lenient); + super(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed); this.throwingRankBuilderType = AbstractRerankerIT.ThrowingRankBuilderType.valueOf(throwingType); } From d93a8dcab29864cb6127f6a26f00a7e916265061 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 7 Feb 2025 10:30:00 +0000 Subject: [PATCH 11/22] [CI] Auto commit changes from spotless --- .../TextSimilarityRankMultiNodeTests.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java index ba5b6d8a24b78..60b52f1c6f03f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java @@ -84,12 +84,12 @@ public void testRerankerAllowedFailureNoExceptions() throws Exception { assertNoFailuresAndResponse( prepareSearch().setQuery( - boolQuery().should(matchQuery(searchField, "A")) - .should(matchQuery(searchField, "B")) - .should(matchQuery(searchField, "C")) - .should(matchQuery(searchField, "D")) - .should(matchQuery(searchField, "E")) - ) + boolQuery().should(matchQuery(searchField, "A")) + .should(matchQuery(searchField, "B")) + .should(matchQuery(searchField, "C")) + .should(matchQuery(searchField, "D")) + .should(matchQuery(searchField, "E")) + ) .setRankBuilder( getThrowingRankBuilder( rankWindowSize, From 551f640224f9eaeb02a3c96f9dd55af6366446d8 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 7 Feb 2025 10:35:41 +0000 Subject: [PATCH 12/22] Some more test tweaks --- ...aturePhaseRankCoordinatorContextTests.java | 26 +++---------------- .../TextSimilarityRankTests.java | 15 ++++++----- 2 files changed, 11 insertions(+), 30 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java index 8d1daa67d1194..717fb9437ad52 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java @@ -7,13 +7,13 @@ package org.elasticsearch.xpack.inference.rank.textsimilarity; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.inference.TaskType; import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; @@ -44,17 +44,7 @@ public void testComputeScores() { featureDoc3.featureData("text 3"); RankFeatureDoc[] featureDocs = new RankFeatureDoc[] { featureDoc1, featureDoc2, featureDoc3 }; - subject.computeScores(featureDocs, new ActionListener<>() { - @Override - public void onResponse(float[] floats) { - assertArrayEquals(new float[] { 1.0f, 3.0f, 2.0f }, floats, 0.0f); - } - - @Override - public void onFailure(Exception e) { - fail(); - } - }); + subject.computeScores(featureDocs, assertNoFailureListener(f -> assertArrayEquals(new float[] { 1.0f, 3.0f, 2.0f }, f, 0.0f))); verify(mockClient).execute( eq(GetInferenceModelAction.INSTANCE), argThat(actionRequest -> ((GetInferenceModelAction.Request) actionRequest).getTaskType().equals(TaskType.RERANK)), @@ -63,17 +53,7 @@ public void onFailure(Exception e) { } public void testComputeScoresForEmpty() { - subject.computeScores(new RankFeatureDoc[0], new ActionListener<>() { - @Override - public void onResponse(float[] floats) { - assertArrayEquals(new float[0], floats, 0.0f); - } - - @Override - public void onFailure(Exception e) { - fail(); - } - }); + subject.computeScores(new RankFeatureDoc[0], assertNoFailureListener(f -> assertArrayEquals(new float[0], f, 0.0f))); verify(mockClient).execute( eq(GetInferenceModelAction.INSTANCE), argThat(actionRequest -> ((GetInferenceModelAction.Request) actionRequest).getTaskType().equals(TaskType.RERANK)), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java index dd23e91d36187..540d587c58d6d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java @@ -30,6 +30,8 @@ import java.util.Map; import static org.elasticsearch.test.LambdaMatchers.transformedMatch; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasRank; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasScore; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.containsString; @@ -69,10 +71,9 @@ public InferenceResultCountAcceptingTextSimilarityRankBuilder( String inferenceText, int rankWindowSize, Float minScore, - boolean lenient, int inferenceResultCount ) { - super(field, inferenceId, inferenceText, rankWindowSize, minScore, lenient); + super(field, inferenceId, inferenceText, rankWindowSize, minScore, false); this.inferenceResultCount = inferenceResultCount; } @@ -186,9 +187,9 @@ public void testRerankInferenceFailure() { ); } - public void testLenientRerankInference() { + public void testRerankInferenceAllowedFailure() { ElasticsearchAssertions.assertNoFailuresAndResponse( - // Execute search with text similarity reranking + // Execute search with text similarity reranking that fails, but it is allowed client.prepareSearch() .setRankBuilder( new TextSimilarityTestPlugin.ThrowingMockRequestActionBasedRankBuilder( @@ -244,7 +245,7 @@ public void testRerankInputSizeAndInferenceResultsMismatch() { client.prepareSearch() .setRankBuilder( // Simulate reranker returning different number of results from input - new InferenceResultCountAcceptingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, false, 4) + new InferenceResultCountAcceptingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, 4) ) .setQuery(QueryBuilders.matchAllQuery()) ); @@ -254,8 +255,8 @@ public void testRerankInputSizeAndInferenceResultsMismatch() { private static Matcher searchHitWith(int expectedRank, float expectedScore, String expectedText) { return allOf( - transformedMatch(SearchHit::getRank, equalTo(expectedRank)), - transformedMatch(SearchHit::getScore, equalTo(expectedScore)), + hasRank(expectedRank), + hasScore(expectedScore), transformedMatch(hit -> hit.getSourceAsMap().get("text"), equalTo(expectedText)) ); } From df8b1aca5054eb1e76b5a1b0257ecd2df00534be Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 7 Feb 2025 11:04:38 +0000 Subject: [PATCH 13/22] Update docs/changelog/121784.yaml --- docs/changelog/121784.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog/121784.yaml b/docs/changelog/121784.yaml index 2ad240d400eb9..c336205767803 100644 --- a/docs/changelog/121784.yaml +++ b/docs/changelog/121784.yaml @@ -2,4 +2,4 @@ pr: 121784 summary: Optionally allow text similarity reranking to fail area: Search type: enhancement -issues: [116796] +issues: [] From 108b89c2fa339dc10d96e48f3386ebd06ac1b4ee Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 10 Feb 2025 12:19:43 +0000 Subject: [PATCH 14/22] Actually check scores are passed through --- .../rank/rerank/AbstractRerankerIT.java | 3 ++- .../TextSimilarityRankMultiNodeTests.java | 24 ++++++++++++++----- .../TextSimilarityRankTests.java | 23 ++++++++++++------ 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java b/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java index 21f5e66d907c7..2dc35b0bdf904 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java +++ b/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java @@ -30,6 +30,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasId; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasRank; +import static org.hamcrest.Matchers.arrayWithSize; import static org.hamcrest.Matchers.equalTo; /** @@ -403,7 +404,7 @@ public void testRankFeaturePhaseShardThrowingPartialFailures() throws Exception .allMatch(failure -> failure.getCause().getMessage().contains("rfs - simulated failure")) ); assertHitCount(response, 5); - assertTrue(response.getHits().getHits().length == 0); + assertThat(response.getHits().getHits(), arrayWithSize(5)); } ); assertNoOpenContext(indexName); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java index 60b52f1c6f03f..4acebd9c956b1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.rank.textsimilarity; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.rerank.AbstractRerankerIT; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; @@ -16,9 +17,12 @@ import java.util.List; import static org.elasticsearch.index.query.QueryBuilders.boolQuery; +import static org.elasticsearch.index.query.QueryBuilders.constantScoreQuery; import static org.elasticsearch.index.query.QueryBuilders.matchQuery; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasId; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasRank; public class TextSimilarityRankMultiNodeTests extends AbstractRerankerIT { @@ -84,11 +88,11 @@ public void testRerankerAllowedFailureNoExceptions() throws Exception { assertNoFailuresAndResponse( prepareSearch().setQuery( - boolQuery().should(matchQuery(searchField, "A")) - .should(matchQuery(searchField, "B")) - .should(matchQuery(searchField, "C")) - .should(matchQuery(searchField, "D")) - .should(matchQuery(searchField, "E")) + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(10)) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(20)) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(30)) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(40)) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(50)) ) .setRankBuilder( getThrowingRankBuilder( @@ -103,8 +107,16 @@ public void testRerankerAllowedFailureNoExceptions() throws Exception { .setAllowPartialSearchResults(true) .setSize(10), response -> { - // just check it returns 5 documents, the order will be random due to not getting reranked assertHitCount(response, 5L); + int rank = 1; + for (SearchHit searchHit : response.getHits().getHits()) { + int id = 5 - (rank - 1); + assertThat(searchHit, hasId(String.valueOf(id))); + assertThat(searchHit, hasRank(rank)); + assertNotNull(searchHit.getFields().get(searchField)); + assertEquals(id * 10, searchHit.getScore(), 0f); + rank++; + } } ); assertNoOpenContext(indexName); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java index 540d587c58d6d..6f4b00c5768fa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java @@ -29,6 +29,9 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.index.query.QueryBuilders.boolQuery; +import static org.elasticsearch.index.query.QueryBuilders.constantScoreQuery; +import static org.elasticsearch.index.query.QueryBuilders.matchQuery; import static org.elasticsearch.test.LambdaMatchers.transformedMatch; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasRank; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasScore; @@ -202,17 +205,23 @@ public void testRerankInferenceAllowedFailure() { AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name() ) ) - .setQuery(QueryBuilders.matchAllQuery()), + .setQuery( + boolQuery().should(constantScoreQuery(matchQuery("text", "0")).boost(50)) + .should(constantScoreQuery(matchQuery("text", "1")).boost(40)) + .should(constantScoreQuery(matchQuery("text", "2")).boost(30)) + .should(constantScoreQuery(matchQuery("text", "3")).boost(20)) + .should(constantScoreQuery(matchQuery("text", "4")).boost(10)) + ), response -> { - // these will all have a score of 1, the score from matchAllQuery + // these will all have the scores from the constant score clauses assertThat( response.getHits().getHits(), arrayContaining( - searchHitWith(1, 1.0f, "0"), - searchHitWith(2, 1.0f, "1"), - searchHitWith(3, 1.0f, "2"), - searchHitWith(4, 1.0f, "3"), - searchHitWith(5, 1.0f, "4") + searchHitWith(1, 50, "0"), + searchHitWith(2, 40, "1"), + searchHitWith(3, 30, "2"), + searchHitWith(4, 20, "3"), + searchHitWith(5, 10, "4") ) ); } From 6cde90659da1436a9770f7eefacdb33ac557d787 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 10 Feb 2025 13:01:38 +0000 Subject: [PATCH 15/22] Remove clone --- .../rank/context/RankFeaturePhaseRankCoordinatorContext.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java index f9a4e8c8dd162..819d04e12eeeb 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java @@ -55,9 +55,8 @@ public boolean failuresAllowed() { * @param rerankedScores {@code true} if the document scores have been reranked */ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs, boolean rerankedScores) { - RankFeatureDoc[] sorted = originalDocs.clone(); - Arrays.sort(sorted, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); - return sorted; + Arrays.sort(originalDocs, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); + return originalDocs; } /** From bfbcca3aa09f85e0b61f35cbe0bd9c5f607975f0 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 10 Feb 2025 13:25:47 +0000 Subject: [PATCH 16/22] Fix test checks --- .../search/rank/rerank/AbstractRerankerIT.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java b/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java index 2dc35b0bdf904..954019f0055c8 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java +++ b/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java @@ -31,7 +31,9 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasId; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasRank; import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.emptyArray; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; /** * this base class acts as a wrapper for testing different rerankers, and their behavior when exceptions are thrown @@ -191,7 +193,7 @@ public void testRerankerPaginationOutsideOfBounds() throws Exception { .setFrom(10), response -> { assertHitCount(response, 5L); - assertEquals(0, response.getHits().getHits().length); + assertThat(response.getHits().getHits(), emptyArray()); } ); assertNoOpenContext(indexName); @@ -227,7 +229,7 @@ public void testNotAllShardsArePresentInFetchPhase() throws Exception { .setSize(2), response -> { assertHitCount(response, 4L); - assertEquals(2, response.getHits().getHits().length); + assertThat(response.getHits().getHits(), arrayWithSize(2)); int rank = 1; for (SearchHit searchHit : response.getHits().getHits()) { assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1)))); @@ -398,13 +400,13 @@ public void testRankFeaturePhaseShardThrowingPartialFailures() throws Exception .setAllowPartialSearchResults(true) .setSize(10), response -> { - assertTrue(response.getFailedShards() > 0); + assertThat(response.getFailedShards(), greaterThan(0)); assertTrue( Arrays.stream(response.getShardFailures()) .allMatch(failure -> failure.getCause().getMessage().contains("rfs - simulated failure")) ); assertHitCount(response, 5); - assertThat(response.getHits().getHits(), arrayWithSize(5)); + assertThat(response.getHits().getHits(), emptyArray()); } ); assertNoOpenContext(indexName); From f4f13c9df802c7a23db2d9293fc29ab8331d7165 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 10 Feb 2025 15:01:05 +0000 Subject: [PATCH 17/22] Set scores where there are none --- .../action/search/RankFeaturePhase.java | 20 +++-- ...ith_text_similarity_reranker_retriever.yml | 80 +++++++++++++++++++ 2 files changed, 93 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index fadbebadef097..398e3bbac17dd 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -20,6 +20,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.dfs.AggregatedDfs; import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.search.rank.feature.RankFeatureResult; @@ -187,7 +188,7 @@ private void onPhaseDone( new ActionListener<>() { @Override public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) { - RankFeatureDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(docsWithUpdatedScores, true); + RankDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(docsWithUpdatedScores, true); SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = newReducedQueryPhaseResults( reducedQueryPhase, topResults @@ -200,13 +201,18 @@ public void onFailure(Exception e) { if (rankFeaturePhaseRankCoordinatorContext.failuresAllowed()) { // TODO: handle the exception somewhere // don't want to log the entire stack trace, it's not helpful here - logger.warn("Exception computing updated ranks: {}. Continuing with existing ranks.", e.toString()); + logger.warn("Exception computing updated ranks, continuing with existing ranks: {}", e.toString()); // use the existing score docs as-is - RankFeatureDoc[] existingScores = Arrays.stream(reducedQueryPhase.sortedTopDocs().scoreDocs()) - .map(sd -> new RankFeatureDoc(sd.doc, sd.score, sd.shardIndex)) - .toArray(RankFeatureDoc[]::new); - - RankFeatureDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(existingScores, false); + // downstream things expect every doc to have a score, so we need to infer a score here + // if the doc doesn't otherwise have a score. We can use the rank. + ScoreDoc[] inputDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); + // use RankDoc to indicate there was a problem using the specified features + RankFeatureDoc[] rankDocs = new RankFeatureDoc[inputDocs.length]; + for (int i = 0; i < inputDocs.length; i++) { + ScoreDoc doc = inputDocs[i]; + rankDocs[i] = new RankFeatureDoc(doc.doc, Float.isNaN(doc.score) ? 1f / (i+1) : doc.score, doc.shardIndex); + } + RankDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(rankDocs, false); SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = newReducedQueryPhaseResults( reducedQueryPhase, topResults diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml index 98c4ae9e642d1..8247fbfe9535c 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml @@ -334,3 +334,83 @@ setup: - match: {hits.hits.0._explanation.details.1.description: "/rrf.score:.\\[0.5\\].*/" } - match: {hits.hits.0._explanation.details.1.details.0.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[text\\].*/" } - match: {hits.hits.0._explanation.details.1.details.0.details.0.description: "/weight.*astronomy.*/" } + +--- +"rrf retriever with failed text similarity reranker": + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: + rrf: { + retrievers: + [ + { + standard: { + query: { + bool: { + should: + [ + { + constant_score: { + filter: { + term: { + integer: 1 + } + }, + boost: 10 + } + }, + { + constant_score: + { + filter: + { + term: + { + integer: 2 + } + }, + boost: 1 + } + } + ] + } + } + } + }, + { + text_similarity_reranker: { + retriever: + { + standard: { + query: { + match_all: {} + }, + sort: { + integer: "asc" + } + } + }, + rank_window_size: 10, + inference_id: failure-rerank-model, + inference_text: "How often does the moon hide the sun?", + field: text, + allow_rerank_failures: true + } + } + ], + rank_window_size: 10, + rank_constant: 1 + } + size: 10 + + - match: { hits.total.value: 3 } + - length: { hits.hits: 3 } + + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } + - match: { hits.hits.2._id: "doc_3" } From 559aa322ad215f6f3ac84836793aedc1d16863e8 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 10 Feb 2025 16:54:50 +0000 Subject: [PATCH 18/22] Stray comment --- .../java/org/elasticsearch/action/search/RankFeaturePhase.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index 398e3bbac17dd..a06d008ae48a1 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -204,9 +204,8 @@ public void onFailure(Exception e) { logger.warn("Exception computing updated ranks, continuing with existing ranks: {}", e.toString()); // use the existing score docs as-is // downstream things expect every doc to have a score, so we need to infer a score here - // if the doc doesn't otherwise have a score. We can use the rank. + // if the doc doesn't otherwise have one. We can use the rank to infer a possible score instead (1/rank). ScoreDoc[] inputDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); - // use RankDoc to indicate there was a problem using the specified features RankFeatureDoc[] rankDocs = new RankFeatureDoc[inputDocs.length]; for (int i = 0; i < inputDocs.length; i++) { ScoreDoc doc = inputDocs[i]; From 7b81006e85cd6b98acf0f9c26f8bddb061b4a1d3 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 10 Feb 2025 16:55:12 +0000 Subject: [PATCH 19/22] splotless --- .../java/org/elasticsearch/action/search/RankFeaturePhase.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index a06d008ae48a1..25238711c5c1c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -209,7 +209,7 @@ public void onFailure(Exception e) { RankFeatureDoc[] rankDocs = new RankFeatureDoc[inputDocs.length]; for (int i = 0; i < inputDocs.length; i++) { ScoreDoc doc = inputDocs[i]; - rankDocs[i] = new RankFeatureDoc(doc.doc, Float.isNaN(doc.score) ? 1f / (i+1) : doc.score, doc.shardIndex); + rankDocs[i] = new RankFeatureDoc(doc.doc, Float.isNaN(doc.score) ? 1f / (i + 1) : doc.score, doc.shardIndex); } RankDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(rankDocs, false); SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = newReducedQueryPhaseResults( From c3c08f63c022daeb65c640dcb1096312dd07b7a1 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 12 Feb 2025 09:38:57 +0000 Subject: [PATCH 20/22] Change to constant_score for determinism --- ...ith_text_similarity_reranker_retriever.yml | 39 +++++++++++++++++-- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml index 8247fbfe9535c..139bcf3fca8cf 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml @@ -388,10 +388,41 @@ setup: { standard: { query: { - match_all: {} - }, - sort: { - integer: "asc" + bool: { + should: + [ + { + constant_score: { + filter: { + term: { + integer: 1 + } + }, + boost: 3 + } + }, + { + constant_score: { + filter: { + term: { + integer: 2 + } + }, + boost: 2 + } + }, + { + constant_score: { + filter: { + term: { + integer: 3 + } + }, + boost: 1 + } + } + ] + } } } }, From b0670a268a5cac6ab920684d4f83216ca4d389c0 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 17 Feb 2025 10:32:37 +0000 Subject: [PATCH 21/22] Maintain existing sort order on failure --- ...ankFeaturePhaseRankCoordinatorContext.java | 3 +- .../70_text_similarity_rank_retriever.yml | 30 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 2a41f1d03cf51..a4c07deacf72c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -134,8 +134,7 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener docs = new ArrayList<>(originalDocs.length); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml index 9a6ecffe29d4d..48a73a5ebdec7 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml @@ -172,6 +172,36 @@ setup: field: text size: 10 +--- +"Text similarity reranking with allowed failure maintains custom sorting": + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: + text_similarity_reranker: + retriever: + standard: + query: + term: + topic: "science" + sort: + - "subtopic" + rank_window_size: 10 + inference_id: failing-rerank-model + inference_text: "science" + field: text + allow_rerank_failures: true + size: 10 + + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_2" } + - match: { hits.hits.1._id: "doc_1" } --- "text similarity reranking with explain": From 01654d5017b199dbc12790df838103b44972ad16 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 17 Feb 2025 11:38:51 +0000 Subject: [PATCH 22/22] Fix transport version --- server/src/main/java/org/elasticsearch/TransportVersions.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 74b7be707b1d2..1b748b616189f 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -197,7 +197,7 @@ static TransportVersion def(int id) { public static final TransportVersion SLM_UNHEALTHY_IF_NO_SNAPSHOT_WITHIN = def(9_010_0_00); public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS = def(9_011_0_00); public static final TransportVersion REMOVE_REPOSITORY_CONFLICT_MESSAGE = def(9_012_0_00); - public static final TransportVersion RERANKER_FAILURES_ALLOWED = def(9_012_0_00); + public static final TransportVersion RERANKER_FAILURES_ALLOWED = def(9_013_0_00); /* * STOP! READ THIS FIRST! No, really,