Skip to content

Commit

Permalink
[ML][HLRC] Add data frame analytics regression analysis (#46024)
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitris-athanasiou authored Aug 28, 2019
1 parent fd3488d commit eab6425
Show file tree
Hide file tree
Showing 7 changed files with 382 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
new NamedXContentRegistry.Entry(
DataFrameAnalysis.class,
OutlierDetection.NAME,
(p, c) -> OutlierDetection.fromXContent(p)));
(p, c) -> OutlierDetection.fromXContent(p)),
new NamedXContentRegistry.Entry(
DataFrameAnalysis.class,
Regression.NAME,
(p, c) -> Regression.fromXContent(p)));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;

public class Regression implements DataFrameAnalysis {

public static Regression fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

public static Builder builder(String dependentVariable) {
return new Builder(dependentVariable);
}

public static final ParseField NAME = new ParseField("regression");

static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
static final ParseField LAMBDA = new ParseField("lambda");
static final ParseField GAMMA = new ParseField("gamma");
static final ParseField ETA = new ParseField("eta");
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");

private static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), true,
a -> new Regression(
(String) a[0],
(Double) a[1],
(Double) a[2],
(Double) a[3],
(Integer) a[4],
(Double) a[5],
(String) a[6],
(Double) a[7]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
}

private final String dependentVariable;
private final Double lambda;
private final Double gamma;
private final Double eta;
private final Integer maximumNumberTrees;
private final Double featureBagFraction;
private final String predictionFieldName;
private final Double trainingPercent;

private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
@Nullable Double trainingPercent) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
this.eta = eta;
this.maximumNumberTrees = maximumNumberTrees;
this.featureBagFraction = featureBagFraction;
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent;
}

@Override
public String getName() {
return NAME.getPreferredName();
}

public String getDependentVariable() {
return dependentVariable;
}

public Double getLambda() {
return lambda;
}

public Double getGamma() {
return gamma;
}

public Double getEta() {
return eta;
}

public Integer getMaximumNumberTrees() {
return maximumNumberTrees;
}

public Double getFeatureBagFraction() {
return featureBagFraction;
}

public String getPredictionFieldName() {
return predictionFieldName;
}

public Double getTrainingPercent() {
return trainingPercent;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
if (lambda != null) {
builder.field(LAMBDA.getPreferredName(), lambda);
}
if (gamma != null) {
builder.field(GAMMA.getPreferredName(), gamma);
}
if (eta != null) {
builder.field(ETA.getPreferredName(), eta);
}
if (maximumNumberTrees != null) {
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
}
if (featureBagFraction != null) {
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
if (predictionFieldName != null) {
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
if (trainingPercent != null) {
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
}
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Regression that = (Regression) o;
return Objects.equals(dependentVariable, that.dependentVariable)
&& Objects.equals(lambda, that.lambda)
&& Objects.equals(gamma, that.gamma)
&& Objects.equals(eta, that.eta)
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(trainingPercent, that.trainingPercent);
}

@Override
public String toString() {
return Strings.toString(this);
}

public static class Builder {
private String dependentVariable;
private Double lambda;
private Double gamma;
private Double eta;
private Integer maximumNumberTrees;
private Double featureBagFraction;
private String predictionFieldName;
private Double trainingPercent;

private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
}

public Builder setLambda(Double lambda) {
this.lambda = lambda;
return this;
}

public Builder setGamma(Double gamma) {
this.gamma = gamma;
return this;
}

public Builder setEta(Double eta) {
this.eta = eta;
return this;
}

public Builder setMaximumNumberTrees(Integer maximumNumberTrees) {
this.maximumNumberTrees = maximumNumberTrees;
return this;
}

public Builder setFeatureBagFraction(Double featureBagFraction) {
this.featureBagFraction = featureBagFraction;
return this;
}

public Builder setPredictionFieldName(String predictionFieldName) {
this.predictionFieldName = predictionFieldName;
return this;
}

public Builder setTrainingPercent(Double trainingPercent) {
this.trainingPercent = trainingPercent;
return this;
}

public Regression build() {
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1215,9 +1215,9 @@ public void testDeleteCalendarEvent() throws IOException {
assertThat(remainingIds, not(hasItem(deletedEvent)));
}

public void testPutDataFrameAnalyticsConfig() throws Exception {
public void testPutDataFrameAnalyticsConfig_GivenOutlierDetectionAnalysis() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String configId = "put-test-config";
String configId = "test-put-df-analytics-outlier-detection";
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
.setId(configId)
.setSource(DataFrameAnalyticsSource.builder()
Expand Down Expand Up @@ -1247,6 +1247,41 @@ public void testPutDataFrameAnalyticsConfig() throws Exception {
assertThat(createdConfig.getDescription(), equalTo("some description"));
}

public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String configId = "test-put-df-analytics-regression";
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
.setId(configId)
.setSource(DataFrameAnalyticsSource.builder()
.setIndex("put-test-source-index")
.build())
.setDest(DataFrameAnalyticsDest.builder()
.setIndex("put-test-dest-index")
.build())
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression
.builder("my_dependent_variable")
.setTrainingPercent(80.0)
.build())
.setDescription("this is a regression")
.build();

createIndex("put-test-source-index", defaultMappingForTest());

PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute(
new PutDataFrameAnalyticsRequest(config),
machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync);
DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig();
assertThat(createdConfig.getId(), equalTo(config.getId()));
assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex()));
assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder()))); // default value
assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex()));
assertThat(createdConfig.getDest().getResultsField(), equalTo("ml")); // default value
assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis()));
assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields()));
assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value
assertThat(createdConfig.getDescription(), equalTo("this is a regression"));
}

