Skip to content

Commit 5fc555d

Browse files
authored
add status code to model tensor (opensearch-project#1443)
* add status code to model tensor Signed-off-by: Yaliang Wu <[email protected]> * fix ut Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]>
1 parent 285f71b commit 5fc555d

File tree

5 files changed

+38
-0
lines changed

5 files changed

+38
-0
lines changed

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java

+9
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import lombok.Builder;
99
import lombok.Getter;
10+
import lombok.Setter;
1011
import org.opensearch.common.bytes.BytesReference;
1112
import org.opensearch.common.io.stream.BytesStreamOutput;
1213
import org.opensearch.common.io.stream.StreamInput;
@@ -24,7 +25,10 @@
2425
@Getter
2526
public class ModelTensors implements Writeable, ToXContentObject {
2627
public static final String OUTPUT_FIELD = "output";
28+
public static final String STATUS_CODE_FIELD = "status_code";
2729
private List<ModelTensor> mlModelTensors;
30+
@Setter
31+
private Integer statusCode;
2832

2933
@Builder
3034
public ModelTensors(List<ModelTensor> mlModelTensors) {
@@ -41,6 +45,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4145
}
4246
builder.endArray();
4347
}
48+
if (statusCode != null) {
49+
builder.field(STATUS_CODE_FIELD, statusCode);
50+
}
4451
builder.endObject();
4552
return builder;
4653
}
@@ -53,6 +60,7 @@ public ModelTensors(StreamInput in) throws IOException {
5360
mlModelTensors.add(new ModelTensor(in));
5461
}
5562
}
63+
statusCode = in.readOptionalInt();
5664
}
5765

5866
@Override
@@ -66,6 +74,7 @@ public void writeTo(StreamOutput out) throws IOException {
6674
} else {
6775
out.writeBoolean(false);
6876
}
77+
out.writeOptionalInt(statusCode);
6978
}
7079

7180
public void filter(ModelResultFilter resultFilter) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

+2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
8282
HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction<HttpExecuteResponse>) () -> {
8383
return httpClient.prepareRequest(executeRequest).call();
8484
});
85+
int statusCode = response.httpResponse().statusCode();
8586

8687
AbortableInputStream body = null;
8788
if (response.responseBody().isPresent()) {
@@ -102,6 +103,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
102103
String modelResponse = responseBuilder.toString();
103104

104105
ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
106+
tensors.setStatusCode(statusCode);
105107
tensorOutputs.add(tensors);
106108
} catch (RuntimeException exception) {
107109
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java

+3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ public HttpJsonConnectorExecutor(Connector connector) {
5252
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
5353
try {
5454
AtomicReference<String> responseRef = new AtomicReference<>("");
55+
AtomicReference<Integer> statusCodeRef = new AtomicReference<>();
5556

5657
HttpUriRequest request;
5758
switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) {
@@ -97,12 +98,14 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
9798
String responseBody = EntityUtils.toString(responseEntity);
9899
EntityUtils.consume(responseEntity);
99100
responseRef.set(responseBody);
101+
statusCodeRef.set(response.getStatusLine().getStatusCode());
100102
}
101103
return null;
102104
});
103105
String modelResponse = responseRef.get();
104106

