Skip to content

Commit 030d187

Browse files
committed
Add support for Cohere and others.
Signed-off-by: Austin Lee <[email protected]>
1 parent d645c83 commit 030d187

File tree

13 files changed

+284
-48
lines changed

13 files changed

+284
-48
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

+169-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import com.google.common.collect.ImmutableList;
3939
import com.google.common.collect.ImmutableMap;
40+
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
4041

4142
public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
4243

@@ -147,6 +148,32 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
147148
private static final String BEDROCK_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null
148149
? BEDROCK_CONNECTOR_BLUEPRINT2
149150
: BEDROCK_CONNECTOR_BLUEPRINT1;
151+
152+
private static final String COHERE_API_KEY = System.getenv("COHERE_API_KEY");
153+
private static final String COHERE_CONNECTOR_BLUEPRINT = "{\n"
154+
+ " \"name\": \"Cohere Chat Model\",\n"
155+
+ " \"description\": \"The connector to Cohere's public chat API\",\n"
156+
+ " \"version\": \"1\",\n"
157+
+ " \"protocol\": \"http\",\n"
158+
+ " \"credential\": {\n"
159+
+ " \"cohere_key\": \"" + COHERE_API_KEY + "\"\n"
160+
+ " },\n"
161+
+ " \"parameters\": {\n"
162+
+ " \"model\": \"command\"\n"
163+
+ " },\n"
164+
+ " \"actions\": [\n"
165+
+ " {\n"
166+
+ " \"action_type\": \"predict\",\n"
167+
+ " \"method\": \"POST\",\n"
168+
+ " \"url\": \"https://api.cohere.ai/v1/chat\",\n"
169+
+ " \"headers\": {\n"
170+
+ " \"Authorization\": \"Bearer ${credential.cohere_key}\",\n"
171+
+ " \"Request-Source\": \"unspecified:opensearch\"\n"
172+
+ " },\n"
173+
+ " \"request_body\": \"{ \\\"message\\\": \\\"${parameters.inputs}\\\", \\\"model\\\": \\\"${parameters.model}\\\" }\" \n"
174+
// + " \"post_process_function\": \"\\n String escape(def input) { \\n if (input.contains(\\\"\\\\\\\\\\\")) {\\n input = input.replace(\\\"\\\\\\\\\\\", \\\"\\\\\\\\\\\\\\\\\\\");\\n }\\n if (input.contains(\\\"\\\\\\\"\\\")) {\\n input = input.replace(\\\"\\\\\\\"\\\", \\\"\\\\\\\\\\\\\\\"\\\");\\n }\\n if (input.contains('\\r')) {\\n input = input = input.replace('\\r', '\\\\\\\\r');\\n }\\n if (input.contains(\\\"\\\\\\\\t\\\")) {\\n input = input.replace(\\\"\\\\\\\\t\\\", \\\"\\\\\\\\\\\\\\\\\\\\\\\\t\\\");\\n }\\n if (input.contains('\\n')) {\\n input = input.replace('\\n', '\\\\\\\\n');\\n }\\n if (input.contains('\\b')) {\\n input = input.replace('\\b', '\\\\\\\\b');\\n }\\n if (input.contains('\\f')) {\\n input = input.replace('\\f', '\\\\\\\\f');\\n }\\n return input;\\n }\\n def name = 'response';\\n def result = params.text;\\n def json = '{ \\\"name\\\": \\\"' + name + '\\\",' +\\n '\\\"dataAsMap\\\": { \\\"completion\\\": \\\"' + escape(result) +\\n '\\\"}}';\\n return json;\\n \\n \"\n"
175+
+ " }\n" + " ]\n" + "}";
176+
150177
private static final String PIPELINE_TEMPLATE = "{\n"
151178
+ " \"response_processors\": [\n"
152179
+ " {\n"
@@ -195,6 +222,23 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
195222
+ " }\n"
196223
+ "}";
197224

