Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature/extensions] Pass full RestResponse to user from Extension #4356

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,104 @@

import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.RestStatus;
import org.opensearch.transport.TransportResponse;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Map;

/**
* Response to execute REST Actions on the extension node.
* Response to execute REST Actions on the extension node. Wraps the components of a {@link RestResponse}.
*
* @opensearch.internal
*/
public class RestExecuteOnExtensionResponse extends TransportResponse {
private String response;

public RestExecuteOnExtensionResponse(String response) {
this.response = response;
private RestStatus status;
private String contentType;
private byte[] content;
private Map<String, List<String>> headers;

/**
* Instantiate this object with a status and response string.
*
* @param status The REST status.
* @param responseString The response content as a String.
*/
public RestExecuteOnExtensionResponse(RestStatus status, String responseString) {
this(status, BytesRestResponse.TEXT_CONTENT_TYPE, responseString.getBytes(StandardCharsets.UTF_8), Collections.emptyMap());
}

/**
* Instantiate this object with the components of a {@link RestResponse}.
*
* @param status The REST status.
* @param contentType The type of the content.
* @param content The content.
* @param headers The headers.
*/
public RestExecuteOnExtensionResponse(RestStatus status, String contentType, byte[] content, Map<String, List<String>> headers) {
setStatus(status);
setContentType(contentType);
setContent(content);
setHeaders(headers);
}

/**
* Instantiate this object from a Transport Stream
*
* @param in The stream input.
* @throws IOException on transport failure.
*/
public RestExecuteOnExtensionResponse(StreamInput in) throws IOException {
response = in.readString();
setStatus(RestStatus.readFrom(in));
setContentType(in.readString());
setContent(in.readByteArray());
setHeaders(in.readMapOfLists(StreamInput::readString, StreamInput::readString));
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(response);
RestStatus.writeTo(out, status);
out.writeString(contentType);
out.writeByteArray(content);
out.writeMapOfLists(headers, StreamOutput::writeString, StreamOutput::writeString);
}

public RestStatus getStatus() {
return status;
}

public void setStatus(RestStatus status) {
this.status = status;
}

public String getContentType() {
return contentType;
}

public void setContentType(String contentType) {
this.contentType = contentType;
}

public byte[] getContent() {
return content;
}

public void setContent(byte[] content) {
this.content = content;
}

public Map<String, List<String>> getHeaders() {
return headers;
}

public String getResponse() {
return response;
public void setHeaders(Map<String, List<String>> headers) {
this.headers = Map.copyOf(headers);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.extensions.DiscoveryExtension;
Expand All @@ -26,11 +25,14 @@
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map.Entry;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableList;

/**
Expand Down Expand Up @@ -97,8 +99,13 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
}
String message = "Forwarding the request " + method + " " + uri + " to " + discoveryExtension;
logger.info(message);
// Hack to pass a final class in to fetch the response string
final StringBuilder responseBuilder = new StringBuilder();
// Initialize response. Values will be changed in the handler.
final RestExecuteOnExtensionResponse restExecuteOnExtensionResponse = new RestExecuteOnExtensionResponse(
RestStatus.INTERNAL_SERVER_ERROR,
BytesRestResponse.TEXT_CONTENT_TYPE,
message.getBytes(StandardCharsets.UTF_8),
emptyMap()
);
final CountDownLatch inProgressLatch = new CountDownLatch(1);
final TransportResponseHandler<RestExecuteOnExtensionResponse> restExecuteOnExtensionResponseHandler = new TransportResponseHandler<
RestExecuteOnExtensionResponse>() {
Expand All @@ -110,15 +117,20 @@ public RestExecuteOnExtensionResponse read(StreamInput in) throws IOException {

@Override
public void handleResponse(RestExecuteOnExtensionResponse response) {
responseBuilder.append(response.getResponse());
logger.info("Received response from extension: {}", response.getResponse());
logger.info("Received response from extension: {}", response.getStatus());
restExecuteOnExtensionResponse.setStatus(response.getStatus());
restExecuteOnExtensionResponse.setContentType(response.getContentType());
restExecuteOnExtensionResponse.setContent(response.getContent());
restExecuteOnExtensionResponse.setHeaders(response.getHeaders());
inProgressLatch.countDown();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to figure out a way to remove latches. I think OpenSearch does it via Listeners.

}

@Override
public void handleException(TransportException exp) {
responseBuilder.append("FAILED: ").append(exp);
logger.debug(new ParameterizedMessage("REST request failed"), exp);
logger.debug("REST request failed", exp);
// Status is already defaulted to 500 (INTERNAL_SERVER_ERROR)
byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8);
restExecuteOnExtensionResponse.setContent(responseBytes);
inProgressLatch.countDown();
}

Expand All @@ -144,12 +156,18 @@ public String executor() {
} catch (Exception e) {
logger.info("Failed to send REST Actions to extension " + discoveryExtension.getName(), e);
}
String response = responseBuilder.toString();
if (response.isBlank() || response.startsWith("FAILED")) {
return channel -> channel.sendResponse(
new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, response.isBlank() ? "Request Failed" : response)
);

BytesRestResponse restResponse = new BytesRestResponse(
restExecuteOnExtensionResponse.getStatus(),
restExecuteOnExtensionResponse.getContentType(),
restExecuteOnExtensionResponse.getContent()
);
for (Entry<String, List<String>> headerEntry : restExecuteOnExtensionResponse.getHeaders().entrySet()) {
for (String value : headerEntry.getValue()) {
restResponse.addHeader(headerEntry.getKey(), value);
}
}
return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.OK, response));

return channel -> channel.sendResponse(restResponse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
package org.opensearch.extensions.rest;

import java.util.List;

import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.io.stream.BytesStreamInput;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.test.OpenSearchTestCase;

public class RegisterRestActionsTests extends OpenSearchTestCase {
Expand All @@ -17,11 +21,42 @@ public void testRegisterRestActionsRequest() throws Exception {
String uniqueIdStr = "uniqueid1";
List<String> expected = List.of("GET /foo", "PUT /bar", "POST /baz");
RegisterRestActionsRequest registerRestActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, expected);
assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId());

assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId());
List<String> restActions = registerRestActionsRequest.getRestActions();
assertEquals(expected.size(), restActions.size());
assertTrue(restActions.containsAll(expected));
assertTrue(expected.containsAll(restActions));

try (BytesStreamOutput out = new BytesStreamOutput()) {
registerRestActionsRequest.writeTo(out);
out.flush();
try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) {
registerRestActionsRequest = new RegisterRestActionsRequest(in);

assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId());
restActions = registerRestActionsRequest.getRestActions();
assertEquals(expected.size(), restActions.size());
assertTrue(restActions.containsAll(expected));
assertTrue(expected.containsAll(restActions));
}
}
}

public void testRegisterRestActionsResponse() throws Exception {
String response = "This is a response";
RegisterRestActionsResponse registerRestActionsResponse = new RegisterRestActionsResponse(response);

assertEquals(response, registerRestActionsResponse.getResponse());

try (BytesStreamOutput out = new BytesStreamOutput()) {
registerRestActionsResponse.writeTo(out);
out.flush();
try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) {
registerRestActionsResponse = new RegisterRestActionsResponse(in);

assertEquals(response, registerRestActionsResponse.getResponse());
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.extensions.rest;

import org.opensearch.rest.RestStatus;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.io.stream.BytesStreamInput;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest.Method;
import org.opensearch.test.OpenSearchTestCase;

import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;

public class RestExecuteOnExtensionTests extends OpenSearchTestCase {

public void testRestExecuteOnExtensionRequest() throws Exception {
Method expectedMethod = Method.GET;
String expectedUri = "/test/uri";
RestExecuteOnExtensionRequest request = new RestExecuteOnExtensionRequest(expectedMethod, expectedUri);

assertEquals(expectedMethod, request.getMethod());
assertEquals(expectedUri, request.getUri());

try (BytesStreamOutput out = new BytesStreamOutput()) {
request.writeTo(out);
out.flush();
try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) {
request = new RestExecuteOnExtensionRequest(in);

assertEquals(expectedMethod, request.getMethod());
assertEquals(expectedUri, request.getUri());
}
}
}

public void testRestExecuteOnExtensionResponse() throws Exception {
RestStatus expectedStatus = RestStatus.OK;
String expectedContentType = BytesRestResponse.TEXT_CONTENT_TYPE;
String expectedResponse = "Test response";
byte[] expectedResponseBytes = expectedResponse.getBytes(StandardCharsets.UTF_8);

RestExecuteOnExtensionResponse response = new RestExecuteOnExtensionResponse(expectedStatus, expectedResponse);

assertEquals(expectedStatus, response.getStatus());
assertEquals(expectedContentType, response.getContentType());
assertArrayEquals(expectedResponseBytes, response.getContent());
assertEquals(0, response.getHeaders().size());

String headerKey = "foo";
List<String> headerValueList = List.of("bar", "baz");
Map<String, List<String>> expectedHeaders = Map.of(headerKey, headerValueList);

response = new RestExecuteOnExtensionResponse(expectedStatus, expectedContentType, expectedResponseBytes, expectedHeaders);

assertEquals(expectedStatus, response.getStatus());
assertEquals(expectedContentType, response.getContentType());
assertArrayEquals(expectedResponseBytes, response.getContent());

assertEquals(1, expectedHeaders.keySet().size());
assertTrue(expectedHeaders.containsKey(headerKey));

List<String> fooList = expectedHeaders.get(headerKey);
assertEquals(2, fooList.size());
assertTrue(fooList.containsAll(headerValueList));

try (BytesStreamOutput out = new BytesStreamOutput()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. Now we are testing input/output byte stream as well. Can we create an issue to test the other request/response as well we have for extensibility?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a great thing that we're testing it, as when my test initially failed I discovered write(byte[]) is different than writeByteArray(byte[]). :-)

Yes, we should test all the other request/response code similarly. There's a small list of other fixes I'd like (compiler warnings, etc.) that I'll put in an issue later this week.

response.writeTo(out);
out.flush();
try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) {
response = new RestExecuteOnExtensionResponse(in);

assertEquals(expectedStatus, response.getStatus());
assertEquals(expectedContentType, response.getContentType());
assertArrayEquals(expectedResponseBytes, response.getContent());

assertEquals(1, expectedHeaders.keySet().size());
assertTrue(expectedHeaders.containsKey(headerKey));

fooList = expectedHeaders.get(headerKey);
assertEquals(2, fooList.size());
assertTrue(fooList.containsAll(headerValueList));
}
}
}
}