Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Take more care that memory estimation uses unique named pipes #60395

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
package org.elasticsearch.xpack.ml.integration;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
Expand All @@ -22,6 +24,8 @@
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
Expand Down Expand Up @@ -127,6 +131,49 @@ public void testTrainingPercentageIsApplied() throws IOException {
lessThanOrEqualTo(allDataUsedForTraining));
}

public void testSimultaneousExplainSameConfig() throws IOException {

final int simultaneousInvocationCount = 10;

String sourceIndex = "test-simultaneous-explain";
RegressionIT.indexData(sourceIndex, 100, 0);

DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId("dfa-simultaneous-explain-" + sourceIndex)
.setSource(new DataFrameAnalyticsSource(new String[]{sourceIndex},
QueryProvider.fromParsedQuery(QueryBuilders.matchAllQuery()),
null))
.setAnalysis(new Regression(RegressionIT.DEPENDENT_VARIABLE_FIELD,
BoostedTreeParams.builder().build(),
null,
100.0,
null,
null,
null))
.buildForExplain();

List<ActionFuture<ExplainDataFrameAnalyticsAction.Response>> futures = new ArrayList<>();

for (int i = 0; i < simultaneousInvocationCount; ++i) {
futures.add(client().execute(ExplainDataFrameAnalyticsAction.INSTANCE, new PutDataFrameAnalyticsAction.Request(config)));
}

ExplainDataFrameAnalyticsAction.Response previous = null;
for (ActionFuture<ExplainDataFrameAnalyticsAction.Response> future : futures) {
// The main purpose of this test is that actionGet() here will throw an exception
// if any of the simultaneous calls returns an error due to interaction between
// the many estimation processes that get run
ExplainDataFrameAnalyticsAction.Response current = future.actionGet(10000);
if (previous != null) {
// A secondary check the test can perform is that the multiple invocations
// return the same result (but it was failures due to unwanted interactions
// that caused this test to be written)
assertEquals(previous, current);
}
previous = current;
}
}

@Override
boolean supportsInference() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

public class NativeMemoryUsageEstimationProcessFactory implements AnalyticsProcessFactory<MemoryUsageEstimationResult> {
Expand All @@ -39,11 +40,13 @@ public class NativeMemoryUsageEstimationProcessFactory implements AnalyticsProce

private final Environment env;
private final NativeController nativeController;
private final AtomicLong counter;
private volatile Duration processConnectTimeout;

public NativeMemoryUsageEstimationProcessFactory(Environment env, NativeController nativeController, ClusterService clusterService) {
this.env = Objects.requireNonNull(env);
this.nativeController = Objects.requireNonNull(nativeController);
this.counter = new AtomicLong(0);
setProcessConnectTimeout(MachineLearning.PROCESS_CONNECT_TIMEOUT.get(env.settings()));
clusterService.getClusterSettings().addSettingsUpdateConsumer(
MachineLearning.PROCESS_CONNECT_TIMEOUT, this::setProcessConnectTimeout);
Expand All @@ -61,8 +64,13 @@ public NativeMemoryUsageEstimationProcess createAnalyticsProcess(
ExecutorService executorService,
Consumer<String> onProcessCrash) {
List<Path> filesToDelete = new ArrayList<>();
// The config ID passed to the process pipes is only used to make the file names unique. Since memory estimation can be
// called many times in quick succession for the same config the config ID alone is not sufficient to guarantee that the
// memory estimation process pipe names are unique. Therefore an increasing counter value is appended to the config ID
// to ensure uniqueness between calls.
ProcessPipes processPipes = new ProcessPipes(
env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, config.getId(), false, false, true, false, false);
env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, config.getId() + "_" + counter.incrementAndGet(),
false, false, true, false, false);

createNativeProcess(config.getId(), analyticsProcessConfig, filesToDelete, processPipes);

Expand Down