Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Zhongnan Su <[email protected]>
  • Loading branch information
zhongnansu committed May 10, 2022
1 parent aafe4a3 commit e6b6cee
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 68 deletions.
2 changes: 1 addition & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ dependencyLicenses.enabled = false
// enable testingConventions check will cause errors like: "Classes ending with [Tests] must subclass [LuceneTestCase]"
testingConventions.enabled = false

// TODO: need to verify the thirdParty
// TODO: need to verify the thirdPartyAudit
// currently it complains missing classes like ibatis, mysql etc, should not be a problem
thirdPartyAudit.enabled = false

Expand Down
22 changes: 14 additions & 8 deletions plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.commons.sql.action.SQLActions;
import org.opensearch.commons.sql.action.TransportQueryResponse;
import org.opensearch.commons.sql.action.TransportPPLQueryResponse;
import org.opensearch.commons.sql.action.TransportSQLQueryResponse;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.plugins.ActionPlugin;
Expand All @@ -60,9 +61,10 @@
import org.opensearch.sql.plugin.rest.RestPPLQueryAction;
import org.opensearch.sql.plugin.rest.RestPPLStatsAction;
import org.opensearch.sql.plugin.rest.RestQuerySettingsAction;
import org.opensearch.sql.plugin.transport.PPLQueryHelper;
import org.opensearch.sql.plugin.transport.SQLQueryHelper;
import org.opensearch.sql.plugin.transport.TransportQueryAction;
import org.opensearch.sql.plugin.transport.TransportPPLService;
import org.opensearch.sql.plugin.transport.TransportPPLQueryAction;
import org.opensearch.sql.plugin.transport.TransportSQLService;
import org.opensearch.sql.plugin.transport.TransportSQLQueryAction;
import org.opensearch.sql.ppl.PPLService;
import org.opensearch.sql.ppl.config.PPLServiceConfig;
import org.opensearch.sql.sql.SQLService;
Expand Down Expand Up @@ -120,8 +122,12 @@ public List<RestHandler> getRestHandlers(Settings settings, RestController restC
return Arrays
.asList(
new ActionHandler<>(
new ActionType<>(SQLActions.SEND_SQL_QUERY_NAME, TransportQueryResponse::new),
TransportQueryAction.class
new ActionType<>(SQLActions.SEND_SQL_QUERY_NAME, TransportSQLQueryResponse::new),
TransportSQLQueryAction.class
),
new ActionHandler<>(
new ActionType<>(SQLActions.SEND_PPL_QUERY_NAME, TransportPPLQueryResponse::new),
TransportPPLQueryAction.class
)
);
}
Expand All @@ -145,8 +151,8 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
LocalClusterState.state().setClusterService(clusterService);
LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings);

PPLQueryHelper.getInstance().setPplService(createPPLService((NodeClient) client));
SQLQueryHelper.getInstance().setSqlService(createSQLService((NodeClient) client));
TransportPPLService.getInstance().setPplService(createPPLService((NodeClient) client));
TransportSQLService.getInstance().setSqlService(createSQLService((NodeClient) client));

return super
.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.plugin.transport;

import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.commons.sql.action.SQLActions;
import org.opensearch.commons.sql.action.TransportPPLQueryRequest;
import org.opensearch.commons.sql.action.TransportPPLQueryResponse;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.commons.utils.TransportHelpersKt;

import java.io.IOException;

/**
* Send PPL query transport action
*/
public class TransportPPLQueryAction extends HandledTransportAction<ActionRequest, TransportPPLQueryResponse> {
private final Client client;

@Inject
public TransportPPLQueryAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(SQLActions.SEND_PPL_QUERY_NAME, transportService, actionFilters, TransportPPLQueryRequest::new);
this.client = client;
}

