Skip to content

Commit c76d8d2

Browse files
committed
Improve code coverage, add more tests.
Signed-off-by: Austin Lee <[email protected]>
1 parent d0a56cf commit c76d8d2

File tree

1 file changed

+191
-0
lines changed

1 file changed

+191
-0
lines changed

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java

+191
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.junit.Rule;
3535
import org.junit.rules.ExpectedException;
3636
import org.mockito.ArgumentCaptor;
37+
import org.opensearch.OpenSearchException;
3738
import org.opensearch.action.search.SearchRequest;
3839
import org.opensearch.action.search.SearchResponse;
3940
import org.opensearch.action.search.SearchResponseSections;
@@ -230,6 +231,90 @@ public void testProcessResponse() throws Exception {
230231
assertTrue(res instanceof GenerativeSearchResponse);
231232
}
232233

234+
public void testProcessResponseWithErrorFromLlm() throws Exception {
235+
Client client = mock(Client.class);
236+
Map<String, Object> config = new HashMap<>();
237+
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model");
238+
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text"));
239+
240+
GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(
241+
client,
242+
alwaysOn
243+
).create(null, "tag", "desc", true, config, null);
244+
245+
ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class);
246+
when(memoryClient.getInteractions(any(), anyInt()))
247+
.thenReturn(
248+
List
249+
.of(
250+
new Interaction(
251+
"0",
252+
Instant.now(),
253+
"1",
254+
"question",
255+
"",
256+
"answer",
257+
"foo",
258+
Collections.singletonMap("meta data", "some meta")
259+
)
260+
)
261+
);
262+
processor.setMemoryClient(memoryClient);
263+
264+
SearchRequest request = new SearchRequest();
265+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
266+
GenerativeQAParameters params = new GenerativeQAParameters(
267+
"12345",
268+
"llm_model",
269+
"You are kind.",
270+
"system_promt",
271+
"user_insturctions",
272+
null,
273+
null,
274+
null
275+
);
276+
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
277+
extBuilder.setParams(params);
278+
request.source(sourceBuilder);
279+
sourceBuilder.ext(List.of(extBuilder));
280+
281+
int numHits = 10;
282+
SearchHit[] hitsArray = new SearchHit[numHits];
283+
for (int i = 0; i < numHits; i++) {
284+
XContentBuilder sourceContent = JsonXContent
285+
.contentBuilder()
286+
.startObject()
287+
.field("_id", String.valueOf(i))
288+
.field("text", "passage" + i)
289+
.field("title", "This is the title for document " + i)
290+
.endObject();
291+
hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of());
292+
hitsArray[i].sourceRef(BytesReference.bytes(sourceContent));
293+
}
294+
295+
SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f);
296+
SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
297+
SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null);
298+
299+
Llm llm = mock(Llm.class);
300+
ChatCompletionOutput output = mock(ChatCompletionOutput.class);
301+
when(llm.doChatCompletion(any())).thenReturn(output);
302+
when(output.isErrorOccurred()).thenReturn(true);
303+
when(output.getErrors()).thenReturn(List.of("something bad has occurred."));
304+
processor.setLlm(llm);
305+
306+
ArgumentCaptor<ChatCompletionInput> captor = ArgumentCaptor.forClass(ChatCompletionInput.class);
307+
SearchResponse res = processor.processResponse(request, response);
308+
verify(llm).doChatCompletion(captor.capture());
309+
ChatCompletionInput input = captor.getValue();
310+
assertTrue(input instanceof ChatCompletionInput);
311+
List<String> passages = ((ChatCompletionInput) input).getContexts();
312+
assertEquals("passage0", passages.get(0));
313+
assertEquals("passage1", passages.get(1));
314+
assertEquals(numHits, passages.size());
315+
assertTrue(res instanceof GenerativeSearchResponse);
316+
}
317+
233318
public void testProcessResponseSmallerContextSize() throws Exception {
234319
Client client = mock(Client.class);
235320
Map<String, Object> config = new HashMap<>();
@@ -497,4 +582,110 @@ public void testProcessResponseNullValueInteractions() throws Exception {
497582

498583
SearchResponse res = processor.processResponse(request, response);
499584
}
585+
586+
public void testProcessResponseIllegalArgument() throws Exception {
587+
exceptionRule.expect(IllegalArgumentException.class);
588+
exceptionRule.expectMessage("llm_model cannot be null.");
589+
590+
Client client = mock(Client.class);
591+
Map<String, Object> config = new HashMap<>();
592+
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model");
593+
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text"));
594+
595+
GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(
596+
client,
597+
alwaysOn
598+
).create(null, "tag", "desc", true, config, null);
599+
600+
ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class);
601+
when(memoryClient.getInteractions(any(), anyInt()))
602+
.thenReturn(List.of(new Interaction("0", Instant.now(), "1", null, null, null, null, null)));
603+
processor.setMemoryClient(memoryClient);
604+
605+
SearchRequest request = new SearchRequest();
606+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
607+
int contextSize = 5;
608+
GenerativeQAParameters params = new GenerativeQAParameters("12345", null, "Question", "You are kind.", null, contextSize, null, null);
609+
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
610+
extBuilder.setParams(params);
611+
request.source(sourceBuilder);
612+
sourceBuilder.ext(List.of(extBuilder));
613+
614+
int numHits = 10;
615+
SearchHit[] hitsArray = new SearchHit[numHits];
616+
for (int i = 0; i < numHits; i++) {
617+
XContentBuilder sourceContent = JsonXContent
618+
.contentBuilder()
619+
.startObject()
620+
.field("_id", String.valueOf(i))
621+
.field("text", "passage" + i)
622+
.field("title", "This is the title for document " + i)
623+
.endObject();
624+
hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of());
625+
hitsArray[i].sourceRef(BytesReference.bytes(sourceContent));
626+
}
627+
628+
SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f);
629+
SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
630+
SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null);
631+
632+
Llm llm = mock(Llm.class);
633+
// when(llm.doChatCompletion(any())).thenThrow(new NullPointerException("Null Pointer in Interactions"));
634+
processor.setLlm(llm);
635+
636+
SearchResponse res = processor.processResponse(request, response);
637+
}
638+
639+
public void testProcessResponseOpenSearchException() throws Exception {
640+
exceptionRule.expect(OpenSearchException.class);
641+
exceptionRule.expectMessage("GenerativeQAResponseProcessor failed in precessing response");
642+
643+
Client client = mock(Client.class);
644+
Map<String, Object> config = new HashMap<>();
645+
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model");
646+
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text"));
647+
648+
GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(
649+
client,
650+
alwaysOn
651+
).create(null, "tag", "desc", true, config, null);
652+
653+
ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class);
654+
when(memoryClient.getInteractions(any(), anyInt()))
655+
.thenReturn(List.of(new Interaction("0", Instant.now(), "1", null, null, null, null, null)));
656+
processor.setMemoryClient(memoryClient);
657+
658+
SearchRequest request = new SearchRequest();
659+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
660+
int contextSize = 5;
661+
GenerativeQAParameters params = new GenerativeQAParameters("12345", "model", "Question", "You are kind.", null, contextSize, null, null);
662+
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
663+
extBuilder.setParams(params);
664+
request.source(sourceBuilder);
665+
sourceBuilder.ext(List.of(extBuilder));
666+
667+
int numHits = 10;
668+
SearchHit[] hitsArray = new SearchHit[numHits];
669+
for (int i = 0; i < numHits; i++) {
670+
XContentBuilder sourceContent = JsonXContent
671+
.contentBuilder()
672+
.startObject()
673+
.field("_id", String.valueOf(i))
674+
.field("text", "passage" + i)
675+
.field("title", "This is the title for document " + i)
676+
.endObject();
677+
hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of());
678+
hitsArray[i].sourceRef(BytesReference.bytes(sourceContent));
679+
}
680+
681+
SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f);
682+
SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
683+
SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null);
684+
685+
Llm llm = mock(Llm.class);
686+
when(llm.doChatCompletion(any())).thenThrow(new RuntimeException());
687+
processor.setLlm(llm);
688+
689+
SearchResponse res = processor.processResponse(request, response);
690+
}
500691
}

0 commit comments

Comments
 (0)