Skip to content

Commit e2d2778

Browse files
authored
if model version fails to register, update model group accordingly (#1463)
* if model version fails to register, update model group accordingly Signed-off-by: Bhavana Ramaram <[email protected]>
1 parent 6b15702 commit e2d2778

File tree

8 files changed

+117
-44
lines changed

8 files changed

+117
-44
lines changed

common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java

+20-4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
5353
public static final String ACCESS_MODE_FIELD = "access_mode";
5454
public static final String BACKEND_ROLES_FIELD = "backend_roles";
5555
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles";
56-
56+
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";
5757
private FunctionName functionName;
5858
private String modelName;
5959
private String modelGroupId;
@@ -73,6 +73,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
7373
private List<String> backendRoles;
7474
private Boolean addAllBackendRoles;
7575
private AccessMode accessMode;
76+
private Boolean doesVersionCreateModelGroup;
7677

7778
@Builder(toBuilder = true)
7879
public MLRegisterModelInput(FunctionName functionName,
@@ -90,7 +91,8 @@ public MLRegisterModelInput(FunctionName functionName,
9091
String connectorId,
9192
List<String> backendRoles,
9293
Boolean addAllBackendRoles,
93-
AccessMode accessMode
94+
AccessMode accessMode,
95+
Boolean doesVersionCreateModelGroup
9496
) {
9597
if (functionName == null) {
9698
this.functionName = FunctionName.TEXT_EMBEDDING;
@@ -123,6 +125,7 @@ public MLRegisterModelInput(FunctionName functionName,
123125
this.backendRoles = backendRoles;
124126
this.addAllBackendRoles = addAllBackendRoles;
125127
this.accessMode = accessMode;
128+
this.doesVersionCreateModelGroup = doesVersionCreateModelGroup;
126129
}
127130

128131

@@ -157,6 +160,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
157160
if (in.readBoolean()) {
158161
this.accessMode = in.readEnum(AccessMode.class);
159162
}
163+
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
160164
}
161165

162166
@Override
@@ -202,6 +206,7 @@ public void writeTo(StreamOutput out) throws IOException {
202206
} else {
203207
out.writeBoolean(false);
204208
}
209+
out.writeOptionalBoolean(doesVersionCreateModelGroup);
205210
}
206211

207212
@Override
@@ -249,6 +254,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
249254
if (accessMode != null) {
250255
builder.field(ACCESS_MODE_FIELD, accessMode);
251256
}
257+
if (doesVersionCreateModelGroup != null) {
258+
builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup);
259+
}
252260
builder.endObject();
253261
return builder;
254262
}
@@ -267,6 +275,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
267275
List<String> backendRoles = new ArrayList<>();
268276
Boolean addAllBackendRoles = null;
269277
AccessMode accessMode = null;
278+
Boolean doesVersionCreateModelGroup = null;
270279

271280
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
272281
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -318,12 +327,15 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
318327
case ACCESS_MODE_FIELD:
319328
accessMode = AccessMode.from(parser.text());
320329
break;
330+
case DOES_VERSION_CREATE_MODEL_GROUP:
331+
doesVersionCreateModelGroup = parser.booleanValue();
332+
break;
321333
default:
322334
parser.skipChildren();
323335
break;
324336
}
325337
}
326-
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode);
338+
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup);
327339
}
328340

329341
public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException {
@@ -342,6 +354,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
342354
List<String> backendRoles = new ArrayList<>();
343355
AccessMode accessMode = null;
344356
Boolean addAllBackendRoles = null;
357+
Boolean doesVersionCreateModelGroup = null;
345358

346359
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
347360
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -400,11 +413,14 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
400413
case ACCESS_MODE_FIELD:
401414
accessMode = AccessMode.from(parser.text());
402415
break;
416+
case DOES_VERSION_CREATE_MODEL_GROUP:
417+
doesVersionCreateModelGroup = parser.booleanValue();
418+
break;
403419
default:
404420
parser.skipChildren();
405421
break;
406422
}
407423
}
408-
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode);
424+
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup);
409425
}
410426
}

common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java

+16-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{
4646
public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional
4747
public static final String ACCESS_MODE = "access_mode"; //optional
4848
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional
49+
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";
50+
4951

5052
private FunctionName functionName;
5153
private String name;
@@ -65,11 +67,13 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{
6567
private List<String> backendRoles;
6668
private AccessMode accessMode;
6769
private Boolean isAddAllBackendRoles;
70+
private Boolean doesVersionCreateModelGroup;
6871

6972
@Builder(toBuilder = true)
7073
public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List<String> backendRoles,
7174
AccessMode accessMode,
72-
Boolean isAddAllBackendRoles) {
75+
Boolean isAddAllBackendRoles,
76+
Boolean doesVersionCreateModelGroup) {
7377
if (name == null) {
7478
throw new IllegalArgumentException("model name is null");
7579
}
@@ -103,6 +107,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
103107
this.backendRoles = backendRoles;
104108
this.accessMode = accessMode;
105109
this.isAddAllBackendRoles = isAddAllBackendRoles;
110+
this.doesVersionCreateModelGroup = doesVersionCreateModelGroup;
106111
}
107112

108113
public MLRegisterModelMetaInput(StreamInput in) throws IOException{
@@ -128,6 +133,7 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{
128133
accessMode = in.readEnum(AccessMode.class);
129134
}
130135
this.isAddAllBackendRoles = in.readOptionalBoolean();
136+
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
131137
}
132138

133139
@Override
@@ -171,6 +177,7 @@ public void writeTo(StreamOutput out) throws IOException {
171177
out.writeBoolean(false);
172178
}
173179
out.writeOptionalBoolean(isAddAllBackendRoles);
180+
out.writeOptionalBoolean(doesVersionCreateModelGroup);
174181
}
175182

