Skip to content

Commit 266399c

Browse files
committed
Apply review comments, add more tests, simplify code.
Signed-off-by: Austin Lee <[email protected]>
1 parent 0dfbcc8 commit 266399c

File tree

4 files changed

+112
-59
lines changed

4 files changed

+112
-59
lines changed

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java

+4-9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import java.util.List;
2121

22+
import org.opensearch.core.common.util.CollectionUtils;
23+
2224
import lombok.Getter;
2325
import lombok.Setter;
2426
import lombok.extern.log4j.Log4j2;
@@ -38,19 +40,12 @@ public class ChatCompletionOutput {
3840

3941
public ChatCompletionOutput(List<Object> answers, List<String> errors) {
4042

41-
if (answers == null && errors == null) {
43+
if (CollectionUtils.isEmpty(answers) && CollectionUtils.isEmpty(errors)) {
4244
throw new IllegalArgumentException("answers and errors can't both be null.");
4345
}
4446

45-
if (answers == null) {
46-
if (errors.isEmpty()) {
47-
throw new IllegalArgumentException("If answers is not provided, one or more errors must be provided.");
48-
}
47+
if (CollectionUtils.isEmpty(answers)) {
4948
this.errorOccurred = true;
50-
} else if (errors == null) {
51-
if (answers.isEmpty()) {
52-
throw new IllegalArgumentException("If errors is not provided, one or more answers must be provided.");
53-
}
5449
}
5550

5651
this.answers = answers;

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java

+38-37
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import static com.google.common.base.Preconditions.checkNotNull;
2121

22+
import java.util.ArrayList;
2223
import java.util.HashMap;
2324
import java.util.List;
2425
import java.util.Map;
@@ -134,21 +135,16 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
134135

135136
protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, Map<String, ?> dataAsMap, String llmResponseField) {
136137

137-
List<Object> answers = null;
138-
List<String> errors = null;
138+
List<Object> answers = new ArrayList<>();
139+
List<String> errors = new ArrayList<>();
140+
141+
String answerField = null;
142+
String errorField = "error";
143+
String defaultErrorMessageField = "message";
139144

140145
if (llmResponseField != null) {
141-
String response = (String) dataAsMap.get(llmResponseField);
142-
if (response != null) {
143-
answers = List.of(response);
144-
} else {
145-
Map error = (Map) dataAsMap.get("error");
146-
if (error != null) {
147-
errors = List.of((String) error.get("message"));
148-
} else {
149-
errors = List.of("Unknown error or response.");
150-
}
151-
}
146+
answerField = llmResponseField;
147+
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
152148
} else if (provider == ModelProvider.OPENAI) {
153149
List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES);
154150
if (choices == null) {
@@ -168,34 +164,39 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
168164
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
169165
}
170166
} else if (provider == ModelProvider.BEDROCK) {
171-
String response = (String) dataAsMap.get("completion");
172-
if (response != null) {
173-
answers = List.of(response);
174-
} else {
175-
Map error = (Map) dataAsMap.get("error");
176-
if (error != null) {
177-
errors = List.of((String) error.get("message"));
178-
} else {
179-
errors = List.of("Unknown error or response.");
180-
}
181-
}
167+
answerField = "completion";
168+
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
182169
} else if (provider == ModelProvider.COHERE) {
183-
String response = (String) dataAsMap.get("text");
184-
if (response != null) {
185-
answers = List.of(response);
186-
} else {
187-
Map error = (Map) dataAsMap.get("error");
188-
if (error != null) {
189-
errors = List.of((String) error.get("message"));
190-
} else {
191-
errors = List.of("Unknown error or response.");
192-
log.error("{}", dataAsMap);
193-
}
194-
}
170+
answerField = "text";
171+
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
195172
} else {
196-
throw new IllegalArgumentException("Unknown/unsupported model provider: " + provider);
173+
throw new IllegalArgumentException(
174+
"Unknown/unsupported model provider: " + provider + ". You must provide a valid model provider or llm_response_field."
175+
);
197176
}
198177

199178
return new ChatCompletionOutput(answers, errors);
200179
}
180+
181+
private void fillAnswersOrErrors(
182+
final Map<String, ?> dataAsMap,
183+
List<Object> answers,
184+
List<String> errors,
185+
String answerField,
186+
String errorField,
187+
String defaultErrorMessageField
188+
) {
189+
String response = (String) dataAsMap.get(answerField);
190+
if (response != null) {
191+
answers.add(response);
192+
} else {
193+
Map error = (Map) dataAsMap.get(errorField);
194+
if (error != null && error.get(defaultErrorMessageField) != null) {
195+
errors.add((String) error.get(defaultErrorMessageField));
196+
} else {
197+
errors.add("Unknown error or response.");
198+
log.error("{}", dataAsMap);
199+
}
200+
}
201+
}
201202
}

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutputTests.java

-13
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
*/
1818
package org.opensearch.searchpipelines.questionanswering.generative.llm;
1919

20-
import java.util.ArrayList;
2120
import java.util.List;
2221

2322
import org.junit.Rule;
@@ -45,16 +44,4 @@ public void testIllegalArgument1() {
4544
exceptionRule.expectMessage("answers and errors can't both be null.");
4645
new ChatCompletionOutput(null, null);
4746
}
48-
49-
public void testIllegalArgument2() {
50-
exceptionRule.expect(IllegalArgumentException.class);
51-
exceptionRule.expectMessage("If answers is not provided, one or more errors must be provided.");
52-
new ChatCompletionOutput(null, new ArrayList<>());
53-
}
54-
55-
public void testIllegalArgument3() {
56-
exceptionRule.expect(IllegalArgumentException.class);
57-
exceptionRule.expectMessage("If errors is not provided, one or more answers must be provided.");
58-
new ChatCompletionOutput(new ArrayList<>(), null);
59-
}
6047
}

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java

