Skip to content

Commit

Permalink
use model type to check local or remote model
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 28, 2025
1 parent cc630d0 commit 9b21e79
Showing 1 changed file with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
}

if (algorithm != null) {
MLPredictionTaskRequest mlPredictionTaskRequest = getRequest(modelId, algorithm, request);
MLPredictionTaskRequest mlPredictionTaskRequest = getRequest(modelId, functionName.get().name(), algorithm, request);
return channel -> client
.execute(MLPredictionTaskAction.INSTANCE, mlPredictionTaskRequest, new RestToXContentListener<>(channel));
}
Expand All @@ -103,7 +103,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
client
.execute(
MLPredictionTaskAction.INSTANCE,
getRequest(modelId, algoName, request),
getRequest(modelId, functionName.get().name(), algoName, request),
new RestToXContentListener<>(channel)
);
}, e -> {
Expand Down Expand Up @@ -132,12 +132,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
* @return MLPredictionTaskRequest
*/
@VisibleForTesting
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException {
String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request);
ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request));
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase(Locale.ROOT)))
} else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT)))
&& !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
} else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
Expand All @@ -148,7 +148,7 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest

XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLInput mlInput = MLInput.parse(parser, algorithm, actionType);
MLInput mlInput = MLInput.parse(parser, userAlgorithm, actionType);
return new MLPredictionTaskRequest(modelId, mlInput, null, tenantId);
}

Expand Down

0 comments on commit 9b21e79

Please sign in to comment.