Skip to content

Commit c7baf31

Browse files
authored
chore: Use nullable types in PatchExperiment [DET-6486] (#3497)
* release note
1 parent b55a61b commit c7baf31

File tree

11 files changed

+167
-109
lines changed

11 files changed

+167
-109
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:orphan:
2+
3+
**API Change**
4+
5+
- The `PATCH /api/v1/experiments/:id` route no longer uses a field mask.
6+
When you include a field in the body (e.g. notes or labels) that field will be updated, if it
7+
is excluded then it will remain unchanged.

harness/determined/cli/experiment.py

+20-16
Original file line numberDiff line numberDiff line change
@@ -527,40 +527,44 @@ def pause(args: Namespace) -> None:
527527
@authentication.required
528528
def set_description(args: Namespace) -> None:
529529
session = setup_session(args)
530-
experiment = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
531-
experiment.description = args.description
532-
bindings.patch_PatchExperiment(session, body=experiment, experiment_id=args.experiment_id)
530+
exp = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
531+
exp_patch = bindings.v1PatchExperiment.from_json(exp.to_json())
532+
exp_patch.description = args.description
533+
bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id)
533534
print("Set description of experiment {} to '{}'".format(args.experiment_id, args.description))
534535

535536

536537
@authentication.required
537538
def set_name(args: Namespace) -> None:
538539
session = setup_session(args)
539-
experiment = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
540-
experiment.name = args.name
541-
bindings.patch_PatchExperiment(session, body=experiment, experiment_id=args.experiment_id)
540+
exp = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
541+
exp_patch = bindings.v1PatchExperiment.from_json(exp.to_json())
542+
exp_patch.name = args.name
543+
bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id)
542544
print("Set name of experiment {} to '{}'".format(args.experiment_id, args.name))
543545

544546

545547
@authentication.required
546548
def add_label(args: Namespace) -> None:
547549
session = setup_session(args)
548-
experiment = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
549-
if experiment.labels is None:
550-
experiment.labels = []
551-
if args.label not in experiment.labels:
552-
experiment.labels = list(experiment.labels) + [args.label]
553-
bindings.patch_PatchExperiment(session, body=experiment, experiment_id=args.experiment_id)
550+
exp = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
551+
exp_patch = bindings.v1PatchExperiment.from_json(exp.to_json())
552+
if exp_patch.labels is None:
553+
exp_patch.labels = []
554+
if args.label not in exp_patch.labels:
555+
exp_patch.labels = list(exp_patch.labels) + [args.label]
556+
bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id)
554557
print("Added label '{}' to experiment {}".format(args.label, args.experiment_id))
555558

556559

557560
@authentication.required
558561
def remove_label(args: Namespace) -> None:
559562
session = setup_session(args)
560-
experiment = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
561-
if (experiment.labels is not None) and (args.label in experiment.labels):
562-
experiment.labels = [label for label in experiment.labels if label != args.label]
563-
bindings.patch_PatchExperiment(session, body=experiment, experiment_id=args.experiment_id)
563+
exp = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
564+
exp_patch = bindings.v1PatchExperiment.from_json(exp.to_json())
565+
if (exp_patch.labels is not None) and (args.label in exp_patch.labels):
566+
exp_patch.labels = [label for label in exp_patch.labels if label != args.label]
567+
bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id)
564568
print("Removed label '{}' from experiment {}".format(args.label, args.experiment_id))
565569

566570

harness/determined/common/api/bindings.py

+35-19
Original file line numberDiff line numberDiff line change
@@ -220,24 +220,6 @@ def to_json(self) -> typing.Any:
220220
"value": self.value if self.value is not None else None,
221221
}
222222

223-
class protobufFieldMask:
224-
def __init__(
225-
self,
226-
paths: "typing.Optional[typing.Sequence[str]]" = None,
227-
):
228-
self.paths = paths
229-
230-
@classmethod
231-
def from_json(cls, obj: Json) -> "protobufFieldMask":
232-
return cls(
233-
paths=obj.get("paths", None),
234-
)
235-
236-
def to_json(self) -> typing.Any:
237-
return {
238-
"paths": self.paths if self.paths is not None else None,
239-
}
240-
241223
class protobufNullValue(enum.Enum):
242224
NULL_VALUE = "NULL_VALUE"
243225

@@ -3063,6 +3045,40 @@ def to_json(self) -> typing.Any:
30633045
"limit": self.limit if self.limit is not None else None,
30643046
}
30653047

