Skip to content

Commit

Permalink
applying multi-tenancy in search
Browse files Browse the repository at this point in the history
Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os committed Jan 28, 2025
1 parent 06a2b40 commit 57f9afa
Show file tree
Hide file tree
Showing 29 changed files with 1,079 additions and 157 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package org.opensearch.ml.common.transport.search;

import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import lombok.Builder;
import lombok.Getter;

@Getter
public class MLSearchActionRequest extends SearchRequest {
SearchRequest searchRequest;
String tenantId;

@Builder
public MLSearchActionRequest(SearchRequest searchRequest, String tenantId) {
this.searchRequest = searchRequest;
this.tenantId = tenantId;
}

public MLSearchActionRequest(StreamInput input) throws IOException {
super(input);
Version streamInputVersion = input.getVersion();
if (input.readBoolean()) {
searchRequest = new SearchRequest(input);
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
}

@Override
public void writeTo(StreamOutput output) throws IOException {
super.writeTo(output);
Version streamOutputVersion = output.getVersion();
if (searchRequest != null) {
output.writeBoolean(true); // user exists
searchRequest.writeTo(output);
} else {
output.writeBoolean(false); // user does not exist
}
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
output.writeOptionalString(tenantId);
}
}

public static MLSearchActionRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLSearchActionRequest) {
return (MLSearchActionRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLSearchActionRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLSearchActionRequest", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package org.opensearch.ml.common.transport.search;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;

import java.io.IOException;
import java.io.UncheckedIOException;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

public class MLSearchActionRequestTest {

private SearchRequest searchRequest;

@Before
public void setUp() {
searchRequest = new SearchRequest("test-index");
}

@Test
public void testConstructorAndGetters() {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
assertEquals("test-index", request.getSearchRequest().indices()[0]);
assertEquals("test-tenant", request.getTenantId());
}

@Test
public void testStreamConstructorAndWriteTo() throws IOException {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
BytesStreamOutput out = new BytesStreamOutput();
request.writeTo(out);

MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput());
assertEquals("test-index", deserializedRequest.getSearchRequest().indices()[0]);
assertEquals("test-tenant", deserializedRequest.getTenantId());
}

@Test
public void testWriteToWithNullSearchRequest() throws IOException {
MLSearchActionRequest request = MLSearchActionRequest.builder().tenantId("test-tenant").build();
BytesStreamOutput out = new BytesStreamOutput();
request.writeTo(out);

MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput());
assertNull(deserializedRequest.getSearchRequest());
assertEquals("test-tenant", deserializedRequest.getTenantId());
}

@Test
public void testFromActionRequestWithMLSearchActionRequest() {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request);
assertSame(result, request);
}

@Test
public void testFromActionRequestWithNonMLSearchActionRequest() throws IOException {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
request.writeTo(out);
}
};

MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(actionRequest);
assertNotSame(result, request);
assertEquals(request.getSearchRequest().indices()[0], result.getSearchRequest().indices()[0]);
assertEquals(request.getTenantId(), result.getTenantId());
}

@Test(expected = UncheckedIOException.class)
public void testFromActionRequestIOException() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException("test");
}
};
MLSearchActionRequest.fromActionRequest(actionRequest);
}

@Test
public void testBackwardCompatibility() throws IOException {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();

BytesStreamOutput out = new BytesStreamOutput();
out.setVersion(Version.V_2_18_0); // Older version
request.writeTo(out);

StreamInput in = out.bytes().streamInput();
in.setVersion(Version.V_2_18_0);

MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);
assertNull(deserializedRequest.getTenantId()); // Ensure tenantId is ignored
}

@Test
public void testFromActionRequestWithValidRequest() {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();

MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request);
assertSame(request, result);
}

@Test
public void testMixedVersionCompatibility() throws IOException {
MLSearchActionRequest originalRequest = MLSearchActionRequest
.builder()
.searchRequest(searchRequest)
.tenantId("test-tenant")
.build();

// Serialize with a newer version
BytesStreamOutput out = new BytesStreamOutput();
out.setVersion(Version.V_2_19_0);
originalRequest.writeTo(out);

// Deserialize with an older version
StreamInput in = out.bytes().streamInput();
in.setVersion(Version.V_2_18_0);

MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);
assertNull(deserializedRequest.getTenantId()); // tenantId should not exist in older versions
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand All @@ -38,7 +39,7 @@
import lombok.extern.log4j.Log4j2;

