diff --git a/flyteadmin/pkg/manager/impl/execution_manager.go b/flyteadmin/pkg/manager/impl/execution_manager.go index aab62c12f98..226a9c1e4e7 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager.go +++ b/flyteadmin/pkg/manager/impl/execution_manager.go @@ -409,16 +409,17 @@ func (m *ExecutionManager) getClusterAssignment(ctx context.Context, req *admin. return nil, err } - if req.GetSpec().GetClusterAssignment() == nil { + reqAssignment := req.GetSpec().GetClusterAssignment() + reqPool := reqAssignment.GetClusterPoolName() + storedPool := storedAssignment.GetClusterPoolName() + if reqPool == "" { return storedAssignment, nil } - if storedAssignment == nil { - return req.GetSpec().GetClusterAssignment(), nil + if storedPool == "" { + return reqAssignment, nil } - reqPool := req.Spec.ClusterAssignment.GetClusterPoolName() - storedPool := storedAssignment.GetClusterPoolName() if reqPool != storedPool { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "execution with project %q and domain %q cannot run on cluster pool %q, because its configured to run on pool %q", req.Project, req.Domain, reqPool, storedPool) } diff --git a/flyteadmin/pkg/manager/impl/execution_manager_test.go b/flyteadmin/pkg/manager/impl/execution_manager_test.go index 1c8c2b9f60c..5646d192ae2 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/execution_manager_test.go @@ -5519,6 +5519,32 @@ func TestGetClusterAssignment(t *testing.T) { assert.NoError(t, err) assert.True(t, proto.Equal(ca, &reqClusterAssignment)) }) + t.Run("empty value in DB, takes value from request", func(t *testing.T) { + clusterPoolAsstProvider := &runtimeIFaceMocks.ClusterPoolAssignmentConfiguration{} + clusterPoolAsstProvider.OnGetClusterPoolAssignments().Return(runtimeInterfaces.ClusterPoolAssignments{ + workflowIdentifier.GetDomain(): runtimeInterfaces.ClusterPoolAssignment{ + Pool: "", + }, + }) + mockConfig := getMockExecutionsConfigProvider() + mockConfig.(*runtimeMocks.MockConfigurationProvider).AddClusterPoolAssignmentConfiguration(clusterPoolAsstProvider) + + executionManager := ExecutionManager{ + resourceManager: &managerMocks.MockResourceManager{}, + config: mockConfig, + } + + reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"} + ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Spec: &admin.ExecutionSpec{ + ClusterAssignment: &reqClusterAssignment, + }, + }) + assert.NoError(t, err) + assert.True(t, proto.Equal(ca, &reqClusterAssignment)) + }) t.Run("value from request doesn't match value from config", func(t *testing.T) { reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "swimming-pool"} _, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{