Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally allow text similarity reranking to fail #121784

Merged
merged 29 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ad91d47
Add a lenient option to text similarity reranking
thecoop Feb 5, 2025
b447b33
Update docs/changelog/121784.yaml
thecoop Feb 5, 2025
1ac18fe
propagate
thecoop Feb 6, 2025
22bd2a0
leniency is only part of the text similarity builders
thecoop Feb 6, 2025
505ae8b
Merge remote-tracking branch 'upstream/main' into lenient-rerankers
thecoop Feb 6, 2025
c47b8a5
Remove superfluous xcontent serialization
thecoop Feb 6, 2025
bc81bd0
Add a test for lenient rerankers
thecoop Feb 6, 2025
2ef4e4e
Rename lenient to failuresAllowed
thecoop Feb 7, 2025
9bb1d20
Use correct exception
thecoop Feb 7, 2025
781e736
Merge remote-tracking branch 'upstream/main' into lenient-rerankers
thecoop Feb 7, 2025
fd396d9
Update xcontent field name
thecoop Feb 7, 2025
3ee20ed
Add another test
thecoop Feb 7, 2025
d93a8dc
[CI] Auto commit changes from spotless
elasticsearchmachine Feb 7, 2025
551f640
Some more test tweaks
thecoop Feb 7, 2025
df8b1ac
Update docs/changelog/121784.yaml
thecoop Feb 7, 2025
108b89c
Actually check scores are passed through
thecoop Feb 10, 2025
6cde906
Remove clone
thecoop Feb 10, 2025
36a16e1
Merge branch 'main' into lenient-rerankers
thecoop Feb 10, 2025
bfbcca3
Fix test checks
thecoop Feb 10, 2025
f4f13c9
Set scores where there are none
thecoop Feb 10, 2025
559aa32
Stray comment
thecoop Feb 10, 2025
cbd13b9
Merge remote-tracking branch 'upstream/main' into lenient-rerankers
thecoop Feb 10, 2025
7b81006
splotless
thecoop Feb 10, 2025
3112cec
Merge remote-tracking branch 'upstream/main' into lenient-rerankers
thecoop Feb 12, 2025
c3c08f6
Change to constant_score for determinism
thecoop Feb 12, 2025
b0670a2
Maintain existing sort order on failure
thecoop Feb 17, 2025
773d610
Merge remote-tracking branch 'upstream/main' into lenient-rerankers
thecoop Feb 17, 2025
7fa15e6
Merge remote-tracking branch 'upstream/main' into lenient-rerankers
thecoop Feb 17, 2025
01654d5
Fix transport version
thecoop Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ static TransportVersion def(int id) {
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES = def(9_002_0_00);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED = def(9_003_0_00);

public static final TransportVersion REMOVE_DESIRED_NODE_VERSION = def(9_004_0_00);
public static final TransportVersion LENIENT_RERANKERS = def(9_005_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
import org.elasticsearch.transport.Transport;

import java.util.Arrays;
import java.util.List;

/**
Expand Down Expand Up @@ -181,6 +182,11 @@ private void onPhaseDone(
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext,
SearchPhaseController.ReducedQueryPhase reducedQueryPhase
) {
RankFeatureDoc[] docs = rankPhaseResults.getSuccessfulResults()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the thread this happens in - not sure if this matters or not?

.flatMap(r -> Arrays.stream(r.rankFeatureResult().shardResult().rankFeatureDocs))
.filter(rfd -> rfd.featureData != null)
.toArray(RankFeatureDoc[]::new);

ThreadedActionListener<RankFeatureDoc[]> rankResultListener = new ThreadedActionListener<>(
context::execute,
new ActionListener<>() {
Expand All @@ -196,21 +202,26 @@ public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) {

@Override
public void onFailure(Exception e) {
context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e);
if (rankFeaturePhaseRankCoordinatorContext.isLenient()) {
// TODO: handle the exception somewhere
logger.warn("Exception computing updated ranks. Continuing with existing ranks.", e);
// use the existing docs as-is
// AbstractThreadedActionListener forks onFailure to the same executor as onResponse,
// so we can just call this direct
onResponse(docs);
} else {
context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e);
}
}
}
);
rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults(
rankPhaseResults.getAtomicArray().asList().stream().map(SearchPhaseResult::rankFeatureResult).toList(),
rankResultListener
);
rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults(docs, rankResultListener);
}

private SearchPhaseController.ReducedQueryPhase newReducedQueryPhaseResults(
SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
ScoreDoc[] scoreDocs
) {

return new SearchPhaseController.ReducedQueryPhase(
reducedQueryPhase.totalHits(),
reducedQueryPhase.fetchHits(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Query;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -42,21 +43,32 @@
public abstract class RankBuilder implements VersionedNamedWriteable, ToXContentObject {

public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
public static final ParseField LENIENT_FIELD = new ParseField("lenient");

public static final int DEFAULT_RANK_WINDOW_SIZE = SearchService.DEFAULT_SIZE;

private final int rankWindowSize;
private final boolean lenient;

public RankBuilder(int rankWindowSize) {
public RankBuilder(int rankWindowSize, boolean lenient) {
this.rankWindowSize = rankWindowSize;
this.lenient = lenient;
}

public RankBuilder(StreamInput in) throws IOException {
rankWindowSize = in.readVInt();
if (in.getTransportVersion().onOrAfter(TransportVersions.LENIENT_RERANKERS)) {
lenient = in.readBoolean();
} else {
lenient = false;
}
}

public final void writeTo(StreamOutput out) throws IOException {
out.writeVInt(rankWindowSize);
if (out.getTransportVersion().onOrAfter(TransportVersions.LENIENT_RERANKERS)) {
out.writeBoolean(lenient);
}
doWriteTo(out);
}

Expand All @@ -67,6 +79,9 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params)
builder.startObject();
builder.startObject(getWriteableName());
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
if (lenient) {
builder.field(LENIENT_FIELD.getPreferredName(), lenient);
}
doXContent(builder, params);
builder.endObject();
builder.endObject();
Expand All @@ -79,6 +94,10 @@ public int rankWindowSize() {
return rankWindowSize;
}

public boolean isLenient() {
return lenient;
}

/**
* Specify whether this rank builder is a compound builder or not. A compound builder is a rank builder that requires
* two or more queries to be executed in order to generate the final result.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,17 @@ public abstract class RankFeaturePhaseRankCoordinatorContext {
protected final int size;
protected final int from;
protected final int rankWindowSize;
protected final boolean lenient;

public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) {
public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean lenient) {
this.size = size < 0 ? DEFAULT_SIZE : size;
this.from = from < 0 ? DEFAULT_FROM : from;
this.rankWindowSize = rankWindowSize;
this.lenient = lenient;
}

public boolean isLenient() {
return lenient;
}

/**
Expand All @@ -51,9 +57,9 @@ public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindow
* @param originalDocs documents to process
*/
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
return Arrays.stream(originalDocs)
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
.toArray(RankFeatureDoc[]::new);
RankFeatureDoc[] sorted = originalDocs.clone();
Arrays.sort(sorted, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed());
return sorted;
}

/**
Expand All @@ -64,16 +70,10 @@ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
* Once all the scores have been computed, we sort the results, perform any pagination needed, and then call the `onFinish` consumer
* with the final array of {@link ScoreDoc} results.
*
* @param rankSearchResults a list of rank feature results from each shard
* @param featureDocs an array of rank feature results from each shard
* @param rankListener a rankListener to handle the global ranking result
*/
public void computeRankScoresForGlobalResults(
List<RankFeatureResult> rankSearchResults,
ActionListener<RankFeatureDoc[]> rankListener
) {
// extract feature data from each shard rank-feature phase result
RankFeatureDoc[] featureDocs = extractFeatureDocs(rankSearchResults);

public void computeRankScoresForGlobalResults(RankFeatureDoc[] featureDocs, ActionListener<RankFeatureDoc[]> rankListener) {
// generate the final `topResults` results, and pass them to fetch phase through the `rankListener`
computeScores(featureDocs, rankListener.delegateFailureAndWrap((listener, scores) -> {
for (int i = 0; i < featureDocs.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ public void sendExecuteRankFeature(
}

private RankFeaturePhaseRankCoordinatorContext defaultRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) {
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize) {
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize, false) {

@Override
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
Expand All @@ -785,16 +785,8 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
}

@Override
public void computeRankScoresForGlobalResults(
List<RankFeatureResult> rankSearchResults,
ActionListener<RankFeatureDoc[]> rankListener
) {
List<RankFeatureDoc> features = new ArrayList<>();
for (RankFeatureResult rankFeatureResult : rankSearchResults) {
RankFeatureShardResult shardResult = rankFeatureResult.shardResult();
features.addAll(Arrays.stream(shardResult.rankFeatureDocs).toList());
}
rankListener.onResponse(features.toArray(new RankFeatureDoc[0]));
public void computeRankScoresForGlobalResults(RankFeatureDoc[] featureDocs, ActionListener<RankFeatureDoc[]> rankListener) {
rankListener.onResponse(featureDocs);
}

@Override
Expand Down Expand Up @@ -875,7 +867,7 @@ private RankBuilder rankBuilder(
RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext,
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext
) {
return new RankBuilder(rankWindowSize) {
return new RankBuilder(rankWindowSize, false) {
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
// no-op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
int from,
Client client
) {
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
@Override
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
float[] scores = new float[featureDocs.length];
Expand Down Expand Up @@ -831,7 +831,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
int from,
Client client
) {
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
@Override
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
throw new IllegalStateException("should have failed earlier");
Expand Down Expand Up @@ -947,7 +947,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
int from,
Client client
) {
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
@Override
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
float[] scores = new float[featureDocs.length];
Expand Down Expand Up @@ -1075,7 +1075,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
int from,
Client client
) {
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
@Override
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
float[] scores = new float[featureDocs.length];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public boolean isCancelled() {
}

private RankBuilder getRankBuilder(final String field) {
return new RankBuilder(DEFAULT_RANK_WINDOW_SIZE) {
return new RankBuilder(DEFAULT_RANK_WINDOW_SIZE, false) {
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
// no-op
Expand Down Expand Up @@ -171,7 +171,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
// no work to be done on the coordinator node for the rank feature phase
@Override
public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) {
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
@Override
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
throw new AssertionError("not expected");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public static TestRankBuilder randomRankBuilder() {
}

public TestRankBuilder(int windowSize) {
super(windowSize);
super(windowSize, false);
}

public TestRankBuilder(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public class RandomRankBuilder extends RankBuilder {
private final Integer seed;

public RandomRankBuilder(int rankWindowSize, String field, Integer seed) {
super(rankWindowSize);
super(rankWindowSize, false);

if (field == null || field.isEmpty()) {
throw new IllegalArgumentException("field is required");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.feature.RankFeatureDoc;

import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;

/**
Expand All @@ -24,7 +22,7 @@ public class RandomRankFeaturePhaseRankCoordinatorContext extends RankFeaturePha
private final Integer seed;

public RandomRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, Integer seed) {
super(size, from, rankWindowSize);
super(size, from, rankWindowSize, false);
this.seed = seed;
}

Expand All @@ -40,16 +38,4 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
}
scoreListener.onResponse(scores);
}

/**
* Sorts documents by score descending.
* @param originalDocs documents to process
*/
@Override
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
return Arrays.stream(originalDocs)
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
.toArray(RankFeatureDoc[]::new);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ public class TextSimilarityRankBuilder extends RankBuilder {
String field = (String) args[2];
Integer rankWindowSize = args[3] == null ? DEFAULT_RANK_WINDOW_SIZE : (Integer) args[3];
Float minScore = (Float) args[4];
boolean lenient = args[5] != null && (Boolean) args[5];

return new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore);
return new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, lenient);
});

static {
Expand All @@ -67,15 +68,23 @@ public class TextSimilarityRankBuilder extends RankBuilder {
PARSER.declareString(constructorArg(), FIELD_FIELD);
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
PARSER.declareFloat(optionalConstructorArg(), MIN_SCORE_FIELD);
PARSER.declareBoolean(optionalConstructorArg(), LENIENT_FIELD);
}

private final String inferenceId;
private final String inferenceText;
private final String field;
private final Float minScore;

public TextSimilarityRankBuilder(String field, String inferenceId, String inferenceText, int rankWindowSize, Float minScore) {
super(rankWindowSize);
public TextSimilarityRankBuilder(
String field,
String inferenceId,
String inferenceText,
int rankWindowSize,
Float minScore,
boolean lenient
) {
super(rankWindowSize, lenient);
this.inferenceId = inferenceId;
this.inferenceText = inferenceText;
this.field = field;
Expand Down Expand Up @@ -103,7 +112,7 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public void doWriteTo(StreamOutput out) throws IOException {
// rankWindowSize serialization is handled by the parent class RankBuilder
// rankWindowSize & lenient serialization is handled by the parent class RankBuilder
out.writeString(inferenceId);
out.writeString(inferenceText);
out.writeString(field);
Expand All @@ -112,7 +121,7 @@ public void doWriteTo(StreamOutput out) throws IOException {

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
// rankWindowSize serialization is handled by the parent class RankBuilder
// rankWindowSize & lenient serialization is handled by the parent class RankBuilder
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText);
builder.field(FIELD_FIELD.getPreferredName(), field);
Expand Down Expand Up @@ -177,7 +186,8 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
client,
inferenceId,
inferenceText,
minScore
minScore,
isLenient()
);
}

Expand Down
Loading