Skip to content

Commit

Permalink
Merge branch 'main' into close_121873
Browse files Browse the repository at this point in the history
  • Loading branch information
nik9000 authored Feb 6, 2025
2 parents 4396a7d + e24489f commit 12316e3
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 42 deletions.
12 changes: 12 additions & 0 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,18 @@ tests:
- class: org.elasticsearch.xpack.searchablesnapshots.FrozenSearchableSnapshotsIntegTests
method: testCreateAndRestorePartialSearchableSnapshot
issue: https://github.com/elastic/elasticsearch/issues/121927
- class: org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT
method: test {yaml=analysis-common/40_token_filters/stemmer_override file access}
issue: https://github.com/elastic/elasticsearch/issues/121625
- class: org.elasticsearch.test.rest.ClientYamlTestSuiteIT
method: test {yaml=update/100_synthetic_source/stored text}
issue: https://github.com/elastic/elasticsearch/issues/121964
- class: org.elasticsearch.test.rest.ClientYamlTestSuiteIT
method: test {yaml=update/100_synthetic_source/keyword}
issue: https://github.com/elastic/elasticsearch/issues/121965
- class: org.elasticsearch.xpack.esql.plugin.DataNodeRequestSenderTests
method: testDoNotRetryOnRequestLevelFailure
issue: https://github.com/elastic/elasticsearch/issues/121966

# Examples:
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@
package org.elasticsearch.xpack.inference.external.elastic;

import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceErrorResponseEntity;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;

import java.util.Locale;
import java.util.concurrent.Flow;

import static org.elasticsearch.core.Strings.format;

public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, true);
Expand All @@ -29,7 +33,8 @@ public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String reques
@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec
// EIS uses the unified API spec
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));

flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(openAiProcessor);
Expand All @@ -52,4 +57,30 @@ protected Exception buildError(String message, Request request, HttpResult resul
return super.buildError(message, request, result, errorResponse);
}
}

private static Exception buildMidStreamError(Request request, String message, Exception e) {
var errorResponse = ElasticInferenceServiceErrorResponseEntity.fromString(message);
if (errorResponse.errorStructureFound()) {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format(
"%s for request from inference entity id [%s]. Error message: [%s]",
SERVER_ERROR_OBJECT,
request.getInferenceEntityId(),
errorResponse.getErrorMessage()
),
"error",
"stream_error"
);
} else if (e != null) {
return UnifiedChatCompletionException.fromThrowable(e);
} else {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
"error",
"stream_error"
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
Expand All @@ -29,6 +30,8 @@
import java.util.Optional;
import java.util.concurrent.Flow;

import static org.elasticsearch.core.Strings.format;

public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, OpenAiErrorResponse::fromResponse);
Expand All @@ -37,7 +40,7 @@ public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponsePa
@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var openAiProcessor = new OpenAiUnifiedStreamingProcessor();
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));

flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(openAiProcessor);
Expand All @@ -64,6 +67,33 @@ protected Exception buildError(String message, Request request, HttpResult resul
}
}

private static Exception buildMidStreamError(Request request, String message, Exception e) {
var errorResponse = OpenAiErrorResponse.fromString(message);
if (errorResponse instanceof OpenAiErrorResponse oer) {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format(
"%s for request from inference entity id [%s]. Error message: [%s]",
SERVER_ERROR_OBJECT,
request.getInferenceEntityId(),
errorResponse.getErrorMessage()
),
oer.type(),
oer.code(),
oer.param()
);
} else if (e != null) {
return UnifiedChatCompletionException.fromThrowable(e);
} else {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
"stream_error"
);
}
}

private static class OpenAiErrorResponse extends ErrorResponse {
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
"open_ai_error",
Expand Down Expand Up @@ -103,6 +133,19 @@ private static ErrorResponse fromResponse(HttpResult response) {
return ErrorResponse.UNDEFINED_ERROR;
}

private static ErrorResponse fromString(String response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response)
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}

return ErrorResponse.UNDEFINED_ERROR;
}

@Nullable
private final String code;
@Nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;

import java.io.IOException;
import java.util.ArrayDeque;
Expand All @@ -28,6 +29,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.function.BiFunction;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
Expand Down Expand Up @@ -57,7 +59,13 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<S
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
public static final String TOTAL_TOKENS_FIELD = "total_tokens";

