diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java b/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java index 8b826fab9c4bb..39661bd78d996 100644 --- a/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java +++ b/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java @@ -10,32 +10,104 @@ import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestResponse; +import org.opensearch.rest.RestStatus; import org.opensearch.transport.TransportResponse; import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.Map; /** - * Response to execute REST Actions on the extension node. + * Response to execute REST Actions on the extension node. Wraps the components of a {@link RestResponse}. * * @opensearch.internal */ public class RestExecuteOnExtensionResponse extends TransportResponse { - private String response; - public RestExecuteOnExtensionResponse(String response) { - this.response = response; + private RestStatus status; + private String contentType; + private byte[] content; + private Map> headers; + + /** + * Instantiate this object with a status and response string. + * + * @param status The REST status. + * @param responseString The response content as a String. + */ + public RestExecuteOnExtensionResponse(RestStatus status, String responseString) { + this(status, BytesRestResponse.TEXT_CONTENT_TYPE, responseString.getBytes(StandardCharsets.UTF_8), Collections.emptyMap()); + } + + /** + * Instantiate this object with the components of a {@link RestResponse}. + * + * @param status The REST status. + * @param contentType The type of the content. + * @param content The content. + * @param headers The headers. + */ + public RestExecuteOnExtensionResponse(RestStatus status, String contentType, byte[] content, Map> headers) { + setStatus(status); + setContentType(contentType); + setContent(content); + setHeaders(headers); } + /** + * Instantiate this object from a Transport Stream + * + * @param in The stream input. + * @throws IOException on transport failure. + */ public RestExecuteOnExtensionResponse(StreamInput in) throws IOException { - response = in.readString(); + setStatus(RestStatus.readFrom(in)); + setContentType(in.readString()); + setContent(in.readByteArray()); + setHeaders(in.readMapOfLists(StreamInput::readString, StreamInput::readString)); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(response); + RestStatus.writeTo(out, status); + out.writeString(contentType); + out.writeByteArray(content); + out.writeMapOfLists(headers, StreamOutput::writeString, StreamOutput::writeString); + } + + public RestStatus getStatus() { + return status; + } + + public void setStatus(RestStatus status) { + this.status = status; + } + + public String getContentType() { + return contentType; + } + + public void setContentType(String contentType) { + this.contentType = contentType; + } + + public byte[] getContent() { + return content; + } + + public void setContent(byte[] content) { + this.content = content; + } + + public Map> getHeaders() { + return headers; } - public String getResponse() { - return response; + public void setHeaders(Map> headers) { + this.headers = Map.copyOf(headers); } } diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java b/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java index 57213023a3d4c..1a5c0351bdaf8 100644 --- a/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java +++ b/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.client.node.NodeClient; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.extensions.DiscoveryExtension; @@ -26,11 +25,14 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.Map.Entry; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import static java.util.Collections.emptyMap; import static java.util.Collections.unmodifiableList; /** @@ -97,8 +99,13 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC } String message = "Forwarding the request " + method + " " + uri + " to " + discoveryExtension; logger.info(message); - // Hack to pass a final class in to fetch the response string - final StringBuilder responseBuilder = new StringBuilder(); + // Initialize response. Values will be changed in the handler. + final RestExecuteOnExtensionResponse restExecuteOnExtensionResponse = new RestExecuteOnExtensionResponse( + RestStatus.INTERNAL_SERVER_ERROR, + BytesRestResponse.TEXT_CONTENT_TYPE, + message.getBytes(StandardCharsets.UTF_8), + emptyMap() + ); final CountDownLatch inProgressLatch = new CountDownLatch(1); final TransportResponseHandler restExecuteOnExtensionResponseHandler = new TransportResponseHandler< RestExecuteOnExtensionResponse>() { @@ -110,15 +117,20 @@ public RestExecuteOnExtensionResponse read(StreamInput in) throws IOException { @Override public void handleResponse(RestExecuteOnExtensionResponse response) { - responseBuilder.append(response.getResponse()); - logger.info("Received response from extension: {}", response.getResponse()); + logger.info("Received response from extension: {}", response.getStatus()); + restExecuteOnExtensionResponse.setStatus(response.getStatus()); + restExecuteOnExtensionResponse.setContentType(response.getContentType()); + restExecuteOnExtensionResponse.setContent(response.getContent()); + restExecuteOnExtensionResponse.setHeaders(response.getHeaders()); inProgressLatch.countDown(); } @Override public void handleException(TransportException exp) { - responseBuilder.append("FAILED: ").append(exp); - logger.debug(new ParameterizedMessage("REST request failed"), exp); + logger.debug("REST request failed", exp); + // Status is already defaulted to 500 (INTERNAL_SERVER_ERROR) + byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8); + restExecuteOnExtensionResponse.setContent(responseBytes); inProgressLatch.countDown(); } @@ -144,12 +156,18 @@ public String executor() { } catch (Exception e) { logger.info("Failed to send REST Actions to extension " + discoveryExtension.getName(), e); } - String response = responseBuilder.toString(); - if (response.isBlank() || response.startsWith("FAILED")) { - return channel -> channel.sendResponse( - new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, response.isBlank() ? "Request Failed" : response) - ); + + BytesRestResponse restResponse = new BytesRestResponse( + restExecuteOnExtensionResponse.getStatus(), + restExecuteOnExtensionResponse.getContentType(), + restExecuteOnExtensionResponse.getContent() + ); + for (Entry> headerEntry : restExecuteOnExtensionResponse.getHeaders().entrySet()) { + for (String value : headerEntry.getValue()) { + restResponse.addHeader(headerEntry.getKey(), value); + } } - return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.OK, response)); + + return channel -> channel.sendResponse(restResponse); } } diff --git a/server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java b/server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java index cce5bbb21490a..a8f1739ce82f2 100644 --- a/server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java +++ b/server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java @@ -9,6 +9,10 @@ package org.opensearch.extensions.rest; import java.util.List; + +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.test.OpenSearchTestCase; public class RegisterRestActionsTests extends OpenSearchTestCase { @@ -17,11 +21,42 @@ public void testRegisterRestActionsRequest() throws Exception { String uniqueIdStr = "uniqueid1"; List expected = List.of("GET /foo", "PUT /bar", "POST /baz"); RegisterRestActionsRequest registerRestActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, expected); - assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId()); + assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId()); List restActions = registerRestActionsRequest.getRestActions(); assertEquals(expected.size(), restActions.size()); assertTrue(restActions.containsAll(expected)); assertTrue(expected.containsAll(restActions)); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + registerRestActionsRequest.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + registerRestActionsRequest = new RegisterRestActionsRequest(in); + + assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId()); + restActions = registerRestActionsRequest.getRestActions(); + assertEquals(expected.size(), restActions.size()); + assertTrue(restActions.containsAll(expected)); + assertTrue(expected.containsAll(restActions)); + } + } + } + + public void testRegisterRestActionsResponse() throws Exception { + String response = "This is a response"; + RegisterRestActionsResponse registerRestActionsResponse = new RegisterRestActionsResponse(response); + + assertEquals(response, registerRestActionsResponse.getResponse()); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + registerRestActionsResponse.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + registerRestActionsResponse = new RegisterRestActionsResponse(in); + + assertEquals(response, registerRestActionsResponse.getResponse()); + } + } } } diff --git a/server/src/test/java/org/opensearch/extensions/rest/RestExecuteOnExtensionTests.java b/server/src/test/java/org/opensearch/extensions/rest/RestExecuteOnExtensionTests.java new file mode 100644 index 0000000000000..98521ddcf1e26 --- /dev/null +++ b/server/src/test/java/org/opensearch/extensions/rest/RestExecuteOnExtensionTests.java @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.rest.RestStatus; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.test.OpenSearchTestCase; + +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +public class RestExecuteOnExtensionTests extends OpenSearchTestCase { + + public void testRestExecuteOnExtensionRequest() throws Exception { + Method expectedMethod = Method.GET; + String expectedUri = "/test/uri"; + RestExecuteOnExtensionRequest request = new RestExecuteOnExtensionRequest(expectedMethod, expectedUri); + + assertEquals(expectedMethod, request.getMethod()); + assertEquals(expectedUri, request.getUri()); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + request.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + request = new RestExecuteOnExtensionRequest(in); + + assertEquals(expectedMethod, request.getMethod()); + assertEquals(expectedUri, request.getUri()); + } + } + } + + public void testRestExecuteOnExtensionResponse() throws Exception { + RestStatus expectedStatus = RestStatus.OK; + String expectedContentType = BytesRestResponse.TEXT_CONTENT_TYPE; + String expectedResponse = "Test response"; + byte[] expectedResponseBytes = expectedResponse.getBytes(StandardCharsets.UTF_8); + + RestExecuteOnExtensionResponse response = new RestExecuteOnExtensionResponse(expectedStatus, expectedResponse); + + assertEquals(expectedStatus, response.getStatus()); + assertEquals(expectedContentType, response.getContentType()); + assertArrayEquals(expectedResponseBytes, response.getContent()); + assertEquals(0, response.getHeaders().size()); + + String headerKey = "foo"; + List headerValueList = List.of("bar", "baz"); + Map> expectedHeaders = Map.of(headerKey, headerValueList); + + response = new RestExecuteOnExtensionResponse(expectedStatus, expectedContentType, expectedResponseBytes, expectedHeaders); + + assertEquals(expectedStatus, response.getStatus()); + assertEquals(expectedContentType, response.getContentType()); + assertArrayEquals(expectedResponseBytes, response.getContent()); + + assertEquals(1, expectedHeaders.keySet().size()); + assertTrue(expectedHeaders.containsKey(headerKey)); + + List fooList = expectedHeaders.get(headerKey); + assertEquals(2, fooList.size()); + assertTrue(fooList.containsAll(headerValueList)); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + response.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + response = new RestExecuteOnExtensionResponse(in); + + assertEquals(expectedStatus, response.getStatus()); + assertEquals(expectedContentType, response.getContentType()); + assertArrayEquals(expectedResponseBytes, response.getContent()); + + assertEquals(1, expectedHeaders.keySet().size()); + assertTrue(expectedHeaders.containsKey(headerKey)); + + fooList = expectedHeaders.get(headerKey); + assertEquals(2, fooList.size()); + assertTrue(fooList.containsAll(headerValueList)); + } + } + } +}