|
17 | 17 | */
|
18 | 18 | package org.opensearch.ml.rest;
|
19 | 19 |
|
| 20 | +import static org.opensearch.ml.rest.RestMLRemoteInferenceIT.createConnector; |
| 21 | +import static org.opensearch.ml.rest.RestMLRemoteInferenceIT.deployRemoteModel; |
20 | 22 | import static org.opensearch.ml.utils.TestHelper.makeRequest;
|
21 | 23 | import static org.opensearch.ml.utils.TestHelper.toHttpEntity;
|
22 | 24 |
|
|
41 | 43 | import com.google.common.collect.ImmutableList;
|
42 | 44 | import com.google.common.collect.ImmutableMap;
|
43 | 45 |
|
44 |
| -public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { |
| 46 | +public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { |
45 | 47 |
|
46 | 48 | private static final String OPENAI_KEY = System.getenv("OPENAI_KEY");
|
47 | 49 | private static final String OPENAI_CONNECTOR_BLUEPRINT = "{\n"
|
@@ -526,11 +528,11 @@ public void testBM25WithOpenAI() throws Exception {
|
526 | 528 | Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT);
|
527 | 529 | Map responseMap = parseResponseToMap(response);
|
528 | 530 | String connectorId = (String) responseMap.get("connector_id");
|
529 |
| - response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); |
| 531 | + response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-3.5 completions", connectorId); |
530 | 532 | responseMap = parseResponseToMap(response);
|
531 | 533 | String taskId = (String) responseMap.get("task_id");
|
532 | 534 | waitForTask(taskId, MLTaskState.COMPLETED);
|
533 |
| - response = getTask(taskId); |
| 535 | + response = RestMLRemoteInferenceIT.getTask(taskId); |
534 | 536 | responseMap = parseResponseToMap(response);
|
535 | 537 | String modelId = (String) responseMap.get("model_id");
|
536 | 538 | response = deployRemoteModel(modelId);
|
@@ -580,11 +582,11 @@ public void testBM25WithOpenAIWithImage() throws Exception {
|
580 | 582 | Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT);
|
581 | 583 | Map responseMap = parseResponseToMap(response);
|
582 | 584 | String connectorId = (String) responseMap.get("connector_id");
|
583 |
| - response = registerRemoteModel("openAI-GPT-4o-mini completions", connectorId); |
| 585 | + response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-4o-mini completions", connectorId); |
584 | 586 | responseMap = parseResponseToMap(response);
|
585 | 587 | String taskId = (String) responseMap.get("task_id");
|
586 | 588 | waitForTask(taskId, MLTaskState.COMPLETED);
|
587 |
| - response = getTask(taskId); |
| 589 | + response = RestMLRemoteInferenceIT.getTask(taskId); |
588 | 590 | responseMap = parseResponseToMap(response);
|
589 | 591 | String modelId = (String) responseMap.get("model_id");
|
590 | 592 | response = deployRemoteModel(modelId);
|
@@ -667,11 +669,11 @@ public void testBM25WithBedrock() throws Exception {
|
667 | 669 | Response response = createConnector(BEDROCK_CONNECTOR_BLUEPRINT);
|
668 | 670 | Map responseMap = parseResponseToMap(response);
|
669 | 671 | String connectorId = (String) responseMap.get("connector_id");
|
670 |
| - response = registerRemoteModel("Bedrock Anthropic Claude", connectorId); |
| 672 | + response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId); |
671 | 673 | responseMap = parseResponseToMap(response);
|
672 | 674 | String taskId = (String) responseMap.get("task_id");
|
673 | 675 | waitForTask(taskId, MLTaskState.COMPLETED);
|
674 |
| - response = getTask(taskId); |
| 676 | + response = RestMLRemoteInferenceIT.getTask(taskId); |
675 | 677 | responseMap = parseResponseToMap(response);
|
676 | 678 | String modelId = (String) responseMap.get("model_id");
|
677 | 679 | response = deployRemoteModel(modelId);
|
@@ -719,11 +721,11 @@ public void testBM25WithBedrockConverse() throws Exception {
|
719 | 721 | Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT);
|
720 | 722 | Map responseMap = parseResponseToMap(response);
|
721 | 723 | String connectorId = (String) responseMap.get("connector_id");
|
722 |
| - response = registerRemoteModel("Bedrock Anthropic Claude", connectorId); |
| 724 | + response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId); |
723 | 725 | responseMap = parseResponseToMap(response);
|
724 | 726 | String taskId = (String) responseMap.get("task_id");
|
725 | 727 | waitForTask(taskId, MLTaskState.COMPLETED);
|
726 |
| - response = getTask(taskId); |
| 728 | + response = RestMLRemoteInferenceIT.getTask(taskId); |
727 | 729 | responseMap = parseResponseToMap(response);
|
728 | 730 | String modelId = (String) responseMap.get("model_id");
|
729 | 731 | response = deployRemoteModel(modelId);
|
@@ -771,11 +773,11 @@ public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception {
|
771 | 773 | Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2);
|
772 | 774 | Map responseMap = parseResponseToMap(response);
|
773 | 775 | String connectorId = (String) responseMap.get("connector_id");
|
774 |
| - response = registerRemoteModel("Bedrock Anthropic Claude", connectorId); |
| 776 | + response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId); |
775 | 777 | responseMap = parseResponseToMap(response);
|
776 | 778 | String taskId = (String) responseMap.get("task_id");
|
777 | 779 | waitForTask(taskId, MLTaskState.COMPLETED);
|
778 |
| - response = getTask(taskId); |
| 780 | + response = RestMLRemoteInferenceIT.getTask(taskId); |
779 | 781 | responseMap = parseResponseToMap(response);
|
780 | 782 | String modelId = (String) responseMap.get("model_id");
|
781 | 783 | response = deployRemoteModel(modelId);
|
@@ -831,11 +833,11 @@ public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws
|
831 | 833 | Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2);
|
832 | 834 | Map responseMap = parseResponseToMap(response);
|
833 | 835 | String connectorId = (String) responseMap.get("connector_id");
|
834 |
| - response = registerRemoteModel("Bedrock Anthropic Claude", connectorId); |
| 836 | + response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId); |
835 | 837 | responseMap = parseResponseToMap(response);
|
836 | 838 | String taskId = (String) responseMap.get("task_id");
|
837 | 839 | waitForTask(taskId, MLTaskState.COMPLETED);
|
838 |
| - response = getTask(taskId); |
| 840 | + response = RestMLRemoteInferenceIT.getTask(taskId); |
839 | 841 | responseMap = parseResponseToMap(response);
|
840 | 842 | String modelId = (String) responseMap.get("model_id");
|
841 | 843 | response = deployRemoteModel(modelId);
|
@@ -890,11 +892,11 @@ public void testBM25WithOpenAIWithConversation() throws Exception {
|
890 | 892 | Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT);
|
891 | 893 | Map responseMap = parseResponseToMap(response);
|
892 | 894 | String connectorId = (String) responseMap.get("connector_id");
|
893 |
| - response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); |
| 895 | + response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-3.5 completions", connectorId); |
894 | 896 | responseMap = parseResponseToMap(response);
|
895 | 897 | String taskId = (String) responseMap.get("task_id");
|
896 | 898 | waitForTask(taskId, MLTaskState.COMPLETED);
|
897 |
| - response = getTask(taskId); |
| 899 | + response = RestMLRemoteInferenceIT.getTask(taskId); |
898 | 900 | responseMap = parseResponseToMap(response);
|
899 | 901 | String modelId = (String) responseMap.get("model_id");
|
900 | 902 | response = deployRemoteModel(modelId);
|
@@ -947,11 +949,11 @@ public void testBM25WithOpenAIWithConversationAndImage() throws Exception {
|
947 | 949 | Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT);
|
948 | 950 | Map responseMap = parseResponseToMap(response);
|
949 | 951 | String connectorId = (String) responseMap.get("connector_id");
|
950 |
| - response = registerRemoteModel("openAI-GPT-4 completions", connectorId); |
| 952 | + response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-4 completions", connectorId); |
951 | 953 | responseMap = parseResponseToMap(response);
|
952 | 954 | String taskId = (String) responseMap.get("task_id");
|
953 | 955 | waitForTask(taskId, MLTaskState.COMPLETED);
|
954 |
| - response = getTask(taskId); |
| 956 | + response = RestMLRemoteInferenceIT.getTask(taskId); |
955 | 957 | responseMap = parseResponseToMap(response);
|
956 | 958 | String modelId = (String) responseMap.get("model_id");
|
957 | 959 | response = deployRemoteModel(modelId);
|
@@ -1008,11 +1010,11 @@ public void testBM25WithBedrockWithConversation() throws Exception {
|
1008 | 1010 | Response response = createConnector(BEDROCK_CONNECTOR_BLUEPRINT);
|
1009 | 1011 | Map responseMap = parseResponseToMap(response);
|
1010 | 1012 | String connectorId = (String) responseMap.get("connector_id");
|
1011 |
| - response = registerRemoteModel("Bedrock", connectorId); |
| 1013 | + response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock", connectorId); |
1012 | 1014 | responseMap = parseResponseToMap(response);
|
1013 | 1015 | String taskId = (String) responseMap.get("task_id");
|
1014 | 1016 | waitForTask(taskId, MLTaskState.COMPLETED);
|
1015 |
| - response = getTask(taskId); |
| 1017 | + response = RestMLRemoteInferenceIT.getTask(taskId); |
1016 | 1018 | responseMap = parseResponseToMap(response);
|
1017 | 1019 | String modelId = (String) responseMap.get("model_id");
|
1018 | 1020 | response = deployRemoteModel(modelId);
|
@@ -1065,11 +1067,11 @@ public void testBM25WithCohere() throws Exception {
|
1065 | 1067 | Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT);
|
1066 | 1068 | Map responseMap = parseResponseToMap(response);
|
1067 | 1069 | String connectorId = (String) responseMap.get("connector_id");
|
1068 |
| - response = registerRemoteModel("Cohere Chat Completion v1", connectorId); |
| 1070 | + response = RestMLRemoteInferenceIT.registerRemoteModel("Cohere Chat Completion v1", connectorId); |
1069 | 1071 | responseMap = parseResponseToMap(response);
|
1070 | 1072 | String taskId = (String) responseMap.get("task_id");
|
1071 | 1073 | waitForTask(taskId, MLTaskState.COMPLETED);
|
1072 |
| - response = getTask(taskId); |
| 1074 | + response = RestMLRemoteInferenceIT.getTask(taskId); |
1073 | 1075 | responseMap = parseResponseToMap(response);
|
1074 | 1076 | String modelId = (String) responseMap.get("model_id");
|
1075 | 1077 | response = deployRemoteModel(modelId);
|
@@ -1117,11 +1119,11 @@ public void testBM25WithCohereUsingLlmResponseField() throws Exception {
|
1117 | 1119 | Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT);
|
1118 | 1120 | Map responseMap = parseResponseToMap(response);
|
1119 | 1121 | String connectorId = (String) responseMap.get("connector_id");
|
1120 |
| - response = registerRemoteModel("Cohere Chat Completion v1", connectorId); |
| 1122 | + response = RestMLRemoteInferenceIT.registerRemoteModel("Cohere Chat Completion v1", connectorId); |
1121 | 1123 | responseMap = parseResponseToMap(response);
|
1122 | 1124 | String taskId = (String) responseMap.get("task_id");
|
1123 | 1125 | waitForTask(taskId, MLTaskState.COMPLETED);
|
1124 |
| - response = getTask(taskId); |
| 1126 | + response = RestMLRemoteInferenceIT.getTask(taskId); |
1125 | 1127 | responseMap = parseResponseToMap(response);
|
1126 | 1128 | String modelId = (String) responseMap.get("model_id");
|
1127 | 1129 | response = deployRemoteModel(modelId);
|
|
0 commit comments