Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A Generic ML Command in PPL #971

Merged
merged 5 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;
import static org.opensearch.sql.utils.MLCommonsConstants.ACTION;
import static org.opensearch.sql.utils.MLCommonsConstants.MODELID;
import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP;
import static org.opensearch.sql.utils.MLCommonsConstants.STATUS;
import static org.opensearch.sql.utils.MLCommonsConstants.TASKID;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAINANDPREDICT;
import static org.opensearch.sql.utils.SystemIndexUtils.CATALOGS_TABLE_NAME;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -50,6 +57,7 @@
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.ML;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.RareTopN;
Expand Down Expand Up @@ -83,6 +91,7 @@
import org.opensearch.sql.planner.logical.LogicalEval;
import org.opensearch.sql.planner.logical.LogicalFilter;
import org.opensearch.sql.planner.logical.LogicalLimit;
import org.opensearch.sql.planner.logical.LogicalML;
import org.opensearch.sql.planner.logical.LogicalMLCommons;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalProject;
Expand Down Expand Up @@ -502,6 +511,19 @@ public LogicalPlan visitAD(AD node, AnalysisContext context) {
return new LogicalAD(child, options);
}

/**
* Build {@link LogicalML} for ml command.
*/
@Override
public LogicalPlan visitML(ML node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
TypeEnvironment currentEnv = context.peek();
node.getOutputSchema(currentEnv).entrySet().stream()
.forEach(v -> currentEnv.define(new Symbol(Namespace.FIELD_NAME, v.getKey()), v.getValue()));

return new LogicalML(child, node.getArguments());
}

/**
* The first argument is always "asc", others are optional.
* Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

package org.opensearch.sql.analysis;

import static org.opensearch.sql.analysis.symbol.Namespace.FIELD_NAME;

import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -82,7 +84,7 @@ public void define(Symbol symbol, ExprType type) {
* @param ref {@link ReferenceExpression}
*/
public void define(ReferenceExpression ref) {
define(new Symbol(Namespace.FIELD_NAME, ref.getAttr()), ref.type());
define(new Symbol(FIELD_NAME, ref.getAttr()), ref.type());
}

