forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
applying multi-tenancy in search [model, model group, agent, connecto…
…r] (opensearch-project#3433) * applying multi-tenancy in search Signed-off-by: Dhrubo Saha <[email protected]> * addressed comments Signed-off-by: Dhrubo Saha <[email protected]> --------- Signed-off-by: Dhrubo Saha <[email protected]>
- Loading branch information
Showing
32 changed files
with
1,187 additions
and
172 deletions.
There are no files selected for viewing
92 changes: 92 additions & 0 deletions
92
common/src/main/java/org/opensearch/ml/common/transport/search/MLSearchActionRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
package org.opensearch.ml.common.transport.search; | ||
|
||
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
import org.opensearch.Version; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
|
||
/** | ||
* Represents an extended search action request that includes a tenant ID. | ||
* This class allows OpenSearch to include a tenant ID in search requests, | ||
* which is not natively supported in the standard {@link SearchRequest}. | ||
*/ | ||
@Getter | ||
public class MLSearchActionRequest extends SearchRequest { | ||
SearchRequest searchRequest; | ||
String tenantId; | ||
|
||
/** | ||
* Constructor for building an MLSearchActionRequest. | ||
* | ||
* @param searchRequest The original {@link SearchRequest} to be wrapped. | ||
* @param tenantId The tenant ID associated with the request. | ||
*/ | ||
@Builder | ||
public MLSearchActionRequest(SearchRequest searchRequest, String tenantId) { | ||
this.searchRequest = searchRequest; | ||
this.tenantId = tenantId; | ||
} | ||
|
||
/** | ||
* Deserializes an {@link MLSearchActionRequest} from a {@link StreamInput}. | ||
* | ||
* @param input The stream input to read from. | ||
* @throws IOException If an I/O error occurs during deserialization. | ||
*/ | ||
public MLSearchActionRequest(StreamInput input) throws IOException { | ||
super(input); | ||
Version streamInputVersion = input.getVersion(); | ||
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null; | ||
} | ||
|
||
/** | ||
* Serializes this {@link MLSearchActionRequest} to a {@link StreamOutput}. | ||
* | ||
* @param output The stream output to write to. | ||
* @throws IOException If an I/O error occurs during serialization. | ||
*/ | ||
@Override | ||
public void writeTo(StreamOutput output) throws IOException { | ||
super.writeTo(output); | ||
Version streamOutputVersion = output.getVersion(); | ||
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { | ||
output.writeOptionalString(tenantId); | ||
} | ||
} | ||
|
||
/** | ||
* Converts a generic {@link ActionRequest} into an {@link MLSearchActionRequest}. | ||
* This is useful when handling requests that may need to be converted for compatibility. | ||
* | ||
* @param actionRequest The original {@link ActionRequest}. | ||
* @return The converted {@link MLSearchActionRequest}. | ||
* @throws UncheckedIOException If the conversion fails due to an I/O error. | ||
*/ | ||
public static MLSearchActionRequest fromActionRequest(ActionRequest actionRequest) { | ||
if (actionRequest instanceof MLSearchActionRequest) { | ||
return (MLSearchActionRequest) actionRequest; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionRequest.writeTo(osso); | ||
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLSearchActionRequest(input); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("failed to parse ActionRequest into MLSearchActionRequest", e); | ||
} | ||
} | ||
} |
147 changes: 147 additions & 0 deletions
147
...on/src/test/java/org/opensearch/ml/common/transport/search/MLSearchActionRequestTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
package org.opensearch.ml.common.transport.search; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertNotSame; | ||
import static org.junit.Assert.assertNull; | ||
import static org.junit.Assert.assertSame; | ||
|
||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.opensearch.Version; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.common.io.stream.BytesStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
|
||
public class MLSearchActionRequestTest { | ||
|
||
private SearchRequest searchRequest; | ||
|
||
@Before | ||
public void setUp() { | ||
searchRequest = new SearchRequest("test-index"); | ||
} | ||
|
||
@Test | ||
public void testConstructorAndGetters() { | ||
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); | ||
assertEquals("test-index", request.getSearchRequest().indices()[0]); | ||
assertEquals("test-tenant", request.getTenantId()); | ||
} | ||
|
||
@Test | ||
public void testStreamConstructorAndWriteTo() throws IOException { | ||
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); | ||
BytesStreamOutput out = new BytesStreamOutput(); | ||
request.writeTo(out); | ||
|
||
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput()); | ||
assertEquals("test-index", deserializedRequest.getSearchRequest().indices()[0]); | ||
assertEquals("test-tenant", deserializedRequest.getTenantId()); | ||
} | ||
|
||
@Test | ||
public void testWriteToWithNullSearchRequest() throws IOException { | ||
MLSearchActionRequest request = MLSearchActionRequest.builder().tenantId("test-tenant").build(); | ||
BytesStreamOutput out = new BytesStreamOutput(); | ||
request.writeTo(out); | ||
|
||
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput()); | ||
assertNull(deserializedRequest.getSearchRequest()); | ||
assertEquals("test-tenant", deserializedRequest.getTenantId()); | ||
} | ||
|
||
@Test | ||
public void testFromActionRequestWithMLSearchActionRequest() { | ||
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); | ||
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request); | ||
assertSame(result, request); | ||
} | ||
|
||
@Test | ||
public void testFromActionRequestWithNonMLSearchActionRequest() throws IOException { | ||
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); | ||
ActionRequest actionRequest = new ActionRequest() { | ||
@Override | ||
public ActionRequestValidationException validate() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
request.writeTo(out); | ||
} | ||
}; | ||
|
||
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(actionRequest); | ||
assertNotSame(result, request); | ||
assertEquals(request.getSearchRequest().indices()[0], result.getSearchRequest().indices()[0]); | ||
assertEquals(request.getTenantId(), result.getTenantId()); | ||
} | ||
|
||
@Test(expected = UncheckedIOException.class) | ||
public void testFromActionRequestIOException() { | ||
ActionRequest actionRequest = new ActionRequest() { | ||
@Override | ||
public ActionRequestValidationException validate() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
throw new IOException("test"); | ||
} | ||
}; | ||
MLSearchActionRequest.fromActionRequest(actionRequest); | ||
} | ||
|
||
@Test | ||
public void testBackwardCompatibility() throws IOException { | ||
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); | ||
|
||
BytesStreamOutput out = new BytesStreamOutput(); | ||
out.setVersion(Version.V_2_18_0); // Older version | ||
request.writeTo(out); | ||
|
||
StreamInput in = out.bytes().streamInput(); | ||
in.setVersion(Version.V_2_18_0); | ||
|
||
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in); | ||
assertNull(deserializedRequest.getTenantId()); // Ensure tenantId is ignored | ||
} | ||
|
||
@Test | ||
public void testFromActionRequestWithValidRequest() { | ||
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build(); | ||
|
||
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request); | ||
assertSame(request, result); | ||
} | ||
|
||
@Test | ||
public void testMixedVersionCompatibility() throws IOException { | ||
MLSearchActionRequest originalRequest = MLSearchActionRequest | ||
.builder() | ||
.searchRequest(searchRequest) | ||
.tenantId("test-tenant") | ||
.build(); | ||
|
||
// Serialize with a newer version | ||
BytesStreamOutput out = new BytesStreamOutput(); | ||
out.setVersion(Version.V_2_19_0); | ||
originalRequest.writeTo(out); | ||
|
||
// Deserialize with an older version | ||
StreamInput in = out.bytes().streamInput(); | ||
in.setVersion(Version.V_2_18_0); | ||
|
||
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in); | ||
assertNull(deserializedRequest.getTenantId()); // tenantId should not exist in older versions | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.