|
| 1 | +package taskrunner |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "sync" |
| 7 | + "time" |
| 8 | + |
| 9 | + "github.com/hashicorp/consul/api" |
| 10 | + log "github.com/hashicorp/go-hclog" |
| 11 | + "github.com/hashicorp/nomad/client/allocrunner/interfaces" |
| 12 | + tinterfaces "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces" |
| 13 | + "github.com/hashicorp/nomad/client/consul" |
| 14 | + "github.com/hashicorp/nomad/client/taskenv" |
| 15 | + agentconsul "github.com/hashicorp/nomad/command/agent/consul" |
| 16 | + "github.com/hashicorp/nomad/nomad/structs" |
| 17 | +) |
| 18 | + |
| 19 | +var _ interfaces.TaskPoststartHook = &scriptCheckHook{} |
| 20 | +var _ interfaces.TaskUpdateHook = &scriptCheckHook{} |
| 21 | +var _ interfaces.TaskStopHook = &scriptCheckHook{} |
| 22 | + |
| 23 | +// default max amount of time to wait for all scripts on shutdown. |
| 24 | +const defaultShutdownWait = time.Minute |
| 25 | + |
| 26 | +type scriptCheckHookConfig struct { |
| 27 | + alloc *structs.Allocation |
| 28 | + task *structs.Task |
| 29 | + consul consul.ConsulServiceAPI |
| 30 | + logger log.Logger |
| 31 | + shutdownWait time.Duration |
| 32 | +} |
| 33 | + |
| 34 | +// scriptCheckHook implements a task runner hook for running script |
| 35 | +// checks in the context of a task |
| 36 | +type scriptCheckHook struct { |
| 37 | + consul consul.ConsulServiceAPI |
| 38 | + allocID string |
| 39 | + taskName string |
| 40 | + logger log.Logger |
| 41 | + shutdownWait time.Duration // max time to wait for scripts to shutdown |
| 42 | + shutdownCh chan struct{} // closed when all scripts should shutdown |
| 43 | + |
| 44 | + // The following fields can be changed by Update() |
| 45 | + driverExec tinterfaces.ScriptExecutor |
| 46 | + taskEnv *taskenv.TaskEnv |
| 47 | + |
| 48 | + // These maintain state |
| 49 | + scripts map[string]*scriptCheck |
| 50 | + runningScripts map[string]*taskletHandle |
| 51 | + |
| 52 | + // Since Update() may be called concurrently with any other hook all |
| 53 | + // hook methods must be fully serialized |
| 54 | + mu sync.Mutex |
| 55 | +} |
| 56 | + |
| 57 | +func newScriptCheckHook(c scriptCheckHookConfig) *scriptCheckHook { |
| 58 | + scriptChecks := make(map[string]*scriptCheck) |
| 59 | + for _, service := range c.task.Services { |
| 60 | + for _, check := range service.Checks { |
| 61 | + if check.Type != structs.ServiceCheckScript { |
| 62 | + continue |
| 63 | + } |
| 64 | + sc := newScriptCheck(&scriptCheckConfig{ |
| 65 | + allocID: c.alloc.ID, |
| 66 | + taskName: c.task.Name, |
| 67 | + check: check, |
| 68 | + service: service, |
| 69 | + agent: c.consul, |
| 70 | + }) |
| 71 | + scriptChecks[sc.id] = sc |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + // Walk back through the task group to see if there are script checks |
| 76 | + // associated with the task. If so, we'll create scriptCheck tasklets |
| 77 | + // for them. The group-level service and any check restart behaviors it |
| 78 | + // needs are entirely encapsulated within the group service hook which |
| 79 | + // watches Consul for status changes. |
| 80 | + tg := c.alloc.Job.LookupTaskGroup(c.alloc.TaskGroup) |
| 81 | + for _, service := range tg.Services { |
| 82 | + for _, check := range service.Checks { |
| 83 | + if check.Type != structs.ServiceCheckScript { |
| 84 | + continue |
| 85 | + } |
| 86 | + if check.TaskName != c.task.Name { |
| 87 | + continue |
| 88 | + } |
| 89 | + groupTaskName := "group-" + tg.Name |
| 90 | + sc := newScriptCheck(&scriptCheckConfig{ |
| 91 | + allocID: c.alloc.ID, |
| 92 | + taskName: groupTaskName, |
| 93 | + service: service, |
| 94 | + check: check, |
| 95 | + agent: c.consul, |
| 96 | + }) |
| 97 | + scriptChecks[sc.id] = sc |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + h := &scriptCheckHook{ |
| 102 | + consul: c.consul, |
| 103 | + allocID: c.alloc.ID, |
| 104 | + taskName: c.task.Name, |
| 105 | + scripts: scriptChecks, |
| 106 | + runningScripts: make(map[string]*taskletHandle), |
| 107 | + shutdownWait: defaultShutdownWait, |
| 108 | + shutdownCh: make(chan struct{}), |
| 109 | + } |
| 110 | + |
| 111 | + if c.shutdownWait != 0 { |
| 112 | + h.shutdownWait = c.shutdownWait // override for testing |
| 113 | + } |
| 114 | + h.logger = c.logger.Named(h.Name()) |
| 115 | + return h |
| 116 | +} |
| 117 | + |
| 118 | +func (h *scriptCheckHook) Name() string { |
| 119 | + return "script_checks" |
| 120 | +} |
| 121 | + |
| 122 | +// PostStart implements interfaces.TaskPoststartHook. It adds the current |
| 123 | +// task context (driver and env) to the script checks and starts up the |
| 124 | +// scripts. |
| 125 | +func (h *scriptCheckHook) Poststart(ctx context.Context, req *interfaces.TaskPoststartRequest, _ *interfaces.TaskPoststartResponse) error { |
| 126 | + h.mu.Lock() |
| 127 | + defer h.mu.Unlock() |
| 128 | + |
| 129 | + if req.DriverExec == nil { |
| 130 | + return fmt.Errorf("driver doesn't support script checks") |
| 131 | + } |
| 132 | + |
| 133 | + // Store the TaskEnv for interpolating now and when Updating |
| 134 | + h.driverExec = req.DriverExec |
| 135 | + h.taskEnv = req.TaskEnv |
| 136 | + h.scripts = h.getTaskScriptChecks() |
| 137 | + |
| 138 | + // Handle starting scripts |
| 139 | + for checkID, script := range h.scripts { |
| 140 | + // If it's already running, cancel and replace |
| 141 | + if oldScript, running := h.runningScripts[checkID]; running { |
| 142 | + oldScript.cancel() |
| 143 | + } |
| 144 | + // Start and store the handle |
| 145 | + h.runningScripts[checkID] = script.run() |
| 146 | + } |
| 147 | + return nil |
| 148 | +} |
| 149 | + |
| 150 | +// Updated implements interfaces.TaskUpdateHook. It adds the current |
| 151 | +// task context (driver and env) to the script checks and replaces any |
| 152 | +// that have been changed. |
| 153 | +func (h *scriptCheckHook) Update(ctx context.Context, req *interfaces.TaskUpdateRequest, _ *interfaces.TaskUpdateResponse) error { |
| 154 | + h.mu.Lock() |
| 155 | + defer h.mu.Unlock() |
| 156 | + |
| 157 | + // Get current script checks with request's driver metadata as it |
| 158 | + // can't change due to Updates |
| 159 | + oldScriptChecks := h.getTaskScriptChecks() |
| 160 | + |
| 161 | + task := req.Alloc.LookupTask(h.taskName) |
| 162 | + if task == nil { |
| 163 | + return fmt.Errorf("task %q not found in updated alloc", h.taskName) |
| 164 | + } |
| 165 | + |
| 166 | + // Update service hook fields |
| 167 | + h.taskEnv = req.TaskEnv |
| 168 | + |
| 169 | + // Create new script checks struct with those new values |
| 170 | + newScriptChecks := h.getTaskScriptChecks() |
| 171 | + |
| 172 | + // Handle starting scripts |
| 173 | + for checkID, script := range newScriptChecks { |
| 174 | + if _, ok := oldScriptChecks[checkID]; ok { |
| 175 | + // If it's already running, cancel and replace |
| 176 | + if oldScript, running := h.runningScripts[checkID]; running { |
| 177 | + oldScript.cancel() |
| 178 | + } |
| 179 | + // Start and store the handle |
| 180 | + h.runningScripts[checkID] = script.run() |
| 181 | + } |
| 182 | + } |
| 183 | + |
| 184 | + // Cancel scripts we no longer want |
| 185 | + for checkID := range oldScriptChecks { |
| 186 | + if _, ok := newScriptChecks[checkID]; !ok { |
| 187 | + if oldScript, running := h.runningScripts[checkID]; running { |
| 188 | + oldScript.cancel() |
| 189 | + } |
| 190 | + } |
| 191 | + } |
| 192 | + return nil |
| 193 | +} |
| 194 | + |
| 195 | +// Stop implements interfaces.TaskStopHook and blocks waiting for running |
| 196 | +// scripts to finish (or for the shutdownWait timeout to expire). |
| 197 | +func (h *scriptCheckHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error { |
| 198 | + h.mu.Lock() |
| 199 | + defer h.mu.Unlock() |
| 200 | + close(h.shutdownCh) |
| 201 | + deadline := time.After(h.shutdownWait) |
| 202 | + err := fmt.Errorf("timed out waiting for script checks to exit") |
| 203 | + for _, script := range h.runningScripts { |
| 204 | + select { |
| 205 | + case <-script.wait(): |
| 206 | + case <-ctx.Done(): |
| 207 | + // the caller is passing the background context, so |
| 208 | + // we should never really see this outside of testing |
| 209 | + case <-deadline: |
| 210 | + // at this point the Consul client has been cleaned |
| 211 | + // up so we don't want to hang onto this. |
| 212 | + return err |
| 213 | + } |
| 214 | + } |
| 215 | + return nil |
| 216 | +} |
| 217 | + |
| 218 | +// getTaskScriptChecks returns an interpolated copy of services and checks with |
| 219 | +// values from the task's environment. |
| 220 | +func (h *scriptCheckHook) getTaskScriptChecks() map[string]*scriptCheck { |
| 221 | + // Guard against not having a valid taskEnv. This can be the case if the |
| 222 | + // PreKilling or Exited hook is run before Poststart. |
| 223 | + if h.taskEnv == nil || h.driverExec == nil { |
| 224 | + return nil |
| 225 | + } |
| 226 | + newChecks := make(map[string]*scriptCheck) |
| 227 | + for _, orig := range h.scripts { |
| 228 | + sc := orig.Copy() |
| 229 | + sc.exec = h.driverExec |
| 230 | + sc.logger = h.logger |
| 231 | + sc.shutdownCh = h.shutdownCh |
| 232 | + sc.callback = newScriptCheckCallback(sc) |
| 233 | + sc.Command = h.taskEnv.ReplaceEnv(orig.Command) |
| 234 | + sc.Args = h.taskEnv.ParseAndReplace(orig.Args) |
| 235 | + newChecks[sc.id] = sc |
| 236 | + } |
| 237 | + return newChecks |
| 238 | +} |
| 239 | + |
| 240 | +// heartbeater is the subset of consul agent functionality needed by script |
| 241 | +// checks to heartbeat |
| 242 | +type heartbeater interface { |
| 243 | + UpdateTTL(id, output, status string) error |
| 244 | +} |
| 245 | + |
| 246 | +// scriptCheck runs script checks via a interfaces.ScriptExecutor and updates the |
| 247 | +// appropriate check's TTL when the script succeeds. |
| 248 | +type scriptCheck struct { |
| 249 | + id string |
| 250 | + agent heartbeater |
| 251 | + lastCheckOk bool // true if the last check was ok; otherwise false |
| 252 | + tasklet |
| 253 | +} |
| 254 | + |
| 255 | +// scriptCheckConfig is a parameter struct for newScriptCheck |
| 256 | +type scriptCheckConfig struct { |
| 257 | + allocID string |
| 258 | + taskName string |
| 259 | + service *structs.Service |
| 260 | + check *structs.ServiceCheck |
| 261 | + agent heartbeater |
| 262 | +} |
| 263 | + |
| 264 | +// newScriptCheck constructs a scriptCheck. we're only going to |
| 265 | +// configure the immutable fields of scriptCheck here, with the |
| 266 | +// rest being configured during the Poststart hook so that we have |
| 267 | +// the rest of the task execution environment |
| 268 | +func newScriptCheck(config *scriptCheckConfig) *scriptCheck { |
| 269 | + serviceID := agentconsul.MakeTaskServiceID( |
| 270 | + config.allocID, config.taskName, config.service) |
| 271 | + checkID := agentconsul.MakeCheckID(serviceID, config.check) |
| 272 | + |
| 273 | + sc := &scriptCheck{ |
| 274 | + id: checkID, |
| 275 | + agent: config.agent, |
| 276 | + lastCheckOk: true, // start logging on first failure |
| 277 | + } |
| 278 | + // we can't use the promoted fields of tasklet in the struct literal |
| 279 | + sc.Command = config.check.Command |
| 280 | + sc.Args = config.check.Args |
| 281 | + sc.Interval = config.check.Interval |
| 282 | + sc.Timeout = config.check.Timeout |
| 283 | + return sc |
| 284 | +} |
| 285 | + |
| 286 | +// Copy does a *shallow* copy of script checks. |
| 287 | +func (sc *scriptCheck) Copy() *scriptCheck { |
| 288 | + newSc := sc |
| 289 | + return newSc |
| 290 | +} |
| 291 | + |
| 292 | +// closes over the script check and returns the taskletCallback for |
| 293 | +// when the script check executes. |
| 294 | +func newScriptCheckCallback(s *scriptCheck) taskletCallback { |
| 295 | + |
| 296 | + return func(ctx context.Context, params execResult) { |
| 297 | + output := params.output |
| 298 | + code := params.code |
| 299 | + err := params.err |
| 300 | + |
| 301 | + state := api.HealthCritical |
| 302 | + switch code { |
| 303 | + case 0: |
| 304 | + state = api.HealthPassing |
| 305 | + case 1: |
| 306 | + state = api.HealthWarning |
| 307 | + } |
| 308 | + |
| 309 | + var outputMsg string |
| 310 | + if err != nil { |
| 311 | + state = api.HealthCritical |
| 312 | + outputMsg = err.Error() |
| 313 | + } else { |
| 314 | + outputMsg = string(output) |
| 315 | + } |
| 316 | + |
| 317 | + // heartbeat the check to Consul |
| 318 | + err = s.updateTTL(ctx, s.id, outputMsg, state) |
| 319 | + select { |
| 320 | + case <-ctx.Done(): |
| 321 | + // check has been removed; don't report errors |
| 322 | + return |
| 323 | + default: |
| 324 | + } |
| 325 | + |
| 326 | + if err != nil { |
| 327 | + if s.lastCheckOk { |
| 328 | + s.lastCheckOk = false |
| 329 | + s.logger.Warn("updating check failed", "error", err) |
| 330 | + } else { |
| 331 | + s.logger.Debug("updating check still failing", "error", err) |
| 332 | + } |
| 333 | + |
| 334 | + } else if !s.lastCheckOk { |
| 335 | + // Succeeded for the first time or after failing; log |
| 336 | + s.lastCheckOk = true |
| 337 | + s.logger.Info("updating check succeeded") |
| 338 | + } |
| 339 | + } |
| 340 | +} |
| 341 | + |
| 342 | +const ( |
| 343 | + updateTTLBackoffBaseline = 1 * time.Second |
| 344 | + updateTTLBackoffLimit = 3 * time.Second |
| 345 | +) |
| 346 | + |
| 347 | +// updateTTL updates the state to Consul, performing an expontential backoff |
| 348 | +// in the case where the check isn't registered in Consul to avoid a race between |
| 349 | +// service registration and the first check. |
| 350 | +func (s *scriptCheck) updateTTL(ctx context.Context, id, msg, state string) error { |
| 351 | + for attempts := 0; ; attempts++ { |
| 352 | + err := s.agent.UpdateTTL(id, msg, state) |
| 353 | + if err == nil { |
| 354 | + return nil |
| 355 | + } |
| 356 | + |
| 357 | + // Handle the retry case |
| 358 | + backoff := (1 << (2 * uint64(attempts))) * updateTTLBackoffBaseline |
| 359 | + if backoff > updateTTLBackoffLimit { |
| 360 | + return err |
| 361 | + } |
| 362 | + |
| 363 | + // Wait till retrying |
| 364 | + select { |
| 365 | + case <-ctx.Done(): |
| 366 | + return err |
| 367 | + case <-time.After(backoff): |
| 368 | + } |
| 369 | + } |
| 370 | +} |
0 commit comments