225+
private static final String BM25_SEARCH_REQUEST_WITH_LLM_RESPONSE_FIELD_TEMPLATE = "{\n"
226+
+ " \"_source\": [\"%s\"],\n"
227+
+ " \"query\" : {\n"
228+
+ " \"match\": {\"%s\": \"%s\"}\n"
229+
+ " },\n"
230+
+ " \"ext\": {\n"
231+
+ " \"generative_qa_parameters\": {\n"
232+
+ " \"llm_model\": \"%s\",\n"
233+
+ " \"llm_question\": \"%s\",\n"
234+
+ " \"context_size\": %d,\n"
235+
+ " \"message_size\": %d,\n"
236+
+ " \"timeout\": %d,\n"
237+
+ " \"llm_response_field\": \"%s\"\n"
238+
+ " }\n"
239+
+ " }\n"
240+
+ "}";
241+
198242
private static final String OPENAI_MODEL = "gpt-3.5-turbo";
199243
private static final String BEDROCK_ANTHROPIC_CLAUDE = "bedrock/anthropic-claude";
200244
private static final String TEST_DOC_PATH = "org/opensearch/ml/rest/test_data/";
@@ -466,6 +510,111 @@ public void testBM25WithBedrockWithConversation() throws Exception {
466510
assertNotNull(interactionId);
467511
}
468512

