Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Hoey <[email protected]>
  • Loading branch information
snuggie12 committed Oct 28, 2022
1 parent 151a053 commit 6e750a8
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions connector/google/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
var groups []string
if s.Groups && c.adminSrv != nil {
checkedGroups := make(map[string]struct{})
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
groups, err = getGroups(c.getGroupsList, claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return identity, fmt.Errorf("google: could not retrieve groups: %v", err)
}
Expand All @@ -252,15 +252,22 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
return identity, nil
}

// getGroupsList returns a list of Groups from google
func (c *googleConnector) getGroupsList(email string, nextPageToken string) (*admin.Groups, error) {
groupsList, err := c.adminSrv.Groups.List().
UserKey(email).PageToken(nextPageToken).Do()
return groupsList, err
}

// getGroups creates a connection to the admin directory service and lists
// all groups the user is a member of
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
// to test functionality, first parameter is the function you want to run to fetch groups
func getGroups(getGroupsListFunc func(string, string) (*admin.Groups, error), email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
var userGroups []string
var err error
groupsList := &admin.Groups{}
for {
groupsList, err = c.adminSrv.Groups.List().
UserKey(email).PageToken(groupsList.NextPageToken).Do()
groupsList, err = getGroupsListFunc(email, groupsList.NextPageToken)
if err != nil {
return nil, fmt.Errorf("could not list groups: %v", err)
}
Expand All @@ -279,7 +286,7 @@ func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership
}

// getGroups takes a user's email/alias as well as a group's email/alias
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups)
transitiveGroups, err := getGroups(getGroupsListFunc, group.Email, fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return nil, fmt.Errorf("could not list transitive groups: %v", err)
}
Expand Down

0 comments on commit 6e750a8

Please sign in to comment.