@@ -11,7 +11,6 @@ import (
11
11
"strings"
12
12
"time"
13
13
14
- "github.com/determined-ai/determined/master/internal/config"
15
14
"github.com/determined-ai/determined/master/internal/prom"
16
15
17
16
"github.com/google/uuid"
@@ -39,6 +38,8 @@ import (
39
38
"github.com/determined-ai/determined/proto/pkg/checkpointv1"
40
39
"github.com/determined-ai/determined/proto/pkg/experimentv1"
41
40
"github.com/determined-ai/determined/proto/pkg/jobv1"
41
+
42
+ structpb "github.com/golang/protobuf/ptypes/struct"
42
43
)
43
44
44
45
var experimentsAddr = actor .Addr ("experiments" )
@@ -539,46 +540,43 @@ func (a *apiServer) PatchExperiment(
539
540
return nil , errors .Wrapf (err , "error fetching experiment from database: %d" , req .Experiment .Id )
540
541
}
541
542
542
- paths := req .UpdateMask .GetPaths ()
543
- shouldUpdateNotes := false
544
- shouldUpdateConfig := false
545
- for _ , path := range paths {
546
- switch {
547
- case path == "name" :
548
- if len (strings .TrimSpace (req .Experiment .Name )) == 0 {
549
- return nil , status .Errorf (codes .InvalidArgument , "`name` is required." )
550
- }
551
- exp .Name = req .Experiment .Name
552
- patch := config.ExperimentConfigPatch {
553
- Name : & req .Experiment .Name ,
554
- }
555
- a .m .system .TellAt (actor .Addr ("experiments" , req .Experiment .Id ), patch )
556
- case path == "notes" :
557
- shouldUpdateNotes = true
558
- exp .Notes = req .Experiment .Notes
559
- case path == "labels" :
560
- exp .Labels = req .Experiment .Labels
561
- prom .AssociateExperimentIDLabels (strconv .Itoa (int (req .Experiment .Id )),
562
- req .Experiment .Labels )
563
- case path == "description" :
564
- exp .Description = req .Experiment .Description
543
+ madeChanges := false
544
+ if req .Experiment .Name != nil && exp .Name != req .Experiment .Name .Value {
545
+ madeChanges = true
546
+ if len (strings .TrimSpace (req .Experiment .Name .Value )) == 0 {
547
+ return nil , status .Errorf (codes .InvalidArgument ,
548
+ "`name` must not be an empty or whitespace string." )
565
549
}
550
+ exp .Name = req .Experiment .Name .Value
566
551
}
567
- shouldUpdateConfig = (shouldUpdateNotes && len (paths ) > 1 ) ||
568
- (! shouldUpdateNotes && len (paths ) > 0 )
569
552
570
- if shouldUpdateNotes {
571
- _ , err := a .m .db .RawQuery (
572
- "patch_experiment_notes" ,
573
- req .Experiment .Id ,
574
- req .Experiment .Notes ,
575
- )
576
- if err != nil {
577
- return nil , status .Errorf (codes .Internal , "failed to update experiment" )
553
+ if req .Experiment .Notes != nil && exp .Notes != req .Experiment .Notes .Value {
554
+ madeChanges = true
555
+ exp .Notes = req .Experiment .Notes .Value
556
+ }
557
+
558
+ if req .Experiment .Description != nil && exp .Description != req .Experiment .Description .Value {
559
+ madeChanges = true
560
+ exp .Description = req .Experiment .Description .Value
561
+ }
562
+
563
+ if req .Experiment .Labels != nil {
564
+ var reqLabelList []string
565
+ for _ , el := range req .Experiment .Labels .Values {
566
+ if _ , ok := el .GetKind ().(* structpb.Value_StringValue ); ok {
567
+ reqLabelList = append (reqLabelList , el .GetStringValue ())
568
+ }
569
+ }
570
+ reqLabels := strings .Join (reqLabelList , "," )
571
+ if strings .Join (exp .Labels , "," ) != reqLabels {
572
+ madeChanges = true
573
+ exp .Labels = reqLabelList
574
+ prom .AssociateExperimentIDLabels (strconv .Itoa (int (req .Experiment .Id )),
575
+ exp .Labels )
578
576
}
579
577
}
580
578
581
- if shouldUpdateConfig {
579
+ if madeChanges {
582
580
type experimentPatch struct {
583
581
Labels []string `json:"labels"`
584
582
Description string `json:"description"`
@@ -593,13 +591,11 @@ func (a *apiServer) PatchExperiment(
593
591
if err != nil {
594
592
return nil , status .Errorf (codes .Internal , "failed to marshal experiment patch" )
595
593
}
594
+
596
595
_ , err = a .m .db .RawQuery (
597
- "patch_experiment_config" ,
598
- req .Experiment .Id ,
599
- marshalledPatches ,
600
- )
596
+ "patch_experiment" , exp .Id , marshalledPatches , exp .Notes )
601
597
if err != nil {
602
- return nil , status . Errorf ( codes . Internal , "failed to update experiment" )
598
+ return nil , errors . Wrapf ( err , "error updating experiment in database: %d" , req . Experiment . Id )
603
599
}
604
600
}
605
601
0 commit comments