-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SNOW-1825500 Implement authorzation code flow
- Loading branch information
1 parent
f6efcbd
commit 7c843f8
Showing
15 changed files
with
726 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
package gosnowflake | ||
|
||
import ( | ||
"bufio" | ||
"bytes" | ||
"context" | ||
"errors" | ||
"fmt" | ||
"golang.org/x/oauth2" | ||
"io" | ||
"net" | ||
"net/http" | ||
"net/url" | ||
"strconv" | ||
"strings" | ||
) | ||
|
||
const ( | ||
oauthSuccessHTML = `<!DOCTYPE html><html><head><meta charset="UTF-8"/> | ||
<title>OAuth for Snowflake</title></head> | ||
<body> | ||
OAuth authentication completed successfully. | ||
</body></html>` | ||
) | ||
|
||
type oauthClient struct { | ||
ctx context.Context | ||
cfg *Config | ||
|
||
port int | ||
redirectUriTemplate string | ||
|
||
authorizationCodeProviderFactory func() authorizationCodeProvider | ||
} | ||
|
||
func newOauthClient(ctx context.Context, cfg *Config) (*oauthClient, error) { | ||
port := 0 | ||
if cfg.OauthRedirectURI != "" { | ||
uri, err := url.Parse(cfg.OauthRedirectURI) | ||
if err != nil { | ||
return nil, err | ||
} | ||
portStr := uri.Port() | ||
if portStr != "" { | ||
if port, err = strconv.Atoi(portStr); err != nil { | ||
return nil, err | ||
} | ||
} | ||
} | ||
|
||
redirectUriTemplate := "" | ||
if cfg.OauthRedirectURI == "" { | ||
redirectUriTemplate = "http://localhost:%v/" | ||
} | ||
|
||
client := &http.Client{ | ||
Transport: getTransport(cfg), | ||
} | ||
return &oauthClient{ | ||
ctx: context.WithValue(ctx, oauth2.HTTPClient, client), | ||
cfg: cfg, | ||
port: port, | ||
redirectUriTemplate: redirectUriTemplate, | ||
authorizationCodeProviderFactory: func() authorizationCodeProvider { | ||
return &browserBasedAuthorizationCodeProvider{} | ||
}, | ||
}, nil | ||
} | ||
|
||
func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode() (string, error) { | ||
authCodeProvider := oauthClient.authorizationCodeProviderFactory() | ||
|
||
// TODO timeouts | ||
// TODO retries | ||
successChan := make(chan []byte) | ||
errChan := make(chan error) | ||
responseBodyChan := make(chan string, 2) | ||
closeListenerChan := make(chan bool, 2) | ||
|
||
defer func() { | ||
closeListenerChan <- true | ||
close(successChan) | ||
close(errChan) | ||
close(responseBodyChan) | ||
close(closeListenerChan) | ||
}() | ||
|
||
tcpListener, callbackPort, err := oauthClient.setupListener() | ||
if err != nil { | ||
return "", err | ||
} | ||
defer func(tcpListener *net.TCPListener) { | ||
<-closeListenerChan | ||
if err := tcpListener.Close(); err != nil { | ||
logger.Warnf("error while closing TCP listener. %v", err) | ||
} | ||
}(tcpListener) | ||
|
||
go handleOAuthSocket(tcpListener, successChan, errChan, responseBodyChan, closeListenerChan) | ||
|
||
oauth2cfg := oauthClient.buildOauthConfig(callbackPort) | ||
codeVerifier := authCodeProvider.createCodeVerifier() | ||
state := authCodeProvider.createState() | ||
authorizationUrl := oauth2cfg.AuthCodeURL(state, oauth2.S256ChallengeOption(codeVerifier)) | ||
if err = authCodeProvider.run(authorizationUrl); err != nil { | ||
responseBodyChan <- err.Error() | ||
closeListenerChan <- true | ||
return "", err | ||
} | ||
|
||
err = <-errChan | ||
if err != nil { | ||
responseBodyChan <- err.Error() | ||
return "", err | ||
} | ||
codeReqBytes := <-successChan | ||
|
||
codeReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(codeReqBytes))) | ||
if err != nil { | ||
responseBodyChan <- err.Error() | ||
return "", err | ||
} | ||
|
||
return oauthClient.exchangeAccessToken(codeReq, state, err, oauth2cfg, codeVerifier, responseBodyChan, errChan) | ||
} | ||
|
||
func (oauthClient *oauthClient) setupListener() (*net.TCPListener, int, error) { | ||
tcpListener, err := createLocalTCPListener(oauthClient.port) | ||
if err != nil { | ||
return nil, 0, err | ||
} | ||
callbackPort := tcpListener.Addr().(*net.TCPAddr).Port | ||
return tcpListener, callbackPort, nil | ||
} | ||
|
||
func (oauthClient *oauthClient) exchangeAccessToken(codeReq *http.Request, state string, err error, oauth2cfg *oauth2.Config, codeVerifier string, responseBodyChan chan string, errChan chan error) (string, error) { | ||
queryParams := codeReq.URL.Query() | ||
errorMsg := queryParams.Get("error") | ||
if errorMsg != "" { | ||
errorDesc := queryParams.Get("error_description") | ||
errMsg := fmt.Sprintf("error while getting authentication from oauth: %v. Details: %v", errorMsg, errorDesc) | ||
responseBodyChan <- errMsg | ||
return "", errors.New(errMsg) | ||
} | ||
|
||
receivedState := queryParams.Get("state") | ||
if state != receivedState { | ||
errMsg := "invalid oauth state received" | ||
responseBodyChan <- errMsg | ||
return "", errors.New(errMsg) | ||
} | ||
|
||
code := queryParams.Get("code") | ||
token, err := oauth2cfg.Exchange(oauthClient.ctx, code, oauth2.VerifierOption(codeVerifier)) | ||
if err != nil { | ||
responseBodyChan <- err.Error() | ||
return "", err | ||
} | ||
responseBodyChan <- oauthSuccessHTML | ||
return token.AccessToken, nil | ||
} | ||
|
||
func (oauthClient *oauthClient) buildOauthConfig(callbackPort int) *oauth2.Config { | ||
return &oauth2.Config{ | ||
ClientID: oauthClient.cfg.OauthClientID, | ||
ClientSecret: oauthClient.cfg.OauthClientSecret, | ||
RedirectURL: oauthClient.buildRedirectUri(callbackPort), | ||
Scopes: oauthClient.buildScopes(), | ||
Endpoint: oauth2.Endpoint{ | ||
AuthURL: oauthClient.cfg.OauthAuthorizationURL, | ||
TokenURL: oauthClient.cfg.OauthTokenRequestURL, | ||
}, | ||
} | ||
} | ||
|
||
func (oauthClient *oauthClient) buildRedirectUri(port int) string { | ||
if oauthClient.cfg.OauthRedirectURI != "" { | ||
return oauthClient.cfg.OauthRedirectURI | ||
} | ||
return fmt.Sprintf(oauthClient.redirectUriTemplate, port) | ||
} | ||
|
||
func (oauthClient *oauthClient) buildScopes() []string { | ||
if oauthClient.cfg.OauthScope == "" { | ||
return []string{"session:role:" + oauthClient.cfg.Role} | ||
} | ||
scopes := strings.Split(oauthClient.cfg.OauthScope, ",") | ||
for i, scope := range scopes { | ||
scopes[i] = strings.TrimSpace(scope) | ||
} | ||
return scopes | ||
} | ||
|
||
func handleOAuthSocket(tcpListener *net.TCPListener, successChan chan []byte, errChan chan error, responseBodyChan chan string, closeListenerChan chan bool) { | ||
conn, err := tcpListener.AcceptTCP() | ||
if err != nil { | ||
logger.Warnf("error creating socket. %v", err) | ||
return | ||
} | ||
defer conn.Close() | ||
var buf [bufSize]byte | ||
codeResp := bytes.NewBuffer(nil) | ||
for { | ||
readBytes, err := conn.Read(buf[:]) | ||
if err == io.EOF { | ||
break | ||
} | ||
if err != nil { | ||
errChan <- err | ||
return | ||
} | ||
codeResp.Write(buf[0:readBytes]) | ||
if readBytes < bufSize { | ||
break | ||
} | ||
} | ||
|
||
errChan <- nil | ||
successChan <- codeResp.Bytes() | ||
|
||
responseBody := <-responseBodyChan | ||
respToBrowser, err := buildResponse(responseBody) | ||
if err != nil { | ||
logger.Warnf("cannot create response to browser. %v", err) | ||
} | ||
_, err = conn.Write(respToBrowser.Bytes()) | ||
if err != nil { | ||
logger.Warnf("cannot write response to browser. %v", err) | ||
} | ||
closeListenerChan <- true | ||
} | ||
|
||
type authorizationCodeProvider interface { | ||
run(authorizationUrl string) error | ||
createState() string | ||
createCodeVerifier() string | ||
} | ||
|
||
type browserBasedAuthorizationCodeProvider struct { | ||
conn *net.TCPConn | ||
} | ||
|
||
func newBrowserBasedAuthorizationCodeProvider() authorizationCodeProvider { | ||
return &browserBasedAuthorizationCodeProvider{} | ||
} | ||
|
||
func (provider *browserBasedAuthorizationCodeProvider) run(authorizationUrl string) error { | ||
return openBrowser(authorizationUrl) | ||
} | ||
|
||
func (provider *browserBasedAuthorizationCodeProvider) createState() string { | ||
return NewUUID().String() | ||
} | ||
|
||
func (provider *browserBasedAuthorizationCodeProvider) createCodeVerifier() string { | ||
return oauth2.GenerateVerifier() | ||
} |
Oops, something went wrong.