Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Use nullable types in PatchExperiment [DET-6486] #3497

Merged
merged 6 commits into from
Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:orphan:

**API Change**

- The `PATCH /api/v1/experiments/:id` route no longer uses a field mask.
When you include a field in the body (e.g. notes or labels) that field will be updated, if it
is excluded then it will remain unchanged.
36 changes: 20 additions & 16 deletions harness/determined/cli/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,40 +527,44 @@ def pause(args: Namespace) -> None:
@authentication.required
def set_description(args: Namespace) -> None:
session = setup_session(args)
experiment = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
experiment.description = args.description
bindings.patch_PatchExperiment(session, body=experiment, experiment_id=args.experiment_id)
exp = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
exp_patch = bindings.v1PatchExperiment.from_json(exp.to_json())
exp_patch.description = args.description
bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id)
print("Set description of experiment {} to '{}'".format(args.experiment_id, args.description))


@authentication.required
def set_name(args: Namespace) -> None:
session = setup_session(args)
experiment = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
experiment.name = args.name
bindings.patch_PatchExperiment(session, body=experiment, experiment_id=args.experiment_id)
exp = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
exp_patch = bindings.v1PatchExperiment.from_json(exp.to_json())
exp_patch.name = args.name
bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id)
print("Set name of experiment {} to '{}'".format(args.experiment_id, args.name))


@authentication.required
def add_label(args: Namespace) -> None:
session = setup_session(args)
experiment = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
if experiment.labels is None:
experiment.labels = []
if args.label not in experiment.labels:
experiment.labels = list(experiment.labels) + [args.label]
bindings.patch_PatchExperiment(session, body=experiment, experiment_id=args.experiment_id)
exp = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
exp_patch = bindings.v1PatchExperiment.from_json(exp.to_json())
if exp_patch.labels is None:
exp_patch.labels = []
if args.label not in exp_patch.labels:
exp_patch.labels = list(exp_patch.labels) + [args.label]
bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id)
print("Added label '{}' to experiment {}".format(args.label, args.experiment_id))


@authentication.required
def remove_label(args: Namespace) -> None:
session = setup_session(args)
experiment = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
if (experiment.labels is not None) and (args.label in experiment.labels):
experiment.labels = [label for label in experiment.labels if label != args.label]
bindings.patch_PatchExperiment(session, body=experiment, experiment_id=args.experiment_id)
exp = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment
exp_patch = bindings.v1PatchExperiment.from_json(exp.to_json())
if (exp_patch.labels is not None) and (args.label in exp_patch.labels):
exp_patch.labels = [label for label in exp_patch.labels if label != args.label]
bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id)
print("Removed label '{}' from experiment {}".format(args.label, args.experiment_id))


Expand Down
54 changes: 35 additions & 19 deletions harness/determined/common/api/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,24 +220,6 @@ def to_json(self) -> typing.Any:
"value": self.value if self.value is not None else None,
}

class protobufFieldMask:
def __init__(
self,
paths: "typing.Optional[typing.Sequence[str]]" = None,
):
self.paths = paths

@classmethod
def from_json(cls, obj: Json) -> "protobufFieldMask":
return cls(
paths=obj.get("paths", None),
)

def to_json(self) -> typing.Any:
return {
"paths": self.paths if self.paths is not None else None,
}

class protobufNullValue(enum.Enum):
NULL_VALUE = "NULL_VALUE"

Expand Down Expand Up @@ -3063,6 +3045,40 @@ def to_json(self) -> typing.Any:
"limit": self.limit if self.limit is not None else None,
}

class v1PatchExperiment:
def __init__(
self,
id: int,
description: "typing.Optional[str]" = None,
labels: "typing.Optional[typing.Sequence[typing.Dict[str, typing.Any]]]" = None,
name: "typing.Optional[str]" = None,
notes: "typing.Optional[str]" = None,
):
self.id = id
self.description = description
self.labels = labels
self.name = name
self.notes = notes

@classmethod
def from_json(cls, obj: Json) -> "v1PatchExperiment":
return cls(
id=obj["id"],
description=obj.get("description", None),
labels=obj.get("labels", None),
name=obj.get("name", None),
notes=obj.get("notes", None),
)

def to_json(self) -> typing.Any:
return {
"id": self.id,
"description": self.description if self.description is not None else None,
"labels": self.labels if self.labels is not None else None,
"name": self.name if self.name is not None else None,
"notes": self.notes if self.notes is not None else None,
}

