diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java index 7a33b869d9..294a672bbc 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java @@ -24,6 +24,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.common.FunctionName; @@ -161,13 +162,15 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod .build(); MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput); - transportService - .sendRequest( - getNodeById(coordinatingNodeId), - MLForwardAction.NAME, - deployModelDoneMessage, - new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new) - ); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + transportService + .sendRequest( + getNodeById(coordinatingNodeId), + MLForwardAction.NAME, + deployModelDoneMessage, + new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new) + ); + } }, e -> { MLForwardInput mlForwardInput = MLForwardInput .builder() @@ -179,13 +182,15 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod .build(); MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput); - transportService - .sendRequest( - getNodeById(coordinatingNodeId), - MLForwardAction.NAME, - deployModelDoneMessage, - new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new) - ); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + transportService + .sendRequest( + getNodeById(coordinatingNodeId), + MLForwardAction.NAME, + deployModelDoneMessage, + new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new) + ); + } }) );