3048+
class v1PatchExperiment:
3049+
def __init__(
3050+
self,
3051+
id: int,
3052+
description: "typing.Optional[str]" = None,
3053+
labels: "typing.Optional[typing.Sequence[typing.Dict[str, typing.Any]]]" = None,
3054+
name: "typing.Optional[str]" = None,
3055+
notes: "typing.Optional[str]" = None,
3056+
):
3057+
self.id = id
3058+
self.description = description
3059+
self.labels = labels
3060+
self.name = name
3061+
self.notes = notes
3062+
3063+
@classmethod
3064+
def from_json(cls, obj: Json) -> "v1PatchExperiment":
3065+
return cls(
3066+
id=obj["id"],
3067+
description=obj.get("description", None),
3068+
labels=obj.get("labels", None),
3069+
name=obj.get("name", None),
3070+
notes=obj.get("notes", None),
3071+
)
3072+
3073+
def to_json(self) -> typing.Any:
3074+
return {
3075+
"id": self.id,
3076+
"description": self.description if self.description is not None else None,
3077+
"labels": self.labels if self.labels is not None else None,
3078+
"name": self.name if self.name is not None else None,
3079+
"notes": self.notes if self.notes is not None else None,
3080+
}
3081+
30663082
class v1PatchExperimentResponse:
30673083
def __init__(
30683084
self,
@@ -6536,7 +6552,7 @@ def post_MarkAllocationReservationDaemon(
65366552
def patch_PatchExperiment(
65376553
session: "client.Session",
65386554
*,
6539-
body: "v1Experiment",
6555+
body: "v1PatchExperiment",
65406556
experiment_id: int,
65416557
) -> "v1PatchExperimentResponse":
65426558
_params = None

master/internal/api_experiment.go

+36-40
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"strings"
1212
"time"
1313

14-
"github.com/determined-ai/determined/master/internal/config"
1514
"github.com/determined-ai/determined/master/internal/prom"
1615

1716
"github.com/google/uuid"
@@ -39,6 +38,8 @@ import (
3938
"github.com/determined-ai/determined/proto/pkg/checkpointv1"
4039
"github.com/determined-ai/determined/proto/pkg/experimentv1"
4140
"github.com/determined-ai/determined/proto/pkg/jobv1"
41+
42+
structpb "github.com/golang/protobuf/ptypes/struct"
4243
)
4344

4445
var experimentsAddr = actor.Addr("experiments")
@@ -539,46 +540,43 @@ func (a *apiServer) PatchExperiment(
539540
return nil, errors.Wrapf(err, "error fetching experiment from database: %d", req.Experiment.Id)
540541
}
541542

542-
paths := req.UpdateMask.GetPaths()
543-
shouldUpdateNotes := false
544-
shouldUpdateConfig := false
545-
for _, path := range paths {
546-
switch {
547-
case path == "name":
548-
if len(strings.TrimSpace(req.Experiment.Name)) == 0 {
549-
return nil, status.Errorf(codes.InvalidArgument, "`name` is required.")
550-
}
551-
exp.Name = req.Experiment.Name
552-
patch := config.ExperimentConfigPatch{
553-
Name: &req.Experiment.Name,
554-
}
555-
a.m.system.TellAt(actor.Addr("experiments", req.Experiment.Id), patch)
556-
case path == "notes":
557-
shouldUpdateNotes = true
558-
exp.Notes = req.Experiment.Notes
559-
case path == "labels":
560-
exp.Labels = req.Experiment.Labels
561-
prom.AssociateExperimentIDLabels(strconv.Itoa(int(req.Experiment.Id)),
562-
req.Experiment.Labels)
563-
case path == "description":
564-
exp.Description = req.Experiment.Description
543+
madeChanges := false
544+
if req.Experiment.Name != nil && exp.Name != req.Experiment.Name.Value {
545+
madeChanges = true
546+
if len(strings.TrimSpace(req.Experiment.Name.Value)) == 0 {
547+
return nil, status.Errorf(codes.InvalidArgument,
548+
"`name` must not be an empty or whitespace string.")
565549
}
550+
exp.Name = req.Experiment.Name.Value
566551
}
567-
shouldUpdateConfig = (shouldUpdateNotes && len(paths) > 1) ||
568-
(!shouldUpdateNotes && len(paths) > 0)
569552

570-
if shouldUpdateNotes {
571-
_, err := a.m.db.RawQuery(
572-
"patch_experiment_notes",
573-
req.Experiment.Id,
574-
req.Experiment.Notes,
575-
)
576-
if err != nil {
577-
return nil, status.Errorf(codes.Internal, "failed to update experiment")
553+
if req.Experiment.Notes != nil && exp.Notes != req.Experiment.Notes.Value {
554+
madeChanges = true
555+
exp.Notes = req.Experiment.Notes.Value
556+
}
557+
558+
if req.Experiment.Description != nil && exp.Description != req.Experiment.Description.Value {
559+
madeChanges = true
560+
exp.Description = req.Experiment.Description.Value
561+
}
562+
563+
if req.Experiment.Labels != nil {
564+
var reqLabelList []string
565+
for _, el := range req.Experiment.Labels.Values {
566+
if _, ok := el.GetKind().(*structpb.Value_StringValue); ok {
567+
reqLabelList = append(reqLabelList, el.GetStringValue())
568+
}
569+
}
570+
reqLabels := strings.Join(reqLabelList, ",")
571+
if strings.Join(exp.Labels, ",") != reqLabels {
572+
madeChanges = true
573+
exp.Labels = reqLabelList
574+
prom.AssociateExperimentIDLabels(strconv.Itoa(int(req.Experiment.Id)),
575+
exp.Labels)
578576
}
579577
}
580578

581-
if shouldUpdateConfig {
579+
if madeChanges {
582580
type experimentPatch struct {
583581
Labels []string `json:"labels"`
584582
Description string `json:"description"`
@@ -593,13 +591,11 @@ func (a *apiServer) PatchExperiment(
593591
if err != nil {
594592
return nil, status.Errorf(codes.Internal, "failed to marshal experiment patch")
595593
}
594+
596595
_, err = a.m.db.RawQuery(
597-
"patch_experiment_config",
598-
req.Experiment.Id,
599-
marshalledPatches,
600-
)
596+
"patch_experiment", exp.Id, marshalledPatches, exp.Notes)
601597
if err != nil {
602-
return nil, status.Errorf(codes.Internal, "failed to update experiment")
598+
return nil, errors.Wrapf(err, "error updating experiment in database: %d", req.Experiment.Id)
603599
}
604600
}
605601

Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
UPDATE experiments e
2-
SET config = config || $2
2+
SET config = config || $2, notes = $3
33
WHERE e.id = $1
44
RETURNING e.id

master/static/srv/patch_experiment_notes.sql

-4
This file was deleted.

proto/buf.image.bin

691 Bytes
Binary file not shown.

proto/src/determined/api/v1/experiment.proto

+1-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ option go_package = "github.com/determined-ai/determined/proto/pkg/apiv1";
55

66
import "google/protobuf/wrappers.proto";
77
import "google/protobuf/struct.proto";
8-
import "google/protobuf/field_mask.proto";
98
import "protoc-gen-swagger/options/annotations.proto";
109

1110
import "determined/api/v1/pagination.proto";
@@ -188,9 +187,7 @@ message UnarchiveExperimentResponse {}
188187
// others will be ignored.
189188
message PatchExperimentRequest {
190189
// Patched experiment attributes.
191-
determined.experiment.v1.Experiment experiment = 1;
192-
// Update mask.
193-
google.protobuf.FieldMask update_mask = 2;
190+
determined.experiment.v1.PatchExperiment experiment = 2;
194191
}
195192

196193
// Response to PatchExperimentRequest.

proto/src/determined/experiment/v1/experiment.proto

+18
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ syntax = "proto3";
33
package determined.experiment.v1;
44
option go_package = "github.com/determined-ai/determined/proto/pkg/experimentv1";
55

6+
import "google/protobuf/struct.proto";
67
import "google/protobuf/timestamp.proto";
78
import "google/protobuf/wrappers.proto";
89
import "protoc-gen-swagger/options/annotations.proto";
@@ -90,6 +91,23 @@ message Experiment {
9091
google.protobuf.Int32Value forked_from = 16;
9192
}
9293

94+
// PatchExperiment is a partial update to an experiment with only id required.
95+
message PatchExperiment {
96+
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_schema) = {
97+
json_schema: { required: [ "id" ] }
98+
};
99+
// The id of the experiment.
100+
int32 id = 1;
101+
// The description of the experiment.
102+
google.protobuf.StringValue description = 2;
103+
// Labels attached to the experiment.
104+
google.protobuf.ListValue labels = 3;
105+
// The experiment name.
106+
google.protobuf.StringValue name = 4;
107+
// The experiment notes.
108+
google.protobuf.StringValue notes = 5;
109+
}
110+
93111
// ValidationHistoryEntry is a single entry for a validation history for an
94112
// experiment.
95113
message ValidationHistoryEntry {

proto/src/determined/model/v1/model.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ message Model {
5151
string notes = 12;
5252
}
5353

54-
// PatchModel is a partial update to a model with only id required
54+
// PatchModel is a partial update to a model with only id required.
5555
message PatchModel {
5656
// An updated name for the model.
5757
google.protobuf.StringValue name = 2

0 commit comments

Comments
 (0)