@Log4j2
public class SearchConversationsTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
public class SearchConversationsTransportAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> {

private ConversationalMemoryHandler cmHandler;
private Client client;
Expand All @@ -61,7 +62,7 @@ public SearchConversationsTransportAction(
Client client,
ClusterService clusterService
) {
super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new);
super(SearchConversationsAction.NAME, transportService, actionFilters, MLSearchActionRequest::new);
this.cmHandler = cmHandler;
this.client = client;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
Expand All @@ -71,13 +72,14 @@ public SearchConversationsTransportAction(
}

@Override
public void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
public void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener<SearchResponse> actionListener) {
SearchRequest request = mlSearchActionRequest.getSearchRequest();
if (!featureIsEnabled) {
actionListener.onFailure(new OpenSearchException(ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE));
return;
} else {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, context::restore);
cmHandler.searchConversations(request, internalListener);
} catch (Exception e) {
log.error("Failed to search memories", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
import org.opensearch.ml.memory.MemoryTestUtil;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -79,12 +80,15 @@ public class SearchConversationsTransportActionTests extends OpenSearchTestCase
@Mock
SearchRequest request;

MLSearchActionRequest mlSearchActionRequest;

SearchConversationsTransportAction action;
ThreadContext threadContext;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
mlSearchActionRequest = new MLSearchActionRequest(request, null);

Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build();
this.threadContext = new ThreadContext(settings);
Expand All @@ -104,7 +108,7 @@ public void testEnabled_ThenSucceed() {
listener.onResponse(response);
return null;
}).when(cmHandler).searchConversations(any(), any());
action.doExecute(null, request, actionListener);
action.doExecute(null, mlSearchActionRequest, actionListener);
ArgumentCaptor<SearchResponse> argCaptor = ArgumentCaptor.forClass(SearchResponse.class);
verify(actionListener, times(1)).onResponse(argCaptor.capture());
assert (argCaptor.getValue().equals(response));
Expand All @@ -114,7 +118,7 @@ public void testDisabled_ThenFail() {
clusterService = MemoryTestUtil.clusterServiceWithMemoryFeatureDisabled();
this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService));

action.doExecute(null, request, actionListener);
action.doExecute(null, mlSearchActionRequest, actionListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argCaptor.capture());
assertEquals(argCaptor.getValue().getMessage(), ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.action.agents;

import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -20,28 +21,46 @@
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.transport.agent.MLSearchAgentAction;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class TransportSearchAgentAction extends HandledTransportAction<SearchRequest, SearchResponse> {
public class TransportSearchAgentAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> {
private final Client client;
private final SdkClient sdkClient;
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportSearchAgentAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(MLSearchAgentAction.NAME, transportService, actionFilters, SearchRequest::new);
public TransportSearchAgentAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
SdkClient sdkClient,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLSearchAgentAction.NAME, transportService, actionFilters, MLSearchActionRequest::new);
this.client = client;
this.sdkClient = sdkClient;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
request.indices(CommonValue.ML_AGENT_INDEX);
search(request, actionListener);
protected void doExecute(Task task, MLSearchActionRequest request, ActionListener<SearchResponse> actionListener) {
request.getSearchRequest().indices(CommonValue.ML_AGENT_INDEX);
String tenantId = request.getTenantId();
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
return;
}
search(request.getSearchRequest(), tenantId, actionListener);
}

private void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
private void search(SearchRequest request, String tenantId, ActionListener<SearchResponse> actionListener) {
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search agent");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(listener, context::restore);
Expand All @@ -57,6 +76,11 @@ private void search(SearchRequest request, ActionListener<SearchResponse> action
// Add a should clause to include documents where IS_HIDDEN_FIELD is false
shouldQuery.should(QueryBuilders.termQuery(MLAgent.IS_HIDDEN_FIELD, false));

// For multi-tenancy
if (tenantId != null) {
shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId));
}

// Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null
shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)));

Expand Down
Loading

0 comments on commit 57f9afa

Please sign in to comment.