513+
public void testBM25WithCohere() throws Exception {
514+
// Skip test if key is null
515+
if (COHERE_API_KEY == null) {
516+
return;
517+
}
518+
Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT);
519+
Map responseMap = parseResponseToMap(response);
520+
String connectorId = (String) responseMap.get("connector_id");
521+
response = registerRemoteModel("Cohere Chat Completion v1", connectorId);
522+
responseMap = parseResponseToMap(response);
523+
String taskId = (String) responseMap.get("task_id");
524+
waitForTask(taskId, MLTaskState.COMPLETED);
525+
response = getTask(taskId);
526+
responseMap = parseResponseToMap(response);
527+
String modelId = (String) responseMap.get("model_id");
528+
response = deployRemoteModel(modelId);
529+
responseMap = parseResponseToMap(response);
530+
taskId = (String) responseMap.get("task_id");
531+
waitForTask(taskId, MLTaskState.COMPLETED);
532+
533+
PipelineParameters pipelineParameters = new PipelineParameters();
534+
pipelineParameters.tag = "testBM25WithCohere";
535+
pipelineParameters.description = "desc";
536+
pipelineParameters.modelId = modelId;
537+
pipelineParameters.systemPrompt = "You are a helpful assistant";
538+
pipelineParameters.userInstructions = "none";
539+
pipelineParameters.context_field = "text";
540+
Response response1 = createSearchPipeline("pipeline_test", pipelineParameters);
541+
assertEquals(200, response1.getStatusLine().getStatusCode());
542+
543+
SearchRequestParameters requestParameters = new SearchRequestParameters();
544+
requestParameters.source = "text";
545+
requestParameters.match = "president";
546+
requestParameters.llmModel = LlmIOUtil.COHERE_PROVIDER_PREFIX + "command";
547+
requestParameters.llmQuestion = "who is lincoln";
548+
requestParameters.contextSize = 5;
549+
requestParameters.interactionSize = 5;
550+
requestParameters.timeout = 60;
551+
Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters);
552+
assertEquals(200, response2.getStatusLine().getStatusCode());
553+
554+
Map responseMap2 = parseResponseToMap(response2);
555+
Map ext = (Map) responseMap2.get("ext");
556+
assertNotNull(ext);
557+
Map rag = (Map) ext.get("retrieval_augmented_generation");
558+
assertNotNull(rag);
559+
560+
// TODO handle errors such as throttling
561+
String answer = (String) rag.get("answer");
562+
assertNotNull(answer);
563+
}
564+
565+
public void testBM25WithCohereUsingLlmResponseField() throws Exception {
566+
// Skip test if key is null
567+
if (COHERE_API_KEY == null) {
568+
return;
569+
}
570+
Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT);
571+
Map responseMap = parseResponseToMap(response);
572+
String connectorId = (String) responseMap.get("connector_id");
573+
response = registerRemoteModel("Cohere Chat Completion v1", connectorId);
574+
responseMap = parseResponseToMap(response);
575+
String taskId = (String) responseMap.get("task_id");
576+
waitForTask(taskId, MLTaskState.COMPLETED);
577+
response = getTask(taskId);
578+
responseMap = parseResponseToMap(response);
579+
String modelId = (String) responseMap.get("model_id");
580+
response = deployRemoteModel(modelId);
581+
responseMap = parseResponseToMap(response);
582+
taskId = (String) responseMap.get("task_id");
583+
waitForTask(taskId, MLTaskState.COMPLETED);
584+
585+
PipelineParameters pipelineParameters = new PipelineParameters();
586+
pipelineParameters.tag = "testBM25WithCohereLlmResponseField";
587+
pipelineParameters.description = "desc";
588+
pipelineParameters.modelId = modelId;
589+
pipelineParameters.systemPrompt = "You are a helpful assistant";
590+
pipelineParameters.userInstructions = "none";
591+
pipelineParameters.context_field = "text";
592+
Response response1 = createSearchPipeline("pipeline_test", pipelineParameters);
593+
assertEquals(200, response1.getStatusLine().getStatusCode());
594+
595+
SearchRequestParameters requestParameters = new SearchRequestParameters();
596+
requestParameters.source = "text";
597+
requestParameters.match = "president";
598+
requestParameters.llmModel = "command";
599+
requestParameters.llmQuestion = "who is lincoln";
600+
requestParameters.contextSize = 5;
601+
requestParameters.interactionSize = 5;
602+
requestParameters.timeout = 60;
603+
requestParameters.llmResponseField = "text";
604+
Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters);
605+
assertEquals(200, response2.getStatusLine().getStatusCode());
606+
607+
Map responseMap2 = parseResponseToMap(response2);
608+
Map ext = (Map) responseMap2.get("ext");
609+
assertNotNull(ext);
610+
Map rag = (Map) ext.get("retrieval_augmented_generation");
611+
assertNotNull(rag);
612+
613+
// TODO handle errors such as throttling
614+
String answer = (String) rag.get("answer");
615+
assertNotNull(answer);
616+
}
617+
469618
private Response createSearchPipeline(String pipeline, PipelineParameters parameters) throws Exception {
470619
return makeRequest(
471620
client(),
@@ -492,7 +641,24 @@ private Response createSearchPipeline(String pipeline, PipelineParameters parame
492641
private Response performSearch(String indexName, String pipeline, int size, SearchRequestParameters requestParameters)
493642
throws Exception {
494643

495-
String httpEntity = (requestParameters.conversationId == null)
644+
String httpEntity =
645+
requestParameters.llmResponseField != null ?
646+
String
647+
.format(
648+
Locale.ROOT,
649+
BM25_SEARCH_REQUEST_WITH_LLM_RESPONSE_FIELD_TEMPLATE,
650+
requestParameters.source,
651+
requestParameters.source,
652+
requestParameters.match,
653+
requestParameters.llmModel,
654+
requestParameters.llmQuestion,
655+
requestParameters.contextSize,
656+
requestParameters.interactionSize,
657+
requestParameters.timeout,
658+
requestParameters.llmResponseField
659+
)
660+
:
661+
(requestParameters.conversationId == null)
496662
? String
497663
.format(
498664
Locale.ROOT,
@@ -560,5 +726,7 @@ static class SearchRequestParameters {
560726
int interactionSize;
561727
int timeout;
562728
String conversationId;
729+
730+
String llmResponseField;
563731
}
564732
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
143143
}
144144
List<String> searchResults = getSearchResults(response, topN);
145145

146+
log.info("RAG request params: [{}]", params.getLlmResponseField());
146147
start = Instant.now();
147148
try {
148149
ChatCompletionOutput output = llm
@@ -155,7 +156,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
155156
llmQuestion,
156157
chatHistory,
157158
searchResults,
158-
timeout
159+
timeout, params.getLlmResponseField()
159160
)
160161
);
161162
log.info("doChatCompletion complete. ({})", getDuration(start));

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java

+18-3
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
7070
// from a remote inference endpoint before timing out the request.
7171
private static final ParseField TIMEOUT = new ParseField("timeout");
7272

73+
// Optional parameter; this parameter indicates the name of the field in the LLM response
74+
// that contains the chat completion text, i.e. "answer".
75+
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");
76+
7377
public static final int SIZE_NULL_VALUE = -1;
7478

7579
static {
@@ -80,6 +84,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
8084
PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE);
8185
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
8286
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
87+
PARSER.declareString(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
8388
}
8489

8590
@Setter
@@ -106,13 +111,18 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
106111
@Getter
107112
private Integer timeout;
108113

114+
@Setter
115+
@Getter
116+
private String llmResponseField;
117+
109118
public GenerativeQAParameters(
110119
String conversationId,
111120
String llmModel,
112121
String llmQuestion,
113122
Integer contextSize,
114123
Integer interactionSize,
115-
Integer timeout
124+
Integer timeout,
125+
String llmResponseField
116126
) {
117127
this.conversationId = conversationId;
118128
this.llmModel = llmModel;
@@ -124,6 +134,7 @@ public GenerativeQAParameters(
124134
this.contextSize = (contextSize == null) ? SIZE_NULL_VALUE : contextSize;
125135
this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize;
126136
this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout;
137+
this.llmResponseField = llmResponseField;
127138
}
128139

129140
public GenerativeQAParameters(StreamInput input) throws IOException {
@@ -133,6 +144,7 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
133144
this.contextSize = input.readInt();
134145
this.interactionSize = input.readInt();
135146
this.timeout = input.readInt();
147+
this.llmResponseField = input.readOptionalString();
136148
}
137149

138150
@Override
@@ -143,7 +155,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
143155
.field(LLM_QUESTION.getPreferredName(), this.llmQuestion)
144156
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
145157
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
146-
.field(TIMEOUT.getPreferredName(), this.timeout);
158+
.field(TIMEOUT.getPreferredName(), this.timeout)
159+
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField);
147160
}
148161

149162
@Override
@@ -156,6 +169,7 @@ public void writeTo(StreamOutput out) throws IOException {
156169
out.writeInt(contextSize);
157170
out.writeInt(interactionSize);
158171
out.writeInt(timeout);
172+
out.writeOptionalString(llmResponseField);
159173
}
160174

161175
public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
@@ -177,6 +191,7 @@ public boolean equals(Object o) {
177191
&& Objects.equals(this.llmQuestion, other.getLlmQuestion())
178192
&& (this.contextSize == other.getContextSize())
179193
&& (this.interactionSize == other.getInteractionSize())
180-
&& (this.timeout == other.getTimeout());
194+
&& (this.timeout == other.getTimeout())
195+
&& Objects.equals(this.llmResponseField, other.getLlmResponseField());
181196
}
182197
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java

+1
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,5 @@ public class ChatCompletionInput {
4343
private String systemPrompt;
4444
private String userInstructions;
4545
private Llm.ModelProvider modelProvider;
46+
private String llmResponseField;
4647
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java

+31-4
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionI
8787

8888
// TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases.
8989

90-
return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap);
90+
return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap, chatCompletionInput.getLlmResponseField());
9191
}
9292

