diff --git a/common/src/main/java/org/opensearch/ml/common/transport/search/MLSearchActionRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/search/MLSearchActionRequest.java new file mode 100644 index 0000000000..ec05fed236 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/search/MLSearchActionRequest.java @@ -0,0 +1,92 @@ +package org.opensearch.ml.common.transport.search; + +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.Version; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.Builder; +import lombok.Getter; + +/** + * Represents an extended search action request that includes a tenant ID. + * This class allows OpenSearch to include a tenant ID in search requests, + * which is not natively supported in the standard {@link SearchRequest}. + */ +@Getter +public class MLSearchActionRequest extends SearchRequest { + SearchRequest searchRequest; + String tenantId; + + /** + * Constructor for building an MLSearchActionRequest. + * + * @param searchRequest The original {@link SearchRequest} to be wrapped. + * @param tenantId The tenant ID associated with the request. + */ + @Builder + public MLSearchActionRequest(SearchRequest searchRequest, String tenantId) { + this.searchRequest = searchRequest; + this.tenantId = tenantId; + } + + /** + * Deserializes an {@link MLSearchActionRequest} from a {@link StreamInput}. + * + * @param input The stream input to read from. + * @throws IOException If an I/O error occurs during deserialization. + */ + public MLSearchActionRequest(StreamInput input) throws IOException { + super(input); + Version streamInputVersion = input.getVersion(); + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null; + } + + /** + * Serializes this {@link MLSearchActionRequest} to a {@link StreamOutput}. + * + * @param output The stream output to write to. + * @throws IOException If an I/O error occurs during serialization. + */ + @Override + public void writeTo(StreamOutput output) throws IOException { + super.writeTo(output); + Version streamOutputVersion = output.getVersion(); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + output.writeOptionalString(tenantId); + } + } + + /** + * Converts a generic {@link ActionRequest} into an {@link MLSearchActionRequest}. + * This is useful when handling requests that may need to be converted for compatibility. + * + * @param actionRequest The original {@link ActionRequest}. + * @return The converted {@link MLSearchActionRequest}. + * @throws UncheckedIOException If the conversion fails due to an I/O error. + */ + public static MLSearchActionRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLSearchActionRequest) { + return (MLSearchActionRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLSearchActionRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLSearchActionRequest", e); + } + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/search/MLSearchActionRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/search/MLSearchActionRequestTest.java new file mode 100644 index 0000000000..f4c9abd431 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/search/MLSearchActionRequestTest.java @@ -0,0 +1,147 @@ +package org.opensearch.ml.common.transport.search; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.Version; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLSearchActionRequestTest { + + private SearchRequest searchRequest; + + @Before + public void setUp() { + searchRequest = new SearchRequest("test-index"); + } + + @Test + public void testConstructorAndGetters() { + MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); + assertEquals("test-index", request.getSearchRequest().indices()[0]); + assertEquals("test-tenant", request.getTenantId()); + } + + @Test + public void testStreamConstructorAndWriteTo() throws IOException { + MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + + MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput()); + assertEquals("test-index", deserializedRequest.getSearchRequest().indices()[0]); + assertEquals("test-tenant", deserializedRequest.getTenantId()); + } + + @Test + public void testWriteToWithNullSearchRequest() throws IOException { + MLSearchActionRequest request = MLSearchActionRequest.builder().tenantId("test-tenant").build(); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + + MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput()); + assertNull(deserializedRequest.getSearchRequest()); + assertEquals("test-tenant", deserializedRequest.getTenantId()); + } + + @Test + public void testFromActionRequestWithMLSearchActionRequest() { + MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); + MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request); + assertSame(result, request); + } + + @Test + public void testFromActionRequestWithNonMLSearchActionRequest() throws IOException { + MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + + MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(actionRequest); + assertNotSame(result, request); + assertEquals(request.getSearchRequest().indices()[0], result.getSearchRequest().indices()[0]); + assertEquals(request.getTenantId(), result.getTenantId()); + } + + @Test(expected = UncheckedIOException.class) + public void testFromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLSearchActionRequest.fromActionRequest(actionRequest); + } + + @Test + public void testBackwardCompatibility() throws IOException { + MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); + + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_2_18_0); // Older version + request.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(Version.V_2_18_0); + + MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in); + assertNull(deserializedRequest.getTenantId()); // Ensure tenantId is ignored + } + + @Test + public void testFromActionRequestWithValidRequest() { + MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); + + MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request); + assertSame(request, result); + } + + @Test + public void testMixedVersionCompatibility() throws IOException { + MLSearchActionRequest originalRequest = MLSearchActionRequest + .builder() + .searchRequest(searchRequest) + .tenantId("test-tenant") + .build(); + + // Serialize with a newer version + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_2_19_0); + originalRequest.writeTo(out); + + // Deserialize with an older version + StreamInput in = out.bytes().streamInput(); + in.setVersion(Version.V_2_18_0); + + MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in); + assertNull(deserializedRequest.getTenantId()); // tenantId should not exist in older versions + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java index b598d63e25..8a85f4a234 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java @@ -30,6 +30,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.memory.ConversationalMemoryHandler; import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.tasks.Task; @@ -38,7 +39,7 @@ import lombok.extern.log4j.Log4j2; @Log4j2 -public class SearchConversationsTransportAction extends HandledTransportAction { +public class SearchConversationsTransportAction extends HandledTransportAction { private ConversationalMemoryHandler cmHandler; private Client client; @@ -61,7 +62,7 @@ public SearchConversationsTransportAction( Client client, ClusterService clusterService ) { - super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new); + super(SearchConversationsAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.cmHandler = cmHandler; this.client = client; this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); @@ -71,13 +72,14 @@ public SearchConversationsTransportAction( } @Override - public void doExecute(Task task, SearchRequest request, ActionListener actionListener) { + public void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener actionListener) { + SearchRequest request = mlSearchActionRequest.getSearchRequest(); if (!featureIsEnabled) { actionListener.onFailure(new OpenSearchException(ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE)); return; } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { - ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener internalListener = ActionListener.runBefore(actionListener, context::restore); cmHandler.searchConversations(request, internalListener); } catch (Exception e) { log.error("Failed to search memories", e); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java index 1cee8c2ddf..86aad3be9e 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java @@ -44,6 +44,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.memory.MemoryTestUtil; import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.test.OpenSearchTestCase; @@ -79,12 +80,15 @@ public class SearchConversationsTransportActionTests extends OpenSearchTestCase @Mock SearchRequest request; + MLSearchActionRequest mlSearchActionRequest; + SearchConversationsTransportAction action; ThreadContext threadContext; @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); + mlSearchActionRequest = new MLSearchActionRequest(request, null); Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); this.threadContext = new ThreadContext(settings); @@ -104,7 +108,7 @@ public void testEnabled_ThenSucceed() { listener.onResponse(response); return null; }).when(cmHandler).searchConversations(any(), any()); - action.doExecute(null, request, actionListener); + action.doExecute(null, mlSearchActionRequest, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); verify(actionListener, times(1)).onResponse(argCaptor.capture()); assert (argCaptor.getValue().equals(response)); @@ -114,7 +118,7 @@ public void testDisabled_ThenFail() { clusterService = MemoryTestUtil.clusterServiceWithMemoryFeatureDisabled(); this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); - action.doExecute(null, request, actionListener); + action.doExecute(null, mlSearchActionRequest, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); assertEquals(argCaptor.getValue().getMessage(), ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE); diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java index 3153394061..ad693da3cc 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.agents; import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -20,28 +21,46 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.transport.agent.MLSearchAgentAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import lombok.extern.log4j.Log4j2; @Log4j2 -public class TransportSearchAgentAction extends HandledTransportAction { +public class TransportSearchAgentAction extends HandledTransportAction { private final Client client; + private final SdkClient sdkClient; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject - public TransportSearchAgentAction(TransportService transportService, ActionFilters actionFilters, Client client) { - super(MLSearchAgentAction.NAME, transportService, actionFilters, SearchRequest::new); + public TransportSearchAgentAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + SdkClient sdkClient, + MLFeatureEnabledSetting mlFeatureEnabledSetting + ) { + super(MLSearchAgentAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.client = client; + this.sdkClient = sdkClient; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override - protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { - request.indices(CommonValue.ML_AGENT_INDEX); - search(request, actionListener); + protected void doExecute(Task task, MLSearchActionRequest request, ActionListener actionListener) { + request.getSearchRequest().indices(CommonValue.ML_AGENT_INDEX); + String tenantId = request.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + search(request.getSearchRequest(), tenantId, actionListener); } - private void search(SearchRequest request, ActionListener actionListener) { + private void search(SearchRequest request, String tenantId, ActionListener actionListener) { ActionListener listener = wrapRestActionListener(actionListener, "Fail to search agent"); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); @@ -57,6 +76,11 @@ private void search(SearchRequest request, ActionListener action // Add a should clause to include documents where IS_HIDDEN_FIELD is false shouldQuery.should(QueryBuilders.termQuery(MLAgent.IS_HIDDEN_FIELD, false)); + // For multi-tenancy + if (tenantId != null) { + shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); + } + // Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD))); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java index 514c56e209..8d0813f76b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java @@ -5,6 +5,7 @@ package org.opensearch.ml.action.connector; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import java.util.ArrayList; @@ -25,11 +26,17 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.search.internal.InternalSearchResponse; @@ -41,31 +48,42 @@ import lombok.extern.log4j.Log4j2; @Log4j2 -public class SearchConnectorTransportAction extends HandledTransportAction { +public class SearchConnectorTransportAction extends HandledTransportAction { private final Client client; + private final SdkClient sdkClient; private final ConnectorAccessControlHelper connectorAccessControlHelper; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject public SearchConnectorTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, - ConnectorAccessControlHelper connectorAccessControlHelper + SdkClient sdkClient, + ConnectorAccessControlHelper connectorAccessControlHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { - super(MLConnectorSearchAction.NAME, transportService, actionFilters, SearchRequest::new); + super(MLConnectorSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.client = client; + this.sdkClient = sdkClient; this.connectorAccessControlHelper = connectorAccessControlHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override - protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { - request.indices(CommonValue.ML_CONNECTOR_INDEX); - search(request, actionListener); + protected void doExecute(Task task, MLSearchActionRequest request, ActionListener actionListener) { + request.getSearchRequest().indices(CommonValue.ML_CONNECTOR_INDEX); + + String tenantId = request.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + search(request.getSearchRequest(), tenantId, actionListener); } - private void search(SearchRequest request, ActionListener actionListener) { + private void search(SearchRequest request, String tenantId, ActionListener actionListener) { User user = RestActionUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); @@ -90,6 +108,15 @@ private void search(SearchRequest request, ActionListener action final ActionListener doubleWrappedListener = ActionListener .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); + if (tenantId != null) { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery(); + if (request.source().query() != null) { + queryBuilder.must(request.source().query()); + } + queryBuilder.filter(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); // Replace with your tenant_id field + request.source().query(queryBuilder); + } + if (connectorAccessControlHelper.skipConnectorAccessControl(user)) { client.search(request, doubleWrappedListener); } else { diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index 52be1dd608..5e02caa14a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -7,6 +7,7 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import java.util.ArrayList; @@ -79,11 +80,11 @@ public MLSearchHandler( * @param request * @param actionListener */ - public void search(SearchRequest request, ActionListener actionListener) { + public void search(SearchRequest request, String tenantId, ActionListener actionListener) { User user = RestActionUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, "Fail to search model version"); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); List excludes = Optional .ofNullable(request.source()) .map(SearchSourceBuilder::fetchSource) @@ -113,6 +114,11 @@ public void search(SearchRequest request, ActionListener actionL // Add a should clause to include documents where IS_HIDDEN_FIELD is false shouldQuery.should(QueryBuilders.termQuery(MLModel.IS_HIDDEN_FIELD, false)); + // For multi-tenancy + if (tenantId != null) { + shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); + } + // Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLModel.IS_HIDDEN_FIELD))); diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java index 4fa95fafa6..b9e001e57b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.model_group; import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import org.opensearch.action.search.SearchRequest; @@ -18,19 +19,27 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import lombok.extern.log4j.Log4j2; @Log4j2 -public class SearchModelGroupTransportAction extends HandledTransportAction { +public class SearchModelGroupTransportAction extends HandledTransportAction { Client client; + SdkClient sdkClient; ClusterService clusterService; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; ModelAccessControlHelper modelAccessControlHelper; @@ -39,36 +48,64 @@ public SearchModelGroupTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SdkClient sdkClient, ClusterService clusterService, - ModelAccessControlHelper modelAccessControlHelper + ModelAccessControlHelper modelAccessControlHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { - super(MLModelGroupSearchAction.NAME, transportService, actionFilters, SearchRequest::new); + super(MLModelGroupSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.client = client; + this.sdkClient = sdkClient; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override - protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { + protected void doExecute(Task task, MLSearchActionRequest request, ActionListener actionListener) { User user = RestActionUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, "Fail to search"); - request.indices(CommonValue.ML_MODEL_GROUP_INDEX); - preProcessRoleAndPerformSearch(request, user, listener); + request.getSearchRequest().indices(CommonValue.ML_MODEL_GROUP_INDEX); + String tenantId = request.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + preProcessRoleAndPerformSearch(request.getSearchRequest(), tenantId, user, listener); } - private void preProcessRoleAndPerformSearch(SearchRequest request, User user, ActionListener listener) { + private void preProcessRoleAndPerformSearch( + SearchRequest request, + String tenantId, + User user, + ActionListener listener + ) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); final ActionListener doubleWrappedListener = ActionListener .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); + // Modify the query to include tenant ID filtering + if (tenantId != null) { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery(); + + // Preserve existing query if present + if (request.source().query() != null) { + queryBuilder.must(request.source().query()); + } + // Add tenancy filter + queryBuilder.filter(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); + + // Update the request's source with the new query + request.source().query(queryBuilder); + } + if (modelAccessControlHelper.skipModelAccessControl(user)) { client.search(request, doubleWrappedListener); } else { // Security is enabled, filter is enabled and user isn't admin modelAccessControlHelper.addUserBackendRolesFilter(user, request.source()); - log.debug("Filtering result by " + user.getBackendRoles()); + log.debug("Filtering result by {}", user.getBackendRoles()); client.search(request, doubleWrappedListener); } } catch (Exception e) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java index 64d4913d5f..5be1d0a06d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java @@ -5,7 +5,6 @@ package org.opensearch.ml.action.models; -import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -14,24 +13,42 @@ import org.opensearch.ml.action.handler.MLSearchHandler; import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import lombok.extern.log4j.Log4j2; @Log4j2 -public class SearchModelTransportAction extends HandledTransportAction { +public class SearchModelTransportAction extends HandledTransportAction { private final MLSearchHandler mlSearchHandler; + private final SdkClient sdkClient; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject - public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, MLSearchHandler mlSearchHandler) { - super(MLModelSearchAction.NAME, transportService, actionFilters, SearchRequest::new); + public SearchModelTransportAction( + TransportService transportService, + ActionFilters actionFilters, + SdkClient sdkClient, + MLSearchHandler mlSearchHandler, + MLFeatureEnabledSetting mlFeatureEnabledSetting + ) { + super(MLModelSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); + this.sdkClient = sdkClient; this.mlSearchHandler = mlSearchHandler; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override - protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { - request.indices(CommonValue.ML_MODEL_INDEX); - mlSearchHandler.search(request, actionListener); + protected void doExecute(Task task, MLSearchActionRequest request, ActionListener actionListener) { + request.getSearchRequest().indices(CommonValue.ML_MODEL_INDEX); + String tenantId = request.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + mlSearchHandler.search(request.getSearchRequest(), tenantId, actionListener); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java index c87e4087c3..527a60d8e4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java @@ -5,6 +5,7 @@ package org.opensearch.ml.action.tasks; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import org.opensearch.action.search.SearchRequest; @@ -15,28 +16,65 @@ import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.common.transport.task.MLTaskSearchAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import lombok.extern.log4j.Log4j2; @Log4j2 -public class SearchTaskTransportAction extends HandledTransportAction { +public class SearchTaskTransportAction extends HandledTransportAction { private Client client; + private final SdkClient sdkClient; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject - public SearchTaskTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { - super(MLTaskSearchAction.NAME, transportService, actionFilters, SearchRequest::new); + public SearchTaskTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + SdkClient sdkClient, + MLFeatureEnabledSetting mlFeatureEnabledSetting + ) { + super(MLTaskSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.client = client; + this.sdkClient = sdkClient; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override - protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { + protected void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener actionListener) { + String tenantId = mlSearchActionRequest.getTenantId(); + SearchRequest request = mlSearchActionRequest.getSearchRequest(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { final ActionListener wrappedListener = ActionListener .wrap(actionListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, actionListener)); - client.search(request, ActionListener.runBefore(wrappedListener, () -> context.restore())); + + // Modify the query to include tenant ID filtering + if (tenantId != null) { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery(); + + // Preserve existing query if present + if (request.source().query() != null) { + queryBuilder.must(request.source().query()); + } + // Add tenancy filter + queryBuilder.filter(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); // Replace 'tenant_id_field' with actual field name + + // Update the request's source with the new query + request.source().query(queryBuilder); + } + client.search(request, ActionListener.runBefore(wrappedListener, context::restore)); } catch (Exception e) { log.error(e.getMessage(), e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index ae0c890d61..e8c3a936b3 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -774,10 +774,10 @@ public List getRestHandlers( RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction(mlFeatureEnabledSetting); RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting); RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction(mlFeatureEnabledSetting); - RestMLSearchModelAction restMLSearchModelAction = new RestMLSearchModelAction(); + RestMLSearchModelAction restMLSearchModelAction = new RestMLSearchModelAction(mlFeatureEnabledSetting); RestMLGetTaskAction restMLGetTaskAction = new RestMLGetTaskAction(mlFeatureEnabledSetting); RestMLDeleteTaskAction restMLDeleteTaskAction = new RestMLDeleteTaskAction(mlFeatureEnabledSetting); - RestMLSearchTaskAction restMLSearchTaskAction = new RestMLSearchTaskAction(); + RestMLSearchTaskAction restMLSearchTaskAction = new RestMLSearchTaskAction(mlFeatureEnabledSetting); RestMLProfileAction restMLProfileAction = new RestMLProfileAction(clusterService); RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction( clusterService, @@ -802,14 +802,16 @@ public List getRestHandlers( RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting); RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(clusterService, settings, mlFeatureEnabledSetting); RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(mlFeatureEnabledSetting); - RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction(); + RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction(mlFeatureEnabledSetting); RestMemoryCreateConversationAction restCreateConversationAction = new RestMemoryCreateConversationAction(); RestMemoryGetConversationsAction restListConversationsAction = new RestMemoryGetConversationsAction(); RestMemoryCreateInteractionAction restCreateInteractionAction = new RestMemoryCreateInteractionAction(); RestMemoryGetInteractionsAction restListInteractionsAction = new RestMemoryGetInteractionsAction(); RestMemoryDeleteConversationAction restDeleteConversationAction = new RestMemoryDeleteConversationAction(); RestMLUpdateConnectorAction restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); - RestMemorySearchConversationsAction restSearchConversationsAction = new RestMemorySearchConversationsAction(); + RestMemorySearchConversationsAction restSearchConversationsAction = new RestMemorySearchConversationsAction( + mlFeatureEnabledSetting + ); RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/AbstractMLSearchAction.java b/plugin/src/main/java/org/opensearch/ml/rest/AbstractMLSearchAction.java index 4c94f8267a..d18594a228 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/AbstractMLSearchAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/AbstractMLSearchAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.ml.utils.RestActionUtils.getSourceContext; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.ArrayList; @@ -18,6 +19,8 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; @@ -32,12 +35,20 @@ public abstract class AbstractMLSearchAction extends protected final String index; protected final Class clazz; protected final ActionType actionType; + MLFeatureEnabledSetting mlFeatureEnabledSetting; - public AbstractMLSearchAction(List urlPaths, String index, Class clazz, ActionType actionType) { + public AbstractMLSearchAction( + List urlPaths, + String index, + Class clazz, + ActionType actionType, + MLFeatureEnabledSetting mlFeatureEnabledSetting + ) { this.urlPaths = urlPaths; this.index = index; this.clazz = clazz; this.actionType = actionType; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override @@ -46,12 +57,18 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); - return channel -> client.execute(actionType, searchRequest, search(channel)); + MLSearchActionRequest mlSearchActionRequest = MLSearchActionRequest + .builder() + .searchRequest(searchRequest) + .tenantId(tenantId) + .build(); + return channel -> client.execute(actionType, mlSearchActionRequest, search(channel)); } protected RestResponseListener search(RestChannel channel) { - return new RestResponseListener(channel) { + return new RestResponseListener<>(channel) { @Override public RestResponse buildResponse(SearchResponse response) throws Exception { if (response.isTimedOut()) { diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchAgentAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchAgentAction.java index 45576f5cbb..61e7af7a30 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchAgentAction.java @@ -24,10 +24,9 @@ public class RestMLSearchAgentAction extends AbstractMLSearchAction { private static final String ML_SEARCH_AGENT_ACTION = "ml_search_agent_action"; private static final String SEARCH_AGENT_PATH = ML_BASE_URI + "/agents/_search"; - private final MLFeatureEnabledSetting mlFeatureEnabledSetting; public RestMLSearchAgentAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { - super(ImmutableList.of(SEARCH_AGENT_PATH), ML_AGENT_INDEX, MLAgent.class, MLSearchAgentAction.INSTANCE); + super(ImmutableList.of(SEARCH_AGENT_PATH), ML_AGENT_INDEX, MLAgent.class, MLSearchAgentAction.INSTANCE, mlFeatureEnabledSetting); this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchConnectorAction.java index 517882a7a3..c8f9afcb9c 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchConnectorAction.java @@ -10,6 +10,7 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import com.google.common.collect.ImmutableList; @@ -17,8 +18,14 @@ public class RestMLSearchConnectorAction extends AbstractMLSearchAction { private static final String ML_SEARCH_MODEL_ACTION = "ml_search_model_action"; private static final String SEARCH_MODEL_PATH = ML_BASE_URI + "/models/_search"; - public RestMLSearchModelAction() { - super(ImmutableList.of(SEARCH_MODEL_PATH), ML_MODEL_INDEX, MLModel.class, MLModelSearchAction.INSTANCE); + public RestMLSearchModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + super(ImmutableList.of(SEARCH_MODEL_PATH), ML_MODEL_INDEX, MLModel.class, MLModelSearchAction.INSTANCE, mlFeatureEnabledSetting); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelGroupAction.java index d6123c57d7..c4f5b5f48e 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelGroupAction.java @@ -20,12 +20,15 @@ public class RestMLSearchModelGroupAction extends AbstractMLSearchAction { private static final String ML_SEARCH_MODEL_GROUP_ACTION = "ml_search_model_group_action"; private static final String SEARCH_MODEL_GROUP_PATH = ML_BASE_URI + "/model_groups/_search"; - private final MLFeatureEnabledSetting mlFeatureEnabledSetting; public RestMLSearchModelGroupAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { - super(ImmutableList.of(SEARCH_MODEL_GROUP_PATH), ML_MODEL_GROUP_INDEX, MLModelGroup.class, MLModelGroupSearchAction.INSTANCE); - - this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + super( + ImmutableList.of(SEARCH_MODEL_GROUP_PATH), + ML_MODEL_GROUP_INDEX, + MLModelGroup.class, + MLModelGroupSearchAction.INSTANCE, + mlFeatureEnabledSetting + ); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchTaskAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchTaskAction.java index 25f69e23b8..70bf4f7894 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchTaskAction.java @@ -10,6 +10,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.transport.task.MLTaskSearchAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import com.google.common.collect.ImmutableList; @@ -17,8 +18,8 @@ public class RestMLSearchTaskAction extends AbstractMLSearchAction { private static final String ML_SEARCH_Task_ACTION = "ml_search_task_action"; private static final String SEARCH_TASK_PATH = ML_BASE_URI + "/tasks/_search"; - public RestMLSearchTaskAction() { - super(ImmutableList.of(SEARCH_TASK_PATH), ML_TASK_INDEX, MLTask.class, MLTaskSearchAction.INSTANCE); + public RestMLSearchTaskAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + super(ImmutableList.of(SEARCH_TASK_PATH), ML_TASK_INDEX, MLTask.class, MLTaskSearchAction.INSTANCE, mlFeatureEnabledSetting); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java index 5beee29c42..02c2ba8c70 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java @@ -21,18 +21,20 @@ import org.opensearch.ml.common.conversation.ConversationMeta; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import com.google.common.collect.ImmutableList; public class RestMemorySearchConversationsAction extends AbstractMLSearchAction { private static final String SEARCH_CONVERSATIONS_NAME = "conversation_memory_search_conversations"; - public RestMemorySearchConversationsAction() { + public RestMemorySearchConversationsAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { super( ImmutableList.of(ActionConstants.SEARCH_CONVERSATIONS_REST_PATH), ConversationalIndexConstants.META_INDEX_NAME, ConversationMeta.class, - SearchConversationsAction.INSTANCE + SearchConversationsAction.INSTANCE, + mlFeatureEnabledSetting ); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java index 9853578bfa..70f53d882f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java @@ -8,21 +8,37 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import java.util.Collections; + +import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -30,6 +46,8 @@ public class TransportSearchAgentActionTests extends OpenSearchTestCase { @Mock Client client; + SdkClient sdkClient; + SearchResponse searchResponse; @Mock TransportService transportService; @@ -37,22 +55,51 @@ public class TransportSearchAgentActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock ActionListener actionListener; TransportSearchAgentAction transportSearchAgentAction; - @Mock - SearchResponse mockedSearchResponse; - @Before public void setup() { MockitoAnnotations.openMocks(this); - transportSearchAgentAction = new TransportSearchAgentAction(transportService, actionFilters, client); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + transportSearchAgentAction = new TransportSearchAgentAction( + transportService, + actionFilters, + client, + sdkClient, + mlFeatureEnabledSetting + ); ThreadPool threadPool = mock(ThreadPool.class); when(client.threadPool()).thenReturn(threadPool); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); when(threadPool.getThreadContext()).thenReturn(threadContext); + + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); } @Test @@ -62,14 +109,23 @@ public void testDoExecuteWithEmptyQuery() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - listener.onResponse(mockedSearchResponse); + listener.onResponse(searchResponse); return null; }).when(client).search(eq(request), any()); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); - transportSearchAgentAction.doExecute(null, request, actionListener); - + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); verify(client, times(1)).search(eq(request), any()); - verify(actionListener, times(1)).onResponse(eq(mockedSearchResponse)); + // Use ArgumentCaptor to capture the SearchResponse + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchResponse.class); + // Capture the response passed to actionListener.onResponse + verify(actionListener, times(1)).onResponse(responseCaptor.capture()); + // Assert that the captured response matches the expected values + SearchResponse capturedResponse = responseCaptor.getValue(); + assertEquals(searchResponse.getHits().getTotalHits(), capturedResponse.getHits().getTotalHits()); + assertEquals(searchResponse.getHits().getHits().length, capturedResponse.getHits().getHits().length); + assertEquals(searchResponse.status(), capturedResponse.status()); + } @Test @@ -77,30 +133,39 @@ public void testDoExecuteWithNonEmptyQuery() { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); sourceBuilder.query(QueryBuilders.matchAllQuery()); SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - listener.onResponse(mockedSearchResponse); + listener.onResponse(searchResponse); return null; }).when(client).search(eq(request), any()); - transportSearchAgentAction.doExecute(null, request, actionListener); + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); verify(client, times(1)).search(eq(request), any()); - verify(actionListener, times(1)).onResponse(eq(mockedSearchResponse)); + // Use ArgumentCaptor to capture the SearchResponse + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchResponse.class); + // Capture the response passed to actionListener.onResponse + verify(actionListener, times(1)).onResponse(responseCaptor.capture()); + // Assert that the captured response matches the expected values + SearchResponse capturedResponse = responseCaptor.getValue(); + assertEquals(searchResponse.getHits().getTotalHits(), capturedResponse.getHits().getTotalHits()); + assertEquals(searchResponse.getHits().getHits().length, capturedResponse.getHits().getHits().length); + assertEquals(searchResponse.status(), capturedResponse.status()); } @Test public void testDoExecuteOnFailure() { SearchRequest request = new SearchRequest("my_index"); - + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new Exception("test exception")); return null; }).when(client).search(eq(request), any()); - transportSearchAgentAction.doExecute(null, request, actionListener); + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); verify(client, times(1)).search(eq(request), any()); verify(actionListener, times(1)).onFailure(any(Exception.class)); @@ -111,17 +176,25 @@ public void testSearchWithHiddenField() { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); - + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - listener.onResponse(mockedSearchResponse); + listener.onResponse(searchResponse); return null; }).when(client).search(eq(request), any()); - transportSearchAgentAction.doExecute(null, request, actionListener); + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); verify(client, times(1)).search(eq(request), any()); - verify(actionListener, times(1)).onResponse(eq(mockedSearchResponse)); + // Use ArgumentCaptor to capture the SearchResponse + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchResponse.class); + // Capture the response passed to actionListener.onResponse + verify(actionListener, times(1)).onResponse(responseCaptor.capture()); + // Assert that the captured response matches the expected values + SearchResponse capturedResponse = responseCaptor.getValue(); + assertEquals(searchResponse.getHits().getTotalHits(), capturedResponse.getHits().getTotalHits()); + assertEquals(searchResponse.getHits().getHits().length, capturedResponse.getHits().getHits().length); + assertEquals(searchResponse.status(), capturedResponse.status()); } @Test @@ -129,14 +202,14 @@ public void testSearchException() { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); - + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new Exception("failed to search the agent index")); return null; }).when(client).search(eq(request), any()); - transportSearchAgentAction.doExecute(null, request, actionListener); + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -150,11 +223,53 @@ public void testSearchThrowsException() { // Create a search request SearchRequest searchRequest = new SearchRequest(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); // Execute the action - transportSearchAgentAction.doExecute(null, searchRequest, actionListener); + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); // Verify that the actionListener's onFailure method was called verify(actionListener, times(1)).onFailure(any(RuntimeException.class)); } + + @Test + public void testDoExecute_MultiTenancyEnabled_TenantFilteringNotEnabled() throws InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); + + // Execute the action + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); + + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(captor.capture()); + OpenSearchStatusException exception = captor.getValue(); + assertEquals(RestStatus.FORBIDDEN, exception.status()); + assertEquals("You don't have permission to access this resource", exception.getMessage()); + } + + @Test + public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() throws InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + sourceBuilder.query(QueryBuilders.termQuery(TENANT_ID_FIELD, "123456")); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, "123456"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + // Execute the action + transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(actionListener).onResponse(any(SearchResponse.class)); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/SearchConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/SearchConnectorTransportActionTests.java index 079206e621..5ec1fd8853 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/SearchConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/SearchConnectorTransportActionTests.java @@ -11,11 +11,18 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.util.Collections; + +import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; @@ -23,9 +30,19 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -35,6 +52,7 @@ public class SearchConnectorTransportActionTests extends OpenSearchTestCase { @Mock Client client; + SdkClient sdkClient; @Mock TransportService transportService; @@ -42,9 +60,12 @@ public class SearchConnectorTransportActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; - @Mock SearchRequest searchRequest; + MLSearchActionRequest mlSearchActionRequest; + + SearchResponse searchResponse; + SearchSourceBuilder searchSourceBuilder; @Mock @@ -63,14 +84,21 @@ public class SearchConnectorTransportActionTests extends OpenSearchTestCase { @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() { MockitoAnnotations.openMocks(this); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); searchConnectorTransportAction = new SearchConnectorTransportAction( transportService, actionFilters, client, - connectorAccessControlHelper + sdkClient, + connectorAccessControlHelper, + mlFeatureEnabledSetting ); Settings settings = Settings.builder().build(); @@ -80,41 +108,100 @@ public void setup() { searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.fetchSource(fetchSourceContext); - when(searchRequest.source()).thenReturn(searchSourceBuilder); + searchRequest = new SearchRequest(new String[0], searchSourceBuilder); + mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); when(fetchSourceContext.includes()).thenReturn(new String[] {}); when(fetchSourceContext.excludes()).thenReturn(new String[] {}); + + when(connectorAccessControlHelper.addUserBackendRolesFilter(any(), any(SearchSourceBuilder.class))).thenReturn(searchSourceBuilder); + + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); } + @Test public void test_doExecute_connectorAccessControlNotEnabled_searchSuccess() { String userString = "admin|role-1|all_access"; threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userString); when(connectorAccessControlHelper.skipConnectorAccessControl(any(User.class))).thenReturn(true); - SearchResponse searchResponse = mock(SearchResponse.class); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(searchResponse); return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - searchConnectorTransportAction.doExecute(task, searchRequest, actionListener); + searchConnectorTransportAction.doExecute(task, mlSearchActionRequest, actionListener); verify(actionListener).onResponse(any(SearchResponse.class)); } + @Test public void test_doExecute_connectorAccessControlEnabled_searchSuccess() { when(connectorAccessControlHelper.skipConnectorAccessControl(any(User.class))).thenReturn(false); - SearchResponse searchResponse = mock(SearchResponse.class); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(searchResponse); return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - searchConnectorTransportAction.doExecute(task, searchRequest, actionListener); + searchConnectorTransportAction.doExecute(task, mlSearchActionRequest, actionListener); verify(actionListener).onResponse(any(SearchResponse.class)); } + @Test public void test_doExecute_exception() { - when(searchRequest.source()).thenThrow(new RuntimeException("runtime exception")); - searchConnectorTransportAction.doExecute(task, searchRequest, actionListener); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("runtime exception")); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + searchConnectorTransportAction.doExecute(task, mlSearchActionRequest, actionListener); verify(actionListener).onFailure(any(RuntimeException.class)); } + @Test + public void testDoExecute_MultiTenancyEnabled_TenantFilteringNotEnabled() throws InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + searchConnectorTransportAction.doExecute(task, mlSearchActionRequest, actionListener); + + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(captor.capture()); + OpenSearchStatusException exception = captor.getValue(); + assertEquals(RestStatus.FORBIDDEN, exception.status()); + assertEquals("You don't have permission to access this resource", exception.getMessage()); + } + + @Test + public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() throws InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + mlSearchActionRequest = new MLSearchActionRequest(searchRequest, "123456"); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + searchConnectorTransportAction.doExecute(task, mlSearchActionRequest, actionListener); + verify(actionListener).onResponse(any(SearchResponse.class)); + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupITTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupITTests.java index 2e3a7a84b1..94a88d1b84 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupITTests.java @@ -7,6 +7,7 @@ import org.junit.Before; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -17,6 +18,7 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; @@ -47,81 +49,97 @@ private void registerModelGroup() { this.modelGroupId = response.getModelGroupId(); } + @Test public void test_empty_body_search() { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchRequest.source(searchSourceBuilder); - SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); } + @Test public void test_matchAll_search() { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchRequest.source(searchSourceBuilder); searchRequest.source().query(QueryBuilders.matchAllQuery()); - SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); } + @Test public void test_bool_search() { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchRequest.source(searchSourceBuilder); searchRequest.source().query(QueryBuilders.boolQuery().must(QueryBuilders.termQuery("name.keyword", "mock_model_group_name"))); - SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); } + @Test public void test_term_search() { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchRequest.source(searchSourceBuilder); searchRequest.source().query(QueryBuilders.termQuery("name.keyword", "mock_model_group_name")); - SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); } + @Test public void test_terms_search() { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchRequest.source(searchSourceBuilder); searchRequest.source().query(QueryBuilders.termsQuery("name.keyword", "mock_model_group_name", "test_model_group_name")); - SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); } + @Test public void test_range_search() { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchRequest.source(searchSourceBuilder); searchRequest.source().query(QueryBuilders.rangeQuery("created_time").gte("now-1d")); - SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); } + @Test public void test_matchPhrase_search() { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchRequest.source(searchSourceBuilder); searchRequest.source().query(QueryBuilders.matchPhraseQuery("description", "desc")); - SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); } + @Test public void test_queryString_search() { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchRequest.source(searchSourceBuilder); searchRequest.source().query(QueryBuilders.queryStringQuery("name: mock_model_group_*")); - SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java index ecfb10f221..48ddb8c3c6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java @@ -6,15 +6,23 @@ package org.opensearch.ml.action.model_group; import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.util.Collections; + +import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -22,8 +30,20 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -31,6 +51,7 @@ public class SearchModelGroupTransportActionTests extends OpenSearchTestCase { @Mock Client client; + SdkClient sdkClient; @Mock NamedXContentRegistry namedXContentRegistry; @@ -41,9 +62,17 @@ public class SearchModelGroupTransportActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; - @Mock SearchRequest searchRequest; + SearchResponse searchResponse; + + SearchSourceBuilder searchSourceBuilder; + + MLSearchActionRequest mlSearchActionRequest; + + @Mock + FetchSourceContext fetchSourceContext; + @Mock ActionListener actionListener; @@ -56,17 +85,22 @@ public class SearchModelGroupTransportActionTests extends OpenSearchTestCase { @Mock private ModelAccessControlHelper modelAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; ThreadContext threadContext; @Before public void setup() { MockitoAnnotations.openMocks(this); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); searchModelGroupTransportAction = new SearchModelGroupTransportAction( transportService, actionFilters, client, + sdkClient, clusterService, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ); Settings settings = Settings.builder().build(); @@ -74,29 +108,122 @@ public void setup() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + + searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.fetchSource(fetchSourceContext); + searchRequest = new SearchRequest(new String[0], searchSourceBuilder); + mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + when(fetchSourceContext.includes()).thenReturn(new String[] {}); + when(fetchSourceContext.excludes()).thenReturn(new String[] {}); + + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); } + @Test public void test_DoExecute() { when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); - searchModelGroupTransportAction.doExecute(null, searchRequest, actionListener); + searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(modelAccessControlHelper).addUserBackendRolesFilter(any(), any()); verify(client).search(any(), any()); } + @Test + public void test_DoExecute_Exception() throws InterruptedException { + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("search failed")); + return null; + }).when(client).search(any(), any()); + + when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); + searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + + verify(modelAccessControlHelper).addUserBackendRolesFilter(any(), any()); + verify(client).search(any(), any()); + verify(actionListener).onFailure(any(RuntimeException.class)); + } + + @Test public void test_skipModelAccessControlTrue() { when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelGroupTransportAction.doExecute(null, searchRequest, actionListener); + searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(client).search(any(), any()); } + @Test public void test_ThreadContextError() { when(modelAccessControlHelper.skipModelAccessControl(any())).thenThrow(new RuntimeException("thread context error")); - searchModelGroupTransportAction.doExecute(null, searchRequest, actionListener); + searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Fail to search", argumentCaptor.getValue().getMessage()); } + + @Test + public void testDoExecute_MultiTenancyEnabled_TenantFilteringNotEnabled() throws InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + + mlSearchActionRequest = new MLSearchActionRequest(request, null); + + searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(captor.capture()); + OpenSearchStatusException exception = captor.getValue(); + assertEquals(RestStatus.FORBIDDEN, exception.status()); + assertEquals("You don't have permission to access this resource", exception.getMessage()); + } + + @Test + public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() throws InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + mlSearchActionRequest = new MLSearchActionRequest(request, "123456"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(actionListener).onResponse(any(SearchResponse.class)); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java index 9bbf98e139..b347597ea4 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java @@ -27,6 +27,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelAction; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; @@ -110,7 +111,8 @@ private void test_empty_body_search() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchRequest.source(searchSourceBuilder); searchRequest.source().query(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER))); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } @@ -121,7 +123,8 @@ private void test_matchAll_search() { searchRequest .source() .query(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER)).must(QueryBuilders.matchAllQuery())); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } @@ -141,7 +144,8 @@ private void test_bool_search() { .must(QueryBuilders.termQuery("name.keyword", "msmarco-distilbert-base-tas-b-pt")) ) ); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } @@ -154,7 +158,8 @@ private void test_term_search() { .mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER)) .must(QueryBuilders.termQuery("name.keyword", "msmarco-distilbert-base-tas-b-pt")); searchRequest.source().query(boolQueryBuilder); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } @@ -167,7 +172,8 @@ private void test_terms_search() { .mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER)) .must(QueryBuilders.termsQuery("name.keyword", "msmarco-distilbert-base-tas-b-pt", "test_model_group_name")); searchRequest.source().query(boolQueryBuilder); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } @@ -180,7 +186,8 @@ private void test_range_search() { .mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER)) .must(QueryBuilders.rangeQuery("created_time").gte("now-1d")); searchRequest.source().query(boolQueryBuilder); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } @@ -193,7 +200,8 @@ private void test_matchPhrase_search() { .mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER)) .must(QueryBuilders.matchPhraseQuery("description", "desc")); searchRequest.source().query(boolQueryBuilder); - SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); + SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, mlSearchActionRequest).actionGet(); assertEquals(1, response.getHits().getTotalHits().value); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java index 9687306241..4c6b473a60 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -16,12 +16,15 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Collections; import java.util.Map; import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; @@ -38,18 +41,24 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.action.handler.MLSearchHandler; import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -57,6 +66,7 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase { @Mock Client client; + SdkClient sdkClient; @Mock NamedXContentRegistry namedXContentRegistry; @@ -67,9 +77,12 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; - @Mock SearchRequest searchRequest; + MLSearchActionRequest mlSearchActionRequest; + + SearchResponse searchResponse; + @Mock ActionListener actionListener; @@ -91,36 +104,76 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase { @Mock private FetchSourceContext fetchSourceContext; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException thrown = ExpectedException.none(); @Before public void setup() { MockitoAnnotations.openMocks(this); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper, clusterService)); - searchModelTransportAction = new SearchModelTransportAction(transportService, actionFilters, mlSearchHandler); + searchModelTransportAction = new SearchModelTransportAction( + transportService, + actionFilters, + sdkClient, + mlSearchHandler, + mlFeatureEnabledSetting + ); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.fetchSource(fetchSourceContext); + searchRequest = new SearchRequest(new String[0], searchSourceBuilder); + mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); when(fetchSourceContext.includes()).thenReturn(new String[] {}); when(fetchSourceContext.excludes()).thenReturn(new String[] {}); - searchSourceBuilder.fetchSource(fetchSourceContext); - when(searchRequest.source()).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); Metadata metadata = mock(Metadata.class); when(metadata.hasIndex(anyString())).thenReturn(true); ClusterState testState = new ClusterState(new ClusterName("mock"), 123l, "111111", metadata, null, null, null, Map.of(), 0, false); when(clusterService.state()).thenReturn(testState); + + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); } public void test_DoExecute_admin() { when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, null, actionListener); verify(client, times(1)).search(any(), any()); } @@ -130,25 +183,22 @@ public void test_DoExecute_addBackendRoles() throws IOException { ActionListener listener = invocation.getArgument(1); listener.onResponse(searchResponse); return null; - }).when(client).search(any(), isA(ActionListener.class)); + }).when(client).search(any(), any()); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, null, actionListener); verify(client, times(2)).search(any(), any()); } public void test_DoExecute_addBackendRoles_without_groupIds() { - SearchResponse searchResponse = mock(SearchResponse.class); - SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); - when(searchResponse.getHits()).thenReturn(hits); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(searchResponse); return null; }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, null, actionListener); verify(client, times(2)).search(any(), any()); } @@ -159,8 +209,8 @@ public void test_DoExecute_addBackendRoles_exception() { return null; }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, null, actionListener); verify(client, times(1)).search(any(), any()); } @@ -171,8 +221,8 @@ public void test_DoExecute_searchModel_before_model_creation_no_exception() { return null; }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, null, actionListener); verify(client, times(1)).search(any(), any()); verify(actionListener, times(0)).onFailure(any(IndexNotFoundException.class)); } @@ -204,8 +254,8 @@ public void test_DoExecute_searchModel_before_model_creation_empty_search() { return null; }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, null, actionListener); verify(client, times(1)).search(any(), any()); verify(actionListener, times(0)).onFailure(any(IndexNotFoundException.class)); verify(actionListener, times(1)).onResponse(any(SearchResponse.class)); @@ -218,8 +268,8 @@ public void test_DoExecute_searchModel_MLResourceNotFoundException_exception() { return null; }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, null, actionListener); verify(client, times(1)).search(any(), any()); verify(actionListener, times(1)).onFailure(any(OpenSearchStatusException.class)); } @@ -233,8 +283,8 @@ public void test_DoExecute_addBackendRoles_boolQuery() throws IOException { }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.boolQuery().must(QueryBuilders.matchQuery("name", "model_IT"))); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, null, actionListener); verify(client, times(2)).search(any(), any()); } @@ -247,13 +297,50 @@ public void test_DoExecute_addBackendRoles_termQuery() throws IOException { }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.termQuery("name", "model_IT")); - searchModelTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, null, actionListener); + verify(client, times(2)).search(any(), any()); + } + + @Test + public void testDoExecute_MultiTenancyEnabled_TenantFilteringNotEnabled() throws InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + mlSearchActionRequest = new MLSearchActionRequest(request, null); + + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(captor.capture()); + OpenSearchStatusException exception = captor.getValue(); + assertEquals(RestStatus.FORBIDDEN, exception.status()); + assertEquals("You don't have permission to access this resource", exception.getMessage()); + } + + @Test + public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() throws InterruptedException, IOException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + searchRequest.source().query(QueryBuilders.termQuery("name", "model_IT")); + mlSearchActionRequest = new MLSearchActionRequest(searchRequest, "123456"); + + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + + verify(mlSearchHandler).search(searchRequest, "123456", actionListener); verify(client, times(2)).search(any(), any()); } private SearchResponse createModelGroupSearchResponse() throws IOException { - SearchResponse searchResponse = mock(SearchResponse.class); String modelContent = "{\n" + " \"created_time\": 1684981986069,\n" + " \"access\": \"public\",\n" @@ -264,7 +351,25 @@ private SearchResponse createModelGroupSearchResponse() throws IOException { + " }"; SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); - when(searchResponse.getHits()).thenReturn(hits); - return searchResponse; + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + hits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + return new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java index 3ad05f337a..cbba3ea011 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java @@ -7,22 +7,38 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.util.Collections; + +import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -30,7 +46,9 @@ public class SearchTaskTransportActionTests extends OpenSearchTestCase { @Mock Client client; + SdkClient sdkClient; + SearchResponse searchResponse; @Mock NamedXContentRegistry namedXContentRegistry; @@ -40,6 +58,9 @@ public class SearchTaskTransportActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock SearchRequest searchRequest; @@ -54,15 +75,62 @@ public class SearchTaskTransportActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - searchTaskTransportAction = new SearchTaskTransportAction(transportService, actionFilters, client); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + searchTaskTransportAction = new SearchTaskTransportAction( + transportService, + actionFilters, + client, + sdkClient, + mlFeatureEnabledSetting + ); ThreadPool threadPool = mock(ThreadPool.class); when(client.threadPool()).thenReturn(threadPool); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); when(threadPool.getThreadContext()).thenReturn(threadContext); + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); + } public void test_DoExecute() { - searchTaskTransportAction.doExecute(null, searchRequest, actionListener); - verify(client).search(eq(searchRequest), any()); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(eq(request), any()); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); + searchTaskTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + verify(client, times(1)).search(eq(request), any()); + // Use ArgumentCaptor to capture the SearchResponse + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchResponse.class); + // Capture the response passed to actionListener.onResponse + verify(actionListener, times(1)).onResponse(responseCaptor.capture()); + // Assert that the captured response matches the expected values + SearchResponse capturedResponse = responseCaptor.getValue(); + assertEquals(searchResponse.getHits().getTotalHits(), capturedResponse.getHits().getTotalHits()); + assertEquals(searchResponse.getHits().getHits().length, capturedResponse.getHits().getHits().length); + assertEquals(searchResponse.status(), capturedResponse.status()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchAgentActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchAgentActionTests.java index 22c1a9ee48..7f0d2d658a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchAgentActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchAgentActionTests.java @@ -37,6 +37,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.transport.agent.MLSearchAgentAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; import org.opensearch.rest.RestChannel; @@ -136,11 +137,12 @@ public void testPrepareRequest() throws Exception { RestRequest request = getSearchAllRestRequest(); restMLSearchAgentAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); verify(client, times(1)).execute(eq(MLSearchAgentAction.INSTANCE), argumentCaptor.capture(), any()); verify(channel, times(1)).sendResponse(responseCaptor.capture()); - SearchRequest searchRequest = argumentCaptor.getValue(); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_AGENT_INDEX }, indices); assertEquals( @@ -189,11 +191,12 @@ public void testPrepareRequest_timeout() throws Exception { RestRequest request = getSearchAllRestRequest(); restMLSearchAgentAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); verify(client, times(1)).execute(eq(MLSearchAgentAction.INSTANCE), argumentCaptor.capture(), any()); verify(channel, times(1)).sendResponse(responseCaptor.capture()); - SearchRequest searchRequest = argumentCaptor.getValue(); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_AGENT_INDEX }, indices); assertEquals( diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java index 6ff65715ba..0d922494e9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java @@ -40,6 +40,8 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.action.connector.SearchConnectorTransportAction; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; @@ -56,6 +58,9 @@ public class RestMLSearchConnectorActionTests extends OpenSearchTestCase { private RestMLSearchConnectorAction restMLSearchConnectorAction; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + NodeClient client; private ThreadPool threadPool; @Mock @@ -64,7 +69,7 @@ public class RestMLSearchConnectorActionTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - restMLSearchConnectorAction = new RestMLSearchConnectorAction(); + restMLSearchConnectorAction = new RestMLSearchConnectorAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -110,7 +115,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLSearchConnectorAction mlSearchConnectorAction = new RestMLSearchConnectorAction(); + RestMLSearchConnectorAction mlSearchConnectorAction = new RestMLSearchConnectorAction(mlFeatureEnabledSetting); assertNotNull(mlSearchConnectorAction); } @@ -134,11 +139,12 @@ public void testPrepareRequest() throws Exception { RestRequest request = getSearchAllRestRequest(); restMLSearchConnectorAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), argumentCaptor.capture(), any()); verify(channel, times(1)).sendResponse(responseCaptor.capture()); - SearchRequest searchRequest = argumentCaptor.getValue(); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_CONNECTOR_INDEX }, indices); assertEquals( @@ -180,11 +186,12 @@ public void testPrepareRequest_timeout() throws Exception { RestRequest request = getSearchAllRestRequest(); restMLSearchConnectorAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), argumentCaptor.capture(), any()); verify(channel, times(1)).sendResponse(responseCaptor.capture()); - SearchRequest searchRequest = argumentCaptor.getValue(); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_CONNECTOR_INDEX }, indices); assertEquals( diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java index b679b73585..b9e20bf474 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java @@ -12,6 +12,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.TestHelper.getSearchAllRestRequest; @@ -36,6 +37,8 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; @@ -52,6 +55,9 @@ public class RestMLSearchModelActionTests extends OpenSearchTestCase { private RestMLSearchModelAction restMLSearchModelAction; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + NodeClient client; private ThreadPool threadPool; @Mock @@ -60,7 +66,7 @@ public class RestMLSearchModelActionTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - restMLSearchModelAction = new RestMLSearchModelAction(); + restMLSearchModelAction = new RestMLSearchModelAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -106,7 +112,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLSearchModelAction mlSearchModelAction = new RestMLSearchModelAction(); + RestMLSearchModelAction mlSearchModelAction = new RestMLSearchModelAction(mlFeatureEnabledSetting); assertNotNull(mlSearchModelAction); } @@ -130,11 +136,12 @@ public void testPrepareRequest() throws Exception { RestRequest request = getSearchAllRestRequest(); restMLSearchModelAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), argumentCaptor.capture(), any()); verify(channel, times(1)).sendResponse(responseCaptor.capture()); - SearchRequest searchRequest = argumentCaptor.getValue(); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_MODEL_INDEX }, indices); assertEquals( @@ -145,6 +152,27 @@ public void testPrepareRequest() throws Exception { assertNotEquals(RestStatus.REQUEST_TIMEOUT, restResponse.status()); } + public void testPrepareRequest_multiTenancy() throws Exception { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + RestRequest request = getSearchAllRestRequest(); + restMLSearchModelAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), argumentCaptor.capture(), any()); + verify(channel, times(1)).sendResponse(responseCaptor.capture()); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); + String[] indices = searchRequest.indices(); + assertArrayEquals(new String[] { ML_MODEL_INDEX }, indices); + assertEquals( + "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}", + searchRequest.source().toString() + ); + RestResponse restResponse = responseCaptor.getValue(); + assertNotEquals(RestStatus.REQUEST_TIMEOUT, restResponse.status()); + } + public void testPrepareRequest_timeout() throws Exception { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); @@ -176,11 +204,12 @@ public void testPrepareRequest_timeout() throws Exception { RestRequest request = getSearchAllRestRequest(); restMLSearchModelAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), argumentCaptor.capture(), any()); verify(channel, times(1)).sendResponse(responseCaptor.capture()); - SearchRequest searchRequest = argumentCaptor.getValue(); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_MODEL_INDEX }, indices); assertEquals( diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelGroupActionTests.java index 0aa45c4ea3..7959881da2 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelGroupActionTests.java @@ -36,6 +36,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; import org.opensearch.rest.RestChannel; @@ -134,11 +135,12 @@ public void testPrepareRequest() throws Exception { RestRequest request = getSearchAllRestRequest(); restMLSearchModelGroupAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); verify(client, times(1)).execute(eq(MLModelGroupSearchAction.INSTANCE), argumentCaptor.capture(), any()); verify(channel, times(1)).sendResponse(responseCaptor.capture()); - SearchRequest searchRequest = argumentCaptor.getValue(); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_MODEL_GROUP_INDEX }, indices); assertEquals( @@ -180,11 +182,12 @@ public void testPrepareRequest_timeout() throws Exception { RestRequest request = getSearchAllRestRequest(); restMLSearchModelGroupAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); verify(client, times(1)).execute(eq(MLModelGroupSearchAction.INSTANCE), argumentCaptor.capture(), any()); verify(channel, times(1)).sendResponse(responseCaptor.capture()); - SearchRequest searchRequest = argumentCaptor.getValue(); + MLSearchActionRequest mlSearchActionRequest = argumentCaptor.getValue(); + SearchRequest searchRequest = mlSearchActionRequest.getSearchRequest(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_MODEL_GROUP_INDEX }, indices); assertEquals( diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchTaskActionTests.java index 58713cb277..3c917a9568 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchTaskActionTests.java @@ -9,7 +9,9 @@ import org.hamcrest.Matchers; import org.junit.Before; +import org.mockito.Mock; import org.opensearch.core.common.Strings; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; @@ -18,13 +20,16 @@ public class RestMLSearchTaskActionTests extends OpenSearchTestCase { private RestMLSearchTaskAction restMLSearchTaskAction; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() { - restMLSearchTaskAction = new RestMLSearchTaskAction(); + restMLSearchTaskAction = new RestMLSearchTaskAction(mlFeatureEnabledSetting); } public void testConstructor() { - RestMLSearchTaskAction mlSearchTaskAction = new RestMLSearchTaskAction(); + RestMLSearchTaskAction mlSearchTaskAction = new RestMLSearchTaskAction(mlFeatureEnabledSetting); assertNotNull(mlSearchTaskAction); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java index 294e3deac4..adf2b37417 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java @@ -27,15 +27,19 @@ import org.junit.Before; import org.mockito.ArgumentCaptor; -import org.opensearch.action.search.SearchRequest; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; import org.opensearch.test.OpenSearchTestCase; import com.google.gson.Gson; @@ -44,13 +48,17 @@ public class RestMemorySearchConversationsActionTests extends OpenSearchTestCase Gson gson; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() { + MockitoAnnotations.openMocks(this); gson = new Gson(); } public void testBasics() { - RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(); + RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(mlFeatureEnabledSetting); assert (action.getName().equals("conversation_memory_search_conversations")); List routes = action.routes(); assert (routes.size() == 2); @@ -59,15 +67,17 @@ public void testBasics() { } public void testPreprareRequest() throws Exception { - RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(); + RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(mlFeatureEnabledSetting); RestRequest request = TestHelper.getSearchAllRestRequest(); NodeClient client = mock(NodeClient.class); RestChannel channel = mock(RestChannel.class); action.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLSearchActionRequest.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(client, times(1)).execute(eq(SearchConversationsAction.INSTANCE), argumentCaptor.capture(), any()); - assert (argumentCaptor.getValue().source().query() instanceof MatchAllQueryBuilder); + assert (argumentCaptor.getValue().getSearchRequest().source().query() instanceof MatchAllQueryBuilder); } } diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index 860574838a..21d3018032 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -399,7 +399,11 @@ public static RestRequest getExecuteAgentRestRequest() { } public static RestRequest getSearchAllRestRequest() { + String tenantId = "test-tenant"; + Map> headers = new HashMap<>(); + headers.put(Constants.TENANT_ID_HEADER, Collections.singletonList(tenantId)); RestRequest request = new FakeRestRequest.Builder(getXContentRegistry()) + .withHeaders(headers) .withContent(new BytesArray(TestData.matchAllSearchQuery()), XContentType.JSON) .build(); return request;