Skip to content

Commit bb6f5c2

Browse files
committed
Add test coverage for error cases.
Signed-off-by: Austin Lee <[email protected]>
1 parent 266399c commit bb6f5c2

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java

+46
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import org.json.JSONArray;
3131
import org.json.JSONException;
3232
import org.json.JSONObject;
33+
import org.junit.Rule;
34+
import org.junit.rules.ExpectedException;
3335
import org.mockito.ArgumentCaptor;
3436
import org.mockito.Mock;
3537
import org.opensearch.client.Client;
@@ -52,6 +54,9 @@ public class DefaultLlmImplTests extends OpenSearchTestCase {
5254
@Mock
5355
Client client;
5456

57+
@Rule
58+
public ExpectedException exceptionRule = ExpectedException.none();
59+
5560
public void testBuildMessageParameter() {
5661
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
5762
String question = "Who am I";
@@ -422,6 +427,47 @@ public void testChatCompletionBedrockThrowingError() throws Exception {
422427
assertEquals(errorMessage, (String) output.getErrors().get(0));
423428
}
424429

430+
public void testIllegalArgument1() {
431+
exceptionRule.expect(IllegalArgumentException.class);
432+
exceptionRule.expectMessage("Unknown/unsupported model provider: null. You must provide a valid model provider or llm_response_field.");
433+
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
434+
ArgumentCaptor<MLInput> captor = ArgumentCaptor.forClass(MLInput.class);
435+
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
436+
connector.setMlClient(mlClient);
437+
438+
String errorMessage = "throttled";
439+
Map<String, String> messageMap = Map.of("message", errorMessage);
440+
Map<String, ?> dataAsMap = Map.of("error", messageMap);
441+
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap);
442+
ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor))));
443+
ActionFuture<MLOutput> future = mock(ActionFuture.class);
444+
when(future.actionGet(anyLong())).thenReturn(mlOutput);
445+
when(mlClient.predict(any(), any())).thenReturn(future);
446+
ChatCompletionInput input = new ChatCompletionInput(
447+
"model",
448+
"question",
449+
Collections.emptyList(),
450+
Collections.emptyList(),
451+
0,
452+
"prompt",
453+
"instructions",
454+
null,
455+
null
456+
);
457+
ChatCompletionOutput output = connector.doChatCompletion(input);
458+
}
459+
460+
public void testIllegalArgument2() {
461+
exceptionRule.expect(IllegalArgumentException.class);
462+
exceptionRule.expectMessage("Unknown/unsupported model provider: null. You must provide a valid model provider or llm_response_field.");
463+
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
464+
ArgumentCaptor<MLInput> captor = ArgumentCaptor.forClass(MLInput.class);
465+
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
466+
connector.setMlClient(mlClient);
467+
468+
connector.buildChatCompletionOutput(null, Collections.emptyMap(), null);
469+
}
470+
425471
private boolean isJson(String Json) {
426472
try {
427473
new JSONObject(Json);

0 commit comments

Comments
 (0)