176183
@Override
@@ -206,6 +213,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
206213
if (isAddAllBackendRoles != null) {
207214
builder.field(ADD_ALL_BACKEND_ROLES, isAddAllBackendRoles);
208215
}
216+
if (doesVersionCreateModelGroup != null) {
217+
builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup);
218+
}
209219
builder.endObject();
210220
return builder;
211221
}
@@ -225,6 +235,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
225235
List<String> backendRoles = null;
226236
AccessMode accessMode = null;
227237
Boolean isAddAllBackendRoles = null;
238+
Boolean doesVersionCreateModelGroup = null;
228239

229240
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
230241
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -277,12 +288,15 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
277288
case ADD_ALL_BACKEND_ROLES:
278289
isAddAllBackendRoles = parser.booleanValue();
279290
break;
291+
case DOES_VERSION_CREATE_MODEL_GROUP:
292+
doesVersionCreateModelGroup = parser.booleanValue();
293+
break;
280294
default:
281295
parser.skipChildren();
282296
break;
283297
}
284298
}
285-
return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles);
299+
return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup);
286300
}
287301

288302
}

common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public void setup() {
4343
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
4444
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
4545
mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0",
46-
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null);
46+
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null);
4747
}
4848

4949
@Test

common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public void setUp() {
3333
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
3434
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
3535
mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0",
36-
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null);
36+
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null);
3737
}
3838

3939
@Test

plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java

+2
Original file line numberDiff line numberDiff line change
@@ -256,12 +256,14 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis
256256
MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(registerModelInput);
257257
mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> {
258258
registerModelInput.setModelGroupId(modelGroupId);
259+
registerModelInput.setDoesVersionCreateModelGroup(true);
259260
registerModel(registerModelInput, listener);
260261
}, e -> {
261262
logException("Failed to create Model Group", e, log);
262263
listener.onFailure(e);
263264
}));
264265
} else {
266+
registerModelInput.setDoesVersionCreateModelGroup(false);
265267
registerModel(registerModelInput, listener);
266268
}
267269
}

plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java

+2
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,14 @@ private void createModelGroup(MLRegisterModelMetaInput mlUploadInput, ActionList
121121
MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput);
122122
mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> {
123123
mlUploadInput.setModelGroupId(modelGroupId);
124+
mlUploadInput.setDoesVersionCreateModelGroup(true);
124125
registerModelMeta(mlUploadInput, listener);
125126
}, e -> {
126127
logException("Failed to create Model Group", e, log);
127128
listener.onFailure(e);
128129
}));
129130
} else {
131+
mlUploadInput.setDoesVersionCreateModelGroup(false);
130132
registerModelMeta(mlUploadInput, listener);
131133
}
132134
}

plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java

+30-23
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,13 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us
146146
AccessMode modelAccessMode = input.getModelAccessMode();
147147
Boolean isAddAllBackendRoles = input.getIsAddAllBackendRoles();
148148
if (modelAccessMode == null) {
149-
if (modelAccessMode == null) {
150-
if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) {
151-
throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time.");
152-
} else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) {
153-
input.setModelAccessMode(AccessMode.RESTRICTED);
154-
modelAccessMode = AccessMode.RESTRICTED;
155-
} else {
156-
input.setModelAccessMode(AccessMode.PRIVATE);
157-
}
149+
if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) {
150+
throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time.");
151+
} else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) {
152+
input.setModelAccessMode(AccessMode.RESTRICTED);
153+
modelAccessMode = AccessMode.RESTRICTED;
154+
} else {
155+
input.setModelAccessMode(AccessMode.PRIVATE);
158156
}
159157
}
160158
if ((AccessMode.PUBLIC == modelAccessMode || AccessMode.PRIVATE == modelAccessMode)
@@ -184,20 +182,29 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us
184182
}
185183

186184
public void validateUniqueModelGroupName(String name, ActionListener<SearchResponse> listener) throws IllegalArgumentException {
187-
BoolQueryBuilder query = new BoolQueryBuilder();
188-
query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name));
189-
190-
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
191-
SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder);
192-
193-
client.search(searchRequest, ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> {
194-
if (e instanceof IndexNotFoundException) {
195-
listener.onResponse(null);
196-
} else {
197-
log.error("Failed to search model group index", e);
198-
listener.onFailure(e);
199-
}
200-
}));
185+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
186+
BoolQueryBuilder query = new BoolQueryBuilder();
187+
query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name));
188+
189+
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
190+
SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder);
191+
192+
client
193+
.search(
194+
searchRequest,
195+
ActionListener.runBefore(ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> {
196+
if (e instanceof IndexNotFoundException) {
197+
listener.onResponse(null);
198+
} else {
199+
log.error("Failed to search model group index", e);
200+
listener.onFailure(e);
201+
}
202+
}), () -> context.restore())
203+
);
204+
} catch (Exception e) {
205+
log.error("Failed to search model group index", e);
206+
listener.onFailure(e);
207+
}
201208
}
202209

203210
private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) {

0 commit comments

Comments
 (0)