Skip to content

Commit 513ca39

Browse files
authored
throw exception if remote model doesn't return 2xx status code; fix p… (opensearch-project#1473)
* throw exception if remote model doesn't return 2xx status code; fix predict runner Signed-off-by: Yaliang Wu <[email protected]> * fix kmeans model deploy bug Signed-off-by: Yaliang Wu <[email protected]> * support multiple docs for remote embedding model Signed-off-by: Yaliang Wu <[email protected]> * fix ut Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]>
1 parent 5fc555d commit 513ca39

File tree

8 files changed

+92
-18
lines changed

8 files changed

+92
-18
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

+3
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
101101
throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST);
102102
}
103103
String modelResponse = responseBuilder.toString();
104+
if (statusCode < 200 || statusCode >= 300) {
105+
throw new OpenSearchStatusException(modelResponse, RestStatus.fromCode(statusCode));
106+
}
104107

105108
ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
106109
tensors.setStatusCode(statusCode);

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat
102102
docs.add(null);
103103
}
104104
}
105-
if (preProcessFunction.contains("${parameters")) {
105+
if (preProcessFunction.contains("${parameters.")) {
106106
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
107107
preProcessFunction = substitutor.replace(preProcessFunction);
108108
}
@@ -164,7 +164,7 @@ public static ModelTensors processOutput(String modelResponse, Connector connect
164164
// execute user defined painless script.
165165
Optional<String> processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse);
166166
String response = processedResponse.orElse(modelResponse);
167-
boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent();
167+
boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent() && org.opensearch.ml.common.utils.StringUtils.isJson(response);
168168
if (responseFilter == null) {
169169
connector.parseResponse(response, modelTensors, scriptReturnModelTensor);
170170
} else {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
import org.apache.http.entity.StringEntity;
1717
import org.apache.http.impl.client.CloseableHttpClient;
1818
import org.apache.http.util.EntityUtils;
19+
import org.opensearch.OpenSearchStatusException;
1920
import org.opensearch.ml.common.connector.Connector;
2021
import org.opensearch.ml.common.connector.HttpConnector;
2122
import org.opensearch.ml.common.exception.MLException;
2223
import org.opensearch.ml.common.input.MLInput;
2324
import org.opensearch.ml.common.output.model.ModelTensors;
2425
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
2526
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
27+
import org.opensearch.rest.RestStatus;
2628
import org.opensearch.script.ScriptService;
2729

2830
import java.security.AccessController;
@@ -103,9 +105,13 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
103105
return null;
104106
});
105107
String modelResponse = responseRef.get();
108+
Integer statusCode = statusCodeRef.get();
109+
if (statusCode < 200 || statusCode >= 300) {
110+
throw new OpenSearchStatusException(modelResponse, RestStatus.fromCode(statusCode));
111+
}
106112

107113
ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
108-
tensors.setStatusCode(statusCodeRef.get());
114+
tensors.setStatusCode(statusCode);
109115
tensorOutputs.add(tensors);
110116
} catch (RuntimeException e) {
111117
log.error("Fail to execute http connector", e);

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,15 @@ default ModelTensorOutput executePredict(MLInput mlInput) {
3232

3333
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
3434
TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
35-
List<String> textDocs = new ArrayList<>(textDocsInputDataSet.getDocs());
36-
preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tensorOutputs);
35+
int processedDocs = 0;
36+
while(processedDocs < textDocsInputDataSet.getDocs().size()) {
37+
List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
38+
List<ModelTensors> tempTensorOutputs = new ArrayList<>();
39+
preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs);
40+
processedDocs += Math.max(tempTensorOutputs.size(), 1);
41+
tensorOutputs.addAll(tempTensorOutputs);
42+
}
43+
3744
} else {
3845
preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs);
3946
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java

+31-1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,36 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
118118
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
119119
}
120120

