Skip to content

Commit 65f59d4

Browse files
committed
Use suite specific model group name.
Signed-off-by: Austin Lee <[email protected]>
1 parent 06b74cb commit 65f59d4

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

+22-20
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,8 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
475475
protected ClassLoader classLoader = RestMLRAGSearchProcessorIT.class.getClassLoader();
476476
private static final String INDEX_NAME = "test";
477477

478+
private static final String ML_RAG_REMOTE_MODEL_GROUP = "rag_remote_model_group";
479+
478480
// "client" gets initialized by the test framework at the instance level
479481
// so we perform this per test case, not via @BeforeClass.
480482
@Before
@@ -528,7 +530,7 @@ public void testBM25WithOpenAI() throws Exception {
528530
Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT);
529531
Map responseMap = parseResponseToMap(response);
530532
String connectorId = (String) responseMap.get("connector_id");
531-
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
533+
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "openAI-GPT-3.5 completions", connectorId);
532534
responseMap = parseResponseToMap(response);
533535
String taskId = (String) responseMap.get("task_id");
534536
waitForTask(taskId, MLTaskState.COMPLETED);
@@ -582,7 +584,7 @@ public void testBM25WithOpenAIWithImage() throws Exception {
582584
Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT);
583585
Map responseMap = parseResponseToMap(response);
584586
String connectorId = (String) responseMap.get("connector_id");
585-
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-4o-mini completions", connectorId);
587+
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "openAI-GPT-4o-mini completions", connectorId);
586588
responseMap = parseResponseToMap(response);
587589
String taskId = (String) responseMap.get("task_id");
588590
waitForTask(taskId, MLTaskState.COMPLETED);
@@ -595,7 +597,7 @@ public void testBM25WithOpenAIWithImage() throws Exception {
595597
waitForTask(taskId, MLTaskState.COMPLETED);
596598

597599
PipelineParameters pipelineParameters = new PipelineParameters();
598-
pipelineParameters.tag = "testBM25WithOpenAI";
600+
pipelineParameters.tag = "testBM25WithOpenAIWithImage";
599601
pipelineParameters.description = "desc";
600602
pipelineParameters.modelId = modelId;
601603
pipelineParameters.systemPrompt = "You are a helpful assistant";
@@ -669,7 +671,7 @@ public void testBM25WithBedrock() throws Exception {
669671
Response response = createConnector(BEDROCK_CONNECTOR_BLUEPRINT);
670672
Map responseMap = parseResponseToMap(response);
671673
String connectorId = (String) responseMap.get("connector_id");
672-
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId);
674+
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId);
673675
responseMap = parseResponseToMap(response);
674676
String taskId = (String) responseMap.get("task_id");
675677
waitForTask(taskId, MLTaskState.COMPLETED);
@@ -682,7 +684,7 @@ public void testBM25WithBedrock() throws Exception {
682684
waitForTask(taskId, MLTaskState.COMPLETED);
683685

684686
PipelineParameters pipelineParameters = new PipelineParameters();
685-
pipelineParameters.tag = "testBM25WithOpenAI";
687+
pipelineParameters.tag = "testBM25WithBedrock";
686688
pipelineParameters.description = "desc";
687689
pipelineParameters.modelId = modelId;
688690
pipelineParameters.systemPrompt = "You are a helpful assistant";
@@ -721,7 +723,7 @@ public void testBM25WithBedrockConverse() throws Exception {
721723
Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT);
722724
Map responseMap = parseResponseToMap(response);
723725
String connectorId = (String) responseMap.get("connector_id");
724-
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId);
726+
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId);
725727
responseMap = parseResponseToMap(response);
726728
String taskId = (String) responseMap.get("task_id");
727729
waitForTask(taskId, MLTaskState.COMPLETED);
@@ -734,7 +736,7 @@ public void testBM25WithBedrockConverse() throws Exception {
734736
waitForTask(taskId, MLTaskState.COMPLETED);
735737

736738
PipelineParameters pipelineParameters = new PipelineParameters();
737-
pipelineParameters.tag = "testBM25WithOpenAI";
739+
pipelineParameters.tag = "testBM25WithBedrockConverse";
738740
pipelineParameters.description = "desc";
739741
pipelineParameters.modelId = modelId;
740742
pipelineParameters.systemPrompt = "You are a helpful assistant";
@@ -773,7 +775,7 @@ public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception {
773775
Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2);
774776
Map responseMap = parseResponseToMap(response);
775777
String connectorId = (String) responseMap.get("connector_id");
776-
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId);
778+
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId);
777779
responseMap = parseResponseToMap(response);
778780
String taskId = (String) responseMap.get("task_id");
779781
waitForTask(taskId, MLTaskState.COMPLETED);
@@ -786,7 +788,7 @@ public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception {
786788
waitForTask(taskId, MLTaskState.COMPLETED);
787789

788790
PipelineParameters pipelineParameters = new PipelineParameters();
789-
pipelineParameters.tag = "testBM25WithOpenAI";
791+
pipelineParameters.tag = "testBM25WithBedrockConverseUsingLlmMessages";
790792
pipelineParameters.description = "desc";
791793
pipelineParameters.modelId = modelId;
792794
pipelineParameters.systemPrompt = "You are a helpful assistant";
@@ -833,7 +835,7 @@ public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws
833835
Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2);
834836
Map responseMap = parseResponseToMap(response);
835837
String connectorId = (String) responseMap.get("connector_id");
836-
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId);
838+
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId);
837839
responseMap = parseResponseToMap(response);
838840
String taskId = (String) responseMap.get("task_id");
839841
waitForTask(taskId, MLTaskState.COMPLETED);
@@ -846,7 +848,7 @@ public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws
846848
waitForTask(taskId, MLTaskState.COMPLETED);
847849

