Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Add queue_capacity setting to start deployment API #79433

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
public static final ParseField WAIT_FOR = new ParseField("wait_for");
public static final ParseField INFERENCE_THREADS = TaskParams.INFERENCE_THREADS;
public static final ParseField MODEL_THREADS = TaskParams.MODEL_THREADS;
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;

public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);

Expand All @@ -69,6 +70,7 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR);
PARSER.declareInt(Request::setInferenceThreads, INFERENCE_THREADS);
PARSER.declareInt(Request::setModelThreads, MODEL_THREADS);
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
}

public static Request parseRequest(String modelId, XContentParser parser) {
Expand All @@ -87,6 +89,7 @@ public static Request parseRequest(String modelId, XContentParser parser) {
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
private int modelThreads = 1;
private int inferenceThreads = 1;
private int queueCapacity = 1024;

private Request() {}

Expand All @@ -101,6 +104,7 @@ public Request(StreamInput in) throws IOException {
waitForState = in.readEnum(AllocationStatus.State.class);
modelThreads = in.readVInt();
inferenceThreads = in.readVInt();
queueCapacity = in.readVInt();
}

public final void setModelId(String modelId) {
Expand Down Expand Up @@ -144,6 +148,14 @@ public void setInferenceThreads(int inferenceThreads) {
this.inferenceThreads = inferenceThreads;
}

public int getQueueCapacity() {
return queueCapacity;
}

public void setQueueCapacity(int queueCapacity) {
this.queueCapacity = queueCapacity;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Expand All @@ -152,6 +164,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(waitForState);
out.writeVInt(modelThreads);
out.writeVInt(inferenceThreads);
out.writeVInt(queueCapacity);
}

@Override
Expand All @@ -162,6 +175,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(WAIT_FOR.getPreferredName(), waitForState);
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
builder.endObject();
return builder;
}
Expand All @@ -183,12 +197,15 @@ public ActionRequestValidationException validate() {
if (inferenceThreads < 1) {
validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer");
}
if (queueCapacity < 1) {
validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
}
return validationException.validationErrors().isEmpty() ? null : validationException;
}

@Override
public int hashCode() {
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads);
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads, queueCapacity);
}

@Override
Expand All @@ -204,7 +221,8 @@ public boolean equals(Object obj) {
&& Objects.equals(timeout, other.timeout)
&& Objects.equals(waitForState, other.waitForState)
&& modelThreads == other.modelThreads
&& inferenceThreads == other.inferenceThreads;
&& inferenceThreads == other.inferenceThreads
&& queueCapacity == other.queueCapacity;
}

@Override
Expand All @@ -226,16 +244,20 @@ public static boolean mayAllocateToNode(DiscoveryNode node) {
private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
public static final ParseField MODEL_THREADS = new ParseField("model_threads");
public static final ParseField INFERENCE_THREADS = new ParseField("inference_threads");
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");

private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_deployment_params",
true,
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3])
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3], (int) a[4])
);

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
}

public static TaskParams fromXContent(XContentParser parser) {
Expand All @@ -253,28 +275,22 @@ public static TaskParams fromXContent(XContentParser parser) {
private final long modelBytes;
private final int inferenceThreads;
private final int modelThreads;
private final int queueCapacity;

public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads) {
public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) {
this.modelId = Objects.requireNonNull(modelId);
this.modelBytes = modelBytes;
if (modelBytes < 0) {
throw new IllegalArgumentException("modelBytes must be non-negative");
}
this.inferenceThreads = inferenceThreads;
if (inferenceThreads < 1) {
throw new IllegalArgumentException(INFERENCE_THREADS + " must be positive");
}
this.modelThreads = modelThreads;
if (modelThreads < 1) {
throw new IllegalArgumentException(MODEL_THREADS + " must be positive");
}
Comment on lines -260 to -270
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these validations should remain. Especially modelBytes as negative bytes will break a ton of logic down stream (node allocation, etc.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We validate those elsewhere. For example, we fetch model bytes from TrainedModelDefinitionDoc where we check the value is positive. Same goes for the threading and queue params which are validated on the start request. The benefit of having no validation here is that this object is persisted in the cluster state and deciding to change the range of valid values in the future may result in the cluster not being able to start. We have adequate validations around those in case someone changes the cluster state directly (e.g. native process will fail to launch for invalid values).

this.queueCapacity = queueCapacity;
}

public TaskParams(StreamInput in) throws IOException {
this.modelId = in.readString();
this.modelBytes = in.readVLong();
this.modelBytes = in.readLong();
this.inferenceThreads = in.readVInt();
this.modelThreads = in.readVInt();
this.queueCapacity = in.readVInt();
}

public String getModelId() {
Expand All @@ -293,9 +309,10 @@ public Version getMinimalSupportedVersion() {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeVLong(modelBytes);
out.writeLong(modelBytes);
out.writeVInt(inferenceThreads);
out.writeVInt(modelThreads);
out.writeVInt(queueCapacity);
}

@Override
Expand All @@ -305,13 +322,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads);
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads, queueCapacity);
}

@Override
Expand All @@ -323,7 +341,8 @@ public boolean equals(Object o) {
return Objects.equals(modelId, other.modelId)
&& modelBytes == other.modelBytes
&& inferenceThreads == other.inferenceThreads
&& modelThreads == other.modelThreads;
&& modelThreads == other.modelThreads
&& queueCapacity == other.queueCapacity;
}

@Override
Expand All @@ -342,6 +361,15 @@ public int getInferenceThreads() {
public int getModelThreads() {
return modelThreads;
}

public int getQueueCapacity() {
return queueCapacity;
}

@Override
public String toString() {
return Strings.toString(this);
}
}

