Skip to content

Commit

Permalink
Merge pull request bitly#466 from clobrano/github-use-login-as-user
Browse files Browse the repository at this point in the history
GitHub use login as user
  • Loading branch information
hlhendy authored Nov 20, 2017
2 parents 6ddbb2c + 731fa9f commit b0c1c85
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 43 deletions.
7 changes: 7 additions & 0 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
if s.Email == "" {
s.Email, err = p.provider.GetEmailAddress(s)
}

if s.User == "" {
s.User, err = p.provider.GetUserName(s)
if err != nil && err.Error() == "not implemented" {
err = nil
}
}
return
}

Expand Down
47 changes: 45 additions & 2 deletions providers/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
} else {
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
}

log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)

if err := json.Unmarshal(body, &emails); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}
Expand All @@ -234,3 +234,46 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {

return "", nil
}

func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
var user struct {
Login string `json:"login"`
Email string `json:"email"`
}

endpoint := &url.URL{
Scheme: p.ValidateURL.Scheme,
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user"),
}

req, err := http.NewRequest("GET", endpoint.String(), nil)
if err != nil {
return "", fmt.Errorf("could not create new GET request: %v", err)
}

req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}

body, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
if err != nil {
return "", err
}

if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
}

log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)

if err := json.Unmarshal(body, &user); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}

return user.Login, nil
}
146 changes: 146 additions & 0 deletions providers/github_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package providers

import (
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
)

func testGitHubProvider(hostname string) *GitHubProvider {
p := NewGitHubProvider(
&ProviderData{
ProviderName: "",
LoginURL: &url.URL{},
RedeemURL: &url.URL{},
ProfileURL: &url.URL{},
ValidateURL: &url.URL{},
Scope: ""})
if hostname != "" {
updateURL(p.Data().LoginURL, hostname)
updateURL(p.Data().RedeemURL, hostname)
updateURL(p.Data().ProfileURL, hostname)
updateURL(p.Data().ValidateURL, hostname)
}
return p
}

func testGitHubBackend(payload string) *httptest.Server {
pathToQueryMap := map[string]string{
"/user": "",
"/user/emails": "",
}

return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
query, ok := pathToQueryMap[url.Path]
if !ok {
w.WriteHeader(404)
} else if url.RawQuery != query {
w.WriteHeader(404)
} else {
w.WriteHeader(200)
w.Write([]byte(payload))
}
}))
}

func TestGitHubProviderDefaults(t *testing.T) {
p := testGitHubProvider("")
assert.NotEqual(t, nil, p)
assert.Equal(t, "GitHub", p.Data().ProviderName)
assert.Equal(t, "https://github.com/login/oauth/authorize",
p.Data().LoginURL.String())
assert.Equal(t, "https://github.com/login/oauth/access_token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://api.github.com/",
p.Data().ValidateURL.String())
assert.Equal(t, "user:email", p.Data().Scope)
}

func TestGitHubProviderOverrides(t *testing.T) {
p := NewGitHubProvider(
&ProviderData{
LoginURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/login/oauth/authorize"},
RedeemURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/login/oauth/access_token"},
ValidateURL: &url.URL{
Scheme: "https",
Host: "api.example.com",
Path: "/"},
Scope: "profile"})
assert.NotEqual(t, nil, p)
assert.Equal(t, "GitHub", p.Data().ProviderName)
assert.Equal(t, "https://example.com/login/oauth/authorize",
p.Data().LoginURL.String())
assert.Equal(t, "https://example.com/login/oauth/access_token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://api.example.com/",
p.Data().ValidateURL.String())
assert.Equal(t, "profile", p.Data().Scope)
}

func TestGitHubProviderGetEmailAddress(t *testing.T) {
b := testGitHubBackend(`[ {"email": "[email protected]", "primary": true} ]`)
defer b.Close()

bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)

session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "[email protected]", email)
}

// Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse.
func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testGitHubBackend("unused payload")
defer b.Close()

bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)

// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}

func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testGitHubBackend("{\"foo\": \"bar\"}")
defer b.Close()

bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)

session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}

