Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

applying multi-tenancy in search [model, model group, agent, connector] #3433

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package org.opensearch.ml.common.transport.search;

import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import lombok.Builder;
import lombok.Getter;

@Getter
public class MLSearchActionRequest extends SearchRequest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add java documentation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

SearchRequest searchRequest;
String tenantId;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new MLSearchActionRequest class was added to wrap the original SearchRequest and include a tenantId field.So I want to understand does it impact other functionality using SearchRequest?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please give me an example which functionality are you referring to?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the order matters, when did the SearchRequest converts to MLSearchActionRequest? I am worry about the tools and search pipelines that are using the search requests. if the transform is after them, it should be fine, but if that's before tools and search pipelines, does it impact them?
for example

private SearchRequest getSearchRequest(String index, String query) throws IOException {


@Builder
public MLSearchActionRequest(SearchRequest searchRequest, String tenantId) {
this.searchRequest = searchRequest;
this.tenantId = tenantId;
}

public MLSearchActionRequest(StreamInput input) throws IOException {
super(input);
Version streamInputVersion = input.getVersion();
if (input.readBoolean()) {
searchRequest = new SearchRequest(input);
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
}

@Override
public void writeTo(StreamOutput output) throws IOException {
super.writeTo(output);
Version streamOutputVersion = output.getVersion();
if (searchRequest != null) {
output.writeBoolean(true); // user exists
searchRequest.writeTo(output);
} else {
output.writeBoolean(false); // user does not exist
}
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
output.writeOptionalString(tenantId);
}
}

public static MLSearchActionRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLSearchActionRequest) {
return (MLSearchActionRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLSearchActionRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLSearchActionRequest", e);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this exception handling, I think IllegalArgumentException is more proper in this case

The method is expecting an ActionRequest that can be converted to MLSearchActionRequest.
If the conversion fails, it's likely because the input ActionRequest is not of the expected type or format.
This is more of a logical/argument error than an I/O error.

I think if using IllegalArgumentException, it provides clearer feedback to the caller about what went wrong.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is literally the same method we use for all other request classes. Example

}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package org.opensearch.ml.common.transport.search;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;

import java.io.IOException;
import java.io.UncheckedIOException;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

public class MLSearchActionRequestTest {

private SearchRequest searchRequest;

@Before
public void setUp() {
searchRequest = new SearchRequest("test-index");
}

@Test
public void testConstructorAndGetters() {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
assertEquals("test-index", request.getSearchRequest().indices()[0]);
assertEquals("test-tenant", request.getTenantId());
}

@Test
public void testStreamConstructorAndWriteTo() throws IOException {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
BytesStreamOutput out = new BytesStreamOutput();
request.writeTo(out);

MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput());
assertEquals("test-index", deserializedRequest.getSearchRequest().indices()[0]);
assertEquals("test-tenant", deserializedRequest.getTenantId());
}

@Test
public void testWriteToWithNullSearchRequest() throws IOException {
MLSearchActionRequest request = MLSearchActionRequest.builder().tenantId("test-tenant").build();
BytesStreamOutput out = new BytesStreamOutput();
request.writeTo(out);

MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput());
assertNull(deserializedRequest.getSearchRequest());
assertEquals("test-tenant", deserializedRequest.getTenantId());
}

@Test
public void testFromActionRequestWithMLSearchActionRequest() {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request);
assertSame(result, request);
}

@Test
public void testFromActionRequestWithNonMLSearchActionRequest() throws IOException {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
request.writeTo(out);
}
};

MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(actionRequest);
assertNotSame(result, request);
assertEquals(request.getSearchRequest().indices()[0], result.getSearchRequest().indices()[0]);
assertEquals(request.getTenantId(), result.getTenantId());
}

@Test(expected = UncheckedIOException.class)
public void testFromActionRequestIOException() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException("test");
}
};
MLSearchActionRequest.fromActionRequest(actionRequest);
}

@Test
public void testBackwardCompatibility() throws IOException {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();

BytesStreamOutput out = new BytesStreamOutput();
out.setVersion(Version.V_2_18_0); // Older version
request.writeTo(out);

StreamInput in = out.bytes().streamInput();
in.setVersion(Version.V_2_18_0);

MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);
assertNull(deserializedRequest.getTenantId()); // Ensure tenantId is ignored
}

@Test
public void testFromActionRequestWithValidRequest() {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();

MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request);
assertSame(request, result);
}

