From 2ad8d08f918e1ebb2c97081761cc9753c60921fb Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 22 Apr 2023 07:16:28 +0800 Subject: [PATCH] External Plugin Service (grpc) (#330) * Add fastapi plugin Signed-off-by: Kevin Su * Add dummy plugin Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * test Signed-off-by: Kevin Su * test Signed-off-by: Kevin Su * test Signed-off-by: Kevin Su * test Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * grpc plugin Signed-off-by: Kevin Su * updated idl version Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * add grpc plugin Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * update idl Signed-off-by: Kevin Su * more tests Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su * write output Signed-off-by: Kevin Su * remove prev state Signed-off-by: Kevin Su * bump idl Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * test Signed-off-by: Kevin Su * update idl Signed-off-by: Kevin Su * rename Signed-off-by: Kevin Su * more test Signed-off-by: Kevin Su * remove grpcTokenKey Signed-off-by: Kevin Su * Add SupportedTaskTypes Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * cache connection Signed-off-by: Kevin Su * more tests Signed-off-by: Kevin Su * fixes tests Signed-off-by: Kevin Su * fixes tests Signed-off-by: Kevin Su * more tests Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su * remove bigquery_query_job_task Signed-off-by: Kevin Su * add bigquery_query_job_task Signed-off-by: Kevin Su * set random value for SupportedTaskTypes Signed-off-by: Kevin Su --------- Signed-off-by: Kevin Su --- flyteplugins/go.mod | 3 +- flyteplugins/go.sum | 5 +- .../go/tasks/plugins/webapi/grpc/config.go | 72 +++++ .../tasks/plugins/webapi/grpc/config_test.go | 17 ++ .../plugins/webapi/grpc/integration_test.go | 268 ++++++++++++++++++ .../go/tasks/plugins/webapi/grpc/plugin.go | 206 ++++++++++++++ .../tasks/plugins/webapi/grpc/plugin_test.go | 60 ++++ 7 files changed, 628 insertions(+), 3 deletions(-) create mode 100644 flyteplugins/go/tasks/plugins/webapi/grpc/config.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/grpc/config_test.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/grpc/integration_test.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/grpc/plugin.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/grpc/plugin_test.go diff --git a/flyteplugins/go.mod b/flyteplugins/go.mod index 4d54cbdeae..e5863add0d 100644 --- a/flyteplugins/go.mod +++ b/flyteplugins/go.mod @@ -12,7 +12,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/athena v1.0.0 github.com/bstadlbauer/dask-k8s-operator-go-client v0.1.0 github.com/coocood/freecache v1.1.1 - github.com/flyteorg/flyteidl v1.3.14 + github.com/flyteorg/flyteidl v1.3.16 github.com/flyteorg/flytestdlib v1.0.15 github.com/go-test/deep v1.0.7 github.com/golang/protobuf v1.5.2 @@ -86,6 +86,7 @@ require ( github.com/google/gofuzz v1.2.0 // indirect github.com/googleapis/gax-go/v2 v2.3.0 // indirect github.com/googleapis/go-type-adapters v1.0.0 // indirect + github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect diff --git a/flyteplugins/go.sum b/flyteplugins/go.sum index 179728decd..09a4ffa144 100644 --- a/flyteplugins/go.sum +++ b/flyteplugins/go.sum @@ -232,8 +232,8 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/flyteorg/flyteidl v1.3.14 h1:o5M0g/r6pXTPu5PEurbYxbQmuOu3hqqsaI2M6uvK0N8= -github.com/flyteorg/flyteidl v1.3.14/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= +github.com/flyteorg/flyteidl v1.3.16 h1:mRq1VeUl5LP12dezbGHLQcrLuAmO9kawK9X7arqCInM= +github.com/flyteorg/flyteidl v1.3.16/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= @@ -443,6 +443,7 @@ github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:Fecb github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= +github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= diff --git a/flyteplugins/go/tasks/plugins/webapi/grpc/config.go b/flyteplugins/go/tasks/plugins/webapi/grpc/config.go new file mode 100644 index 0000000000..2fef371a0e --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/grpc/config.go @@ -0,0 +1,72 @@ +package grpc + +import ( + "time" + + pluginsConfig "github.com/flyteorg/flyteplugins/go/tasks/config" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" + "github.com/flyteorg/flytestdlib/config" +) + +var ( + defaultConfig = Config{ + WebAPI: webapi.PluginConfig{ + ResourceQuotas: map[core.ResourceNamespace]int{ + "default": 1000, + }, + ReadRateLimiter: webapi.RateLimiterConfig{ + Burst: 100, + QPS: 10, + }, + WriteRateLimiter: webapi.RateLimiterConfig{ + Burst: 100, + QPS: 10, + }, + Caching: webapi.CachingConfig{ + Size: 500000, + ResyncInterval: config.Duration{Duration: 30 * time.Second}, + Workers: 10, + MaxSystemFailures: 5, + }, + ResourceMeta: nil, + }, + ResourceConstraints: core.ResourceConstraintsSpec{ + ProjectScopeResourceConstraint: &core.ResourceConstraint{ + Value: 100, + }, + NamespaceScopeResourceConstraint: &core.ResourceConstraint{ + Value: 50, + }, + }, + DefaultGrpcEndpoint: "dns:///external-plugin-service.flyte.svc.cluster.local:80", + SupportedTaskTypes: []string{"task_type_1", "task_type_2"}, + } + + configSection = pluginsConfig.MustRegisterSubSection("external-plugin-service", &defaultConfig) +) + +// Config is config for 'databricks' plugin +type Config struct { + // WebAPI defines config for the base WebAPI plugin + WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` + + // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time + ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."` + + DefaultGrpcEndpoint string `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of external plugin service."` + + // Maps endpoint to their plugin handler. {TaskType: Endpoint} + EndpointForTaskTypes map[string]string `json:"endpointForTaskTypes" pflag:"-,"` + + // SupportedTaskTypes is a list of task types that are supported by this plugin. + SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."` +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} + +func SetConfig(cfg *Config) error { + return configSection.SetConfig(cfg) +} diff --git a/flyteplugins/go/tasks/plugins/webapi/grpc/config_test.go b/flyteplugins/go/tasks/plugins/webapi/grpc/config_test.go new file mode 100644 index 0000000000..9e994f07fb --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/grpc/config_test.go @@ -0,0 +1,17 @@ +package grpc + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestGetAndSetConfig(t *testing.T) { + cfg := defaultConfig + cfg.WebAPI.Caching.Workers = 1 + cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + err := SetConfig(&cfg) + assert.NoError(t, err) + assert.Equal(t, &cfg, GetConfig()) +} diff --git a/flyteplugins/go/tasks/plugins/webapi/grpc/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/grpc/integration_test.go new file mode 100644 index 0000000000..a4f1d42846 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/grpc/integration_test.go @@ -0,0 +1,268 @@ +package grpc + +import ( + "context" + "encoding/json" + "fmt" + "sync/atomic" + "testing" + "time" + + flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + pluginCoreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + + "github.com/flyteorg/flyteidl/clients/go/coreutils" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" + "github.com/flyteorg/flyteplugins/tests" + "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" + "github.com/flyteorg/flytestdlib/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc" + "k8s.io/apimachinery/pkg/util/rand" +) + +type MockPlugin struct { + Plugin +} + +type MockClient struct { +} + +func (m *MockClient) CreateTask(_ context.Context, _ *service.TaskCreateRequest, _ ...grpc.CallOption) (*service.TaskCreateResponse, error) { + return &service.TaskCreateResponse{JobId: "job-id"}, nil +} + +func (m *MockClient) GetTask(_ context.Context, _ *service.TaskGetRequest, _ ...grpc.CallOption) (*service.TaskGetResponse, error) { + return &service.TaskGetResponse{State: service.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{ + Literals: map[string]*flyteIdlCore.Literal{ + "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), + }, + }}, nil +} + +func (m *MockClient) DeleteTask(_ context.Context, _ *service.TaskDeleteRequest, _ ...grpc.CallOption) (*service.TaskDeleteResponse, error) { + return &service.TaskDeleteResponse{}, nil +} + +func mockGetClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.ExternalPluginServiceClient, error) { + return &MockClient{}, nil +} + +func mockGetBadClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.ExternalPluginServiceClient, error) { + return nil, fmt.Errorf("error") +} + +func TestEndToEnd(t *testing.T) { + iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { + return nil + } + + cfg := defaultConfig + cfg.WebAPI.ResourceQuotas = map[core.ResourceNamespace]int{} + cfg.WebAPI.Caching.Workers = 1 + cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + err := SetConfig(&cfg) + assert.NoError(t, err) + + databricksConfDict := map[string]interface{}{ + "name": "flytekit databricks plugin example", + "new_cluster": map[string]string{ + "spark_version": "11.0.x-scala2.12", + "node_type_id": "r3.xlarge", + "num_workers": "4", + }, + "timeout_seconds": 3600, + "max_retries": 1, + } + databricksConfig, err := utils.MarshalObjToStruct(databricksConfDict) + assert.NoError(t, err) + sparkJob := plugins.SparkJob{DatabricksConf: databricksConfig, DatabricksToken: "token", SparkConf: map[string]string{"spark.driver.bindAddress": "127.0.0.1"}} + st, err := utils.MarshalPbToStruct(&sparkJob) + assert.NoError(t, err) + + inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) + template := flyteIdlCore.TaskTemplate{ + Type: "bigquery_query_job_task", + Custom: st, + } + basePrefix := storage.DataReference("fake://bucket/prefix/") + + t.Run("run a job", func(t *testing.T) { + pluginEntry := pluginmachinery.CreateRemotePlugin(newMockGrpcPlugin()) + plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test1")) + assert.NoError(t, err) + + phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) + assert.Equal(t, true, phase.Phase().IsSuccess()) + }) + + t.Run("failed to create a job", func(t *testing.T) { + grpcPlugin := newMockGrpcPlugin() + grpcPlugin.PluginLoader = func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + return &MockPlugin{ + Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: GetConfig(), + getClient: mockGetBadClientFunc, + }, + }, nil + } + pluginEntry := pluginmachinery.CreateRemotePlugin(grpcPlugin) + plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test2")) + assert.NoError(t, err) + + tCtx := getTaskContext(t) + tr := &pluginCoreMocks.TaskReader{} + tr.OnRead(context.Background()).Return(&template, nil) + tCtx.OnTaskReader().Return(tr) + inputReader := &ioMocks.InputReader{} + inputReader.OnGetInputPrefixPath().Return(basePrefix) + inputReader.OnGetInputPath().Return(basePrefix + "/inputs.pb") + inputReader.OnGetMatch(mock.Anything).Return(inputs, nil) + tCtx.OnInputReader().Return(inputReader) + + trns, err := plugin.Handle(context.Background(), tCtx) + assert.Error(t, err) + assert.Equal(t, trns.Info().Phase(), core.PhaseUndefined) + err = plugin.Abort(context.Background(), tCtx) + assert.Nil(t, err) + }) + + t.Run("failed to read task template", func(t *testing.T) { + tCtx := getTaskContext(t) + tr := &pluginCoreMocks.TaskReader{} + tr.OnRead(context.Background()).Return(nil, fmt.Errorf("read fail")) + tCtx.OnTaskReader().Return(tr) + + grpcPlugin := newMockGrpcPlugin() + pluginEntry := pluginmachinery.CreateRemotePlugin(grpcPlugin) + plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test3")) + assert.NoError(t, err) + + trns, err := plugin.Handle(context.Background(), tCtx) + assert.Error(t, err) + assert.Equal(t, trns.Info().Phase(), core.PhaseUndefined) + }) + + t.Run("failed to read inputs", func(t *testing.T) { + tCtx := getTaskContext(t) + tr := &pluginCoreMocks.TaskReader{} + tr.OnRead(context.Background()).Return(&template, nil) + tCtx.OnTaskReader().Return(tr) + inputReader := &ioMocks.InputReader{} + inputReader.OnGetInputPrefixPath().Return(basePrefix) + inputReader.OnGetInputPath().Return(basePrefix + "/inputs.pb") + inputReader.OnGetMatch(mock.Anything).Return(nil, fmt.Errorf("read fail")) + tCtx.OnInputReader().Return(inputReader) + + grpcPlugin := newMockGrpcPlugin() + pluginEntry := pluginmachinery.CreateRemotePlugin(grpcPlugin) + plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test4")) + assert.NoError(t, err) + + trns, err := plugin.Handle(context.Background(), tCtx) + assert.Error(t, err) + assert.Equal(t, trns.Info().Phase(), core.PhaseUndefined) + }) +} + +func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext { + latestKnownState := atomic.Value{} + pluginStateReader := &pluginCoreMocks.PluginStateReader{} + pluginStateReader.OnGetMatch(mock.Anything).Return(0, nil).Run(func(args mock.Arguments) { + o := args.Get(0) + x, err := json.Marshal(latestKnownState.Load()) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(x, &o)) + }) + pluginStateWriter := &pluginCoreMocks.PluginStateWriter{} + pluginStateWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + latestKnownState.Store(args.Get(1)) + }) + + pluginStateWriter.OnReset().Return(nil).Run(func(args mock.Arguments) { + latestKnownState.Store(nil) + }) + + execID := rand.String(3) + tID := &pluginCoreMocks.TaskExecutionID{} + tID.OnGetGeneratedName().Return(execID + "-my-task-1") + tID.OnGetID().Return(flyteIdlCore.TaskExecutionIdentifier{ + TaskId: &flyteIdlCore.Identifier{ + ResourceType: flyteIdlCore.ResourceType_TASK, + Project: "a", + Domain: "d", + Name: "n", + Version: "abc", + }, + NodeExecutionId: &flyteIdlCore.NodeExecutionIdentifier{ + NodeId: "node1", + ExecutionId: &flyteIdlCore.WorkflowExecutionIdentifier{ + Project: "a", + Domain: "d", + Name: "exec", + }, + }, + RetryAttempt: 0, + }) + tMeta := &pluginCoreMocks.TaskExecutionMetadata{} + tMeta.OnGetTaskExecutionID().Return(tID) + resourceManager := &pluginCoreMocks.ResourceManager{} + resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(pluginCore.AllocationStatusGranted, nil) + resourceManager.OnReleaseResourceMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + + basePrefix := storage.DataReference("fake://bucket/prefix/" + execID) + outputWriter := &ioMocks.OutputWriter{} + outputWriter.OnGetRawOutputPrefix().Return("/sandbox/") + outputWriter.OnGetOutputPrefixPath().Return(basePrefix) + outputWriter.OnGetErrorPath().Return(basePrefix + "/error.pb") + outputWriter.OnGetOutputPath().Return(basePrefix + "/outputs.pb") + outputWriter.OnGetCheckpointPrefix().Return("/checkpoint") + outputWriter.OnGetPreviousCheckpointsPrefix().Return("/prev") + + tCtx := &pluginCoreMocks.TaskExecutionContext{} + tCtx.OnOutputWriter().Return(outputWriter) + tCtx.OnResourceManager().Return(resourceManager) + tCtx.OnPluginStateReader().Return(pluginStateReader) + tCtx.OnPluginStateWriter().Return(pluginStateWriter) + tCtx.OnTaskExecutionMetadata().Return(tMeta) + return tCtx +} + +func newMockGrpcPlugin() webapi.PluginEntry { + return webapi.PluginEntry{ + ID: "external-plugin-service", + SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task"}, + PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + return &MockPlugin{ + Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: GetConfig(), + getClient: mockGetClientFunc, + }, + }, nil + }, + } +} + +func newFakeSetupContext(name string) *pluginCoreMocks.SetupContext { + fakeResourceRegistrar := pluginCoreMocks.ResourceRegistrar{} + fakeResourceRegistrar.On("RegisterResourceQuota", mock.Anything, mock.Anything, mock.Anything).Return(nil) + labeled.SetMetricKeys(contextutils.NamespaceKey) + + fakeSetupContext := pluginCoreMocks.SetupContext{} + fakeSetupContext.OnMetricsScope().Return(promutils.NewScope(name)) + fakeSetupContext.OnResourceRegistrar().Return(&fakeResourceRegistrar) + + return &fakeSetupContext +} diff --git a/flyteplugins/go/tasks/plugins/webapi/grpc/plugin.go b/flyteplugins/go/tasks/plugins/webapi/grpc/plugin.go new file mode 100644 index 0000000000..b180fdb882 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/grpc/plugin.go @@ -0,0 +1,206 @@ +package grpc + +import ( + "context" + "encoding/gob" + "fmt" + + "google.golang.org/grpc/grpclog" + + flyteIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" + "github.com/flyteorg/flytestdlib/promutils" + "google.golang.org/grpc" +) + +type GetClientFunc func(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.ExternalPluginServiceClient, error) + +type Plugin struct { + metricScope promutils.Scope + cfg *Config + getClient GetClientFunc + connectionCache map[string]*grpc.ClientConn +} + +type ResourceWrapper struct { + State service.State + Outputs *flyteIdl.LiteralMap +} + +type ResourceMetaWrapper struct { + OutputPrefix string + Token string + JobID string + TaskType string +} + +func (p Plugin) GetConfig() webapi.PluginConfig { + return GetConfig().WebAPI +} + +func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( + namespace core.ResourceNamespace, constraints core.ResourceConstraintsSpec, err error) { + + // Resource requirements are assumed to be the same. + return "default", p.cfg.ResourceConstraints, nil +} + +func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, + webapi.Resource, error) { + taskTemplate, err := taskCtx.TaskReader().Read(ctx) + if err != nil { + return nil, nil, err + } + inputs, err := taskCtx.InputReader().Get(ctx) + if err != nil { + return nil, nil, err + } + + outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() + + endpoint := getFinalEndpoint(taskTemplate.Type, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes) + client, err := p.getClient(ctx, endpoint, p.connectionCache) + if err != nil { + return nil, nil, fmt.Errorf("failed to connect external plugin service with error: %v", err) + } + + res, err := client.CreateTask(ctx, &service.TaskCreateRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix}) + if err != nil { + return nil, nil, err + } + + return &ResourceMetaWrapper{ + OutputPrefix: outputPrefix, + JobID: res.GetJobId(), + Token: "", + TaskType: taskTemplate.Type, + }, &ResourceWrapper{State: service.State_RUNNING}, nil +} + +func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { + metadata := taskCtx.ResourceMeta().(*ResourceMetaWrapper) + + endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes) + client, err := p.getClient(ctx, endpoint, p.connectionCache) + if err != nil { + return nil, fmt.Errorf("failed to connect external plugin service with error: %v", err) + } + + res, err := client.GetTask(ctx, &service.TaskGetRequest{TaskType: metadata.TaskType, JobId: metadata.JobID}) + if err != nil { + return nil, err + } + + return &ResourceWrapper{ + State: res.State, + Outputs: res.Outputs, + }, nil +} + +func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error { + if taskCtx.ResourceMeta() == nil { + return nil + } + metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) + + endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes) + client, err := p.getClient(ctx, endpoint, p.connectionCache) + if err != nil { + return fmt.Errorf("failed to connect external plugin service with error: %v", err) + } + + _, err = client.DeleteTask(ctx, &service.TaskDeleteRequest{TaskType: metadata.TaskType, JobId: metadata.JobID}) + return err +} + +func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { + resource := taskCtx.Resource().(*ResourceWrapper) + taskInfo := &core.TaskInfo{} + + switch resource.State { + case service.State_RUNNING: + return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil + case service.State_PERMANENT_FAILURE: + return core.PhaseInfoFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil + case service.State_RETRYABLE_FAILURE: + return core.PhaseInfoRetryableFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil + case service.State_SUCCEEDED: + if resource.Outputs != nil { + err := taskCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil)) + if err != nil { + return core.PhaseInfoUndefined, err + } + } + return core.PhaseInfoSuccess(taskInfo), nil + } + return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State) +} + +func getFinalEndpoint(taskType, defaultEndpoint string, endpointForTaskTypes map[string]string) string { + if t, exists := endpointForTaskTypes[taskType]; exists { + return t + } + + return defaultEndpoint +} + +func getClientFunc(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.ExternalPluginServiceClient, error) { + conn, ok := connectionCache[endpoint] + if ok { + return service.NewExternalPluginServiceClient(conn), nil + } + var opts []grpc.DialOption + var err error + + opts = append(opts, grpc.WithInsecure()) + conn, err = grpc.Dial(endpoint, opts...) + if err != nil { + return nil, err + } + connectionCache[endpoint] = conn + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + return service.NewExternalPluginServiceClient(conn), nil +} + +func newGrpcPlugin() webapi.PluginEntry { + supportedTaskTypes := GetConfig().SupportedTaskTypes + + return webapi.PluginEntry{ + ID: "external-plugin-service", + SupportedTaskTypes: supportedTaskTypes, + PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + return &Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: GetConfig(), + getClient: getClientFunc, + connectionCache: make(map[string]*grpc.ClientConn), + }, nil + }, + } +} + +func init() { + gob.Register(ResourceMetaWrapper{}) + gob.Register(ResourceWrapper{}) + + pluginmachinery.PluginRegistry().RegisterRemotePlugin(newGrpcPlugin()) +} diff --git a/flyteplugins/go/tasks/plugins/webapi/grpc/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/grpc/plugin_test.go new file mode 100644 index 0000000000..93bf11ec92 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/grpc/plugin_test.go @@ -0,0 +1,60 @@ +package grpc + +import ( + "context" + "testing" + "time" + + "google.golang.org/grpc" + + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + pluginCoreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +func TestPlugin(t *testing.T) { + fakeSetupContext := pluginCoreMocks.SetupContext{} + fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) + + plugin := Plugin{ + metricScope: fakeSetupContext.MetricsScope(), + cfg: GetConfig(), + } + t.Run("get config", func(t *testing.T) { + cfg := defaultConfig + cfg.WebAPI.Caching.Workers = 1 + cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + cfg.DefaultGrpcEndpoint = "test-service.flyte.svc.cluster.local:80" + cfg.EndpointForTaskTypes = map[string]string{"spark": "localhost:80"} + err := SetConfig(&cfg) + assert.NoError(t, err) + assert.Equal(t, cfg.WebAPI, plugin.GetConfig()) + }) + t.Run("get ResourceRequirements", func(t *testing.T) { + namespace, constraints, err := plugin.ResourceRequirements(context.TODO(), nil) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.ResourceNamespace("default"), namespace) + assert.Equal(t, plugin.cfg.ResourceConstraints, constraints) + }) + + t.Run("tet newGrpcPlugin", func(t *testing.T) { + p := newGrpcPlugin() + assert.NotNil(t, p) + assert.Equal(t, p.ID, "external-plugin-service") + assert.NotNil(t, p.PluginLoader) + }) + + t.Run("test getFinalEndpoint", func(t *testing.T) { + endpoint := getFinalEndpoint("spark", "localhost:8080", map[string]string{"spark": "localhost:80"}) + assert.Equal(t, endpoint, "localhost:80") + endpoint = getFinalEndpoint("spark", "localhost:8080", map[string]string{}) + assert.Equal(t, endpoint, "localhost:8080") + }) + + t.Run("test getClientFunc", func(t *testing.T) { + client, err := getClientFunc(context.Background(), "localhost:80", map[string]*grpc.ClientConn{}) + assert.NoError(t, err) + assert.NotNil(t, client) + }) +}