|
20 | 20 | import org.junit.rules.ExpectedException;
|
21 | 21 | import org.mockito.Mock;
|
22 | 22 | import org.mockito.MockitoAnnotations;
|
| 23 | +import org.opensearch.OpenSearchStatusException; |
23 | 24 | import org.opensearch.cluster.ClusterStateTaskConfig;
|
24 | 25 | import org.opensearch.ingest.TestTemplateService;
|
25 | 26 | import org.opensearch.ml.common.FunctionName;
|
@@ -120,12 +121,34 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
|
120 | 121 | when(executor.getHttpClient()).thenReturn(httpClient);
|
121 | 122 | MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
|
122 | 123 | ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
|
123 |
| - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); |
| 124 | + Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); |
124 | 125 | Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
|
125 | 126 | Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
|
126 | 127 | Assert.assertEquals("test result", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response"));
|
127 | 128 | }
|
128 | 129 |
|
| 130 | + @Test |
| 131 | + public void executePredict_TextDocsInput_LimitExceed() throws IOException { |
| 132 | + exceptionRule.expect(OpenSearchStatusException.class); |
| 133 | + exceptionRule.expectMessage("{\"message\": \"Too many requests\"}"); |
| 134 | + ConnectorAction predictAction = ConnectorAction.builder() |
| 135 | + .actionType(ConnectorAction.ActionType.PREDICT) |
| 136 | + .method("POST") |
| 137 | + .url("http://test.com/mock") |
| 138 | + .requestBody("{\"input\": ${parameters.input}}") |
| 139 | + .build(); |
| 140 | + when(httpClient.execute(any())).thenReturn(response); |
| 141 | + HttpEntity entity = new StringEntity("{\"message\": \"Too many requests\"}"); |
| 142 | + when(response.getEntity()).thenReturn(entity); |
| 143 | + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK"); |
| 144 | + when(response.getStatusLine()).thenReturn(statusLine); |
| 145 | + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); |
| 146 | + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); |
| 147 | + when(executor.getHttpClient()).thenReturn(httpClient); |
| 148 | + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); |
| 149 | + executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); |
| 150 | + } |
| 151 | + |
129 | 152 | @Test
|
130 | 153 | public void executePredict_TextDocsInput() throws IOException {
|
131 | 154 | String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
|
@@ -161,7 +184,7 @@ public void executePredict_TextDocsInput() throws IOException {
|
161 | 184 | when(executor.getHttpClient()).thenReturn(httpClient);
|
162 | 185 | MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
|
163 | 186 | ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
|
164 |
| - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); |
| 187 | + Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); |
165 | 188 | Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
|
166 | 189 | Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
|
167 | 190 | Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData());
|
|
0 commit comments