9393
protected Map<String, String> getInputParameters(ChatCompletionInput chatCompletionInput) {
@@ -105,7 +105,9 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
105105
);
106106
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
107107
// log.info("Messages to LLM: {}", messages);
108-
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK) {
108+
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK
109+
|| chatCompletionInput.getModelProvider() == ModelProvider.COHERE
110+
|| chatCompletionInput.getLlmResponseField() != null) {
109111
inputParameters
110112
.put(
111113
"inputs",
@@ -126,12 +128,24 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
126128
return inputParameters;
127129
}
128130

129-
protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, Map<String, ?> dataAsMap) {
131+
protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, Map<String, ?> dataAsMap, String llmResponseField) {
130132

131133
List<Object> answers = null;
132134
List<String> errors = null;
133135

134-
if (provider == ModelProvider.OPENAI) {
136+
if (llmResponseField != null) {
137+
String response = (String) dataAsMap.get(llmResponseField);
138+
if (response != null) {
139+
answers = List.of(response);
140+
} else {
141+
Map error = (Map) dataAsMap.get("error");
142+
if (error != null) {
143+
errors = List.of((String) error.get("message"));
144+
} else {
145+
errors = List.of("Unknown error or response.");
146+
}
147+
}
148+
} else if (provider == ModelProvider.OPENAI) {
135149
List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES);
136150
if (choices == null) {
137151
Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR);
@@ -161,6 +175,19 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
161175
errors = List.of("Unknown error or response.");
162176
}
163177
}
178+
} else if (provider == ModelProvider.COHERE) {
179+
String response = (String) dataAsMap.get("text");
180+
if (response != null) {
181+
answers = List.of(response);
182+
} else {
183+
Map error = (Map) dataAsMap.get("error");
184+
if (error != null) {
185+
errors = List.of((String) error.get("message"));
186+
} else {
187+
errors = List.of("Unknown error or response.");
188+
log.error("{}", dataAsMap);
189+
}
190+
}
164191
} else {
165192
throw new IllegalArgumentException("Unknown/unsupported model provider: " + provider);
166193
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ public interface Llm {
2525
// TODO Ensure the current implementation works with all models supported by Bedrock.
2626
enum ModelProvider {
2727
OPENAI,
28-
BEDROCK
28+
BEDROCK,
29+
COHERE
2930
}
3031

3132
ChatCompletionOutput doChatCompletion(ChatCompletionInput input);

0 commit comments

Comments
 (0)