Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Use agent as name where it fits #381

Merged
merged 1 commit into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go/tasks/plugins/webapi/agent/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ func TestGetAndSetConfig(t *testing.T) {
}
cfg.DefaultAgent.DefaultTimeout = config.Duration{Duration: 10 * time.Second}
cfg.Agents = map[string]*Agent{
"endpoint_1": {
"agent_1": {
Insecure: cfg.DefaultAgent.Insecure,
DefaultServiceConfig: cfg.DefaultAgent.DefaultServiceConfig,
Timeouts: cfg.DefaultAgent.Timeouts,
},
}
cfg.AgentForTaskTypes = map[string]string{"task_type_1": "endpoint_1"}
cfg.AgentForTaskTypes = map[string]string{"task_type_1": "agent_1"}
err := SetConfig(&cfg)
assert.NoError(t, err)
assert.Equal(t, &cfg, GetConfig())
Expand Down
62 changes: 31 additions & 31 deletions go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"google.golang.org/grpc"
)

type GetClientFunc func(ctx context.Context, endpoint *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error)
type GetClientFunc func(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error)

type Plugin struct {
metricScope promutils.Scope
Expand Down Expand Up @@ -70,16 +70,16 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR

outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String()

endpoint, err := getFinalEndpoint(taskTemplate.Type, p.cfg)
agent, err := getFinalAgent(taskTemplate.Type, p.cfg)
if err != nil {
return nil, nil, fmt.Errorf("failed to find agent endpoint with error: %v", err)
return nil, nil, fmt.Errorf("failed to find agent agent with error: %v", err)
}
client, err := p.getClient(ctx, endpoint, p.connectionCache)
client, err := p.getClient(ctx, agent, p.connectionCache)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

finalCtx, cancel := getFinalContext(ctx, "CreateTask", endpoint)
finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent)
defer cancel()

taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
Expand All @@ -99,16 +99,16 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
metadata := taskCtx.ResourceMeta().(*ResourceMetaWrapper)

endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg)
agent, err := getFinalAgent(metadata.TaskType, p.cfg)
if err != nil {
return nil, fmt.Errorf("failed to find agent endpoint with error: %v", err)
return nil, fmt.Errorf("failed to find agent with error: %v", err)
}
client, err := p.getClient(ctx, endpoint, p.connectionCache)
client, err := p.getClient(ctx, agent, p.connectionCache)
if err != nil {
return nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

finalCtx, cancel := getFinalContext(ctx, "GetTask", endpoint)
finalCtx, cancel := getFinalContext(ctx, "GetTask", agent)
defer cancel()

res, err := client.GetTask(finalCtx, &admin.GetTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
Expand All @@ -128,16 +128,16 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
}
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)

endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg)
agent, err := getFinalAgent(metadata.TaskType, p.cfg)
if err != nil {
return fmt.Errorf("failed to find agent endpoint with error: %v", err)
return fmt.Errorf("failed to find agent agent with error: %v", err)
}
client, err := p.getClient(ctx, endpoint, p.connectionCache)
client, err := p.getClient(ctx, agent, p.connectionCache)
if err != nil {
return fmt.Errorf("failed to connect to agent with error: %v", err)
}

finalCtx, cancel := getFinalContext(ctx, "DeleteTask", endpoint)
finalCtx, cancel := getFinalContext(ctx, "DeleteTask", agent)
defer cancel()

_, err = client.DeleteTask(finalCtx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
Expand Down Expand Up @@ -167,26 +167,26 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase
return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State)
}

