Skip to content

Commit 833c9c1

Browse files
committed
Fix loading of default feature flag settings, update tests to pick up
changes in memory API. Signed-off-by: Austin Lee <[email protected]>
1 parent d57fa56 commit 833c9c1

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+6
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,12 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc
347347
private ScriptService scriptService;
348348
private Encryptor encryptor;
349349

350+
public MachineLearningPlugin(Settings settings) {
351+
// Handle this here as this feature is tied to Search/Query API, not to a ml-common API
352+
// and as such, it can't be lazy-loaded when a ml-commons API is invoked.
353+
this.ragSearchPipelineEnabled = MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings);
354+
}
355+
350356
@Override
351357
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
352358
return ImmutableList

plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.mockito.Mock;
3535
import org.mockito.MockitoAnnotations;
3636
import org.opensearch.client.Client;
37+
import org.opensearch.common.settings.Settings;
3738
import org.opensearch.ml.common.spi.MLCommonsExtension;
3839
import org.opensearch.ml.common.spi.tools.Tool;
3940
import org.opensearch.ml.engine.tools.MLModelTool;
@@ -47,7 +48,7 @@
4748

4849
public class MachineLearningPluginTests {
4950

50-
MachineLearningPlugin plugin = new MachineLearningPlugin();
51+
MachineLearningPlugin plugin = new MachineLearningPlugin(Settings.EMPTY);
5152

5253
@Rule
5354
public ExpectedException exceptionRule = ExpectedException.none();

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

+11-9
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
172172
+ " \"llm_model\": \"%s\",\n"
173173
+ " \"llm_question\": \"%s\",\n"
174174
+ " \"context_size\": %d,\n"
175-
+ " \"interaction_size\": %d,\n"
175+
+ " \"message_size\": %d,\n"
176176
+ " \"timeout\": %d\n"
177177
+ " }\n"
178178
+ " }\n"
@@ -187,9 +187,9 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
187187
+ " \"generative_qa_parameters\": {\n"
188188
+ " \"llm_model\": \"%s\",\n"
189189
+ " \"llm_question\": \"%s\",\n"
190-
+ " \"conversation_id\": \"%s\",\n"
190+
+ " \"memory_id\": \"%s\",\n"
191191
+ " \"context_size\": %d,\n"
192-
+ " \"interaction_size\": %d,\n"
192+
+ " \"message_size\": %d,\n"
193193
+ " \"timeout\": %d\n"
194194
+ " }\n"
195195
+ " }\n"
@@ -208,6 +208,7 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
208208
@Before
209209
public void init() throws Exception {
210210

211+
/*
211212
Response response = TestHelper
212213
.makeRequest(
213214
client(),
@@ -218,7 +219,7 @@ public void init() throws Exception {
218219
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
219220
);
220221
assertEquals(200, response.getStatusLine().getStatusCode());
221-
222+
222223
response = TestHelper
223224
.makeRequest(
224225
client(),
@@ -229,8 +230,9 @@ public void init() throws Exception {
229230
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
230231
);
231232
assertEquals(200, response.getStatusLine().getStatusCode());
233+
*/
232234

233-
response = TestHelper
235+
Response response = TestHelper
234236
.makeRequest(
235237
client(),
236238
"PUT",
@@ -428,7 +430,7 @@ public void testBM25WithOpenAIWithConversation() throws Exception {
428430
String answer = (String) rag.get("answer");
429431
assertNotNull(answer);
430432

431-
String interactionId = (String) rag.get("interaction_id");
433+
String interactionId = (String) rag.get("message_id");
432434
assertNotNull(interactionId);
433435
}
434436

@@ -485,7 +487,7 @@ public void testBM25WithBedrockWithConversation() throws Exception {
485487
String answer = (String) rag.get("answer");
486488
assertNotNull(answer);
487489

488-
String interactionId = (String) rag.get("interaction_id");
490+
String interactionId = (String) rag.get("message_id");
489491
assertNotNull(interactionId);
490492
}
491493

@@ -557,12 +559,12 @@ private String createConversation(String name) throws Exception {
557559
Response response = makeRequest(
558560
client(),
559561
"POST",
560-
"/_plugins/_ml/memory/conversation",
562+
"/_plugins/_ml/memory",
561563
null,
562564
toHttpEntity(String.format(Locale.ROOT, "{\"name\": \"%s\"}", name)),
563565
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
564566
);
565-
return (String) ((Map) parseResponseToMap(response)).get("conversation_id");
567+
return (String) ((Map) parseResponseToMap(response)).get("memory_id");
566568
}
567569

568570
static class PipelineParameters {

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponse.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class GenerativeSearchResponse extends SearchResponse {
3535
private static final String EXT_SECTION_NAME = "ext";
3636
private static final String GENERATIVE_QA_ANSWER_FIELD_NAME = "answer";
3737
private static final String GENERATIVE_QA_ERROR_FIELD_NAME = "error";
38-
private static final String INTERACTION_ID_FIELD_NAME = "interaction_id";
38+
private static final String INTERACTION_ID_FIELD_NAME = "message_id";
3939

4040
private final String answer;
4141
private String errorMessage;

0 commit comments

Comments
 (0)