/**
* {@inheritDoc}
* Transform the request and call super.doExecute() to support call from other plugins.
*/
@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<TransportPPLQueryResponse> listener) {
TransportPPLQueryRequest transformedRequest;
if (request instanceof TransportPPLQueryRequest) {
transformedRequest = (TransportPPLQueryRequest) request;
} else {
transformedRequest = TransportHelpersKt.recreateObject(request, streamInput -> {
try {
return new TransportPPLQueryRequest(streamInput);
} catch (IOException e) {
listener.onFailure(e);
}
return null;
}
);
}
TransportPPLService.execute(transformedRequest, listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionListener;
import org.opensearch.commons.sql.action.TransportQueryRequest;
import org.opensearch.commons.sql.action.TransportQueryResponse;
import org.opensearch.commons.sql.action.TransportPPLQueryRequest;
import org.opensearch.commons.sql.action.TransportPPLQueryResponse;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.rest.RestStatus;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
Expand All @@ -18,8 +18,6 @@
import org.opensearch.sql.exception.QueryEngineException;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.executor.ExecutionEngine;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.ppl.PPLService;
import org.opensearch.sql.ppl.domain.PPLQueryRequest;
import org.opensearch.sql.protocol.response.QueryResult;
Expand All @@ -32,33 +30,36 @@

import static org.opensearch.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.rest.RestStatus.SERVICE_UNAVAILABLE;
import static org.opensearch.sql.plugin.rest.RestPPLQueryAction.QUERY_API_ENDPOINT;
import static org.opensearch.sql.protocol.response.format.JsonResponseFormatter.Style.PRETTY;

public class PPLQueryHelper {
public class TransportPPLService {
private static PPLService pplService;
private static PPLQueryHelper INSTANCE;
private static TransportPPLService INSTANCE;

public static synchronized PPLQueryHelper getInstance() {
private TransportPPLService() {}

public static synchronized TransportPPLService getInstance() {
if (INSTANCE == null) {
INSTANCE = new PPLQueryHelper();
INSTANCE = new TransportPPLService();
}
return INSTANCE;
}

public void setPplService(PPLService pplService) {
PPLQueryHelper.pplService = pplService;
TransportPPLService.pplService = pplService;
}

public static void execute(TransportQueryRequest request, ActionListener<TransportQueryResponse> listener) {
// convert the TransportQueryRequest request to PPLQueryRequest
public static void execute(TransportPPLQueryRequest request, ActionListener<TransportPPLQueryResponse> listener) {
// convert the TransportPPLQueryRequest request to PPLQueryRequest
PPLQueryRequest pplRequest = createPPLQueryRequest(request);
// execute by ppl service
pplService.execute(pplRequest, createListener(pplRequest, listener));
}

private static ResponseListener<ExecutionEngine.QueryResponse> createListener(
PPLQueryRequest pplRequest,
ActionListener<TransportQueryResponse> listener
ActionListener<TransportPPLQueryResponse> listener
) {
Format format = pplRequest.format();
ResponseFormatter<QueryResult> formatter;
Expand All @@ -83,21 +84,19 @@ public void onResponse(ExecutionEngine.QueryResponse response) {
@Override
public void onFailure(Exception e) {
if (isClientError(e)) {
Metrics.getInstance().getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_CUS).increment();
reportError(listener, e, BAD_REQUEST);
} else {
Metrics.getInstance().getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS).increment();
reportError(listener, e, SERVICE_UNAVAILABLE);
}
}
};
}

private static void sendResponse(String content, ActionListener<TransportQueryResponse> listener) {
listener.onResponse(new TransportQueryResponse(content));
private static void sendResponse(String content, ActionListener<TransportPPLQueryResponse> listener) {
listener.onResponse(new TransportPPLQueryResponse(content));
}

private static void reportError(ActionListener<TransportQueryResponse> listener, final Exception exception, final RestStatus status) {
private static void reportError(ActionListener<TransportPPLQueryResponse> listener, final Exception exception, final RestStatus status) {
listener.onFailure(new OpenSearchStatusException(exception.getMessage(), status));
}

Expand All @@ -112,8 +111,7 @@ private static boolean isClientError(Exception e) {
|| e instanceof SyntaxCheckException;
}

private static PPLQueryRequest createPPLQueryRequest(TransportQueryRequest request) {
return new PPLQueryRequest(request.getQuery(), null,"/_plugins/_ppl");
private static PPLQueryRequest createPPLQueryRequest(TransportPPLQueryRequest request) {
return new PPLQueryRequest(request.getQuery(), null,QUERY_API_ENDPOINT);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,24 @@
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.ConfigConstants;
import org.opensearch.commons.authuser.User;
import org.opensearch.commons.sql.action.SQLActions;
import org.opensearch.commons.sql.action.TransportQueryRequest;
import org.opensearch.commons.sql.action.TransportQueryResponse;
import org.opensearch.commons.sql.model.QueryType;
import org.opensearch.commons.sql.action.TransportSQLQueryRequest;
import org.opensearch.commons.sql.action.TransportSQLQueryResponse;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.commons.utils.TransportHelpersKt;

import java.io.IOException;

/**
* Send SQL/PPL query transport action
* Send SQL query transport action
*/
public class TransportQueryAction extends HandledTransportAction<ActionRequest, TransportQueryResponse> {
public class TransportSQLQueryAction extends HandledTransportAction<ActionRequest, TransportSQLQueryResponse> {
private final Client client;

@Inject
public TransportQueryAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(SQLActions.SEND_SQL_QUERY_NAME, transportService, actionFilters, TransportQueryRequest::new);
public TransportSQLQueryAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(SQLActions.SEND_SQL_QUERY_NAME, transportService, actionFilters, TransportSQLQueryRequest::new);
this.client = client;
}

Expand All @@ -42,25 +38,22 @@ public TransportQueryAction(TransportService transportService, ActionFilters act
* Transform the request and call super.doExecute() to support call from other plugins.
*/
@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<TransportQueryResponse> listener) {
TransportQueryRequest transformedRequest;
if (request instanceof TransportQueryRequest) {
transformedRequest = (TransportQueryRequest) request;
protected void doExecute(Task task, ActionRequest request, ActionListener<TransportSQLQueryResponse> listener) {
TransportSQLQueryRequest transformedRequest;
if (request instanceof TransportSQLQueryRequest) {
transformedRequest = (TransportSQLQueryRequest) request;
} else {
transformedRequest = TransportHelpersKt.recreateObject(request, streamInput -> {
try {
return new TransportQueryRequest(streamInput);
return new TransportSQLQueryRequest(streamInput);
} catch (IOException e) {
listener.onFailure(e);
}
return null;
}
);
}
if (transformedRequest.getType() == QueryType.PPL) {
PPLQueryHelper.execute(transformedRequest, listener);
} else if (transformedRequest.getType() == QueryType.SQL) {
SQLQueryHelper.execute(transformedRequest, listener);
}

TransportSQLService.execute(transformedRequest, listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import org.json.JSONObject;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionListener;
import org.opensearch.commons.sql.action.TransportQueryRequest;
import org.opensearch.commons.sql.action.TransportQueryResponse;
import org.opensearch.commons.sql.action.TransportSQLQueryRequest;
import org.opensearch.commons.sql.action.TransportSQLQueryResponse;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.rest.RestStatus;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
Expand All @@ -19,8 +19,6 @@
import org.opensearch.sql.exception.QueryEngineException;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.executor.ExecutionEngine;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.protocol.response.QueryResult;
import org.opensearch.sql.protocol.response.format.CsvResponseFormatter;
import org.opensearch.sql.protocol.response.format.Format;
Expand All @@ -34,33 +32,36 @@

import static org.opensearch.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.rest.RestStatus.SERVICE_UNAVAILABLE;
import static org.opensearch.sql.legacy.plugin.RestSqlAction.QUERY_API_ENDPOINT;
import static org.opensearch.sql.protocol.response.format.JsonResponseFormatter.Style.PRETTY;

public class SQLQueryHelper {
public class TransportSQLService {
private static SQLService sqlService;
private static SQLQueryHelper INSTANCE;
private static TransportSQLService INSTANCE;

public static synchronized SQLQueryHelper getInstance() {
private TransportSQLService(){}

public static synchronized TransportSQLService getInstance() {
if (INSTANCE == null) {
INSTANCE = new SQLQueryHelper();
INSTANCE = new TransportSQLService();
}
return INSTANCE;
}

public void setSqlService(SQLService sqlService) {
SQLQueryHelper.sqlService = sqlService;
TransportSQLService.sqlService = sqlService;
}

public static void execute(TransportQueryRequest request, ActionListener<TransportQueryResponse> listener) {
// convert the TransportQueryRequest request to SQLQueryRequest
public static void execute(TransportSQLQueryRequest request, ActionListener<TransportSQLQueryResponse> listener) {
// convert the TransportSQLQueryRequest request to SQLQueryRequest
SQLQueryRequest sqlRequest = createSQLQueryRequest(request);
// execute by sql service
sqlService.execute(sqlRequest, createListener(sqlRequest, listener));
}

private static ResponseListener<ExecutionEngine.QueryResponse> createListener(
SQLQueryRequest sqlRequest,
ActionListener<TransportQueryResponse> listener
ActionListener<TransportSQLQueryResponse> listener
) {
Format format = sqlRequest.format();
ResponseFormatter<QueryResult> formatter;
Expand All @@ -83,21 +84,19 @@ public void onResponse(ExecutionEngine.QueryResponse response) {
@Override
public void onFailure(Exception e) {
if (isClientError(e)) {
Metrics.getInstance().getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_CUS).increment();
reportError(listener, e, BAD_REQUEST);
} else {
Metrics.getInstance().getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS).increment();
reportError(listener, e, SERVICE_UNAVAILABLE);
}
}
};
}

private static void sendResponse(String content, ActionListener<TransportQueryResponse> listener) {
listener.onResponse(new TransportQueryResponse(content));
private static void sendResponse(String content, ActionListener<TransportSQLQueryResponse> listener) {
listener.onResponse(new TransportSQLQueryResponse(content));
}

private static void reportError(ActionListener<TransportQueryResponse> listener, final Exception exception, final RestStatus status) {
private static void reportError(ActionListener<TransportSQLQueryResponse> listener, final Exception exception, final RestStatus status) {
listener.onFailure(new OpenSearchStatusException(exception.getMessage(), status));
}

Expand All @@ -112,9 +111,9 @@ private static boolean isClientError(Exception e) {
|| e instanceof SyntaxCheckException;
}

private static SQLQueryRequest createSQLQueryRequest(TransportQueryRequest request) {
private static SQLQueryRequest createSQLQueryRequest(TransportSQLQueryRequest request) {
String query = request.getQuery();
String jsonContent = "{\"query\": \"" + query + "\"}";
return new SQLQueryRequest(new JSONObject(jsonContent), query ,"/_plugins/_sql", Collections.emptyMap());
return new SQLQueryRequest(new JSONObject(jsonContent), query ,QUERY_API_ENDPOINT, Collections.emptyMap());
}
}

0 comments on commit e6b6cee

Please sign in to comment.