Skip to content

Commit

Permalink
SNOW-1825500 Implement authorzation code flow
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Feb 17, 2025
1 parent f6efcbd commit 7c843f8
Show file tree
Hide file tree
Showing 15 changed files with 726 additions and 30 deletions.
35 changes: 33 additions & 2 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"io"
"net/http"
"net/url"
"os"
"runtime"
"strconv"
"strings"
Expand Down Expand Up @@ -49,6 +50,8 @@ const (
AuthTypeTokenAccessor
// AuthTypeUsernamePasswordMFA is to use username and password with mfa
AuthTypeUsernamePasswordMFA
// AuthTypeOAuthAuthorizationCode is to use browser-based OAuth2 flow
AuthTypeOAuthAuthorizationCode
)

func determineAuthenticatorType(cfg *Config, value string) error {
Expand All @@ -72,6 +75,12 @@ func determineAuthenticatorType(cfg *Config, value string) error {
} else if upperCaseValue == AuthTypeTokenAccessor.String() {
cfg.Authenticator = AuthTypeTokenAccessor
return nil
} else if upperCaseValue == AuthTypeOAuthAuthorizationCode.String() {
if experimentalAuthEnabled() {
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
return nil
}
return errors.New("OAuth2 authorization code is not yet enabled")
} else {
// possibly Okta case
oktaURLString, err := url.QueryUnescape(lowerCaseValue)
Expand Down Expand Up @@ -121,6 +130,8 @@ func (authType AuthType) String() string {
return "TOKENACCESSOR"
case AuthTypeUsernamePasswordMFA:
return "USERNAME_PASSWORD_MFA"
case AuthTypeOAuthAuthorizationCode:
return "OAUTH_AUTHORIZATION_CODE"
default:
return "UNKNOWN"
}
Expand Down Expand Up @@ -441,7 +452,7 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
}
requestMain.Token = jwtTokenString
case AuthTypeSnowflake:
logger.WithContext(sc.ctx).Info("Username and password")
logger.WithContext(sc.ctx).Debug("Username and password")
requestMain.LoginName = sc.cfg.User
requestMain.Password = sc.cfg.Password
switch {
Expand All @@ -452,7 +463,7 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
requestMain.ExtAuthnDuoMethod = "passcode"
}
case AuthTypeUsernamePasswordMFA:
logger.WithContext(sc.ctx).Info("Username and password MFA")
logger.WithContext(sc.ctx).Debug("Username and password MFA")
requestMain.LoginName = sc.cfg.User
requestMain.Password = sc.cfg.Password
switch {
Expand All @@ -464,6 +475,21 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
requestMain.Passcode = sc.cfg.Passcode
requestMain.ExtAuthnDuoMethod = "passcode"
}
case AuthTypeOAuthAuthorizationCode:
logger.WithContext(sc.ctx).Debug("OAuth authorization code")
if !experimentalAuthEnabled() {
return nil, errors.New("OAuth2 is not yet enabled")
}
oauthClient, err := newOauthClient(sc.ctx, sc.cfg)
if err != nil {
return nil, err
}
token, err := oauthClient.authenticateByOAuthAuthorizationCode()
if err != nil {
return nil, err
}
requestMain.LoginName = sc.cfg.User
requestMain.Token = token
}

authRequest := authRequest{
Expand Down Expand Up @@ -571,3 +597,8 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.ctx = context.WithValue(sc.ctx, SFSessionIDKey, authData.SessionID)
return nil
}

func experimentalAuthEnabled() bool {
val, ok := os.LookupEnv("ENABLE_EXPERIMENTAL_AUTHENTICATION")
return ok && strings.EqualFold(val, "true")
}
257 changes: 257 additions & 0 deletions auth_oauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
package gosnowflake

import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"golang.org/x/oauth2"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
)

const (
oauthSuccessHTML = `<!DOCTYPE html><html><head><meta charset="UTF-8"/>
<title>OAuth for Snowflake</title></head>
<body>
OAuth authentication completed successfully.
</body></html>`
)

