Skip to content

Commit

Permalink
Merge pull request #9 from lyft/fix-unmarshal
Browse files Browse the repository at this point in the history
Add validators to where we call unmarshal func
  • Loading branch information
EngHabu authored Sep 11, 2019
2 parents aea1f39 + a4b807e commit e8dbad8
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 13 deletions.
25 changes: 23 additions & 2 deletions go/tasks/v1/k8splugins/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,35 @@ func validateAndFinalizeContainers(
return &pod, nil
}

func validateSidecarJob(sidecarJob *plugins.SidecarJob) error {
if sidecarJob == nil {
return fmt.Errorf("empty sidecarjob")
}

if sidecarJob.PodSpec == nil {
return fmt.Errorf("empty podspec")
}

if len(sidecarJob.PodSpec.Containers) == 0 {
return fmt.Errorf("empty containers")
}

return nil
}

func (sidecarResourceHandler) BuildResource(
ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) (
flytek8s.K8sResource, error) {
sidecarJob := plugins.SidecarJob{}
err := utils.UnmarshalStruct(task.GetCustom(), &sidecarJob)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification,
"invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error())
return nil, errors.Wrapf(errors.BadTaskSpecification, err,
"invalid TaskSpecification [%v], failed to unmarshal", task.GetCustom())
}

if err = validateSidecarJob(&sidecarJob); err != nil {
return nil, errors.Wrapf(errors.BadTaskSpecification, err,
"invalid TaskSpecification [%v]", task.GetCustom())
}

pod := flytek8s.BuildPodWithSpec(sidecarJob.PodSpec)
Expand Down
20 changes: 18 additions & 2 deletions go/tasks/v1/k8splugins/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,29 @@ func setSparkConfig(cfg *SparkConfig) error {
type sparkResourceHandler struct {
}

func validateSparkJob(sparkJob *plugins.SparkJob) error {
if sparkJob == nil {
return fmt.Errorf("empty sparkJob")
}

if len(sparkJob.MainApplicationFile) == 0 && len(sparkJob.MainClass) == 0 {
return fmt.Errorf("either MainApplicationFile or MainClass must be set")
}

return nil
}

// Creates a new Job that will execute the main container as well as any generated types the result from the execution.
func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) (flytek8s.K8sResource, error) {

sparkJob := plugins.SparkJob{}
err := utils.UnmarshalStruct(task.GetCustom(), &sparkJob)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error())
return nil, errors.Wrapf(errors.BadTaskSpecification, err, "invalid TaskSpecification [%v], failed to unmarshal", task.GetCustom())
}

if err = validateSparkJob(&sparkJob); err != nil {
return nil, errors.Wrapf(errors.BadTaskSpecification, err, "invalid TaskSpecification [%v].", task.GetCustom())
}

annotations := flytek8s.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.GetAnnotations()))
Expand Down Expand Up @@ -147,7 +163,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx types.Tas
HadoopConf: sparkJob.GetHadoopConf(),
// SubmissionFailures handled here. Task Failures handled at Propeller/Job level.
RestartPolicy: sparkOp.RestartPolicy{
Type: sparkOp.OnFailure,
Type: sparkOp.OnFailure,
OnSubmissionFailureRetries: &submissionFailureRetries,
},
},
Expand Down
25 changes: 25 additions & 0 deletions go/tasks/v1/k8splugins/waitable_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ func discoverWaitableInputs(l *core.Literal) (literals []*core.Literal, waitable
return []*core.Literal{}, []*waitableWrapper{}
}

if err = validateWaitable(waitable); err != nil {
// skip, it's just a different type?
return []*core.Literal{}, []*waitableWrapper{}
}

return []*core.Literal{l}, []*waitableWrapper{{Waitable: waitable}}
}
}
Expand Down Expand Up @@ -309,6 +314,22 @@ func (w waitableTaskExecutor) getUpdatedWaitables(ctx context.Context, taskCtx t
return updatedWaitables, allDone, hasChanged, nil
}

func validateWaitable(waitable *plugins.Waitable) error {
if waitable == nil {
return fmt.Errorf("empty waitable")
}

if waitable.WfExecId == nil {
return fmt.Errorf("empty executionID")
}

if len(waitable.WfExecId.Name) == 0 {
return fmt.Errorf("empty executionID Name")
}

return nil
}

