Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
beyang committed Nov 18, 2018
1 parent 4e0e699 commit 642b990
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
20 changes: 10 additions & 10 deletions github/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func githubHandler(config *oauth2.Config, success, failure http.Handler) http.Ha
httpClient := config.Client(ctx, token)
githubClient, err := githubClientFromAuthURL(config.Endpoint.AuthURL, httpClient)
if err != nil {
ctx = gologin.WithError(ctx, fmt.Errorf("could not parse AuthURL %s", config.Endpoint.AuthURL))
ctx = gologin.WithError(ctx, fmt.Errorf("github: error creating Client: %v", err))
failure.ServeHTTP(w, req.WithContext(ctx))
return
}
Expand Down Expand Up @@ -94,17 +94,17 @@ func validateResponse(user *github.User, resp *github.Response, err error) error
}

func githubClientFromAuthURL(authURL string, httpClient *http.Client) (*github.Client, error) {
if strings.HasPrefix(authURL, "https://github.com/") {
return github.NewClient(httpClient), nil
} else {
client := github.NewClient(httpClient)
if !strings.HasPrefix(authURL, "https://github.com/") {
// convert authURL to GHE baseURL https://<mycompany>.github.jparrowsec.cn/api/v3/
baseURL, err := url.Parse(authURL)
if err != nil {
return nil, fmt.Errorf("could not parse AuthURL %s", authURL)
return nil, fmt.Errorf("github: error parsing Endoint.AuthURL: %s", authURL)
}
baseURL.Path = ""
baseURL.RawQuery = ""
baseURL.Fragment = ""
baseURLStr := strings.TrimSuffix(baseURL.String(), "/") + "/api/v3/"
return github.NewEnterpriseClient(baseURLStr, baseURLStr, httpClient)
baseURL.Path = "/api/v3/"

client.BaseURL = baseURL
client.UploadURL = baseURL
}
return client, nil
}
21 changes: 21 additions & 0 deletions github/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func TestGithubHandler(t *testing.T) {
ctx = oauth2Login.WithToken(ctx, anyToken)

config := &oauth2.Config{}
config.Endpoint.AuthURL = "https://github.com/login/oauth/authorize"
success := func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
githubUser, err := UserFromContext(ctx)
Expand Down Expand Up @@ -108,3 +109,23 @@ func TestValidateResponse(t *testing.T) {
assert.Equal(t, ErrUnableToGetGithubUser, validateResponse(validUser, invalidResponse, nil))
assert.Equal(t, ErrUnableToGetGithubUser, validateResponse(&github.User{}, validResponse, nil))
}

func Test_githubClientFromAuthURL(t *testing.T) {
for _, test := range []struct {
authURL string
expClientBaseURL string
}{
{authURL: "https://github.com/login/oauth/authorize/", expClientBaseURL: "https://api.github.com/"},
{authURL: "https://github.com/login/oauth/authorize", expClientBaseURL: "https://api.github.com/"},
{authURL: "https://github.mycompany.com/login/oauth/authorize", expClientBaseURL: "https://github.mycompany.com/api/v3/"},
{authURL: "http://github.mycompany.com/login/oauth/authorize", expClientBaseURL: "http://github.mycompany.com/api/v3/"},
} {
client, err := githubClientFromAuthURL(test.authURL, nil)
if err != nil {
t.Fatal(err)
}
if got, want := client.BaseURL.String(), test.expClientBaseURL; got != want {
t.Errorf("For authorization URL %q, expected client URL %q, but got %q", test.authURL, want, got)
}
}
}

0 comments on commit 642b990

Please sign in to comment.