private final BiFunction<String, Exception, Exception> errorParser;
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
private volatile boolean previousEventWasError = false;

public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
this.errorParser = errorParser;
}

@Override
protected void upstreamRequest(long n) {
Expand All @@ -71,7 +79,25 @@ protected void upstreamRequest(long n) {
@Override
protected void next(Deque<ServerSentEvent> item) throws Exception {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
var results = parseEvent(item, OpenAiUnifiedStreamingProcessor::parse, parserConfig, logger);

var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(item.size());
for (var event : item) {
if (ServerSentEventField.EVENT == event.name() && "error".equals(event.value())) {
previousEventWasError = true;
} else if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
if (previousEventWasError) {
throw errorParser.apply(event.value(), null);
}

try {
var delta = parse(parserConfig, event);
delta.forEachRemaining(results::offer);
} catch (Exception e) {
logger.warn("Failed to parse event from inference provider: {}", event);
throw errorParser.apply(event.value(), e);
}
}
}

if (results.isEmpty()) {
upstream().request(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,26 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;

import java.io.IOException;

/**
* An example error response would look like
*
* <code>
* {
* "error": "some error"
* }
* </code>
*
*/
public class ElasticInferenceServiceErrorResponseEntity extends ErrorResponse {

private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceErrorResponseEntity.class);
Expand All @@ -24,24 +37,18 @@ private ElasticInferenceServiceErrorResponseEntity(String errorMessage) {
super(errorMessage);
}

/**
* An example error response would look like
*
* <code>
* {
* "error": "some error"
* }
* </code>
*
* @param response The error response
* @return An error entity if the response is JSON with the above structure
* or {@link ErrorResponse#UNDEFINED_ERROR} if the error field wasn't found
*/
public static ErrorResponse fromResponse(HttpResult response) {
try (
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
) {
return fromParser(
() -> XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())
);
}

public static ErrorResponse fromString(String response) {
return fromParser(() -> XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response));
}

private static ErrorResponse fromParser(CheckedSupplier<XContentParser, IOException> jsonParserFactory) {
try (XContentParser jsonParser = jsonParserFactory.get()) {
var responseMap = jsonParser.map();
var error = (String) responseMap.get("error");
if (error != null) {
Expand All @@ -50,7 +57,6 @@ public static ErrorResponse fromResponse(HttpResult response) {
} catch (Exception e) {
logger.debug("Failed to parse error response", e);
}

return ErrorResponse.UNDEFINED_ERROR;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -970,14 +970,51 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
}

public void testUnifiedCompletionError() throws Exception {
testUnifiedStreamError(404, """
{
"error": "The model `rainbow-sprinkles` does not exist or you do not have access to it."
}""", """
{\
"error":{\
"code":"not_found",\
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\
"type":"error"\
}}""");
}

public void testUnifiedCompletionErrorMidStream() throws Exception {
testUnifiedStreamError(200, """
data: { "error": "some error" }
""", """
{\
"error":{\
"code":"stream_error",\
"message":"Received an error response for request from inference entity id [id]. Error message: [some error]",\
"type":"error"\
}}""");
}

public void testUnifiedCompletionMalformedError() throws Exception {
testUnifiedStreamError(200, """
data: { i am not json }
""", """
{\
"error":{\
"code":"bad_request",\
"message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\
at [Source: (String)\\"{ i am not json }\\"; line: 1, column: 3]",\
"type":"x_content_parse_exception"\
}}""");
}

private void testUnifiedStreamError(int responseCode, String responseJson, String expectedJson) throws Exception {
var eisGatewayUrl = getUrl(webServer);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createService(senderFactory, eisGatewayUrl)) {
var responseJson = """
{
"error": "The model `rainbow-sprinkles` does not exist or you do not have access to it."
}""";
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
webServer.enqueue(new MockResponse().setResponseCode(responseCode).setBody(responseJson));
var model = new ElasticInferenceServiceCompletionModel(
"id",
TaskType.COMPLETION,
Expand Down Expand Up @@ -1012,14 +1049,7 @@ public void testUnifiedCompletionError() throws Exception {
});
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());

assertThat(json, is("""
{\
"error":{\
"code":"not_found",\
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\
"type":"error"\
}}"""));
assertThat(json, is(expectedJson));
}
});
}
Expand Down
Loading

0 comments on commit 12316e3

Please sign in to comment.