public void remove(Symbol symbol) {
Expand All @@ -93,6 +95,14 @@ public void remove(Symbol symbol) {
* Remove ref.
*/
public void remove(ReferenceExpression ref) {
remove(new Symbol(Namespace.FIELD_NAME, ref.getAttr()));
remove(new Symbol(FIELD_NAME, ref.getAttr()));
}

/**
* Clear all fields in the current environment.
*/
public void clearAllFields() {
lookupAllFields(FIELD_NAME).keySet().stream()
.forEach(v -> remove(new Symbol(Namespace.FIELD_NAME, v)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.ML;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.RareTopN;
Expand Down Expand Up @@ -266,6 +267,10 @@ public T visitAD(AD node, C context) {
return visitChildren(node, context);
}

public T visitML(ML node, C context) {
return visitChildren(node, context);
}

public T visitHighlightFunction(HighlightFunction node, C context) {
return visitChildren(node, context);
}
Expand Down
135 changes: 135 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/ML.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.ast.tree;

import static org.opensearch.sql.utils.MLCommonsConstants.ACTION;
import static org.opensearch.sql.utils.MLCommonsConstants.ALGO;
import static org.opensearch.sql.utils.MLCommonsConstants.ASYNC;
import static org.opensearch.sql.utils.MLCommonsConstants.CLUSTERID;
import static org.opensearch.sql.utils.MLCommonsConstants.KMEANS;
import static org.opensearch.sql.utils.MLCommonsConstants.MODELID;
import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIME_FIELD;
import static org.opensearch.sql.utils.MLCommonsConstants.STATUS;
import static org.opensearch.sql.utils.MLCommonsConstants.TASKID;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAINANDPREDICT;

import com.google.common.collect.ImmutableList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.analysis.TypeEnvironment;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.data.type.ExprCoreType;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = true)
@RequiredArgsConstructor
@AllArgsConstructor
public class ML extends UnresolvedPlan {
private UnresolvedPlan child;

private final Map<String, Literal> arguments;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitML(this, context);
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of(this.child);
}

private String getAction() {
return (String) arguments.get(ACTION).getValue();
}

/**
* Generate the ml output schema.
*
* @param env the current environment
* @return the schema
*/
public Map<String, ExprCoreType> getOutputSchema(TypeEnvironment env) {
switch (getAction()) {
case TRAIN:
env.clearAllFields();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why clean all fields?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ml train command will return only model/task id, and status, so remove all fields from input fields.

return getTrainOutputSchema();
case PREDICT:
case TRAINANDPREDICT:
return getPredictOutputSchema();
default:
throw new IllegalArgumentException(
"Action error. Please indicate train, predict or trainandpredict.");
}
}

/**
* Generate the ml predict output schema.
*
* @return the schema
*/
public Map<String, ExprCoreType> getPredictOutputSchema() {
HashMap<String, ExprCoreType> res = new HashMap<>();
String algo = arguments.containsKey(ALGO) ? (String) arguments.get(ALGO).getValue() : null;
switch (algo) {
case KMEANS:
res.put(CLUSTERID, ExprCoreType.INTEGER);
break;
case RCF:
res.put(RCF_SCORE, ExprCoreType.DOUBLE);
if (arguments.containsKey(RCF_TIME_FIELD)) {
res.put(RCF_ANOMALY_GRADE, ExprCoreType.DOUBLE);
res.put((String) arguments.get(RCF_TIME_FIELD).getValue(), ExprCoreType.TIMESTAMP);
} else {
res.put(RCF_ANOMALOUS, ExprCoreType.BOOLEAN);
}
break;
default:
throw new IllegalArgumentException("Unsupported algorithm: " + algo);
}
return res;
}

/**
* Generate the ml train output schema.
*
* @return the schema
*/
public Map<String, ExprCoreType> getTrainOutputSchema() {
boolean isAsync = arguments.containsKey(ASYNC)
? (boolean) arguments.get(ASYNC).getValue() : false;
Map<String, ExprCoreType> res = new HashMap<>(Map.of(STATUS, ExprCoreType.STRING));
if (isAsync) {
res.put(TASKID, ExprCoreType.STRING);
} else {
res.put(MODELID, ExprCoreType.STRING);
}
return res;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Literal;

/**
* ML logical plan.
*/
@Getter
@ToString
@EqualsAndHashCode(callSuper = true)
public class LogicalML extends LogicalPlan {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan deprecated Kmeans/AD in Logical plan?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we will deprecate them in the following version like 2.5, then remove them in 3.0.

private final Map<String, Literal> arguments;

/**
* Constructor of LogicalML.
* @param child child logical plan
* @param arguments arguments of the algorithm
*/
public LogicalML(LogicalPlan child, Map<String, Literal> arguments) {
super(Collections.singletonList(child));
this.arguments = arguments;
}

@Override
public <R, C> R accept(LogicalPlanNodeVisitor<R, C> visitor, C context) {
return visitor.visitML(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ public R visitMLCommons(LogicalMLCommons plan, C context) {
return visitNode(plan, context);
}

public R visitML(LogicalML plan, C context) {
return visitNode(plan, context);
}

public R visitAD(LogicalAD plan, C context) {
return visitNode(plan, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,8 @@ public R visitMLCommons(PhysicalPlan node, C context) {
public R visitAD(PhysicalPlan node, C context) {
return visitNode(node, context);
}

public R visitML(PhysicalPlan node, C context) {
return visitNode(node, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,20 @@ public class MLCommonsConstants {
public static final String CENTROIDS = "centroids";
public static final String ITERATIONS = "iterations";
public static final String DISTANCE_TYPE = "distance_type";

public static final String ACTION = "action";
public static final String TRAIN = "train";
public static final String PREDICT = "predict";
public static final String TRAINANDPREDICT = "trainandpredict";
public static final String ASYNC = "async";
public static final String ALGO = "algorithm";
public static final String KMEANS = "kmeans";
public static final String CLUSTERID = "ClusterID";
public static final String RCF = "rcf";
public static final String RCF_TIME_FIELD = "timeField";
public static final String MODELID = "model_id";
public static final String TASKID = "task_id";
public static final String STATUS = "status";
public static final String LIR = "linear_regression";
public static final String LIR_TARGET = "target";
}
Loading