Skip to content

Commit c2dc1c5

Browse files
authored
full task cleanup when alloc prerun hook fails (#17104)
to avoid leaking task resources (e.g. containers, iptables) if allocRunner prerun fails during restore on client restart. now if prerun fails, TaskRunner.MarkFailedKill() will only emit an event, mark the task as failed, and cancel the tr's killCtx, so then ar.runTasks() -> tr.Run() can take care of the actual cleanup. removed from (formerly) tr.MarkFailedDead(), now handled by tr.Run(): * set task state as dead * save task runner local state * task stop hooks also done in tr.Run() now that it's not skipped: * handleKill() to kill tasks while respecting their shutdown delay, and retrying as needed * also includes task preKill hooks * clearDriverHandle() to destroy the task and associated resources * task exited hooks
1 parent 58a7d40 commit c2dc1c5

File tree

10 files changed

+334
-35
lines changed

10 files changed

+334
-35
lines changed

.changelog/17104.txt

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
```release-note:bug
2+
client: clean up resources upon failure to restore task during client restart
3+
```

client/allocrunner/alloc_runner.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
log "github.com/hashicorp/go-hclog"
1313
multierror "github.com/hashicorp/go-multierror"
14+
1415
"github.com/hashicorp/nomad/client/allocdir"
1516
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
1617
"github.com/hashicorp/nomad/client/allocrunner/state"
@@ -347,17 +348,15 @@ func (ar *allocRunner) Run() {
347348
ar.logger.Error("prerun failed", "error", err)
348349

349350
for _, tr := range ar.tasks {
350-
tr.MarkFailedDead(fmt.Sprintf("failed to setup alloc: %v", err))
351+
// emit event and mark task to be cleaned up during runTasks()
352+
tr.MarkFailedKill(fmt.Sprintf("failed to setup alloc: %v", err))
351353
}
352-
353-
goto POST
354354
}
355355
}
356356

357357
// Run the runners (blocks until they exit)
358358
ar.runTasks()
359359

360-
POST:
361360
if ar.isShuttingDown() {
362361
return
363362
}

client/allocrunner/alloc_runner_hooks.go

+4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"time"
99

1010
multierror "github.com/hashicorp/go-multierror"
11+
1112
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
1213
clientconfig "github.com/hashicorp/nomad/client/config"
1314
"github.com/hashicorp/nomad/client/taskenv"
@@ -138,6 +139,9 @@ func (ar *allocRunner) initRunnerHooks(config *clientconfig.Config) error {
138139
newCSIHook(alloc, hookLogger, ar.csiManager, ar.rpcClient, ar, ar.hookResources, ar.clientConfig.Node.SecretID),
139140
newChecksHook(hookLogger, alloc, ar.checkStore, ar),
140141
}
142+
if config.ExtraAllocHooks != nil {
143+
ar.runnerHooks = append(ar.runnerHooks, config.ExtraAllocHooks...)
144+
}
141145

142146
return nil
143147
}

