diff --git a/internals/daemon/daemon.go b/internals/daemon/daemon.go index 432451d6d..487f2cc42 100644 --- a/internals/daemon/daemon.go +++ b/internals/daemon/daemon.go @@ -690,7 +690,7 @@ func (d *Daemon) rebootDelay() (time.Duration, error) { // see whether a reboot had already been scheduled var rebootAt time.Time err := d.state.Get("daemon-system-restart-at", &rebootAt) - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return 0, err } rebootDelay := 1 * time.Minute @@ -832,7 +832,7 @@ var errExpectedReboot = errors.New("expected reboot did not happen") func (d *Daemon) RebootIsMissing(st *state.State) error { var nTentative int err := st.Get("daemon-system-restart-tentative", &nTentative) - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return err } nTentative++ diff --git a/internals/daemon/daemon_test.go b/internals/daemon/daemon_test.go index 2a3d18abb..e3784dc96 100644 --- a/internals/daemon/daemon_test.go +++ b/internals/daemon/daemon_test.go @@ -991,8 +991,8 @@ func (s *daemonSuite) TestRestartExpectedRebootOK(c *C) { defer st.Unlock() var v interface{} // these were cleared - c.Check(st.Get("daemon-system-restart-at", &v), Equals, state.ErrNoState) - c.Check(st.Get("system-restart-from-boot-id", &v), Equals, state.ErrNoState) + c.Check(st.Get("daemon-system-restart-at", &v), testutil.ErrorIs, state.ErrNoState) + c.Check(st.Get("system-restart-from-boot-id", &v), testutil.ErrorIs, state.ErrNoState) } func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *C) { @@ -1015,9 +1015,9 @@ func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *C) { defer st.Unlock() var v interface{} // these were cleared - c.Check(st.Get("daemon-system-restart-at", &v), Equals, state.ErrNoState) - c.Check(st.Get("system-restart-from-boot-id", &v), Equals, state.ErrNoState) - c.Check(st.Get("daemon-system-restart-tentative", &v), Equals, state.ErrNoState) + c.Check(st.Get("daemon-system-restart-at", &v), testutil.ErrorIs, state.ErrNoState) + c.Check(st.Get("system-restart-from-boot-id", &v), testutil.ErrorIs, state.ErrNoState) + c.Check(st.Get("daemon-system-restart-tentative", &v), testutil.ErrorIs, state.ErrNoState) } func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *C) { diff --git a/internals/overlord/export_test.go b/internals/overlord/export_test.go index 225297029..372780eb8 100644 --- a/internals/overlord/export_test.go +++ b/internals/overlord/export_test.go @@ -40,6 +40,14 @@ func FakePruneInterval(prunei, prunew, abortw time.Duration) (restore func()) { } } +func FakePruneTicker(f func(t *time.Ticker) <-chan time.Time) (restore func()) { + old := pruneTickerC + pruneTickerC = f + return func() { + pruneTickerC = old + } +} + // FakeEnsureNext sets o.ensureNext for tests. func FakeEnsureNext(o *Overlord, t time.Time) { o.ensureNext = t @@ -49,3 +57,11 @@ func FakeEnsureNext(o *Overlord, t time.Time) { func (o *Overlord) Engine() *StateEngine { return o.stateEng } + +func FakeTimeNow(f func() time.Time) (restore func()) { + old := timeNow + timeNow = f + return func() { + timeNow = old + } +} diff --git a/internals/overlord/overlord.go b/internals/overlord/overlord.go index d6406e243..6c068233f 100644 --- a/internals/overlord/overlord.go +++ b/internals/overlord/overlord.go @@ -16,6 +16,7 @@ package overlord import ( + "errors" "fmt" "io" "os" @@ -49,6 +50,10 @@ var ( defaultCachedDownloads = 5 ) +var pruneTickerC = func(t *time.Ticker) <-chan time.Time { + return t.C +} + // Extension represents an extension of the Overlord. type Extension interface { // ExtraManagers allows additional StateManagers to be used. @@ -81,6 +86,8 @@ type Overlord struct { ensureRun int32 pruneTicker *time.Ticker + startOfOperationTime time.Time + // managers inited bool startedUp bool @@ -258,6 +265,15 @@ func (o *Overlord) StartUp() error { return nil } o.startedUp = true + + var err error + st := o.State() + st.Lock() + o.startOfOperationTime, err = o.StartOfOperationTime() + st.Unlock() + if err != nil { + return fmt.Errorf("cannot get start of operation time: %s", err) + } return o.stateEng.StartUp() } @@ -314,14 +330,15 @@ func (o *Overlord) Loop() { // continue to the next Ensure() try for now o.stateEng.Ensure() o.ensureDidRun() + pruneC := pruneTickerC(o.pruneTicker) select { case <-o.loopTomb.Dying(): return nil case <-o.ensureTimer.C: - case <-o.pruneTicker.C: + case <-pruneC: st := o.State() st.Lock() - st.Prune(pruneWait, abortWait, pruneMaxChanges) + st.Prune(o.startOfOperationTime, pruneWait, abortWait, pruneMaxChanges) st.Unlock() } } @@ -499,6 +516,23 @@ func (o *Overlord) AddManager(mgr StateManager) { o.stateEng.AddManager(mgr) } +var timeNow = time.Now + +func (m *Overlord) StartOfOperationTime() (time.Time, error) { + var opTime time.Time + err := m.State().Get("start-of-operation-time", &opTime) + if err == nil { + return opTime, nil + } + if err != nil && !errors.Is(err, state.ErrNoState) { + return opTime, err + } + opTime = timeNow() + + m.State().Set("start-of-operation-time", opTime) + return opTime, nil +} + type fakeBackend struct { o *Overlord } diff --git a/internals/overlord/overlord_test.go b/internals/overlord/overlord_test.go index 57765a049..bef3e53ea 100644 --- a/internals/overlord/overlord_test.go +++ b/internals/overlord/overlord_test.go @@ -46,6 +46,26 @@ type overlordSuite struct { var _ = Suite(&overlordSuite{}) +type ticker struct { + tickerChannel chan time.Time +} + +func (w *ticker) tick(n int) { + for i := 0; i < n; i++ { + w.tickerChannel <- time.Now() + } +} + +func fakePruneTicker() (w *ticker, restore func()) { + w = &ticker{ + tickerChannel: make(chan time.Time), + } + restore = overlord.FakePruneTicker(func(t *time.Ticker) <-chan time.Time { + return w.tickerChannel + }) + return w, restore +} + func (ovs *overlordSuite) SetUpTest(c *C) { ovs.dir = c.MkDir() ovs.statePath = filepath.Join(ovs.dir, ".pebble.state") @@ -517,6 +537,137 @@ func (ovs *overlordSuite) TestEnsureLoopPruneRunsMultipleTimes(c *C) { c.Assert(err, IsNil) } +func (ovs *overlordSuite) TestOverlordStartUpSetsStartOfOperation(c *C) { + restoreIntv := overlord.FakePruneInterval(100*time.Millisecond, 1000*time.Millisecond, 1*time.Hour) + defer restoreIntv() + + // use real overlord, we need device manager to be there + o, err := overlord.New(&overlord.Options{PebbleDir: ovs.dir}) + c.Assert(err, IsNil) + + st := o.State() + st.Lock() + defer st.Unlock() + + // validity check, not set + var opTime time.Time + c.Assert(st.Get("start-of-operation-time", &opTime), testutil.ErrorIs, state.ErrNoState) + st.Unlock() + + c.Assert(o.StartUp(), IsNil) + + st.Lock() + c.Assert(st.Get("start-of-operation-time", &opTime), IsNil) +} + +func (ovs *overlordSuite) TestEnsureLoopPruneDoesntAbortShortlyAfterStartOfOperation(c *C) { + w, restoreTicker := fakePruneTicker() + defer restoreTicker() + + // use real overlord, we need device manager to be there + o, err := overlord.New(&overlord.Options{PebbleDir: ovs.dir}) + c.Assert(err, IsNil) + + // avoid immediate transition to Done due to unknown kind + o.TaskRunner().AddHandler("bar", func(t *state.Task, _ *tomb.Tomb) error { + return &state.Retry{} + }, nil) + + st := o.State() + st.Lock() + + // start of operation time is 50min ago, this is less then abort limit + opTime := time.Now().Add(-50 * time.Minute) + st.Set("start-of-operation-time", opTime) + + // spawn time one month ago + spawnTime := time.Now().AddDate(0, -1, 0) + restoreTimeNow := state.FakeTime(spawnTime) + + t := st.NewTask("bar", "...") + chg := st.NewChange("other-change", "...") + chg.AddTask(t) + + restoreTimeNow() + + // validity + c.Check(st.Changes(), HasLen, 1) + + st.Unlock() + c.Assert(o.StartUp(), IsNil) + + // start the loop that runs the prune ticker + o.Loop() + w.tick(2) + + c.Assert(o.Stop(), IsNil) + + st.Lock() + defer st.Unlock() + c.Assert(st.Changes(), HasLen, 1) + c.Check(chg.Status(), Equals, state.DoingStatus) +} + +func (ovs *overlordSuite) TestEnsureLoopPruneAbortsOld(c *C) { + // Ensure interval is not relevant for this test + restoreEnsureIntv := overlord.FakeEnsureInterval(10 * time.Hour) + defer restoreEnsureIntv() + + w, restoreTicker := fakePruneTicker() + defer restoreTicker() + + // use real overlord, we need device manager to be there + o, err := overlord.New(&overlord.Options{PebbleDir: ovs.dir}) + c.Assert(err, IsNil) + + // avoid immediate transition to Done due to having unknown kind + o.TaskRunner().AddHandler("bar", func(t *state.Task, _ *tomb.Tomb) error { + return &state.Retry{} + }, nil) + + st := o.State() + st.Lock() + + // start of operation time is a year ago + opTime := time.Now().AddDate(-1, 0, 0) + st.Set("start-of-operation-time", opTime) + + st.Unlock() + c.Assert(o.StartUp(), IsNil) + st.Lock() + + // spawn time one month ago + spawnTime := time.Now().AddDate(0, -1, 0) + restoreTimeNow := state.FakeTime(spawnTime) + t := st.NewTask("bar", "...") + chg := st.NewChange("other-change", "...") + chg.AddTask(t) + + restoreTimeNow() + + // validity + c.Check(st.Changes(), HasLen, 1) + st.Unlock() + + // start the loop that runs the prune ticker + o.Loop() + w.tick(2) + + c.Assert(o.Stop(), IsNil) + + st.Lock() + defer st.Unlock() + + // validity + op, err := o.StartOfOperationTime() + c.Assert(err, IsNil) + c.Check(op.Equal(opTime), Equals, true) + + c.Assert(st.Changes(), HasLen, 1) + // change was aborted + c.Check(chg.Status(), Equals, state.HoldStatus) +} + func (ovs *overlordSuite) TestCheckpoint(c *C) { oldUmask := syscall.Umask(0) defer syscall.Umask(oldUmask) @@ -907,3 +1058,40 @@ func (ovs *overlordSuite) TestOverlordCanStandby(c *C) { c.Assert(o.CanStandby(), Equals, true) } + +func (ovs *overlordSuite) TestStartOfOperationTimeAlreadySet(c *C) { + o := overlord.Fake() + st := o.State() + st.Lock() + defer st.Unlock() + + op := time.Now().AddDate(0, -1, 0) + st.Set("start-of-operation-time", op) + + operationTime, err := o.StartOfOperationTime() + c.Assert(err, IsNil) + c.Check(operationTime.Equal(op), Equals, true) +} + +func (s *overlordSuite) TestStartOfOperationSetTime(c *C) { + o := overlord.Fake() + st := o.State() + st.Lock() + defer st.Unlock() + + now := time.Now().Add(-1 * time.Second) + overlord.FakeTimeNow(func() time.Time { + return now + }) + + operationTime, err := o.StartOfOperationTime() + c.Assert(err, IsNil) + c.Check(operationTime.Equal(now), Equals, true) + + // repeated call returns already set time + prev := now + now = time.Now().Add(-10 * time.Hour) + operationTime, err = o.StartOfOperationTime() + c.Assert(err, IsNil) + c.Check(operationTime.Equal(prev), Equals, true) +} diff --git a/internals/overlord/patch/patch.go b/internals/overlord/patch/patch.go index afc93186e..58514a8a4 100644 --- a/internals/overlord/patch/patch.go +++ b/internals/overlord/patch/patch.go @@ -20,6 +20,7 @@ package patch import ( + "errors" "fmt" "github.com/canonical/pebble/internals/logger" @@ -43,12 +44,12 @@ var patches = make(map[int][]PatchFunc) func Init(s *state.State) { s.Lock() defer s.Unlock() - if s.Get("patch-level", new(int)) != state.ErrNoState { + if err := s.Get("patch-level", new(int)); !errors.Is(err, state.ErrNoState) { panic("internal error: expected empty state, attempting to override patch-level without actual patching") } s.Set("patch-level", Level) - if s.Get("patch-sublevel", new(int)) != state.ErrNoState { + if err := s.Get("patch-sublevel", new(int)); !errors.Is(err, state.ErrNoState) { panic("internal error: expected empty state, attempting to override patch-sublevel without actual patching") } s.Set("patch-sublevel", Sublevel) @@ -76,12 +77,12 @@ func Apply(s *state.State) error { var stateLevel, stateSublevel int s.Lock() err := s.Get("patch-level", &stateLevel) - if err == nil || err == state.ErrNoState { + if err == nil || errors.Is(err, state.ErrNoState) { err = s.Get("patch-sublevel", &stateSublevel) } s.Unlock() - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return err } diff --git a/internals/overlord/restart/restart.go b/internals/overlord/restart/restart.go index 1e545b8cb..f52f09174 100644 --- a/internals/overlord/restart/restart.go +++ b/internals/overlord/restart/restart.go @@ -16,6 +16,8 @@ package restart import ( + "errors" + "github.com/canonical/pebble/internals/overlord/state" ) @@ -63,7 +65,7 @@ func Init(st *state.State, curBootID string, h Handler) error { } var fromBootID string err := st.Get("system-restart-from-boot-id", &fromBootID) - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return err } st.Cache(restartStateKey{}, rs) diff --git a/internals/overlord/restart/restart_test.go b/internals/overlord/restart/restart_test.go index 8a8617ed3..74fc07646 100644 --- a/internals/overlord/restart/restart_test.go +++ b/internals/overlord/restart/restart_test.go @@ -21,6 +21,7 @@ import ( "github.com/canonical/pebble/internals/overlord/restart" "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/testutil" ) func TestRestart(t *testing.T) { TestingT(t) } @@ -134,5 +135,5 @@ func (s *restartSuite) TestRequestRestartSystemAndVerifyReboot(c *C) { err = restart.Init(st, "boot-id-2", h2) c.Assert(err, IsNil) c.Check(h2.rebootAsExpected, Equals, true) - c.Check(st.Get("system-restart-from-boot-id", &fromBootID), Equals, state.ErrNoState) + c.Check(st.Get("system-restart-from-boot-id", &fromBootID), testutil.ErrorIs, state.ErrNoState) } diff --git a/internals/overlord/servstate/handlers.go b/internals/overlord/servstate/handlers.go index 1e38dad04..f753e8ea7 100644 --- a/internals/overlord/servstate/handlers.go +++ b/internals/overlord/servstate/handlers.go @@ -1,6 +1,7 @@ package servstate import ( + "errors" "fmt" "io" "os" @@ -28,7 +29,7 @@ import ( func TaskServiceRequest(task *state.Task) (*ServiceRequest, error) { req := &ServiceRequest{} err := task.Get("service-request", req) - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return nil, err } if err == nil { diff --git a/internals/overlord/state/change.go b/internals/overlord/state/change.go index bfb3b42db..3ab11ddfa 100644 --- a/internals/overlord/state/change.go +++ b/internals/overlord/state/change.go @@ -1,21 +1,16 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (c) 2016 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . package state @@ -23,8 +18,11 @@ import ( "bytes" "encoding/json" "fmt" + "sort" "strings" "time" + + "github.com/canonical/pebble/internals/logger" ) // Status is used for status values for changes and tasks. @@ -37,7 +35,8 @@ const ( // to an aggregation of its tasks' statuses. See Change.Status for details. DefaultStatus Status = 0 - // HoldStatus means the task should not run, perhaps as a consequence of an error on another task. + // HoldStatus means the task should not run for the moment, perhaps as a + // consequence of an error on another task. HoldStatus Status = 1 // DoStatus means the change or task is ready to start. @@ -65,6 +64,10 @@ const ( // ErrorStatus means the change or task has errored out while running or being undone. ErrorStatus Status = 9 + // WaitStatus means the task was accomplished successfully but some + // external event needs to happen before work can progress further. + WaitStatus Status = 10 + nStatuses = iota ) @@ -88,6 +91,8 @@ func (s Status) String() string { return "Doing" case DoneStatus: return "Done" + case WaitStatus: + return "Wait" case AbortStatus: return "Abort" case UndoStatus: @@ -104,6 +109,18 @@ func (s Status) String() string { panic(fmt.Sprintf("internal error: unknown task status code: %d", s)) } +// taskWaitComputeStatus is used while computing the wait status of a +// change. It keeps track of whether a task is waiting or not waiting, or the +// computation for it is still in-progress to detect cyclic dependencies. +type taskWaitComputeStatus int + +const ( + taskWaitStatusNotComputed taskWaitComputeStatus = iota + taskWaitStatusComputing + taskWaitStatusNotWaiting + taskWaitStatusWaiting +) + // Change represents a tracked modification to the system state. // // The Change provides both the justification for individual tasks @@ -115,16 +132,16 @@ func (s Status) String() string { // while the individual Task values would track the running of // the hooks themselves. type Change struct { - state *State - id string - kind string - summary string - status Status - clean bool - data customData - taskIDs []string - lanes int - ready chan struct{} + state *State + id string + kind string + summary string + status Status + clean bool + data customData + taskIDs []string + ready chan struct{} + lastObservedStatus Status spawnTime time.Time readyTime time.Time @@ -157,7 +174,6 @@ type marshalledChange struct { Clean bool `json:"clean,omitempty"` Data map[string]*json.RawMessage `json:"data,omitempty"` TaskIDs []string `json:"task-ids,omitempty"` - Lanes int `json:"lanes,omitempty"` SpawnTime time.Time `json:"spawn-time"` ReadyTime *time.Time `json:"ready-time,omitempty"` @@ -178,7 +194,6 @@ func (c *Change) MarshalJSON() ([]byte, error) { Clean: c.clean, Data: c.data, TaskIDs: c.taskIDs, - Lanes: c.lanes, SpawnTime: c.spawnTime, ReadyTime: readyTime, @@ -206,7 +221,6 @@ func (c *Change) UnmarshalJSON(data []byte) error { } c.data = custData c.taskIDs = unmarshalled.TaskIDs - c.lanes = unmarshalled.Lanes c.ready = make(chan struct{}) c.spawnTime = unmarshalled.SpawnTime if unmarshalled.ReadyTime != nil { @@ -251,12 +265,19 @@ func (c *Change) Get(key string, value interface{}) error { return c.data.get(key, value) } +// Has returns whether the provided key has an associated value. +func (c *Change) Has(key string) bool { + c.state.reading() + return c.data.has(key) +} + var statusOrder = []Status{ AbortStatus, UndoingStatus, UndoStatus, DoingStatus, DoStatus, + WaitStatus, ErrorStatus, UndoneStatus, DoneStatus, @@ -269,32 +290,138 @@ func init() { } } +func (c *Change) isTaskWaiting(visited map[string]taskWaitComputeStatus, t *Task, deps []*Task) bool { + taskID := t.ID() + // Retrieve the compute status of the wait for the task, if not + // computed this defaults to 0 (taskWaitStatusNotComputed). + computeStatus := visited[taskID] + switch computeStatus { + case taskWaitStatusComputing: + // Cyclic dependency detected, return false to short-circuit. + logger.Noticef("detected cyclic dependencies for task %q in change %q", t.Kind(), t.Change().Kind()) + // Make sure errors show up in "pebble change " too + t.Logf("detected cyclic dependencies for task %q in change %q", t.Kind(), t.Change().Kind()) + return false + case taskWaitStatusWaiting, taskWaitStatusNotWaiting: + return computeStatus == taskWaitStatusWaiting + } + visited[taskID] = taskWaitStatusComputing + + var isWaiting bool +depscheck: + for _, wt := range deps { + switch wt.Status() { + case WaitStatus: + isWaiting = true + // States that can be valid when waiting + // - Done, Undone, ErrorStatus, HoldStatus + case DoneStatus, UndoneStatus, ErrorStatus, HoldStatus: + continue + // For 'Do' and 'Undo' we have to check whether the task is waiting + // for any dependencies. The logic is the same, but the set of tasks + // varies. + case DoStatus: + isWaiting = c.isTaskWaiting(visited, wt, wt.WaitTasks()) + if !isWaiting { + // Cancel early if we detect something is runnable. + break depscheck + } + case UndoStatus: + isWaiting = c.isTaskWaiting(visited, wt, wt.HaltTasks()) + if !isWaiting { + // Cancel early if we detect something is runnable. + break depscheck + } + default: + // When we determine the change can not be in a wait-state then + // break early. + isWaiting = false + break depscheck + } + } + if isWaiting { + visited[taskID] = taskWaitStatusWaiting + } else { + visited[taskID] = taskWaitStatusNotWaiting + } + return isWaiting +} + +// isChangeWaiting should only ever return true iff it determines all tasks in Do/Undo +// are blocked by tasks in either of three states: 'DoneStatus', 'UndoneStatus' or 'WaitStatus', +// if this fails, we default to the normal status ordering logic. +func (c *Change) isChangeWaiting() bool { + // Since we might visit tasks more than once, we store results to avoid recomputing them. + visited := make(map[string]taskWaitComputeStatus) + for _, t := range c.Tasks() { + switch t.Status() { + case WaitStatus, DoneStatus, UndoneStatus, ErrorStatus, HoldStatus: + continue + case DoStatus: + if !c.isTaskWaiting(visited, t, t.WaitTasks()) { + return false + } + case UndoStatus: + if !c.isTaskWaiting(visited, t, t.HaltTasks()) { + return false + } + default: + return false + } + } + // If we end up here, then return true as we know we + // have at least one waiter in this change. + return true +} + // Status returns the current status of the change. // If the status was not explicitly set the result is derived from the status // of the individual tasks related to the change, according to the following // decision sequence: // +// - With all pending tasks blocked by other tasks in WaitStatus, return WaitStatus // - With at least one task in DoStatus, return DoStatus // - With at least one task in ErrorStatus, return ErrorStatus // - Otherwise, return DoneStatus func (c *Change) Status() Status { c.state.reading() - if c.status == DefaultStatus { - if len(c.taskIDs) == 0 { - return HoldStatus - } - statusStats := make([]int, nStatuses) - for _, tid := range c.taskIDs { - statusStats[c.state.tasks[tid].Status()]++ + if c.status != DefaultStatus { + return c.status + } + + if len(c.taskIDs) == 0 { + return HoldStatus + } + + statusStats := make([]int, nStatuses) + for _, tid := range c.taskIDs { + statusStats[c.state.tasks[tid].Status()]++ + } + + // If the change has any waiters, check for any runnable tasks + // or whether it's completely blocked by waiters. + if statusStats[WaitStatus] > 0 { + // Only if the change has all tasks blocked we return WaitStatus. + if c.isChangeWaiting() { + return WaitStatus } - for _, s := range statusOrder { - if statusStats[s] > 0 { - return s - } + } + + // Otherwise we return the current status with the highest priority. + for _, s := range statusOrder { + if statusStats[s] > 0 { + return s } - panic(fmt.Sprintf("internal error: cannot process change status: %v", statusStats)) } - return c.status + panic(fmt.Sprintf("internal error: cannot process change status: %v", statusStats)) +} + +func (c *Change) notifyStatusChange(new Status) { + if c.lastObservedStatus == new { + return + } + c.state.notifyChangeStatusChangedHandlers(c, c.lastObservedStatus, new) + c.lastObservedStatus = new } // SetStatus sets the change status, overriding the default behavior (see Status method). @@ -304,6 +431,7 @@ func (c *Change) SetStatus(s Status) { if s.Ready() { c.markReady() } + c.notifyStatusChange(c.Status()) } func (c *Change) markReady() { @@ -322,15 +450,10 @@ func (c *Change) Ready() <-chan struct{} { return c.ready } -// taskStatusChanged is called by tasks when their status is changed, -// to give the opportunity for the change to close its ready channel. -func (c *Change) taskStatusChanged(t *Task, old, new Status) { - if old.Ready() == new.Ready() { - return - } +func (c *Change) detectChangeReady(excludeTask *Task) { for _, tid := range c.taskIDs { task := c.state.tasks[tid] - if task != t && !task.status.Ready() { + if task != excludeTask && !task.status.Ready() { return } } @@ -343,6 +466,21 @@ func (c *Change) taskStatusChanged(t *Task, old, new Status) { c.markReady() } +// taskStatusChanged is called by tasks when their status is changed, +// to give the opportunity for the change to close its ready channel, and +// notify observers of Change changes. +func (c *Change) taskStatusChanged(t *Task, old, new Status) { + cs := c.Status() + // If the task changes from ready => unready or unready => ready, + // update the ready status for the change. + if old.Ready() == new.Ready() { + c.notifyStatusChange(cs) + return + } + c.detectChangeReady(t) + c.notifyStatusChange(cs) +} + // IsClean returns whether all tasks in the change have been cleaned. See SetClean. func (c *Change) IsClean() bool { c.state.reading() @@ -519,6 +657,44 @@ func (c *Change) AbortLanes(lanes []int) { c.abortLanes(lanes, make(map[int]bool), make(map[string]bool)) } +// AbortUnreadyLanes aborts the tasks from lanes that aren't fully ready, where +// a ready lane is one in which all tasks are ready. +func (c *Change) AbortUnreadyLanes() { + c.state.writing() + c.abortUnreadyLanes() +} + +func (c *Change) abortUnreadyLanes() { + lanesWithLiveTasks := map[int]bool{} + + for _, tid := range c.taskIDs { + t := c.state.tasks[tid] + if !t.Status().Ready() { + for _, tlane := range t.Lanes() { + lanesWithLiveTasks[tlane] = true + } + } + } + + abortLanes := []int{} + for lane := range lanesWithLiveTasks { + abortLanes = append(abortLanes, lane) + } + c.abortLanes(abortLanes, make(map[int]bool), make(map[string]bool)) +} + +// taskEffectiveStatus returns the 'effective' status. This means it accounts +// for tasks being in WaitStatus, and instead of returning the WaitStatus we +// return the actual status. (The status after the wait). +func taskEffectiveStatus(t *Task) Status { + status := t.Status() + if status == WaitStatus { + // If the task is waiting, then use the effective status instead. + status = t.WaitedStatus() + } + return status +} + func (c *Change) abortLanes(lanes []int, abortedLanes map[int]bool, seenTasks map[string]bool) { var hasLive = make(map[int]bool) var hasDead = make(map[int]bool) @@ -528,7 +704,7 @@ NextChangeTask: t := c.state.tasks[tid] var live bool - switch t.Status() { + switch taskEffectiveStatus(t) { case DoStatus, DoingStatus, DoneStatus: live = true } @@ -579,7 +755,7 @@ func (c *Change) abortTasks(tasks []*Task, abortedLanes map[int]bool, seenTasks continue } seenTasks[t.id] = true - switch t.Status() { + switch taskEffectiveStatus(t) { case DoStatus: // Still pending so don't even start. t.SetStatus(HoldStatus) @@ -607,3 +783,87 @@ func (c *Change) abortTasks(tasks []*Task, abortedLanes map[int]bool, seenTasks c.abortLanes(lanes, abortedLanes, seenTasks) } } + +type TaskDependencyCycleError struct { + IDs []string + msg string +} + +func (e *TaskDependencyCycleError) Error() string { return e.msg } + +func (e *TaskDependencyCycleError) Is(err error) bool { + _, ok := err.(*TaskDependencyCycleError) + return ok +} + +// CheckTaskDependencies checks the tasks in the change for cyclic dependencies +// and returns an error in such case. +func (c *Change) CheckTaskDependencies() error { + tasks := c.Tasks() + // count how many tasks any given non-independent task waits for + predecessors := make(map[string]int, len(tasks)) + + taskByID := map[string]*Task{} + for _, t := range tasks { + taskByID[t.id] = t + if l := len(t.waitTasks); l > 0 { + // only add an entry if the task is not independent + predecessors[t.id] = l + } + } + + // Kahn topological sort: make our way starting with tasks that are + // independent (their predecessors count is 0), then visit their direct + // successors (halt tasks), and for each reduce their predecessors + // count; once the count drops to 0, all direct dependencies of a given + // task have been accounted for and the task becomes independent. + + // queue of tasks to check + queue := make([]string, 0, len(tasks)) + // identify all independent tasks + for _, t := range tasks { + if predecessors[t.id] == 0 { + queue = append(queue, t.id) + } + } + + for len(queue) > 0 { + // take the first independent task + id := queue[0] + queue = queue[1:] + // reduce the incoming edge of its successors + for _, successor := range taskByID[id].haltTasks { + predecessors[successor]-- + if predecessors[successor] == 0 { + // a task that was a successor has become + // independent + delete(predecessors, successor) + queue = append(queue, successor) + } + } + } + + if len(predecessors) != 0 { + // tasks that are left cannot have their dependencies satisfied + var unsatisfiedTasks []string + for id := range predecessors { + unsatisfiedTasks = append(unsatisfiedTasks, id) + } + sort.Strings(unsatisfiedTasks) + msg := strings.Builder{} + msg.WriteString("dependency cycle involving tasks [") + for i, id := range unsatisfiedTasks { + t := taskByID[id] + msg.WriteString(fmt.Sprintf("%v:%v", t.id, t.kind)) + if i < len(unsatisfiedTasks)-1 { + msg.WriteRune(' ') + } + } + msg.WriteRune(']') + return &TaskDependencyCycleError{ + IDs: unsatisfiedTasks, + msg: msg.String(), + } + } + return nil +} diff --git a/internals/overlord/state/change_test.go b/internals/overlord/state/change_test.go index 339c94344..273c09052 100644 --- a/internals/overlord/state/change_test.go +++ b/internals/overlord/state/change_test.go @@ -1,25 +1,21 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (c) 2016 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . package state_test import ( + "errors" "fmt" "sort" "strconv" @@ -68,7 +64,7 @@ func (cs *changeSuite) TestReadyTime(c *C) { } func (cs *changeSuite) TestStatusString(c *C) { - for s := state.Status(0); s < state.ErrorStatus+1; s++ { + for s := state.Status(0); s < state.WaitStatus+1; s++ { c.Assert(s.String(), Matches, ".+") } } @@ -88,6 +84,21 @@ func (cs *changeSuite) TestGetSet(c *C) { c.Check(v, Equals, 1) } +func (cs *changeSuite) TestHas(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("install", "...") + c.Check(chg.Has("a"), Equals, false) + + chg.Set("a", 1) + c.Check(chg.Has("a"), Equals, true) + + chg.Set("a", nil) + c.Check(chg.Has("a"), Equals, false) +} + // TODO Better testing of full change roundtripping via JSON. func (cs *changeSuite) TestNewTaskAddTaskAndTasks(c *C) { @@ -227,9 +238,13 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) { tasks := make(map[state.Status]*state.Task) - for s := state.DefaultStatus + 1; s < state.ErrorStatus+1; s++ { + for s := state.DefaultStatus + 1; s < state.WaitStatus+1; s++ { t := st.NewTask("download", s.String()) - t.SetStatus(s) + if s == state.WaitStatus { + t.SetToWait(state.DoneStatus) + } else { + t.SetStatus(s) + } chg.AddTask(t) tasks[s] = t } @@ -240,6 +255,7 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) { state.UndoStatus, state.DoingStatus, state.DoStatus, + state.WaitStatus, state.ErrorStatus, state.UndoneStatus, state.DoneStatus, @@ -252,7 +268,11 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) { if s == s2 { break } - tasks[s2].SetStatus(s) + if s == state.WaitStatus { + tasks[s2].SetToWait(state.DoneStatus) + } else { + tasks[s2].SetStatus(s) + } } c.Assert(chg.Status(), Equals, s) } @@ -432,9 +452,13 @@ func (cs *changeSuite) TestAbort(c *C) { chg := st.NewChange("install", "...") - for s := state.DefaultStatus + 1; s < state.ErrorStatus+1; s++ { + for s := state.DefaultStatus + 1; s < state.WaitStatus+1; s++ { t := st.NewTask("download", s.String()) - t.SetStatus(s) + if s == state.WaitStatus { + t.SetToWait(state.DoneStatus) + } else { + t.SetStatus(s) + } t.Set("old-status", s) chg.AddTask(t) } @@ -451,7 +475,7 @@ func (cs *changeSuite) TestAbort(c *C) { switch s { case state.DoStatus: c.Assert(t.Status(), Equals, state.HoldStatus) - case state.DoneStatus: + case state.DoneStatus, state.WaitStatus: c.Assert(t.Status(), Equals, state.UndoStatus) case state.DoingStatus: c.Assert(t.Status(), Equals, state.AbortStatus) @@ -526,11 +550,13 @@ func (cs *changeSuite) TestAbortKⁿ(c *C) { // Task wait order: // -// => t21 => t22 -// / \ -// t11 => t12 => t41 => t42 -// \ / -// => t31 => t32 +// => t21 => t22 +// / \ +// +// t11 => t12 => t41 => t42 +// +// \ / +// => t31 => t32 // // setup and result lines are :[:,...] // @@ -577,6 +603,10 @@ var abortLanesTests = []struct { setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", abort: []int{2}, result: "t21:hold t22:hold t41:hold t42:hold *:do", + }, { + setup: "t11:done:1 t12:wait:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + abort: []int{2}, + result: "t21:hold t22:hold t41:hold t42:hold t11:done t12:wait *:do", }, { setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", abort: []int{3}, @@ -660,7 +690,7 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) { c.Logf("Testing setup: %s", test.setup) statuses := make(map[string]state.Status) - for s := state.DefaultStatus; s <= state.ErrorStatus; s++ { + for s := state.DefaultStatus; s <= state.WaitStatus; s++ { statuses[strings.ToLower(s.String())] = s } @@ -680,7 +710,11 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) { } seen[parts[0]] = true task := tasks[parts[0]] - task.SetStatus(statuses[parts[1]]) + if statuses[parts[1]] == state.WaitStatus { + task.SetToWait(state.DoneStatus) + } else { + task.SetStatus(statuses[parts[1]]) + } if len(parts) > 2 { lanes := strings.Split(parts[2], ",") for _, lane := range lanes { @@ -723,3 +757,691 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) { c.Assert(strings.Join(obtained, " "), Equals, strings.Join(expected, " "), Commentf("setup: %s", test.setup)) } } + +// setup and result lines are :[:,...] +// order is -> (implies task2 waits for task 1) +// "*" as task name means "all remaining". +var abortUnreadyLanesTests = []struct { + setup string + order string + result string +}{ + + // Some basics. + { + setup: "*:do", + result: "*:hold", + }, { + setup: "*:wait", + result: "*:undo", + }, { + setup: "*:done", + result: "*:done", + }, { + setup: "*:error", + result: "*:error", + }, + + // t11 (1) => t12 (1) => t21 (1) => t22 (1) + // t31 (2) => t32 (2) => t41 (2) => t42 (2) + { + setup: "t11:do:1 t12:do:1 t21:do:1 t22:do:1 t31:do:2 t32:do:2 t41:do:2 t42:do:2", + order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42", + result: "*:hold", + }, { + setup: "t11:done:1 t12:done:1 t21:done:1 t22:done:1 t31:do:2 t32:do:2 t41:do:2 t42:do:2", + order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42", + result: "t11:done t12:done t21:done t22:done t31:hold t32:hold t41:hold t42:hold", + }, { + setup: "t11:done:1 t12:done:1 t21:done:1 t22:done:1 t31:done:2 t32:done:2 t41:done:2 t42:do:2", + order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42", + result: "t11:done t12:done t21:done t22:done t31:undo t32:undo t41:undo t42:hold", + }, + // => t21 (2) => t22 (2) + // / \ + // t11 (2,3) => t12 (2,3) => t41 (4) => t42 (4) + // \ / + // => t31 (3) => t32 (3) + { + setup: "t11:do:2,3 t12:do:2,3 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + result: "*:hold", + }, { + setup: "t11:done:2,3 t12:done:2,3 t21:done:2 t22:done:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + // lane 2 is fully complete so it does not get aborted + result: "t11:done t12:done t21:done t22:done t31:abort t32:hold t41:hold t42:hold *:undo", + }, { + setup: "t11:done:2,3 t12:done:2,3 t21:done:2 t22:done:2 t31:wait:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + // lane 2 is fully complete so it does not get aborted + result: "t11:done t12:done t21:done t22:done t31:undo t32:hold t41:hold t42:hold *:undo", + }, { + setup: "t11:done:2,3 t12:done:2,3 t21:doing:2 t22:do:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + result: "t21:abort t22:hold t31:abort t32:hold t41:hold t42:hold *:undo", + }, + + // t11 (1) => t12 (1) + // t21 (2) => t22 (2) + // t31 (3) => t32 (3) + // t41 (4) => t42 (4) + { + setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t21->t22 t31->t32 t41->t42", + result: "*:hold", + }, { + setup: "t11:do:1 t12:do:1 t21:doing:2 t22:do:2 t31:done:3 t32:doing:3 t41:undone:4 t42:error:4", + order: "t11->t12 t21->t22 t31->t32 t41->t42", + result: "t11:hold t12:hold t21:abort t22:hold t31:undo t32:abort t41:undone t42:error", + }, + // auto refresh like arrangement + // + // (apps) + // => t31 (3) => t32 (3) + // (snapd) (base) / + // t11 (1) => t12 (1) => t21 (2) => t22 (2) + // \ + // => t41 (4) => t42 (4) + { + setup: "t11:done:1 t12:done:1 t21:done:2 t22:done:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42", + result: "t11:done t12:done t21:done t22:done t31:abort *:hold", + }, { + // + setup: "t11:done:1 t12:done:1 t21:done:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42", + result: "t11:done t12:done t21:undo *:hold", + }, + // arrangement with a cyclic dependency between tasks + // + // /-----------------------------------------\ + // | | + // | => t31 (3) => t32 (3) / + // (snapd) v (base) / + // t11 (1) => t12 (1) => t21 (2) => t22 (2) + // \ + // => t41 (4) => t42 (4) + { + setup: "t11:done:1 t12:done:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42 t32->t21", + result: "t11:done t12:done *:hold", + }, +} + +func (ts *taskRunnerSuite) TestAbortUnreadyLanes(c *C) { + + names := strings.Fields("t11 t12 t21 t22 t31 t32 t41 t42") + + for i, test := range abortUnreadyLanesTests { + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + st.Lock() + defer st.Unlock() + + c.Assert(len(st.Tasks()), Equals, 0) + + chg := st.NewChange("install", "...") + tasks := make(map[string]*state.Task) + for _, name := range names { + tasks[name] = st.NewTask("do", name) + chg.AddTask(tasks[name]) + } + + c.Logf("----- %v", i) + c.Logf("Testing setup: %s", test.setup) + + for _, wp := range strings.Fields(test.order) { + pair := strings.Split(wp, "->") + c.Assert(pair, HasLen, 2) + // task 2 waits for task 1 is denoted as: + // task1->task2 + tasks[pair[1]].WaitFor(tasks[pair[0]]) + } + + statuses := make(map[string]state.Status) + for s := state.DefaultStatus; s <= state.WaitStatus; s++ { + statuses[strings.ToLower(s.String())] = s + } + + items := strings.Fields(test.setup) + seen := make(map[string]bool) + for i := 0; i < len(items); i++ { + item := items[i] + parts := strings.Split(item, ":") + if parts[0] == "*" { + c.Assert(i, Equals, len(items)-1, Commentf("*: can only be used as the last entry")) + for _, name := range names { + if !seen[name] { + parts[0] = name + items = append(items, strings.Join(parts, ":")) + } + } + continue + } + seen[parts[0]] = true + task := tasks[parts[0]] + if statuses[parts[1]] == state.WaitStatus { + task.SetToWait(state.DoneStatus) + } else { + task.SetStatus(statuses[parts[1]]) + } + if len(parts) > 2 { + lanes := strings.Split(parts[2], ",") + for _, lane := range lanes { + n, err := strconv.Atoi(lane) + c.Assert(err, IsNil) + task.JoinLane(n) + } + } + } + + c.Logf("Aborting") + + chg.AbortUnreadyLanes() + + c.Logf("Expected result: %s", test.result) + + seen = make(map[string]bool) + var expected = strings.Fields(test.result) + var obtained []string + for i := 0; i < len(expected); i++ { + item := expected[i] + parts := strings.Split(item, ":") + if parts[0] == "*" { + c.Assert(i, Equals, len(expected)-1, Commentf("*: can only be used as the last entry")) + var expanded []string + for _, name := range names { + if !seen[name] { + parts[0] = name + expanded = append(expanded, strings.Join(parts, ":")) + } + } + expected = append(expected[:i], append(expanded, expected[i+1:]...)...) + i-- + continue + } + name := parts[0] + seen[parts[0]] = true + obtained = append(obtained, name+":"+strings.ToLower(tasks[name].Status().String())) + } + + c.Assert(strings.Join(obtained, " "), Equals, strings.Join(expected, " "), Commentf("setup: %s", test.setup)) + } +} + +// setup is a list of tasks " ", order is -> +// (implies task2 waits for task 1) +var cyclicDependencyTests = []struct { + setup string + order string + err string + errIDs []string +}{ + + // Some basics. + { + setup: "t1", + }, { + setup: "", + }, { + // independent tasks + setup: "t1 t2 t3", + }, { + // some independent and some ordered tasks + setup: "t1 t2 t3 t4", + order: "t2->t3", + }, + // some independent, dependencies as if added by WaitAll() + // t1 => t2 + // t1,t2 => t3 + // t1,t2,t3 => t4 + { + setup: "t1 t2 t3 t4", + order: "t1->t2 t1->t3 t2->t3 t1->t4 t2->t4 t3->t4", + }, { + // simple loop + setup: "t1 t2", + order: "t1->t2 t2->t1", + err: `dependency cycle involving tasks \[1:t1 2:t2\]`, + errIDs: []string{"1", "2"}, + }, + + // t1 => t2 => t3 => t4 + // t5 => t6 => t7 => t8 + { + setup: "t1 t2 t3 t4 t5 t6 t7 t8", + order: "t1->t2 t2->t3 t3->t4 t5->t6 t6->t7 t7->t8", + }, + // => t21 => t22 + // / \ + // t11 => t12 => t41 => t42 + // \ / + // => t31 => t32 + { + setup: "t11 t12 t21 t22 t31 t32 t41 t42", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + }, + // t11 (1) => t12 (1) + // t21 (2) => t22 (2) + // t31 (3) => t32 (3) + // t41 (4) => t42 (4) + { + setup: "t11 t12 t21 t22 t31 t32 t41 t42", + order: "t11->t12 t21->t22 t31->t32 t41->t42", + }, + // auto refresh like arrangement + // + // (apps) + // => t31 (3) => t32 (3) + // (snapd) (base) / + // t11 (1) => t12 (1) => t21 (2) => t22 (2) + // \ + // => t41 (4) => t42 (4) + { + setup: "t11 t12 t21 t22 t31 t32 t41 t42", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42", + }, + // arrangement with a cyclic dependency between tasks + // + // /-----------------------------------------\ + // | | + // | => t31 (3) => t32 (3) / + // (snapd) v (base) / + // t11 (1) => t12 (1) => t21 (2) => t22 (2) + // \ + // => t41 (4) => t42 (4) + { + setup: "t11 t12 t21 t22 t31 t32 t41 t42", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42 t32->t21", + err: `dependency cycle involving tasks \[3:t21 4:t22 5:t31 6:t32 7:t41 8:t42\]`, + errIDs: []string{"3", "4", "5", "6", "7", "8"}, + }, + // t1 => t2 => t3 => t4 --> t6 + // t5 => t6 => t7 => t8 --> t2 + { + setup: "t1 t2 t3 t4 t5 t6 t7 t8", + order: "t1->t2 t2->t3 t3->t4 t4->t6 t5->t6 t6->t7 t7->t8 t8->t2", + err: `dependency cycle involving tasks \[2:t2 3:t3 4:t4 6:t6 7:t7 8:t8\]`, + errIDs: []string{"2", "3", "4", "6", "7", "8"}, + }, +} + +func (ts *taskRunnerSuite) TestCheckTaskDependencies(c *C) { + + for i, test := range cyclicDependencyTests { + names := strings.Fields(test.setup) + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + st.Lock() + defer st.Unlock() + + c.Assert(len(st.Tasks()), Equals, 0) + + chg := st.NewChange("install", "...") + tasks := make(map[string]*state.Task) + for _, name := range names { + tasks[name] = st.NewTask(name, name) + chg.AddTask(tasks[name]) + } + + c.Logf("----- %v", i) + c.Logf("Testing setup: %s", test.setup) + + for _, wp := range strings.Fields(test.order) { + pair := strings.Split(wp, "->") + c.Assert(pair, HasLen, 2) + // task 2 waits for task 1 is denoted as: + // task1->task2 + tasks[pair[1]].WaitFor(tasks[pair[0]]) + } + + err := chg.CheckTaskDependencies() + + if test.err != "" { + c.Assert(err, ErrorMatches, test.err) + c.Assert(errors.Is(err, &state.TaskDependencyCycleError{}), Equals, true) + errTasksDepCycle := err.(*state.TaskDependencyCycleError) + c.Assert(errTasksDepCycle.IDs, DeepEquals, test.errIDs) + } else { + c.Assert(err, IsNil) + } + } +} + +func (cs *changeSuite) TestIsWaitingStatusOrderWithWaits(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("task3", "...") + t4 := st.NewTask("wait-task", "...") + t1.WaitFor(t2) + t1.WaitFor(t3) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + chg.AddTask(t4) + + // Set the wait-task into WaitStatus, to ensure we trigger the isWaiting + // logic and that it doesn't return WaitStatus for statuses which are in + // higher order + t4.SetToWait(state.DoneStatus) + + // Test the following sequences: + // task1 (do) => task2 (done) => task3 (doing) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.DoingStatus) + c.Check(chg.Status(), Equals, state.DoingStatus) + + // task1 (done) => task2 (done) => task3 (undoing) + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.UndoingStatus) + c.Check(chg.Status(), Equals, state.UndoingStatus) + + // task1 (done) => task2 (done) => task3 (abort) + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.AbortStatus) + c.Check(chg.Status(), Equals, state.AbortStatus) +} + +func (cs *changeSuite) TestIsWaitingSingle(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + + chg.AddTask(t1) + c.Check(chg.Status(), Equals, state.DoStatus) + + t1.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} + +func (cs *changeSuite) TestIsWaitingTwoTasks(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("wait-task", "...") + t2.WaitFor(t1) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + + // Put t3 into wait-status to trigger the isWaiting logic each time + // for the change. + t3.SetToWait(state.DoneStatus) + + // task1 (do) => task2 (do) no reboot + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (done) => task2 (do) no reboot + t1.SetStatus(state.DoneStatus) + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (wait) => task2 (do) means need a reboot + t1.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (done) => task2 (wait) means need a reboot + t1.SetStatus(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} + +func (cs *changeSuite) TestIsWaitingCircularDependency(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("task3", "...") + t4 := st.NewTask("wait-task", "...") + + // Setup circular dependency between t1,t2 and t3, they should + // still act normally. + t2.WaitFor(t1) + t3.WaitFor(t2) + t1.WaitFor(t3) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + chg.AddTask(t4) + + // To trigger the cyclic dependency check, we must trigger the isWaiting logic + // and we do this by putting t4 into WaitStatus. + t4.SetToWait(state.DoneStatus) + + // task1 (do) => task2 (do) => task3 (do) no reboot + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (done) => task2 (do) => task3 (do) no reboot + t1.SetStatus(state.DoneStatus) + t2.SetStatus(state.DoingStatus) + c.Check(chg.Status(), Equals, state.DoingStatus) + + // task1 (wait) => task2 (do) => task3 (do) means need a reboot + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (done) => task2 (wait) => task3 (do) means need a reboot + t1.SetStatus(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} + +func (cs *changeSuite) TestIsWaitingMultipleDependencies(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("task3", "...") + t4 := st.NewTask("wait-task", "...") + t3.WaitFor(t1) + t3.WaitFor(t2) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + chg.AddTask(t4) + + // Put t4 into wait-status to trigger the isWaiting logic each time + // for the change. + t4.SetToWait(state.DoneStatus) + + // task1 (do) + task2 (do) => task3 (do) no reboot + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (done) + task2 (done) => task3 (do) no reboot + t1.SetStatus(state.DoneStatus) + t2.SetStatus(state.DoneStatus) + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (done) + task2 (do) => task3 (do) no reboot + t1.SetStatus(state.DoneStatus) + t2.SetStatus(state.DoStatus) + c.Check(chg.Status(), Equals, state.DoStatus) + + // For the next two cases we are testing that a task with dependencies + // which have completed, but in a non-successful way is handled correctly. + // task1 (error) + task2 (wait) => task3 (do) means need reboot + // to finalize task2 + t1.SetStatus(state.ErrorStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (error) => task3 (do) means need reboot + // to finalize task1 + t1.SetToWait(state.DoneStatus) + t2.SetStatus(state.ErrorStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (done) + task2 (wait) => task3 (do) means need a reboot + t1.SetStatus(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (wait) => task3 (do) means need a reboot + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (done) + task2 (done) => task3 (wait) means need a reboot + t1.SetStatus(state.DoneStatus) + t2.SetStatus(state.DoneStatus) + t3.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (abort) => task3 (do) + t1.SetToWait(state.DoneStatus) + t2.SetStatus(state.AbortStatus) + t3.SetStatus(state.DoStatus) + c.Check(chg.Status(), Equals, state.AbortStatus) +} + +func (cs *changeSuite) TestIsWaitingUndoTwoTasks(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("wait-task", "...") + t2.WaitFor(t1) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + + // Put t3 into wait-status to trigger the isWaiting logic each time + // for the change. + t3.SetToWait(state.DoneStatus) + + // we use <=| to denote the reverse dependence relationship + // followed by undo logic + + // task1 (undo) <=| task2 (undo) no reboot + t1.SetStatus(state.UndoStatus) + t2.SetStatus(state.UndoStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (undo) <=| task2 (undone) no reboot + t1.SetStatus(state.UndoStatus) + t2.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (undo) <=| task2 (wait) means need a reboot + t1.SetStatus(state.UndoStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) <=| task2 (undone) means need a reboot + t1.SetToWait(state.DoneStatus) + t2.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} + +func (cs *changeSuite) TestIsWaitingUndoMultipleDependencies(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("task3", "...") + t4 := st.NewTask("task4", "...") + t5 := st.NewTask("wait-task", "...") + t3.WaitFor(t1) + t3.WaitFor(t2) + t4.WaitFor(t1) + t4.WaitFor(t2) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + chg.AddTask(t4) + chg.AddTask(t5) + + // Put t5 into wait-status to trigger the isWaiting logic each time + // for the change. + t5.SetToWait(state.DoneStatus) + + // task1 (undo) + task2 (undo) <=| task3 (undo) no reboot + t1.SetStatus(state.UndoStatus) + t2.SetStatus(state.UndoStatus) + t3.SetStatus(state.UndoStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (undo) + task2 (undo) <=| task3 (undone) no reboot + t1.SetStatus(state.UndoStatus) + t2.SetStatus(state.UndoStatus) + t3.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (undo) + task2 (undo) <=| task3 (wait) + task4 (error) means + // need reboot to continue undoing 1 and 2 + t3.SetStatus(state.ErrorStatus) + t4.SetToWait(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (undo) + task2 (undo) => task3 (error) + task4 (wait) means + // need reboot to continue undoing 1 and 2 + t3.SetToWait(state.UndoneStatus) + t4.SetStatus(state.ErrorStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (wait) <=| task3 (undone) + task4 (undo) no reboot + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.UndoneStatus) + t4.SetStatus(state.UndoStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (wait) + task2 (done) <=| task3 (undone) + task4 (undone) means need a reboot + t1.SetToWait(state.DoneStatus) + t2.SetStatus(state.DoneStatus) + t3.SetStatus(state.UndoneStatus) + t4.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (wait) <=| task3 (undone) + task4 (undone) means need a reboot + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.UndoneStatus) + t4.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} diff --git a/internals/overlord/state/export_test.go b/internals/overlord/state/export_test.go index 46d22f4d6..b307dc5c6 100644 --- a/internals/overlord/state/export_test.go +++ b/internals/overlord/state/export_test.go @@ -1,21 +1,16 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (c) 2016 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . package state diff --git a/internals/overlord/state/notices_test.go b/internals/overlord/state/notices_test.go index 1a1e1eeb5..88840fe17 100644 --- a/internals/overlord/state/notices_test.go +++ b/internals/overlord/state/notices_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Canonical Ltd +// Copyright (c) 2024 Canonical Ltd // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License version 3 as @@ -11,7 +11,6 @@ // // You should have received a copy of the GNU General Public License // along with this program. If not, see . - package state_test import ( @@ -448,7 +447,7 @@ func (s *noticesSuite) TestDeleteExpired(c *C) { addNotice(c, st, nil, state.CustomNotice, "foo.com/z", nil) c.Assert(st.NumNotices(), Equals, 4) - st.Prune(0, 0, 0) + st.Prune(time.Now(), 0, 0, 0) c.Assert(st.NumNotices(), Equals, 2) notices := st.Notices(nil) diff --git a/internals/overlord/state/state.go b/internals/overlord/state/state.go index b848ecb69..e72dc3ef8 100644 --- a/internals/overlord/state/state.go +++ b/internals/overlord/state/state.go @@ -1,21 +1,16 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (c) 2016 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . // Package state implements the representation of system state. package state @@ -46,7 +41,7 @@ type customData map[string]*json.RawMessage func (data customData) get(key string, value interface{}) error { entryJSON := data[key] if entryJSON == nil { - return ErrNoState + return &NoStateError{Key: key} } err := json.Unmarshal(*entryJSON, value) if err != nil { @@ -88,6 +83,9 @@ type State struct { lastChangeId int lastLaneId int lastNoticeId int + // lastHandlerId is not serialized, it's only used during runtime + // for registering runtime callbacks + lastHandlerId int backend Backend data customData @@ -101,19 +99,28 @@ type State struct { modified bool cache map[interface{}]interface{} + + pendingChangeByAttr map[string]func(*Change) bool + + // task/changes observing + taskHandlers map[int]func(t *Task, old, new Status) + changeHandlers map[int]func(chg *Change, old, new Status) } // New returns a new empty state. func New(backend Backend) *State { st := &State{ - backend: backend, - data: make(customData), - changes: make(map[string]*Change), - tasks: make(map[string]*Task), - warnings: make(map[string]*Warning), - notices: make(map[noticeKey]*Notice), - modified: true, - cache: make(map[interface{}]interface{}), + backend: backend, + data: make(customData), + changes: make(map[string]*Change), + tasks: make(map[string]*Task), + warnings: make(map[string]*Warning), + notices: make(map[noticeKey]*Notice), + modified: true, + cache: make(map[interface{}]interface{}), + pendingChangeByAttr: make(map[string]func(*Change) bool), + taskHandlers: make(map[int]func(t *Task, old Status, new Status)), + changeHandlers: make(map[int]func(chg *Change, old Status, new Status)), } st.noticeCond = sync.NewCond(st) // use State.Lock and State.Unlock return st @@ -221,6 +228,15 @@ var ( unlockCheckpointRetryInterval = 3 * time.Second ) +// Unlocker returns a closure that will unlock and checkpoint the state and +// in turn return a function to relock it. +func (s *State) Unlocker() (unlock func() (relock func())) { + return func() func() { + s.Unlock() + return s.Lock + } +} + // Unlock releases the state lock and checkpoints the state. // It does not return until the state is correctly checkpointed. // After too many unsuccessful checkpoint attempts, it panics. @@ -254,6 +270,28 @@ func (s *State) EnsureBefore(d time.Duration) { // ErrNoState represents the case of no state entry for a given key. var ErrNoState = errors.New("no state entry for key") +// NoStateError represents the case where no state could be found for a given key. +type NoStateError struct { + // Key is the key for which no state could be found. + Key string +} + +func (e *NoStateError) Error() string { + var keyMsg string + if e.Key != "" { + keyMsg = fmt.Sprintf(" %q", e.Key) + } + + return fmt.Sprintf("no state entry for key%s", keyMsg) +} + +// Is returns true if the error is of type *NoStateError or equal to ErrNoState. +// NoStateError's key isn't compared between errors. +func (e *NoStateError) Is(err error) bool { + _, ok := err.(*NoStateError) + return ok || errors.Is(err, ErrNoState) +} + // Get unmarshals the stored value associated with the provided key // into the value parameter. // It returns ErrNoState if there is no entry for key. @@ -262,6 +300,12 @@ func (s *State) Get(key string, value interface{}) error { return s.data.get(key, value) } +// Has returns whether the provided key has an associated value. +func (s *State) Has(key string) bool { + s.reading() + return s.data.has(key) +} + // Set associates value with key for future consulting by managers. // The provided value must properly marshal and unmarshal with encoding/json. func (s *State) Set(key string, value interface{}) { @@ -370,15 +414,25 @@ func (s *State) tasksIn(tids []string) []*Task { return res } +// RegisterPendingChangeByAttr registers predicates that will be invoked by +// Prune on changes with the specified attribute set to check whether even if +// they meet the time criteria they must not be aborted yet. +func (s *State) RegisterPendingChangeByAttr(attr string, f func(*Change) bool) { + s.pendingChangeByAttr[attr] = f +} + // Prune does several cleanup tasks to the in-memory state: // // - it removes changes that became ready for more than pruneWait and aborts -// tasks spawned for more than abortWait. +// tasks spawned for more than abortWait unless prevented by predicates +// registered with RegisterPendingChangeByAttr. +// // - it removes tasks unlinked to changes after pruneWait. When there are more // changes than the limit set via "maxReadyChanges" those changes in ready // state will also removed even if they are below the pruneWait duration. -// - it removes expired warnings and notices -func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) { +// +// - it removes expired warnings and notices. +func (s *State) Prune(startOfOperation time.Time, pruneWait, abortWait time.Duration, maxReadyChanges int) { now := time.Now() pruneLimit := now.Add(-pruneWait) abortLimit := now.Add(-abortWait) @@ -411,15 +465,24 @@ func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) { } } +NextChange: for _, chg := range changes { - spawnTime := chg.SpawnTime() readyTime := chg.ReadyTime() + spawnTime := chg.SpawnTime() + if spawnTime.Before(startOfOperation) { + spawnTime = startOfOperation + } if readyTime.IsZero() { if spawnTime.Before(pruneLimit) && len(chg.Tasks()) == 0 { chg.Abort() delete(s.changes, chg.ID()) } else if spawnTime.Before(abortLimit) { - chg.Abort() + for attr, pending := range s.pendingChangeByAttr { + if chg.Has(attr) && pending(chg) { + continue NextChange + } + } + chg.AbortUnreadyLanes() } continue } @@ -443,6 +506,75 @@ func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) { } } +// GetMaybeTimings implements timings.GetSaver +func (s *State) GetMaybeTimings(timings interface{}) error { + if err := s.Get("timings", timings); err != nil && !errors.Is(err, ErrNoState) { + return err + } + return nil +} + +// AddTaskStatusChangedHandler adds a callback function that will be invoked +// whenever tasks change status. +// NOTE: Callbacks registered this way may be invoked in the context +// of the taskrunner, so the callbacks should be as simple as possible, and return +// as quickly as possible, and should avoid the use of i/o code or blocking, as this +// will stop the entire task system. +func (s *State) AddTaskStatusChangedHandler(f func(t *Task, old, new Status)) (id int) { + // We are reading here as we want to ensure access to the state is serialized, + // and not writing as we are not changing the part of state that goes on the disk. + s.reading() + id = s.lastHandlerId + s.lastHandlerId++ + s.taskHandlers[id] = f + return id +} + +func (s *State) RemoveTaskStatusChangedHandler(id int) { + s.reading() + delete(s.taskHandlers, id) +} + +func (s *State) notifyTaskStatusChangedHandlers(t *Task, old, new Status) { + s.reading() + for _, f := range s.taskHandlers { + f(t, old, new) + } +} + +// AddChangeStatusChangedHandler adds a callback function that will be invoked +// whenever a Change changes status. +// NOTE: Callbacks registered this way may be invoked in the context +// of the taskrunner, so the callbacks should be as simple as possible, and return +// as quickly as possible, and should avoid the use of i/o code or blocking, as this +// will stop the entire task system. +func (s *State) AddChangeStatusChangedHandler(f func(chg *Change, old, new Status)) (id int) { + // We are reading here as we want to ensure access to the state is serialized, + // and not writing as we are not changing the part of state that goes on the disk. + s.reading() + id = s.lastHandlerId + s.lastHandlerId++ + s.changeHandlers[id] = f + return id +} + +func (s *State) RemoveChangeStatusChangedHandler(id int) { + s.reading() + delete(s.changeHandlers, id) +} + +func (s *State) notifyChangeStatusChangedHandlers(chg *Change, old, new Status) { + s.reading() + for _, f := range s.changeHandlers { + f(chg, old, new) + } +} + +// SaveTimings implements timings.GetSaver +func (s *State) SaveTimings(timings interface{}) { + s.Set("timings", timings) +} + // ReadState returns the state deserialized from r. func ReadState(backend Backend, r io.Reader) (*State, error) { s := new(State) @@ -454,8 +586,11 @@ func ReadState(backend Backend, r io.Reader) (*State, error) { return nil, fmt.Errorf("cannot read state: %s", err) } s.backend = backend + s.noticeCond = sync.NewCond(s) s.modified = false s.cache = make(map[interface{}]interface{}) - s.noticeCond = sync.NewCond(s) + s.pendingChangeByAttr = make(map[string]func(*Change) bool) + s.changeHandlers = make(map[int]func(chg *Change, old Status, new Status)) + s.taskHandlers = make(map[int]func(t *Task, old Status, new Status)) return s, err } diff --git a/internals/overlord/state/state_test.go b/internals/overlord/state/state_test.go index e1e8a7076..b1cf534c6 100644 --- a/internals/overlord/state/state_test.go +++ b/internals/overlord/state/state_test.go @@ -1,21 +1,16 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (c) 2016 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . package state_test @@ -23,12 +18,14 @@ import ( "bytes" "errors" "fmt" + "reflect" "testing" "time" . "gopkg.in/check.v1" "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/testutil" ) func TestState(t *testing.T) { TestingT(t) } @@ -55,6 +52,17 @@ func (ss *stateSuite) TestLockUnlock(c *C) { st.Unlock() } +func (ss *stateSuite) TestUnlocker(c *C) { + st := state.New(nil) + unlocker := st.Unlocker() + st.Lock() + defer st.Unlock() + relock := unlocker() + st.Lock() + st.Unlock() + relock() +} + func (ss *stateSuite) TestGetAndSet(c *C) { st := state.New(nil) st.Lock() @@ -76,6 +84,37 @@ func (ss *stateSuite) TestGetAndSet(c *C) { c.Check(&mSt2B, DeepEquals, mSt2) } +func (ss *stateSuite) TestHas(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + c.Check(st.Has("a"), Equals, false) + + st.Set("a", 1) + c.Check(st.Has("a"), Equals, true) + + st.Set("a", nil) + c.Check(st.Has("a"), Equals, false) +} + +func (ss *stateSuite) TestStrayTaskWithNoChange(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + t1 := st.NewTask("foo", "...") + chg.AddTask(t1) + _ = st.NewTask("bar", "...") + + // only the task with associate change is returned + c.Assert(st.Tasks(), HasLen, 1) + c.Assert(st.Tasks()[0].ID(), Equals, t1.ID()) + // but count includes all tasks + c.Assert(st.TaskCount(), Equals, 2) +} + func (ss *stateSuite) TestSetPanic(c *C) { st := state.New(nil) st.Lock() @@ -94,7 +133,7 @@ func (ss *stateSuite) TestGetNoState(c *C) { var mSt1B mgrState1 err := st.Get("mgr9", &mSt1B) - c.Check(err, Equals, state.ErrNoState) + c.Check(err, testutil.ErrorIs, state.ErrNoState) } func (ss *stateSuite) TestSetToNilDeletes(c *C) { @@ -112,7 +151,7 @@ func (ss *stateSuite) TestSetToNilDeletes(c *C) { var v1 map[string]int err = st.Get("a", &v1) - c.Check(err, Equals, state.ErrNoState) + c.Check(err, testutil.ErrorIs, state.ErrNoState) c.Check(v1, HasLen, 0) } @@ -126,7 +165,7 @@ func (ss *stateSuite) TestNullMeansNoState(c *C) { var v1 map[string]int err = st.Get("a", &v1) - c.Check(err, Equals, state.ErrNoState) + c.Check(err, testutil.ErrorIs, state.ErrNoState) c.Check(v1, HasLen, 0) } @@ -514,6 +553,30 @@ func (ss *stateSuite) TestEmptyStateDataAndCheckpointReadAndSet(c *C) { // no crash st2.Set("a", 1) + + // ensure all maps of state are correctly initialized by ReadState + val := reflect.ValueOf(st2) + typ := val.Elem().Type() + var maps []string + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + if f.Type.Kind() == reflect.Map { + maps = append(maps, f.Name) + fv := val.Elem().Field(i) + c.Check(fv.IsNil(), Equals, false, Commentf("Map field %s of state was not initialized by ReadState", f.Name)) + } + } + c.Check(maps, DeepEquals, []string{ + "data", + "changes", + "tasks", + "warnings", + "notices", + "cache", + "pendingChangeByAttr", + "taskHandlers", + "changeHandlers", + }) } func (ss *stateSuite) TestEmptyTaskAndChangeDataAndCheckpointReadAndSet(c *C) { @@ -702,7 +765,7 @@ func (ss *stateSuite) TestMethodEntrance(c *C) { func() { st.Tasks() }, func() { st.Task("foo") }, func() { st.MarshalJSON() }, - func() { st.Prune(time.Hour, time.Hour, 100) }, + func() { st.Prune(time.Now(), time.Hour, time.Hour, 100) }, func() { st.TaskCount() }, func() { st.AllWarnings() }, func() { st.PendingWarnings() }, @@ -768,7 +831,8 @@ func (ss *stateSuite) TestPrune(c *C) { st.AddWarning("hello", now, never, time.Nanosecond, state.DefaultRepeatAfter) st.Warnf("hello again") - st.Prune(pruneWait, abortWait, 100) + past := time.Now().AddDate(-1, 0, 0) + st.Prune(past, pruneWait, abortWait, 100) c.Assert(st.Change(chg1.ID()), Equals, chg1) c.Assert(st.Change(chg2.ID()), IsNil) @@ -793,6 +857,55 @@ func (ss *stateSuite) TestPrune(c *C) { c.Check(st.AllWarnings(), HasLen, 1) } +func (ss *stateSuite) TestRegisterPendingChangeByAttr(c *C) { + st := state.New(&fakeStateBackend{}) + st.Lock() + defer st.Unlock() + + now := time.Now() + pruneWait := 1 * time.Hour + abortWait := 3 * time.Hour + + unset := time.Time{} + + t1 := st.NewTask("foo", "...") + t2 := st.NewTask("foo", "...") + t3 := st.NewTask("foo", "...") + t4 := st.NewTask("foo", "...") + + chg1 := st.NewChange("abort", "...") + chg1.AddTask(t1) + chg1.AddTask(t2) + state.FakeChangeTimes(chg1, now.Add(-abortWait), unset) + + chg2 := st.NewChange("pending", "...") + chg2.AddTask(t3) + chg2.AddTask(t4) + state.FakeChangeTimes(chg2, now.Add(-abortWait), unset) + chg2.Set("pending-flag", true) + t3.SetStatus(state.HoldStatus) + + st.RegisterPendingChangeByAttr("pending-flag", func(chg *state.Change) bool { + c.Check(chg.ID(), Equals, chg2.ID()) + return true + }) + + past := time.Now().AddDate(-1, 0, 0) + st.Prune(past, pruneWait, abortWait, 100) + + c.Assert(st.Change(chg1.ID()), Equals, chg1) + c.Assert(st.Change(chg2.ID()), Equals, chg2) + c.Assert(st.Task(t1.ID()), Equals, t1) + c.Assert(st.Task(t2.ID()), Equals, t2) + c.Assert(st.Task(t3.ID()), Equals, t3) + c.Assert(st.Task(t4.ID()), Equals, t4) + + c.Assert(t1.Status(), Equals, state.HoldStatus) + c.Assert(t2.Status(), Equals, state.HoldStatus) + c.Assert(t3.Status(), Equals, state.HoldStatus) + c.Assert(t4.Status(), Equals, state.DoStatus) +} + func (ss *stateSuite) TestPruneEmptyChange(c *C) { // Empty changes are a bit special because they start out on Hold // which is a Ready status, but the change itself is not considered Ready @@ -809,7 +922,8 @@ func (ss *stateSuite) TestPruneEmptyChange(c *C) { chg := st.NewChange("abort", "...") state.FakeChangeTimes(chg, now.Add(-pruneWait), time.Time{}) - st.Prune(pruneWait, abortWait, 100) + past := time.Now().AddDate(-1, 0, 0) + st.Prune(past, pruneWait, abortWait, 100) c.Assert(st.Change(chg.ID()), IsNil) } @@ -844,13 +958,14 @@ func (ss *stateSuite) TestPruneMaxChangesHappy(c *C) { // test that nothing is done when we are within pruneWait and // maxReadyChanges + past := time.Now().AddDate(-1, 0, 0) maxReadyChanges := 100 - st.Prune(pruneWait, abortWait, maxReadyChanges) + st.Prune(past, pruneWait, abortWait, maxReadyChanges) c.Assert(st.Changes(), HasLen, 15) // but with maxReadyChanges we remove the ready ones maxReadyChanges = 5 - st.Prune(pruneWait, abortWait, maxReadyChanges) + st.Prune(past, pruneWait, abortWait, maxReadyChanges) c.Assert(st.Changes(), HasLen, 10) remaining := map[string]bool{} for _, chg := range st.Changes() { @@ -886,8 +1001,9 @@ func (ss *stateSuite) TestPruneMaxChangesSomeNotReady(c *C) { c.Assert(st.Changes(), HasLen, 10) // nothing can be pruned + past := time.Now().AddDate(-1, 0, 0) maxChanges := 5 - st.Prune(1*time.Hour, 3*time.Hour, maxChanges) + st.Prune(past, 1*time.Hour, 3*time.Hour, maxChanges) c.Assert(st.Changes(), HasLen, 10) } @@ -905,7 +1021,7 @@ func (ss *stateSuite) TestPruneMaxChangesHonored(c *C) { c.Assert(st.Changes(), HasLen, 10) // one extra change that just now entered ready state - chg := st.NewChange(fmt.Sprintf("chg99"), "so-ready") + chg := st.NewChange("chg99", "so-ready") t := st.NewTask("foo", "so-ready") when := 1 * time.Second state.FakeChangeTimes(chg, time.Now().Add(-when), time.Now().Add(-when)) @@ -916,11 +1032,45 @@ func (ss *stateSuite) TestPruneMaxChangesHonored(c *C) { // // this test we do not purge the freshly ready change maxChanges := 10 - st.Prune(1*time.Hour, 3*time.Hour, maxChanges) + past := time.Now().AddDate(-1, 0, 0) + st.Prune(past, 1*time.Hour, 3*time.Hour, maxChanges) c.Assert(st.Changes(), HasLen, 11) } -func (ss *stateSuite) TestReadStateInitsCache(c *C) { +func (ss *stateSuite) TestPruneHonorsStartOperationTime(c *C) { + st := state.New(&fakeStateBackend{}) + st.Lock() + defer st.Unlock() + + now := time.Now() + + startTime := 2 * time.Hour + spawnTime := 10 * time.Hour + pruneWait := 1 * time.Hour + abortWait := 3 * time.Hour + + chg := st.NewChange("change", "...") + t := st.NewTask("foo", "") + chg.AddTask(t) + // change spawned 10h ago + state.FakeChangeTimes(chg, now.Add(-spawnTime), time.Time{}) + + // start operation time is 2h ago, change is not aborted because + // it's less than abortWait limit. + opTime := now.Add(-startTime) + st.Prune(opTime, pruneWait, abortWait, 100) + c.Assert(st.Changes(), HasLen, 1) + c.Check(chg.Status(), Equals, state.DoStatus) + + // start operation time is 9h ago, change is aborted. + startTime = 9 * time.Hour + opTime = time.Now().Add(-startTime) + st.Prune(opTime, pruneWait, abortWait, 100) + c.Assert(st.Changes(), HasLen, 1) + c.Check(chg.Status(), Equals, state.HoldStatus) +} + +func (ss *stateSuite) TestReadStateInitsTransientMapFields(c *C) { st, err := state.ReadState(nil, bytes.NewBufferString("{}")) c.Assert(err, IsNil) st.Lock() @@ -928,4 +1078,195 @@ func (ss *stateSuite) TestReadStateInitsCache(c *C) { st.Cache("key", "value") c.Assert(st.Cached("key"), Equals, "value") + st.RegisterPendingChangeByAttr("attr", func(*state.Change) bool { return false }) +} + +func (ss *stateSuite) TestTimingsSupport(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + var tims []int + + err := st.GetMaybeTimings(&tims) + c.Assert(err, IsNil) + c.Check(tims, IsNil) + + st.SaveTimings([]int{1, 2, 3}) + + err = st.GetMaybeTimings(&tims) + c.Assert(err, IsNil) + c.Check(tims, DeepEquals, []int{1, 2, 3}) +} + +func (ss *stateSuite) TestNoStateErrorIs(c *C) { + err := &state.NoStateError{Key: "foo"} + c.Assert(err, testutil.ErrorIs, &state.NoStateError{}) + c.Assert(err, testutil.ErrorIs, &state.NoStateError{Key: "bar"}) + c.Assert(err, testutil.ErrorIs, state.ErrNoState) +} + +func (ss *stateSuite) TestNoStateErrorString(c *C) { + err := &state.NoStateError{} + c.Assert(err.Error(), Equals, `no state entry for key`) + err.Key = "foo" + c.Assert(err.Error(), Equals, `no state entry for key "foo"`) +} + +type taskAndStatus struct { + t *state.Task + old, new state.Status +} + +func (ss *stateSuite) TestTaskChangedHandler(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + var taskObservedChanges []taskAndStatus + oId := st.AddTaskStatusChangedHandler(func(t *state.Task, old, new state.Status) { + taskObservedChanges = append(taskObservedChanges, taskAndStatus{ + t: t, + old: old, + new: new, + }) + }) + + t1 := st.NewTask("foo", "...") + + t1.SetStatus(state.DoingStatus) + + // Set task status to identical status, we don't want + // task events when task don't actually change status. + t1.SetStatus(state.DoingStatus) + + // Set task to done. + t1.SetStatus(state.DoneStatus) + + // Unregister us, and make sure we do not receive more events. + st.RemoveTaskStatusChangedHandler(oId) + + // must not appear in list. + t1.SetStatus(state.DoingStatus) + + c.Check(taskObservedChanges, DeepEquals, []taskAndStatus{ + { + t: t1, + old: state.DefaultStatus, + new: state.DoingStatus, + }, + { + t: t1, + old: state.DoingStatus, + new: state.DoneStatus, + }, + }) +} + +type changeAndStatus struct { + chg *state.Change + old, new state.Status +} + +func (ss *stateSuite) TestChangeChangedHandler(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + var observedChanges []changeAndStatus + oId := st.AddChangeStatusChangedHandler(func(chg *state.Change, old, new state.Status) { + observedChanges = append(observedChanges, changeAndStatus{ + chg: chg, + old: old, + new: new, + }) + }) + + chg := st.NewChange("test-chg", "...") + t1 := st.NewTask("foo", "...") + chg.AddTask(t1) + + t1.SetStatus(state.DoingStatus) + + // Set task status to identical status, we don't want + // change events when changes don't actually change status. + t1.SetStatus(state.DoingStatus) + + // Set task to waiting + t1.SetToWait(state.DoneStatus) + + // Unregister us, and make sure we do not receive more events. + st.RemoveChangeStatusChangedHandler(oId) + + // must not appear in list. + t1.SetStatus(state.DoneStatus) + + c.Check(observedChanges, DeepEquals, []changeAndStatus{ + { + chg: chg, + old: state.DefaultStatus, + new: state.DoingStatus, + }, + { + chg: chg, + old: state.DoingStatus, + new: state.WaitStatus, + }, + }) +} + +func (ss *stateSuite) TestChangeSetStatusChangedHandler(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + var observedChanges []changeAndStatus + oId := st.AddChangeStatusChangedHandler(func(chg *state.Change, old, new state.Status) { + observedChanges = append(observedChanges, changeAndStatus{ + chg: chg, + old: old, + new: new, + }) + }) + + chg := st.NewChange("test-chg", "...") + t1 := st.NewTask("foo", "...") + chg.AddTask(t1) + + t1.SetStatus(state.DoingStatus) + + // We have a single task in Doing, now we manipulate the status + // of the change to ensure we are receiving correct events + chg.SetStatus(state.WaitStatus) + + // Change to a new status + chg.SetStatus(state.ErrorStatus) + + // Now return the status back to Default, which should result + // in the change reporting Doing + chg.SetStatus(state.DefaultStatus) + st.RemoveChangeStatusChangedHandler(oId) + + c.Check(observedChanges, DeepEquals, []changeAndStatus{ + { + chg: chg, + old: state.DefaultStatus, + new: state.DoingStatus, + }, + { + chg: chg, + old: state.DoingStatus, + new: state.WaitStatus, + }, + { + chg: chg, + old: state.WaitStatus, + new: state.ErrorStatus, + }, + { + chg: chg, + old: state.ErrorStatus, + new: state.DoingStatus, + }, + }) } diff --git a/internals/overlord/state/task.go b/internals/overlord/state/task.go index a3e2a4864..39a85a20a 100644 --- a/internals/overlord/state/task.go +++ b/internals/overlord/state/task.go @@ -1,21 +1,16 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (c) 2016 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . package state @@ -38,19 +33,22 @@ type progress struct { // // See Change for more details. type Task struct { - state *State - id string - kind string - summary string - status Status - clean bool - progress *progress - data customData - waitTasks []string - haltTasks []string - lanes []int - log []string - change string + state *State + id string + kind string + summary string + status Status + // waitedStatus is the Status that should be used instead of + // WaitStatus once the wait is complete (i.e post reboot). + waitedStatus Status + clean bool + progress *progress + data customData + waitTasks []string + haltTasks []string + lanes []int + log []string + change string spawnTime time.Time readyTime time.Time @@ -77,18 +75,19 @@ func newTask(state *State, id, kind, summary string) *Task { } type marshalledTask struct { - ID string `json:"id"` - Kind string `json:"kind"` - Summary string `json:"summary"` - Status Status `json:"status"` - Clean bool `json:"clean,omitempty"` - Progress *progress `json:"progress,omitempty"` - Data map[string]*json.RawMessage `json:"data,omitempty"` - WaitTasks []string `json:"wait-tasks,omitempty"` - HaltTasks []string `json:"halt-tasks,omitempty"` - Lanes []int `json:"lanes,omitempty"` - Log []string `json:"log,omitempty"` - Change string `json:"change"` + ID string `json:"id"` + Kind string `json:"kind"` + Summary string `json:"summary"` + Status Status `json:"status"` + WaitedStatus Status `json:"waited-status"` + Clean bool `json:"clean,omitempty"` + Progress *progress `json:"progress,omitempty"` + Data map[string]*json.RawMessage `json:"data,omitempty"` + WaitTasks []string `json:"wait-tasks,omitempty"` + HaltTasks []string `json:"halt-tasks,omitempty"` + Lanes []int `json:"lanes,omitempty"` + Log []string `json:"log,omitempty"` + Change string `json:"change"` SpawnTime time.Time `json:"spawn-time"` ReadyTime *time.Time `json:"ready-time,omitempty"` @@ -111,18 +110,19 @@ func (t *Task) MarshalJSON() ([]byte, error) { atTime = &t.atTime } return json.Marshal(marshalledTask{ - ID: t.id, - Kind: t.kind, - Summary: t.summary, - Status: t.status, - Clean: t.clean, - Progress: t.progress, - Data: t.data, - WaitTasks: t.waitTasks, - HaltTasks: t.haltTasks, - Lanes: t.lanes, - Log: t.log, - Change: t.change, + ID: t.id, + Kind: t.kind, + Summary: t.summary, + Status: t.status, + WaitedStatus: t.waitedStatus, + Clean: t.clean, + Progress: t.progress, + Data: t.data, + WaitTasks: t.waitTasks, + HaltTasks: t.haltTasks, + Lanes: t.lanes, + Log: t.log, + Change: t.change, SpawnTime: t.spawnTime, ReadyTime: readyTime, @@ -148,6 +148,13 @@ func (t *Task) UnmarshalJSON(data []byte) error { t.kind = unmarshalled.Kind t.summary = unmarshalled.Summary t.status = unmarshalled.Status + t.waitedStatus = unmarshalled.WaitedStatus + if t.waitedStatus == DefaultStatus { + // For backwards-compatibility, default the waitStatus, which is + // the result status after a wait, to DoneStatus to keep any previous + // behaviour before any upgrade. + t.waitedStatus = DoneStatus + } t.clean = unmarshalled.Clean t.progress = unmarshalled.Progress custData := unmarshalled.Data @@ -188,6 +195,40 @@ func (t *Task) Summary() string { } // Status returns the current task status. +// +// Possible state transitions: +// +// /----aborting lane--Do +// | | +// V V +// Hold Doing-->Wait +// ^ / | \ +// | abort / V V +// no undo / Done Error +// | V | +// \----------Abort aborting lane +// / | | +// | finished or | +// running not running | +// V \------->| +// kill goroutine | +// | V +// / \ ----->Undo +// / no error / | +// | from goroutine | +// error | +// from goroutine | +// | V +// | Undoing-->Wait +// V | \ +// Error V V +// Undone Error +// +// Do -> Doing -> Done is the direct succcess scenario. +// +// Wait can transition to its waited status, +// usually Done|Undone or back to Doing. +// See Wait struct, SetToWait and WaitedStatus. func (t *Task) Status() Status { t.state.reading() if t.status == DefaultStatus { @@ -196,10 +237,10 @@ func (t *Task) Status() Status { return t.status } -// SetStatus sets the task status, overriding the default behavior (see Status method). -func (t *Task) SetStatus(new Status) { - t.state.writing() - old := t.status +func (t *Task) changeStatus(old, new Status) { + if old == new { + return + } t.status = new if !old.Ready() && new.Ready() { t.readyTime = timeNow() @@ -208,6 +249,55 @@ func (t *Task) SetStatus(new Status) { if chg != nil { chg.taskStatusChanged(t, old, new) } + t.state.notifyTaskStatusChangedHandlers(t, old, new) +} + +// SetStatus sets the task status, overriding the default behavior (see Status method). +func (t *Task) SetStatus(new Status) { + if new == WaitStatus { + panic("Task.SetStatus() called with WaitStatus, which is not allowed. Use SetToWait() instead") + } + + t.state.writing() + old := t.status + if new == DoneStatus && old == AbortStatus { + // if the task is in AbortStatus (because some other task ran + // in parallel and had an error so the change is aborted) and + // DoneStatus was requested (which can happen if the + // task handler sets its status explicitly) then keep it at + // aborted so it can transition to Undo. + return + } + t.changeStatus(old, new) +} + +// SetToWait puts the task into WaitStatus, and sets the status the task should be restored +// to after the SetToWait. +func (t *Task) SetToWait(resultStatus Status) { + switch resultStatus { + case DefaultStatus, WaitStatus: + panic("Task.SetToWait() cannot be invoked with either of DefaultStatus or WaitStatus") + } + + t.state.writing() + old := t.status + if old == AbortStatus { + // if the task is in AbortStatus (because some other task ran + // in parallel and had an error so the change is aborted) and + // WaitStatus was requested (which can happen if the + // task handler sets its status explicitly) then keep it at + // aborted so it can transition to Undo. + return + } + t.waitedStatus = resultStatus + t.changeStatus(old, WaitStatus) +} + +// WaitedStatus returns the status the Task should return to once the current WaitStatus +// has been resolved. +func (t *Task) WaitedStatus() Status { + t.state.reading() + return t.waitedStatus } // IsClean returns whether the task has been cleaned. See SetClean. @@ -332,7 +422,7 @@ func (t *Task) addLog(kind, format string, args []interface{}) { } tstr := timeNow().Format(time.RFC3339) - msg := fmt.Sprintf(tstr+" "+kind+" "+format, args...) + msg := tstr + " " + kind + " " + fmt.Sprintf(format, args...) t.log = append(t.log, msg) logger.Debugf(msg) } @@ -484,10 +574,16 @@ func NewTaskSet(tasks ...*Task) *TaskSet { return &TaskSet{tasks, nil} } -// Edge returns the task marked with the given edge name. +// MaybeEdge returns the task marked with the given edge name or nil if no such +// task exists. +func (ts TaskSet) MaybeEdge(e TaskSetEdge) *Task { + return ts.edges[e] +} + +// Edge returns the task marked with the given edge name or an error. func (ts TaskSet) Edge(e TaskSetEdge) (*Task, error) { - t, ok := ts.edges[e] - if !ok { + t := ts.MaybeEdge(e) + if t == nil { return nil, fmt.Errorf("internal error: missing %q edge in task set", e) } return t, nil @@ -509,7 +605,7 @@ func (ts *TaskSet) WaitAll(anotherTs *TaskSet) { } } -// AddTask adds the the task to the task set. +// AddTask adds the task to the task set. func (ts *TaskSet) AddTask(task *Task) { for _, t := range ts.tasks { if t == task { @@ -522,6 +618,9 @@ func (ts *TaskSet) AddTask(task *Task) { // MarkEdge marks the given task as a specific edge. Any pre-existing // edge mark will be overridden. func (ts *TaskSet) MarkEdge(task *Task, edge TaskSetEdge) { + if task == nil { + panic(fmt.Sprintf("cannot set edge %q with nil task", edge)) + } if ts.edges == nil { ts.edges = make(map[TaskSetEdge]*Task) } diff --git a/internals/overlord/state/task_test.go b/internals/overlord/state/task_test.go index 9c073d806..5a5c4f9b1 100644 --- a/internals/overlord/state/task_test.go +++ b/internals/overlord/state/task_test.go @@ -1,21 +1,16 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (c) 2016 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . package state_test @@ -127,7 +122,7 @@ func (ts *taskSuite) TestClear(c *C) { t.Clear("a") - c.Check(t.Get("a", &v), Equals, state.ErrNoState) + c.Check(t.Get("a", &v), testutil.ErrorIs, state.ErrNoState) } func (ts *taskSuite) TestStatusAndSetStatus(c *C) { @@ -144,6 +139,60 @@ func (ts *taskSuite) TestStatusAndSetStatus(c *C) { c.Check(t.Status(), Equals, state.DoneStatus) } +func (ts *taskSuite) TestSetDoneAfterAbortNoop(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + t := st.NewTask("download", "1...") + t.SetStatus(state.AbortStatus) + c.Check(t.Status(), Equals, state.AbortStatus) + t.SetStatus(state.DoneStatus) + c.Check(t.Status(), Equals, state.AbortStatus) +} + +func (ts *taskSuite) TestSetWaitAfterAbortNoop(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + t := st.NewTask("download", "1...") + t.SetStatus(state.AbortStatus) + c.Check(t.Status(), Equals, state.AbortStatus) + t.SetToWait(state.DoneStatus) // noop + c.Check(t.Status(), Equals, state.AbortStatus) + c.Check(t.WaitedStatus(), Equals, state.DefaultStatus) +} + +func (ts *taskSuite) TestSetWait(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + t := st.NewTask("download", "1...") + t.SetToWait(state.DoneStatus) + c.Check(t.Status(), Equals, state.WaitStatus) + c.Check(t.WaitedStatus(), Equals, state.DoneStatus) + t.SetToWait(state.UndoStatus) + c.Check(t.Status(), Equals, state.WaitStatus) + c.Check(t.WaitedStatus(), Equals, state.UndoStatus) +} + +func (ts *taskSuite) TestTaskMarshalsWaitStatus(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + t1 := st.NewTask("download", "1...") + t1.SetToWait(state.UndoStatus) + + d, err := t1.MarshalJSON() + c.Assert(err, IsNil) + + needle := fmt.Sprintf(`"waited-status":%d`, t1.WaitedStatus()) + c.Assert(string(d), testutil.Contains, needle) +} + func (ts *taskSuite) TestIsCleanAndSetClean(c *C) { st := state.New(nil) st.Lock() @@ -561,6 +610,10 @@ func (cs *taskSuite) TestTaskSetEdge(c *C) { // edges are just typed strings edge1 := state.TaskSetEdge("on-edge") edge2 := state.TaskSetEdge("eddie") + edge3 := state.TaskSetEdge("not-found") + + // nil task causes panic + c.Check(func() { ts.MarkEdge(nil, edge1) }, PanicMatches, `cannot set edge "on-edge" with nil task`) // no edge marked yet t, err := ts.Edge(edge1) @@ -590,6 +643,12 @@ func (cs *taskSuite) TestTaskSetEdge(c *C) { t, err = ts.Edge(edge1) c.Assert(t, Equals, t3) c.Assert(err, IsNil) + + // it is possible to check if edge exists without failing + t = ts.MaybeEdge(edge1) + c.Assert(t, Equals, t3) + t = ts.MaybeEdge(edge3) + c.Assert(t, IsNil) } func (cs *taskSuite) TestTaskAddAllWithEdges(c *C) { diff --git a/internals/overlord/state/taskrunner.go b/internals/overlord/state/taskrunner.go index 5d9ff2440..cfed27bd2 100644 --- a/internals/overlord/state/taskrunner.go +++ b/internals/overlord/state/taskrunner.go @@ -1,21 +1,16 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (c) 2016 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . package state @@ -46,6 +41,23 @@ func (r *Retry) Error() string { return "task should be retried" } +// Wait is returned from a handler to signal that the task cannot +// proceed at the moment maybe because some manual action from the +// user required at this point or because of errors. The task +// will be set to WaitStatus, and it's wait complete status will be +// set to WaitedStatus. +type Wait struct { + Reason string + // If not explicitly set, then WaitedStatus will default to + // DoneStatus, meaning that the task will be set to DoneStatus + // after the wait has resolved. + WaitedStatus Status +} + +func (r *Wait) Error() string { + return "task set to wait, manual action required" +} + type blockedFunc func(t *Task, running []*Task) bool // TaskRunner controls the running of goroutines to execute known task kinds. @@ -62,6 +74,9 @@ type TaskRunner struct { blocked []blockedFunc someBlocked bool + // optional callback executed on task errors + taskErrorCallback func(err error) + // go-routines lifecycle tombs map[string]*tomb.Tomb } @@ -85,6 +100,11 @@ func NewTaskRunner(s *State) *TaskRunner { } } +// OnTaskError sets an error callback executed when any task errors out. +func (r *TaskRunner) OnTaskError(f func(err error)) { + r.taskErrorCallback = f +} + // AddHandler registers the functions to concurrently call for doing and // undoing tasks of the given kind. The undo handler may be nil. func (r *TaskRunner) AddHandler(kind string, do, undo HandlerFunc) { @@ -214,7 +234,7 @@ func (r *TaskRunner) run(t *Task) { switch err.(type) { case nil: // we are ok - case *Retry: + case *Retry, *Wait: // preserve default: if r.stopped { @@ -227,13 +247,24 @@ func (r *TaskRunner) run(t *Task) { switch x := err.(type) { case *Retry: // Handler asked to be called again later. - // TODO Allow postponing retries past the next Ensure. if t.Status() == AbortStatus { // Would work without it but might take two ensures. r.tryUndo(t) } else if x.After != 0 { t.At(timeNow().Add(x.After)) } + case *Wait: + if t.Status() == AbortStatus { + // Would work without it but might take two ensures. + r.tryUndo(t) + } else { + // Default to DoneStatus if no status is set in Wait + waitedStatus := x.WaitedStatus + if waitedStatus == DefaultStatus { + waitedStatus = DoneStatus + } + t.SetToWait(waitedStatus) + } case nil: var next []*Task switch t.Status() { @@ -259,6 +290,11 @@ func (r *TaskRunner) run(t *Task) { r.abortLanes(t.Change(), t.Lanes()) t.SetStatus(ErrorStatus) t.Errorf("%s", err) + // ensure the error is available in the global log too + logger.Noticef("[change %s %q task] failed: %v", t.Change().ID(), t.Summary(), err) + if r.taskErrorCallback != nil { + r.taskErrorCallback(err) + } } return nil @@ -388,6 +424,10 @@ ConsiderTasks: } continue } + if status == WaitStatus { + // nothing more to run + continue + } if mustWait(t) { // Dependencies still unhandled. diff --git a/internals/overlord/state/taskrunner_test.go b/internals/overlord/state/taskrunner_test.go index 38c5685a2..979d05e96 100644 --- a/internals/overlord/state/taskrunner_test.go +++ b/internals/overlord/state/taskrunner_test.go @@ -1,21 +1,16 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (c) 2016 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . package state_test @@ -31,11 +26,14 @@ import ( . "gopkg.in/check.v1" "gopkg.in/tomb.v2" - "github.com/canonical/pebble/internals/overlord/restart" + "github.com/canonical/pebble/internals/logger" "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/testutil" ) -type taskRunnerSuite struct{} +type taskRunnerSuite struct { + testutil.BaseTest +} var _ = Suite(&taskRunnerSuite{}) @@ -58,8 +56,6 @@ func (b *stateBackend) EnsureBefore(d time.Duration) { } } -func (b *stateBackend) RequestRestart(t restart.RestartType) {} - func ensureChange(c *C, r *state.TaskRunner, sb *stateBackend, chg *state.Change) { for i := 0; i < 20; i++ { sb.ensureBefore = time.Hour @@ -142,6 +138,10 @@ var sequenceTests = []struct{ setup, result string }{{ result: "t31:undo t32:do t32:do-error t21:undo", }} +func (ts *taskRunnerSuite) SetUpTest(c *C) { + ts.BaseTest.SetUpTest(c) +} + func (ts *taskRunnerSuite) TestSequenceTests(c *C) { sb := &stateBackend{} st := state.New(sb) @@ -185,11 +185,12 @@ func (ts *taskRunnerSuite) TestSequenceTests(c *C) { r.AddHandler("do", fn("do"), nil) r.AddHandler("do-undo", fn("do"), fn("undo")) + past := time.Now().AddDate(-1, 0, 0) for _, test := range sequenceTests { st.Lock() // Delete previous changes. - st.Prune(1, 1, 1) + st.Prune(past, 1, 1, 1) chg := st.NewChange("install", "...") tasks := make(map[string]*state.Task) @@ -342,6 +343,206 @@ func (ts *taskRunnerSuite) TestSequenceTests(c *C) { } } +func (ts *taskRunnerSuite) TestAbortAcrossLanesDescendantTask(c *C) { + + // () + // t11(1) -> t12(1) => t15(1) + // \ / + // => t13(1,2) => t14(1,2) => t23(1,2) => t24(1,2) + // / \ + // t21(2) -> t22(2) => t25(2) + // + names := strings.Fields("t11 t12 t13 t14 t15 t21 t22 t23 t24 t25") + + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + st.Lock() + defer st.Unlock() + + c.Assert(len(st.Tasks()), Equals, 0) + + chg := st.NewChange("install", "...") + tasks := make(map[string]*state.Task) + for _, name := range names { + tasks[name] = st.NewTask("do", name) + chg.AddTask(tasks[name]) + } + tasks["t12"].WaitFor(tasks["t11"]) + tasks["t13"].WaitFor(tasks["t12"]) + tasks["t14"].WaitFor(tasks["t13"]) + tasks["t15"].WaitFor(tasks["t14"]) + for lane, names := range map[int][]string{ + 1: {"t11", "t12", "t13", "t14", "t15", "t23", "t24"}, + 2: {"t21", "t22", "t23", "t24", "t25", "t13", "t14"}, + } { + for _, name := range names { + tasks[name].JoinLane(lane) + } + } + + tasks["t22"].WaitFor(tasks["t21"]) + tasks["t23"].WaitFor(tasks["t22"]) + tasks["t24"].WaitFor(tasks["t23"]) + tasks["t25"].WaitFor(tasks["t24"]) + + tasks["t13"].WaitFor(tasks["t22"]) + tasks["t15"].WaitFor(tasks["t24"]) + tasks["t23"].WaitFor(tasks["t14"]) + + ch := make(chan string, 256) + do := func(task *state.Task, tomb *tomb.Tomb) error { + c.Logf("do %q", task.Summary()) + label := task.Summary() + if label == "t15" { + ch <- "t15:error" + return fmt.Errorf("mock error") + } + ch <- fmt.Sprintf("%s:do", label) + return nil + } + undo := func(task *state.Task, tomb *tomb.Tomb) error { + c.Logf("undo %q", task.Summary()) + label := task.Summary() + ch <- fmt.Sprintf("%s:undo", label) + return nil + } + r.AddHandler("do", do, undo) + + c.Logf("-----") + + st.Unlock() + ensureChange(c, r, sb, chg) + st.Lock() + close(ch) + var sequence []string + for event := range ch { + sequence = append(sequence, event) + } + for _, name := range names { + task := tasks[name] + c.Logf("%5s %5s lanes: %v status: %v", task.ID(), task.Summary(), task.Lanes(), task.Status()) + } + c.Assert(sequence[:4], testutil.DeepUnsortedMatches, []string{ + "t11:do", "t12:do", + "t21:do", "t22:do", + }) + c.Assert(sequence[4:8], DeepEquals, []string{ + "t13:do", "t14:do", "t23:do", "t24:do", + }) + c.Assert(sequence[8:10], testutil.DeepUnsortedMatches, []string{ + "t25:do", + "t15:error", + }) + c.Assert(sequence[10:11], testutil.DeepUnsortedMatches, []string{ + "t25:undo", + }) + c.Assert(sequence[11:15], DeepEquals, []string{ + "t24:undo", "t23:undo", "t14:undo", "t13:undo", + }) + c.Assert(sequence[15:19], testutil.DeepUnsortedMatches, []string{ + "t21:undo", "t22:undo", + "t12:undo", "t11:undo", + }) +} + +func (ts *taskRunnerSuite) TestAbortAcrossLanesStriclyOrderedTasks(c *C) { + + // () + // t11(1) -> t12(1) + // \ + // => t13(1,2) => t14(1,2) => t23(1,2) => t24(1,2) + // / + // t21(2) -> t22(2) + // + names := strings.Fields("t11 t12 t13 t14 t21 t22 t23 t24") + + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + st.Lock() + defer st.Unlock() + + c.Assert(len(st.Tasks()), Equals, 0) + + chg := st.NewChange("install", "...") + tasks := make(map[string]*state.Task) + for _, name := range names { + tasks[name] = st.NewTask("do", name) + chg.AddTask(tasks[name]) + } + tasks["t12"].WaitFor(tasks["t11"]) + tasks["t13"].WaitFor(tasks["t12"]) + tasks["t14"].WaitFor(tasks["t13"]) + for lane, names := range map[int][]string{ + 1: {"t11", "t12", "t13", "t14", "t23", "t24"}, + 2: {"t21", "t22", "t23", "t24", "t13", "t14"}, + } { + for _, name := range names { + tasks[name].JoinLane(lane) + } + } + + tasks["t22"].WaitFor(tasks["t21"]) + tasks["t23"].WaitFor(tasks["t22"]) + tasks["t24"].WaitFor(tasks["t23"]) + + tasks["t13"].WaitFor(tasks["t22"]) + tasks["t23"].WaitFor(tasks["t14"]) + + ch := make(chan string, 256) + do := func(task *state.Task, tomb *tomb.Tomb) error { + c.Logf("do %q", task.Summary()) + label := task.Summary() + if label == "t24" { + ch <- "t24:error" + return fmt.Errorf("mock error") + } + ch <- fmt.Sprintf("%s:do", label) + return nil + } + undo := func(task *state.Task, tomb *tomb.Tomb) error { + c.Logf("undo %q", task.Summary()) + label := task.Summary() + ch <- fmt.Sprintf("%s:undo", label) + return nil + } + r.AddHandler("do", do, undo) + + c.Logf("-----") + + st.Unlock() + ensureChange(c, r, sb, chg) + st.Lock() + close(ch) + var sequence []string + for event := range ch { + sequence = append(sequence, event) + } + for _, name := range names { + task := tasks[name] + c.Logf("%5s %5s lanes: %v status: %v", task.ID(), task.Summary(), task.Lanes(), task.Status()) + } + c.Assert(sequence[:4], testutil.DeepUnsortedMatches, []string{ + "t11:do", "t12:do", + "t21:do", "t22:do", + }) + c.Assert(sequence[4:8], DeepEquals, []string{ + "t13:do", "t14:do", "t23:do", "t24:error", + }) + c.Assert(sequence[8:11], DeepEquals, []string{ + "t23:undo", "t14:undo", "t13:undo", + }) + c.Assert(sequence[11:], testutil.DeepUnsortedMatches, []string{ + "t21:undo", "t22:undo", + "t12:undo", "t11:undo", + }) +} + func (ts *taskRunnerSuite) TestExternalAbort(c *C) { sb := &stateBackend{} st := state.New(sb) @@ -372,6 +573,101 @@ func (ts *taskRunnerSuite) TestExternalAbort(c *C) { ensureChange(c, r, sb, chg) } +func (ts *taskRunnerSuite) TestUndoSingleLane(c *C) { + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + r.AddHandler("noop", func(t *state.Task, tb *tomb.Tomb) error { + return nil + }, func(t *state.Task, tb *tomb.Tomb) error { + return nil + }) + + r.AddHandler("noop-slow", func(t *state.Task, tb *tomb.Tomb) error { + time.Sleep(10 * time.Millisecond) + t.State().Lock() + defer t.State().Unlock() + // critical + t.SetStatus(state.DoneStatus) + return nil + }, func(t *state.Task, tb *tomb.Tomb) error { + return nil + }) + + r.AddHandler("fail", func(t *state.Task, tb *tomb.Tomb) error { + return fmt.Errorf("fail") + }, nil) + + st.Lock() + + lane := st.NewLane() + chg := st.NewChange("install", "...") + + // first taskset + var prev *state.Task + for i := 0; i < 10; i++ { + t := st.NewTask("noop-slow", "...") + if prev != nil { + t.WaitFor(prev) + } + chg.AddTask(t) + t.JoinLane(lane) + + prev = t + } + + // second taskset with a failing task that triggers undo of the change + prev = nil + for i := 0; i < 10; i++ { + t := st.NewTask("noop", "...") + if prev != nil { + t.WaitFor(prev) + } + chg.AddTask(t) + t.JoinLane(lane) + prev = t + } + + // error trigger + t := st.NewTask("fail", "...") + t.WaitFor(prev) + chg.AddTask(t) + t.JoinLane(lane) + + st.Unlock() + + var done bool + for !done { + c.Assert(r.Ensure(), Equals, nil) + st.Lock() + done = chg.IsReady() && chg.IsClean() + st.Unlock() + } + + st.Lock() + defer st.Unlock() + + // make sure all tasks are either undone or on hold (except for "fail" task which + // is in error). + for _, t := range st.Tasks() { + switch t.Kind() { + case "fail": + c.Assert(t.Status(), Equals, state.ErrorStatus) + case "noop", "noop-slow": + if t.Status() != state.UndoneStatus && t.Status() != state.HoldStatus { + for _, tsk := range st.Tasks() { + fmt.Printf("%s -> %s\n", tsk.Kind(), tsk.Status()) + } + c.Fatalf("unexpected status: %s", t.Status()) + } + default: + c.Fatalf("unexpected kind: %s", t.Kind()) + } + } +} + func (ts *taskRunnerSuite) TestStopHandlerJustFinishing(c *C) { sb := &stateBackend{} st := state.New(sb) @@ -508,6 +804,55 @@ func (ts *taskRunnerSuite) TestStopAskForRetry(c *C) { c.Check(t.AtTime().IsZero(), Equals, false) } +func (ts *taskRunnerSuite) testTaskReturningWait(c *C, waitedStatus, expectedStatus state.Status) { + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + r.AddHandler("ask-for-wait", func(t *state.Task, tb *tomb.Tomb) error { + // ask for wait + return &state.Wait{WaitedStatus: waitedStatus} + }, nil) + + st.Lock() + chg := st.NewChange("install", "...") + t := st.NewTask("ask-for-wait", "...") + chg.AddTask(t) + st.Unlock() + + r.Ensure() + // wait for handler to finish + r.Wait() + + st.Lock() + defer st.Unlock() + c.Check(t.Status(), Equals, state.WaitStatus) + c.Check(t.WaitedStatus(), Equals, expectedStatus) + c.Check(chg.Status().Ready(), Equals, false) + + st.Unlock() + defer st.Lock() + // does nothing + r.Ensure() + + // state is unchanged + st.Lock() + defer st.Unlock() + c.Check(t.Status(), Equals, state.WaitStatus) + c.Check(chg.Status().Ready(), Equals, false) +} + +func (ts *taskRunnerSuite) TestTaskReturningWaitNormal(c *C) { + ts.testTaskReturningWait(c, state.UndoneStatus, state.UndoneStatus) +} + +func (ts *taskRunnerSuite) TestTaskReturningWaitDefaultStatus(c *C) { + // If no state was set (DefaultStatus), then it should default to + // DoneStatus instead. + ts.testTaskReturningWait(c, state.DefaultStatus, state.DoneStatus) +} + func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) { ensureBeforeTick := make(chan bool, 1) sb := &stateBackend{ @@ -823,7 +1168,7 @@ func (ts *taskRunnerSuite) TestUndoSequence(c *C) { terr.WaitFor(prev) chg.AddTask(terr) - c.Check(chg.Tasks(), HasLen, 9) // sanity check + c.Check(chg.Tasks(), HasLen, 9) // validity check st.Unlock() @@ -910,3 +1255,71 @@ func (ts *taskRunnerSuite) TestCleanup(c *C) { c.Assert(chgIsClean(), Equals, true) c.Assert(called, Equals, 2) } + +func (ts *taskRunnerSuite) TestErrorCallbackCalledOnError(c *C) { + logbuf, restore := logger.MockLogger("TASKRUNNER: ") + defer restore() + + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + + var called bool + r.OnTaskError(func(err error) { + called = true + }) + + r.AddHandler("foo", func(t *state.Task, tomb *tomb.Tomb) error { + return fmt.Errorf("handler error for %q", t.Kind()) + }, nil) + + st.Lock() + chg := st.NewChange("install", "change summary") + t1 := st.NewTask("foo", "task summary") + chg.AddTask(t1) + st.Unlock() + + // Mark tasks as done. + ensureChange(c, r, sb, chg) + r.Stop() + + st.Lock() + defer st.Unlock() + + c.Check(t1.Status(), Equals, state.ErrorStatus) + c.Check(strings.Join(t1.Log(), ""), Matches, `.*handler error for "foo"`) + c.Check(called, Equals, true) + + c.Check(logbuf.String(), Matches, `(?m).*: \[change 1 "task summary" task\] failed: handler error for "foo".*`) +} + +func (ts *taskRunnerSuite) TestErrorCallbackNotCalled(c *C) { + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + + var called bool + r.OnTaskError(func(err error) { + called = true + }) + + r.AddHandler("foo", func(t *state.Task, tomb *tomb.Tomb) error { + return nil + }, nil) + + st.Lock() + chg := st.NewChange("install", "...") + t1 := st.NewTask("foo", "...") + chg.AddTask(t1) + st.Unlock() + + // Mark tasks as done. + ensureChange(c, r, sb, chg) + r.Stop() + + st.Lock() + defer st.Unlock() + + c.Check(t1.Status(), Equals, state.DoneStatus) + c.Check(called, Equals, false) +} diff --git a/internals/overlord/state/warning.go b/internals/overlord/state/warning.go index 69dac7c84..2b348b58c 100644 --- a/internals/overlord/state/warning.go +++ b/internals/overlord/state/warning.go @@ -1,21 +1,16 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (c) 2018 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . package state diff --git a/internals/overlord/state/warning_test.go b/internals/overlord/state/warning_test.go index bf30f69e9..4cc701bb6 100644 --- a/internals/overlord/state/warning_test.go +++ b/internals/overlord/state/warning_test.go @@ -1,21 +1,16 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (c) 2018 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ +// Copyright (c) 2024 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . package state_test @@ -99,7 +94,7 @@ func (stateSuite) TestUnmarshalErrors(c *check.C) { } for _, t := range []T1{ - // sanity check + // validity check {`{"message": "x", "first-added": "2006-01-02T15:04:05Z", "expire-after": "1h", "repeat-after": "1h"}`, nil}, // remove one field at a time: {`{ "first-added": "2006-01-02T15:04:05Z", "expire-after": "1h", "repeat-after": "1h"}`, state.ErrNoWarningMessage},