func updateWaitableLiterals(literals []*core.Literal, waitables []*waitableWrapper) error {
index := make(map[string]*plugins.Waitable, len(waitables))
for _, w := range waitables {
Expand All @@ -321,6 +342,10 @@ func updateWaitableLiterals(literals []*core.Literal, waitables []*waitableWrapp
return err
}

if err := validateWaitable(orig); err != nil {
return err
}

newW, found := index[orig.WfExecId.String()]
if !found {
return fmt.Errorf("couldn't find a waitable corresponding to literal WfID: %v", orig.WfExecId.String())
Expand Down
24 changes: 21 additions & 3 deletions go/tasks/v1/qubole/hive_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package qubole
import (
"context"
"fmt"
"github.com/go-redis/redis"
"strconv"
"time"

"github.com/go-redis/redis"

eventErrors "github.com/lyft/flyteidl/clients/go/events/errors"
"github.com/lyft/flyteplugins/go/tasks/v1/events"

Expand Down Expand Up @@ -126,6 +127,18 @@ func (h HiveExecutor) getUniqueCacheKey(taskCtx types.TaskContext, idx int) stri
return fmt.Sprintf("%s_%d", taskCtx.GetTaskExecutionID().GetGeneratedName(), idx)
}

func validateQuboleHiveJob(job *plugins.QuboleHiveJob) error {
if job == nil {
return fmt.Errorf("empty job")
}

if job.Query == nil && job.QueryCollection == nil {
return fmt.Errorf("either query or queryCollection must be set")
}

return nil
}

// This function is only ever called once, assuming it doesn't return in error.
// Essentially, what this function does is translate the task's custom field into the TaskContext's CustomState
// that's stored back into etcd
Expand All @@ -135,8 +148,13 @@ func (h HiveExecutor) StartTask(ctx context.Context, taskCtx types.TaskContext,
hiveJob := plugins.QuboleHiveJob{}
err := utils.UnmarshalStruct(task.GetCustom(), &hiveJob)
if err != nil {
return types.TaskStatusPermanentFailure(errors.Errorf(errors.BadTaskSpecification,
"Invalid Job Specification in task: [%v]. Err: [%v]", task.GetCustom(), err)), nil
return types.TaskStatusPermanentFailure(errors.Wrapf(errors.BadTaskSpecification, err,
"Invalid Job Specification in task: [%v], failed to unmarshal", task.GetCustom())), nil
}

if err = validateQuboleHiveJob(&hiveJob); err != nil {
return types.TaskStatusPermanentFailure(errors.Wrapf(errors.BadTaskSpecification, err,
"Invalid Job Specification in task: [%v]", task.GetCustom())), nil
}

// TODO: Asserts around queries, like len > 0 or something.
Expand Down
14 changes: 8 additions & 6 deletions go/tasks/v1/qubole/qubole_work_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package qubole

import (
"encoding/json"
"strings"
"testing"

"github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"
tasksMocks "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks"
"github.com/stretchr/testify/assert"
"strings"
"testing"
)

func getMockTaskContext() *tasksMocks.TaskContext {
Expand All @@ -31,6 +32,7 @@ func TestConstructEventInfoFromQuboleWorkItems(t *testing.T) {
Status: QuboleWorkSucceeded,
ClusterLabel: "default",
Tags: []string{},
CommandUri: "https://api.qubole.com/command/",
},
}

Expand Down Expand Up @@ -156,12 +158,12 @@ func TestInterfaceConverter(t *testing.T) {
// This is a complicated step to reproduce what will ultimately be given to the function at runtime, the values
// inside the CustomState
item := QuboleWorkItem{
Status: QuboleWorkRunning,
CommandId: "123456",
Query: "",
Status: QuboleWorkRunning,
CommandId: "123456",
Query: "",
UniqueWorkCacheKey: "fjdsakfjd",
}
raw, err := json.Marshal(map[string]interface{}{"":item})
raw, err := json.Marshal(map[string]interface{}{"": item})
assert.NoError(t, err)

// We can't unmarshal into a interface{} but we can unmarhsal into a interface{} if it's the value of a map.
Expand Down

0 comments on commit e8dbad8

Please sign in to comment.