client/allocrunner/fail_hook.go

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
// FailHook is designed to fail for testing purposes,
5+
// so should never be included in a release.
6+
//go:build !release
7+
8+
package allocrunner
9+
10+
import (
11+
"errors"
12+
"fmt"
13+
"os"
14+
15+
"github.com/hashicorp/go-hclog"
16+
"github.com/hashicorp/hcl/v2/hclsimple"
17+
18+
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
19+
)
20+
21+
var ErrFailHookError = errors.New("failed successfully")
22+
23+
func NewFailHook(l hclog.Logger, name string) *FailHook {
24+
return &FailHook{
25+
name: name,
26+
logger: l.Named(name),
27+
}
28+
}
29+
30+
type FailHook struct {
31+
name string
32+
logger hclog.Logger
33+
Fail struct {
34+
Prerun bool `hcl:"prerun,optional"`
35+
PreKill bool `hcl:"prekill,optional"`
36+
Postrun bool `hcl:"postrun,optional"`
37+
Destroy bool `hcl:"destroy,optional"`
38+
Update bool `hcl:"update,optional"`
39+
PreTaskRestart bool `hcl:"pretaskrestart,optional"`
40+
Shutdown bool `hcl:"shutdown,optional"`
41+
}
42+
}
43+
44+
func (h *FailHook) Name() string {
45+
return h.name
46+
}
47+
48+
func (h *FailHook) LoadConfig(path string) *FailHook {
49+
if _, err := os.Stat(path); os.IsNotExist(err) {
50+
h.logger.Error("couldn't load config", "error", err)
51+
return h
52+
}
53+
if err := hclsimple.DecodeFile(path, nil, &h.Fail); err != nil {
54+
h.logger.Error("error parsing config", "path", path, "error", err)
55+
}
56+
return h
57+
}
58+
59+
var _ interfaces.RunnerPrerunHook = &FailHook{}
60+
61+
func (h *FailHook) Prerun() error {
62+
if h.Fail.Prerun {
63+
return fmt.Errorf("prerun %w", ErrFailHookError)
64+
}
65+
return nil
66+
}
67+
68+
var _ interfaces.RunnerPreKillHook = &FailHook{}
69+
70+
func (h *FailHook) PreKill() {
71+
if h.Fail.PreKill {
72+
h.logger.Error("prekill", "error", ErrFailHookError)
73+
}
74+
}
75+
76+
var _ interfaces.RunnerPostrunHook = &FailHook{}
77+
78+
func (h *FailHook) Postrun() error {
79+
if h.Fail.Postrun {
80+
return fmt.Errorf("postrun %w", ErrFailHookError)
81+
}
82+
return nil
83+
}
84+
85+
var _ interfaces.RunnerDestroyHook = &FailHook{}
86+
87+
func (h *FailHook) Destroy() error {
88+
if h.Fail.Destroy {
89+
return fmt.Errorf("destroy %w", ErrFailHookError)
90+
}
91+
return nil
92+
}
93+
94+
var _ interfaces.RunnerUpdateHook = &FailHook{}
95+
96+
func (h *FailHook) Update(request *interfaces.RunnerUpdateRequest) error {
97+
if h.Fail.Update {
98+
return fmt.Errorf("update %w", ErrFailHookError)
99+
}
100+
return nil
101+
}
102+
103+
var _ interfaces.RunnerTaskRestartHook = &FailHook{}
104+
105+
func (h *FailHook) PreTaskRestart() error {
106+
if h.Fail.PreTaskRestart {
107+
return fmt.Errorf("destroy %w", ErrFailHookError)
108+
}
109+
return nil
110+
}
111+
112+
var _ interfaces.ShutdownHook = &FailHook{}
113+
114+
func (h *FailHook) Shutdown() {
115+
if h.Fail.Shutdown {
116+
h.logger.Error("shutdown", "error", ErrFailHookError)
117+
}
118+
}

client/allocrunner/taskrunner/task_runner.go

