7
7
8
8
import com .google .common .collect .ImmutableMap ;
9
9
import org .apache .http .HttpEntity ;
10
+ import org .apache .http .ProtocolVersion ;
11
+ import org .apache .http .StatusLine ;
10
12
import org .apache .http .client .methods .CloseableHttpResponse ;
11
13
import org .apache .http .entity .StringEntity ;
12
14
import org .apache .http .impl .client .CloseableHttpClient ;
15
+ import org .apache .http .message .BasicStatusLine ;
13
16
import org .junit .Assert ;
14
17
import org .junit .Before ;
15
18
import org .junit .Rule ;
16
19
import org .junit .Test ;
17
20
import org .junit .rules .ExpectedException ;
18
21
import org .mockito .Mock ;
19
22
import org .mockito .MockitoAnnotations ;
23
+ import org .opensearch .cluster .ClusterStateTaskConfig ;
20
24
import org .opensearch .ingest .TestTemplateService ;
21
25
import org .opensearch .ml .common .FunctionName ;
22
26
import org .opensearch .ml .common .connector .Connector ;
@@ -87,6 +91,8 @@ public void executePredict_RemoteInferenceInput() throws IOException {
87
91
when (httpClient .execute (any ())).thenReturn (response );
88
92
HttpEntity entity = new StringEntity ("{\" response\" : \" test result\" }" );
89
93
when (response .getEntity ()).thenReturn (entity );
94
+ StatusLine statusLine = new BasicStatusLine (new ProtocolVersion ("HTTP" , 1 , 1 ), 200 , "OK" );
95
+ when (response .getStatusLine ()).thenReturn (statusLine );
90
96
when (executor .getHttpClient ()).thenReturn (httpClient );
91
97
MLInputDataset inputDataSet = RemoteInferenceInputDataSet .builder ().parameters (ImmutableMap .of ("input" , "test input data" )).build ();
92
98
ModelTensorOutput modelTensorOutput = executor .executePredict (MLInput .builder ().algorithm (FunctionName .REMOTE ).inputDataset (inputDataSet ).build ());
@@ -107,6 +113,8 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
107
113
when (httpClient .execute (any ())).thenReturn (response );
108
114
HttpEntity entity = new StringEntity ("{\" response\" : \" test result\" }" );
109
115
when (response .getEntity ()).thenReturn (entity );
116
+ StatusLine statusLine = new BasicStatusLine (new ProtocolVersion ("HTTP" , 1 , 1 ), 200 , "OK" );
117
+ when (response .getStatusLine ()).thenReturn (statusLine );
110
118
Connector connector = HttpConnector .builder ().name ("test connector" ).version ("1" ).protocol ("http" ).actions (Arrays .asList (predictAction )).build ();
111
119
HttpJsonConnectorExecutor executor = spy (new HttpJsonConnectorExecutor (connector ));
112
120
when (executor .getHttpClient ()).thenReturn (httpClient );
@@ -146,6 +154,8 @@ public void executePredict_TextDocsInput() throws IOException {
146
154
+ " 0.0035105038\n " + " ]\n " + " }\n " + " ],\n "
147
155
+ " \" model\" : \" text-embedding-ada-002-v2\" ,\n " + " \" usage\" : {\n " + " \" prompt_tokens\" : 5,\n "
148
156
+ " \" total_tokens\" : 5\n " + " }\n " + "}" ;
157
+ StatusLine statusLine = new BasicStatusLine (new ProtocolVersion ("HTTP" , 1 , 1 ), 200 , "OK" );
158
+ when (response .getStatusLine ()).thenReturn (statusLine );
149
159
HttpEntity entity = new StringEntity (modelResponse );
150
160
when (response .getEntity ()).thenReturn (entity );
151
161
when (executor .getHttpClient ()).thenReturn (httpClient );
0 commit comments