37
37
38
38
import com .google .common .collect .ImmutableList ;
39
39
import com .google .common .collect .ImmutableMap ;
40
+ import org .opensearch .searchpipelines .questionanswering .generative .llm .LlmIOUtil ;
40
41
41
42
public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
42
43
@@ -147,6 +148,32 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
147
148
private static final String BEDROCK_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null
148
149
? BEDROCK_CONNECTOR_BLUEPRINT2
149
150
: BEDROCK_CONNECTOR_BLUEPRINT1 ;
151
+
152
+ private static final String COHERE_API_KEY = System .getenv ("COHERE_API_KEY" );
153
+ private static final String COHERE_CONNECTOR_BLUEPRINT = "{\n "
154
+ + " \" name\" : \" Cohere Chat Model\" ,\n "
155
+ + " \" description\" : \" The connector to Cohere's public chat API\" ,\n "
156
+ + " \" version\" : \" 1\" ,\n "
157
+ + " \" protocol\" : \" http\" ,\n "
158
+ + " \" credential\" : {\n "
159
+ + " \" cohere_key\" : \" " + COHERE_API_KEY + "\" \n "
160
+ + " },\n "
161
+ + " \" parameters\" : {\n "
162
+ + " \" model\" : \" command\" \n "
163
+ + " },\n "
164
+ + " \" actions\" : [\n "
165
+ + " {\n "
166
+ + " \" action_type\" : \" predict\" ,\n "
167
+ + " \" method\" : \" POST\" ,\n "
168
+ + " \" url\" : \" https://api.cohere.ai/v1/chat\" ,\n "
169
+ + " \" headers\" : {\n "
170
+ + " \" Authorization\" : \" Bearer ${credential.cohere_key}\" ,\n "
171
+ + " \" Request-Source\" : \" unspecified:opensearch\" \n "
172
+ + " },\n "
173
+ + " \" request_body\" : \" { \\ \" message\\ \" : \\ \" ${parameters.inputs}\\ \" , \\ \" model\\ \" : \\ \" ${parameters.model}\\ \" }\" \n "
174
+ // + " \"post_process_function\": \"\\n String escape(def input) { \\n if (input.contains(\\\"\\\\\\\\\\\")) {\\n input = input.replace(\\\"\\\\\\\\\\\", \\\"\\\\\\\\\\\\\\\\\\\");\\n }\\n if (input.contains(\\\"\\\\\\\"\\\")) {\\n input = input.replace(\\\"\\\\\\\"\\\", \\\"\\\\\\\\\\\\\\\"\\\");\\n }\\n if (input.contains('\\r')) {\\n input = input = input.replace('\\r', '\\\\\\\\r');\\n }\\n if (input.contains(\\\"\\\\\\\\t\\\")) {\\n input = input.replace(\\\"\\\\\\\\t\\\", \\\"\\\\\\\\\\\\\\\\\\\\\\\\t\\\");\\n }\\n if (input.contains('\\n')) {\\n input = input.replace('\\n', '\\\\\\\\n');\\n }\\n if (input.contains('\\b')) {\\n input = input.replace('\\b', '\\\\\\\\b');\\n }\\n if (input.contains('\\f')) {\\n input = input.replace('\\f', '\\\\\\\\f');\\n }\\n return input;\\n }\\n def name = 'response';\\n def result = params.text;\\n def json = '{ \\\"name\\\": \\\"' + name + '\\\",' +\\n '\\\"dataAsMap\\\": { \\\"completion\\\": \\\"' + escape(result) +\\n '\\\"}}';\\n return json;\\n \\n \"\n"
175
+ + " }\n " + " ]\n " + "}" ;
176
+
150
177
private static final String PIPELINE_TEMPLATE = "{\n "
151
178
+ " \" response_processors\" : [\n "
152
179
+ " {\n "
@@ -195,6 +222,23 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
195
222
+ " }\n "
196
223
+ "}" ;
197
224
225
+ private static final String BM25_SEARCH_REQUEST_WITH_LLM_RESPONSE_FIELD_TEMPLATE = "{\n "
226
+ + " \" _source\" : [\" %s\" ],\n "
227
+ + " \" query\" : {\n "
228
+ + " \" match\" : {\" %s\" : \" %s\" }\n "
229
+ + " },\n "
230
+ + " \" ext\" : {\n "
231
+ + " \" generative_qa_parameters\" : {\n "
232
+ + " \" llm_model\" : \" %s\" ,\n "
233
+ + " \" llm_question\" : \" %s\" ,\n "
234
+ + " \" context_size\" : %d,\n "
235
+ + " \" message_size\" : %d,\n "
236
+ + " \" timeout\" : %d,\n "
237
+ + " \" llm_response_field\" : \" %s\" \n "
238
+ + " }\n "
239
+ + " }\n "
240
+ + "}" ;
241
+
198
242
private static final String OPENAI_MODEL = "gpt-3.5-turbo" ;
199
243
private static final String BEDROCK_ANTHROPIC_CLAUDE = "bedrock/anthropic-claude" ;
200
244
private static final String TEST_DOC_PATH = "org/opensearch/ml/rest/test_data/" ;
@@ -466,6 +510,111 @@ public void testBM25WithBedrockWithConversation() throws Exception {
466
510
assertNotNull (interactionId );
467
511
}
468
512
513
+ public void testBM25WithCohere () throws Exception {
514
+ // Skip test if key is null
515
+ if (COHERE_API_KEY == null ) {
516
+ return ;
517
+ }
518
+ Response response = createConnector (COHERE_CONNECTOR_BLUEPRINT );
519
+ Map responseMap = parseResponseToMap (response );
520
+ String connectorId = (String ) responseMap .get ("connector_id" );
521
+ response = registerRemoteModel ("Cohere Chat Completion v1" , connectorId );
522
+ responseMap = parseResponseToMap (response );
523
+ String taskId = (String ) responseMap .get ("task_id" );
524
+ waitForTask (taskId , MLTaskState .COMPLETED );
525
+ response = getTask (taskId );
526
+ responseMap = parseResponseToMap (response );
527
+ String modelId = (String ) responseMap .get ("model_id" );
528
+ response = deployRemoteModel (modelId );
529
+ responseMap = parseResponseToMap (response );
530
+ taskId = (String ) responseMap .get ("task_id" );
531
+ waitForTask (taskId , MLTaskState .COMPLETED );
532
+
533
+ PipelineParameters pipelineParameters = new PipelineParameters ();
534
+ pipelineParameters .tag = "testBM25WithCohere" ;
535
+ pipelineParameters .description = "desc" ;
536
+ pipelineParameters .modelId = modelId ;
537
+ pipelineParameters .systemPrompt = "You are a helpful assistant" ;
538
+ pipelineParameters .userInstructions = "none" ;
539
+ pipelineParameters .context_field = "text" ;
540
+ Response response1 = createSearchPipeline ("pipeline_test" , pipelineParameters );
541
+ assertEquals (200 , response1 .getStatusLine ().getStatusCode ());
542
+
543
+ SearchRequestParameters requestParameters = new SearchRequestParameters ();
544
+ requestParameters .source = "text" ;
545
+ requestParameters .match = "president" ;
546
+ requestParameters .llmModel = LlmIOUtil .COHERE_PROVIDER_PREFIX + "command" ;
547
+ requestParameters .llmQuestion = "who is lincoln" ;
548
+ requestParameters .contextSize = 5 ;
549
+ requestParameters .interactionSize = 5 ;
550
+ requestParameters .timeout = 60 ;
551
+ Response response2 = performSearch (INDEX_NAME , "pipeline_test" , 5 , requestParameters );
552
+ assertEquals (200 , response2 .getStatusLine ().getStatusCode ());
553
+
554
+ Map responseMap2 = parseResponseToMap (response2 );
555
+ Map ext = (Map ) responseMap2 .get ("ext" );
556
+ assertNotNull (ext );
557
+ Map rag = (Map ) ext .get ("retrieval_augmented_generation" );
558
+ assertNotNull (rag );
559
+
560
+ // TODO handle errors such as throttling
561
+ String answer = (String ) rag .get ("answer" );
562
+ assertNotNull (answer );
563
+ }
564
+
565
+ public void testBM25WithCohereUsingLlmResponseField () throws Exception {
566
+ // Skip test if key is null
567
+ if (COHERE_API_KEY == null ) {
568
+ return ;
569
+ }
570
+ Response response = createConnector (COHERE_CONNECTOR_BLUEPRINT );
571
+ Map responseMap = parseResponseToMap (response );
572
+ String connectorId = (String ) responseMap .get ("connector_id" );
573
+ response = registerRemoteModel ("Cohere Chat Completion v1" , connectorId );
574
+ responseMap = parseResponseToMap (response );
575
+ String taskId = (String ) responseMap .get ("task_id" );
576
+ waitForTask (taskId , MLTaskState .COMPLETED );
577
+ response = getTask (taskId );
578
+ responseMap = parseResponseToMap (response );
579
+ String modelId = (String ) responseMap .get ("model_id" );
580
+ response = deployRemoteModel (modelId );
581
+ responseMap = parseResponseToMap (response );
582
+ taskId = (String ) responseMap .get ("task_id" );
583
+ waitForTask (taskId , MLTaskState .COMPLETED );
584
+
585
+ PipelineParameters pipelineParameters = new PipelineParameters ();
586
+ pipelineParameters .tag = "testBM25WithCohereLlmResponseField" ;
587
+ pipelineParameters .description = "desc" ;
588
+ pipelineParameters .modelId = modelId ;
589
+ pipelineParameters .systemPrompt = "You are a helpful assistant" ;
590
+ pipelineParameters .userInstructions = "none" ;
591
+ pipelineParameters .context_field = "text" ;
592
+ Response response1 = createSearchPipeline ("pipeline_test" , pipelineParameters );
593
+ assertEquals (200 , response1 .getStatusLine ().getStatusCode ());
594
+
595
+ SearchRequestParameters requestParameters = new SearchRequestParameters ();
596
+ requestParameters .source = "text" ;
597
+ requestParameters .match = "president" ;
598
+ requestParameters .llmModel = "command" ;
599
+ requestParameters .llmQuestion = "who is lincoln" ;
600
+ requestParameters .contextSize = 5 ;
601
+ requestParameters .interactionSize = 5 ;
602
+ requestParameters .timeout = 60 ;
603
+ requestParameters .llmResponseField = "text" ;
604
+ Response response2 = performSearch (INDEX_NAME , "pipeline_test" , 5 , requestParameters );
605
+ assertEquals (200 , response2 .getStatusLine ().getStatusCode ());
606
+
607
+ Map responseMap2 = parseResponseToMap (response2 );
608
+ Map ext = (Map ) responseMap2 .get ("ext" );
609
+ assertNotNull (ext );
610
+ Map rag = (Map ) ext .get ("retrieval_augmented_generation" );
611
+ assertNotNull (rag );
612
+
613
+ // TODO handle errors such as throttling
614
+ String answer = (String ) rag .get ("answer" );
615
+ assertNotNull (answer );
616
+ }
617
+
469
618
private Response createSearchPipeline (String pipeline , PipelineParameters parameters ) throws Exception {
470
619
return makeRequest (
471
620
client (),
@@ -492,7 +641,24 @@ private Response createSearchPipeline(String pipeline, PipelineParameters parame
492
641
private Response performSearch (String indexName , String pipeline , int size , SearchRequestParameters requestParameters )
493
642
throws Exception {
494
643
495
- String httpEntity = (requestParameters .conversationId == null )
644
+ String httpEntity =
645
+ requestParameters .llmResponseField != null ?
646
+ String
647
+ .format (
648
+ Locale .ROOT ,
649
+ BM25_SEARCH_REQUEST_WITH_LLM_RESPONSE_FIELD_TEMPLATE ,
650
+ requestParameters .source ,
651
+ requestParameters .source ,
652
+ requestParameters .match ,
653
+ requestParameters .llmModel ,
654
+ requestParameters .llmQuestion ,
655
+ requestParameters .contextSize ,
656
+ requestParameters .interactionSize ,
657
+ requestParameters .timeout ,
658
+ requestParameters .llmResponseField
659
+ )
660
+ :
661
+ (requestParameters .conversationId == null )
496
662
? String
497
663
.format (
498
664
Locale .ROOT ,
@@ -560,5 +726,7 @@ static class SearchRequestParameters {
560
726
int interactionSize ;
561
727
int timeout ;
562
728
String conversationId ;
729
+
730
+ String llmResponseField ;
563
731
}
564
732
}
0 commit comments