105107
ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
108+
tensors.setStatusCode(statusCodeRef.get());
106109
tensorOutputs.add(tensors);
107110
} catch (RuntimeException e) {
108111
log.error("Fail to execute http connector", e);

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java

+14
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
import com.google.common.collect.ImmutableList;
99
import com.google.common.collect.ImmutableMap;
10+
import org.apache.http.ProtocolVersion;
11+
import org.apache.http.StatusLine;
12+
import org.apache.http.message.BasicStatusLine;
1013
import org.junit.Assert;
1114
import org.junit.Before;
1215
import org.junit.Rule;
@@ -32,6 +35,7 @@
3235
import software.amazon.awssdk.http.ExecutableHttpRequest;
3336
import software.amazon.awssdk.http.HttpExecuteResponse;
3437
import software.amazon.awssdk.http.SdkHttpClient;
38+
import software.amazon.awssdk.http.SdkHttpResponse;
3539

3640
import java.io.ByteArrayInputStream;
3741
import java.io.IOException;
@@ -41,6 +45,7 @@
4145
import java.util.Optional;
4246

4347
import static org.mockito.ArgumentMatchers.any;
48+
import static org.mockito.Mockito.mock;
4449
import static org.mockito.Mockito.spy;
4550
import static org.mockito.Mockito.when;
4651
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
@@ -92,6 +97,9 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
9297
exceptionRule.expectMessage("No response from model");
9398
when(response.responseBody()).thenReturn(Optional.empty());
9499
when(httpRequest.call()).thenReturn(response);
100+
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
101+
when(httpResponse.statusCode()).thenReturn(200);
102+
when(response.httpResponse()).thenReturn(httpResponse);
95103
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);
96104

97105
ConnectorAction predictAction = ConnectorAction.builder()
@@ -116,6 +124,9 @@ public void executePredict_RemoteInferenceInput() throws IOException {
116124
InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
117125
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
118126
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
127+
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
128+
when(httpResponse.statusCode()).thenReturn(200);
129+
when(response.httpResponse()).thenReturn(httpResponse);
119130
when(httpRequest.call()).thenReturn(response);
120131
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);
121132

@@ -147,6 +158,9 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
147158
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
148159
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
149160
when(httpRequest.call()).thenReturn(response);
161+
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
162+
when(httpResponse.statusCode()).thenReturn(200);
163+
when(response.httpResponse()).thenReturn(httpResponse);
150164
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);
151165

152166
ConnectorAction predictAction = ConnectorAction.builder()

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java

+10
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,20 @@
77

88
import com.google.common.collect.ImmutableMap;
99
import org.apache.http.HttpEntity;
10+
import org.apache.http.ProtocolVersion;
11+
import org.apache.http.StatusLine;
1012
import org.apache.http.client.methods.CloseableHttpResponse;
1113
import org.apache.http.entity.StringEntity;
1214
import org.apache.http.impl.client.CloseableHttpClient;
15+
import org.apache.http.message.BasicStatusLine;
1316
import org.junit.Assert;
1417
import org.junit.Before;
1518
import org.junit.Rule;
1619
import org.junit.Test;
1720
import org.junit.rules.ExpectedException;
1821
import org.mockito.Mock;
1922
import org.mockito.MockitoAnnotations;
23+
import org.opensearch.cluster.ClusterStateTaskConfig;
2024
import org.opensearch.ingest.TestTemplateService;
2125
import org.opensearch.ml.common.FunctionName;
2226
import org.opensearch.ml.common.connector.Connector;
@@ -87,6 +91,8 @@ public void executePredict_RemoteInferenceInput() throws IOException {
8791
when(httpClient.execute(any())).thenReturn(response);
8892
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
8993
when(response.getEntity()).thenReturn(entity);
94+
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
95+
when(response.getStatusLine()).thenReturn(statusLine);
9096
when(executor.getHttpClient()).thenReturn(httpClient);
9197
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
9298
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
@@ -107,6 +113,8 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
107113
when(httpClient.execute(any())).thenReturn(response);
108114
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
109115
when(response.getEntity()).thenReturn(entity);
116+
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
117+
when(response.getStatusLine()).thenReturn(statusLine);
110118
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
111119
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
112120
when(executor.getHttpClient()).thenReturn(httpClient);
@@ -146,6 +154,8 @@ public void executePredict_TextDocsInput() throws IOException {
146154
+ " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n"
147155
+ " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n"
148156
+ " \"total_tokens\": 5\n" + " }\n" + "}";
157+
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
158+
when(response.getStatusLine()).thenReturn(statusLine);
149159
HttpEntity entity = new StringEntity(modelResponse);
150160
when(response.getEntity()).thenReturn(entity);
151161
when(executor.getHttpClient()).thenReturn(httpClient);

0 commit comments

Comments
 (0)