+12-21
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ import (
1111
"sync"
1212
"time"
1313

14-
"github.com/hashicorp/nomad/client/lib/cgutil"
1514
"golang.org/x/exp/slices"
1615

1716
metrics "github.com/armon/go-metrics"
1817
log "github.com/hashicorp/go-hclog"
1918
multierror "github.com/hashicorp/go-multierror"
2019
"github.com/hashicorp/hcl/v2/hcldec"
20+
2121
"github.com/hashicorp/nomad/client/allocdir"
2222
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
2323
"github.com/hashicorp/nomad/client/allocrunner/taskrunner/restarts"
@@ -27,6 +27,7 @@ import (
2727
"github.com/hashicorp/nomad/client/devicemanager"
2828
"github.com/hashicorp/nomad/client/dynamicplugins"
2929
cinterfaces "github.com/hashicorp/nomad/client/interfaces"
30+
"github.com/hashicorp/nomad/client/lib/cgutil"
3031
"github.com/hashicorp/nomad/client/pluginmanager/csimanager"
3132
"github.com/hashicorp/nomad/client/pluginmanager/drivermanager"
3233
"github.com/hashicorp/nomad/client/serviceregistration"
@@ -495,30 +496,20 @@ func (tr *TaskRunner) initLabels() {
495496
}
496497
}
497498

498-
// MarkFailedDead marks a task as failed and not to run. Aimed to be invoked
499-
// when alloc runner prestart hooks failed. Should never be called with Run().
500-
func (tr *TaskRunner) MarkFailedDead(reason string) {
501-
defer close(tr.waitCh)
502-
503-
tr.stateLock.Lock()
504-
if err := tr.stateDB.PutTaskRunnerLocalState(tr.allocID, tr.taskName, tr.localState); err != nil {
505-
//TODO Nomad will be unable to restore this task; try to kill
506-
// it now and fail? In general we prefer to leave running
507-
// tasks running even if the agent encounters an error.
508-
tr.logger.Warn("error persisting local failed task state; may be unable to restore after a Nomad restart",
509-
"error", err)
510-
}
511-
tr.stateLock.Unlock()
512-
499+
// MarkFailedKill marks a task as failed and should be killed.
500+
// It should be invoked when alloc runner prestart hooks fail.
501+
// Afterwards, Run() will perform any necessary cleanup.
502+
func (tr *TaskRunner) MarkFailedKill(reason string) {
503+
// Emit an event that fails the task and gives reasons for humans.
513504
event := structs.NewTaskEvent(structs.TaskSetupFailure).
505+
SetKillReason(structs.TaskRestoreFailed).
514506
SetDisplayMessage(reason).
515507
SetFailsTask()
516-
tr.UpdateState(structs.TaskStateDead, event)
508+
tr.EmitEvent(event)
517509

518-
// Run the stop hooks in case task was a restored task that failed prestart
519-
if err := tr.stop(); err != nil {
520-
tr.logger.Error("stop failed while marking task dead", "error", err)
521-
}
510+
// Cancel kill context, so later when allocRunner runs tr.Run(),
511+
// we'll follow the usual kill path and do all the appropriate cleanup steps.
512+
tr.killCtxCancel()
522513
}
523514

524515
// Run the TaskRunner. Starts the user's task or reattaches to a restored task.

client/allocrunner/taskrunner/task_runner_test.go

+61-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ import (
1616
"time"
1717

1818
"github.com/golang/snappy"
19+
"github.com/kr/pretty"
20+
"github.com/shoenig/test"
21+
"github.com/shoenig/test/must"
22+
"github.com/stretchr/testify/assert"
23+
"github.com/stretchr/testify/require"
24+
1925
"github.com/hashicorp/nomad/ci"
2026
"github.com/hashicorp/nomad/client/allocdir"
2127
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
@@ -41,10 +47,6 @@ import (
4147
"github.com/hashicorp/nomad/plugins/device"
4248
"github.com/hashicorp/nomad/plugins/drivers"
4349
"github.com/hashicorp/nomad/testutil"
44-
"github.com/kr/pretty"
45-
"github.com/shoenig/test/must"
46-
"github.com/stretchr/testify/assert"
47-
"github.com/stretchr/testify/require"
4850
)
4951

5052
type MockTaskStateUpdater struct {
@@ -662,6 +664,61 @@ func TestTaskRunner_Restore_System(t *testing.T) {
662664
})
663665
}
664666

667+
// TestTaskRunner_MarkFailedKill asserts that MarkFailedKill marks the task as failed
668+
// and cancels the killCtx so a subsequent Run() will do any necessary task cleanup.
669+
func TestTaskRunner_MarkFailedKill(t *testing.T) {
670+
ci.Parallel(t)
671+
672+
// set up some taskrunner
673+
alloc := mock.MinAlloc()
674+
task := alloc.Job.TaskGroups[0].Tasks[0]
675+
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
676+
t.Cleanup(cleanup)
677+
tr, err := NewTaskRunner(conf)
678+
must.NoError(t, err)
679+
680+
// side quest: set this lifecycle coordination channel,
681+
// so early in tr MAIN, it doesn't randomly follow that route.
682+
// test config creates this already closed, but not so in real life.
683+
startCh := make(chan struct{})
684+
t.Cleanup(func() { close(startCh) })
685+
tr.startConditionMetCh = startCh
686+
687+
// function under test: should mark the task as failed and cancel kill context
688+
reason := "because i said so"
689+
tr.MarkFailedKill(reason)
690+
691+
// explicitly check kill context.
692+
select {
693+
case <-tr.killCtx.Done():
694+
default:
695+
t.Fatal("kill context should be done")
696+
}
697+
698+
// Run() should now follow the kill path.
699+
go tr.Run()
700+
701+
select { // it should finish up very quickly
702+
case <-tr.WaitCh():
703+
case <-time.After(time.Second):
704+
t.Error("task not killed (or not as fast as expected)")
705+
}
706+
707+
// check state for expected values and events
708+
state := tr.TaskState()
709+
710+
// this gets set directly by MarkFailedKill()
711+
test.True(t, state.Failed, test.Sprint("task should have failed"))
712+
// this is set in Run()
713+
test.Eq(t, structs.TaskStateDead, state.State, test.Sprint("task should be dead"))
714+
// reason "because i said so" should be a task event message
715+
foundMessages := make(map[string]bool)
716+
for _, event := range state.Events {
717+
foundMessages[event.DisplayMessage] = true
718+
}
719+
test.True(t, foundMessages[reason], test.Sprintf("expected '%s' in events: %#v", reason, foundMessages))
720+
}
721+
665722
// TestTaskRunner_TaskEnv_Interpolated asserts driver configurations are
666723
// interpolated.
667724
func TestTaskRunner_TaskEnv_Interpolated(t *testing.T) {

0 commit comments

Comments
 (0)