Skip to content

Commit

Permalink
Optimize desktop backend reads
Browse files Browse the repository at this point in the history
Updates #52062
  • Loading branch information
zmb3 committed Feb 12, 2025
1 parent 7352a80 commit 0dc4e59
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 28 deletions.
113 changes: 95 additions & 18 deletions lib/services/local/desktops.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,61 @@ func NewWindowsDesktopService(backend backend.Backend) *WindowsDesktopService {
return &WindowsDesktopService{Backend: backend}
}

func (s *WindowsDesktopService) getWindowsDesktop(ctx context.Context, name, hostID string) (types.WindowsDesktop, error) {
key := backend.ExactKey(windowsDesktopsPrefix, hostID, name)
item, err := s.Get(ctx, key)
if err != nil {
return nil, trace.Wrap(err)
}
desktop, err := itemToWindowsDesktop(item)
if err != nil {
return nil, trace.Wrap(err)
}
return desktop, nil
}

func itemToWindowsDesktop(item *backend.Item) (types.WindowsDesktop, error) {
desktop, err := services.UnmarshalWindowsDesktop(
item.Value,
services.WithExpires(item.Expires),
services.WithRevision(item.Revision),
)
return desktop, trace.Wrap(err)
}

func (s *WindowsDesktopService) getWindowsDesktopsForHostID(ctx context.Context, hostID string, limit int) ([]types.WindowsDesktop, error) {
startKey := backend.ExactKey(windowsDesktopsPrefix, hostID)
result, err := s.GetRange(ctx, startKey, backend.RangeEnd(startKey), limit)
if err != nil {
return nil, trace.Wrap(err)
}

var desktops []types.WindowsDesktop
for _, item := range result.Items {
desktop, err := itemToWindowsDesktop(&item)
if err != nil {
return nil, trace.Wrap(err)
}
desktops = append(desktops, desktop)
}

return desktops, nil
}