public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String configId = "get-test-config";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package org.elasticsearch.client;

import com.fasterxml.jackson.core.JsonParseException;

import org.apache.http.HttpEntity;
import org.apache.http.HttpHost;
import org.apache.http.HttpResponse;
Expand Down Expand Up @@ -677,7 +676,7 @@ public void testDefaultNamedXContents() {

public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(36, namedXContents.size());
assertEquals(37, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
Expand Down Expand Up @@ -711,8 +710,9 @@ public void testProvidedNamedXContents() {
assertTrue(names.contains(ShrinkAction.NAME));
assertTrue(names.contains(FreezeAction.NAME));
assertTrue(names.contains(SetPriorityAction.NAME));
assertEquals(Integer.valueOf(1), categories.get(DataFrameAnalysis.class));
assertEquals(Integer.valueOf(2), categories.get(DataFrameAnalysis.class));
assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName()));
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
assertTrue(names.contains(TimeSyncConfig.NAME));
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
Expand Down Expand Up @@ -2923,16 +2924,28 @@ public void testPutDataFrameAnalytics() throws Exception {
.build();
// end::put-data-frame-analytics-dest-config

// tag::put-data-frame-analytics-analysis-default
// tag::put-data-frame-analytics-outlier-detection-default
DataFrameAnalysis outlierDetection = OutlierDetection.createDefault(); // <1>
// end::put-data-frame-analytics-analysis-default
// end::put-data-frame-analytics-outlier-detection-default

// tag::put-data-frame-analytics-analysis-customized
// tag::put-data-frame-analytics-outlier-detection-customized
DataFrameAnalysis outlierDetectionCustomized = OutlierDetection.builder() // <1>
.setMethod(OutlierDetection.Method.DISTANCE_KNN) // <2>
.setNNeighbors(5) // <3>
.build();
// end::put-data-frame-analytics-analysis-customized
// end::put-data-frame-analytics-outlier-detection-customized

// tag::put-data-frame-analytics-regression
DataFrameAnalysis regression = Regression.builder("my_dependent_variable") // <1>
.setLambda(1.0) // <2>
.setGamma(5.5) // <3>
.setEta(5.5) // <4>
.setMaximumNumberTrees(50) // <5>
.setFeatureBagFraction(0.4) // <6>
.setPredictionFieldName("my_prediction_field_name") // <7>
.setTrainingPercent(50.0) // <8>
.build();
// end::put-data-frame-analytics-regression

// tag::put-data-frame-analytics-analyzed-fields
FetchSourceContext analyzedFields =
Expand Down
Loading

0 comments on commit eab6425

Please sign in to comment.