Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tctl resource selection ux #30081

Merged
merged 5 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions tool/tctl/common/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api"
Expand Down Expand Up @@ -132,10 +133,10 @@ func testKubeServerCollection_writeText(t *testing.T) {
types.DiscoveredNameLabel: "cluster3",
}
kubeServers := []types.KubeServer{
mustCreateNewKubeServer(t, "cluster1", nil),
mustCreateNewKubeServer(t, "cluster2", longLabelFixture),
mustCreateNewKubeServer(t, "afirstCluster", longLabelFixture),
mustCreateNewKubeServer(t, "cluster3-eks-us-west-1-123456789012", eksDiscoveredNameLabel),
mustCreateNewKubeServer(t, "cluster1", "_", nil),
mustCreateNewKubeServer(t, "cluster2", "_", longLabelFixture),
mustCreateNewKubeServer(t, "afirstCluster", "_", longLabelFixture),
mustCreateNewKubeServer(t, "cluster3-eks-us-west-1-123456789012", "_", eksDiscoveredNameLabel),
}
test := writeTextTest{
collection: &kubeServerCollection{servers: kubeServers},
Expand Down Expand Up @@ -307,10 +308,10 @@ func mustCreateNewKubeCluster(t *testing.T, name string, extraStaticLabels map[s
return cluster
}

func mustCreateNewKubeServer(t *testing.T, name string, extraStaticLabels map[string]string) types.KubeServer {
func mustCreateNewKubeServer(t *testing.T, name, hostname string, extraStaticLabels map[string]string) *types.KubernetesServerV3 {
t.Helper()
cluster := mustCreateNewKubeCluster(t, name, extraStaticLabels)
kubeServer, err := types.NewKubernetesServerV3FromCluster(cluster, "some-host", "some-hostid")
kubeServer, err := types.NewKubernetesServerV3FromCluster(cluster, hostname, uuid.New().String())
require.NoError(t, err)
return kubeServer
}
Expand Down
3 changes: 3 additions & 0 deletions tool/tctl/common/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ func makeAndRunTestAuthServer(t *testing.T, opts ...testServerOptionFunc) (auth
}

func waitForDatabases(t *testing.T, auth *service.TeleportProcess, dbs []servicecfg.Database) {
if len(dbs) == 0 {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for {
Expand Down
239 changes: 180 additions & 59 deletions tool/tctl/common/resource_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"math"
"os"
"sort"
"strings"
"time"

"github.com/alecthomas/kingpin/v2"
Expand Down Expand Up @@ -1087,23 +1088,23 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client auth.ClientI) (err
}
fmt.Printf("lock %q has been deleted\n", name)
case types.KindDatabaseServer:
dbServers, err := client.GetDatabaseServers(ctx, apidefaults.Namespace)
servers, err := client.GetDatabaseServers(ctx, apidefaults.Namespace)
if err != nil {
return trace.Wrap(err)
}
deleted := false
for _, server := range dbServers {
if server.GetName() == rc.ref.Name {
if err := client.DeleteDatabaseServer(ctx, apidefaults.Namespace, server.GetHostID(), server.GetName()); err != nil {
return trace.Wrap(err)
}
deleted = true
}
resDesc := "database server"
servers = filterByNameOrPrefix(servers, rc.ref.Name)
name, err := getOneResourceNameToDelete(servers, rc.ref, resDesc)
if err != nil {
return trace.Wrap(err)
}
if !deleted {
return trace.NotFound("database server %q not found", rc.ref.Name)
for _, s := range servers {
err := client.DeleteDatabaseServer(ctx, apidefaults.Namespace, s.GetHostID(), name)
if err != nil {
return trace.Wrap(err)
}
}
fmt.Printf("database server %q has been deleted\n", rc.ref.Name)
fmt.Printf("%s %q has been deleted\n", resDesc, name)
case types.KindNetworkRestrictions:
if err = resetNetworkRestrictions(ctx, client); err != nil {
return trace.Wrap(err)
Expand All @@ -1115,15 +1116,35 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client auth.ClientI) (err
}
fmt.Printf("application %q has been deleted\n", rc.ref.Name)
case types.KindDatabase:
if err = client.DeleteDatabase(ctx, rc.ref.Name); err != nil {
databases, err := client.GetDatabases(ctx)
if err != nil {
return trace.Wrap(err)
}
resDesc := "database"
databases = filterByNameOrPrefix(databases, rc.ref.Name)
name, err := getOneResourceNameToDelete(databases, rc.ref, resDesc)
if err != nil {
return trace.Wrap(err)
}
if err := client.DeleteDatabase(ctx, name); err != nil {
return trace.Wrap(err)
}
fmt.Printf("database %q has been deleted\n", rc.ref.Name)
fmt.Printf("%s %q has been deleted\n", resDesc, name)
case types.KindKubernetesCluster:
if err = client.DeleteKubernetesCluster(ctx, rc.ref.Name); err != nil {
clusters, err := client.GetKubernetesClusters(ctx)
if err != nil {
return trace.Wrap(err)
}
fmt.Printf("kubernetes cluster %q has been deleted\n", rc.ref.Name)
resDesc := "kubernetes cluster"
clusters = filterByNameOrPrefix(clusters, rc.ref.Name)
name, err := getOneResourceNameToDelete(clusters, rc.ref, resDesc)
if err != nil {
return trace.Wrap(err)
}
if err := client.DeleteKubernetesCluster(ctx, name); err != nil {
return trace.Wrap(err)
}
fmt.Printf("%s %q has been deleted\n", resDesc, name)
case types.KindWindowsDesktopService:
if err = client.DeleteWindowsDesktopService(ctx, rc.ref.Name); err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -1176,23 +1197,23 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client auth.ClientI) (err
}
fmt.Printf("%s '%s/%s' has been deleted\n", types.KindCertAuthority, rc.ref.SubKind, rc.ref.Name)
case types.KindKubeServer:
kubeServers, err := client.GetKubernetesServers(ctx)
servers, err := client.GetKubernetesServers(ctx)
if err != nil {
return trace.Wrap(err)
}
deleted := false
for _, server := range kubeServers {
if server.GetName() == rc.ref.Name {
if err := client.DeleteKubernetesServer(ctx, server.GetHostID(), server.GetName()); err != nil {
return trace.Wrap(err)
}
deleted = true
}
resDesc := "kubernetes server"
servers = filterByNameOrPrefix(servers, rc.ref.Name)
name, err := getOneResourceNameToDelete(servers, rc.ref, resDesc)
if err != nil {
return trace.Wrap(err)
}
if !deleted {
return trace.NotFound("kubernetes server %q not found", rc.ref.Name)
for _, s := range servers {
err := client.DeleteKubernetesServer(ctx, s.GetHostID(), name)
if err != nil {
return trace.Wrap(err)
}
}
fmt.Printf("kubernetes server %q has been deleted\n", rc.ref.Name)
fmt.Printf("%s %q has been deleted\n", resDesc, name)
case types.KindUIConfig:
err := client.DeleteUIConfig(ctx)
if err != nil {
Expand Down Expand Up @@ -1652,16 +1673,11 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client auth.Client
return &databaseServerCollection{servers: servers}, nil
}

var out []types.DatabaseServer
for _, server := range servers {
if server.GetName() == rc.ref.Name {
out = append(out, server)
}
}
if len(out) == 0 {
servers = filterByNameOrPrefix(servers, rc.ref.Name)
if len(servers) == 0 {
return nil, trace.NotFound("database server %q not found", rc.ref.Name)
}
return &databaseServerCollection{servers: out}, nil
return &databaseServerCollection{servers: servers}, nil
case types.KindKubeServer:
servers, err := client.GetKubernetesServers(ctx)
if err != nil {
Expand All @@ -1670,17 +1686,14 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client auth.Client
if rc.ref.Name == "" {
return &kubeServerCollection{servers: servers}, nil
}

var out []types.KubeServer
for _, server := range servers {
if server.GetName() == rc.ref.Name || server.GetHostname() == rc.ref.Name {
out = append(out, server)
}
altNameFn := func(r types.KubeServer) string {
return r.GetHostname()
}
if len(out) == 0 {
servers = filterByNameOrPrefix(servers, rc.ref.Name, altNameFn)
if len(servers) == 0 {
return nil, trace.NotFound("kubernetes server %q not found", rc.ref.Name)
}
return &kubeServerCollection{servers: out}, nil
return &kubeServerCollection{servers: servers}, nil

case types.KindAppServer:
servers, err := client.GetApplicationServers(ctx, rc.namespace)
Expand Down Expand Up @@ -1721,31 +1734,31 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client auth.Client
}
return &appCollection{apps: []types.Application{app}}, nil
case types.KindDatabase:
databases, err := client.GetDatabases(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
if rc.ref.Name == "" {
databases, err := client.GetDatabases(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return &databaseCollection{databases: databases}, nil
}
database, err := client.GetDatabase(ctx, rc.ref.Name)
databases = filterByNameOrPrefix(databases, rc.ref.Name)
if len(databases) == 0 {
return nil, trace.NotFound("database %q not found", rc.ref.Name)
}
return &databaseCollection{databases: databases}, nil
case types.KindKubernetesCluster:
clusters, err := client.GetKubernetesClusters(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return &databaseCollection{databases: []types.Database{database}}, nil
case types.KindKubernetesCluster:
if rc.ref.Name == "" {
clusters, err := client.GetKubernetesClusters(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return &kubeClusterCollection{clusters: clusters}, nil
}
cluster, err := client.GetKubernetesCluster(ctx, rc.ref.Name)
if err != nil {
return nil, trace.Wrap(err)
clusters = filterByNameOrPrefix(clusters, rc.ref.Name)
if len(clusters) == 0 {
return nil, trace.NotFound("kubernetes cluster %q not found", rc.ref.Name)
}
return &kubeClusterCollection{clusters: []types.KubeCluster{cluster}}, nil
return &kubeClusterCollection{clusters: clusters}, nil
case types.KindWindowsDesktopService:
services, err := client.GetWindowsDesktopServices(ctx)
if err != nil {
Expand Down Expand Up @@ -2124,3 +2137,111 @@ func findDeviceByIDOrTag(ctx context.Context, remote devicepb.DeviceTrustService

return nil, trace.BadParameter("found multiple devices for asset tag %q, please retry using the device ID instead", idOrTag)
}

// keepFn is a predicate function that returns true if a resource should be
// retained by filterResources.
type keepFn[T types.ResourceWithLabels] func(T) bool

// filterResources takes a list of resources and returns a filtered list of
// resources for which the `keep` predicate function returns true.
func filterResources[T types.ResourceWithLabels](resources []T, keep keepFn[T]) []T {
out := make([]T, 0, len(resources))
for _, r := range resources {
if keep(r) {
out = append(out, r)
}
}
return out
}

// altNameFn is a func that returns an alternative name for a resource.
type altNameFn[T types.ResourceWithLabels] func(T) string

// filterByNameOrPrefix filters resources by name or a prefix of the name.
// It prefers exact name filtering first - if none of the resource names match
// exactly (i.e. all of the resources are filtered out), then it retries and
// filters the resources by prefix of resource name instead.
// This is to avoid an annoying UX, for example:
// resources: [foo, foobar]
// $ tctl rm foo <- should select foo by exact name instead of matching both by
// prefix "foo".
func filterByNameOrPrefix[T types.ResourceWithLabels](resources []T, prefixOrName string, extra ...altNameFn[T]) []T {
// prefer exact names
out := filterByName(resources, prefixOrName, extra...)
if len(out) == 0 {
// fallback to looking for prefixes
out = filterByPrefix(resources, prefixOrName, extra...)
}
return out
}

// filterByName filters resources by exact name match.
func filterByName[T types.ResourceWithLabels](resources []T, name string, altNameFns ...altNameFn[T]) []T {
return filterResources(resources, func(r T) bool {
if r.GetName() == name {
return true
}
for _, altName := range altNameFns {
if altName(r) == name {
return true
}
}
return false
})
}

// filterByPrefix filters resources by a prefix of the resource name.
func filterByPrefix[T types.ResourceWithLabels](resources []T, prefix string, altNameFns ...altNameFn[T]) []T {
return filterResources(resources, func(r T) bool {
if strings.HasPrefix(r.GetName(), prefix) {
return true
}
for _, altName := range altNameFns {
if strings.HasPrefix(altName(r), prefix) {
return true
}
}
return false
})
}

// getOneResourceNameToDelete checks a list of resources to ensure there is
// exactly one resource name among them, and returns that name or an error.
// Heartbeat resources can have the same name but different host ID, so this
// still allows a user to delete multiple heartbeats of the same name, for
// example `$ tctl rm db_server/someDB`.
func getOneResourceNameToDelete[T types.ResourceWithLabels](rs []T, ref services.Ref, resDesc string) (string, error) {
seen := make(map[string]struct{})
for _, r := range rs {
seen[r.GetName()] = struct{}{}
}
switch len(seen) {
case 1: // need exactly one.
return rs[0].GetName(), nil
case 0:
return "", trace.NotFound("%v %q not found", resDesc, ref.Name)
default:
names := make([]string, 0, len(rs))
for _, r := range rs {
names = append(names, r.GetName())
}
msg := formatAmbiguousDeleteMessage(ref, resDesc, names)
return "", trace.BadParameter(msg)
}
}

// formatAmbiguousDeleteMessage returns a formatted message when a user is
// attempting to delete multiple resources by an ambiguous prefix of the
// resource names.
func formatAmbiguousDeleteMessage(ref services.Ref, resDesc string, names []string) string {
slices.Sort(names)
// choose an actual resource for the example in the error.
exampleRef := ref
exampleRef.Name = names[0]
return fmt.Sprintf(`%s matches multiple %vs as a name prefix:
%v

Use either a full resource name or an unambiguous prefix, for example:
$ tctl rm %s`,
ref.String(), resDesc, strings.Join(names, "\n"), exampleRef.String())
}
Loading