@Test
public void testMixedVersionCompatibility() throws IOException {
MLSearchActionRequest originalRequest = MLSearchActionRequest
.builder()
.searchRequest(searchRequest)
.tenantId("test-tenant")
.build();

// Serialize with a newer version
BytesStreamOutput out = new BytesStreamOutput();
out.setVersion(Version.V_2_19_0);
originalRequest.writeTo(out);

// Deserialize with an older version
StreamInput in = out.bytes().streamInput();
in.setVersion(Version.V_2_18_0);

MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);
assertNull(deserializedRequest.getTenantId()); // tenantId should not exist in older versions
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand All @@ -38,7 +39,7 @@
import lombok.extern.log4j.Log4j2;

@Log4j2
public class SearchConversationsTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
public class SearchConversationsTransportAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> {

private ConversationalMemoryHandler cmHandler;
private Client client;
Expand All @@ -61,7 +62,7 @@ public SearchConversationsTransportAction(
Client client,
ClusterService clusterService
) {
super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new);
super(SearchConversationsAction.NAME, transportService, actionFilters, MLSearchActionRequest::new);
this.cmHandler = cmHandler;
this.client = client;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
Expand All @@ -71,13 +72,14 @@ public SearchConversationsTransportAction(
}

@Override
public void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
public void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener<SearchResponse> actionListener) {
SearchRequest request = mlSearchActionRequest.getSearchRequest();
if (!featureIsEnabled) {
actionListener.onFailure(new OpenSearchException(ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE));
return;
} else {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, context::restore);
cmHandler.searchConversations(request, internalListener);
} catch (Exception e) {
log.error("Failed to search memories", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
import org.opensearch.ml.memory.MemoryTestUtil;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -79,12 +80,15 @@ public class SearchConversationsTransportActionTests extends OpenSearchTestCase
@Mock
SearchRequest request;

MLSearchActionRequest mlSearchActionRequest;

SearchConversationsTransportAction action;
ThreadContext threadContext;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
mlSearchActionRequest = new MLSearchActionRequest(request, null);

Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build();
this.threadContext = new ThreadContext(settings);
Expand All @@ -104,7 +108,7 @@ public void testEnabled_ThenSucceed() {
listener.onResponse(response);
return null;
}).when(cmHandler).searchConversations(any(), any());
action.doExecute(null, request, actionListener);
action.doExecute(null, mlSearchActionRequest, actionListener);
ArgumentCaptor<SearchResponse> argCaptor = ArgumentCaptor.forClass(SearchResponse.class);
verify(actionListener, times(1)).onResponse(argCaptor.capture());
assert (argCaptor.getValue().equals(response));
Expand All @@ -114,7 +118,7 @@ public void testDisabled_ThenFail() {
clusterService = MemoryTestUtil.clusterServiceWithMemoryFeatureDisabled();
this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService));

action.doExecute(null, request, actionListener);
action.doExecute(null, mlSearchActionRequest, actionListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argCaptor.capture());
assertEquals(argCaptor.getValue().getMessage(), ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.action.agents;

import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -20,28 +21,46 @@
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.transport.agent.MLSearchAgentAction;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class TransportSearchAgentAction extends HandledTransportAction<SearchRequest, SearchResponse> {
public class TransportSearchAgentAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> {
private final Client client;
private final SdkClient sdkClient;
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportSearchAgentAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(MLSearchAgentAction.NAME, transportService, actionFilters, SearchRequest::new);
public TransportSearchAgentAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
SdkClient sdkClient,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLSearchAgentAction.NAME, transportService, actionFilters, MLSearchActionRequest::new);
this.client = client;
this.sdkClient = sdkClient;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
request.indices(CommonValue.ML_AGENT_INDEX);
search(request, actionListener);
protected void doExecute(Task task, MLSearchActionRequest request, ActionListener<SearchResponse> actionListener) {
request.getSearchRequest().indices(CommonValue.ML_AGENT_INDEX);
String tenantId = request.getTenantId();
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
return;
}
search(request.getSearchRequest(), tenantId, actionListener);
}

private void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
private void search(SearchRequest request, String tenantId, ActionListener<SearchResponse> actionListener) {
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search agent");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(listener, context::restore);
Expand All @@ -57,6 +76,11 @@ private void search(SearchRequest request, ActionListener<SearchResponse> action
// Add a should clause to include documents where IS_HIDDEN_FIELD is false
shouldQuery.should(QueryBuilders.termQuery(MLAgent.IS_HIDDEN_FIELD, false));

// For multi-tenancy
if (tenantId != null) {
shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId));
}

// Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null
shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)));

Expand Down
Loading
Loading