Skip to content

Commit

Permalink
[v13] tctl resource selection ux
Browse files Browse the repository at this point in the history
backports #30081 to branch/v13.

* prefix matching for tctl get discovery resources:
  * kube_cluster
  * kube_server
  * db
  * db_server
  * skip 500ms wait for 0 databases in tests
  • Loading branch information
GavinFrazar committed Sep 16, 2023
1 parent 02e0da0 commit 56159a8
Show file tree
Hide file tree
Showing 4 changed files with 638 additions and 269 deletions.
13 changes: 7 additions & 6 deletions tool/tctl/common/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api"
Expand Down Expand Up @@ -132,10 +133,10 @@ func testKubeServerCollection_writeText(t *testing.T) {
types.DiscoveredNameLabel: "cluster3",
}
kubeServers := []types.KubeServer{
mustCreateNewKubeServer(t, "cluster1", nil),
mustCreateNewKubeServer(t, "cluster2", longLabelFixture),
mustCreateNewKubeServer(t, "afirstCluster", longLabelFixture),
mustCreateNewKubeServer(t, "cluster3-eks-us-west-1-123456789012", eksDiscoveredNameLabel),
mustCreateNewKubeServer(t, "cluster1", "_", nil),
mustCreateNewKubeServer(t, "cluster2", "_", longLabelFixture),
mustCreateNewKubeServer(t, "afirstCluster", "_", longLabelFixture),
mustCreateNewKubeServer(t, "cluster3-eks-us-west-1-123456789012", "_", eksDiscoveredNameLabel),
}
test := writeTextTest{
collection: &kubeServerCollection{servers: kubeServers},
Expand Down Expand Up @@ -307,10 +308,10 @@ func mustCreateNewKubeCluster(t *testing.T, name string, extraStaticLabels map[s
return cluster
}

func mustCreateNewKubeServer(t *testing.T, name string, extraStaticLabels map[string]string) types.KubeServer {
func mustCreateNewKubeServer(t *testing.T, name, hostname string, extraStaticLabels map[string]string) *types.KubernetesServerV3 {
t.Helper()
cluster := mustCreateNewKubeCluster(t, name, extraStaticLabels)
kubeServer, err := types.NewKubernetesServerV3FromCluster(cluster, "some-host", "some-hostid")
kubeServer, err := types.NewKubernetesServerV3FromCluster(cluster, hostname, uuid.New().String())
require.NoError(t, err)
return kubeServer
}
Expand Down
3 changes: 3 additions & 0 deletions tool/tctl/common/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ func makeAndRunTestAuthServer(t *testing.T, opts ...testServerOptionFunc) (auth
}

func waitForDatabases(t *testing.T, auth *service.TeleportProcess, dbs []servicecfg.Database) {
if len(dbs) == 0 {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for {
Expand Down
239 changes: 180 additions & 59 deletions tool/tctl/common/resource_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"math"
"os"
"sort"
"strings"
"time"

"github.com/alecthomas/kingpin/v2"
Expand Down Expand Up @@ -1093,23 +1094,23 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client auth.ClientI) (err
}
fmt.Printf("lock %q has been deleted\n", name)
case types.KindDatabaseServer:
dbServers, err := client.GetDatabaseServers(ctx, apidefaults.Namespace)
servers, err := client.GetDatabaseServers(ctx, apidefaults.Namespace)
if err != nil {
return trace.Wrap(err)
}
deleted := false
for _, server := range dbServers {
if server.GetName() == rc.ref.Name {
if err := client.DeleteDatabaseServer(ctx, apidefaults.Namespace, server.GetHostID(), server.GetName()); err != nil {
return trace.Wrap(err)
}
deleted = true
}
resDesc := "database server"
servers = filterByNameOrPrefix(servers, rc.ref.Name)
name, err := getOneResourceNameToDelete(servers, rc.ref, resDesc)
if err != nil {
return trace.Wrap(err)
}
if !deleted {
return trace.NotFound("database server %q not found", rc.ref.Name)
for _, s := range servers {
err := client.DeleteDatabaseServer(ctx, apidefaults.Namespace, s.GetHostID(), name)
if err != nil {
return trace.Wrap(err)
}
}
fmt.Printf("database server %q has been deleted\n", rc.ref.Name)
fmt.Printf("%s %q has been deleted\n", resDesc, name)
case types.KindNetworkRestrictions:
if err = resetNetworkRestrictions(ctx, client); err != nil {
return trace.Wrap(err)
Expand All @@ -1121,15 +1122,35 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client auth.ClientI) (err
}
fmt.Printf("application %q has been deleted\n", rc.ref.Name)
case types.KindDatabase:
if err = client.DeleteDatabase(ctx, rc.ref.Name); err != nil {
databases, err := client.GetDatabases(ctx)
if err != nil {
return trace.Wrap(err)
}
resDesc := "database"
databases = filterByNameOrPrefix(databases, rc.ref.Name)
name, err := getOneResourceNameToDelete(databases, rc.ref, resDesc)
if err != nil {
return trace.Wrap(err)
}
if err := client.DeleteDatabase(ctx, name); err != nil {
return trace.Wrap(err)
}
fmt.Printf("database %q has been deleted\n", rc.ref.Name)
fmt.Printf("%s %q has been deleted\n", resDesc, name)
case types.KindKubernetesCluster:
if err = client.DeleteKubernetesCluster(ctx, rc.ref.Name); err != nil {
clusters, err := client.GetKubernetesClusters(ctx)
if err != nil {
return trace.Wrap(err)
}
fmt.Printf("kubernetes cluster %q has been deleted\n", rc.ref.Name)
resDesc := "kubernetes cluster"
clusters = filterByNameOrPrefix(clusters, rc.ref.Name)
name, err := getOneResourceNameToDelete(clusters, rc.ref, resDesc)
if err != nil {
return trace.Wrap(err)
}
if err := client.DeleteKubernetesCluster(ctx, name); err != nil {
return trace.Wrap(err)
}
fmt.Printf("%s %q has been deleted\n", resDesc, name)
case types.KindWindowsDesktopService:
if err = client.DeleteWindowsDesktopService(ctx, rc.ref.Name); err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -1182,23 +1203,23 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client auth.ClientI) (err
}
fmt.Printf("%s '%s/%s' has been deleted\n", types.KindCertAuthority, rc.ref.SubKind, rc.ref.Name)
case types.KindKubeServer:
kubeServers, err := client.GetKubernetesServers(ctx)
servers, err := client.GetKubernetesServers(ctx)
if err != nil {
return trace.Wrap(err)
}
deleted := false
for _, server := range kubeServers {
if server.GetName() == rc.ref.Name {
if err := client.DeleteKubernetesServer(ctx, server.GetHostID(), server.GetName()); err != nil {
return trace.Wrap(err)
}
deleted = true
}
resDesc := "kubernetes server"
servers = filterByNameOrPrefix(servers, rc.ref.Name)
name, err := getOneResourceNameToDelete(servers, rc.ref, resDesc)
if err != nil {
return trace.Wrap(err)
}
if !deleted {
return trace.NotFound("kubernetes server %q not found", rc.ref.Name)
for _, s := range servers {
err := client.DeleteKubernetesServer(ctx, s.GetHostID(), name)
if err != nil {
return trace.Wrap(err)
}
}
fmt.Printf("kubernetes server %q has been deleted\n", rc.ref.Name)
fmt.Printf("%s %q has been deleted\n", resDesc, name)
case types.KindUIConfig:
err := client.DeleteUIConfig(ctx)
if err != nil {
Expand Down Expand Up @@ -1658,16 +1679,11 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client auth.Client
return &databaseServerCollection{servers: servers}, nil
}

var out []types.DatabaseServer
for _, server := range servers {
if server.GetName() == rc.ref.Name {
out = append(out, server)
}
}
if len(out) == 0 {
servers = filterByNameOrPrefix(servers, rc.ref.Name)
if len(servers) == 0 {
return nil, trace.NotFound("database server %q not found", rc.ref.Name)
}
return &databaseServerCollection{servers: out}, nil
return &databaseServerCollection{servers: servers}, nil
case types.KindKubeServer:
servers, err := client.GetKubernetesServers(ctx)
if err != nil {
Expand All @@ -1676,17 +1692,14 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client auth.Client
if rc.ref.Name == "" {
return &kubeServerCollection{servers: servers}, nil
}

var out []types.KubeServer
for _, server := range servers {
if server.GetName() == rc.ref.Name || server.GetHostname() == rc.ref.Name {
out = append(out, server)
}
altNameFn := func(r types.KubeServer) string {
return r.GetHostname()
}
if len(out) == 0 {
servers = filterByNameOrPrefix(servers, rc.ref.Name, altNameFn)
if len(servers) == 0 {
return nil, trace.NotFound("kubernetes server %q not found", rc.ref.Name)
}
return &kubeServerCollection{servers: out}, nil
return &kubeServerCollection{servers: servers}, nil

case types.KindAppServer:
servers, err := client.GetApplicationServers(ctx, rc.namespace)
Expand Down Expand Up @@ -1727,31 +1740,31 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client auth.Client
}
return &appCollection{apps: []types.Application{app}}, nil
case types.KindDatabase:
databases, err := client.GetDatabases(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
if rc.ref.Name == "" {
databases, err := client.GetDatabases(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return &databaseCollection{databases: databases}, nil
}
database, err := client.GetDatabase(ctx, rc.ref.Name)
databases = filterByNameOrPrefix(databases, rc.ref.Name)
if len(databases) == 0 {
return nil, trace.NotFound("database %q not found", rc.ref.Name)
}
return &databaseCollection{databases: databases}, nil
case types.KindKubernetesCluster:
clusters, err := client.GetKubernetesClusters(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return &databaseCollection{databases: []types.Database{database}}, nil
case types.KindKubernetesCluster:
if rc.ref.Name == "" {
clusters, err := client.GetKubernetesClusters(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return &kubeClusterCollection{clusters: clusters}, nil
}
cluster, err := client.GetKubernetesCluster(ctx, rc.ref.Name)
if err != nil {
return nil, trace.Wrap(err)
clusters = filterByNameOrPrefix(clusters, rc.ref.Name)
if len(clusters) == 0 {
return nil, trace.NotFound("kubernetes cluster %q not found", rc.ref.Name)
}
return &kubeClusterCollection{clusters: []types.KubeCluster{cluster}}, nil
return &kubeClusterCollection{clusters: clusters}, nil
case types.KindWindowsDesktopService:
services, err := client.GetWindowsDesktopServices(ctx)
if err != nil {
Expand Down Expand Up @@ -2130,3 +2143,111 @@ func findDeviceByIDOrTag(ctx context.Context, remote devicepb.DeviceTrustService

return nil, trace.BadParameter("found multiple devices for asset tag %q, please retry using the device ID instead", idOrTag)
}

// keepFn is a predicate function that returns true if a resource should be
// retained by filterResources.
type keepFn[T types.ResourceWithLabels] func(T) bool

// filterResources takes a list of resources and returns a filtered list of
// resources for which the `keep` predicate function returns true.
func filterResources[T types.ResourceWithLabels](resources []T, keep keepFn[T]) []T {
out := make([]T, 0, len(resources))
for _, r := range resources {
if keep(r) {
out = append(out, r)
}
}
return out
}

// altNameFn is a func that returns an alternative name for a resource.
type altNameFn[T types.ResourceWithLabels] func(T) string

// filterByNameOrPrefix filters resources by name or a prefix of the name.
// It prefers exact name filtering first - if none of the resource names match
// exactly (i.e. all of the resources are filtered out), then it retries and
// filters the resources by prefix of resource name instead.
// This is to avoid an annoying UX, for example:
// resources: [foo, foobar]
// $ tctl rm foo <- should select foo by exact name instead of matching both by
// prefix "foo".
func filterByNameOrPrefix[T types.ResourceWithLabels](resources []T, prefixOrName string, extra ...altNameFn[T]) []T {
// prefer exact names
out := filterByName(resources, prefixOrName, extra...)
if len(out) == 0 {
// fallback to looking for prefixes
out = filterByPrefix(resources, prefixOrName, extra...)
}
return out
}

// filterByName filters resources by exact name match.
func filterByName[T types.ResourceWithLabels](resources []T, name string, altNameFns ...altNameFn[T]) []T {
return filterResources(resources, func(r T) bool {
if r.GetName() == name {
return true
}
for _, altName := range altNameFns {
if altName(r) == name {
return true
}
}
return false
})
}

// filterByPrefix filters resources by a prefix of the resource name.
func filterByPrefix[T types.ResourceWithLabels](resources []T, prefix string, altNameFns ...altNameFn[T]) []T {
return filterResources(resources, func(r T) bool {
if strings.HasPrefix(r.GetName(), prefix) {
return true
}
for _, altName := range altNameFns {
if strings.HasPrefix(altName(r), prefix) {
return true
}
}
return false
})
}

// getOneResourceNameToDelete checks a list of resources to ensure there is
// exactly one resource name among them, and returns that name or an error.
// Heartbeat resources can have the same name but different host ID, so this
// still allows a user to delete multiple heartbeats of the same name, for
// example `$ tctl rm db_server/someDB`.
func getOneResourceNameToDelete[T types.ResourceWithLabels](rs []T, ref services.Ref, resDesc string) (string, error) {
seen := make(map[string]struct{})
for _, r := range rs {
seen[r.GetName()] = struct{}{}
}
switch len(seen) {
case 1: // need exactly one.
return rs[0].GetName(), nil
case 0:
return "", trace.NotFound("%v %q not found", resDesc, ref.Name)
default:
names := make([]string, 0, len(rs))
for _, r := range rs {
names = append(names, r.GetName())
}
msg := formatAmbiguousDeleteMessage(ref, resDesc, names)
return "", trace.BadParameter(msg)
}
}

// formatAmbiguousDeleteMessage returns a formatted message when a user is
// attempting to delete multiple resources by an ambiguous prefix of the
// resource names.
func formatAmbiguousDeleteMessage(ref services.Ref, resDesc string, names []string) string {
slices.Sort(names)
// choose an actual resource for the example in the error.
exampleRef := ref
exampleRef.Name = names[0]
return fmt.Sprintf(`%s matches multiple %vs as a name prefix:
%v
Use either a full resource name or an unambiguous prefix, for example:
$ tctl rm %s`,
ref.String(), resDesc, strings.Join(names, "\n"), exampleRef.String())
}
Loading

0 comments on commit 56159a8

Please sign in to comment.