// GetWindowsDesktops returns all Windows desktops matching filter.
func (s *WindowsDesktopService) GetWindowsDesktops(ctx context.Context, filter types.WindowsDesktopFilter) ([]types.WindowsDesktop, error) {
// do a point-read instead of a range-read if a filter is provided
if filter.HostID != "" && filter.Name == "" {
desktop, err := s.getWindowsDesktop(ctx, filter.Name, filter.HostID)
if err != nil {
return nil, trace.Wrap(err)
}
return []types.WindowsDesktop{desktop}, nil
}
if filter.HostID != "" && filter.Name == "" {
return s.getWindowsDesktopsForHostID(ctx, filter.HostID, backend.NoLimit)
}

startKey := backend.ExactKey(windowsDesktopsPrefix)
result, err := s.GetRange(ctx, startKey, backend.RangeEnd(startKey), backend.NoLimit)
if err != nil {
Expand All @@ -48,8 +101,7 @@ func (s *WindowsDesktopService) GetWindowsDesktops(ctx context.Context, filter t

var desktops []types.WindowsDesktop
for _, item := range result.Items {
desktop, err := services.UnmarshalWindowsDesktop(item.Value,
services.WithExpires(item.Expires), services.WithRevision(item.Revision))
desktop, err := itemToWindowsDesktop(&item)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -58,10 +110,6 @@ func (s *WindowsDesktopService) GetWindowsDesktops(ctx context.Context, filter t
}
desktops = append(desktops, desktop)
}
// If both HostID and Name are set in the filter only one desktop should be expected
if filter.HostID != "" && filter.Name != "" && len(desktops) == 0 {
return nil, trace.NotFound("windows desktop \"%s/%s\" doesn't exist", filter.HostID, filter.Name)
}

return desktops, nil
}
Expand Down Expand Up @@ -176,14 +224,11 @@ func (s *WindowsDesktopService) ListWindowsDesktops(ctx context.Context, req typ
return nil, trace.BadParameter("nonpositive parameter limit")
}

rangeStart := backend.NewKey(windowsDesktopsPrefix, req.StartKey)
rangeEnd := backend.RangeEnd(backend.ExactKey(windowsDesktopsPrefix))
filter := services.MatchResourceFilter{
ResourceKind: types.KindWindowsDesktop,
Labels: req.Labels,
SearchKeywords: req.SearchKeywords,
}

if req.PredicateExpression != "" {
expression, err := services.NewResourceExpression(req.PredicateExpression)
if err != nil {
Expand All @@ -192,6 +237,41 @@ func (s *WindowsDesktopService) ListWindowsDesktops(ctx context.Context, req typ
filter.PredicateExpression = expression
}

// do a point-read instead of a range-read if a filter is provided
if req.HostID != "" && req.Name != "" {
desktop, err := s.getWindowsDesktop(ctx, req.Name, req.HostID)
if trace.IsNotFound(err) {
return &types.ListWindowsDesktopsResponse{}, nil
} else if err != nil {
return nil, trace.Wrap(err)
}

match, err := services.MatchResourceByFilters(desktop, filter, nil /* ignore dup matches */)
if err != nil {
return nil, trace.Wrap(err)
}

if !match {
return &types.ListWindowsDesktopsResponse{}, nil
}

return &types.ListWindowsDesktopsResponse{
Desktops: []types.WindowsDesktop{desktop},
}, nil
}

var rangeStart, rangeEnd backend.Key
if req.HostID != "" && req.Name == "" {
rangeStart = backend.NewKey(windowsDesktopsPrefix, req.HostID)
if req.StartKey != "" {
rangeStart = backend.NewKey(windowsDesktopsPrefix, req.StartKey)
}
rangeEnd = backend.RangeEnd(rangeStart)
} else {
rangeStart = backend.NewKey(windowsDesktopsPrefix, req.StartKey)
rangeEnd = backend.RangeEnd(backend.ExactKey(windowsDesktopsPrefix))
}

// Get most limit+1 results to determine if there will be a next key.
maxLimit := reqLimit + 1
var desktops []types.WindowsDesktop
Expand All @@ -201,8 +281,7 @@ func (s *WindowsDesktopService) ListWindowsDesktops(ctx context.Context, req typ
break
}

desktop, err := services.UnmarshalWindowsDesktop(item.Value,
services.WithExpires(item.Expires), services.WithRevision(item.Revision))
desktop, err := itemToWindowsDesktop(&item)
if err != nil {
return false, trace.Wrap(err)
}
Expand All @@ -224,11 +303,6 @@ func (s *WindowsDesktopService) ListWindowsDesktops(ctx context.Context, req typ
return nil, trace.Wrap(err)
}

// If both HostID and Name are set in the filter only one desktop should be expected
if req.HostID != "" && req.Name != "" && len(desktops) == 0 {
return nil, trace.NotFound("windows desktop \"%s/%s\" doesn't exist", req.HostID, req.Name)
}

var nextKey string
if len(desktops) > reqLimit {
nextKey = backend.GetPaginationKey(desktops[len(desktops)-1])
Expand Down Expand Up @@ -273,8 +347,11 @@ func (s *WindowsDesktopService) ListWindowsDesktopServices(ctx context.Context,
break
}

desktop, err := services.UnmarshalWindowsDesktopService(item.Value,
services.WithExpires(item.Expires), services.WithRevision(item.Revision))
desktop, err := services.UnmarshalWindowsDesktopService(
item.Value,
services.WithExpires(item.Expires),
services.WithRevision(item.Revision),
)
if err != nil {
return false, trace.Wrap(err)
}
Expand Down
77 changes: 67 additions & 10 deletions lib/services/local/desktops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ func TestListWindowsDesktops(t *testing.T) {

// With label.
testLabel := map[string]string{"env": "test"}
d1, err := types.NewWindowsDesktopV3("apple", testLabel, types.WindowsDesktopSpecV3{Addr: "_"})
d1, err := types.NewWindowsDesktopV3("apple", testLabel, types.WindowsDesktopSpecV3{Addr: "_", HostID: "hostA"})
require.NoError(t, err)
require.NoError(t, service.CreateWindowsDesktop(ctx, d1))

d2, err := types.NewWindowsDesktopV3("banana", testLabel, types.WindowsDesktopSpecV3{Addr: "_"})
d2, err := types.NewWindowsDesktopV3("banana", testLabel, types.WindowsDesktopSpecV3{Addr: "_", HostID: "hostA"})
require.NoError(t, err)
require.NoError(t, service.CreateWindowsDesktop(ctx, d2))

Expand All @@ -71,6 +71,18 @@ func TestListWindowsDesktops(t *testing.T) {
require.NoError(t, err)
require.NoError(t, service.CreateWindowsDesktop(ctx, d3))

// Test fetch by host ID
out, err = service.ListWindowsDesktops(ctx, types.ListWindowsDesktopsRequest{
Limit: 10,
WindowsDesktopFilter: types.WindowsDesktopFilter{
HostID: "test-host-id",
},
})
require.NoError(t, err)
require.Len(t, out.Desktops, 1)
require.Equal(t, "carrot", out.Desktops[0].GetName())
require.Equal(t, "test-host-id", out.Desktops[0].GetHostID())

// Test fetch all.
out, err = service.ListWindowsDesktops(ctx, types.ListWindowsDesktopsRequest{
Limit: 10,
Expand Down Expand Up @@ -111,6 +123,30 @@ func TestListWindowsDesktops(t *testing.T) {
require.Len(t, resp.Desktops, 1)
require.Equal(t, out.Desktops[2], resp.Desktops[0])
require.Empty(t, resp.NextKey)

// Test paginating while filtering by Host ID

resp, err = service.ListWindowsDesktops(ctx, types.ListWindowsDesktopsRequest{
Limit: 1,
WindowsDesktopFilter: types.WindowsDesktopFilter{
HostID: "hostA",
},
})
require.NoError(t, err)
require.Len(t, resp.Desktops, 1)
require.Equal(t, "apple", resp.Desktops[0].GetName())

resp, err = service.ListWindowsDesktops(ctx, types.ListWindowsDesktopsRequest{
Limit: 1,
StartKey: resp.NextKey,
WindowsDesktopFilter: types.WindowsDesktopFilter{
HostID: "hostA",
},
})
require.NoError(t, err)
require.Len(t, resp.Desktops, 1)
require.Equal(t, "banana", resp.Desktops[0].GetName())
require.Empty(t, resp.NextKey)
}

func TestListWindowsDesktops_Filters(t *testing.T) {
Expand Down Expand Up @@ -147,7 +183,8 @@ func TestListWindowsDesktops_Filters(t *testing.T) {
tests := []struct {
name string
filter types.ListWindowsDesktopsRequest
wantErr bool
assert require.ErrorAssertionFunc
wantLen int
}{
{
name: "NOK non-matching host id and name",
Expand All @@ -158,47 +195,69 @@ func TestListWindowsDesktops_Filters(t *testing.T) {
Name: "no-match",
},
},
wantErr: true,
assert: require.NoError,
wantLen: 0,
},
{
name: "NOK invalid limit",
filter: types.ListWindowsDesktopsRequest{},
wantErr: true,
assert: require.Error,
wantLen: 0,
},
{
name: "matching host id",
filter: types.ListWindowsDesktopsRequest{
Limit: 5,
WindowsDesktopFilter: types.WindowsDesktopFilter{HostID: "test-host-id"},
},
assert: require.NoError,
wantLen: 2,
},
{
name: "matching host id, mismatching labels",
filter: types.ListWindowsDesktopsRequest{
Limit: 5,
Labels: map[string]string{"doesnot": "exist"},
WindowsDesktopFilter: types.WindowsDesktopFilter{HostID: "test-host-id"},
},
assert: require.NoError,
wantLen: 0,
},
{
name: "matching name",
filter: types.ListWindowsDesktopsRequest{
Limit: 5,
WindowsDesktopFilter: types.WindowsDesktopFilter{Name: "banana"},
},
assert: require.NoError,
wantLen: 2,
},
{
name: "with search",
filter: types.ListWindowsDesktopsRequest{
Limit: 5,
SearchKeywords: []string{"env", "test"},
},
assert: require.NoError,
wantLen: 2,
},
{
name: "with labels",
filter: types.ListWindowsDesktopsRequest{
Limit: 5,
Labels: testLabel,
},
assert: require.NoError,
wantLen: 2,
},
{
name: "with predicate",
filter: types.ListWindowsDesktopsRequest{
Limit: 5,
PredicateExpression: `labels.env == "test"`,
},
assert: require.NoError,
wantLen: 2,
},
}

Expand All @@ -207,12 +266,10 @@ func TestListWindowsDesktops_Filters(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
resp, err := service.ListWindowsDesktops(ctx, tc.filter)
tc.assert(t, err)

if tc.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Len(t, resp.Desktops, 2)
if resp != nil {
require.Len(t, resp.Desktops, tc.wantLen)
}
})
}
Expand Down

0 comments on commit 0dc4e59

Please sign in to comment.