func TestGitHubProviderGetUserName(t *testing.T) {
b := testGitHubBackend(`{"email": "[email protected]", "login": "mbland"}`)
defer b.Close()

bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)

session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetUserName(session)
assert.Equal(t, nil, err)
assert.Equal(t, "mbland", email)
}
5 changes: 5 additions & 0 deletions providers/provider_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
return "", errors.New("not implemented")
}

// GetUserName returns the Account username
func (p *ProviderData) GetUserName(s *SessionState) (string, error) {
return "", errors.New("not implemented")
}

// ValidateGroup validates that the provided email exists in the configured provider
// email group(s).
func (p *ProviderData) ValidateGroup(email string) bool {
Expand Down
1 change: 1 addition & 0 deletions providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
type Provider interface {
Data() *ProviderData
GetEmailAddress(*SessionState) (string, error)
GetUserName(*SessionState) (string, error)
Redeem(string, string) (*SessionState, error)
ValidateGroup(string) bool
ValidateSessionState(*SessionState) bool
Expand Down
76 changes: 40 additions & 36 deletions providers/session_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (s *SessionState) IsExpired() bool {
}

func (s *SessionState) String() string {
o := fmt.Sprintf("Session{%s", s.userOrEmail())
o := fmt.Sprintf("Session{%s", s.accountInfo())
if s.AccessToken != "" {
o += " token:true"
}
Expand All @@ -40,17 +40,13 @@ func (s *SessionState) String() string {

func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
if c == nil || s.AccessToken == "" {
return s.userOrEmail(), nil
return s.accountInfo(), nil
}
return s.EncryptedString(c)
}

func (s *SessionState) userOrEmail() string {
u := s.User
if s.Email != "" {
u = s.Email
}
return u
func (s *SessionState) accountInfo() string {
return fmt.Sprintf("email:%s user:%s", s.Email, s.User)
}

func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
Expand All @@ -60,56 +56,64 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
}
a := s.AccessToken
if a != "" {
a, err = c.Encrypt(a)
if err != nil {
if a, err = c.Encrypt(a); err != nil {
return "", err
}
}
r := s.RefreshToken
if r != "" {
r, err = c.Encrypt(r)
if err != nil {
if r, err = c.Encrypt(r); err != nil {
return "", err
}
}
return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil
return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil
}

func decodeSessionStatePlain(v string) (s *SessionState, err error) {
chunks := strings.Split(v, " ")
if len(chunks) != 2 {
return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks))
}

email := strings.TrimPrefix(chunks[0], "email:")
user := strings.TrimPrefix(chunks[1], "user:")
if user == "" {
user = strings.Split(email, "@")[0]
}

return &SessionState{User: user, Email: email}, nil
}

func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
chunks := strings.Split(v, "|")
if len(chunks) == 1 {
if strings.Contains(chunks[0], "@") {
u := strings.Split(v, "@")[0]
return &SessionState{Email: v, User: u}, nil
}
return &SessionState{User: v}, nil
if c == nil {
return decodeSessionStatePlain(v)
}

chunks := strings.Split(v, "|")
if len(chunks) != 4 {
err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks))
return
}

s = &SessionState{}
if c != nil && chunks[1] != "" {
s.AccessToken, err = c.Decrypt(chunks[1])
if err != nil {
sessionState, err := decodeSessionStatePlain(chunks[0])
if err != nil {
return nil, err
}

if chunks[1] != "" {
if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil {
return nil, err
}
}
if c != nil && chunks[3] != "" {
s.RefreshToken, err = c.Decrypt(chunks[3])
if err != nil {

ts, _ := strconv.Atoi(chunks[2])
sessionState.ExpiresOn = time.Unix(int64(ts), 0)

if chunks[3] != "" {
if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil {
return nil, err
}
}
if u := chunks[0]; strings.Contains(u, "@") {
s.Email = u
s.User = strings.Split(u, "@")[0]
} else {
s.User = u
}
ts, _ := strconv.Atoi(chunks[2])
s.ExpiresOn = time.Unix(int64(ts), 0)
return

return sessionState, nil
}
Loading

0 comments on commit b0c1c85

Please sign in to comment.