class v1PatchExperimentResponse:
def __init__(
self,
Expand Down Expand Up @@ -6536,7 +6552,7 @@ def post_MarkAllocationReservationDaemon(
def patch_PatchExperiment(
session: "client.Session",
*,
body: "v1Experiment",
body: "v1PatchExperiment",
experiment_id: int,
) -> "v1PatchExperimentResponse":
_params = None
Expand Down
76 changes: 36 additions & 40 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"strings"
"time"

"github.com/determined-ai/determined/master/internal/config"
"github.com/determined-ai/determined/master/internal/prom"

"github.com/google/uuid"
Expand Down Expand Up @@ -39,6 +38,8 @@ import (
"github.com/determined-ai/determined/proto/pkg/checkpointv1"
"github.com/determined-ai/determined/proto/pkg/experimentv1"
"github.com/determined-ai/determined/proto/pkg/jobv1"

structpb "github.com/golang/protobuf/ptypes/struct"
)

var experimentsAddr = actor.Addr("experiments")
Expand Down Expand Up @@ -539,46 +540,43 @@ func (a *apiServer) PatchExperiment(
return nil, errors.Wrapf(err, "error fetching experiment from database: %d", req.Experiment.Id)
}

paths := req.UpdateMask.GetPaths()
shouldUpdateNotes := false
shouldUpdateConfig := false
for _, path := range paths {
switch {
case path == "name":
if len(strings.TrimSpace(req.Experiment.Name)) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "`name` is required.")
}
exp.Name = req.Experiment.Name
patch := config.ExperimentConfigPatch{
Name: &req.Experiment.Name,
}
a.m.system.TellAt(actor.Addr("experiments", req.Experiment.Id), patch)
case path == "notes":
shouldUpdateNotes = true
exp.Notes = req.Experiment.Notes
case path == "labels":
exp.Labels = req.Experiment.Labels
prom.AssociateExperimentIDLabels(strconv.Itoa(int(req.Experiment.Id)),
req.Experiment.Labels)
case path == "description":
exp.Description = req.Experiment.Description
madeChanges := false
if req.Experiment.Name != nil && exp.Name != req.Experiment.Name.Value {
madeChanges = true
if len(strings.TrimSpace(req.Experiment.Name.Value)) == 0 {
return nil, status.Errorf(codes.InvalidArgument,
"`name` must not be an empty or whitespace string.")
}
exp.Name = req.Experiment.Name.Value
}
shouldUpdateConfig = (shouldUpdateNotes && len(paths) > 1) ||
(!shouldUpdateNotes && len(paths) > 0)

if shouldUpdateNotes {
_, err := a.m.db.RawQuery(
"patch_experiment_notes",
req.Experiment.Id,
req.Experiment.Notes,
)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update experiment")
if req.Experiment.Notes != nil && exp.Notes != req.Experiment.Notes.Value {
madeChanges = true
exp.Notes = req.Experiment.Notes.Value
}

if req.Experiment.Description != nil && exp.Description != req.Experiment.Description.Value {
madeChanges = true
exp.Description = req.Experiment.Description.Value
}

if req.Experiment.Labels != nil {
var reqLabelList []string
for _, el := range req.Experiment.Labels.Values {
if _, ok := el.GetKind().(*structpb.Value_StringValue); ok {
reqLabelList = append(reqLabelList, el.GetStringValue())
}
}
reqLabels := strings.Join(reqLabelList, ",")
if strings.Join(exp.Labels, ",") != reqLabels {
madeChanges = true
exp.Labels = reqLabelList
prom.AssociateExperimentIDLabels(strconv.Itoa(int(req.Experiment.Id)),
exp.Labels)
}
}

if shouldUpdateConfig {
if madeChanges {
type experimentPatch struct {
Labels []string `json:"labels"`
Description string `json:"description"`
Expand All @@ -593,13 +591,11 @@ func (a *apiServer) PatchExperiment(
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to marshal experiment patch")
}

_, err = a.m.db.RawQuery(
"patch_experiment_config",
req.Experiment.Id,
marshalledPatches,
)
"patch_experiment", exp.Id, marshalledPatches, exp.Notes)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update experiment")
return nil, errors.Wrapf(err, "error updating experiment in database: %d", req.Experiment.Id)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
UPDATE experiments e
SET config = config || $2
SET config = config || $2, notes = $3
WHERE e.id = $1
RETURNING e.id
4 changes: 0 additions & 4 deletions master/static/srv/patch_experiment_notes.sql

This file was deleted.

Binary file modified proto/buf.image.bin
Binary file not shown.
5 changes: 1 addition & 4 deletions proto/src/determined/api/v1/experiment.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ option go_package = "github.com/determined-ai/determined/proto/pkg/apiv1";

import "google/protobuf/wrappers.proto";
import "google/protobuf/struct.proto";
import "google/protobuf/field_mask.proto";
import "protoc-gen-swagger/options/annotations.proto";

import "determined/api/v1/pagination.proto";
Expand Down Expand Up @@ -188,9 +187,7 @@ message UnarchiveExperimentResponse {}
// others will be ignored.
message PatchExperimentRequest {
// Patched experiment attributes.
determined.experiment.v1.Experiment experiment = 1;
// Update mask.
google.protobuf.FieldMask update_mask = 2;
determined.experiment.v1.PatchExperiment experiment = 2;
}

// Response to PatchExperimentRequest.
Expand Down
18 changes: 18 additions & 0 deletions proto/src/determined/experiment/v1/experiment.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ syntax = "proto3";
package determined.experiment.v1;
option go_package = "github.com/determined-ai/determined/proto/pkg/experimentv1";

import "google/protobuf/struct.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
import "protoc-gen-swagger/options/annotations.proto";
Expand Down Expand Up @@ -90,6 +91,23 @@ message Experiment {
google.protobuf.Int32Value forked_from = 16;
}

// PatchExperiment is a partial update to an experiment with only id required.
message PatchExperiment {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏼

option (grpc.gateway.protoc_gen_swagger.options.openapiv2_schema) = {
json_schema: { required: [ "id" ] }
};
// The id of the experiment.
int32 id = 1;
// The description of the experiment.
google.protobuf.StringValue description = 2;
// Labels attached to the experiment.
google.protobuf.ListValue labels = 3;
// The experiment name.
google.protobuf.StringValue name = 4;
// The experiment notes.
google.protobuf.StringValue notes = 5;
}

// ValidationHistoryEntry is a single entry for a validation history for an
// experiment.
message ValidationHistoryEntry {
Expand Down
2 changes: 1 addition & 1 deletion proto/src/determined/model/v1/model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ message Model {
string notes = 12;
}

// PatchModel is a partial update to a model with only id required
// PatchModel is a partial update to a model with only id required.
message PatchModel {
// An updated name for the model.
google.protobuf.StringValue name = 2
Expand Down
Loading