+70
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,76 @@ public void testChatCompletionApiForFooWithError() throws Exception {
286286
assertEquals(errorMessage, (String) output.getErrors().get(0));
287287
}
288288

289+
public void testChatCompletionApiForFooWithErrorUnknowMessageField() throws Exception {
290+
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
291+
ArgumentCaptor<MLInput> captor = ArgumentCaptor.forClass(MLInput.class);
292+
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
293+
connector.setMlClient(mlClient);
294+
295+
String llmRespondField = UUID.randomUUID().toString();
296+
297+
String errorMessage = "throttled";
298+
Map<String, String> messageMap = Map.of("msg", errorMessage);
299+
Map<String, ?> dataAsMap = Map.of("error", messageMap);
300+
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap);
301+
ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor))));
302+
ActionFuture<MLOutput> future = mock(ActionFuture.class);
303+
when(future.actionGet(anyLong())).thenReturn(mlOutput);
304+
when(mlClient.predict(any(), any())).thenReturn(future);
305+
ChatCompletionInput input = new ChatCompletionInput(
306+
"model_foo",
307+
"question",
308+
Collections.emptyList(),
309+
Collections.emptyList(),
310+
0,
311+
"prompt",
312+
"instructions",
313+
null,
314+
llmRespondField
315+
);
316+
ChatCompletionOutput output = connector.doChatCompletion(input);
317+
verify(mlClient, times(1)).predict(any(), captor.capture());
318+
MLInput mlInput = captor.getValue();
319+
assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet);
320+
assertTrue(output.isErrorOccurred());
321+
assertEquals("Unknown error or response.", (String) output.getErrors().get(0));
322+
}
323+
324+
public void testChatCompletionApiForFooWithErrorUnknowErrorField() throws Exception {
325+
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
326+
ArgumentCaptor<MLInput> captor = ArgumentCaptor.forClass(MLInput.class);
327+
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
328+
connector.setMlClient(mlClient);
329+
330+
String llmRespondField = UUID.randomUUID().toString();
331+
332+
String errorMessage = "throttled";
333+
Map<String, String> messageMap = Map.of("msg", errorMessage);
334+
Map<String, ?> dataAsMap = Map.of("err", messageMap);
335+
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap);
336+
ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor))));
337+
ActionFuture<MLOutput> future = mock(ActionFuture.class);
338+
when(future.actionGet(anyLong())).thenReturn(mlOutput);
339+
when(mlClient.predict(any(), any())).thenReturn(future);
340+
ChatCompletionInput input = new ChatCompletionInput(
341+
"model_foo",
342+
"question",
343+
Collections.emptyList(),
344+
Collections.emptyList(),
345+
0,
346+
"prompt",
347+
"instructions",
348+
null,
349+
llmRespondField
350+
);
351+
ChatCompletionOutput output = connector.doChatCompletion(input);
352+
verify(mlClient, times(1)).predict(any(), captor.capture());
353+
MLInput mlInput = captor.getValue();
354+
assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet);
355+
assertTrue(output.isErrorOccurred());
356+
assertEquals("Unknown error or response.", (String) output.getErrors().get(0));
357+
}
358+
289359
public void testChatCompletionThrowingError() throws Exception {
290360
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
291361
ArgumentCaptor<MLInput> captor = ArgumentCaptor.forClass(MLInput.class);

0 commit comments

Comments
 (0)