type oauthClient struct {
ctx context.Context
cfg *Config

port int
redirectUriTemplate string

authorizationCodeProviderFactory func() authorizationCodeProvider
}

func newOauthClient(ctx context.Context, cfg *Config) (*oauthClient, error) {
port := 0
if cfg.OauthRedirectURI != "" {
uri, err := url.Parse(cfg.OauthRedirectURI)
if err != nil {
return nil, err
}
portStr := uri.Port()
if portStr != "" {
if port, err = strconv.Atoi(portStr); err != nil {
return nil, err
}
}
}

redirectUriTemplate := ""
if cfg.OauthRedirectURI == "" {
redirectUriTemplate = "http://localhost:%v/"
}

client := &http.Client{
Transport: getTransport(cfg),
}
return &oauthClient{
ctx: context.WithValue(ctx, oauth2.HTTPClient, client),
cfg: cfg,
port: port,
redirectUriTemplate: redirectUriTemplate,
authorizationCodeProviderFactory: func() authorizationCodeProvider {
return &browserBasedAuthorizationCodeProvider{}
},
}, nil
}

func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode() (string, error) {
authCodeProvider := oauthClient.authorizationCodeProviderFactory()

// TODO timeouts
// TODO retries
successChan := make(chan []byte)
errChan := make(chan error)
responseBodyChan := make(chan string, 2)
closeListenerChan := make(chan bool, 2)

defer func() {
closeListenerChan <- true
close(successChan)
close(errChan)
close(responseBodyChan)
close(closeListenerChan)
}()

tcpListener, callbackPort, err := oauthClient.setupListener()
if err != nil {
return "", err
}
defer func(tcpListener *net.TCPListener) {
<-closeListenerChan
if err := tcpListener.Close(); err != nil {
logger.Warnf("error while closing TCP listener. %v", err)
}
}(tcpListener)

go handleOAuthSocket(tcpListener, successChan, errChan, responseBodyChan, closeListenerChan)

oauth2cfg := oauthClient.buildOauthConfig(callbackPort)
codeVerifier := authCodeProvider.createCodeVerifier()
state := authCodeProvider.createState()
authorizationUrl := oauth2cfg.AuthCodeURL(state, oauth2.S256ChallengeOption(codeVerifier))
if err = authCodeProvider.run(authorizationUrl); err != nil {
responseBodyChan <- err.Error()
closeListenerChan <- true
return "", err
}

err = <-errChan
if err != nil {
responseBodyChan <- err.Error()
return "", err
}
codeReqBytes := <-successChan

codeReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(codeReqBytes)))
if err != nil {
responseBodyChan <- err.Error()
return "", err
}

return oauthClient.exchangeAccessToken(codeReq, state, err, oauth2cfg, codeVerifier, responseBodyChan, errChan)
}

func (oauthClient *oauthClient) setupListener() (*net.TCPListener, int, error) {
tcpListener, err := createLocalTCPListener(oauthClient.port)
if err != nil {
return nil, 0, err
}
callbackPort := tcpListener.Addr().(*net.TCPAddr).Port
return tcpListener, callbackPort, nil
}

func (oauthClient *oauthClient) exchangeAccessToken(codeReq *http.Request, state string, err error, oauth2cfg *oauth2.Config, codeVerifier string, responseBodyChan chan string, errChan chan error) (string, error) {

Check failure on line 136 in auth_oauth.go

View workflow job for this annotation

GitHub Actions / Check linter

SA4009: argument err is overwritten before first use (staticcheck)
queryParams := codeReq.URL.Query()
errorMsg := queryParams.Get("error")
if errorMsg != "" {
errorDesc := queryParams.Get("error_description")
errMsg := fmt.Sprintf("error while getting authentication from oauth: %v. Details: %v", errorMsg, errorDesc)
responseBodyChan <- errMsg
return "", errors.New(errMsg)
}

receivedState := queryParams.Get("state")
if state != receivedState {
errMsg := "invalid oauth state received"
responseBodyChan <- errMsg
return "", errors.New(errMsg)
}

code := queryParams.Get("code")
token, err := oauth2cfg.Exchange(oauthClient.ctx, code, oauth2.VerifierOption(codeVerifier))

Check failure on line 154 in auth_oauth.go

View workflow job for this annotation

GitHub Actions / Check linter

SA4009(related information): assignment to err (staticcheck)
if err != nil {
responseBodyChan <- err.Error()
return "", err
}
responseBodyChan <- oauthSuccessHTML
return token.AccessToken, nil
}

