diff --git a/internals/daemon/daemon.go b/internals/daemon/daemon.go
index 432451d6d..487f2cc42 100644
--- a/internals/daemon/daemon.go
+++ b/internals/daemon/daemon.go
@@ -690,7 +690,7 @@ func (d *Daemon) rebootDelay() (time.Duration, error) {
// see whether a reboot had already been scheduled
var rebootAt time.Time
err := d.state.Get("daemon-system-restart-at", &rebootAt)
- if err != nil && err != state.ErrNoState {
+ if err != nil && !errors.Is(err, state.ErrNoState) {
return 0, err
}
rebootDelay := 1 * time.Minute
@@ -832,7 +832,7 @@ var errExpectedReboot = errors.New("expected reboot did not happen")
func (d *Daemon) RebootIsMissing(st *state.State) error {
var nTentative int
err := st.Get("daemon-system-restart-tentative", &nTentative)
- if err != nil && err != state.ErrNoState {
+ if err != nil && !errors.Is(err, state.ErrNoState) {
return err
}
nTentative++
diff --git a/internals/daemon/daemon_test.go b/internals/daemon/daemon_test.go
index 2a3d18abb..e3784dc96 100644
--- a/internals/daemon/daemon_test.go
+++ b/internals/daemon/daemon_test.go
@@ -991,8 +991,8 @@ func (s *daemonSuite) TestRestartExpectedRebootOK(c *C) {
defer st.Unlock()
var v interface{}
// these were cleared
- c.Check(st.Get("daemon-system-restart-at", &v), Equals, state.ErrNoState)
- c.Check(st.Get("system-restart-from-boot-id", &v), Equals, state.ErrNoState)
+ c.Check(st.Get("daemon-system-restart-at", &v), testutil.ErrorIs, state.ErrNoState)
+ c.Check(st.Get("system-restart-from-boot-id", &v), testutil.ErrorIs, state.ErrNoState)
}
func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *C) {
@@ -1015,9 +1015,9 @@ func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *C) {
defer st.Unlock()
var v interface{}
// these were cleared
- c.Check(st.Get("daemon-system-restart-at", &v), Equals, state.ErrNoState)
- c.Check(st.Get("system-restart-from-boot-id", &v), Equals, state.ErrNoState)
- c.Check(st.Get("daemon-system-restart-tentative", &v), Equals, state.ErrNoState)
+ c.Check(st.Get("daemon-system-restart-at", &v), testutil.ErrorIs, state.ErrNoState)
+ c.Check(st.Get("system-restart-from-boot-id", &v), testutil.ErrorIs, state.ErrNoState)
+ c.Check(st.Get("daemon-system-restart-tentative", &v), testutil.ErrorIs, state.ErrNoState)
}
func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *C) {
diff --git a/internals/overlord/export_test.go b/internals/overlord/export_test.go
index 225297029..372780eb8 100644
--- a/internals/overlord/export_test.go
+++ b/internals/overlord/export_test.go
@@ -40,6 +40,14 @@ func FakePruneInterval(prunei, prunew, abortw time.Duration) (restore func()) {
}
}
+func FakePruneTicker(f func(t *time.Ticker) <-chan time.Time) (restore func()) {
+ old := pruneTickerC
+ pruneTickerC = f
+ return func() {
+ pruneTickerC = old
+ }
+}
+
// FakeEnsureNext sets o.ensureNext for tests.
func FakeEnsureNext(o *Overlord, t time.Time) {
o.ensureNext = t
@@ -49,3 +57,11 @@ func FakeEnsureNext(o *Overlord, t time.Time) {
func (o *Overlord) Engine() *StateEngine {
return o.stateEng
}
+
+func FakeTimeNow(f func() time.Time) (restore func()) {
+ old := timeNow
+ timeNow = f
+ return func() {
+ timeNow = old
+ }
+}
diff --git a/internals/overlord/overlord.go b/internals/overlord/overlord.go
index d6406e243..6c068233f 100644
--- a/internals/overlord/overlord.go
+++ b/internals/overlord/overlord.go
@@ -16,6 +16,7 @@
package overlord
import (
+ "errors"
"fmt"
"io"
"os"
@@ -49,6 +50,10 @@ var (
defaultCachedDownloads = 5
)
+var pruneTickerC = func(t *time.Ticker) <-chan time.Time {
+ return t.C
+}
+
// Extension represents an extension of the Overlord.
type Extension interface {
// ExtraManagers allows additional StateManagers to be used.
@@ -81,6 +86,8 @@ type Overlord struct {
ensureRun int32
pruneTicker *time.Ticker
+ startOfOperationTime time.Time
+
// managers
inited bool
startedUp bool
@@ -258,6 +265,15 @@ func (o *Overlord) StartUp() error {
return nil
}
o.startedUp = true
+
+ var err error
+ st := o.State()
+ st.Lock()
+ o.startOfOperationTime, err = o.StartOfOperationTime()
+ st.Unlock()
+ if err != nil {
+ return fmt.Errorf("cannot get start of operation time: %s", err)
+ }
return o.stateEng.StartUp()
}
@@ -314,14 +330,15 @@ func (o *Overlord) Loop() {
// continue to the next Ensure() try for now
o.stateEng.Ensure()
o.ensureDidRun()
+ pruneC := pruneTickerC(o.pruneTicker)
select {
case <-o.loopTomb.Dying():
return nil
case <-o.ensureTimer.C:
- case <-o.pruneTicker.C:
+ case <-pruneC:
st := o.State()
st.Lock()
- st.Prune(pruneWait, abortWait, pruneMaxChanges)
+ st.Prune(o.startOfOperationTime, pruneWait, abortWait, pruneMaxChanges)
st.Unlock()
}
}
@@ -499,6 +516,23 @@ func (o *Overlord) AddManager(mgr StateManager) {
o.stateEng.AddManager(mgr)
}
+var timeNow = time.Now
+
+func (m *Overlord) StartOfOperationTime() (time.Time, error) {
+ var opTime time.Time
+ err := m.State().Get("start-of-operation-time", &opTime)
+ if err == nil {
+ return opTime, nil
+ }
+ if err != nil && !errors.Is(err, state.ErrNoState) {
+ return opTime, err
+ }
+ opTime = timeNow()
+
+ m.State().Set("start-of-operation-time", opTime)
+ return opTime, nil
+}
+
type fakeBackend struct {
o *Overlord
}
diff --git a/internals/overlord/overlord_test.go b/internals/overlord/overlord_test.go
index 57765a049..bef3e53ea 100644
--- a/internals/overlord/overlord_test.go
+++ b/internals/overlord/overlord_test.go
@@ -46,6 +46,26 @@ type overlordSuite struct {
var _ = Suite(&overlordSuite{})
+type ticker struct {
+ tickerChannel chan time.Time
+}
+
+func (w *ticker) tick(n int) {
+ for i := 0; i < n; i++ {
+ w.tickerChannel <- time.Now()
+ }
+}
+
+func fakePruneTicker() (w *ticker, restore func()) {
+ w = &ticker{
+ tickerChannel: make(chan time.Time),
+ }
+ restore = overlord.FakePruneTicker(func(t *time.Ticker) <-chan time.Time {
+ return w.tickerChannel
+ })
+ return w, restore
+}
+
func (ovs *overlordSuite) SetUpTest(c *C) {
ovs.dir = c.MkDir()
ovs.statePath = filepath.Join(ovs.dir, ".pebble.state")
@@ -517,6 +537,137 @@ func (ovs *overlordSuite) TestEnsureLoopPruneRunsMultipleTimes(c *C) {
c.Assert(err, IsNil)
}
+func (ovs *overlordSuite) TestOverlordStartUpSetsStartOfOperation(c *C) {
+ restoreIntv := overlord.FakePruneInterval(100*time.Millisecond, 1000*time.Millisecond, 1*time.Hour)
+ defer restoreIntv()
+
+ // use real overlord, we need device manager to be there
+ o, err := overlord.New(&overlord.Options{PebbleDir: ovs.dir})
+ c.Assert(err, IsNil)
+
+ st := o.State()
+ st.Lock()
+ defer st.Unlock()
+
+ // validity check, not set
+ var opTime time.Time
+ c.Assert(st.Get("start-of-operation-time", &opTime), testutil.ErrorIs, state.ErrNoState)
+ st.Unlock()
+
+ c.Assert(o.StartUp(), IsNil)
+
+ st.Lock()
+ c.Assert(st.Get("start-of-operation-time", &opTime), IsNil)
+}
+
+func (ovs *overlordSuite) TestEnsureLoopPruneDoesntAbortShortlyAfterStartOfOperation(c *C) {
+ w, restoreTicker := fakePruneTicker()
+ defer restoreTicker()
+
+ // use real overlord, we need device manager to be there
+ o, err := overlord.New(&overlord.Options{PebbleDir: ovs.dir})
+ c.Assert(err, IsNil)
+
+ // avoid immediate transition to Done due to unknown kind
+ o.TaskRunner().AddHandler("bar", func(t *state.Task, _ *tomb.Tomb) error {
+ return &state.Retry{}
+ }, nil)
+
+ st := o.State()
+ st.Lock()
+
+ // start of operation time is 50min ago, this is less then abort limit
+ opTime := time.Now().Add(-50 * time.Minute)
+ st.Set("start-of-operation-time", opTime)
+
+ // spawn time one month ago
+ spawnTime := time.Now().AddDate(0, -1, 0)
+ restoreTimeNow := state.FakeTime(spawnTime)
+
+ t := st.NewTask("bar", "...")
+ chg := st.NewChange("other-change", "...")
+ chg.AddTask(t)
+
+ restoreTimeNow()
+
+ // validity
+ c.Check(st.Changes(), HasLen, 1)
+
+ st.Unlock()
+ c.Assert(o.StartUp(), IsNil)
+
+ // start the loop that runs the prune ticker
+ o.Loop()
+ w.tick(2)
+
+ c.Assert(o.Stop(), IsNil)
+
+ st.Lock()
+ defer st.Unlock()
+ c.Assert(st.Changes(), HasLen, 1)
+ c.Check(chg.Status(), Equals, state.DoingStatus)
+}
+
+func (ovs *overlordSuite) TestEnsureLoopPruneAbortsOld(c *C) {
+ // Ensure interval is not relevant for this test
+ restoreEnsureIntv := overlord.FakeEnsureInterval(10 * time.Hour)
+ defer restoreEnsureIntv()
+
+ w, restoreTicker := fakePruneTicker()
+ defer restoreTicker()
+
+ // use real overlord, we need device manager to be there
+ o, err := overlord.New(&overlord.Options{PebbleDir: ovs.dir})
+ c.Assert(err, IsNil)
+
+ // avoid immediate transition to Done due to having unknown kind
+ o.TaskRunner().AddHandler("bar", func(t *state.Task, _ *tomb.Tomb) error {
+ return &state.Retry{}
+ }, nil)
+
+ st := o.State()
+ st.Lock()
+
+ // start of operation time is a year ago
+ opTime := time.Now().AddDate(-1, 0, 0)
+ st.Set("start-of-operation-time", opTime)
+
+ st.Unlock()
+ c.Assert(o.StartUp(), IsNil)
+ st.Lock()
+
+ // spawn time one month ago
+ spawnTime := time.Now().AddDate(0, -1, 0)
+ restoreTimeNow := state.FakeTime(spawnTime)
+ t := st.NewTask("bar", "...")
+ chg := st.NewChange("other-change", "...")
+ chg.AddTask(t)
+
+ restoreTimeNow()
+
+ // validity
+ c.Check(st.Changes(), HasLen, 1)
+ st.Unlock()
+
+ // start the loop that runs the prune ticker
+ o.Loop()
+ w.tick(2)
+
+ c.Assert(o.Stop(), IsNil)
+
+ st.Lock()
+ defer st.Unlock()
+
+ // validity
+ op, err := o.StartOfOperationTime()
+ c.Assert(err, IsNil)
+ c.Check(op.Equal(opTime), Equals, true)
+
+ c.Assert(st.Changes(), HasLen, 1)
+ // change was aborted
+ c.Check(chg.Status(), Equals, state.HoldStatus)
+}
+
func (ovs *overlordSuite) TestCheckpoint(c *C) {
oldUmask := syscall.Umask(0)
defer syscall.Umask(oldUmask)
@@ -907,3 +1058,40 @@ func (ovs *overlordSuite) TestOverlordCanStandby(c *C) {
c.Assert(o.CanStandby(), Equals, true)
}
+
+func (ovs *overlordSuite) TestStartOfOperationTimeAlreadySet(c *C) {
+ o := overlord.Fake()
+ st := o.State()
+ st.Lock()
+ defer st.Unlock()
+
+ op := time.Now().AddDate(0, -1, 0)
+ st.Set("start-of-operation-time", op)
+
+ operationTime, err := o.StartOfOperationTime()
+ c.Assert(err, IsNil)
+ c.Check(operationTime.Equal(op), Equals, true)
+}
+
+func (s *overlordSuite) TestStartOfOperationSetTime(c *C) {
+ o := overlord.Fake()
+ st := o.State()
+ st.Lock()
+ defer st.Unlock()
+
+ now := time.Now().Add(-1 * time.Second)
+ overlord.FakeTimeNow(func() time.Time {
+ return now
+ })
+
+ operationTime, err := o.StartOfOperationTime()
+ c.Assert(err, IsNil)
+ c.Check(operationTime.Equal(now), Equals, true)
+
+ // repeated call returns already set time
+ prev := now
+ now = time.Now().Add(-10 * time.Hour)
+ operationTime, err = o.StartOfOperationTime()
+ c.Assert(err, IsNil)
+ c.Check(operationTime.Equal(prev), Equals, true)
+}
diff --git a/internals/overlord/patch/patch.go b/internals/overlord/patch/patch.go
index afc93186e..58514a8a4 100644
--- a/internals/overlord/patch/patch.go
+++ b/internals/overlord/patch/patch.go
@@ -20,6 +20,7 @@
package patch
import (
+ "errors"
"fmt"
"github.com/canonical/pebble/internals/logger"
@@ -43,12 +44,12 @@ var patches = make(map[int][]PatchFunc)
func Init(s *state.State) {
s.Lock()
defer s.Unlock()
- if s.Get("patch-level", new(int)) != state.ErrNoState {
+ if err := s.Get("patch-level", new(int)); !errors.Is(err, state.ErrNoState) {
panic("internal error: expected empty state, attempting to override patch-level without actual patching")
}
s.Set("patch-level", Level)
- if s.Get("patch-sublevel", new(int)) != state.ErrNoState {
+ if err := s.Get("patch-sublevel", new(int)); !errors.Is(err, state.ErrNoState) {
panic("internal error: expected empty state, attempting to override patch-sublevel without actual patching")
}
s.Set("patch-sublevel", Sublevel)
@@ -76,12 +77,12 @@ func Apply(s *state.State) error {
var stateLevel, stateSublevel int
s.Lock()
err := s.Get("patch-level", &stateLevel)
- if err == nil || err == state.ErrNoState {
+ if err == nil || errors.Is(err, state.ErrNoState) {
err = s.Get("patch-sublevel", &stateSublevel)
}
s.Unlock()
- if err != nil && err != state.ErrNoState {
+ if err != nil && !errors.Is(err, state.ErrNoState) {
return err
}
diff --git a/internals/overlord/restart/restart.go b/internals/overlord/restart/restart.go
index 1e545b8cb..f52f09174 100644
--- a/internals/overlord/restart/restart.go
+++ b/internals/overlord/restart/restart.go
@@ -16,6 +16,8 @@
package restart
import (
+ "errors"
+
"github.com/canonical/pebble/internals/overlord/state"
)
@@ -63,7 +65,7 @@ func Init(st *state.State, curBootID string, h Handler) error {
}
var fromBootID string
err := st.Get("system-restart-from-boot-id", &fromBootID)
- if err != nil && err != state.ErrNoState {
+ if err != nil && !errors.Is(err, state.ErrNoState) {
return err
}
st.Cache(restartStateKey{}, rs)
diff --git a/internals/overlord/restart/restart_test.go b/internals/overlord/restart/restart_test.go
index 8a8617ed3..74fc07646 100644
--- a/internals/overlord/restart/restart_test.go
+++ b/internals/overlord/restart/restart_test.go
@@ -21,6 +21,7 @@ import (
"github.com/canonical/pebble/internals/overlord/restart"
"github.com/canonical/pebble/internals/overlord/state"
+ "github.com/canonical/pebble/internals/testutil"
)
func TestRestart(t *testing.T) { TestingT(t) }
@@ -134,5 +135,5 @@ func (s *restartSuite) TestRequestRestartSystemAndVerifyReboot(c *C) {
err = restart.Init(st, "boot-id-2", h2)
c.Assert(err, IsNil)
c.Check(h2.rebootAsExpected, Equals, true)
- c.Check(st.Get("system-restart-from-boot-id", &fromBootID), Equals, state.ErrNoState)
+ c.Check(st.Get("system-restart-from-boot-id", &fromBootID), testutil.ErrorIs, state.ErrNoState)
}
diff --git a/internals/overlord/servstate/handlers.go b/internals/overlord/servstate/handlers.go
index 1e38dad04..f753e8ea7 100644
--- a/internals/overlord/servstate/handlers.go
+++ b/internals/overlord/servstate/handlers.go
@@ -1,6 +1,7 @@
package servstate
import (
+ "errors"
"fmt"
"io"
"os"
@@ -28,7 +29,7 @@ import (
func TaskServiceRequest(task *state.Task) (*ServiceRequest, error) {
req := &ServiceRequest{}
err := task.Get("service-request", req)
- if err != nil && err != state.ErrNoState {
+ if err != nil && !errors.Is(err, state.ErrNoState) {
return nil, err
}
if err == nil {
diff --git a/internals/overlord/state/change.go b/internals/overlord/state/change.go
index bfb3b42db..3ab11ddfa 100644
--- a/internals/overlord/state/change.go
+++ b/internals/overlord/state/change.go
@@ -1,21 +1,16 @@
-// -*- Mode: Go; indent-tabs-mode: t -*-
-
-/*
- * Copyright (c) 2016 Canonical Ltd
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 3 as
- * published by the Free Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- *
- */
+// Copyright (c) 2024 Canonical Ltd
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License version 3 as
+// published by the Free Software Foundation.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
package state
@@ -23,8 +18,11 @@ import (
"bytes"
"encoding/json"
"fmt"
+ "sort"
"strings"
"time"
+
+ "github.com/canonical/pebble/internals/logger"
)
// Status is used for status values for changes and tasks.
@@ -37,7 +35,8 @@ const (
// to an aggregation of its tasks' statuses. See Change.Status for details.
DefaultStatus Status = 0
- // HoldStatus means the task should not run, perhaps as a consequence of an error on another task.
+ // HoldStatus means the task should not run for the moment, perhaps as a
+ // consequence of an error on another task.
HoldStatus Status = 1
// DoStatus means the change or task is ready to start.
@@ -65,6 +64,10 @@ const (
// ErrorStatus means the change or task has errored out while running or being undone.
ErrorStatus Status = 9
+ // WaitStatus means the task was accomplished successfully but some
+ // external event needs to happen before work can progress further.
+ WaitStatus Status = 10
+
nStatuses = iota
)
@@ -88,6 +91,8 @@ func (s Status) String() string {
return "Doing"
case DoneStatus:
return "Done"
+ case WaitStatus:
+ return "Wait"
case AbortStatus:
return "Abort"
case UndoStatus:
@@ -104,6 +109,18 @@ func (s Status) String() string {
panic(fmt.Sprintf("internal error: unknown task status code: %d", s))
}
+// taskWaitComputeStatus is used while computing the wait status of a
+// change. It keeps track of whether a task is waiting or not waiting, or the
+// computation for it is still in-progress to detect cyclic dependencies.
+type taskWaitComputeStatus int
+
+const (
+ taskWaitStatusNotComputed taskWaitComputeStatus = iota
+ taskWaitStatusComputing
+ taskWaitStatusNotWaiting
+ taskWaitStatusWaiting
+)
+
// Change represents a tracked modification to the system state.
//
// The Change provides both the justification for individual tasks
@@ -115,16 +132,16 @@ func (s Status) String() string {
// while the individual Task values would track the running of
// the hooks themselves.
type Change struct {
- state *State
- id string
- kind string
- summary string
- status Status
- clean bool
- data customData
- taskIDs []string
- lanes int
- ready chan struct{}
+ state *State
+ id string
+ kind string
+ summary string
+ status Status
+ clean bool
+ data customData
+ taskIDs []string
+ ready chan struct{}
+ lastObservedStatus Status
spawnTime time.Time
readyTime time.Time
@@ -157,7 +174,6 @@ type marshalledChange struct {
Clean bool `json:"clean,omitempty"`
Data map[string]*json.RawMessage `json:"data,omitempty"`
TaskIDs []string `json:"task-ids,omitempty"`
- Lanes int `json:"lanes,omitempty"`
SpawnTime time.Time `json:"spawn-time"`
ReadyTime *time.Time `json:"ready-time,omitempty"`
@@ -178,7 +194,6 @@ func (c *Change) MarshalJSON() ([]byte, error) {
Clean: c.clean,
Data: c.data,
TaskIDs: c.taskIDs,
- Lanes: c.lanes,
SpawnTime: c.spawnTime,
ReadyTime: readyTime,
@@ -206,7 +221,6 @@ func (c *Change) UnmarshalJSON(data []byte) error {
}
c.data = custData
c.taskIDs = unmarshalled.TaskIDs
- c.lanes = unmarshalled.Lanes
c.ready = make(chan struct{})
c.spawnTime = unmarshalled.SpawnTime
if unmarshalled.ReadyTime != nil {
@@ -251,12 +265,19 @@ func (c *Change) Get(key string, value interface{}) error {
return c.data.get(key, value)
}
+// Has returns whether the provided key has an associated value.
+func (c *Change) Has(key string) bool {
+ c.state.reading()
+ return c.data.has(key)
+}
+
var statusOrder = []Status{
AbortStatus,
UndoingStatus,
UndoStatus,
DoingStatus,
DoStatus,
+ WaitStatus,
ErrorStatus,
UndoneStatus,
DoneStatus,
@@ -269,32 +290,138 @@ func init() {
}
}
+func (c *Change) isTaskWaiting(visited map[string]taskWaitComputeStatus, t *Task, deps []*Task) bool {
+ taskID := t.ID()
+ // Retrieve the compute status of the wait for the task, if not
+ // computed this defaults to 0 (taskWaitStatusNotComputed).
+ computeStatus := visited[taskID]
+ switch computeStatus {
+ case taskWaitStatusComputing:
+ // Cyclic dependency detected, return false to short-circuit.
+ logger.Noticef("detected cyclic dependencies for task %q in change %q", t.Kind(), t.Change().Kind())
+ // Make sure errors show up in "pebble change " too
+ t.Logf("detected cyclic dependencies for task %q in change %q", t.Kind(), t.Change().Kind())
+ return false
+ case taskWaitStatusWaiting, taskWaitStatusNotWaiting:
+ return computeStatus == taskWaitStatusWaiting
+ }
+ visited[taskID] = taskWaitStatusComputing
+
+ var isWaiting bool
+depscheck:
+ for _, wt := range deps {
+ switch wt.Status() {
+ case WaitStatus:
+ isWaiting = true
+ // States that can be valid when waiting
+ // - Done, Undone, ErrorStatus, HoldStatus
+ case DoneStatus, UndoneStatus, ErrorStatus, HoldStatus:
+ continue
+ // For 'Do' and 'Undo' we have to check whether the task is waiting
+ // for any dependencies. The logic is the same, but the set of tasks
+ // varies.
+ case DoStatus:
+ isWaiting = c.isTaskWaiting(visited, wt, wt.WaitTasks())
+ if !isWaiting {
+ // Cancel early if we detect something is runnable.
+ break depscheck
+ }
+ case UndoStatus:
+ isWaiting = c.isTaskWaiting(visited, wt, wt.HaltTasks())
+ if !isWaiting {
+ // Cancel early if we detect something is runnable.
+ break depscheck
+ }
+ default:
+ // When we determine the change can not be in a wait-state then
+ // break early.
+ isWaiting = false
+ break depscheck
+ }
+ }
+ if isWaiting {
+ visited[taskID] = taskWaitStatusWaiting
+ } else {
+ visited[taskID] = taskWaitStatusNotWaiting
+ }
+ return isWaiting
+}
+
+// isChangeWaiting should only ever return true iff it determines all tasks in Do/Undo
+// are blocked by tasks in either of three states: 'DoneStatus', 'UndoneStatus' or 'WaitStatus',
+// if this fails, we default to the normal status ordering logic.
+func (c *Change) isChangeWaiting() bool {
+ // Since we might visit tasks more than once, we store results to avoid recomputing them.
+ visited := make(map[string]taskWaitComputeStatus)
+ for _, t := range c.Tasks() {
+ switch t.Status() {
+ case WaitStatus, DoneStatus, UndoneStatus, ErrorStatus, HoldStatus:
+ continue
+ case DoStatus:
+ if !c.isTaskWaiting(visited, t, t.WaitTasks()) {
+ return false
+ }
+ case UndoStatus:
+ if !c.isTaskWaiting(visited, t, t.HaltTasks()) {
+ return false
+ }
+ default:
+ return false
+ }
+ }
+ // If we end up here, then return true as we know we
+ // have at least one waiter in this change.
+ return true
+}
+
// Status returns the current status of the change.
// If the status was not explicitly set the result is derived from the status
// of the individual tasks related to the change, according to the following
// decision sequence:
//
+// - With all pending tasks blocked by other tasks in WaitStatus, return WaitStatus
// - With at least one task in DoStatus, return DoStatus
// - With at least one task in ErrorStatus, return ErrorStatus
// - Otherwise, return DoneStatus
func (c *Change) Status() Status {
c.state.reading()
- if c.status == DefaultStatus {
- if len(c.taskIDs) == 0 {
- return HoldStatus
- }
- statusStats := make([]int, nStatuses)
- for _, tid := range c.taskIDs {
- statusStats[c.state.tasks[tid].Status()]++
+ if c.status != DefaultStatus {
+ return c.status
+ }
+
+ if len(c.taskIDs) == 0 {
+ return HoldStatus
+ }
+
+ statusStats := make([]int, nStatuses)
+ for _, tid := range c.taskIDs {
+ statusStats[c.state.tasks[tid].Status()]++
+ }
+
+ // If the change has any waiters, check for any runnable tasks
+ // or whether it's completely blocked by waiters.
+ if statusStats[WaitStatus] > 0 {
+ // Only if the change has all tasks blocked we return WaitStatus.
+ if c.isChangeWaiting() {
+ return WaitStatus
}
- for _, s := range statusOrder {
- if statusStats[s] > 0 {
- return s
- }
+ }
+
+ // Otherwise we return the current status with the highest priority.
+ for _, s := range statusOrder {
+ if statusStats[s] > 0 {
+ return s
}
- panic(fmt.Sprintf("internal error: cannot process change status: %v", statusStats))
}
- return c.status
+ panic(fmt.Sprintf("internal error: cannot process change status: %v", statusStats))
+}
+
+func (c *Change) notifyStatusChange(new Status) {
+ if c.lastObservedStatus == new {
+ return
+ }
+ c.state.notifyChangeStatusChangedHandlers(c, c.lastObservedStatus, new)
+ c.lastObservedStatus = new
}
// SetStatus sets the change status, overriding the default behavior (see Status method).
@@ -304,6 +431,7 @@ func (c *Change) SetStatus(s Status) {
if s.Ready() {
c.markReady()
}
+ c.notifyStatusChange(c.Status())
}
func (c *Change) markReady() {
@@ -322,15 +450,10 @@ func (c *Change) Ready() <-chan struct{} {
return c.ready
}
-// taskStatusChanged is called by tasks when their status is changed,
-// to give the opportunity for the change to close its ready channel.
-func (c *Change) taskStatusChanged(t *Task, old, new Status) {
- if old.Ready() == new.Ready() {
- return
- }
+func (c *Change) detectChangeReady(excludeTask *Task) {
for _, tid := range c.taskIDs {
task := c.state.tasks[tid]
- if task != t && !task.status.Ready() {
+ if task != excludeTask && !task.status.Ready() {
return
}
}
@@ -343,6 +466,21 @@ func (c *Change) taskStatusChanged(t *Task, old, new Status) {
c.markReady()
}
+// taskStatusChanged is called by tasks when their status is changed,
+// to give the opportunity for the change to close its ready channel, and
+// notify observers of Change changes.
+func (c *Change) taskStatusChanged(t *Task, old, new Status) {
+ cs := c.Status()
+ // If the task changes from ready => unready or unready => ready,
+ // update the ready status for the change.
+ if old.Ready() == new.Ready() {
+ c.notifyStatusChange(cs)
+ return
+ }
+ c.detectChangeReady(t)
+ c.notifyStatusChange(cs)
+}
+
// IsClean returns whether all tasks in the change have been cleaned. See SetClean.
func (c *Change) IsClean() bool {
c.state.reading()
@@ -519,6 +657,44 @@ func (c *Change) AbortLanes(lanes []int) {
c.abortLanes(lanes, make(map[int]bool), make(map[string]bool))
}
+// AbortUnreadyLanes aborts the tasks from lanes that aren't fully ready, where
+// a ready lane is one in which all tasks are ready.
+func (c *Change) AbortUnreadyLanes() {
+ c.state.writing()
+ c.abortUnreadyLanes()
+}
+
+func (c *Change) abortUnreadyLanes() {
+ lanesWithLiveTasks := map[int]bool{}
+
+ for _, tid := range c.taskIDs {
+ t := c.state.tasks[tid]
+ if !t.Status().Ready() {
+ for _, tlane := range t.Lanes() {
+ lanesWithLiveTasks[tlane] = true
+ }
+ }
+ }
+
+ abortLanes := []int{}
+ for lane := range lanesWithLiveTasks {
+ abortLanes = append(abortLanes, lane)
+ }
+ c.abortLanes(abortLanes, make(map[int]bool), make(map[string]bool))
+}
+
+// taskEffectiveStatus returns the 'effective' status. This means it accounts
+// for tasks being in WaitStatus, and instead of returning the WaitStatus we
+// return the actual status. (The status after the wait).
+func taskEffectiveStatus(t *Task) Status {
+ status := t.Status()
+ if status == WaitStatus {
+ // If the task is waiting, then use the effective status instead.
+ status = t.WaitedStatus()
+ }
+ return status
+}
+
func (c *Change) abortLanes(lanes []int, abortedLanes map[int]bool, seenTasks map[string]bool) {
var hasLive = make(map[int]bool)
var hasDead = make(map[int]bool)
@@ -528,7 +704,7 @@ NextChangeTask:
t := c.state.tasks[tid]
var live bool
- switch t.Status() {
+ switch taskEffectiveStatus(t) {
case DoStatus, DoingStatus, DoneStatus:
live = true
}
@@ -579,7 +755,7 @@ func (c *Change) abortTasks(tasks []*Task, abortedLanes map[int]bool, seenTasks
continue
}
seenTasks[t.id] = true
- switch t.Status() {
+ switch taskEffectiveStatus(t) {
case DoStatus:
// Still pending so don't even start.
t.SetStatus(HoldStatus)
@@ -607,3 +783,87 @@ func (c *Change) abortTasks(tasks []*Task, abortedLanes map[int]bool, seenTasks
c.abortLanes(lanes, abortedLanes, seenTasks)
}
}
+
+type TaskDependencyCycleError struct {
+ IDs []string
+ msg string
+}
+
+func (e *TaskDependencyCycleError) Error() string { return e.msg }
+
+func (e *TaskDependencyCycleError) Is(err error) bool {
+ _, ok := err.(*TaskDependencyCycleError)
+ return ok
+}
+
+// CheckTaskDependencies checks the tasks in the change for cyclic dependencies
+// and returns an error in such case.
+func (c *Change) CheckTaskDependencies() error {
+ tasks := c.Tasks()
+ // count how many tasks any given non-independent task waits for
+ predecessors := make(map[string]int, len(tasks))
+
+ taskByID := map[string]*Task{}
+ for _, t := range tasks {
+ taskByID[t.id] = t
+ if l := len(t.waitTasks); l > 0 {
+ // only add an entry if the task is not independent
+ predecessors[t.id] = l
+ }
+ }
+
+ // Kahn topological sort: make our way starting with tasks that are
+ // independent (their predecessors count is 0), then visit their direct
+ // successors (halt tasks), and for each reduce their predecessors
+ // count; once the count drops to 0, all direct dependencies of a given
+ // task have been accounted for and the task becomes independent.
+
+ // queue of tasks to check
+ queue := make([]string, 0, len(tasks))
+ // identify all independent tasks
+ for _, t := range tasks {
+ if predecessors[t.id] == 0 {
+ queue = append(queue, t.id)
+ }
+ }
+
+ for len(queue) > 0 {
+ // take the first independent task
+ id := queue[0]
+ queue = queue[1:]
+ // reduce the incoming edge of its successors
+ for _, successor := range taskByID[id].haltTasks {
+ predecessors[successor]--
+ if predecessors[successor] == 0 {
+ // a task that was a successor has become
+ // independent
+ delete(predecessors, successor)
+ queue = append(queue, successor)
+ }
+ }
+ }
+
+ if len(predecessors) != 0 {
+ // tasks that are left cannot have their dependencies satisfied
+ var unsatisfiedTasks []string
+ for id := range predecessors {
+ unsatisfiedTasks = append(unsatisfiedTasks, id)
+ }
+ sort.Strings(unsatisfiedTasks)
+ msg := strings.Builder{}
+ msg.WriteString("dependency cycle involving tasks [")
+ for i, id := range unsatisfiedTasks {
+ t := taskByID[id]
+ msg.WriteString(fmt.Sprintf("%v:%v", t.id, t.kind))
+ if i < len(unsatisfiedTasks)-1 {
+ msg.WriteRune(' ')
+ }
+ }
+ msg.WriteRune(']')
+ return &TaskDependencyCycleError{
+ IDs: unsatisfiedTasks,
+ msg: msg.String(),
+ }
+ }
+ return nil
+}
diff --git a/internals/overlord/state/change_test.go b/internals/overlord/state/change_test.go
index 339c94344..273c09052 100644
--- a/internals/overlord/state/change_test.go
+++ b/internals/overlord/state/change_test.go
@@ -1,25 +1,21 @@
-// -*- Mode: Go; indent-tabs-mode: t -*-
-
-/*
- * Copyright (c) 2016 Canonical Ltd
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 3 as
- * published by the Free Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- *
- */
+// Copyright (c) 2024 Canonical Ltd
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License version 3 as
+// published by the Free Software Foundation.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
package state_test
import (
+ "errors"
"fmt"
"sort"
"strconv"
@@ -68,7 +64,7 @@ func (cs *changeSuite) TestReadyTime(c *C) {
}
func (cs *changeSuite) TestStatusString(c *C) {
- for s := state.Status(0); s < state.ErrorStatus+1; s++ {
+ for s := state.Status(0); s < state.WaitStatus+1; s++ {
c.Assert(s.String(), Matches, ".+")
}
}
@@ -88,6 +84,21 @@ func (cs *changeSuite) TestGetSet(c *C) {
c.Check(v, Equals, 1)
}
+func (cs *changeSuite) TestHas(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("install", "...")
+ c.Check(chg.Has("a"), Equals, false)
+
+ chg.Set("a", 1)
+ c.Check(chg.Has("a"), Equals, true)
+
+ chg.Set("a", nil)
+ c.Check(chg.Has("a"), Equals, false)
+}
+
// TODO Better testing of full change roundtripping via JSON.
func (cs *changeSuite) TestNewTaskAddTaskAndTasks(c *C) {
@@ -227,9 +238,13 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) {
tasks := make(map[state.Status]*state.Task)
- for s := state.DefaultStatus + 1; s < state.ErrorStatus+1; s++ {
+ for s := state.DefaultStatus + 1; s < state.WaitStatus+1; s++ {
t := st.NewTask("download", s.String())
- t.SetStatus(s)
+ if s == state.WaitStatus {
+ t.SetToWait(state.DoneStatus)
+ } else {
+ t.SetStatus(s)
+ }
chg.AddTask(t)
tasks[s] = t
}
@@ -240,6 +255,7 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) {
state.UndoStatus,
state.DoingStatus,
state.DoStatus,
+ state.WaitStatus,
state.ErrorStatus,
state.UndoneStatus,
state.DoneStatus,
@@ -252,7 +268,11 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) {
if s == s2 {
break
}
- tasks[s2].SetStatus(s)
+ if s == state.WaitStatus {
+ tasks[s2].SetToWait(state.DoneStatus)
+ } else {
+ tasks[s2].SetStatus(s)
+ }
}
c.Assert(chg.Status(), Equals, s)
}
@@ -432,9 +452,13 @@ func (cs *changeSuite) TestAbort(c *C) {
chg := st.NewChange("install", "...")
- for s := state.DefaultStatus + 1; s < state.ErrorStatus+1; s++ {
+ for s := state.DefaultStatus + 1; s < state.WaitStatus+1; s++ {
t := st.NewTask("download", s.String())
- t.SetStatus(s)
+ if s == state.WaitStatus {
+ t.SetToWait(state.DoneStatus)
+ } else {
+ t.SetStatus(s)
+ }
t.Set("old-status", s)
chg.AddTask(t)
}
@@ -451,7 +475,7 @@ func (cs *changeSuite) TestAbort(c *C) {
switch s {
case state.DoStatus:
c.Assert(t.Status(), Equals, state.HoldStatus)
- case state.DoneStatus:
+ case state.DoneStatus, state.WaitStatus:
c.Assert(t.Status(), Equals, state.UndoStatus)
case state.DoingStatus:
c.Assert(t.Status(), Equals, state.AbortStatus)
@@ -526,11 +550,13 @@ func (cs *changeSuite) TestAbortKⁿ(c *C) {
// Task wait order:
//
-// => t21 => t22
-// / \
-// t11 => t12 => t41 => t42
-// \ /
-// => t31 => t32
+// => t21 => t22
+// / \
+//
+// t11 => t12 => t41 => t42
+//
+// \ /
+// => t31 => t32
//
// setup and result lines are :[:,...]
//
@@ -577,6 +603,10 @@ var abortLanesTests = []struct {
setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
abort: []int{2},
result: "t21:hold t22:hold t41:hold t42:hold *:do",
+ }, {
+ setup: "t11:done:1 t12:wait:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
+ abort: []int{2},
+ result: "t21:hold t22:hold t41:hold t42:hold t11:done t12:wait *:do",
}, {
setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
abort: []int{3},
@@ -660,7 +690,7 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) {
c.Logf("Testing setup: %s", test.setup)
statuses := make(map[string]state.Status)
- for s := state.DefaultStatus; s <= state.ErrorStatus; s++ {
+ for s := state.DefaultStatus; s <= state.WaitStatus; s++ {
statuses[strings.ToLower(s.String())] = s
}
@@ -680,7 +710,11 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) {
}
seen[parts[0]] = true
task := tasks[parts[0]]
- task.SetStatus(statuses[parts[1]])
+ if statuses[parts[1]] == state.WaitStatus {
+ task.SetToWait(state.DoneStatus)
+ } else {
+ task.SetStatus(statuses[parts[1]])
+ }
if len(parts) > 2 {
lanes := strings.Split(parts[2], ",")
for _, lane := range lanes {
@@ -723,3 +757,691 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) {
c.Assert(strings.Join(obtained, " "), Equals, strings.Join(expected, " "), Commentf("setup: %s", test.setup))
}
}
+
+// setup and result lines are :[:,...]
+// order is -> (implies task2 waits for task 1)
+// "*" as task name means "all remaining".
+var abortUnreadyLanesTests = []struct {
+ setup string
+ order string
+ result string
+}{
+
+ // Some basics.
+ {
+ setup: "*:do",
+ result: "*:hold",
+ }, {
+ setup: "*:wait",
+ result: "*:undo",
+ }, {
+ setup: "*:done",
+ result: "*:done",
+ }, {
+ setup: "*:error",
+ result: "*:error",
+ },
+
+ // t11 (1) => t12 (1) => t21 (1) => t22 (1)
+ // t31 (2) => t32 (2) => t41 (2) => t42 (2)
+ {
+ setup: "t11:do:1 t12:do:1 t21:do:1 t22:do:1 t31:do:2 t32:do:2 t41:do:2 t42:do:2",
+ order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42",
+ result: "*:hold",
+ }, {
+ setup: "t11:done:1 t12:done:1 t21:done:1 t22:done:1 t31:do:2 t32:do:2 t41:do:2 t42:do:2",
+ order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42",
+ result: "t11:done t12:done t21:done t22:done t31:hold t32:hold t41:hold t42:hold",
+ }, {
+ setup: "t11:done:1 t12:done:1 t21:done:1 t22:done:1 t31:done:2 t32:done:2 t41:done:2 t42:do:2",
+ order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42",
+ result: "t11:done t12:done t21:done t22:done t31:undo t32:undo t41:undo t42:hold",
+ },
+ // => t21 (2) => t22 (2)
+ // / \
+ // t11 (2,3) => t12 (2,3) => t41 (4) => t42 (4)
+ // \ /
+ // => t31 (3) => t32 (3)
+ {
+ setup: "t11:do:2,3 t12:do:2,3 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42",
+ result: "*:hold",
+ }, {
+ setup: "t11:done:2,3 t12:done:2,3 t21:done:2 t22:done:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42",
+ // lane 2 is fully complete so it does not get aborted
+ result: "t11:done t12:done t21:done t22:done t31:abort t32:hold t41:hold t42:hold *:undo",
+ }, {
+ setup: "t11:done:2,3 t12:done:2,3 t21:done:2 t22:done:2 t31:wait:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42",
+ // lane 2 is fully complete so it does not get aborted
+ result: "t11:done t12:done t21:done t22:done t31:undo t32:hold t41:hold t42:hold *:undo",
+ }, {
+ setup: "t11:done:2,3 t12:done:2,3 t21:doing:2 t22:do:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42",
+ result: "t21:abort t22:hold t31:abort t32:hold t41:hold t42:hold *:undo",
+ },
+
+ // t11 (1) => t12 (1)
+ // t21 (2) => t22 (2)
+ // t31 (3) => t32 (3)
+ // t41 (4) => t42 (4)
+ {
+ setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t21->t22 t31->t32 t41->t42",
+ result: "*:hold",
+ }, {
+ setup: "t11:do:1 t12:do:1 t21:doing:2 t22:do:2 t31:done:3 t32:doing:3 t41:undone:4 t42:error:4",
+ order: "t11->t12 t21->t22 t31->t32 t41->t42",
+ result: "t11:hold t12:hold t21:abort t22:hold t31:undo t32:abort t41:undone t42:error",
+ },
+ // auto refresh like arrangement
+ //
+ // (apps)
+ // => t31 (3) => t32 (3)
+ // (snapd) (base) /
+ // t11 (1) => t12 (1) => t21 (2) => t22 (2)
+ // \
+ // => t41 (4) => t42 (4)
+ {
+ setup: "t11:done:1 t12:done:1 t21:done:2 t22:done:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42",
+ result: "t11:done t12:done t21:done t22:done t31:abort *:hold",
+ }, {
+ //
+ setup: "t11:done:1 t12:done:1 t21:done:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42",
+ result: "t11:done t12:done t21:undo *:hold",
+ },
+ // arrangement with a cyclic dependency between tasks
+ //
+ // /-----------------------------------------\
+ // | |
+ // | => t31 (3) => t32 (3) /
+ // (snapd) v (base) /
+ // t11 (1) => t12 (1) => t21 (2) => t22 (2)
+ // \
+ // => t41 (4) => t42 (4)
+ {
+ setup: "t11:done:1 t12:done:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42 t32->t21",
+ result: "t11:done t12:done *:hold",
+ },
+}
+
+func (ts *taskRunnerSuite) TestAbortUnreadyLanes(c *C) {
+
+ names := strings.Fields("t11 t12 t21 t22 t31 t32 t41 t42")
+
+ for i, test := range abortUnreadyLanesTests {
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+ defer r.Stop()
+
+ st.Lock()
+ defer st.Unlock()
+
+ c.Assert(len(st.Tasks()), Equals, 0)
+
+ chg := st.NewChange("install", "...")
+ tasks := make(map[string]*state.Task)
+ for _, name := range names {
+ tasks[name] = st.NewTask("do", name)
+ chg.AddTask(tasks[name])
+ }
+
+ c.Logf("----- %v", i)
+ c.Logf("Testing setup: %s", test.setup)
+
+ for _, wp := range strings.Fields(test.order) {
+ pair := strings.Split(wp, "->")
+ c.Assert(pair, HasLen, 2)
+ // task 2 waits for task 1 is denoted as:
+ // task1->task2
+ tasks[pair[1]].WaitFor(tasks[pair[0]])
+ }
+
+ statuses := make(map[string]state.Status)
+ for s := state.DefaultStatus; s <= state.WaitStatus; s++ {
+ statuses[strings.ToLower(s.String())] = s
+ }
+
+ items := strings.Fields(test.setup)
+ seen := make(map[string]bool)
+ for i := 0; i < len(items); i++ {
+ item := items[i]
+ parts := strings.Split(item, ":")
+ if parts[0] == "*" {
+ c.Assert(i, Equals, len(items)-1, Commentf("*: can only be used as the last entry"))
+ for _, name := range names {
+ if !seen[name] {
+ parts[0] = name
+ items = append(items, strings.Join(parts, ":"))
+ }
+ }
+ continue
+ }
+ seen[parts[0]] = true
+ task := tasks[parts[0]]
+ if statuses[parts[1]] == state.WaitStatus {
+ task.SetToWait(state.DoneStatus)
+ } else {
+ task.SetStatus(statuses[parts[1]])
+ }
+ if len(parts) > 2 {
+ lanes := strings.Split(parts[2], ",")
+ for _, lane := range lanes {
+ n, err := strconv.Atoi(lane)
+ c.Assert(err, IsNil)
+ task.JoinLane(n)
+ }
+ }
+ }
+
+ c.Logf("Aborting")
+
+ chg.AbortUnreadyLanes()
+
+ c.Logf("Expected result: %s", test.result)
+
+ seen = make(map[string]bool)
+ var expected = strings.Fields(test.result)
+ var obtained []string
+ for i := 0; i < len(expected); i++ {
+ item := expected[i]
+ parts := strings.Split(item, ":")
+ if parts[0] == "*" {
+ c.Assert(i, Equals, len(expected)-1, Commentf("*: can only be used as the last entry"))
+ var expanded []string
+ for _, name := range names {
+ if !seen[name] {
+ parts[0] = name
+ expanded = append(expanded, strings.Join(parts, ":"))
+ }
+ }
+ expected = append(expected[:i], append(expanded, expected[i+1:]...)...)
+ i--
+ continue
+ }
+ name := parts[0]
+ seen[parts[0]] = true
+ obtained = append(obtained, name+":"+strings.ToLower(tasks[name].Status().String()))
+ }
+
+ c.Assert(strings.Join(obtained, " "), Equals, strings.Join(expected, " "), Commentf("setup: %s", test.setup))
+ }
+}
+
+// setup is a list of tasks " ", order is ->
+// (implies task2 waits for task 1)
+var cyclicDependencyTests = []struct {
+ setup string
+ order string
+ err string
+ errIDs []string
+}{
+
+ // Some basics.
+ {
+ setup: "t1",
+ }, {
+ setup: "",
+ }, {
+ // independent tasks
+ setup: "t1 t2 t3",
+ }, {
+ // some independent and some ordered tasks
+ setup: "t1 t2 t3 t4",
+ order: "t2->t3",
+ },
+ // some independent, dependencies as if added by WaitAll()
+ // t1 => t2
+ // t1,t2 => t3
+ // t1,t2,t3 => t4
+ {
+ setup: "t1 t2 t3 t4",
+ order: "t1->t2 t1->t3 t2->t3 t1->t4 t2->t4 t3->t4",
+ }, {
+ // simple loop
+ setup: "t1 t2",
+ order: "t1->t2 t2->t1",
+ err: `dependency cycle involving tasks \[1:t1 2:t2\]`,
+ errIDs: []string{"1", "2"},
+ },
+
+ // t1 => t2 => t3 => t4
+ // t5 => t6 => t7 => t8
+ {
+ setup: "t1 t2 t3 t4 t5 t6 t7 t8",
+ order: "t1->t2 t2->t3 t3->t4 t5->t6 t6->t7 t7->t8",
+ },
+ // => t21 => t22
+ // / \
+ // t11 => t12 => t41 => t42
+ // \ /
+ // => t31 => t32
+ {
+ setup: "t11 t12 t21 t22 t31 t32 t41 t42",
+ order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42",
+ },
+ // t11 (1) => t12 (1)
+ // t21 (2) => t22 (2)
+ // t31 (3) => t32 (3)
+ // t41 (4) => t42 (4)
+ {
+ setup: "t11 t12 t21 t22 t31 t32 t41 t42",
+ order: "t11->t12 t21->t22 t31->t32 t41->t42",
+ },
+ // auto refresh like arrangement
+ //
+ // (apps)
+ // => t31 (3) => t32 (3)
+ // (snapd) (base) /
+ // t11 (1) => t12 (1) => t21 (2) => t22 (2)
+ // \
+ // => t41 (4) => t42 (4)
+ {
+ setup: "t11 t12 t21 t22 t31 t32 t41 t42",
+ order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42",
+ },
+ // arrangement with a cyclic dependency between tasks
+ //
+ // /-----------------------------------------\
+ // | |
+ // | => t31 (3) => t32 (3) /
+ // (snapd) v (base) /
+ // t11 (1) => t12 (1) => t21 (2) => t22 (2)
+ // \
+ // => t41 (4) => t42 (4)
+ {
+ setup: "t11 t12 t21 t22 t31 t32 t41 t42",
+ order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42 t32->t21",
+ err: `dependency cycle involving tasks \[3:t21 4:t22 5:t31 6:t32 7:t41 8:t42\]`,
+ errIDs: []string{"3", "4", "5", "6", "7", "8"},
+ },
+ // t1 => t2 => t3 => t4 --> t6
+ // t5 => t6 => t7 => t8 --> t2
+ {
+ setup: "t1 t2 t3 t4 t5 t6 t7 t8",
+ order: "t1->t2 t2->t3 t3->t4 t4->t6 t5->t6 t6->t7 t7->t8 t8->t2",
+ err: `dependency cycle involving tasks \[2:t2 3:t3 4:t4 6:t6 7:t7 8:t8\]`,
+ errIDs: []string{"2", "3", "4", "6", "7", "8"},
+ },
+}
+
+func (ts *taskRunnerSuite) TestCheckTaskDependencies(c *C) {
+
+ for i, test := range cyclicDependencyTests {
+ names := strings.Fields(test.setup)
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+ defer r.Stop()
+
+ st.Lock()
+ defer st.Unlock()
+
+ c.Assert(len(st.Tasks()), Equals, 0)
+
+ chg := st.NewChange("install", "...")
+ tasks := make(map[string]*state.Task)
+ for _, name := range names {
+ tasks[name] = st.NewTask(name, name)
+ chg.AddTask(tasks[name])
+ }
+
+ c.Logf("----- %v", i)
+ c.Logf("Testing setup: %s", test.setup)
+
+ for _, wp := range strings.Fields(test.order) {
+ pair := strings.Split(wp, "->")
+ c.Assert(pair, HasLen, 2)
+ // task 2 waits for task 1 is denoted as:
+ // task1->task2
+ tasks[pair[1]].WaitFor(tasks[pair[0]])
+ }
+
+ err := chg.CheckTaskDependencies()
+
+ if test.err != "" {
+ c.Assert(err, ErrorMatches, test.err)
+ c.Assert(errors.Is(err, &state.TaskDependencyCycleError{}), Equals, true)
+ errTasksDepCycle := err.(*state.TaskDependencyCycleError)
+ c.Assert(errTasksDepCycle.IDs, DeepEquals, test.errIDs)
+ } else {
+ c.Assert(err, IsNil)
+ }
+ }
+}
+
+func (cs *changeSuite) TestIsWaitingStatusOrderWithWaits(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+ t2 := st.NewTask("task2", "...")
+ t3 := st.NewTask("task3", "...")
+ t4 := st.NewTask("wait-task", "...")
+ t1.WaitFor(t2)
+ t1.WaitFor(t3)
+
+ chg.AddTask(t1)
+ chg.AddTask(t2)
+ chg.AddTask(t3)
+ chg.AddTask(t4)
+
+ // Set the wait-task into WaitStatus, to ensure we trigger the isWaiting
+ // logic and that it doesn't return WaitStatus for statuses which are in
+ // higher order
+ t4.SetToWait(state.DoneStatus)
+
+ // Test the following sequences:
+ // task1 (do) => task2 (done) => task3 (doing)
+ t2.SetToWait(state.DoneStatus)
+ t3.SetStatus(state.DoingStatus)
+ c.Check(chg.Status(), Equals, state.DoingStatus)
+
+ // task1 (done) => task2 (done) => task3 (undoing)
+ t1.SetToWait(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ t3.SetStatus(state.UndoingStatus)
+ c.Check(chg.Status(), Equals, state.UndoingStatus)
+
+ // task1 (done) => task2 (done) => task3 (abort)
+ t1.SetToWait(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ t3.SetStatus(state.AbortStatus)
+ c.Check(chg.Status(), Equals, state.AbortStatus)
+}
+
+func (cs *changeSuite) TestIsWaitingSingle(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+
+ chg.AddTask(t1)
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ t1.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+}
+
+func (cs *changeSuite) TestIsWaitingTwoTasks(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+ t2 := st.NewTask("task2", "...")
+ t3 := st.NewTask("wait-task", "...")
+ t2.WaitFor(t1)
+
+ chg.AddTask(t1)
+ chg.AddTask(t2)
+ chg.AddTask(t3)
+
+ // Put t3 into wait-status to trigger the isWaiting logic each time
+ // for the change.
+ t3.SetToWait(state.DoneStatus)
+
+ // task1 (do) => task2 (do) no reboot
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // task1 (done) => task2 (do) no reboot
+ t1.SetStatus(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // task1 (wait) => task2 (do) means need a reboot
+ t1.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (done) => task2 (wait) means need a reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+}
+
+func (cs *changeSuite) TestIsWaitingCircularDependency(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+ t2 := st.NewTask("task2", "...")
+ t3 := st.NewTask("task3", "...")
+ t4 := st.NewTask("wait-task", "...")
+
+ // Setup circular dependency between t1,t2 and t3, they should
+ // still act normally.
+ t2.WaitFor(t1)
+ t3.WaitFor(t2)
+ t1.WaitFor(t3)
+
+ chg.AddTask(t1)
+ chg.AddTask(t2)
+ chg.AddTask(t3)
+ chg.AddTask(t4)
+
+ // To trigger the cyclic dependency check, we must trigger the isWaiting logic
+ // and we do this by putting t4 into WaitStatus.
+ t4.SetToWait(state.DoneStatus)
+
+ // task1 (do) => task2 (do) => task3 (do) no reboot
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // task1 (done) => task2 (do) => task3 (do) no reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetStatus(state.DoingStatus)
+ c.Check(chg.Status(), Equals, state.DoingStatus)
+
+ // task1 (wait) => task2 (do) => task3 (do) means need a reboot
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (done) => task2 (wait) => task3 (do) means need a reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+}
+
+func (cs *changeSuite) TestIsWaitingMultipleDependencies(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+ t2 := st.NewTask("task2", "...")
+ t3 := st.NewTask("task3", "...")
+ t4 := st.NewTask("wait-task", "...")
+ t3.WaitFor(t1)
+ t3.WaitFor(t2)
+
+ chg.AddTask(t1)
+ chg.AddTask(t2)
+ chg.AddTask(t3)
+ chg.AddTask(t4)
+
+ // Put t4 into wait-status to trigger the isWaiting logic each time
+ // for the change.
+ t4.SetToWait(state.DoneStatus)
+
+ // task1 (do) + task2 (do) => task3 (do) no reboot
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // task1 (done) + task2 (done) => task3 (do) no reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetStatus(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // task1 (done) + task2 (do) => task3 (do) no reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetStatus(state.DoStatus)
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // For the next two cases we are testing that a task with dependencies
+ // which have completed, but in a non-successful way is handled correctly.
+ // task1 (error) + task2 (wait) => task3 (do) means need reboot
+ // to finalize task2
+ t1.SetStatus(state.ErrorStatus)
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (wait) + task2 (error) => task3 (do) means need reboot
+ // to finalize task1
+ t1.SetToWait(state.DoneStatus)
+ t2.SetStatus(state.ErrorStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (done) + task2 (wait) => task3 (do) means need a reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (wait) + task2 (wait) => task3 (do) means need a reboot
+ t1.SetToWait(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (done) + task2 (done) => task3 (wait) means need a reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetStatus(state.DoneStatus)
+ t3.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (wait) + task2 (abort) => task3 (do)
+ t1.SetToWait(state.DoneStatus)
+ t2.SetStatus(state.AbortStatus)
+ t3.SetStatus(state.DoStatus)
+ c.Check(chg.Status(), Equals, state.AbortStatus)
+}
+
+func (cs *changeSuite) TestIsWaitingUndoTwoTasks(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+ t2 := st.NewTask("task2", "...")
+ t3 := st.NewTask("wait-task", "...")
+ t2.WaitFor(t1)
+
+ chg.AddTask(t1)
+ chg.AddTask(t2)
+ chg.AddTask(t3)
+
+ // Put t3 into wait-status to trigger the isWaiting logic each time
+ // for the change.
+ t3.SetToWait(state.DoneStatus)
+
+ // we use <=| to denote the reverse dependence relationship
+ // followed by undo logic
+
+ // task1 (undo) <=| task2 (undo) no reboot
+ t1.SetStatus(state.UndoStatus)
+ t2.SetStatus(state.UndoStatus)
+ c.Check(chg.Status(), Equals, state.UndoStatus)
+
+ // task1 (undo) <=| task2 (undone) no reboot
+ t1.SetStatus(state.UndoStatus)
+ t2.SetStatus(state.UndoneStatus)
+ c.Check(chg.Status(), Equals, state.UndoStatus)
+
+ // task1 (undo) <=| task2 (wait) means need a reboot
+ t1.SetStatus(state.UndoStatus)
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (wait) <=| task2 (undone) means need a reboot
+ t1.SetToWait(state.DoneStatus)
+ t2.SetStatus(state.UndoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+}
+
+func (cs *changeSuite) TestIsWaitingUndoMultipleDependencies(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+ t2 := st.NewTask("task2", "...")
+ t3 := st.NewTask("task3", "...")
+ t4 := st.NewTask("task4", "...")
+ t5 := st.NewTask("wait-task", "...")
+ t3.WaitFor(t1)
+ t3.WaitFor(t2)
+ t4.WaitFor(t1)
+ t4.WaitFor(t2)
+
+ chg.AddTask(t1)
+ chg.AddTask(t2)
+ chg.AddTask(t3)
+ chg.AddTask(t4)
+ chg.AddTask(t5)
+
+ // Put t5 into wait-status to trigger the isWaiting logic each time
+ // for the change.
+ t5.SetToWait(state.DoneStatus)
+
+ // task1 (undo) + task2 (undo) <=| task3 (undo) no reboot
+ t1.SetStatus(state.UndoStatus)
+ t2.SetStatus(state.UndoStatus)
+ t3.SetStatus(state.UndoStatus)
+ c.Check(chg.Status(), Equals, state.UndoStatus)
+
+ // task1 (undo) + task2 (undo) <=| task3 (undone) no reboot
+ t1.SetStatus(state.UndoStatus)
+ t2.SetStatus(state.UndoStatus)
+ t3.SetStatus(state.UndoneStatus)
+ c.Check(chg.Status(), Equals, state.UndoStatus)
+
+ // task1 (undo) + task2 (undo) <=| task3 (wait) + task4 (error) means
+ // need reboot to continue undoing 1 and 2
+ t3.SetStatus(state.ErrorStatus)
+ t4.SetToWait(state.UndoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (undo) + task2 (undo) => task3 (error) + task4 (wait) means
+ // need reboot to continue undoing 1 and 2
+ t3.SetToWait(state.UndoneStatus)
+ t4.SetStatus(state.ErrorStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (wait) + task2 (wait) <=| task3 (undone) + task4 (undo) no reboot
+ t1.SetToWait(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ t3.SetStatus(state.UndoneStatus)
+ t4.SetStatus(state.UndoStatus)
+ c.Check(chg.Status(), Equals, state.UndoStatus)
+
+ // task1 (wait) + task2 (done) <=| task3 (undone) + task4 (undone) means need a reboot
+ t1.SetToWait(state.DoneStatus)
+ t2.SetStatus(state.DoneStatus)
+ t3.SetStatus(state.UndoneStatus)
+ t4.SetStatus(state.UndoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (wait) + task2 (wait) <=| task3 (undone) + task4 (undone) means need a reboot
+ t1.SetToWait(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ t3.SetStatus(state.UndoneStatus)
+ t4.SetStatus(state.UndoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+}
diff --git a/internals/overlord/state/export_test.go b/internals/overlord/state/export_test.go
index 46d22f4d6..b307dc5c6 100644
--- a/internals/overlord/state/export_test.go
+++ b/internals/overlord/state/export_test.go
@@ -1,21 +1,16 @@
-// -*- Mode: Go; indent-tabs-mode: t -*-
-
-/*
- * Copyright (c) 2016 Canonical Ltd
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 3 as
- * published by the Free Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- *
- */
+// Copyright (c) 2024 Canonical Ltd
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License version 3 as
+// published by the Free Software Foundation.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
package state
diff --git a/internals/overlord/state/notices_test.go b/internals/overlord/state/notices_test.go
index 1a1e1eeb5..88840fe17 100644
--- a/internals/overlord/state/notices_test.go
+++ b/internals/overlord/state/notices_test.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2023 Canonical Ltd
+// Copyright (c) 2024 Canonical Ltd
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License version 3 as
@@ -11,7 +11,6 @@
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
-
package state_test
import (
@@ -448,7 +447,7 @@ func (s *noticesSuite) TestDeleteExpired(c *C) {
addNotice(c, st, nil, state.CustomNotice, "foo.com/z", nil)
c.Assert(st.NumNotices(), Equals, 4)
- st.Prune(0, 0, 0)
+ st.Prune(time.Now(), 0, 0, 0)
c.Assert(st.NumNotices(), Equals, 2)
notices := st.Notices(nil)
diff --git a/internals/overlord/state/state.go b/internals/overlord/state/state.go
index b848ecb69..e72dc3ef8 100644
--- a/internals/overlord/state/state.go
+++ b/internals/overlord/state/state.go
@@ -1,21 +1,16 @@
-// -*- Mode: Go; indent-tabs-mode: t -*-
-
-/*
- * Copyright (c) 2016 Canonical Ltd
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 3 as
- * published by the Free Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- *
- */
+// Copyright (c) 2024 Canonical Ltd
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License version 3 as
+// published by the Free Software Foundation.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
// Package state implements the representation of system state.
package state
@@ -46,7 +41,7 @@ type customData map[string]*json.RawMessage
func (data customData) get(key string, value interface{}) error {
entryJSON := data[key]
if entryJSON == nil {
- return ErrNoState
+ return &NoStateError{Key: key}
}
err := json.Unmarshal(*entryJSON, value)
if err != nil {
@@ -88,6 +83,9 @@ type State struct {
lastChangeId int
lastLaneId int
lastNoticeId int
+ // lastHandlerId is not serialized, it's only used during runtime
+ // for registering runtime callbacks
+ lastHandlerId int
backend Backend
data customData
@@ -101,19 +99,28 @@ type State struct {
modified bool
cache map[interface{}]interface{}
+
+ pendingChangeByAttr map[string]func(*Change) bool
+
+ // task/changes observing
+ taskHandlers map[int]func(t *Task, old, new Status)
+ changeHandlers map[int]func(chg *Change, old, new Status)
}
// New returns a new empty state.
func New(backend Backend) *State {
st := &State{
- backend: backend,
- data: make(customData),
- changes: make(map[string]*Change),
- tasks: make(map[string]*Task),
- warnings: make(map[string]*Warning),
- notices: make(map[noticeKey]*Notice),
- modified: true,
- cache: make(map[interface{}]interface{}),
+ backend: backend,
+ data: make(customData),
+ changes: make(map[string]*Change),
+ tasks: make(map[string]*Task),
+ warnings: make(map[string]*Warning),
+ notices: make(map[noticeKey]*Notice),
+ modified: true,
+ cache: make(map[interface{}]interface{}),
+ pendingChangeByAttr: make(map[string]func(*Change) bool),
+ taskHandlers: make(map[int]func(t *Task, old Status, new Status)),
+ changeHandlers: make(map[int]func(chg *Change, old Status, new Status)),
}
st.noticeCond = sync.NewCond(st) // use State.Lock and State.Unlock
return st
@@ -221,6 +228,15 @@ var (
unlockCheckpointRetryInterval = 3 * time.Second
)
+// Unlocker returns a closure that will unlock and checkpoint the state and
+// in turn return a function to relock it.
+func (s *State) Unlocker() (unlock func() (relock func())) {
+ return func() func() {
+ s.Unlock()
+ return s.Lock
+ }
+}
+
// Unlock releases the state lock and checkpoints the state.
// It does not return until the state is correctly checkpointed.
// After too many unsuccessful checkpoint attempts, it panics.
@@ -254,6 +270,28 @@ func (s *State) EnsureBefore(d time.Duration) {
// ErrNoState represents the case of no state entry for a given key.
var ErrNoState = errors.New("no state entry for key")
+// NoStateError represents the case where no state could be found for a given key.
+type NoStateError struct {
+ // Key is the key for which no state could be found.
+ Key string
+}
+
+func (e *NoStateError) Error() string {
+ var keyMsg string
+ if e.Key != "" {
+ keyMsg = fmt.Sprintf(" %q", e.Key)
+ }
+
+ return fmt.Sprintf("no state entry for key%s", keyMsg)
+}
+
+// Is returns true if the error is of type *NoStateError or equal to ErrNoState.
+// NoStateError's key isn't compared between errors.
+func (e *NoStateError) Is(err error) bool {
+ _, ok := err.(*NoStateError)
+ return ok || errors.Is(err, ErrNoState)
+}
+
// Get unmarshals the stored value associated with the provided key
// into the value parameter.
// It returns ErrNoState if there is no entry for key.
@@ -262,6 +300,12 @@ func (s *State) Get(key string, value interface{}) error {
return s.data.get(key, value)
}
+// Has returns whether the provided key has an associated value.
+func (s *State) Has(key string) bool {
+ s.reading()
+ return s.data.has(key)
+}
+
// Set associates value with key for future consulting by managers.
// The provided value must properly marshal and unmarshal with encoding/json.
func (s *State) Set(key string, value interface{}) {
@@ -370,15 +414,25 @@ func (s *State) tasksIn(tids []string) []*Task {
return res
}
+// RegisterPendingChangeByAttr registers predicates that will be invoked by
+// Prune on changes with the specified attribute set to check whether even if
+// they meet the time criteria they must not be aborted yet.
+func (s *State) RegisterPendingChangeByAttr(attr string, f func(*Change) bool) {
+ s.pendingChangeByAttr[attr] = f
+}
+
// Prune does several cleanup tasks to the in-memory state:
//
// - it removes changes that became ready for more than pruneWait and aborts
-// tasks spawned for more than abortWait.
+// tasks spawned for more than abortWait unless prevented by predicates
+// registered with RegisterPendingChangeByAttr.
+//
// - it removes tasks unlinked to changes after pruneWait. When there are more
// changes than the limit set via "maxReadyChanges" those changes in ready
// state will also removed even if they are below the pruneWait duration.
-// - it removes expired warnings and notices
-func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) {
+//
+// - it removes expired warnings and notices.
+func (s *State) Prune(startOfOperation time.Time, pruneWait, abortWait time.Duration, maxReadyChanges int) {
now := time.Now()
pruneLimit := now.Add(-pruneWait)
abortLimit := now.Add(-abortWait)
@@ -411,15 +465,24 @@ func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) {
}
}
+NextChange:
for _, chg := range changes {
- spawnTime := chg.SpawnTime()
readyTime := chg.ReadyTime()
+ spawnTime := chg.SpawnTime()
+ if spawnTime.Before(startOfOperation) {
+ spawnTime = startOfOperation
+ }
if readyTime.IsZero() {
if spawnTime.Before(pruneLimit) && len(chg.Tasks()) == 0 {
chg.Abort()
delete(s.changes, chg.ID())
} else if spawnTime.Before(abortLimit) {
- chg.Abort()
+ for attr, pending := range s.pendingChangeByAttr {
+ if chg.Has(attr) && pending(chg) {
+ continue NextChange
+ }
+ }
+ chg.AbortUnreadyLanes()
}
continue
}
@@ -443,6 +506,75 @@ func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) {
}
}
+// GetMaybeTimings implements timings.GetSaver
+func (s *State) GetMaybeTimings(timings interface{}) error {
+ if err := s.Get("timings", timings); err != nil && !errors.Is(err, ErrNoState) {
+ return err
+ }
+ return nil
+}
+
+// AddTaskStatusChangedHandler adds a callback function that will be invoked
+// whenever tasks change status.
+// NOTE: Callbacks registered this way may be invoked in the context
+// of the taskrunner, so the callbacks should be as simple as possible, and return
+// as quickly as possible, and should avoid the use of i/o code or blocking, as this
+// will stop the entire task system.
+func (s *State) AddTaskStatusChangedHandler(f func(t *Task, old, new Status)) (id int) {
+ // We are reading here as we want to ensure access to the state is serialized,
+ // and not writing as we are not changing the part of state that goes on the disk.
+ s.reading()
+ id = s.lastHandlerId
+ s.lastHandlerId++
+ s.taskHandlers[id] = f
+ return id
+}
+
+func (s *State) RemoveTaskStatusChangedHandler(id int) {
+ s.reading()
+ delete(s.taskHandlers, id)
+}
+
+func (s *State) notifyTaskStatusChangedHandlers(t *Task, old, new Status) {
+ s.reading()
+ for _, f := range s.taskHandlers {
+ f(t, old, new)
+ }
+}
+
+// AddChangeStatusChangedHandler adds a callback function that will be invoked
+// whenever a Change changes status.
+// NOTE: Callbacks registered this way may be invoked in the context
+// of the taskrunner, so the callbacks should be as simple as possible, and return
+// as quickly as possible, and should avoid the use of i/o code or blocking, as this
+// will stop the entire task system.
+func (s *State) AddChangeStatusChangedHandler(f func(chg *Change, old, new Status)) (id int) {
+ // We are reading here as we want to ensure access to the state is serialized,
+ // and not writing as we are not changing the part of state that goes on the disk.
+ s.reading()
+ id = s.lastHandlerId
+ s.lastHandlerId++
+ s.changeHandlers[id] = f
+ return id
+}
+
+func (s *State) RemoveChangeStatusChangedHandler(id int) {
+ s.reading()
+ delete(s.changeHandlers, id)
+}
+
+func (s *State) notifyChangeStatusChangedHandlers(chg *Change, old, new Status) {
+ s.reading()
+ for _, f := range s.changeHandlers {
+ f(chg, old, new)
+ }
+}
+
+// SaveTimings implements timings.GetSaver
+func (s *State) SaveTimings(timings interface{}) {
+ s.Set("timings", timings)
+}
+
// ReadState returns the state deserialized from r.
func ReadState(backend Backend, r io.Reader) (*State, error) {
s := new(State)
@@ -454,8 +586,11 @@ func ReadState(backend Backend, r io.Reader) (*State, error) {
return nil, fmt.Errorf("cannot read state: %s", err)
}
s.backend = backend
+ s.noticeCond = sync.NewCond(s)
s.modified = false
s.cache = make(map[interface{}]interface{})
- s.noticeCond = sync.NewCond(s)
+ s.pendingChangeByAttr = make(map[string]func(*Change) bool)
+ s.changeHandlers = make(map[int]func(chg *Change, old Status, new Status))
+ s.taskHandlers = make(map[int]func(t *Task, old Status, new Status))
return s, err
}
diff --git a/internals/overlord/state/state_test.go b/internals/overlord/state/state_test.go
index e1e8a7076..b1cf534c6 100644
--- a/internals/overlord/state/state_test.go
+++ b/internals/overlord/state/state_test.go
@@ -1,21 +1,16 @@
-// -*- Mode: Go; indent-tabs-mode: t -*-
-
-/*
- * Copyright (c) 2016 Canonical Ltd
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 3 as
- * published by the Free Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- *
- */
+// Copyright (c) 2024 Canonical Ltd
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License version 3 as
+// published by the Free Software Foundation.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
package state_test
@@ -23,12 +18,14 @@ import (
"bytes"
"errors"
"fmt"
+ "reflect"
"testing"
"time"
. "gopkg.in/check.v1"
"github.com/canonical/pebble/internals/overlord/state"
+ "github.com/canonical/pebble/internals/testutil"
)
func TestState(t *testing.T) { TestingT(t) }
@@ -55,6 +52,17 @@ func (ss *stateSuite) TestLockUnlock(c *C) {
st.Unlock()
}
+func (ss *stateSuite) TestUnlocker(c *C) {
+ st := state.New(nil)
+ unlocker := st.Unlocker()
+ st.Lock()
+ defer st.Unlock()
+ relock := unlocker()
+ st.Lock()
+ st.Unlock()
+ relock()
+}
+
func (ss *stateSuite) TestGetAndSet(c *C) {
st := state.New(nil)
st.Lock()
@@ -76,6 +84,37 @@ func (ss *stateSuite) TestGetAndSet(c *C) {
c.Check(&mSt2B, DeepEquals, mSt2)
}
+func (ss *stateSuite) TestHas(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ c.Check(st.Has("a"), Equals, false)
+
+ st.Set("a", 1)
+ c.Check(st.Has("a"), Equals, true)
+
+ st.Set("a", nil)
+ c.Check(st.Has("a"), Equals, false)
+}
+
+func (ss *stateSuite) TestStrayTaskWithNoChange(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+ t1 := st.NewTask("foo", "...")
+ chg.AddTask(t1)
+ _ = st.NewTask("bar", "...")
+
+ // only the task with associate change is returned
+ c.Assert(st.Tasks(), HasLen, 1)
+ c.Assert(st.Tasks()[0].ID(), Equals, t1.ID())
+ // but count includes all tasks
+ c.Assert(st.TaskCount(), Equals, 2)
+}
+
func (ss *stateSuite) TestSetPanic(c *C) {
st := state.New(nil)
st.Lock()
@@ -94,7 +133,7 @@ func (ss *stateSuite) TestGetNoState(c *C) {
var mSt1B mgrState1
err := st.Get("mgr9", &mSt1B)
- c.Check(err, Equals, state.ErrNoState)
+ c.Check(err, testutil.ErrorIs, state.ErrNoState)
}
func (ss *stateSuite) TestSetToNilDeletes(c *C) {
@@ -112,7 +151,7 @@ func (ss *stateSuite) TestSetToNilDeletes(c *C) {
var v1 map[string]int
err = st.Get("a", &v1)
- c.Check(err, Equals, state.ErrNoState)
+ c.Check(err, testutil.ErrorIs, state.ErrNoState)
c.Check(v1, HasLen, 0)
}
@@ -126,7 +165,7 @@ func (ss *stateSuite) TestNullMeansNoState(c *C) {
var v1 map[string]int
err = st.Get("a", &v1)
- c.Check(err, Equals, state.ErrNoState)
+ c.Check(err, testutil.ErrorIs, state.ErrNoState)
c.Check(v1, HasLen, 0)
}
@@ -514,6 +553,30 @@ func (ss *stateSuite) TestEmptyStateDataAndCheckpointReadAndSet(c *C) {
// no crash
st2.Set("a", 1)
+
+ // ensure all maps of state are correctly initialized by ReadState
+ val := reflect.ValueOf(st2)
+ typ := val.Elem().Type()
+ var maps []string
+ for i := 0; i < typ.NumField(); i++ {
+ f := typ.Field(i)
+ if f.Type.Kind() == reflect.Map {
+ maps = append(maps, f.Name)
+ fv := val.Elem().Field(i)
+ c.Check(fv.IsNil(), Equals, false, Commentf("Map field %s of state was not initialized by ReadState", f.Name))
+ }
+ }
+ c.Check(maps, DeepEquals, []string{
+ "data",
+ "changes",
+ "tasks",
+ "warnings",
+ "notices",
+ "cache",
+ "pendingChangeByAttr",
+ "taskHandlers",
+ "changeHandlers",
+ })
}
func (ss *stateSuite) TestEmptyTaskAndChangeDataAndCheckpointReadAndSet(c *C) {
@@ -702,7 +765,7 @@ func (ss *stateSuite) TestMethodEntrance(c *C) {
func() { st.Tasks() },
func() { st.Task("foo") },
func() { st.MarshalJSON() },
- func() { st.Prune(time.Hour, time.Hour, 100) },
+ func() { st.Prune(time.Now(), time.Hour, time.Hour, 100) },
func() { st.TaskCount() },
func() { st.AllWarnings() },
func() { st.PendingWarnings() },
@@ -768,7 +831,8 @@ func (ss *stateSuite) TestPrune(c *C) {
st.AddWarning("hello", now, never, time.Nanosecond, state.DefaultRepeatAfter)
st.Warnf("hello again")
- st.Prune(pruneWait, abortWait, 100)
+ past := time.Now().AddDate(-1, 0, 0)
+ st.Prune(past, pruneWait, abortWait, 100)
c.Assert(st.Change(chg1.ID()), Equals, chg1)
c.Assert(st.Change(chg2.ID()), IsNil)
@@ -793,6 +857,55 @@ func (ss *stateSuite) TestPrune(c *C) {
c.Check(st.AllWarnings(), HasLen, 1)
}
+func (ss *stateSuite) TestRegisterPendingChangeByAttr(c *C) {
+ st := state.New(&fakeStateBackend{})
+ st.Lock()
+ defer st.Unlock()
+
+ now := time.Now()
+ pruneWait := 1 * time.Hour
+ abortWait := 3 * time.Hour
+
+ unset := time.Time{}
+
+ t1 := st.NewTask("foo", "...")
+ t2 := st.NewTask("foo", "...")
+ t3 := st.NewTask("foo", "...")
+ t4 := st.NewTask("foo", "...")
+
+ chg1 := st.NewChange("abort", "...")
+ chg1.AddTask(t1)
+ chg1.AddTask(t2)
+ state.FakeChangeTimes(chg1, now.Add(-abortWait), unset)
+
+ chg2 := st.NewChange("pending", "...")
+ chg2.AddTask(t3)
+ chg2.AddTask(t4)
+ state.FakeChangeTimes(chg2, now.Add(-abortWait), unset)
+ chg2.Set("pending-flag", true)
+ t3.SetStatus(state.HoldStatus)
+
+ st.RegisterPendingChangeByAttr("pending-flag", func(chg *state.Change) bool {
+ c.Check(chg.ID(), Equals, chg2.ID())
+ return true
+ })
+
+ past := time.Now().AddDate(-1, 0, 0)
+ st.Prune(past, pruneWait, abortWait, 100)
+
+ c.Assert(st.Change(chg1.ID()), Equals, chg1)
+ c.Assert(st.Change(chg2.ID()), Equals, chg2)
+ c.Assert(st.Task(t1.ID()), Equals, t1)
+ c.Assert(st.Task(t2.ID()), Equals, t2)
+ c.Assert(st.Task(t3.ID()), Equals, t3)
+ c.Assert(st.Task(t4.ID()), Equals, t4)
+
+ c.Assert(t1.Status(), Equals, state.HoldStatus)
+ c.Assert(t2.Status(), Equals, state.HoldStatus)
+ c.Assert(t3.Status(), Equals, state.HoldStatus)
+ c.Assert(t4.Status(), Equals, state.DoStatus)
+}
+
func (ss *stateSuite) TestPruneEmptyChange(c *C) {
// Empty changes are a bit special because they start out on Hold
// which is a Ready status, but the change itself is not considered Ready
@@ -809,7 +922,8 @@ func (ss *stateSuite) TestPruneEmptyChange(c *C) {
chg := st.NewChange("abort", "...")
state.FakeChangeTimes(chg, now.Add(-pruneWait), time.Time{})
- st.Prune(pruneWait, abortWait, 100)
+ past := time.Now().AddDate(-1, 0, 0)
+ st.Prune(past, pruneWait, abortWait, 100)
c.Assert(st.Change(chg.ID()), IsNil)
}
@@ -844,13 +958,14 @@ func (ss *stateSuite) TestPruneMaxChangesHappy(c *C) {
// test that nothing is done when we are within pruneWait and
// maxReadyChanges
+ past := time.Now().AddDate(-1, 0, 0)
maxReadyChanges := 100
- st.Prune(pruneWait, abortWait, maxReadyChanges)
+ st.Prune(past, pruneWait, abortWait, maxReadyChanges)
c.Assert(st.Changes(), HasLen, 15)
// but with maxReadyChanges we remove the ready ones
maxReadyChanges = 5
- st.Prune(pruneWait, abortWait, maxReadyChanges)
+ st.Prune(past, pruneWait, abortWait, maxReadyChanges)
c.Assert(st.Changes(), HasLen, 10)
remaining := map[string]bool{}
for _, chg := range st.Changes() {
@@ -886,8 +1001,9 @@ func (ss *stateSuite) TestPruneMaxChangesSomeNotReady(c *C) {
c.Assert(st.Changes(), HasLen, 10)
// nothing can be pruned
+ past := time.Now().AddDate(-1, 0, 0)
maxChanges := 5
- st.Prune(1*time.Hour, 3*time.Hour, maxChanges)
+ st.Prune(past, 1*time.Hour, 3*time.Hour, maxChanges)
c.Assert(st.Changes(), HasLen, 10)
}
@@ -905,7 +1021,7 @@ func (ss *stateSuite) TestPruneMaxChangesHonored(c *C) {
c.Assert(st.Changes(), HasLen, 10)
// one extra change that just now entered ready state
- chg := st.NewChange(fmt.Sprintf("chg99"), "so-ready")
+ chg := st.NewChange("chg99", "so-ready")
t := st.NewTask("foo", "so-ready")
when := 1 * time.Second
state.FakeChangeTimes(chg, time.Now().Add(-when), time.Now().Add(-when))
@@ -916,11 +1032,45 @@ func (ss *stateSuite) TestPruneMaxChangesHonored(c *C) {
//
// this test we do not purge the freshly ready change
maxChanges := 10
- st.Prune(1*time.Hour, 3*time.Hour, maxChanges)
+ past := time.Now().AddDate(-1, 0, 0)
+ st.Prune(past, 1*time.Hour, 3*time.Hour, maxChanges)
c.Assert(st.Changes(), HasLen, 11)
}
-func (ss *stateSuite) TestReadStateInitsCache(c *C) {
+func (ss *stateSuite) TestPruneHonorsStartOperationTime(c *C) {
+ st := state.New(&fakeStateBackend{})
+ st.Lock()
+ defer st.Unlock()
+
+ now := time.Now()
+
+ startTime := 2 * time.Hour
+ spawnTime := 10 * time.Hour
+ pruneWait := 1 * time.Hour
+ abortWait := 3 * time.Hour
+
+ chg := st.NewChange("change", "...")
+ t := st.NewTask("foo", "")
+ chg.AddTask(t)
+ // change spawned 10h ago
+ state.FakeChangeTimes(chg, now.Add(-spawnTime), time.Time{})
+
+ // start operation time is 2h ago, change is not aborted because
+ // it's less than abortWait limit.
+ opTime := now.Add(-startTime)
+ st.Prune(opTime, pruneWait, abortWait, 100)
+ c.Assert(st.Changes(), HasLen, 1)
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // start operation time is 9h ago, change is aborted.
+ startTime = 9 * time.Hour
+ opTime = time.Now().Add(-startTime)
+ st.Prune(opTime, pruneWait, abortWait, 100)
+ c.Assert(st.Changes(), HasLen, 1)
+ c.Check(chg.Status(), Equals, state.HoldStatus)
+}
+
+func (ss *stateSuite) TestReadStateInitsTransientMapFields(c *C) {
st, err := state.ReadState(nil, bytes.NewBufferString("{}"))
c.Assert(err, IsNil)
st.Lock()
@@ -928,4 +1078,195 @@ func (ss *stateSuite) TestReadStateInitsCache(c *C) {
st.Cache("key", "value")
c.Assert(st.Cached("key"), Equals, "value")
+ st.RegisterPendingChangeByAttr("attr", func(*state.Change) bool { return false })
+}
+
+func (ss *stateSuite) TestTimingsSupport(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ var tims []int
+
+ err := st.GetMaybeTimings(&tims)
+ c.Assert(err, IsNil)
+ c.Check(tims, IsNil)
+
+ st.SaveTimings([]int{1, 2, 3})
+
+ err = st.GetMaybeTimings(&tims)
+ c.Assert(err, IsNil)
+ c.Check(tims, DeepEquals, []int{1, 2, 3})
+}
+
+func (ss *stateSuite) TestNoStateErrorIs(c *C) {
+ err := &state.NoStateError{Key: "foo"}
+ c.Assert(err, testutil.ErrorIs, &state.NoStateError{})
+ c.Assert(err, testutil.ErrorIs, &state.NoStateError{Key: "bar"})
+ c.Assert(err, testutil.ErrorIs, state.ErrNoState)
+}
+
+func (ss *stateSuite) TestNoStateErrorString(c *C) {
+ err := &state.NoStateError{}
+ c.Assert(err.Error(), Equals, `no state entry for key`)
+ err.Key = "foo"
+ c.Assert(err.Error(), Equals, `no state entry for key "foo"`)
+}
+
+type taskAndStatus struct {
+ t *state.Task
+ old, new state.Status
+}
+
+func (ss *stateSuite) TestTaskChangedHandler(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ var taskObservedChanges []taskAndStatus
+ oId := st.AddTaskStatusChangedHandler(func(t *state.Task, old, new state.Status) {
+ taskObservedChanges = append(taskObservedChanges, taskAndStatus{
+ t: t,
+ old: old,
+ new: new,
+ })
+ })
+
+ t1 := st.NewTask("foo", "...")
+
+ t1.SetStatus(state.DoingStatus)
+
+ // Set task status to identical status, we don't want
+ // task events when task don't actually change status.
+ t1.SetStatus(state.DoingStatus)
+
+ // Set task to done.
+ t1.SetStatus(state.DoneStatus)
+
+ // Unregister us, and make sure we do not receive more events.
+ st.RemoveTaskStatusChangedHandler(oId)
+
+ // must not appear in list.
+ t1.SetStatus(state.DoingStatus)
+
+ c.Check(taskObservedChanges, DeepEquals, []taskAndStatus{
+ {
+ t: t1,
+ old: state.DefaultStatus,
+ new: state.DoingStatus,
+ },
+ {
+ t: t1,
+ old: state.DoingStatus,
+ new: state.DoneStatus,
+ },
+ })
+}
+
+type changeAndStatus struct {
+ chg *state.Change
+ old, new state.Status
+}
+
+func (ss *stateSuite) TestChangeChangedHandler(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ var observedChanges []changeAndStatus
+ oId := st.AddChangeStatusChangedHandler(func(chg *state.Change, old, new state.Status) {
+ observedChanges = append(observedChanges, changeAndStatus{
+ chg: chg,
+ old: old,
+ new: new,
+ })
+ })
+
+ chg := st.NewChange("test-chg", "...")
+ t1 := st.NewTask("foo", "...")
+ chg.AddTask(t1)
+
+ t1.SetStatus(state.DoingStatus)
+
+ // Set task status to identical status, we don't want
+ // change events when changes don't actually change status.
+ t1.SetStatus(state.DoingStatus)
+
+ // Set task to waiting
+ t1.SetToWait(state.DoneStatus)
+
+ // Unregister us, and make sure we do not receive more events.
+ st.RemoveChangeStatusChangedHandler(oId)
+
+ // must not appear in list.
+ t1.SetStatus(state.DoneStatus)
+
+ c.Check(observedChanges, DeepEquals, []changeAndStatus{
+ {
+ chg: chg,
+ old: state.DefaultStatus,
+ new: state.DoingStatus,
+ },
+ {
+ chg: chg,
+ old: state.DoingStatus,
+ new: state.WaitStatus,
+ },
+ })
+}
+
+func (ss *stateSuite) TestChangeSetStatusChangedHandler(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ var observedChanges []changeAndStatus
+ oId := st.AddChangeStatusChangedHandler(func(chg *state.Change, old, new state.Status) {
+ observedChanges = append(observedChanges, changeAndStatus{
+ chg: chg,
+ old: old,
+ new: new,
+ })
+ })
+
+ chg := st.NewChange("test-chg", "...")
+ t1 := st.NewTask("foo", "...")
+ chg.AddTask(t1)
+
+ t1.SetStatus(state.DoingStatus)
+
+ // We have a single task in Doing, now we manipulate the status
+ // of the change to ensure we are receiving correct events
+ chg.SetStatus(state.WaitStatus)
+
+ // Change to a new status
+ chg.SetStatus(state.ErrorStatus)
+
+ // Now return the status back to Default, which should result
+ // in the change reporting Doing
+ chg.SetStatus(state.DefaultStatus)
+ st.RemoveChangeStatusChangedHandler(oId)
+
+ c.Check(observedChanges, DeepEquals, []changeAndStatus{
+ {
+ chg: chg,
+ old: state.DefaultStatus,
+ new: state.DoingStatus,
+ },
+ {
+ chg: chg,
+ old: state.DoingStatus,
+ new: state.WaitStatus,
+ },
+ {
+ chg: chg,
+ old: state.WaitStatus,
+ new: state.ErrorStatus,
+ },
+ {
+ chg: chg,
+ old: state.ErrorStatus,
+ new: state.DoingStatus,
+ },
+ })
}
diff --git a/internals/overlord/state/task.go b/internals/overlord/state/task.go
index a3e2a4864..39a85a20a 100644
--- a/internals/overlord/state/task.go
+++ b/internals/overlord/state/task.go
@@ -1,21 +1,16 @@
-// -*- Mode: Go; indent-tabs-mode: t -*-
-
-/*
- * Copyright (c) 2016 Canonical Ltd
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 3 as
- * published by the Free Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- *
- */
+// Copyright (c) 2024 Canonical Ltd
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License version 3 as
+// published by the Free Software Foundation.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
package state
@@ -38,19 +33,22 @@ type progress struct {
//
// See Change for more details.
type Task struct {
- state *State
- id string
- kind string
- summary string
- status Status
- clean bool
- progress *progress
- data customData
- waitTasks []string
- haltTasks []string
- lanes []int
- log []string
- change string
+ state *State
+ id string
+ kind string
+ summary string
+ status Status
+ // waitedStatus is the Status that should be used instead of
+ // WaitStatus once the wait is complete (i.e post reboot).
+ waitedStatus Status
+ clean bool
+ progress *progress
+ data customData
+ waitTasks []string
+ haltTasks []string
+ lanes []int
+ log []string
+ change string
spawnTime time.Time
readyTime time.Time
@@ -77,18 +75,19 @@ func newTask(state *State, id, kind, summary string) *Task {
}
type marshalledTask struct {
- ID string `json:"id"`
- Kind string `json:"kind"`
- Summary string `json:"summary"`
- Status Status `json:"status"`
- Clean bool `json:"clean,omitempty"`
- Progress *progress `json:"progress,omitempty"`
- Data map[string]*json.RawMessage `json:"data,omitempty"`
- WaitTasks []string `json:"wait-tasks,omitempty"`
- HaltTasks []string `json:"halt-tasks,omitempty"`
- Lanes []int `json:"lanes,omitempty"`
- Log []string `json:"log,omitempty"`
- Change string `json:"change"`
+ ID string `json:"id"`
+ Kind string `json:"kind"`
+ Summary string `json:"summary"`
+ Status Status `json:"status"`
+ WaitedStatus Status `json:"waited-status"`
+ Clean bool `json:"clean,omitempty"`
+ Progress *progress `json:"progress,omitempty"`
+ Data map[string]*json.RawMessage `json:"data,omitempty"`
+ WaitTasks []string `json:"wait-tasks,omitempty"`
+ HaltTasks []string `json:"halt-tasks,omitempty"`
+ Lanes []int `json:"lanes,omitempty"`
+ Log []string `json:"log,omitempty"`
+ Change string `json:"change"`
SpawnTime time.Time `json:"spawn-time"`
ReadyTime *time.Time `json:"ready-time,omitempty"`
@@ -111,18 +110,19 @@ func (t *Task) MarshalJSON() ([]byte, error) {
atTime = &t.atTime
}
return json.Marshal(marshalledTask{
- ID: t.id,
- Kind: t.kind,
- Summary: t.summary,
- Status: t.status,
- Clean: t.clean,
- Progress: t.progress,
- Data: t.data,
- WaitTasks: t.waitTasks,
- HaltTasks: t.haltTasks,
- Lanes: t.lanes,
- Log: t.log,
- Change: t.change,
+ ID: t.id,
+ Kind: t.kind,
+ Summary: t.summary,
+ Status: t.status,
+ WaitedStatus: t.waitedStatus,
+ Clean: t.clean,
+ Progress: t.progress,
+ Data: t.data,
+ WaitTasks: t.waitTasks,
+ HaltTasks: t.haltTasks,
+ Lanes: t.lanes,
+ Log: t.log,
+ Change: t.change,
SpawnTime: t.spawnTime,
ReadyTime: readyTime,
@@ -148,6 +148,13 @@ func (t *Task) UnmarshalJSON(data []byte) error {
t.kind = unmarshalled.Kind
t.summary = unmarshalled.Summary
t.status = unmarshalled.Status
+ t.waitedStatus = unmarshalled.WaitedStatus
+ if t.waitedStatus == DefaultStatus {
+ // For backwards-compatibility, default the waitStatus, which is
+ // the result status after a wait, to DoneStatus to keep any previous
+ // behaviour before any upgrade.
+ t.waitedStatus = DoneStatus
+ }
t.clean = unmarshalled.Clean
t.progress = unmarshalled.Progress
custData := unmarshalled.Data
@@ -188,6 +195,40 @@ func (t *Task) Summary() string {
}
// Status returns the current task status.
+//
+// Possible state transitions:
+//
+// /----aborting lane--Do
+// | |
+// V V
+// Hold Doing-->Wait
+// ^ / | \
+// | abort / V V
+// no undo / Done Error
+// | V |
+// \----------Abort aborting lane
+// / | |
+// | finished or |
+// running not running |
+// V \------->|
+// kill goroutine |
+// | V
+// / \ ----->Undo
+// / no error / |
+// | from goroutine |
+// error |
+// from goroutine |
+// | V
+// | Undoing-->Wait
+// V | \
+// Error V V
+// Undone Error
+//
+// Do -> Doing -> Done is the direct succcess scenario.
+//
+// Wait can transition to its waited status,
+// usually Done|Undone or back to Doing.
+// See Wait struct, SetToWait and WaitedStatus.
func (t *Task) Status() Status {
t.state.reading()
if t.status == DefaultStatus {
@@ -196,10 +237,10 @@ func (t *Task) Status() Status {
return t.status
}
-// SetStatus sets the task status, overriding the default behavior (see Status method).
-func (t *Task) SetStatus(new Status) {
- t.state.writing()
- old := t.status
+func (t *Task) changeStatus(old, new Status) {
+ if old == new {
+ return
+ }
t.status = new
if !old.Ready() && new.Ready() {
t.readyTime = timeNow()
@@ -208,6 +249,55 @@ func (t *Task) SetStatus(new Status) {
if chg != nil {
chg.taskStatusChanged(t, old, new)
}
+ t.state.notifyTaskStatusChangedHandlers(t, old, new)
+}
+
+// SetStatus sets the task status, overriding the default behavior (see Status method).
+func (t *Task) SetStatus(new Status) {
+ if new == WaitStatus {
+ panic("Task.SetStatus() called with WaitStatus, which is not allowed. Use SetToWait() instead")
+ }
+
+ t.state.writing()
+ old := t.status
+ if new == DoneStatus && old == AbortStatus {
+ // if the task is in AbortStatus (because some other task ran
+ // in parallel and had an error so the change is aborted) and
+ // DoneStatus was requested (which can happen if the
+ // task handler sets its status explicitly) then keep it at
+ // aborted so it can transition to Undo.
+ return
+ }
+ t.changeStatus(old, new)
+}
+
+// SetToWait puts the task into WaitStatus, and sets the status the task should be restored
+// to after the SetToWait.
+func (t *Task) SetToWait(resultStatus Status) {
+ switch resultStatus {
+ case DefaultStatus, WaitStatus:
+ panic("Task.SetToWait() cannot be invoked with either of DefaultStatus or WaitStatus")
+ }
+
+ t.state.writing()
+ old := t.status
+ if old == AbortStatus {
+ // if the task is in AbortStatus (because some other task ran
+ // in parallel and had an error so the change is aborted) and
+ // WaitStatus was requested (which can happen if the
+ // task handler sets its status explicitly) then keep it at
+ // aborted so it can transition to Undo.
+ return
+ }
+ t.waitedStatus = resultStatus
+ t.changeStatus(old, WaitStatus)
+}
+
+// WaitedStatus returns the status the Task should return to once the current WaitStatus
+// has been resolved.
+func (t *Task) WaitedStatus() Status {
+ t.state.reading()
+ return t.waitedStatus
}
// IsClean returns whether the task has been cleaned. See SetClean.
@@ -332,7 +422,7 @@ func (t *Task) addLog(kind, format string, args []interface{}) {
}
tstr := timeNow().Format(time.RFC3339)
- msg := fmt.Sprintf(tstr+" "+kind+" "+format, args...)
+ msg := tstr + " " + kind + " " + fmt.Sprintf(format, args...)
t.log = append(t.log, msg)
logger.Debugf(msg)
}
@@ -484,10 +574,16 @@ func NewTaskSet(tasks ...*Task) *TaskSet {
return &TaskSet{tasks, nil}
}
-// Edge returns the task marked with the given edge name.
+// MaybeEdge returns the task marked with the given edge name or nil if no such
+// task exists.
+func (ts TaskSet) MaybeEdge(e TaskSetEdge) *Task {
+ return ts.edges[e]
+}
+
+// Edge returns the task marked with the given edge name or an error.
func (ts TaskSet) Edge(e TaskSetEdge) (*Task, error) {
- t, ok := ts.edges[e]
- if !ok {
+ t := ts.MaybeEdge(e)
+ if t == nil {
return nil, fmt.Errorf("internal error: missing %q edge in task set", e)
}
return t, nil
@@ -509,7 +605,7 @@ func (ts *TaskSet) WaitAll(anotherTs *TaskSet) {
}
}
-// AddTask adds the the task to the task set.
+// AddTask adds the task to the task set.
func (ts *TaskSet) AddTask(task *Task) {
for _, t := range ts.tasks {
if t == task {
@@ -522,6 +618,9 @@ func (ts *TaskSet) AddTask(task *Task) {
// MarkEdge marks the given task as a specific edge. Any pre-existing
// edge mark will be overridden.
func (ts *TaskSet) MarkEdge(task *Task, edge TaskSetEdge) {
+ if task == nil {
+ panic(fmt.Sprintf("cannot set edge %q with nil task", edge))
+ }
if ts.edges == nil {
ts.edges = make(map[TaskSetEdge]*Task)
}
diff --git a/internals/overlord/state/task_test.go b/internals/overlord/state/task_test.go
index 9c073d806..5a5c4f9b1 100644
--- a/internals/overlord/state/task_test.go
+++ b/internals/overlord/state/task_test.go
@@ -1,21 +1,16 @@
-// -*- Mode: Go; indent-tabs-mode: t -*-
-
-/*
- * Copyright (c) 2016 Canonical Ltd
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 3 as
- * published by the Free Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- *
- */
+// Copyright (c) 2024 Canonical Ltd
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License version 3 as
+// published by the Free Software Foundation.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
package state_test
@@ -127,7 +122,7 @@ func (ts *taskSuite) TestClear(c *C) {
t.Clear("a")
- c.Check(t.Get("a", &v), Equals, state.ErrNoState)
+ c.Check(t.Get("a", &v), testutil.ErrorIs, state.ErrNoState)
}
func (ts *taskSuite) TestStatusAndSetStatus(c *C) {
@@ -144,6 +139,60 @@ func (ts *taskSuite) TestStatusAndSetStatus(c *C) {
c.Check(t.Status(), Equals, state.DoneStatus)
}
+func (ts *taskSuite) TestSetDoneAfterAbortNoop(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ t := st.NewTask("download", "1...")
+ t.SetStatus(state.AbortStatus)
+ c.Check(t.Status(), Equals, state.AbortStatus)
+ t.SetStatus(state.DoneStatus)
+ c.Check(t.Status(), Equals, state.AbortStatus)
+}
+
+func (ts *taskSuite) TestSetWaitAfterAbortNoop(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ t := st.NewTask("download", "1...")
+ t.SetStatus(state.AbortStatus)
+ c.Check(t.Status(), Equals, state.AbortStatus)
+ t.SetToWait(state.DoneStatus) // noop
+ c.Check(t.Status(), Equals, state.AbortStatus)
+ c.Check(t.WaitedStatus(), Equals, state.DefaultStatus)
+}
+
+func (ts *taskSuite) TestSetWait(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ t := st.NewTask("download", "1...")
+ t.SetToWait(state.DoneStatus)
+ c.Check(t.Status(), Equals, state.WaitStatus)
+ c.Check(t.WaitedStatus(), Equals, state.DoneStatus)
+ t.SetToWait(state.UndoStatus)
+ c.Check(t.Status(), Equals, state.WaitStatus)
+ c.Check(t.WaitedStatus(), Equals, state.UndoStatus)
+}
+
+func (ts *taskSuite) TestTaskMarshalsWaitStatus(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ t1 := st.NewTask("download", "1...")
+ t1.SetToWait(state.UndoStatus)
+
+ d, err := t1.MarshalJSON()
+ c.Assert(err, IsNil)
+
+ needle := fmt.Sprintf(`"waited-status":%d`, t1.WaitedStatus())
+ c.Assert(string(d), testutil.Contains, needle)
+}
+
func (ts *taskSuite) TestIsCleanAndSetClean(c *C) {
st := state.New(nil)
st.Lock()
@@ -561,6 +610,10 @@ func (cs *taskSuite) TestTaskSetEdge(c *C) {
// edges are just typed strings
edge1 := state.TaskSetEdge("on-edge")
edge2 := state.TaskSetEdge("eddie")
+ edge3 := state.TaskSetEdge("not-found")
+
+ // nil task causes panic
+ c.Check(func() { ts.MarkEdge(nil, edge1) }, PanicMatches, `cannot set edge "on-edge" with nil task`)
// no edge marked yet
t, err := ts.Edge(edge1)
@@ -590,6 +643,12 @@ func (cs *taskSuite) TestTaskSetEdge(c *C) {
t, err = ts.Edge(edge1)
c.Assert(t, Equals, t3)
c.Assert(err, IsNil)
+
+ // it is possible to check if edge exists without failing
+ t = ts.MaybeEdge(edge1)
+ c.Assert(t, Equals, t3)
+ t = ts.MaybeEdge(edge3)
+ c.Assert(t, IsNil)
}
func (cs *taskSuite) TestTaskAddAllWithEdges(c *C) {
diff --git a/internals/overlord/state/taskrunner.go b/internals/overlord/state/taskrunner.go
index 5d9ff2440..cfed27bd2 100644
--- a/internals/overlord/state/taskrunner.go
+++ b/internals/overlord/state/taskrunner.go
@@ -1,21 +1,16 @@
-// -*- Mode: Go; indent-tabs-mode: t -*-
-
-/*
- * Copyright (c) 2016 Canonical Ltd
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 3 as
- * published by the Free Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- *
- */
+// Copyright (c) 2024 Canonical Ltd
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License version 3 as
+// published by the Free Software Foundation.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
package state
@@ -46,6 +41,23 @@ func (r *Retry) Error() string {
return "task should be retried"
}
+// Wait is returned from a handler to signal that the task cannot
+// proceed at the moment maybe because some manual action from the
+// user required at this point or because of errors. The task
+// will be set to WaitStatus, and it's wait complete status will be
+// set to WaitedStatus.
+type Wait struct {
+ Reason string
+ // If not explicitly set, then WaitedStatus will default to
+ // DoneStatus, meaning that the task will be set to DoneStatus
+ // after the wait has resolved.
+ WaitedStatus Status
+}
+
+func (r *Wait) Error() string {
+ return "task set to wait, manual action required"
+}
+
type blockedFunc func(t *Task, running []*Task) bool
// TaskRunner controls the running of goroutines to execute known task kinds.
@@ -62,6 +74,9 @@ type TaskRunner struct {
blocked []blockedFunc
someBlocked bool
+ // optional callback executed on task errors
+ taskErrorCallback func(err error)
+
// go-routines lifecycle
tombs map[string]*tomb.Tomb
}
@@ -85,6 +100,11 @@ func NewTaskRunner(s *State) *TaskRunner {
}
}
+// OnTaskError sets an error callback executed when any task errors out.
+func (r *TaskRunner) OnTaskError(f func(err error)) {
+ r.taskErrorCallback = f
+}
+
// AddHandler registers the functions to concurrently call for doing and
// undoing tasks of the given kind. The undo handler may be nil.
func (r *TaskRunner) AddHandler(kind string, do, undo HandlerFunc) {
@@ -214,7 +234,7 @@ func (r *TaskRunner) run(t *Task) {
switch err.(type) {
case nil:
// we are ok
- case *Retry:
+ case *Retry, *Wait:
// preserve
default:
if r.stopped {
@@ -227,13 +247,24 @@ func (r *TaskRunner) run(t *Task) {
switch x := err.(type) {
case *Retry:
// Handler asked to be called again later.
- // TODO Allow postponing retries past the next Ensure.
if t.Status() == AbortStatus {
// Would work without it but might take two ensures.
r.tryUndo(t)
} else if x.After != 0 {
t.At(timeNow().Add(x.After))
}
+ case *Wait:
+ if t.Status() == AbortStatus {
+ // Would work without it but might take two ensures.
+ r.tryUndo(t)
+ } else {
+ // Default to DoneStatus if no status is set in Wait
+ waitedStatus := x.WaitedStatus
+ if waitedStatus == DefaultStatus {
+ waitedStatus = DoneStatus
+ }
+ t.SetToWait(waitedStatus)
+ }
case nil:
var next []*Task
switch t.Status() {
@@ -259,6 +290,11 @@ func (r *TaskRunner) run(t *Task) {
r.abortLanes(t.Change(), t.Lanes())
t.SetStatus(ErrorStatus)
t.Errorf("%s", err)
+ // ensure the error is available in the global log too
+ logger.Noticef("[change %s %q task] failed: %v", t.Change().ID(), t.Summary(), err)
+ if r.taskErrorCallback != nil {
+ r.taskErrorCallback(err)
+ }
}
return nil
@@ -388,6 +424,10 @@ ConsiderTasks:
}
continue
}
+ if status == WaitStatus {
+ // nothing more to run
+ continue
+ }
if mustWait(t) {
// Dependencies still unhandled.
diff --git a/internals/overlord/state/taskrunner_test.go b/internals/overlord/state/taskrunner_test.go
index 38c5685a2..979d05e96 100644
--- a/internals/overlord/state/taskrunner_test.go
+++ b/internals/overlord/state/taskrunner_test.go
@@ -1,21 +1,16 @@
-// -*- Mode: Go; indent-tabs-mode: t -*-
-
-/*
- * Copyright (c) 2016 Canonical Ltd
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 3 as
- * published by the Free Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- *
- */
+// Copyright (c) 2024 Canonical Ltd
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License version 3 as
+// published by the Free Software Foundation.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
package state_test
@@ -31,11 +26,14 @@ import (
. "gopkg.in/check.v1"
"gopkg.in/tomb.v2"
- "github.com/canonical/pebble/internals/overlord/restart"
+ "github.com/canonical/pebble/internals/logger"
"github.com/canonical/pebble/internals/overlord/state"
+ "github.com/canonical/pebble/internals/testutil"
)
-type taskRunnerSuite struct{}
+type taskRunnerSuite struct {
+ testutil.BaseTest
+}
var _ = Suite(&taskRunnerSuite{})
@@ -58,8 +56,6 @@ func (b *stateBackend) EnsureBefore(d time.Duration) {
}
}
-func (b *stateBackend) RequestRestart(t restart.RestartType) {}
-
func ensureChange(c *C, r *state.TaskRunner, sb *stateBackend, chg *state.Change) {
for i := 0; i < 20; i++ {
sb.ensureBefore = time.Hour
@@ -142,6 +138,10 @@ var sequenceTests = []struct{ setup, result string }{{
result: "t31:undo t32:do t32:do-error t21:undo",
}}
+func (ts *taskRunnerSuite) SetUpTest(c *C) {
+ ts.BaseTest.SetUpTest(c)
+}
+
func (ts *taskRunnerSuite) TestSequenceTests(c *C) {
sb := &stateBackend{}
st := state.New(sb)
@@ -185,11 +185,12 @@ func (ts *taskRunnerSuite) TestSequenceTests(c *C) {
r.AddHandler("do", fn("do"), nil)
r.AddHandler("do-undo", fn("do"), fn("undo"))
+ past := time.Now().AddDate(-1, 0, 0)
for _, test := range sequenceTests {
st.Lock()
// Delete previous changes.
- st.Prune(1, 1, 1)
+ st.Prune(past, 1, 1, 1)
chg := st.NewChange("install", "...")
tasks := make(map[string]*state.Task)
@@ -342,6 +343,206 @@ func (ts *taskRunnerSuite) TestSequenceTests(c *C) {
}
}
+func (ts *taskRunnerSuite) TestAbortAcrossLanesDescendantTask(c *C) {
+
+ // ()
+ // t11(1) -> t12(1) => t15(1)
+ // \ /
+ // => t13(1,2) => t14(1,2) => t23(1,2) => t24(1,2)
+ // / \
+ // t21(2) -> t22(2) => t25(2)
+ //
+ names := strings.Fields("t11 t12 t13 t14 t15 t21 t22 t23 t24 t25")
+
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+ defer r.Stop()
+
+ st.Lock()
+ defer st.Unlock()
+
+ c.Assert(len(st.Tasks()), Equals, 0)
+
+ chg := st.NewChange("install", "...")
+ tasks := make(map[string]*state.Task)
+ for _, name := range names {
+ tasks[name] = st.NewTask("do", name)
+ chg.AddTask(tasks[name])
+ }
+ tasks["t12"].WaitFor(tasks["t11"])
+ tasks["t13"].WaitFor(tasks["t12"])
+ tasks["t14"].WaitFor(tasks["t13"])
+ tasks["t15"].WaitFor(tasks["t14"])
+ for lane, names := range map[int][]string{
+ 1: {"t11", "t12", "t13", "t14", "t15", "t23", "t24"},
+ 2: {"t21", "t22", "t23", "t24", "t25", "t13", "t14"},
+ } {
+ for _, name := range names {
+ tasks[name].JoinLane(lane)
+ }
+ }
+
+ tasks["t22"].WaitFor(tasks["t21"])
+ tasks["t23"].WaitFor(tasks["t22"])
+ tasks["t24"].WaitFor(tasks["t23"])
+ tasks["t25"].WaitFor(tasks["t24"])
+
+ tasks["t13"].WaitFor(tasks["t22"])
+ tasks["t15"].WaitFor(tasks["t24"])
+ tasks["t23"].WaitFor(tasks["t14"])
+
+ ch := make(chan string, 256)
+ do := func(task *state.Task, tomb *tomb.Tomb) error {
+ c.Logf("do %q", task.Summary())
+ label := task.Summary()
+ if label == "t15" {
+ ch <- "t15:error"
+ return fmt.Errorf("mock error")
+ }
+ ch <- fmt.Sprintf("%s:do", label)
+ return nil
+ }
+ undo := func(task *state.Task, tomb *tomb.Tomb) error {
+ c.Logf("undo %q", task.Summary())
+ label := task.Summary()
+ ch <- fmt.Sprintf("%s:undo", label)
+ return nil
+ }
+ r.AddHandler("do", do, undo)
+
+ c.Logf("-----")
+
+ st.Unlock()
+ ensureChange(c, r, sb, chg)
+ st.Lock()
+ close(ch)
+ var sequence []string
+ for event := range ch {
+ sequence = append(sequence, event)
+ }
+ for _, name := range names {
+ task := tasks[name]
+ c.Logf("%5s %5s lanes: %v status: %v", task.ID(), task.Summary(), task.Lanes(), task.Status())
+ }
+ c.Assert(sequence[:4], testutil.DeepUnsortedMatches, []string{
+ "t11:do", "t12:do",
+ "t21:do", "t22:do",
+ })
+ c.Assert(sequence[4:8], DeepEquals, []string{
+ "t13:do", "t14:do", "t23:do", "t24:do",
+ })
+ c.Assert(sequence[8:10], testutil.DeepUnsortedMatches, []string{
+ "t25:do",
+ "t15:error",
+ })
+ c.Assert(sequence[10:11], testutil.DeepUnsortedMatches, []string{
+ "t25:undo",
+ })
+ c.Assert(sequence[11:15], DeepEquals, []string{
+ "t24:undo", "t23:undo", "t14:undo", "t13:undo",
+ })
+ c.Assert(sequence[15:19], testutil.DeepUnsortedMatches, []string{
+ "t21:undo", "t22:undo",
+ "t12:undo", "t11:undo",
+ })
+}
+
+func (ts *taskRunnerSuite) TestAbortAcrossLanesStriclyOrderedTasks(c *C) {
+
+ // ()
+ // t11(1) -> t12(1)
+ // \
+ // => t13(1,2) => t14(1,2) => t23(1,2) => t24(1,2)
+ // /
+ // t21(2) -> t22(2)
+ //
+ names := strings.Fields("t11 t12 t13 t14 t21 t22 t23 t24")
+
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+ defer r.Stop()
+
+ st.Lock()
+ defer st.Unlock()
+
+ c.Assert(len(st.Tasks()), Equals, 0)
+
+ chg := st.NewChange("install", "...")
+ tasks := make(map[string]*state.Task)
+ for _, name := range names {
+ tasks[name] = st.NewTask("do", name)
+ chg.AddTask(tasks[name])
+ }
+ tasks["t12"].WaitFor(tasks["t11"])
+ tasks["t13"].WaitFor(tasks["t12"])
+ tasks["t14"].WaitFor(tasks["t13"])
+ for lane, names := range map[int][]string{
+ 1: {"t11", "t12", "t13", "t14", "t23", "t24"},
+ 2: {"t21", "t22", "t23", "t24", "t13", "t14"},
+ } {
+ for _, name := range names {
+ tasks[name].JoinLane(lane)
+ }
+ }
+
+ tasks["t22"].WaitFor(tasks["t21"])
+ tasks["t23"].WaitFor(tasks["t22"])
+ tasks["t24"].WaitFor(tasks["t23"])
+
+ tasks["t13"].WaitFor(tasks["t22"])
+ tasks["t23"].WaitFor(tasks["t14"])
+
+ ch := make(chan string, 256)
+ do := func(task *state.Task, tomb *tomb.Tomb) error {
+ c.Logf("do %q", task.Summary())
+ label := task.Summary()
+ if label == "t24" {
+ ch <- "t24:error"
+ return fmt.Errorf("mock error")
+ }
+ ch <- fmt.Sprintf("%s:do", label)
+ return nil
+ }
+ undo := func(task *state.Task, tomb *tomb.Tomb) error {
+ c.Logf("undo %q", task.Summary())
+ label := task.Summary()
+ ch <- fmt.Sprintf("%s:undo", label)
+ return nil
+ }
+ r.AddHandler("do", do, undo)
+
+ c.Logf("-----")
+
+ st.Unlock()
+ ensureChange(c, r, sb, chg)
+ st.Lock()
+ close(ch)
+ var sequence []string
+ for event := range ch {
+ sequence = append(sequence, event)
+ }
+ for _, name := range names {
+ task := tasks[name]
+ c.Logf("%5s %5s lanes: %v status: %v", task.ID(), task.Summary(), task.Lanes(), task.Status())
+ }
+ c.Assert(sequence[:4], testutil.DeepUnsortedMatches, []string{
+ "t11:do", "t12:do",
+ "t21:do", "t22:do",
+ })
+ c.Assert(sequence[4:8], DeepEquals, []string{
+ "t13:do", "t14:do", "t23:do", "t24:error",
+ })
+ c.Assert(sequence[8:11], DeepEquals, []string{
+ "t23:undo", "t14:undo", "t13:undo",
+ })
+ c.Assert(sequence[11:], testutil.DeepUnsortedMatches, []string{
+ "t21:undo", "t22:undo",
+ "t12:undo", "t11:undo",
+ })
+}
+
func (ts *taskRunnerSuite) TestExternalAbort(c *C) {
sb := &stateBackend{}
st := state.New(sb)
@@ -372,6 +573,101 @@ func (ts *taskRunnerSuite) TestExternalAbort(c *C) {
ensureChange(c, r, sb, chg)
}
+func (ts *taskRunnerSuite) TestUndoSingleLane(c *C) {
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+ defer r.Stop()
+
+ r.AddHandler("noop", func(t *state.Task, tb *tomb.Tomb) error {
+ return nil
+ }, func(t *state.Task, tb *tomb.Tomb) error {
+ return nil
+ })
+
+ r.AddHandler("noop-slow", func(t *state.Task, tb *tomb.Tomb) error {
+ time.Sleep(10 * time.Millisecond)
+ t.State().Lock()
+ defer t.State().Unlock()
+ // critical
+ t.SetStatus(state.DoneStatus)
+ return nil
+ }, func(t *state.Task, tb *tomb.Tomb) error {
+ return nil
+ })
+
+ r.AddHandler("fail", func(t *state.Task, tb *tomb.Tomb) error {
+ return fmt.Errorf("fail")
+ }, nil)
+
+ st.Lock()
+
+ lane := st.NewLane()
+ chg := st.NewChange("install", "...")
+
+ // first taskset
+ var prev *state.Task
+ for i := 0; i < 10; i++ {
+ t := st.NewTask("noop-slow", "...")
+ if prev != nil {
+ t.WaitFor(prev)
+ }
+ chg.AddTask(t)
+ t.JoinLane(lane)
+
+ prev = t
+ }
+
+ // second taskset with a failing task that triggers undo of the change
+ prev = nil
+ for i := 0; i < 10; i++ {
+ t := st.NewTask("noop", "...")
+ if prev != nil {
+ t.WaitFor(prev)
+ }
+ chg.AddTask(t)
+ t.JoinLane(lane)
+ prev = t
+ }
+
+ // error trigger
+ t := st.NewTask("fail", "...")
+ t.WaitFor(prev)
+ chg.AddTask(t)
+ t.JoinLane(lane)
+
+ st.Unlock()
+
+ var done bool
+ for !done {
+ c.Assert(r.Ensure(), Equals, nil)
+ st.Lock()
+ done = chg.IsReady() && chg.IsClean()
+ st.Unlock()
+ }
+
+ st.Lock()
+ defer st.Unlock()
+
+ // make sure all tasks are either undone or on hold (except for "fail" task which
+ // is in error).
+ for _, t := range st.Tasks() {
+ switch t.Kind() {
+ case "fail":
+ c.Assert(t.Status(), Equals, state.ErrorStatus)
+ case "noop", "noop-slow":
+ if t.Status() != state.UndoneStatus && t.Status() != state.HoldStatus {
+ for _, tsk := range st.Tasks() {
+ fmt.Printf("%s -> %s\n", tsk.Kind(), tsk.Status())
+ }
+ c.Fatalf("unexpected status: %s", t.Status())
+ }
+ default:
+ c.Fatalf("unexpected kind: %s", t.Kind())
+ }
+ }
+}
+
func (ts *taskRunnerSuite) TestStopHandlerJustFinishing(c *C) {
sb := &stateBackend{}
st := state.New(sb)
@@ -508,6 +804,55 @@ func (ts *taskRunnerSuite) TestStopAskForRetry(c *C) {
c.Check(t.AtTime().IsZero(), Equals, false)
}
+func (ts *taskRunnerSuite) testTaskReturningWait(c *C, waitedStatus, expectedStatus state.Status) {
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+ defer r.Stop()
+
+ r.AddHandler("ask-for-wait", func(t *state.Task, tb *tomb.Tomb) error {
+ // ask for wait
+ return &state.Wait{WaitedStatus: waitedStatus}
+ }, nil)
+
+ st.Lock()
+ chg := st.NewChange("install", "...")
+ t := st.NewTask("ask-for-wait", "...")
+ chg.AddTask(t)
+ st.Unlock()
+
+ r.Ensure()
+ // wait for handler to finish
+ r.Wait()
+
+ st.Lock()
+ defer st.Unlock()
+ c.Check(t.Status(), Equals, state.WaitStatus)
+ c.Check(t.WaitedStatus(), Equals, expectedStatus)
+ c.Check(chg.Status().Ready(), Equals, false)
+
+ st.Unlock()
+ defer st.Lock()
+ // does nothing
+ r.Ensure()
+
+ // state is unchanged
+ st.Lock()
+ defer st.Unlock()
+ c.Check(t.Status(), Equals, state.WaitStatus)
+ c.Check(chg.Status().Ready(), Equals, false)
+}
+
+func (ts *taskRunnerSuite) TestTaskReturningWaitNormal(c *C) {
+ ts.testTaskReturningWait(c, state.UndoneStatus, state.UndoneStatus)
+}
+
+func (ts *taskRunnerSuite) TestTaskReturningWaitDefaultStatus(c *C) {
+ // If no state was set (DefaultStatus), then it should default to
+ // DoneStatus instead.
+ ts.testTaskReturningWait(c, state.DefaultStatus, state.DoneStatus)
+}
+
func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) {
ensureBeforeTick := make(chan bool, 1)
sb := &stateBackend{
@@ -823,7 +1168,7 @@ func (ts *taskRunnerSuite) TestUndoSequence(c *C) {
terr.WaitFor(prev)
chg.AddTask(terr)
- c.Check(chg.Tasks(), HasLen, 9) // sanity check
+ c.Check(chg.Tasks(), HasLen, 9) // validity check
st.Unlock()
@@ -910,3 +1255,71 @@ func (ts *taskRunnerSuite) TestCleanup(c *C) {
c.Assert(chgIsClean(), Equals, true)
c.Assert(called, Equals, 2)
}
+
+func (ts *taskRunnerSuite) TestErrorCallbackCalledOnError(c *C) {
+ logbuf, restore := logger.MockLogger("TASKRUNNER: ")
+ defer restore()
+
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+
+ var called bool
+ r.OnTaskError(func(err error) {
+ called = true
+ })
+
+ r.AddHandler("foo", func(t *state.Task, tomb *tomb.Tomb) error {
+ return fmt.Errorf("handler error for %q", t.Kind())
+ }, nil)
+
+ st.Lock()
+ chg := st.NewChange("install", "change summary")
+ t1 := st.NewTask("foo", "task summary")
+ chg.AddTask(t1)
+ st.Unlock()
+
+ // Mark tasks as done.
+ ensureChange(c, r, sb, chg)
+ r.Stop()
+
+ st.Lock()
+ defer st.Unlock()
+
+ c.Check(t1.Status(), Equals, state.ErrorStatus)
+ c.Check(strings.Join(t1.Log(), ""), Matches, `.*handler error for "foo"`)
+ c.Check(called, Equals, true)
+
+ c.Check(logbuf.String(), Matches, `(?m).*: \[change 1 "task summary" task\] failed: handler error for "foo".*`)
+}
+
+func (ts *taskRunnerSuite) TestErrorCallbackNotCalled(c *C) {
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+
+ var called bool
+ r.OnTaskError(func(err error) {
+ called = true
+ })
+
+ r.AddHandler("foo", func(t *state.Task, tomb *tomb.Tomb) error {
+ return nil
+ }, nil)
+
+ st.Lock()
+ chg := st.NewChange("install", "...")
+ t1 := st.NewTask("foo", "...")
+ chg.AddTask(t1)
+ st.Unlock()
+
+ // Mark tasks as done.
+ ensureChange(c, r, sb, chg)
+ r.Stop()
+
+ st.Lock()
+ defer st.Unlock()
+
+ c.Check(t1.Status(), Equals, state.DoneStatus)
+ c.Check(called, Equals, false)
+}
diff --git a/internals/overlord/state/warning.go b/internals/overlord/state/warning.go
index 69dac7c84..2b348b58c 100644
--- a/internals/overlord/state/warning.go
+++ b/internals/overlord/state/warning.go
@@ -1,21 +1,16 @@
-// -*- Mode: Go; indent-tabs-mode: t -*-
-
-/*
- * Copyright (c) 2018 Canonical Ltd
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 3 as
- * published by the Free Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- *
- */
+// Copyright (c) 2024 Canonical Ltd
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License version 3 as
+// published by the Free Software Foundation.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
package state
diff --git a/internals/overlord/state/warning_test.go b/internals/overlord/state/warning_test.go
index bf30f69e9..4cc701bb6 100644
--- a/internals/overlord/state/warning_test.go
+++ b/internals/overlord/state/warning_test.go
@@ -1,21 +1,16 @@
-// -*- Mode: Go; indent-tabs-mode: t -*-
-
-/*
- * Copyright (c) 2018 Canonical Ltd
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 3 as
- * published by the Free Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- *
- */
+// Copyright (c) 2024 Canonical Ltd
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License version 3 as
+// published by the Free Software Foundation.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
package state_test
@@ -99,7 +94,7 @@ func (stateSuite) TestUnmarshalErrors(c *check.C) {
}
for _, t := range []T1{
- // sanity check
+ // validity check
{`{"message": "x", "first-added": "2006-01-02T15:04:05Z", "expire-after": "1h", "repeat-after": "1h"}`, nil},
// remove one field at a time:
{`{ "first-added": "2006-01-02T15:04:05Z", "expire-after": "1h", "repeat-after": "1h"}`, state.ErrNoWarningMessage},