diff --git a/pkg/authn/identityheaders/identityheaders_test.go b/pkg/authn/identityheaders/identityheaders_test.go index 00e7d19bc..80d978b49 100644 --- a/pkg/authn/identityheaders/identityheaders_test.go +++ b/pkg/authn/identityheaders/identityheaders_test.go @@ -43,30 +43,22 @@ func TestWithAuthHeaders(t *testing.T) { groupKey := "Group" groupValue := "utzer" + defaultUserHeader := map[string]string{ + userKey: userValue, + groupKey: groupValue, + } + for _, tt := range []struct { name string cfg *identityheaders.AuthnHeaderConfig req *http.Request - header map[string][]string + header map[string]string }{ { - name: "should pass through", - cfg: &identityheaders.AuthnHeaderConfig{}, - req: func() *http.Request { - req, err := http.NewRequest(http.MethodGet, "http://example.com", nil) - if err != nil { - t.Fatal(err) - } - - req.Header.Set(userKey, userValue) - req.Header.Set(groupKey, groupValue) - - return req - }(), - header: map[string][]string{ - userKey: {userValue}, - groupKey: {groupValue}, - }, + name: "should pass through", + cfg: &identityheaders.AuthnHeaderConfig{}, + req: testRequest(t, withHeader(defaultUserHeader)), + header: defaultUserHeader, }, { name: "should set username in header", @@ -74,26 +66,8 @@ func TestWithAuthHeaders(t *testing.T) { UserFieldName: userKey, GroupsFieldName: groupKey, }, - header: map[string][]string{ - userKey: {userValue}, - groupKey: {groupValue}, - }, - req: func() *http.Request { - req, err := http.NewRequest(http.MethodGet, "http://example.com", nil) - if err != nil { - t.Fatal(err) - } - - return req.WithContext( - request.WithUser( - req.Context(), - &user.DefaultInfo{ - Name: userValue, - Groups: []string{groupValue}, - }, - ), - ) - }(), + header: defaultUserHeader, + req: testRequest(t, withUserContext(userValue, groupValue)), }, { name: "should not pass client header", @@ -101,18 +75,8 @@ func TestWithAuthHeaders(t *testing.T) { UserFieldName: userKey, GroupsFieldName: groupKey, }, - req: func() *http.Request { - req, err := http.NewRequest(http.MethodGet, "http://example.com", nil) - if err != nil { - t.Fatal(err) - } - - req.Header.Set(userKey, "admin") - req.Header.Set(groupKey, "system:admin") - - return req - }(), - header: map[string][]string{}, + req: testRequest(t, withHeader(map[string]string{userKey: "admin", groupKey: "system:admin"})), + header: map[string]string{}, }, } { tt := tt @@ -127,8 +91,8 @@ func TestWithAuthHeaders(t *testing.T) { if len(tt.header) > 0 { for k, v := range tt.header { - if tt.req.Header[k][0] != v[0] { - t.Errorf("want: %s\nhave: %s", v[0], tt.req.Header[k][0]) + if tt.req.Header[k][0] != v { + t.Errorf("want: %s\nhave: %s", v, tt.req.Header[k][0]) } } } @@ -256,3 +220,43 @@ type testCase struct { expected description string } + +func testRequest(t *testing.T, withOpts ...func(*http.Request) (*http.Request, error)) *http.Request { + req, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + if err != nil { + t.Fatal(err) + } + + for _, opt := range withOpts { + req, err = opt(req) + if err != nil { + t.Fatal(err) + } + } + + return req +} + +func withHeader(header map[string]string) func(*http.Request) (*http.Request, error) { + return func(req *http.Request) (*http.Request, error) { + for key, value := range header { + req.Header.Set(key, value) + } + + return req, nil + } +} + +func withUserContext(userValue, groupValue string) func(*http.Request) (*http.Request, error) { + return func(req *http.Request) (*http.Request, error) { + return req.WithContext( + request.WithUser( + req.Context(), + &user.DefaultInfo{ + Name: userValue, + Groups: []string{groupValue}, + }, + ), + ), nil + } +}