func (oauthClient *oauthClient) buildOauthConfig(callbackPort int) *oauth2.Config {
return &oauth2.Config{
ClientID: oauthClient.cfg.OauthClientID,
ClientSecret: oauthClient.cfg.OauthClientSecret,
RedirectURL: oauthClient.buildRedirectUri(callbackPort),
Scopes: oauthClient.buildScopes(),
Endpoint: oauth2.Endpoint{
AuthURL: oauthClient.cfg.OauthAuthorizationURL,
TokenURL: oauthClient.cfg.OauthTokenRequestURL,
},
}
}

func (oauthClient *oauthClient) buildRedirectUri(port int) string {
if oauthClient.cfg.OauthRedirectURI != "" {
return oauthClient.cfg.OauthRedirectURI
}
return fmt.Sprintf(oauthClient.redirectUriTemplate, port)
}

func (oauthClient *oauthClient) buildScopes() []string {
if oauthClient.cfg.OauthScope == "" {
return []string{"session:role:" + oauthClient.cfg.Role}
}
scopes := strings.Split(oauthClient.cfg.OauthScope, ",")
for i, scope := range scopes {
scopes[i] = strings.TrimSpace(scope)
}
return scopes
}

func handleOAuthSocket(tcpListener *net.TCPListener, successChan chan []byte, errChan chan error, responseBodyChan chan string, closeListenerChan chan bool) {
conn, err := tcpListener.AcceptTCP()
if err != nil {
logger.Warnf("error creating socket. %v", err)
return
}
defer conn.Close()
var buf [bufSize]byte
codeResp := bytes.NewBuffer(nil)
for {
readBytes, err := conn.Read(buf[:])
if err == io.EOF {
break
}
if err != nil {
errChan <- err
return
}
codeResp.Write(buf[0:readBytes])
if readBytes < bufSize {
break
}
}

errChan <- nil
successChan <- codeResp.Bytes()

responseBody := <-responseBodyChan
respToBrowser, err := buildResponse(responseBody)
if err != nil {
logger.Warnf("cannot create response to browser. %v", err)
}
_, err = conn.Write(respToBrowser.Bytes())
if err != nil {
logger.Warnf("cannot write response to browser. %v", err)
}
closeListenerChan <- true
}

type authorizationCodeProvider interface {
run(authorizationUrl string) error
createState() string
createCodeVerifier() string
}

type browserBasedAuthorizationCodeProvider struct {
conn *net.TCPConn

Check failure on line 240 in auth_oauth.go

View workflow job for this annotation

GitHub Actions / Check linter

field `conn` is unused (unused)
}

func newBrowserBasedAuthorizationCodeProvider() authorizationCodeProvider {

Check failure on line 243 in auth_oauth.go

View workflow job for this annotation

GitHub Actions / Check linter

func `newBrowserBasedAuthorizationCodeProvider` is unused (unused)
return &browserBasedAuthorizationCodeProvider{}
}

func (provider *browserBasedAuthorizationCodeProvider) run(authorizationUrl string) error {
return openBrowser(authorizationUrl)
}

func (provider *browserBasedAuthorizationCodeProvider) createState() string {
return NewUUID().String()
}

func (provider *browserBasedAuthorizationCodeProvider) createCodeVerifier() string {
return oauth2.GenerateVerifier()
}
Loading

0 comments on commit 7c843f8

Please sign in to comment.