Skip to content

Commit

Permalink
feat: add interfaces for task runner and basic test to piggyback off …
Browse files Browse the repository at this point in the history
…this
  • Loading branch information
liamstevens committed Feb 7, 2025
1 parent b6d57d1 commit 470f2a8
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 48 deletions.
4 changes: 2 additions & 2 deletions src/aws/ecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type EcsClientAPI interface {
DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error)
}

type ecsWaiterAPI interface {
type EcsWaiterAPI interface {
WaitForOutput(ctx context.Context, params *ecs.DescribeTasksInput, maxWaitDur time.Duration, optFns ...func(*ecs.TasksStoppedWaiterOptions)) (*ecs.DescribeTasksOutput, error)
}

Expand Down Expand Up @@ -59,7 +59,7 @@ func SubmitTask(ctx context.Context, ecsAPI EcsClientAPI, input *TaskRunnerConfi
return *response.Tasks[0].TaskArn, nil
}

func WaitForCompletion(ctx context.Context, waiter ecsWaiterAPI, taskArn string, timeOut int) (*ecs.DescribeTasksOutput, error) {
func WaitForCompletion(ctx context.Context, waiter EcsWaiterAPI, taskArn string, timeOut int) (*ecs.DescribeTasksOutput, error) {
cluster := ClusterFromTaskArn(taskArn)

maxWaitDuration := time.Duration(timeOut) * time.Second
Expand Down
53 changes: 33 additions & 20 deletions src/aws/ecs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,16 +322,23 @@ func TestFindLogStreamFromTaskNegative(t *testing.T) {
// to allow thing to finish in the background. The return value is used only for when a task fails, and we push
// this to a log.
func TestWaitForCompletion(t *testing.T) {
mockedWaiter := mockECSWaiter{
mockWaitForOutput: func(context.Context, *ecs.DescribeTasksInput, time.Duration, ...func(*ecs.TasksStoppedWaiterOptions)) (*ecs.DescribeTasksOutput, error) {
return &ecs.DescribeTasksOutput{
Failures: []types.Failure{
{
Arn: aws.String("arn:aws:ecs:us-west-2:123456789012:task/test-cluster/07cc583696bd44e0be450bff7314ddaf"),
Detail: aws.String("task stopped"),
Reason: aws.String("computer is full of beanz"),
},
}}, errors.New("task stopped: computer is full of beanz")
mockedWaiter := map[string]mockECSWaiter{
"beans": {
mockWaitForOutput: func(context.Context, *ecs.DescribeTasksInput, time.Duration, ...func(*ecs.TasksStoppedWaiterOptions)) (*ecs.DescribeTasksOutput, error) {
return &ecs.DescribeTasksOutput{
Failures: []types.Failure{
{
Arn: aws.String("arn:aws:ecs:us-west-2:123456789012:task/test-cluster/07cc583696bd44e0be450bff7314ddaf"),
Detail: aws.String("task stopped"),
Reason: aws.String("computer is full of beanz"),
},
}}, nil
},
},
"slowpoke": {
mockWaitForOutput: func(context.Context, *ecs.DescribeTasksInput, time.Duration, ...func(*ecs.TasksStoppedWaiterOptions)) (*ecs.DescribeTasksOutput, error) {
return nil, errors.New("task timed out: computer is full of beanz")
},
},
}

Expand All @@ -344,35 +351,41 @@ func TestWaitForCompletion(t *testing.T) {
tests := []struct {
name string
input string
waiter ecsWaiterAPI
waiter EcsWaiterAPI
expected expectedReturn
}{
{
name: "given a task ARN, it should return the task details",
input: "arn:aws:ecs:us-west-2:123456789012:task/test-cluster/07cc583696bd44e0be450bff7314ddaf",
waiter: mockedWaiter,
waiter: mockedWaiter["beans"],
expected: expectedReturn{&ecs.DescribeTasksOutput{
Failures: []types.Failure{
{
Arn: aws.String("arn:aws:ecs:us-west-2:123456789012:task/test-cluster/07cc583696bd44e0be450bff7314ddaf"),
Detail: aws.String("task stopped"),
Reason: aws.String("computer is full of beanz"),
},
}}, errors.New("task stopped: computer is full of beanz"),
}}, nil,
},
},
{
name: "given a task that times out, it should return an error",
input: "arn:aws:ecs:us-west-2:123456789012:task/test-cluster/07cc583696bd44e0be450bff7314ddaf",
waiter: mockedWaiter["slowpoke"],
expected: expectedReturn{nil, errors.New("task timed out: computer is full of beanz")},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := WaitForCompletion(context.TODO(), tc.waiter, tc.input, 15)
t.Logf("result: '%v'", err)
t.Logf("expected: detail: %v, reason: %v", *tc.expected.Failures[0].Detail, *tc.expected.Failures[0].Reason)

// The function is most-useful when the underlying task fails. i.e. no news is good news in a real-world scenario
// So, we will test the failure cases
require.Error(t, err)
assert.Equal(t, tc.expected.Failures[0], result.Failures[0])
t.Logf("name: %s result: '%v'", tc.name, err)
// Errors are only returned when the waiter times out
if err != nil {
require.Equal(t, tc.expected.Error(), err.Error())
} else {
require.Equal(t, tc.expected.Failures, result.Failures)
}
})
}
}
6 changes: 5 additions & 1 deletion src/buildkite/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ import (
osexec "golang.org/x/sys/execabs"
)

type AgentAPI interface {
Annotate(ctx context.Context, message string, style string, annotationContext string) error
}

type Agent struct {
}

func (a *Agent) Annotate(ctx context.Context, message string, style string, annotationContext string) error {
func (a Agent) Annotate(ctx context.Context, message string, style string, annotationContext string) error {
return execCmd(ctx, "buildkite-agent", &message, "annotate", "--style", style, "--context", annotationContext)
}

Expand Down
3 changes: 2 additions & 1 deletion src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"os"

awsinternal "github.com/cultureamp/ecs-task-runner-buildkite-plugin/aws"
"github.com/cultureamp/ecs-task-runner-buildkite-plugin/buildkite"
"github.com/cultureamp/ecs-task-runner-buildkite-plugin/plugin"
)
Expand All @@ -13,7 +14,7 @@ func main() {
fetcher := plugin.EnvironmentConfigFetcher{}
taskRunnerPlugin := plugin.TaskRunnerPlugin{}

err := taskRunnerPlugin.Run(ctx, fetcher)
err := taskRunnerPlugin.Run(ctx, fetcher, awsinternal.WaitForCompletion)

if err != nil {
buildkite.LogFailuref("plugin execution failed: %s\n", err.Error())
Expand Down
61 changes: 37 additions & 24 deletions src/plugin/task-runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"strings"
"time"

awsinternal "github.com/cultureamp/ecs-task-runner-buildkite-plugin/aws"
Expand All @@ -18,13 +19,14 @@ import (
type TaskRunnerPlugin struct {
}

type WaitForCompletion func(ctx context.Context, waiter awsinternal.EcsWaiterAPI, taskArn string, timeOut int) (*ecs.DescribeTasksOutput, error)
type ConfigFetcher interface {
Fetch(config *Config) error
}

func (trp TaskRunnerPlugin) Run(ctx context.Context, fetcher ConfigFetcher) error {
func (trp TaskRunnerPlugin) Run(ctx context.Context, fetcher ConfigFetcher, waiter WaitForCompletion) error {
var config Config
timeoutError := errors.New("exceeded max wait time for TasksStopped waiter")

err := fetcher.Fetch(&config)
if err != nil {
return fmt.Errorf("plugin configuration error: %w", err)
Expand Down Expand Up @@ -60,29 +62,10 @@ func (trp TaskRunnerPlugin) Run(ctx context.Context, fetcher ConfigFetcher) erro
// TODO: This is currently a magic number. If we want this to be configurable, remove the nolint directive and fix it up
o.MaxDelay = 10 * time.Second //nolint:mnd
})
result, err := awsinternal.WaitForCompletion(ctx, waiterClient, taskArn, config.TimeOut)
result, err := waiter(ctx, waiterClient, taskArn, config.TimeOut)
err = trp.HandleResults(ctx, result, err, buildKiteAgent, config)
if err != nil {
if errors.Is(err, timeoutError) {
err := buildKiteAgent.Annotate(ctx, fmt.Sprintf("Task did not complete successfully within timeout (%d seconds)", config.TimeOut), "error", "ecs-task-runner")
if err != nil {
return fmt.Errorf("failed to annotate buildkite with task timeout failure: %w", err)
}
}
bkerr := buildKiteAgent.Annotate(ctx, fmt.Sprintf("failed to wait for task completion: %v\n", err), "error", "ecs-task-runner")
if bkerr != nil {
return fmt.Errorf("failed to annotate buildkite with task wait failure: %w, annotation error: %w", err, bkerr)
}
} else if len(result.Failures) > 0 {
// There is still a scenario where the task could return failures but this isn't handled by the waiter
// This is due to the waiter only returning errors in scenarios where there are issues querying the task
// or scheduling the task. For a list of the Failures that can be returned in this case, see:
// https://docs.aws.amazon.com/AmazonECS/latest/developerguide/api_failures_messages.html
// specifically, under the `DescribeTasks` API.
err := buildKiteAgent.Annotate(ctx, fmt.Sprintf("Task did not complete successfully: %v", result.Failures[0]), "error", "ecs-task-runner")
if err != nil {
return fmt.Errorf("failed to annotate buildkite with task failure: %w", err)
}
return fmt.Errorf("task did not complete successfully: %v", result.Failures[0])
return fmt.Errorf("failed to handle task results: %w", err)
}

// In a successful scenario for task completion, we would have a `tasks` slice with a single element
Expand Down Expand Up @@ -124,3 +107,33 @@ func (trp TaskRunnerPlugin) Run(ctx context.Context, fetcher ConfigFetcher) erro
buildkite.Log("done. \n")
return nil
}

func (trp TaskRunnerPlugin) HandleResults(ctx context.Context, output *ecs.DescribeTasksOutput, err error, bkAgent buildkite.AgentAPI, config Config) error {
if err != nil {
// This comparison is hacky, but is the only way that I could get the wrapped errors surfaced
// from the AWS library to be properly handled. It would be better if this was done using errors.As
if strings.Contains(err.Error(), "exceeded max wait time for TasksStopped waiter") {
err := bkAgent.Annotate(ctx, fmt.Sprintf("Task did not complete successfully within timeout (%d seconds)", config.TimeOut), "error", "ecs-task-runner")
if err != nil {
return fmt.Errorf("failed to annotate buildkite with task timeout failure: %w", err)
}
return errors.New("task did not complete within the time limit")
}
bkerr := bkAgent.Annotate(ctx, fmt.Sprintf("failed to wait for task completion: %v\n", err), "error", "ecs-task-runner")
if bkerr != nil {
return fmt.Errorf("failed to annotate buildkite with task wait failure: %w, annotation error: %w", err, bkerr)
}
} else if len(output.Failures) > 0 {
// There is still a scenario where the task could return failures but this isn't handled by the waiter
// This is due to the waiter only returning errors in scenarios where there are issues querying the task
// or scheduling the task. For a list of the Failures that can be returned in this case, see:
// https://docs.aws.amazon.com/AmazonECS/latest/developerguide/api_failures_messages.html
// specifically, under the `DescribeTasks` API.
err := bkAgent.Annotate(ctx, fmt.Sprintf("Task did not complete successfully: %v", output.Failures[0]), "error", "ecs-task-runner")
if err != nil {
return fmt.Errorf("failed to annotate buildkite with task failure: %w", err)
}
return fmt.Errorf("task did not complete successfully: %v", output.Failures[0])
}
return nil
}
117 changes: 117 additions & 0 deletions src/plugin/task-runner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package plugin_test

import (
"context"
"errors"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ecs"
"github.com/aws/aws-sdk-go-v2/service/ecs/types"
awsinternal "github.com/cultureamp/ecs-task-runner-buildkite-plugin/aws"
"github.com/cultureamp/ecs-task-runner-buildkite-plugin/plugin"
"github.com/stretchr/testify/require"
)

type MockBuildKiteAgent struct{}

func (m MockBuildKiteAgent) Annotate(ctx context.Context, message string, style string, annotationContext string) error {
return nil
}

func TestRunPluginResponse(t *testing.T) {
buildKiteAgent := MockBuildKiteAgent{}
t.Setenv("BUILDKITE_PLUGIN_ECS_TASK_RUNNER_PARAMETER_NAME", "test-parameter")
t.Setenv("BUILDKITE_PLUGIN_ECS_TASK_RUNNER_SCRIPT", "hello-world")
t.Setenv("BUILDKITE_PLUGIN_ECS_TASK_RUNNER_TIME_OUT", "15")
mockFetcher := plugin.EnvironmentConfigFetcher{}
var config plugin.Config
err := mockFetcher.Fetch(&config)
require.NoError(t, err)

mockContainers := map[string]types.Container{
"success": {
ExitCode: aws.Int32(0),
Image: aws.String("nginx"),
Name: aws.String("gateway"),
Reason: aws.String("Gracefully Terminated"),
},
"failed": {
ExitCode: aws.Int32(1),
Image: aws.String("nginx"),
Name: aws.String("gateway"),
Reason: aws.String("Panicked"),
},
"running": {
Image: aws.String("nginx"),
Name: aws.String("gateway"),
Reason: aws.String("Panicked"),
},
}

mockResponses := map[string]plugin.WaitForCompletion{
"success": func(ctx context.Context, waiter awsinternal.EcsWaiterAPI, taskArn string, timeOut int) (*ecs.DescribeTasksOutput, error) {
return &ecs.DescribeTasksOutput{
Tasks: []types.Task{{
Containers: []types.Container{
mockContainers["success"],
},
LastStatus: aws.String("STOPPED"),
},
},
}, nil
},
"failed": func(ctx context.Context, waiter awsinternal.EcsWaiterAPI, taskArn string, timeOut int) (*ecs.DescribeTasksOutput, error) {
return &ecs.DescribeTasksOutput{
Tasks: []types.Task{{
Containers: []types.Container{
mockContainers["failed"],
},
LastStatus: aws.String("STOPPED"),
},
},
Failures: []types.Failure{
{
Arn: aws.String("test-task-arn"),
Reason: aws.String("Panicked"),
Detail: aws.String("Container gateway panicked with non-zero exit code 1"),
},
},
}, nil
},
"running": func(ctx context.Context, waiter awsinternal.EcsWaiterAPI, taskArn string, timeOut int) (*ecs.DescribeTasksOutput, error) {
return &ecs.DescribeTasksOutput{
Tasks: []types.Task{{
Containers: []types.Container{
mockContainers["running"],
},
LastStatus: aws.String("RUNNING"),
},
},
}, errors.New("exceeded max wait time for TasksStopped waiter")
},
}

expectedString := map[string]string{
"success": "",
"failed": "task did not complete successfully",
"running": "task did not complete within the time limit",
}
// expectedError := map[string]error{
// "success": nil,
// "failed": errors.New(expectedString["failed"]),
// "running": errors.New(expectedString["running"]),
// }

for name, mockResponse := range mockResponses {
t.Run(name, func(t *testing.T) {
result, err := mockResponse(context.TODO(), nil, "test-task-arn", 15)
plugin := plugin.TaskRunnerPlugin{}
err = plugin.HandleResults(context.TODO(), result, err, buildKiteAgent, config)
if err != nil {
require.ErrorContains(t, err, expectedString[name])
t.Logf("expected: %v, actual: %v", expectedString[name], err)
}
})
}
}

0 comments on commit 470f2a8

Please sign in to comment.