func getFinalEndpoint(taskType string, cfg *Config) (*Agent, error) {
func getFinalAgent(taskType string, cfg *Config) (*Agent, error) {
if id, exists := cfg.AgentForTaskTypes[taskType]; exists {
if endpoint, exists := cfg.Agents[id]; exists {
return endpoint, nil
if agent, exists := cfg.Agents[id]; exists {
return agent, nil
}
return nil, fmt.Errorf("no endpoint definition found for ID %s that matches task type %s", id, taskType)
return nil, fmt.Errorf("no agent definition found for ID %s that matches task type %s", id, taskType)
}

return &cfg.DefaultAgent, nil
}

func getClientFunc(ctx context.Context, endpoint *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
conn, ok := connectionCache[endpoint]
func getClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
conn, ok := connectionCache[agent]
if ok {
return service.NewAsyncAgentServiceClient(conn), nil
}

var opts []grpc.DialOption

if endpoint.Insecure {
if agent.Insecure {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
pool, err := x509.SystemCertPool()
Expand All @@ -198,27 +198,27 @@ func getClientFunc(ctx context.Context, endpoint *Agent, connectionCache map[*Ag
opts = append(opts, grpc.WithTransportCredentials(creds))
}

if len(endpoint.DefaultServiceConfig) != 0 {
opts = append(opts, grpc.WithDefaultServiceConfig(endpoint.DefaultServiceConfig))
if len(agent.DefaultServiceConfig) != 0 {
opts = append(opts, grpc.WithDefaultServiceConfig(agent.DefaultServiceConfig))
}

var err error
conn, err = grpc.Dial(endpoint.Endpoint, opts...)
conn, err = grpc.Dial(agent.Endpoint, opts...)
if err != nil {
return nil, err
}
connectionCache[endpoint] = conn
connectionCache[agent] = conn
defer func() {
if err != nil {
if cerr := conn.Close(); cerr != nil {
grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr)
grpclog.Infof("Failed to close conn to %s: %v", agent, cerr)
}
return
}
go func() {
<-ctx.Done()
if cerr := conn.Close(); cerr != nil {
grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr)
grpclog.Infof("Failed to close conn to %s: %v", agent, cerr)
}
}()
}()
Expand All @@ -237,16 +237,16 @@ func buildTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionM
}
}

func getFinalTimeout(operation string, endpoint *Agent) config.Duration {
if t, exists := endpoint.Timeouts[operation]; exists {
func getFinalTimeout(operation string, agent *Agent) config.Duration {
if t, exists := agent.Timeouts[operation]; exists {
return t
}

return endpoint.DefaultTimeout
return agent.DefaultTimeout
}

func getFinalContext(ctx context.Context, operation string, endpoint *Agent) (context.Context, context.CancelFunc) {
timeout := getFinalTimeout(operation, endpoint).Duration
func getFinalContext(ctx context.Context, operation string, agent *Agent) (context.Context, context.CancelFunc) {
timeout := getFinalTimeout(operation, agent).Duration
if timeout == 0 {
return ctx, func() {}
}
Expand Down
20 changes: 10 additions & 10 deletions go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ func TestPlugin(t *testing.T) {
assert.NotNil(t, p.PluginLoader)
})

t.Run("test getFinalEndpoint", func(t *testing.T) {
endpoint, _ := getFinalEndpoint("spark", &cfg)
assert.Equal(t, cfg.Agents["spark_agent"].Endpoint, endpoint.Endpoint)
endpoint, _ = getFinalEndpoint("foo", &cfg)
assert.Equal(t, cfg.DefaultAgent.Endpoint, endpoint.Endpoint)
_, err := getFinalEndpoint("bar", &cfg)
t.Run("test getFinalAgent", func(t *testing.T) {
agent, _ := getFinalAgent("spark", &cfg)
assert.Equal(t, cfg.Agents["spark_agent"].Endpoint, agent.Endpoint)
agent, _ = getFinalAgent("foo", &cfg)
assert.Equal(t, cfg.DefaultAgent.Endpoint, agent.Endpoint)
_, err := getFinalAgent("bar", &cfg)
assert.NotNil(t, err)
})

Expand All @@ -72,14 +72,14 @@ func TestPlugin(t *testing.T) {

t.Run("test getClientFunc cache hit", func(t *testing.T) {
connectionCache := make(map[*Agent]*grpc.ClientConn)
endpoint := &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}
agent := &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}

client, err := getClientFunc(context.Background(), endpoint, connectionCache)
client, err := getClientFunc(context.Background(), agent, connectionCache)
assert.NoError(t, err)
assert.NotNil(t, client)
assert.NotNil(t, client, connectionCache[endpoint])
assert.NotNil(t, client, connectionCache[agent])

cachedClient, err := getClientFunc(context.Background(), endpoint, connectionCache)
cachedClient, err := getClientFunc(context.Background(), agent, connectionCache)
assert.NoError(t, err)
assert.NotNil(t, cachedClient)
assert.Equal(t, client, cachedClient)
Expand Down