848850
PipelineParameters pipelineParameters = new PipelineParameters();
849-
pipelineParameters.tag = "testBM25WithOpenAI";
851+
pipelineParameters.tag = "testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat";
850852
pipelineParameters.description = "desc";
851853
pipelineParameters.modelId = modelId;
852854
// pipelineParameters.systemPrompt = "You are a helpful assistant";
@@ -892,7 +894,7 @@ public void testBM25WithOpenAIWithConversation() throws Exception {
892894
Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT);
893895
Map responseMap = parseResponseToMap(response);
894896
String connectorId = (String) responseMap.get("connector_id");
895-
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
897+
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "openAI-GPT-3.5 completions", connectorId);
896898
responseMap = parseResponseToMap(response);
897899
String taskId = (String) responseMap.get("task_id");
898900
waitForTask(taskId, MLTaskState.COMPLETED);
@@ -905,7 +907,7 @@ public void testBM25WithOpenAIWithConversation() throws Exception {
905907
waitForTask(taskId, MLTaskState.COMPLETED);
906908

907909
PipelineParameters pipelineParameters = new PipelineParameters();
908-
pipelineParameters.tag = "testBM25WithOpenAI";
910+
pipelineParameters.tag = "testBM25WithOpenAIWithConversation";
909911
pipelineParameters.description = "desc";
910912
pipelineParameters.modelId = modelId;
911913
pipelineParameters.systemPrompt = "You are a helpful assistant";
@@ -949,7 +951,7 @@ public void testBM25WithOpenAIWithConversationAndImage() throws Exception {
949951
Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT);
950952
Map responseMap = parseResponseToMap(response);
951953
String connectorId = (String) responseMap.get("connector_id");
952-
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-4 completions", connectorId);
954+
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "openAI-GPT-4 completions", connectorId);
953955
responseMap = parseResponseToMap(response);
954956
String taskId = (String) responseMap.get("task_id");
955957
waitForTask(taskId, MLTaskState.COMPLETED);
@@ -962,7 +964,7 @@ public void testBM25WithOpenAIWithConversationAndImage() throws Exception {
962964
waitForTask(taskId, MLTaskState.COMPLETED);
963965

964966
PipelineParameters pipelineParameters = new PipelineParameters();
965-
pipelineParameters.tag = "testBM25WithOpenAI";
967+
pipelineParameters.tag = "testBM25WithOpenAIWithConversationAndImage";
966968
pipelineParameters.description = "desc";
967969
pipelineParameters.modelId = modelId;
968970
pipelineParameters.systemPrompt = "You are a helpful assistant";
@@ -1010,7 +1012,7 @@ public void testBM25WithBedrockWithConversation() throws Exception {
10101012
Response response = createConnector(BEDROCK_CONNECTOR_BLUEPRINT);
10111013
Map responseMap = parseResponseToMap(response);
10121014
String connectorId = (String) responseMap.get("connector_id");
1013-
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock", connectorId);
1015+
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock", connectorId);
10141016
responseMap = parseResponseToMap(response);
10151017
String taskId = (String) responseMap.get("task_id");
10161018
waitForTask(taskId, MLTaskState.COMPLETED);
@@ -1023,7 +1025,7 @@ public void testBM25WithBedrockWithConversation() throws Exception {
10231025
waitForTask(taskId, MLTaskState.COMPLETED);
10241026

10251027
PipelineParameters pipelineParameters = new PipelineParameters();
1026-
pipelineParameters.tag = "testBM25WithBedrock";
1028+
pipelineParameters.tag = "testBM25WithBedrockWithConversation";
10271029
pipelineParameters.description = "desc";
10281030
pipelineParameters.modelId = modelId;
10291031
pipelineParameters.systemPrompt = "You are a helpful assistant";
@@ -1067,7 +1069,7 @@ public void testBM25WithCohere() throws Exception {
10671069
Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT);
10681070
Map responseMap = parseResponseToMap(response);
10691071
String connectorId = (String) responseMap.get("connector_id");
1070-
response = RestMLRemoteInferenceIT.registerRemoteModel("Cohere Chat Completion v1", connectorId);
1072+
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Cohere Chat Completion v1", connectorId);
10711073
responseMap = parseResponseToMap(response);
10721074
String taskId = (String) responseMap.get("task_id");
10731075
waitForTask(taskId, MLTaskState.COMPLETED);
@@ -1119,7 +1121,7 @@ public void testBM25WithCohereUsingLlmResponseField() throws Exception {
11191121
Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT);
11201122
Map responseMap = parseResponseToMap(response);
11211123
String connectorId = (String) responseMap.get("connector_id");
1122-
response = RestMLRemoteInferenceIT.registerRemoteModel("Cohere Chat Completion v1", connectorId);
1124+
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Cohere Chat Completion v1", connectorId);
11231125
responseMap = parseResponseToMap(response);
11241126
String taskId = (String) responseMap.get("task_id");
11251127
waitForTask(taskId, MLTaskState.COMPLETED);
@@ -1132,7 +1134,7 @@ public void testBM25WithCohereUsingLlmResponseField() throws Exception {
11321134
waitForTask(taskId, MLTaskState.COMPLETED);
11331135

11341136
PipelineParameters pipelineParameters = new PipelineParameters();
1135-
pipelineParameters.tag = "testBM25WithCohereLlmResponseField";
1137+
pipelineParameters.tag = "testBM25WithCohereUsingLlmResponseField";
11361138
pipelineParameters.description = "desc";
11371139
pipelineParameters.modelId = modelId;
11381140
pipelineParameters.systemPrompt = "You are a helpful assistant";

0 commit comments

Comments
 (0)