121+
@Test
122+
public void executePredict_RemoteInferenceInput_InvalidToken() throws IOException {
123+
exceptionRule.expect(OpenSearchStatusException.class);
124+
exceptionRule.expectMessage("{\"message\":\"The security token included in the request is invalid\"}");
125+
String jsonString = "{\"message\":\"The security token included in the request is invalid\"}";
126+
InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
127+
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
128+
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
129+
when(httpRequest.call()).thenReturn(response);
130+
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
131+
when(httpResponse.statusCode()).thenReturn(403);
132+
when(response.httpResponse()).thenReturn(httpResponse);
133+
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);
134+
135+
ConnectorAction predictAction = ConnectorAction.builder()
136+
.actionType(ConnectorAction.ActionType.PREDICT)
137+
.method("POST")
138+
.url("http://test.com/mock")
139+
.requestBody("{\"input\": \"${parameters.input}\"}")
140+
.build();
141+
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
142+
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
143+
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
144+
connector.decrypt((c) -> encryptor.decrypt(c));
145+
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
146+
147+
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
148+
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
149+
}
150+
121151
@Test
122152
public void executePredict_RemoteInferenceInput() throws IOException {
123153
String jsonString = "{\"key\":\"value\"}";
@@ -176,7 +206,7 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
176206
connector.decrypt((c) -> encryptor.decrypt(c));
177207
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
178208

179-
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input", "test input data")).build();
209+
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build();
180210
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
181211
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
182212
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java

+25-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.junit.rules.ExpectedException;
2121
import org.mockito.Mock;
2222
import org.mockito.MockitoAnnotations;
23+
import org.opensearch.OpenSearchStatusException;
2324
import org.opensearch.cluster.ClusterStateTaskConfig;
2425
import org.opensearch.ingest.TestTemplateService;
2526
import org.opensearch.ml.common.FunctionName;
@@ -120,12 +121,34 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
120121
when(executor.getHttpClient()).thenReturn(httpClient);
121122
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
122123
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
123-
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
124+
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
124125
Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
125126
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
126127
Assert.assertEquals("test result", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response"));
127128
}
128129

130+
@Test
131+
public void executePredict_TextDocsInput_LimitExceed() throws IOException {
132+
exceptionRule.expect(OpenSearchStatusException.class);
133+
exceptionRule.expectMessage("{\"message\": \"Too many requests\"}");
134+
ConnectorAction predictAction = ConnectorAction.builder()
135+
.actionType(ConnectorAction.ActionType.PREDICT)
136+
.method("POST")
137+
.url("http://test.com/mock")
138+
.requestBody("{\"input\": ${parameters.input}}")
139+
.build();
140+
when(httpClient.execute(any())).thenReturn(response);
141+
HttpEntity entity = new StringEntity("{\"message\": \"Too many requests\"}");
142+
when(response.getEntity()).thenReturn(entity);
143+
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK");
144+
when(response.getStatusLine()).thenReturn(statusLine);
145+
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
146+
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
147+
when(executor.getHttpClient()).thenReturn(httpClient);
148+
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
149+
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
150+
}
151+
129152
@Test
130153
public void executePredict_TextDocsInput() throws IOException {
131154
String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
@@ -161,7 +184,7 @@ public void executePredict_TextDocsInput() throws IOException {
161184
when(executor.getHttpClient()).thenReturn(httpClient);
162185
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
163186
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
164-
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
187+
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
165188
Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
166189
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
167190
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData());

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -746,8 +746,8 @@ public void deployModel(
746746
CLUSTER_SERVICE,
747747
clusterService
748748
);
749-
// deploy remote model or model trained by built-in algorithm like kmeans
750-
if (mlModel.getConnector() != null) {
749+
// deploy remote model with internal connector or model trained by built-in algorithm like kmeans
750+
if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) {
751751
setupPredictable(modelId, mlModel, params);
752752
wrappedListener.onResponse("successful");
753753
return;
@@ -756,6 +756,7 @@ public void deployModel(
756756
GetRequest getConnectorRequest = new GetRequest();
757757
FetchSourceContext fetchContext = new FetchSourceContext(true, null, null);
758758
getConnectorRequest.index(ML_CONNECTOR_INDEX).id(mlModel.getConnectorId()).fetchSourceContext(fetchContext);
759+
// get connector and deploy remote model with standalone connector
759760
client.get(getConnectorRequest, ActionListener.wrap(getResponse -> {
760761
if (getResponse != null && getResponse.isExists()) {
761762
try (

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+12-8
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
210210
FunctionName algorithm = mlInput.getAlgorithm();
211211
// run predict
212212
if (modelId != null) {
213-
try {
214-
Predictable predictor = mlModelManager.getPredictor(modelId);
215-
if (predictor != null) {
213+
Predictable predictor = mlModelManager.getPredictor(modelId);
214+
if (predictor != null) {
215+
try {
216216
if (!predictor.isModelReady()) {
217217
throw new IllegalArgumentException("Model not ready: " + modelId);
218218
}
@@ -226,11 +226,12 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
226226
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
227227
internalListener.onResponse(response);
228228
return;
229-
} else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) {
230-
throw new IllegalArgumentException("Model not ready to be used: " + modelId);
229+
} catch (Exception e) {
230+
handlePredictFailure(mlTask, internalListener, e, false);
231+
return;
231232
}
232-
} catch (Exception e) {
233-
handlePredictFailure(mlTask, internalListener, e, false);
233+
} else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) {
234+
throw new IllegalArgumentException("Model not ready to be used: " + modelId);
234235
}
235236

236237
// search model by model id.
@@ -249,6 +250,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
249250
GetResponse getResponse = r;
250251
String algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString();
251252
MLModel mlModel = MLModel.parse(xContentParser, algorithmName);
253+
mlModel.setModelId(modelId);
252254
User resourceUser = mlModel.getUser();
253255
User requestUser = getUserContext(client);
254256
if (!checkUserPermissions(requestUser, resourceUser, modelId)) {
@@ -260,7 +262,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
260262
return;
261263
}
262264
// run predict
263-
mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync());
265+
if (mlTaskManager.contains(mlTask.getTaskId())) {
266+
mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync());
267+
}
264268
MLOutput output = mlEngine.predict(mlInput, mlModel);
265269
if (output instanceof MLPredictionOutput) {
266270
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());

0 commit comments

Comments
 (0)