Skip to content

Commit 06b74cb

Browse files
committed
Fix MLRAGSearchProcessorIT not to extend RestMLRemoteInferenceIT.
Signed-off-by: Austin Lee <[email protected]>
1 parent e6b21da commit 06b74cb

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

+25-23
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
*/
1818
package org.opensearch.ml.rest;
1919

20+
import static org.opensearch.ml.rest.RestMLRemoteInferenceIT.createConnector;
21+
import static org.opensearch.ml.rest.RestMLRemoteInferenceIT.deployRemoteModel;
2022
import static org.opensearch.ml.utils.TestHelper.makeRequest;
2123
import static org.opensearch.ml.utils.TestHelper.toHttpEntity;
2224

@@ -41,7 +43,7 @@
4143
import com.google.common.collect.ImmutableList;
4244
import com.google.common.collect.ImmutableMap;
4345

44-
public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
46+
public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
4547

4648
private static final String OPENAI_KEY = System.getenv("OPENAI_KEY");
4749
private static final String OPENAI_CONNECTOR_BLUEPRINT = "{\n"
@@ -526,11 +528,11 @@ public void testBM25WithOpenAI() throws Exception {
526528
Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT);
527529
Map responseMap = parseResponseToMap(response);
528530
String connectorId = (String) responseMap.get("connector_id");
529-
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
531+
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
530532
responseMap = parseResponseToMap(response);
531533
String taskId = (String) responseMap.get("task_id");
532534
waitForTask(taskId, MLTaskState.COMPLETED);
533-
response = getTask(taskId);
535+
response = RestMLRemoteInferenceIT.getTask(taskId);
534536
responseMap = parseResponseToMap(response);
535537
String modelId = (String) responseMap.get("model_id");
536538
response = deployRemoteModel(modelId);
@@ -580,11 +582,11 @@ public void testBM25WithOpenAIWithImage() throws Exception {
580582
Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT);
581583
Map responseMap = parseResponseToMap(response);
582584
String connectorId = (String) responseMap.get("connector_id");
583-
response = registerRemoteModel("openAI-GPT-4o-mini completions", connectorId);
585+
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-4o-mini completions", connectorId);
584586
responseMap = parseResponseToMap(response);
585587
String taskId = (String) responseMap.get("task_id");
586588
waitForTask(taskId, MLTaskState.COMPLETED);
587-
response = getTask(taskId);
589+
response = RestMLRemoteInferenceIT.getTask(taskId);
588590
responseMap = parseResponseToMap(response);
589591
String modelId = (String) responseMap.get("model_id");
590592
response = deployRemoteModel(modelId);
@@ -667,11 +669,11 @@ public void testBM25WithBedrock() throws Exception {
667669
Response response = createConnector(BEDROCK_CONNECTOR_BLUEPRINT);
668670
Map responseMap = parseResponseToMap(response);
669671
String connectorId = (String) responseMap.get("connector_id");
670-
response = registerRemoteModel("Bedrock Anthropic Claude", connectorId);
672+
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId);
671673
responseMap = parseResponseToMap(response);
672674
String taskId = (String) responseMap.get("task_id");
673675
waitForTask(taskId, MLTaskState.COMPLETED);
674-
response = getTask(taskId);
676+
response = RestMLRemoteInferenceIT.getTask(taskId);
675677
responseMap = parseResponseToMap(response);
676678
String modelId = (String) responseMap.get("model_id");
677679
response = deployRemoteModel(modelId);
@@ -719,11 +721,11 @@ public void testBM25WithBedrockConverse() throws Exception {
719721
Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT);
720722
Map responseMap = parseResponseToMap(response);
721723
String connectorId = (String) responseMap.get("connector_id");
722-
response = registerRemoteModel("Bedrock Anthropic Claude", connectorId);
724+
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId);
723725
responseMap = parseResponseToMap(response);
724726
String taskId = (String) responseMap.get("task_id");
725727
waitForTask(taskId, MLTaskState.COMPLETED);
726-
response = getTask(taskId);
728+
response = RestMLRemoteInferenceIT.getTask(taskId);
727729
responseMap = parseResponseToMap(response);
728730
String modelId = (String) responseMap.get("model_id");
729731
response = deployRemoteModel(modelId);
@@ -771,11 +773,11 @@ public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception {
771773
Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2);
772774
Map responseMap = parseResponseToMap(response);
773775
String connectorId = (String) responseMap.get("connector_id");
774-
response = registerRemoteModel("Bedrock Anthropic Claude", connectorId);
776+
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId);
775777
responseMap = parseResponseToMap(response);
776778
String taskId = (String) responseMap.get("task_id");
777779
waitForTask(taskId, MLTaskState.COMPLETED);
778-
response = getTask(taskId);
780+
response = RestMLRemoteInferenceIT.getTask(taskId);
779781
responseMap = parseResponseToMap(response);
780782
String modelId = (String) responseMap.get("model_id");
781783
response = deployRemoteModel(modelId);
@@ -831,11 +833,11 @@ public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws
831833
Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2);
832834
Map responseMap = parseResponseToMap(response);
833835
String connectorId = (String) responseMap.get("connector_id");
834-
response = registerRemoteModel("Bedrock Anthropic Claude", connectorId);
836+
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId);
835837
responseMap = parseResponseToMap(response);
836838
String taskId = (String) responseMap.get("task_id");
837839
waitForTask(taskId, MLTaskState.COMPLETED);
838-
response = getTask(taskId);
840+
response = RestMLRemoteInferenceIT.getTask(taskId);
839841
responseMap = parseResponseToMap(response);
840842
String modelId = (String) responseMap.get("model_id");
841843
response = deployRemoteModel(modelId);
@@ -890,11 +892,11 @@ public void testBM25WithOpenAIWithConversation() throws Exception {
890892
Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT);
891893
Map responseMap = parseResponseToMap(response);
892894
String connectorId = (String) responseMap.get("connector_id");
893-
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
895+
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
894896
responseMap = parseResponseToMap(response);
895897
String taskId = (String) responseMap.get("task_id");
896898
waitForTask(taskId, MLTaskState.COMPLETED);
897-
response = getTask(taskId);
899+
response = RestMLRemoteInferenceIT.getTask(taskId);
898900
responseMap = parseResponseToMap(response);
899901
String modelId = (String) responseMap.get("model_id");
900902
response = deployRemoteModel(modelId);
@@ -947,11 +949,11 @@ public void testBM25WithOpenAIWithConversationAndImage() throws Exception {
947949
Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT);
948950
Map responseMap = parseResponseToMap(response);
949951
String connectorId = (String) responseMap.get("connector_id");
950-
response = registerRemoteModel("openAI-GPT-4 completions", connectorId);
952+
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-4 completions", connectorId);
951953
responseMap = parseResponseToMap(response);
952954
String taskId = (String) responseMap.get("task_id");
953955
waitForTask(taskId, MLTaskState.COMPLETED);
954-
response = getTask(taskId);
956+
response = RestMLRemoteInferenceIT.getTask(taskId);
955957
responseMap = parseResponseToMap(response);
956958
String modelId = (String) responseMap.get("model_id");
957959
response = deployRemoteModel(modelId);
@@ -1008,11 +1010,11 @@ public void testBM25WithBedrockWithConversation() throws Exception {
10081010
Response response = createConnector(BEDROCK_CONNECTOR_BLUEPRINT);
10091011
Map responseMap = parseResponseToMap(response);
10101012
String connectorId = (String) responseMap.get("connector_id");
1011-
response = registerRemoteModel("Bedrock", connectorId);
1013+
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock", connectorId);
10121014
responseMap = parseResponseToMap(response);
10131015
String taskId = (String) responseMap.get("task_id");
10141016
waitForTask(taskId, MLTaskState.COMPLETED);
1015-
response = getTask(taskId);
1017+
response = RestMLRemoteInferenceIT.getTask(taskId);
10161018
responseMap = parseResponseToMap(response);
10171019
String modelId = (String) responseMap.get("model_id");
10181020
response = deployRemoteModel(modelId);
@@ -1065,11 +1067,11 @@ public void testBM25WithCohere() throws Exception {
10651067
Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT);
10661068
Map responseMap = parseResponseToMap(response);
10671069
String connectorId = (String) responseMap.get("connector_id");
1068-
response = registerRemoteModel("Cohere Chat Completion v1", connectorId);
1070+
response = RestMLRemoteInferenceIT.registerRemoteModel("Cohere Chat Completion v1", connectorId);
10691071
responseMap = parseResponseToMap(response);
10701072
String taskId = (String) responseMap.get("task_id");
10711073
waitForTask(taskId, MLTaskState.COMPLETED);
1072-
response = getTask(taskId);
1074+
response = RestMLRemoteInferenceIT.getTask(taskId);
10731075
responseMap = parseResponseToMap(response);
10741076
String modelId = (String) responseMap.get("model_id");
10751077
response = deployRemoteModel(modelId);
@@ -1117,11 +1119,11 @@ public void testBM25WithCohereUsingLlmResponseField() throws Exception {
11171119
Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT);
11181120
Map responseMap = parseResponseToMap(response);
11191121
String connectorId = (String) responseMap.get("connector_id");
1120-
response = registerRemoteModel("Cohere Chat Completion v1", connectorId);
1122+
response = RestMLRemoteInferenceIT.registerRemoteModel("Cohere Chat Completion v1", connectorId);
11211123
responseMap = parseResponseToMap(response);
11221124
String taskId = (String) responseMap.get("task_id");
11231125
waitForTask(taskId, MLTaskState.COMPLETED);
1124-
response = getTask(taskId);
1126+
response = RestMLRemoteInferenceIT.getTask(taskId);
11251127
responseMap = parseResponseToMap(response);
11261128
String modelId = (String) responseMap.get("model_id");
11271129
response = deployRemoteModel(modelId);

0 commit comments

Comments
 (0)