From 57f9afa7b18575ee34b8cc49067f0c9dbfdbf366 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Thu, 23 Jan 2025 15:05:42 -0800 Subject: [PATCH] applying multi-tenancy in search Signed-off-by: Dhrubo Saha --- .../search/MLSearchActionRequest.java | 70 +++++++++ .../search/MLSearchActionRequestTest.java | 147 ++++++++++++++++++ .../SearchConversationsTransportAction.java | 10 +- ...archConversationsTransportActionTests.java | 8 +- .../agents/TransportSearchAgentAction.java | 38 ++++- .../SearchConnectorTransportAction.java | 69 +++++++- .../ml/action/handler/MLSearchHandler.java | 15 +- .../SearchModelGroupTransportAction.java | 55 +++++-- .../models/SearchModelTransportAction.java | 31 +++- .../tasks/SearchTaskTransportAction.java | 48 +++++- .../ml/plugin/MachineLearningPlugin.java | 10 +- .../ml/rest/AbstractMLSearchAction.java | 23 ++- .../ml/rest/RestMLSearchAgentAction.java | 3 +- .../ml/rest/RestMLSearchConnectorAction.java | 11 +- .../ml/rest/RestMLSearchModelAction.java | 5 +- .../ml/rest/RestMLSearchModelGroupAction.java | 11 +- .../ml/rest/RestMLSearchTaskAction.java | 5 +- .../RestMemorySearchConversationsAction.java | 6 +- .../TransportSearchAgentActionTests.java | 113 +++++++++++--- .../SearchConnectorTransportActionTests.java | 105 +++++++++++-- .../SearchModelGroupTransportActionTests.java | 91 ++++++++++- .../ml/action/models/SearchModelITTests.java | 22 ++- .../SearchModelTransportActionTests.java | 122 +++++++++++---- .../tasks/SearchTaskTransportActionTests.java | 74 ++++++++- .../RestMLSearchConnectorActionTests.java | 70 ++++++++- .../ml/rest/RestMLSearchModelActionTests.java | 41 ++++- .../ml/rest/RestMLSearchTaskActionTests.java | 9 +- ...tMemorySearchConversationsActionTests.java | 20 ++- .../org/opensearch/ml/utils/TestHelper.java | 4 + 29 files changed, 1079 insertions(+), 157 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/search/MLSearchActionRequest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/search/MLSearchActionRequestTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/transport/search/MLSearchActionRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/search/MLSearchActionRequest.java new file mode 100644 index 0000000000..b5fd54ed4a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/search/MLSearchActionRequest.java @@ -0,0 +1,70 @@ +package org.opensearch.ml.common.transport.search; + +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.Version; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.Builder; +import lombok.Getter; + +@Getter +public class MLSearchActionRequest extends SearchRequest { + SearchRequest searchRequest; + String tenantId; + + @Builder + public MLSearchActionRequest(SearchRequest searchRequest, String tenantId) { + this.searchRequest = searchRequest; + this.tenantId = tenantId; + } + + public MLSearchActionRequest(StreamInput input) throws IOException { + super(input); + Version streamInputVersion = input.getVersion(); + if (input.readBoolean()) { + searchRequest = new SearchRequest(input); + } + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null; + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + super.writeTo(output); + Version streamOutputVersion = output.getVersion(); + if (searchRequest != null) { + output.writeBoolean(true); // user exists + searchRequest.writeTo(output); + } else { + output.writeBoolean(false); // user does not exist + } + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + output.writeOptionalString(tenantId); + } + } + + public static MLSearchActionRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLSearchActionRequest) { + return (MLSearchActionRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLSearchActionRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLSearchActionRequest", e); + } + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/search/MLSearchActionRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/search/MLSearchActionRequestTest.java new file mode 100644 index 0000000000..f4c9abd431 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/search/MLSearchActionRequestTest.java @@ -0,0 +1,147 @@ +package org.opensearch.ml.common.transport.search; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.Version; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLSearchActionRequestTest { + + private SearchRequest searchRequest; + + @Before + public void setUp() { + searchRequest = new SearchRequest("test-index"); + } + + @Test + public void testConstructorAndGetters() { + MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); + assertEquals("test-index", request.getSearchRequest().indices()[0]); + assertEquals("test-tenant", request.getTenantId()); + } + + @Test + public void testStreamConstructorAndWriteTo() throws IOException { + MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + + MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput()); + assertEquals("test-index", deserializedRequest.getSearchRequest().indices()[0]); + assertEquals("test-tenant", deserializedRequest.getTenantId()); + } + + @Test + public void testWriteToWithNullSearchRequest() throws IOException { + MLSearchActionRequest request = MLSearchActionRequest.builder().tenantId("test-tenant").build(); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + + MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput()); + assertNull(deserializedRequest.getSearchRequest()); + assertEquals("test-tenant", deserializedRequest.getTenantId()); + } + + @Test + public void testFromActionRequestWithMLSearchActionRequest() { + MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); + MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request); + assertSame(result, request); + } + + @Test + public void testFromActionRequestWithNonMLSearchActionRequest() throws IOException { + MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + + MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(actionRequest); + assertNotSame(result, request); + assertEquals(request.getSearchRequest().indices()[0], result.getSearchRequest().indices()[0]); + assertEquals(request.getTenantId(), result.getTenantId()); + } + + @Test(expected = UncheckedIOException.class) + public void testFromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLSearchActionRequest.fromActionRequest(actionRequest); + } + + @Test + public void testBackwardCompatibility() throws IOException { + MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); + + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_2_18_0); // Older version + request.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(Version.V_2_18_0); + + MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in); + assertNull(deserializedRequest.getTenantId()); // Ensure tenantId is ignored + } + + @Test + public void testFromActionRequestWithValidRequest() { + MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); + + MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request); + assertSame(request, result); + } + + @Test + public void testMixedVersionCompatibility() throws IOException { + MLSearchActionRequest originalRequest = MLSearchActionRequest + .builder() + .searchRequest(searchRequest) + .tenantId("test-tenant") + .build(); + + // Serialize with a newer version + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_2_19_0); + originalRequest.writeTo(out); + + // Deserialize with an older version + StreamInput in = out.bytes().streamInput(); + in.setVersion(Version.V_2_18_0); + + MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in); + assertNull(deserializedRequest.getTenantId()); // tenantId should not exist in older versions + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java index b598d63e25..8a85f4a234 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java @@ -30,6 +30,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.memory.ConversationalMemoryHandler; import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.tasks.Task; @@ -38,7 +39,7 @@ import lombok.extern.log4j.Log4j2; @Log4j2 -public class SearchConversationsTransportAction extends HandledTransportAction { +public class SearchConversationsTransportAction extends HandledTransportAction { private ConversationalMemoryHandler cmHandler; private Client client; @@ -61,7 +62,7 @@ public SearchConversationsTransportAction( Client client, ClusterService clusterService ) { - super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new); + super(SearchConversationsAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.cmHandler = cmHandler; this.client = client; this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); @@ -71,13 +72,14 @@ public SearchConversationsTransportAction( } @Override - public void doExecute(Task task, SearchRequest request, ActionListener actionListener) { + public void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener actionListener) { + SearchRequest request = mlSearchActionRequest.getSearchRequest(); if (!featureIsEnabled) { actionListener.onFailure(new OpenSearchException(ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE)); return; } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { - ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener internalListener = ActionListener.runBefore(actionListener, context::restore); cmHandler.searchConversations(request, internalListener); } catch (Exception e) { log.error("Failed to search memories", e); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java index 1cee8c2ddf..86aad3be9e 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java @@ -44,6 +44,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.memory.MemoryTestUtil; import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.test.OpenSearchTestCase; @@ -79,12 +80,15 @@ public class SearchConversationsTransportActionTests extends OpenSearchTestCase @Mock SearchRequest request; + MLSearchActionRequest mlSearchActionRequest; + SearchConversationsTransportAction action; ThreadContext threadContext; @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); + mlSearchActionRequest = new MLSearchActionRequest(request, null); Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); this.threadContext = new ThreadContext(settings); @@ -104,7 +108,7 @@ public void testEnabled_ThenSucceed() { listener.onResponse(response); return null; }).when(cmHandler).searchConversations(any(), any()); - action.doExecute(null, request, actionListener); + action.doExecute(null, mlSearchActionRequest, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); verify(actionListener, times(1)).onResponse(argCaptor.capture()); assert (argCaptor.getValue().equals(response)); @@ -114,7 +118,7 @@ public void testDisabled_ThenFail() { clusterService = MemoryTestUtil.clusterServiceWithMemoryFeatureDisabled(); this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); - action.doExecute(null, request, actionListener); + action.doExecute(null, mlSearchActionRequest, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); assertEquals(argCaptor.getValue().getMessage(), ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE); diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java index 3153394061..ad693da3cc 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.agents; import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -20,28 +21,46 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.transport.agent.MLSearchAgentAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import lombok.extern.log4j.Log4j2; @Log4j2 -public class TransportSearchAgentAction extends HandledTransportAction { +public class TransportSearchAgentAction extends HandledTransportAction { private final Client client; + private final SdkClient sdkClient; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject - public TransportSearchAgentAction(TransportService transportService, ActionFilters actionFilters, Client client) { - super(MLSearchAgentAction.NAME, transportService, actionFilters, SearchRequest::new); + public TransportSearchAgentAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + SdkClient sdkClient, + MLFeatureEnabledSetting mlFeatureEnabledSetting + ) { + super(MLSearchAgentAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.client = client; + this.sdkClient = sdkClient; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override - protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { - request.indices(CommonValue.ML_AGENT_INDEX); - search(request, actionListener); + protected void doExecute(Task task, MLSearchActionRequest request, ActionListener actionListener) { + request.getSearchRequest().indices(CommonValue.ML_AGENT_INDEX); + String tenantId = request.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + search(request.getSearchRequest(), tenantId, actionListener); } - private void search(SearchRequest request, ActionListener actionListener) { + private void search(SearchRequest request, String tenantId, ActionListener actionListener) { ActionListener listener = wrapRestActionListener(actionListener, "Fail to search agent"); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); @@ -57,6 +76,11 @@ private void search(SearchRequest request, ActionListener action // Add a should clause to include documents where IS_HIDDEN_FIELD is false shouldQuery.should(QueryBuilders.termQuery(MLAgent.IS_HIDDEN_FIELD, false)); + // For multi-tenancy + if (tenantId != null) { + shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); + } + // Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD))); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java index 25fcf310f7..8d0813f76b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java @@ -5,6 +5,7 @@ package org.opensearch.ml.action.connector; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import java.util.ArrayList; @@ -13,8 +14,10 @@ import java.util.Optional; import java.util.stream.Collectors; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -22,44 +25,65 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import com.google.common.annotations.VisibleForTesting; + import lombok.extern.log4j.Log4j2; @Log4j2 -public class SearchConnectorTransportAction extends HandledTransportAction { +public class SearchConnectorTransportAction extends HandledTransportAction { private final Client client; + private final SdkClient sdkClient; private final ConnectorAccessControlHelper connectorAccessControlHelper; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject public SearchConnectorTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, - ConnectorAccessControlHelper connectorAccessControlHelper + SdkClient sdkClient, + ConnectorAccessControlHelper connectorAccessControlHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { - super(MLConnectorSearchAction.NAME, transportService, actionFilters, SearchRequest::new); + super(MLConnectorSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.client = client; + this.sdkClient = sdkClient; this.connectorAccessControlHelper = connectorAccessControlHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override - protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { - request.indices(CommonValue.ML_CONNECTOR_INDEX); - search(request, actionListener); + protected void doExecute(Task task, MLSearchActionRequest request, ActionListener actionListener) { + request.getSearchRequest().indices(CommonValue.ML_CONNECTOR_INDEX); + + String tenantId = request.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + search(request.getSearchRequest(), tenantId, actionListener); } - private void search(SearchRequest request, ActionListener actionListener) { + private void search(SearchRequest request, String tenantId, ActionListener actionListener) { User user = RestActionUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); @@ -84,6 +108,15 @@ private void search(SearchRequest request, ActionListener action final ActionListener doubleWrappedListener = ActionListener .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); + if (tenantId != null) { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery(); + if (request.source().query() != null) { + queryBuilder.must(request.source().query()); + } + queryBuilder.filter(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); // Replace with your tenant_id field + request.source().query(queryBuilder); + } + if (connectorAccessControlHelper.skipConnectorAccessControl(user)) { client.search(request, doubleWrappedListener); } else { @@ -96,4 +129,26 @@ private void search(SearchRequest request, ActionListener action actionListener.onFailure(e); } } + + @VisibleForTesting + public static void wrapListenerToHandleConnectorIndexNotFound(Exception e, ActionListener listener) { + if (ExceptionsHelper.unwrapCause(e) instanceof IndexNotFoundException) { + log.debug("Connectors index not created yet, therefore we will swallow the exception and return an empty search result"); + final InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty(); + final SearchResponse emptySearchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 0, + new ShardSearchFailure[] {}, + SearchResponse.Clusters.EMPTY, + null + ); + listener.onResponse(emptySearchResponse); + } else { + listener.onFailure(e); + } + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index 80e47fd9e0..d7b2bded27 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -7,6 +7,7 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import java.util.ArrayList; @@ -79,11 +80,11 @@ public MLSearchHandler( * @param request * @param actionListener */ - public void search(SearchRequest request, ActionListener actionListener) { + public void search(SearchRequest request, String tenantId, ActionListener actionListener) { User user = RestActionUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, "Fail to search model version"); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); List excludes = Optional .ofNullable(request.source()) .map(SearchSourceBuilder::fetchSource) @@ -113,6 +114,16 @@ public void search(SearchRequest request, ActionListener actionL // Add a should clause to include documents where IS_HIDDEN_FIELD is false shouldQuery.should(QueryBuilders.termQuery(MLModel.IS_HIDDEN_FIELD, false)); + // For multi-tenancy + if (tenantId != null) { + shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); + } + + // For multi-tenancy + if (tenantId != null) { + shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); + } + // Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLModel.IS_HIDDEN_FIELD))); diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java index 4fa95fafa6..af2d78070e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.model_group; import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import org.opensearch.action.search.SearchRequest; @@ -18,19 +19,27 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import lombok.extern.log4j.Log4j2; @Log4j2 -public class SearchModelGroupTransportAction extends HandledTransportAction { +public class SearchModelGroupTransportAction extends HandledTransportAction { Client client; + SdkClient sdkClient; ClusterService clusterService; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; ModelAccessControlHelper modelAccessControlHelper; @@ -39,36 +48,64 @@ public SearchModelGroupTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SdkClient sdkClient, ClusterService clusterService, - ModelAccessControlHelper modelAccessControlHelper + ModelAccessControlHelper modelAccessControlHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { - super(MLModelGroupSearchAction.NAME, transportService, actionFilters, SearchRequest::new); + super(MLModelGroupSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.client = client; + this.sdkClient = sdkClient; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override - protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { + protected void doExecute(Task task, MLSearchActionRequest request, ActionListener actionListener) { User user = RestActionUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, "Fail to search"); - request.indices(CommonValue.ML_MODEL_GROUP_INDEX); - preProcessRoleAndPerformSearch(request, user, listener); + request.getSearchRequest().indices(CommonValue.ML_MODEL_GROUP_INDEX); + String tenantId = request.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + preProcessRoleAndPerformSearch(request.getSearchRequest(), tenantId, user, listener); } - private void preProcessRoleAndPerformSearch(SearchRequest request, User user, ActionListener listener) { + private void preProcessRoleAndPerformSearch( + SearchRequest request, + String tenantId, + User user, + ActionListener listener + ) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); final ActionListener doubleWrappedListener = ActionListener .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); + // Modify the query to include tenant ID filtering + if (tenantId != null) { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery(); + + // Preserve existing query if present + if (request.source().query() != null) { + queryBuilder.must(request.source().query()); + } + // Add tenancy filter + queryBuilder.filter(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); // Replace 'tenant_id_field' with actual field name + + // Update the request's source with the new query + request.source().query(queryBuilder); + } + if (modelAccessControlHelper.skipModelAccessControl(user)) { client.search(request, doubleWrappedListener); } else { // Security is enabled, filter is enabled and user isn't admin modelAccessControlHelper.addUserBackendRolesFilter(user, request.source()); - log.debug("Filtering result by " + user.getBackendRoles()); + log.debug("Filtering result by {}", user.getBackendRoles()); client.search(request, doubleWrappedListener); } } catch (Exception e) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java index 64d4913d5f..5be1d0a06d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java @@ -5,7 +5,6 @@ package org.opensearch.ml.action.models; -import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -14,24 +13,42 @@ import org.opensearch.ml.action.handler.MLSearchHandler; import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import lombok.extern.log4j.Log4j2; @Log4j2 -public class SearchModelTransportAction extends HandledTransportAction { +public class SearchModelTransportAction extends HandledTransportAction { private final MLSearchHandler mlSearchHandler; + private final SdkClient sdkClient; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject - public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, MLSearchHandler mlSearchHandler) { - super(MLModelSearchAction.NAME, transportService, actionFilters, SearchRequest::new); + public SearchModelTransportAction( + TransportService transportService, + ActionFilters actionFilters, + SdkClient sdkClient, + MLSearchHandler mlSearchHandler, + MLFeatureEnabledSetting mlFeatureEnabledSetting + ) { + super(MLModelSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); + this.sdkClient = sdkClient; this.mlSearchHandler = mlSearchHandler; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override - protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { - request.indices(CommonValue.ML_MODEL_INDEX); - mlSearchHandler.search(request, actionListener); + protected void doExecute(Task task, MLSearchActionRequest request, ActionListener actionListener) { + request.getSearchRequest().indices(CommonValue.ML_MODEL_INDEX); + String tenantId = request.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + mlSearchHandler.search(request.getSearchRequest(), tenantId, actionListener); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java index c87e4087c3..527a60d8e4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java @@ -5,6 +5,7 @@ package org.opensearch.ml.action.tasks; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import org.opensearch.action.search.SearchRequest; @@ -15,28 +16,65 @@ import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.common.transport.task.MLTaskSearchAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import lombok.extern.log4j.Log4j2; @Log4j2 -public class SearchTaskTransportAction extends HandledTransportAction { +public class SearchTaskTransportAction extends HandledTransportAction { private Client client; + private final SdkClient sdkClient; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject - public SearchTaskTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { - super(MLTaskSearchAction.NAME, transportService, actionFilters, SearchRequest::new); + public SearchTaskTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + SdkClient sdkClient, + MLFeatureEnabledSetting mlFeatureEnabledSetting + ) { + super(MLTaskSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.client = client; + this.sdkClient = sdkClient; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override - protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { + protected void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener actionListener) { + String tenantId = mlSearchActionRequest.getTenantId(); + SearchRequest request = mlSearchActionRequest.getSearchRequest(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { final ActionListener wrappedListener = ActionListener .wrap(actionListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, actionListener)); - client.search(request, ActionListener.runBefore(wrappedListener, () -> context.restore())); + + // Modify the query to include tenant ID filtering + if (tenantId != null) { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery(); + + // Preserve existing query if present + if (request.source().query() != null) { + queryBuilder.must(request.source().query()); + } + // Add tenancy filter + queryBuilder.filter(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); // Replace 'tenant_id_field' with actual field name + + // Update the request's source with the new query + request.source().query(queryBuilder); + } + client.search(request, ActionListener.runBefore(wrappedListener, context::restore)); } catch (Exception e) { log.error(e.getMessage(), e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index b57ee724ca..2a3ab5eefa 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -760,10 +760,10 @@ public List getRestHandlers( RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction(mlFeatureEnabledSetting); RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting); RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction(mlFeatureEnabledSetting); - RestMLSearchModelAction restMLSearchModelAction = new RestMLSearchModelAction(); + RestMLSearchModelAction restMLSearchModelAction = new RestMLSearchModelAction(mlFeatureEnabledSetting); RestMLGetTaskAction restMLGetTaskAction = new RestMLGetTaskAction(mlFeatureEnabledSetting); RestMLDeleteTaskAction restMLDeleteTaskAction = new RestMLDeleteTaskAction(mlFeatureEnabledSetting); - RestMLSearchTaskAction restMLSearchTaskAction = new RestMLSearchTaskAction(); + RestMLSearchTaskAction restMLSearchTaskAction = new RestMLSearchTaskAction(mlFeatureEnabledSetting); RestMLProfileAction restMLProfileAction = new RestMLProfileAction(clusterService); RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction( clusterService, @@ -784,14 +784,16 @@ public List getRestHandlers( RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting); RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(clusterService, settings, mlFeatureEnabledSetting); RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(mlFeatureEnabledSetting); - RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction(); + RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction(mlFeatureEnabledSetting); RestMemoryCreateConversationAction restCreateConversationAction = new RestMemoryCreateConversationAction(); RestMemoryGetConversationsAction restListConversationsAction = new RestMemoryGetConversationsAction(); RestMemoryCreateInteractionAction restCreateInteractionAction = new RestMemoryCreateInteractionAction(); RestMemoryGetInteractionsAction restListInteractionsAction = new RestMemoryGetInteractionsAction(); RestMemoryDeleteConversationAction restDeleteConversationAction = new RestMemoryDeleteConversationAction(); RestMLUpdateConnectorAction restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); - RestMemorySearchConversationsAction restSearchConversationsAction = new RestMemorySearchConversationsAction(); + RestMemorySearchConversationsAction restSearchConversationsAction = new RestMemorySearchConversationsAction( + mlFeatureEnabledSetting + ); RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/AbstractMLSearchAction.java b/plugin/src/main/java/org/opensearch/ml/rest/AbstractMLSearchAction.java index 4c94f8267a..d18594a228 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/AbstractMLSearchAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/AbstractMLSearchAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.ml.utils.RestActionUtils.getSourceContext; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.ArrayList; @@ -18,6 +19,8 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; @@ -32,12 +35,20 @@ public abstract class AbstractMLSearchAction extends protected final String index; protected final Class clazz; protected final ActionType actionType; + MLFeatureEnabledSetting mlFeatureEnabledSetting; - public AbstractMLSearchAction(List urlPaths, String index, Class clazz, ActionType actionType) { + public AbstractMLSearchAction( + List urlPaths, + String index, + Class clazz, + ActionType actionType, + MLFeatureEnabledSetting mlFeatureEnabledSetting + ) { this.urlPaths = urlPaths; this.index = index; this.clazz = clazz; this.actionType = actionType; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override @@ -46,12 +57,18 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); - return channel -> client.execute(actionType, searchRequest, search(channel)); + MLSearchActionRequest mlSearchActionRequest = MLSearchActionRequest + .builder() + .searchRequest(searchRequest) + .tenantId(tenantId) + .build(); + return channel -> client.execute(actionType, mlSearchActionRequest, search(channel)); } protected RestResponseListener search(RestChannel channel) { - return new RestResponseListener(channel) { + return new RestResponseListener<>(channel) { @Override public RestResponse buildResponse(SearchResponse response) throws Exception { if (response.isTimedOut()) { diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchAgentAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchAgentAction.java index 45576f5cbb..61e7af7a30 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchAgentAction.java @@ -24,10 +24,9 @@ public class RestMLSearchAgentAction extends AbstractMLSearchAction { private static final String ML_SEARCH_AGENT_ACTION = "ml_search_agent_action"; private static final String SEARCH_AGENT_PATH = ML_BASE_URI + "/agents/_search"; - private final MLFeatureEnabledSetting mlFeatureEnabledSetting; public RestMLSearchAgentAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { - super(ImmutableList.of(SEARCH_AGENT_PATH), ML_AGENT_INDEX, MLAgent.class, MLSearchAgentAction.INSTANCE); + super(ImmutableList.of(SEARCH_AGENT_PATH), ML_AGENT_INDEX, MLAgent.class, MLSearchAgentAction.INSTANCE, mlFeatureEnabledSetting); this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchConnectorAction.java index 517882a7a3..c8f9afcb9c 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchConnectorAction.java @@ -10,6 +10,7 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import com.google.common.collect.ImmutableList; @@ -17,8 +18,14 @@ public class RestMLSearchConnectorAction extends AbstractMLSearchAction { private static final String ML_SEARCH_MODEL_ACTION = "ml_search_model_action"; private static final String SEARCH_MODEL_PATH = ML_BASE_URI + "/models/_search"; - public RestMLSearchModelAction() { - super(ImmutableList.of(SEARCH_MODEL_PATH), ML_MODEL_INDEX, MLModel.class, MLModelSearchAction.INSTANCE); + public RestMLSearchModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + super(ImmutableList.of(SEARCH_MODEL_PATH), ML_MODEL_INDEX, MLModel.class, MLModelSearchAction.INSTANCE, mlFeatureEnabledSetting); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelGroupAction.java index d6123c57d7..c4f5b5f48e 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelGroupAction.java @@ -20,12 +20,15 @@ public class RestMLSearchModelGroupAction extends AbstractMLSearchAction { private static final String ML_SEARCH_MODEL_GROUP_ACTION = "ml_search_model_group_action"; private static final String SEARCH_MODEL_GROUP_PATH = ML_BASE_URI + "/model_groups/_search"; - private final MLFeatureEnabledSetting mlFeatureEnabledSetting; public RestMLSearchModelGroupAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { - super(ImmutableList.of(SEARCH_MODEL_GROUP_PATH), ML_MODEL_GROUP_INDEX, MLModelGroup.class, MLModelGroupSearchAction.INSTANCE); - - this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + super( + ImmutableList.of(SEARCH_MODEL_GROUP_PATH), + ML_MODEL_GROUP_INDEX, + MLModelGroup.class, + MLModelGroupSearchAction.INSTANCE, + mlFeatureEnabledSetting + ); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchTaskAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchTaskAction.java index 25f69e23b8..70bf4f7894 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchTaskAction.java @@ -10,6 +10,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.transport.task.MLTaskSearchAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import com.google.common.collect.ImmutableList; @@ -17,8 +18,8 @@ public class RestMLSearchTaskAction extends AbstractMLSearchAction { private static final String ML_SEARCH_Task_ACTION = "ml_search_task_action"; private static final String SEARCH_TASK_PATH = ML_BASE_URI + "/tasks/_search"; - public RestMLSearchTaskAction() { - super(ImmutableList.of(SEARCH_TASK_PATH), ML_TASK_INDEX, MLTask.class, MLTaskSearchAction.INSTANCE); + public RestMLSearchTaskAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + super(ImmutableList.of(SEARCH_TASK_PATH), ML_TASK_INDEX, MLTask.class, MLTaskSearchAction.INSTANCE, mlFeatureEnabledSetting); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java index 5beee29c42..02c2ba8c70 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java @@ -21,18 +21,20 @@ import org.opensearch.ml.common.conversation.ConversationMeta; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import com.google.common.collect.ImmutableList; public class RestMemorySearchConversationsAction extends AbstractMLSearchAction { private static final String SEARCH_CONVERSATIONS_NAME = "conversation_memory_search_conversations"; - public RestMemorySearchConversationsAction() { + public RestMemorySearchConversationsAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { super( ImmutableList.of(ActionConstants.SEARCH_CONVERSATIONS_REST_PATH), ConversationalIndexConstants.META_INDEX_NAME, ConversationMeta.class, - SearchConversationsAction.INSTANCE + SearchConversationsAction.INSTANCE, + mlFeatureEnabledSetting ); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java index 9853578bfa..0bdb8f862b 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java @@ -9,6 +9,9 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; +import java.util.Collections; + +import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -16,13 +19,23 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -30,6 +43,8 @@ public class TransportSearchAgentActionTests extends OpenSearchTestCase { @Mock Client client; + SdkClient sdkClient; + SearchResponse searchResponse; @Mock TransportService transportService; @@ -37,22 +52,51 @@ public class TransportSearchAgentActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock ActionListener actionListener; TransportSearchAgentAction transportSearchAgentAction; - @Mock - SearchResponse mockedSearchResponse; - @Before public void setup() { MockitoAnnotations.openMocks(this); - transportSearchAgentAction = new TransportSearchAgentAction(transportService, actionFilters, client); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + transportSearchAgentAction = new TransportSearchAgentAction( + transportService, + actionFilters, + client, + sdkClient, + mlFeatureEnabledSetting + ); ThreadPool threadPool = mock(ThreadPool.class); when(client.threadPool()).thenReturn(threadPool); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); when(threadPool.getThreadContext()).thenReturn(threadContext); + + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); } @Test @@ -62,14 +106,23 @@ public void testDoExecuteWithEmptyQuery() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - listener.onResponse(mockedSearchResponse); + listener.onResponse(searchResponse); return null; }).when(client).search(eq(request), any()); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); - transportSearchAgentAction.doExecute(null, request, actionListener); - + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); verify(client, times(1)).search(eq(request), any()); - verify(actionListener, times(1)).onResponse(eq(mockedSearchResponse)); + // Use ArgumentCaptor to capture the SearchResponse + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchResponse.class); + // Capture the response passed to actionListener.onResponse + verify(actionListener, times(1)).onResponse(responseCaptor.capture()); + // Assert that the captured response matches the expected values + SearchResponse capturedResponse = responseCaptor.getValue(); + assertEquals(searchResponse.getHits().getTotalHits(), capturedResponse.getHits().getTotalHits()); + assertEquals(searchResponse.getHits().getHits().length, capturedResponse.getHits().getHits().length); + assertEquals(searchResponse.status(), capturedResponse.status()); + } @Test @@ -77,30 +130,39 @@ public void testDoExecuteWithNonEmptyQuery() { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); sourceBuilder.query(QueryBuilders.matchAllQuery()); SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - listener.onResponse(mockedSearchResponse); + listener.onResponse(searchResponse); return null; }).when(client).search(eq(request), any()); - transportSearchAgentAction.doExecute(null, request, actionListener); + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); verify(client, times(1)).search(eq(request), any()); - verify(actionListener, times(1)).onResponse(eq(mockedSearchResponse)); + // Use ArgumentCaptor to capture the SearchResponse + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchResponse.class); + // Capture the response passed to actionListener.onResponse + verify(actionListener, times(1)).onResponse(responseCaptor.capture()); + // Assert that the captured response matches the expected values + SearchResponse capturedResponse = responseCaptor.getValue(); + assertEquals(searchResponse.getHits().getTotalHits(), capturedResponse.getHits().getTotalHits()); + assertEquals(searchResponse.getHits().getHits().length, capturedResponse.getHits().getHits().length); + assertEquals(searchResponse.status(), capturedResponse.status()); } @Test public void testDoExecuteOnFailure() { SearchRequest request = new SearchRequest("my_index"); - + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new Exception("test exception")); return null; }).when(client).search(eq(request), any()); - transportSearchAgentAction.doExecute(null, request, actionListener); + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); verify(client, times(1)).search(eq(request), any()); verify(actionListener, times(1)).onFailure(any(Exception.class)); @@ -111,17 +173,25 @@ public void testSearchWithHiddenField() { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); - + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - listener.onResponse(mockedSearchResponse); + listener.onResponse(searchResponse); return null; }).when(client).search(eq(request), any()); - transportSearchAgentAction.doExecute(null, request, actionListener); + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); verify(client, times(1)).search(eq(request), any()); - verify(actionListener, times(1)).onResponse(eq(mockedSearchResponse)); + // Use ArgumentCaptor to capture the SearchResponse + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchResponse.class); + // Capture the response passed to actionListener.onResponse + verify(actionListener, times(1)).onResponse(responseCaptor.capture()); + // Assert that the captured response matches the expected values + SearchResponse capturedResponse = responseCaptor.getValue(); + assertEquals(searchResponse.getHits().getTotalHits(), capturedResponse.getHits().getTotalHits()); + assertEquals(searchResponse.getHits().getHits().length, capturedResponse.getHits().getHits().length); + assertEquals(searchResponse.status(), capturedResponse.status()); } @Test @@ -129,18 +199,18 @@ public void testSearchException() { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); - + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new Exception("failed to search the agent index")); return null; }).when(client).search(eq(request), any()); - transportSearchAgentAction.doExecute(null, request, actionListener); + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Fail to search agent", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to search indices [.plugins-ml-agent]", argumentCaptor.getValue().getMessage()); } @Test @@ -150,9 +220,10 @@ public void testSearchThrowsException() { // Create a search request SearchRequest searchRequest = new SearchRequest(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); // Execute the action - transportSearchAgentAction.doExecute(null, searchRequest, actionListener); + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); // Verify that the actionListener's onFailure method was called verify(actionListener, times(1)).onFailure(any(RuntimeException.class)); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/SearchConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/SearchConnectorTransportActionTests.java index 079206e621..5ec1fd8853 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/SearchConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/SearchConnectorTransportActionTests.java @@ -11,11 +11,18 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.util.Collections; + +import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; @@ -23,9 +30,19 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -35,6 +52,7 @@ public class SearchConnectorTransportActionTests extends OpenSearchTestCase { @Mock Client client; + SdkClient sdkClient; @Mock TransportService transportService; @@ -42,9 +60,12 @@ public class SearchConnectorTransportActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; - @Mock SearchRequest searchRequest; + MLSearchActionRequest mlSearchActionRequest; + + SearchResponse searchResponse; + SearchSourceBuilder searchSourceBuilder; @Mock @@ -63,14 +84,21 @@ public class SearchConnectorTransportActionTests extends OpenSearchTestCase { @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() { MockitoAnnotations.openMocks(this); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); searchConnectorTransportAction = new SearchConnectorTransportAction( transportService, actionFilters, client, - connectorAccessControlHelper + sdkClient, + connectorAccessControlHelper, + mlFeatureEnabledSetting ); Settings settings = Settings.builder().build(); @@ -80,41 +108,100 @@ public void setup() { searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.fetchSource(fetchSourceContext); - when(searchRequest.source()).thenReturn(searchSourceBuilder); + searchRequest = new SearchRequest(new String[0], searchSourceBuilder); + mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); when(fetchSourceContext.includes()).thenReturn(new String[] {}); when(fetchSourceContext.excludes()).thenReturn(new String[] {}); + + when(connectorAccessControlHelper.addUserBackendRolesFilter(any(), any(SearchSourceBuilder.class))).thenReturn(searchSourceBuilder); + + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); } + @Test public void test_doExecute_connectorAccessControlNotEnabled_searchSuccess() { String userString = "admin|role-1|all_access"; threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userString); when(connectorAccessControlHelper.skipConnectorAccessControl(any(User.class))).thenReturn(true); - SearchResponse searchResponse = mock(SearchResponse.class); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(searchResponse); return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - searchConnectorTransportAction.doExecute(task, searchRequest, actionListener); + searchConnectorTransportAction.doExecute(task, mlSearchActionRequest, actionListener); verify(actionListener).onResponse(any(SearchResponse.class)); } + @Test public void test_doExecute_connectorAccessControlEnabled_searchSuccess() { when(connectorAccessControlHelper.skipConnectorAccessControl(any(User.class))).thenReturn(false); - SearchResponse searchResponse = mock(SearchResponse.class); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(searchResponse); return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - searchConnectorTransportAction.doExecute(task, searchRequest, actionListener); + searchConnectorTransportAction.doExecute(task, mlSearchActionRequest, actionListener); verify(actionListener).onResponse(any(SearchResponse.class)); } + @Test public void test_doExecute_exception() { - when(searchRequest.source()).thenThrow(new RuntimeException("runtime exception")); - searchConnectorTransportAction.doExecute(task, searchRequest, actionListener); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("runtime exception")); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + searchConnectorTransportAction.doExecute(task, mlSearchActionRequest, actionListener); verify(actionListener).onFailure(any(RuntimeException.class)); } + @Test + public void testDoExecute_MultiTenancyEnabled_TenantFilteringNotEnabled() throws InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + searchConnectorTransportAction.doExecute(task, mlSearchActionRequest, actionListener); + + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(captor.capture()); + OpenSearchStatusException exception = captor.getValue(); + assertEquals(RestStatus.FORBIDDEN, exception.status()); + assertEquals("You don't have permission to access this resource", exception.getMessage()); + } + + @Test + public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() throws InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + mlSearchActionRequest = new MLSearchActionRequest(searchRequest, "123456"); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + searchConnectorTransportAction.doExecute(task, mlSearchActionRequest, actionListener); + verify(actionListener).onResponse(any(SearchResponse.class)); + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java index ecfb10f221..7cba5c02e0 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java @@ -6,15 +6,21 @@ package org.opensearch.ml.action.model_group; import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.util.Collections; + +import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -23,7 +29,17 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -31,6 +47,7 @@ public class SearchModelGroupTransportActionTests extends OpenSearchTestCase { @Mock Client client; + SdkClient sdkClient; @Mock NamedXContentRegistry namedXContentRegistry; @@ -41,9 +58,17 @@ public class SearchModelGroupTransportActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; - @Mock SearchRequest searchRequest; + SearchResponse searchResponse; + + SearchSourceBuilder searchSourceBuilder; + + MLSearchActionRequest mlSearchActionRequest; + + @Mock + FetchSourceContext fetchSourceContext; + @Mock ActionListener actionListener; @@ -56,17 +81,22 @@ public class SearchModelGroupTransportActionTests extends OpenSearchTestCase { @Mock private ModelAccessControlHelper modelAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; ThreadContext threadContext; @Before public void setup() { MockitoAnnotations.openMocks(this); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); searchModelGroupTransportAction = new SearchModelGroupTransportAction( transportService, actionFilters, client, + sdkClient, clusterService, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ); Settings settings = Settings.builder().build(); @@ -74,19 +104,70 @@ public void setup() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + + searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.fetchSource(fetchSourceContext); + searchRequest = new SearchRequest(new String[0], searchSourceBuilder); + mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + when(fetchSourceContext.includes()).thenReturn(new String[] {}); + when(fetchSourceContext.excludes()).thenReturn(new String[] {}); + + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); } public void test_DoExecute() { when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); - searchModelGroupTransportAction.doExecute(null, searchRequest, actionListener); + searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + + verify(modelAccessControlHelper).addUserBackendRolesFilter(any(), any()); + verify(client).search(any(), any()); + } + + public void test_DoExecute_Exception() throws InterruptedException { + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("search failed")); + return null; + }).when(client).search(any(), any()); + + when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); + searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(modelAccessControlHelper).addUserBackendRolesFilter(any(), any()); verify(client).search(any(), any()); + verify(actionListener).onFailure(any(RuntimeException.class)); } public void test_skipModelAccessControlTrue() { when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelGroupTransportAction.doExecute(null, searchRequest, actionListener); + searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(client).search(any(), any()); } @@ -94,7 +175,7 @@ public void test_skipModelAccessControlTrue() { public void test_ThreadContextError() { when(modelAccessControlHelper.skipModelAccessControl(any())).thenThrow(new RuntimeException("thread context error")); - searchModelGroupTransportAction.doExecute(null, searchRequest, actionListener); + searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Fail to search", argumentCaptor.getValue().getMessage()); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java index 9bbf98e139..b347597ea4 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java @@ -27,6 +27,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelAction; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; @@ -110,7 +111,8 @@ private void test_empty_body_search() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchRequest.source(searchSourceBuilder); searchRequest.source().query(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER))); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } @@ -121,7 +123,8 @@ private void test_matchAll_search() { searchRequest .source() .query(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER)).must(QueryBuilders.matchAllQuery())); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } @@ -141,7 +144,8 @@ private void test_bool_search() { .must(QueryBuilders.termQuery("name.keyword", "msmarco-distilbert-base-tas-b-pt")) ) ); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } @@ -154,7 +158,8 @@ private void test_term_search() { .mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER)) .must(QueryBuilders.termQuery("name.keyword", "msmarco-distilbert-base-tas-b-pt")); searchRequest.source().query(boolQueryBuilder); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } @@ -167,7 +172,8 @@ private void test_terms_search() { .mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER)) .must(QueryBuilders.termsQuery("name.keyword", "msmarco-distilbert-base-tas-b-pt", "test_model_group_name")); searchRequest.source().query(boolQueryBuilder); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } @@ -180,7 +186,8 @@ private void test_range_search() { .mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER)) .must(QueryBuilders.rangeQuery("created_time").gte("now-1d")); searchRequest.source().query(boolQueryBuilder); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } @@ -193,7 +200,8 @@ private void test_matchPhrase_search() { .mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER)) .must(QueryBuilders.matchPhraseQuery("description", "desc")); searchRequest.source().query(boolQueryBuilder); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java index 9687306241..1a81a749b5 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -16,6 +16,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Collections; import java.util.Map; import org.apache.lucene.search.TotalHits; @@ -43,13 +44,18 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.action.handler.MLSearchHandler; import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -57,6 +63,7 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase { @Mock Client client; + SdkClient sdkClient; @Mock NamedXContentRegistry namedXContentRegistry; @@ -67,9 +74,12 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; - @Mock SearchRequest searchRequest; + MLSearchActionRequest mlSearchActionRequest; + + SearchResponse searchResponse; + @Mock ActionListener actionListener; @@ -91,36 +101,76 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase { @Mock private FetchSourceContext fetchSourceContext; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException thrown = ExpectedException.none(); @Before public void setup() { MockitoAnnotations.openMocks(this); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper, clusterService)); - searchModelTransportAction = new SearchModelTransportAction(transportService, actionFilters, mlSearchHandler); + searchModelTransportAction = new SearchModelTransportAction( + transportService, + actionFilters, + sdkClient, + mlSearchHandler, + mlFeatureEnabledSetting + ); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.fetchSource(fetchSourceContext); + searchRequest = new SearchRequest(new String[0], searchSourceBuilder); + mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); when(fetchSourceContext.includes()).thenReturn(new String[] {}); when(fetchSourceContext.excludes()).thenReturn(new String[] {}); - searchSourceBuilder.fetchSource(fetchSourceContext); - when(searchRequest.source()).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); Metadata metadata = mock(Metadata.class); when(metadata.hasIndex(anyString())).thenReturn(true); ClusterState testState = new ClusterState(new ClusterName("mock"), 123l, "111111", metadata, null, null, null, Map.of(), 0, false); when(clusterService.state()).thenReturn(testState); + + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); } public void test_DoExecute_admin() { when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(sdkClient, searchRequest, null, actionListener); verify(client, times(1)).search(any(), any()); } @@ -130,25 +180,22 @@ public void test_DoExecute_addBackendRoles() throws IOException { ActionListener listener = invocation.getArgument(1); listener.onResponse(searchResponse); return null; - }).when(client).search(any(), isA(ActionListener.class)); + }).when(client).search(any(), any()); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(sdkClient, searchRequest, null, actionListener); verify(client, times(2)).search(any(), any()); } public void test_DoExecute_addBackendRoles_without_groupIds() { - SearchResponse searchResponse = mock(SearchResponse.class); - SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); - when(searchResponse.getHits()).thenReturn(hits); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(searchResponse); return null; }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(sdkClient, searchRequest, null, actionListener); verify(client, times(2)).search(any(), any()); } @@ -159,8 +206,8 @@ public void test_DoExecute_addBackendRoles_exception() { return null; }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(sdkClient, searchRequest, null, actionListener); verify(client, times(1)).search(any(), any()); } @@ -171,8 +218,8 @@ public void test_DoExecute_searchModel_before_model_creation_no_exception() { return null; }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(sdkClient, searchRequest, null, actionListener); verify(client, times(1)).search(any(), any()); verify(actionListener, times(0)).onFailure(any(IndexNotFoundException.class)); } @@ -204,8 +251,8 @@ public void test_DoExecute_searchModel_before_model_creation_empty_search() { return null; }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(sdkClient, searchRequest, null, actionListener); verify(client, times(1)).search(any(), any()); verify(actionListener, times(0)).onFailure(any(IndexNotFoundException.class)); verify(actionListener, times(1)).onResponse(any(SearchResponse.class)); @@ -218,8 +265,8 @@ public void test_DoExecute_searchModel_MLResourceNotFoundException_exception() { return null; }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(sdkClient, searchRequest, null, actionListener); verify(client, times(1)).search(any(), any()); verify(actionListener, times(1)).onFailure(any(OpenSearchStatusException.class)); } @@ -233,8 +280,8 @@ public void test_DoExecute_addBackendRoles_boolQuery() throws IOException { }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.boolQuery().must(QueryBuilders.matchQuery("name", "model_IT"))); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(sdkClient, searchRequest, null, actionListener); verify(client, times(2)).search(any(), any()); } @@ -247,13 +294,12 @@ public void test_DoExecute_addBackendRoles_termQuery() throws IOException { }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.termQuery("name", "model_IT")); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(sdkClient, searchRequest, null, actionListener); verify(client, times(2)).search(any(), any()); } private SearchResponse createModelGroupSearchResponse() throws IOException { - SearchResponse searchResponse = mock(SearchResponse.class); String modelContent = "{\n" + " \"created_time\": 1684981986069,\n" + " \"access\": \"public\",\n" @@ -264,7 +310,25 @@ private SearchResponse createModelGroupSearchResponse() throws IOException { + " }"; SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); - when(searchResponse.getHits()).thenReturn(hits); - return searchResponse; + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + hits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + return new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java index 3ad05f337a..cbba3ea011 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java @@ -7,22 +7,38 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.util.Collections; + +import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -30,7 +46,9 @@ public class SearchTaskTransportActionTests extends OpenSearchTestCase { @Mock Client client; + SdkClient sdkClient; + SearchResponse searchResponse; @Mock NamedXContentRegistry namedXContentRegistry; @@ -40,6 +58,9 @@ public class SearchTaskTransportActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock SearchRequest searchRequest; @@ -54,15 +75,62 @@ public class SearchTaskTransportActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - searchTaskTransportAction = new SearchTaskTransportAction(transportService, actionFilters, client); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + searchTaskTransportAction = new SearchTaskTransportAction( + transportService, + actionFilters, + client, + sdkClient, + mlFeatureEnabledSetting + ); ThreadPool threadPool = mock(ThreadPool.class); when(client.threadPool()).thenReturn(threadPool); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); when(threadPool.getThreadContext()).thenReturn(threadContext); + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); + } public void test_DoExecute() { - searchTaskTransportAction.doExecute(null, searchRequest, actionListener); - verify(client).search(eq(searchRequest), any()); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(eq(request), any()); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); + searchTaskTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(client, times(1)).search(eq(request), any()); + // Use ArgumentCaptor to capture the SearchResponse + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchResponse.class); + // Capture the response passed to actionListener.onResponse + verify(actionListener, times(1)).onResponse(responseCaptor.capture()); + // Assert that the captured response matches the expected values + SearchResponse capturedResponse = responseCaptor.getValue(); + assertEquals(searchResponse.getHits().getTotalHits(), capturedResponse.getHits().getTotalHits()); + assertEquals(searchResponse.getHits().getHits().length, capturedResponse.getHits().getHits().length); + assertEquals(searchResponse.status(), capturedResponse.status()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java index 97ed37230a..0d922494e9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java @@ -20,10 +20,12 @@ import org.apache.lucene.search.TotalHits; import org.hamcrest.Matchers; +import org.junit.Assert; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchWrapperException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; @@ -35,7 +37,11 @@ import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.action.connector.SearchConnectorTransportAction; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; @@ -52,6 +58,9 @@ public class RestMLSearchConnectorActionTests extends OpenSearchTestCase { private RestMLSearchConnectorAction restMLSearchConnectorAction; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + NodeClient client; private ThreadPool threadPool; @Mock @@ -60,7 +69,7 @@ public class RestMLSearchConnectorActionTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - restMLSearchConnectorAction = new RestMLSearchConnectorAction(); + restMLSearchConnectorAction = new RestMLSearchConnectorAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -106,7 +115,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLSearchConnectorAction mlSearchConnectorAction = new RestMLSearchConnectorAction(); + RestMLSearchConnectorAction mlSearchConnectorAction = new RestMLSearchConnectorAction(mlFeatureEnabledSetting); assertNotNull(mlSearchConnectorAction); } @@ -130,11 +139,12 @@ public void testPrepareRequest() throws Exception { RestRequest request = getSearchAllRestRequest(); restMLSearchConnectorAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), argumentCaptor.capture(), any()); verify(channel, times(1)).sendResponse(responseCaptor.capture()); - SearchRequest searchRequest = argumentCaptor.getValue(); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_CONNECTOR_INDEX }, indices); assertEquals( @@ -176,11 +186,12 @@ public void testPrepareRequest_timeout() throws Exception { RestRequest request = getSearchAllRestRequest(); restMLSearchConnectorAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), argumentCaptor.capture(), any()); verify(channel, times(1)).sendResponse(responseCaptor.capture()); - SearchRequest searchRequest = argumentCaptor.getValue(); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_CONNECTOR_INDEX }, indices); assertEquals( @@ -190,4 +201,51 @@ public void testPrepareRequest_timeout() throws Exception { RestResponse restResponse = responseCaptor.getValue(); assertEquals(RestStatus.REQUEST_TIMEOUT, restResponse.status()); } + + public void testDoubleWrapper_handleIndexNotFound() { + final IndexNotFoundException indexNotFoundException = new IndexNotFoundException("Index not found", ML_CONNECTOR_INDEX); + final DummyActionListener actionListener = new DummyActionListener(); + + SearchConnectorTransportAction.wrapListenerToHandleConnectorIndexNotFound(indexNotFoundException, actionListener); + Assert.assertTrue(actionListener.success); + } + + public void testDoubleWrapper_handleIndexNotFoundWrappedException() { + final WrappedException wrappedException = new WrappedException(); + final DummyActionListener actionListener = new DummyActionListener(); + + SearchConnectorTransportAction.wrapListenerToHandleConnectorIndexNotFound(wrappedException, actionListener); + Assert.assertTrue(actionListener.success); + } + + public void testDoubleWrapper_notRelatedException() { + final RuntimeException exception = new RuntimeException("some random exception"); + final DummyActionListener actionListener = new DummyActionListener(); + + SearchConnectorTransportAction.wrapListenerToHandleConnectorIndexNotFound(exception, actionListener); + Assert.assertFalse(actionListener.success); + } + + public class DummyActionListener implements ActionListener { + public boolean success = false; + + @Override + public void onResponse(SearchResponse searchResponse) { + logger.info("success"); + this.success = true; + } + + @Override + public void onFailure(Exception e) { + logger.error("failure", e); + this.success = false; + } + } + + public static class WrappedException extends Exception implements OpenSearchWrapperException { + @Override + public synchronized Throwable getCause() { + return new IndexNotFoundException("Index not found", ML_CONNECTOR_INDEX); + } + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java index b679b73585..b9e20bf474 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java @@ -12,6 +12,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.TestHelper.getSearchAllRestRequest; @@ -36,6 +37,8 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; @@ -52,6 +55,9 @@ public class RestMLSearchModelActionTests extends OpenSearchTestCase { private RestMLSearchModelAction restMLSearchModelAction; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + NodeClient client; private ThreadPool threadPool; @Mock @@ -60,7 +66,7 @@ public class RestMLSearchModelActionTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - restMLSearchModelAction = new RestMLSearchModelAction(); + restMLSearchModelAction = new RestMLSearchModelAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -106,7 +112,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLSearchModelAction mlSearchModelAction = new RestMLSearchModelAction(); + RestMLSearchModelAction mlSearchModelAction = new RestMLSearchModelAction(mlFeatureEnabledSetting); assertNotNull(mlSearchModelAction); } @@ -130,11 +136,12 @@ public void testPrepareRequest() throws Exception { RestRequest request = getSearchAllRestRequest(); restMLSearchModelAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), argumentCaptor.capture(), any()); verify(channel, times(1)).sendResponse(responseCaptor.capture()); - SearchRequest searchRequest = argumentCaptor.getValue(); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_MODEL_INDEX }, indices); assertEquals( @@ -145,6 +152,27 @@ public void testPrepareRequest() throws Exception { assertNotEquals(RestStatus.REQUEST_TIMEOUT, restResponse.status()); } + public void testPrepareRequest_multiTenancy() throws Exception { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + RestRequest request = getSearchAllRestRequest(); + restMLSearchModelAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), argumentCaptor.capture(), any()); + verify(channel, times(1)).sendResponse(responseCaptor.capture()); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); + String[] indices = searchRequest.indices(); + assertArrayEquals(new String[] { ML_MODEL_INDEX }, indices); + assertEquals( + "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}", + searchRequest.source().toString() + ); + RestResponse restResponse = responseCaptor.getValue(); + assertNotEquals(RestStatus.REQUEST_TIMEOUT, restResponse.status()); + } + public void testPrepareRequest_timeout() throws Exception { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); @@ -176,11 +204,12 @@ public void testPrepareRequest_timeout() throws Exception { RestRequest request = getSearchAllRestRequest(); restMLSearchModelAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), argumentCaptor.capture(), any()); verify(channel, times(1)).sendResponse(responseCaptor.capture()); - SearchRequest searchRequest = argumentCaptor.getValue(); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_MODEL_INDEX }, indices); assertEquals( diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchTaskActionTests.java index 58713cb277..3c917a9568 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchTaskActionTests.java @@ -9,7 +9,9 @@ import org.hamcrest.Matchers; import org.junit.Before; +import org.mockito.Mock; import org.opensearch.core.common.Strings; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; @@ -18,13 +20,16 @@ public class RestMLSearchTaskActionTests extends OpenSearchTestCase { private RestMLSearchTaskAction restMLSearchTaskAction; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() { - restMLSearchTaskAction = new RestMLSearchTaskAction(); + restMLSearchTaskAction = new RestMLSearchTaskAction(mlFeatureEnabledSetting); } public void testConstructor() { - RestMLSearchTaskAction mlSearchTaskAction = new RestMLSearchTaskAction(); + RestMLSearchTaskAction mlSearchTaskAction = new RestMLSearchTaskAction(mlFeatureEnabledSetting); assertNotNull(mlSearchTaskAction); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java index 294e3deac4..adf2b37417 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java @@ -27,15 +27,19 @@ import org.junit.Before; import org.mockito.ArgumentCaptor; -import org.opensearch.action.search.SearchRequest; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; import org.opensearch.test.OpenSearchTestCase; import com.google.gson.Gson; @@ -44,13 +48,17 @@ public class RestMemorySearchConversationsActionTests extends OpenSearchTestCase Gson gson; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() { + MockitoAnnotations.openMocks(this); gson = new Gson(); } public void testBasics() { - RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(); + RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(mlFeatureEnabledSetting); assert (action.getName().equals("conversation_memory_search_conversations")); List routes = action.routes(); assert (routes.size() == 2); @@ -59,15 +67,17 @@ public void testBasics() { } public void testPreprareRequest() throws Exception { - RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(); + RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(mlFeatureEnabledSetting); RestRequest request = TestHelper.getSearchAllRestRequest(); NodeClient client = mock(NodeClient.class); RestChannel channel = mock(RestChannel.class); action.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(client, times(1)).execute(eq(SearchConversationsAction.INSTANCE), argumentCaptor.capture(), any()); - assert (argumentCaptor.getValue().source().query() instanceof MatchAllQueryBuilder); + assert (argumentCaptor.getValue().getSearchRequest().source().query() instanceof MatchAllQueryBuilder); } } diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index 0796fff279..83a38d9c49 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -375,7 +375,11 @@ public static RestRequest getExecuteAgentRestRequest() { } public static RestRequest getSearchAllRestRequest() { + String tenantId = "test-tenant"; + Map> headers = new HashMap<>(); + headers.put(Constants.TENANT_ID_HEADER, Collections.singletonList(tenantId)); RestRequest request = new FakeRestRequest.Builder(getXContentRegistry()) + .withHeaders(headers) .withContent(new BytesArray(TestData.matchAllSearchQuery()), XContentType.JSON) .build(); return request;