public interface TaskMatcher {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,7 @@ public class CreateTrainedModelAllocationActionRequestTests extends AbstractWire

@Override
protected Request createTestInstance() {
return new Request(
new StartTrainedModelDeploymentAction.TaskParams(
randomAlphaOfLength(10),
randomNonNegativeLong(),
randomIntBetween(1, 8),
randomIntBetween(1, 8)
)
);
return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.io.IOException;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
Expand Down Expand Up @@ -53,6 +54,9 @@ public static Request createRandom() {
if (randomBoolean()) {
request.setModelThreads(randomIntBetween(1, 8));
}
if (randomBoolean()) {
request.setQueueCapacity(randomIntBetween(1, 10000));
}
return request;
}

Expand Down Expand Up @@ -95,4 +99,33 @@ public void testValidate_GivenModelThreadsIsNegative() {
assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[model_threads] must be a positive integer"));
}

public void testValidate_GivenQueueCapacityIsZero() {
Request request = createRandom();
request.setQueueCapacity(0);

ActionRequestValidationException e = request.validate();

assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[queue_capacity] must be a positive integer"));
}

public void testValidate_GivenQueueCapacityIsNegative() {
Request request = createRandom();
request.setQueueCapacity(randomIntBetween(Integer.MIN_VALUE, -1));

ActionRequestValidationException e = request.validate();

assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[queue_capacity] must be a positive integer"));
}

public void testDefaults() {
Request request = new Request(randomAlphaOfLength(10));
assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(20)));
assertThat(request.getWaitForState(), equalTo(AllocationStatus.State.STARTED));
assertThat(request.getInferenceThreads(), equalTo(1));
assertThat(request.getModelThreads(), equalTo(1));
assertThat(request.getQueueCapacity(), equalTo(1024));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ public static StartTrainedModelDeploymentAction.TaskParams createRandom() {
randomAlphaOfLength(10),
randomNonNegativeLong(),
randomIntBetween(1, 8),
randomIntBetween(1, 8)
randomIntBetween(1, 8),
randomIntBetween(1, 10000)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentTaskParamsTests;

import java.io.IOException;
import java.util.List;
Expand All @@ -31,9 +32,7 @@
public class TrainedModelAllocationTests extends AbstractSerializingTestCase<TrainedModelAllocation> {

public static TrainedModelAllocation randomInstance() {
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(
new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1)
);
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList());
for (String node : nodes) {
if (randomBoolean()) {
Expand Down Expand Up @@ -249,7 +248,7 @@ private static DiscoveryNode buildNode() {
}

private static StartTrainedModelDeploymentAction.TaskParams randomParams() {
return new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1);
return StartTrainedModelDeploymentTaskParamsTests.createRandom();
}

private static void assertUnchanged(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
Expand All @@ -35,6 +34,7 @@
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction;
Expand Down Expand Up @@ -161,7 +161,8 @@ protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Requ
trainedModelConfig.getModelId(),
modelBytes,
request.getInferenceThreads(),
request.getModelThreads()
request.getModelThreads(),
request.getQueueCapacity()
);
PersistentTasksCustomMetadata persistentTasks = clusterService.state().getMetadata().custom(
PersistentTasksCustomMetadata.TYPE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ TrainedModelDeploymentTask getTask(String modelId) {
}

void prepareModelToLoad(StartTrainedModelDeploymentAction.TaskParams taskParams) {
logger.debug(() -> new ParameterizedMessage("[{}] preparing to load model with task params: {}",
taskParams.getModelId(), taskParams));
TrainedModelDeploymentTask task = (TrainedModelDeploymentTask) taskManager.register(
TRAINED_MODEL_ALLOCATION_TASK_TYPE,
TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX + taskParams.getModelId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,11 @@ class ProcessContext {
this.task = Objects.requireNonNull(task);
resultProcessor = new PyTorchResultProcessor(task.getModelId());
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
this.executorService = new ProcessWorkerExecutorService(threadPool.getThreadContext(), "pytorch_inference", 1024);
this.executorService = new ProcessWorkerExecutorService(
threadPool.getThreadContext(),
"pytorch_inference",
task.getParams().getQueueCapacity()
);
}

PyTorchResultProcessor getResultProcessor() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ public class ProcessWorkerExecutorService extends AbstractExecutorService {
/**
* @param contextHolder the thread context holder
* @param processName the name of the process to be used in logging
* @param queueSize the size of the queue holding operations. If an operation is added
* @param queueCapacity the capacity of the queue holding operations. If an operation is added
* for execution when the queue is full a 429 error is thrown.
*/
@SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors")
public ProcessWorkerExecutorService(ThreadContext contextHolder, String processName, int queueSize) {
public ProcessWorkerExecutorService(ThreadContext contextHolder, String processName, int queueCapacity) {
this.contextHolder = Objects.requireNonNull(contextHolder);
this.processName = Objects.requireNonNull(processName);
this.queue = new LinkedBlockingQueue<>(queueSize);
this.queue = new LinkedBlockingQueue<>(queueCapacity);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.elasticsearch.rest.RestRequest.Method.POST;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.INFERENCE_THREADS;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_THREADS;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.TIMEOUT;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.WAIT_FOR;

Expand Down Expand Up @@ -59,6 +60,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
));
request.setInferenceThreads(restRequest.paramAsInt(INFERENCE_THREADS.getPreferredName(), request.getInferenceThreads()));
request.setModelThreads(restRequest.paramAsInt(MODEL_THREADS.getPreferredName(), request.getModelThreads()));
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
}

return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));
Expand Down
Loading