diff --git a/auth.go b/auth.go
index ae214bb3f..46690c9bd 100644
--- a/auth.go
+++ b/auth.go
@@ -13,6 +13,7 @@ import (
"io"
"net/http"
"net/url"
+ "os"
"runtime"
"strconv"
"strings"
@@ -49,6 +50,8 @@ const (
AuthTypeTokenAccessor
// AuthTypeUsernamePasswordMFA is to use username and password with mfa
AuthTypeUsernamePasswordMFA
+ // AuthTypeOAuthAuthorizationCode is to use browser-based OAuth2 flow
+ AuthTypeOAuthAuthorizationCode
)
func determineAuthenticatorType(cfg *Config, value string) error {
@@ -72,6 +75,12 @@ func determineAuthenticatorType(cfg *Config, value string) error {
} else if upperCaseValue == AuthTypeTokenAccessor.String() {
cfg.Authenticator = AuthTypeTokenAccessor
return nil
+ } else if upperCaseValue == AuthTypeOAuthAuthorizationCode.String() {
+ if experimentalAuthEnabled() {
+ cfg.Authenticator = AuthTypeOAuthAuthorizationCode
+ return nil
+ }
+ return errors.New("OAuth2 authorization code is not yet enabled")
} else {
// possibly Okta case
oktaURLString, err := url.QueryUnescape(lowerCaseValue)
@@ -121,6 +130,8 @@ func (authType AuthType) String() string {
return "TOKENACCESSOR"
case AuthTypeUsernamePasswordMFA:
return "USERNAME_PASSWORD_MFA"
+ case AuthTypeOAuthAuthorizationCode:
+ return "OAUTH_AUTHORIZATION_CODE"
default:
return "UNKNOWN"
}
@@ -441,7 +452,7 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
}
requestMain.Token = jwtTokenString
case AuthTypeSnowflake:
- logger.WithContext(sc.ctx).Info("Username and password")
+ logger.WithContext(sc.ctx).Debug("Username and password")
requestMain.LoginName = sc.cfg.User
requestMain.Password = sc.cfg.Password
switch {
@@ -452,7 +463,7 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
requestMain.ExtAuthnDuoMethod = "passcode"
}
case AuthTypeUsernamePasswordMFA:
- logger.WithContext(sc.ctx).Info("Username and password MFA")
+ logger.WithContext(sc.ctx).Debug("Username and password MFA")
requestMain.LoginName = sc.cfg.User
requestMain.Password = sc.cfg.Password
switch {
@@ -464,6 +475,21 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
requestMain.Passcode = sc.cfg.Passcode
requestMain.ExtAuthnDuoMethod = "passcode"
}
+ case AuthTypeOAuthAuthorizationCode:
+ logger.WithContext(sc.ctx).Debug("OAuth authorization code")
+ if !experimentalAuthEnabled() {
+ return nil, errors.New("OAuth2 is not yet enabled")
+ }
+ oauthClient, err := newOauthClient(sc.ctx, sc.cfg)
+ if err != nil {
+ return nil, err
+ }
+ token, err := oauthClient.authenticateByOAuthAuthorizationCode()
+ if err != nil {
+ return nil, err
+ }
+ requestMain.LoginName = sc.cfg.User
+ requestMain.Token = token
}
authRequest := authRequest{
@@ -571,3 +597,8 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.ctx = context.WithValue(sc.ctx, SFSessionIDKey, authData.SessionID)
return nil
}
+
+func experimentalAuthEnabled() bool {
+ val, ok := os.LookupEnv("ENABLE_EXPERIMENTAL_AUTHENTICATION")
+ return ok && strings.EqualFold(val, "true")
+}
diff --git a/auth_oauth.go b/auth_oauth.go
new file mode 100644
index 000000000..9804e863e
--- /dev/null
+++ b/auth_oauth.go
@@ -0,0 +1,257 @@
+package gosnowflake
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "golang.org/x/oauth2"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+)
+
+const (
+ oauthSuccessHTML = `
+OAuth for Snowflake
+
+OAuth authentication completed successfully.
+`
+)
+
+type oauthClient struct {
+ ctx context.Context
+ cfg *Config
+
+ port int
+ redirectUriTemplate string
+
+ authorizationCodeProviderFactory func() authorizationCodeProvider
+}
+
+func newOauthClient(ctx context.Context, cfg *Config) (*oauthClient, error) {
+ port := 0
+ if cfg.OauthRedirectURI != "" {
+ uri, err := url.Parse(cfg.OauthRedirectURI)
+ if err != nil {
+ return nil, err
+ }
+ portStr := uri.Port()
+ if portStr != "" {
+ if port, err = strconv.Atoi(portStr); err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ redirectUriTemplate := ""
+ if cfg.OauthRedirectURI == "" {
+ redirectUriTemplate = "http://localhost:%v/"
+ }
+
+ client := &http.Client{
+ Transport: getTransport(cfg),
+ }
+ return &oauthClient{
+ ctx: context.WithValue(ctx, oauth2.HTTPClient, client),
+ cfg: cfg,
+ port: port,
+ redirectUriTemplate: redirectUriTemplate,
+ authorizationCodeProviderFactory: func() authorizationCodeProvider {
+ return &browserBasedAuthorizationCodeProvider{}
+ },
+ }, nil
+}
+
+func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode() (string, error) {
+ authCodeProvider := oauthClient.authorizationCodeProviderFactory()
+
+ // TODO timeouts
+ // TODO retries
+ successChan := make(chan []byte)
+ errChan := make(chan error)
+ responseBodyChan := make(chan string, 2)
+ closeListenerChan := make(chan bool, 2)
+
+ defer func() {
+ closeListenerChan <- true
+ close(successChan)
+ close(errChan)
+ close(responseBodyChan)
+ close(closeListenerChan)
+ }()
+
+ tcpListener, callbackPort, err := oauthClient.setupListener()
+ if err != nil {
+ return "", err
+ }
+ defer func(tcpListener *net.TCPListener) {
+ <-closeListenerChan
+ if err := tcpListener.Close(); err != nil {
+ logger.Warnf("error while closing TCP listener. %v", err)
+ }
+ }(tcpListener)
+
+ go handleOAuthSocket(tcpListener, successChan, errChan, responseBodyChan, closeListenerChan)
+
+ oauth2cfg := oauthClient.buildOauthConfig(callbackPort)
+ codeVerifier := authCodeProvider.createCodeVerifier()
+ state := authCodeProvider.createState()
+ authorizationUrl := oauth2cfg.AuthCodeURL(state, oauth2.S256ChallengeOption(codeVerifier))
+ if err = authCodeProvider.run(authorizationUrl); err != nil {
+ responseBodyChan <- err.Error()
+ closeListenerChan <- true
+ return "", err
+ }
+
+ err = <-errChan
+ if err != nil {
+ responseBodyChan <- err.Error()
+ return "", err
+ }
+ codeReqBytes := <-successChan
+
+ codeReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(codeReqBytes)))
+ if err != nil {
+ responseBodyChan <- err.Error()
+ return "", err
+ }
+
+ return oauthClient.exchangeAccessToken(codeReq, state, err, oauth2cfg, codeVerifier, responseBodyChan, errChan)
+}
+
+func (oauthClient *oauthClient) setupListener() (*net.TCPListener, int, error) {
+ tcpListener, err := createLocalTCPListener(oauthClient.port)
+ if err != nil {
+ return nil, 0, err
+ }
+ callbackPort := tcpListener.Addr().(*net.TCPAddr).Port
+ return tcpListener, callbackPort, nil
+}
+
+func (oauthClient *oauthClient) exchangeAccessToken(codeReq *http.Request, state string, err error, oauth2cfg *oauth2.Config, codeVerifier string, responseBodyChan chan string, errChan chan error) (string, error) {
+ queryParams := codeReq.URL.Query()
+ errorMsg := queryParams.Get("error")
+ if errorMsg != "" {
+ errorDesc := queryParams.Get("error_description")
+ errMsg := fmt.Sprintf("error while getting authentication from oauth: %v. Details: %v", errorMsg, errorDesc)
+ responseBodyChan <- errMsg
+ return "", errors.New(errMsg)
+ }
+
+ receivedState := queryParams.Get("state")
+ if state != receivedState {
+ errMsg := "invalid oauth state received"
+ responseBodyChan <- errMsg
+ return "", errors.New(errMsg)
+ }
+
+ code := queryParams.Get("code")
+ token, err := oauth2cfg.Exchange(oauthClient.ctx, code, oauth2.VerifierOption(codeVerifier))
+ if err != nil {
+ responseBodyChan <- err.Error()
+ return "", err
+ }
+ responseBodyChan <- oauthSuccessHTML
+ return token.AccessToken, nil
+}
+
+func (oauthClient *oauthClient) buildOauthConfig(callbackPort int) *oauth2.Config {
+ return &oauth2.Config{
+ ClientID: oauthClient.cfg.OauthClientID,
+ ClientSecret: oauthClient.cfg.OauthClientSecret,
+ RedirectURL: oauthClient.buildRedirectUri(callbackPort),
+ Scopes: oauthClient.buildScopes(),
+ Endpoint: oauth2.Endpoint{
+ AuthURL: oauthClient.cfg.OauthAuthorizationURL,
+ TokenURL: oauthClient.cfg.OauthTokenRequestURL,
+ },
+ }
+}
+
+func (oauthClient *oauthClient) buildRedirectUri(port int) string {
+ if oauthClient.cfg.OauthRedirectURI != "" {
+ return oauthClient.cfg.OauthRedirectURI
+ }
+ return fmt.Sprintf(oauthClient.redirectUriTemplate, port)
+}
+
+func (oauthClient *oauthClient) buildScopes() []string {
+ if oauthClient.cfg.OauthScope == "" {
+ return []string{"session:role:" + oauthClient.cfg.Role}
+ }
+ scopes := strings.Split(oauthClient.cfg.OauthScope, ",")
+ for i, scope := range scopes {
+ scopes[i] = strings.TrimSpace(scope)
+ }
+ return scopes
+}
+
+func handleOAuthSocket(tcpListener *net.TCPListener, successChan chan []byte, errChan chan error, responseBodyChan chan string, closeListenerChan chan bool) {
+ conn, err := tcpListener.AcceptTCP()
+ if err != nil {
+ logger.Warnf("error creating socket. %v", err)
+ return
+ }
+ defer conn.Close()
+ var buf [bufSize]byte
+ codeResp := bytes.NewBuffer(nil)
+ for {
+ readBytes, err := conn.Read(buf[:])
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ errChan <- err
+ return
+ }
+ codeResp.Write(buf[0:readBytes])
+ if readBytes < bufSize {
+ break
+ }
+ }
+
+ errChan <- nil
+ successChan <- codeResp.Bytes()
+
+ responseBody := <-responseBodyChan
+ respToBrowser, err := buildResponse(responseBody)
+ if err != nil {
+ logger.Warnf("cannot create response to browser. %v", err)
+ }
+ _, err = conn.Write(respToBrowser.Bytes())
+ if err != nil {
+ logger.Warnf("cannot write response to browser. %v", err)
+ }
+ closeListenerChan <- true
+}
+
+type authorizationCodeProvider interface {
+ run(authorizationUrl string) error
+ createState() string
+ createCodeVerifier() string
+}
+
+type browserBasedAuthorizationCodeProvider struct {
+ conn *net.TCPConn
+}
+
+func newBrowserBasedAuthorizationCodeProvider() authorizationCodeProvider {
+ return &browserBasedAuthorizationCodeProvider{}
+}
+
+func (provider *browserBasedAuthorizationCodeProvider) run(authorizationUrl string) error {
+ return openBrowser(authorizationUrl)
+}
+
+func (provider *browserBasedAuthorizationCodeProvider) createState() string {
+ return NewUUID().String()
+}
+
+func (provider *browserBasedAuthorizationCodeProvider) createCodeVerifier() string {
+ return oauth2.GenerateVerifier()
+}
diff --git a/auth_oauth_test.go b/auth_oauth_test.go
new file mode 100644
index 000000000..275aafe55
--- /dev/null
+++ b/auth_oauth_test.go
@@ -0,0 +1,123 @@
+package gosnowflake
+
+import (
+ "context"
+ "errors"
+ "golang.org/x/oauth2"
+ "io"
+ "net/http"
+ "testing"
+ "time"
+)
+
+func TestUnitOAuthAuthorizationCode(t *testing.T) {
+ client, err := newOauthClient(context.Background(), &Config{
+ Role: "ANALYST",
+ OauthClientID: "testClientId",
+ OauthClientSecret: "testClientSecret",
+ OauthAuthorizationURL: wiremock.baseUrl() + "/oauth/authorize",
+ OauthTokenRequestURL: wiremock.baseUrl() + "/oauth/token",
+ OauthRedirectURI: "http://localhost:1234/snowflake/oauth-redirect",
+ })
+ assertNilF(t, err)
+
+ t.Run("Success", func(t *testing.T) {
+ wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/successful_flow.json"))
+ authCodeProvider := &nonInteractiveAuthorizationCodeProvider{}
+ client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
+ return authCodeProvider
+ }
+ token, err := client.authenticateByOAuthAuthorizationCode()
+ assertNilF(t, err)
+ assertEqualE(t, token, "access-token-123")
+ time.Sleep(100 * time.Millisecond)
+ assertStringContainsE(t, authCodeProvider.responseBody, "OAuth authentication completed successfully.")
+ })
+
+ t.Run("InvalidState", func(t *testing.T) {
+ wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/successful_flow.json"))
+ authCodeProvider := &nonInteractiveAuthorizationCodeProvider{
+ tamperWithState: true,
+ }
+ client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
+ return authCodeProvider
+ }
+ _, err = client.authenticateByOAuthAuthorizationCode()
+ assertEqualE(t, err.Error(), "invalid oauth state received")
+ time.Sleep(100 * time.Millisecond)
+ assertStringContainsE(t, authCodeProvider.responseBody, "invalid oauth state received")
+ })
+
+ t.Run("ErrorFromIdPWhileGettingCode", func(t *testing.T) {
+ wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/error_from_idp.json"))
+ authCodeProvider := &nonInteractiveAuthorizationCodeProvider{}
+ client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
+ return authCodeProvider
+ }
+ _, err = client.authenticateByOAuthAuthorizationCode()
+ assertEqualE(t, err.Error(), "error while getting authentication from oauth: some error. Details: some error desc")
+ time.Sleep(100 * time.Millisecond)
+ assertStringContainsE(t, authCodeProvider.responseBody, "error while getting authentication from oauth: some error. Details: some error desc")
+ })
+
+ t.Run("ErrorFromProviderWhileGettingCode", func(t *testing.T) {
+ authCodeProvider := &nonInteractiveAuthorizationCodeProvider{
+ triggerError: "test error",
+ }
+ client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
+ return authCodeProvider
+ }
+ _, err = client.authenticateByOAuthAuthorizationCode()
+ assertEqualE(t, err.Error(), "test error")
+ })
+
+ t.Run("InvalidCode", func(t *testing.T) {
+ wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/invalid_code.json"))
+ authCodeProvider := &nonInteractiveAuthorizationCodeProvider{}
+ client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
+ return authCodeProvider
+ }
+ _, err = client.authenticateByOAuthAuthorizationCode()
+ assertNotNilE(t, err)
+ assertEqualE(t, err.(*oauth2.RetrieveError).ErrorCode, "invalid_grant")
+ assertEqualE(t, err.(*oauth2.RetrieveError).ErrorDescription, "The authorization code is invalid or has expired.")
+ time.Sleep(100 * time.Millisecond)
+ assertStringContainsE(t, authCodeProvider.responseBody, "invalid_grant")
+ })
+}
+
+type nonInteractiveAuthorizationCodeProvider struct {
+ t *testing.T
+ callbackUrl string
+ tamperWithState bool
+ error string
+ errorDescription string
+ triggerError string
+ responseBody string
+}
+
+func (provider *nonInteractiveAuthorizationCodeProvider) run(authorizationUrl string) error {
+ if provider.triggerError != "" {
+ return errors.New(provider.triggerError)
+ }
+ go func() {
+ resp, err := http.Get(authorizationUrl)
+ assertNilF(provider.t, err)
+ assertEqualE(provider.t, resp.StatusCode, http.StatusOK)
+ respBody, err := io.ReadAll(resp.Body)
+ assertNilF(provider.t, err)
+ provider.responseBody = string(respBody)
+ }()
+ return nil
+}
+
+func (provider *nonInteractiveAuthorizationCodeProvider) createState() string {
+ if provider.tamperWithState {
+ return "invalidState"
+ }
+ return "testState"
+}
+
+func (provider *nonInteractiveAuthorizationCodeProvider) createCodeVerifier() string {
+ return "testCodeVerifier"
+}
diff --git a/auth_test.go b/auth_test.go
index c559a2407..76bcbbb62 100644
--- a/auth_test.go
+++ b/auth_test.go
@@ -1003,3 +1003,22 @@ func TestContextPropagatedToAuthWhenUsingOpenDB(t *testing.T) {
assertStringContainsE(t, err.Error(), "context deadline exceeded")
cancel()
}
+
+func TestWithOauthAuthorizationCodeFlowManual(t *testing.T) {
+ t.Skip("manual test")
+ cfg, err := GetConfigFromEnv([]*ConfigParam{
+ {"OAuthClientId", "SNOWFLAKE_TEST_OAUTH_CLIENT_ID", true},
+ {"OAuthClientSecret", "SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET", true},
+ {"OAuthAuthorizationURL", "SNOWFLAKE_TEST_OAUTH_AUTHORIZATION_URL", true},
+ {"OAuthTokenRequestURL", "SNOWFLAKE_TEST_OAUTH_TOKEN_REQUEST_URL", true},
+ {"OAuthRedirectURI", "SNOWFLAKE_TEST_OAUTH_REDIRECT_URI", true},
+ {"Role", "SNOWFLAKE_TEST_OAUTH_ROLE", true},
+ {"Account", "SNOWFLAKE_TEST_ACCOUNT", true},
+ })
+ assertNilF(t, err)
+ cfg.Authenticator = AuthTypeOAuthAuthorizationCode
+ connector := NewConnector(&SnowflakeDriver{}, *cfg)
+ db := sql.OpenDB(connector)
+ defer db.Close()
+ runSmokeQuery(t, db)
+}
diff --git a/authexternalbrowser.go b/authexternalbrowser.go
index 04a91fd97..67f69fb6d 100644
--- a/authexternalbrowser.go
+++ b/authexternalbrowser.go
@@ -22,22 +22,19 @@ import (
)
const (
- successHTML = `
+ samlSuccessHTML = `
SAML Response for Snowflake
Your identity was confirmed and propagated to Snowflake %v.
You can close this window now and go back where you started from.
`
-)
-const (
bufSize = 8192
)
// Builds a response to show to the user after successfully
// getting a response from Snowflake.
-func buildResponse(application string) (bytes.Buffer, error) {
- body := fmt.Sprintf(successHTML, application)
+func buildResponse(body string) (bytes.Buffer, error) {
t := &http.Response{
Status: "200 OK",
StatusCode: 200,
@@ -57,8 +54,8 @@ func buildResponse(application string) (bytes.Buffer, error) {
// This opens a socket that listens on all available unicast
// and any anycast IP addresses locally. By specifying "0", we are
// able to bind to a free port.
-func createLocalTCPListener() (*net.TCPListener, error) {
- l, err := net.Listen("tcp", "localhost:0")
+func createLocalTCPListener(port int) (*net.TCPListener, error) {
+ l, err := net.Listen("tcp", fmt.Sprintf("localhost:%v", port))
if err != nil {
return nil, err
}
@@ -243,7 +240,7 @@ func doAuthenticateByExternalBrowser(
password string,
disableConsoleLogin ConfigBool,
) authenticateByExternalBrowserResult {
- l, err := createLocalTCPListener()
+ l, err := createLocalTCPListener(0)
if err != nil {
return authenticateByExternalBrowserResult{nil, nil, err}
}
@@ -310,7 +307,8 @@ func doAuthenticateByExternalBrowser(
buf.Grow(bufSize)
}
if encodedSamlResponse != "" {
- httpResponse, err := buildResponse(application)
+ body := fmt.Sprintf(samlSuccessHTML, application)
+ httpResponse, err := buildResponse(body)
if err != nil && errAccept == nil {
errAccept = err
}
diff --git a/authexternalbrowser_test.go b/authexternalbrowser_test.go
index 4546da461..4c9765d1e 100644
--- a/authexternalbrowser_test.go
+++ b/authexternalbrowser_test.go
@@ -135,7 +135,7 @@ func TestAuthenticationTimeout(t *testing.T) {
}
func Test_createLocalTCPListener(t *testing.T) {
- listener, err := createLocalTCPListener()
+ listener, err := createLocalTCPListener(0)
if err != nil {
t.Fatalf("createLocalTCPListener() failed: %v", err)
}
diff --git a/driver_test.go b/driver_test.go
index ba465dd5a..daf34af99 100644
--- a/driver_test.go
+++ b/driver_test.go
@@ -2111,3 +2111,14 @@ func TestTimePrecision(t *testing.T) {
}
})
}
+
+func runSmokeQuery(t *testing.T, db *sql.DB) {
+ rows, err := db.Query("SELECT 1")
+ assertNilF(t, err)
+ defer rows.Close()
+ assertTrueF(t, rows.Next())
+ var v int
+ err = rows.Scan(&v)
+ assertNilF(t, err)
+ assertEqualE(t, v, 1)
+}
\ No newline at end of file
diff --git a/dsn.go b/dsn.go
index d63eccc84..810c4fcd2 100644
--- a/dsn.go
+++ b/dsn.go
@@ -601,14 +601,16 @@ func buildHostFromAccountAndRegion(account, region string) string {
func authRequiresUser(cfg *Config) bool {
return cfg.Authenticator != AuthTypeOAuth &&
cfg.Authenticator != AuthTypeTokenAccessor &&
- cfg.Authenticator != AuthTypeExternalBrowser
+ cfg.Authenticator != AuthTypeExternalBrowser &&
+ cfg.Authenticator != AuthTypeOAuthAuthorizationCode
}
func authRequiresPassword(cfg *Config) bool {
return cfg.Authenticator != AuthTypeOAuth &&
cfg.Authenticator != AuthTypeTokenAccessor &&
cfg.Authenticator != AuthTypeExternalBrowser &&
- cfg.Authenticator != AuthTypeJwt
+ cfg.Authenticator != AuthTypeJwt &&
+ cfg.Authenticator != AuthTypeOAuthAuthorizationCode
}
// transformAccountToHost transforms account to host
@@ -943,6 +945,7 @@ type ConfigParam struct {
// GetConfigFromEnv is used to parse the environment variable values to specific fields of the Config
func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) {
var account, user, password, role, host, portStr, protocol, warehouse, database, schema, region, passcode, application string
+ var oauthClientId, oauthClientSecret, oauthAuthorizationUrl, oauthTokenRequestUrl, oauthRedirectUri string
var privateKey *rsa.PrivateKey
var err error
if len(properties) == 0 || properties == nil {
@@ -985,6 +988,18 @@ func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) {
if err != nil {
return nil, err
}
+ case "OAuthClientId":
+ oauthClientId = value
+ case "OAuthClientSecret":
+ oauthClientSecret = value
+ case "OAuthAuthorizationURL":
+ oauthAuthorizationUrl = value
+ case "OAuthTokenRequestURL":
+ oauthTokenRequestUrl = value
+ case "OAuthRedirectURI":
+ oauthRedirectUri = value
+ default:
+ return nil, errors.New("unknown property: " + prop.Name)
}
}
@@ -997,21 +1012,26 @@ func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) {
}
cfg := &Config{
- Account: account,
- User: user,
- Password: password,
- Role: role,
- Host: host,
- Port: port,
- Protocol: protocol,
- Warehouse: warehouse,
- Database: database,
- Schema: schema,
- PrivateKey: privateKey,
- Region: region,
- Passcode: passcode,
- Application: application,
- Params: map[string]*string{},
+ Account: account,
+ User: user,
+ Password: password,
+ Role: role,
+ Host: host,
+ Port: port,
+ Protocol: protocol,
+ Warehouse: warehouse,
+ Database: database,
+ Schema: schema,
+ PrivateKey: privateKey,
+ Region: region,
+ Passcode: passcode,
+ Application: application,
+ OauthClientID: oauthClientId,
+ OauthClientSecret: oauthClientSecret,
+ OauthAuthorizationURL: oauthAuthorizationUrl,
+ OauthTokenRequestURL: oauthTokenRequestUrl,
+ OauthRedirectURI: oauthRedirectUri,
+ Params: map[string]*string{},
}
return cfg, nil
}
diff --git a/dsn_test.go b/dsn_test.go
index 2f63f9acd..6fae4f89f 100644
--- a/dsn_test.go
+++ b/dsn_test.go
@@ -466,6 +466,22 @@ func TestParseDSN(t *testing.T) {
ocspMode: ocspModeFailOpen,
err: nil,
},
+ {
+ dsn: "snowflake.local:9876?account=a&protocol=http&authenticator=OAUTH_AUTHORIZATION_CODE",
+ config: &Config{
+ Account: "a", Authenticator: AuthTypeOAuthAuthorizationCode,
+ Protocol: "http", Host: "snowflake.local", Port: 9876,
+ OCSPFailOpen: OCSPFailOpenTrue,
+ ValidateDefaultParameters: ConfigBoolTrue,
+ ClientTimeout: defaultClientTimeout,
+ JWTClientTimeout: defaultJWTClientTimeout,
+ ExternalBrowserTimeout: defaultExternalBrowserTimeout,
+ CloudStorageTimeout: defaultCloudStorageTimeout,
+ IncludeRetryReason: ConfigBoolTrue,
+ },
+ ocspMode: ocspModeFailOpen,
+ err: nil,
+ },
{
dsn: "u:@a.snowflake.local:9876?account=a&protocol=http&authenticator=SNOWFLAKE_JWT",
config: &Config{
@@ -1515,6 +1531,16 @@ func TestDSN(t *testing.T) {
},
dsn: "u:p@a.snowflakecomputing.com:443?authenticator=externalbrowser&clientStoreTemporaryCredential=false&ocspFailOpen=true&validateDefaultParameters=true",
},
+ {
+ cfg: &Config{
+ User: "u",
+ Password: "p",
+ Account: "a",
+ Authenticator: AuthTypeOAuthAuthorizationCode,
+ ClientStoreTemporaryCredential: ConfigBoolFalse,
+ },
+ dsn: "u:p@a.snowflakecomputing.com:443?authenticator=oauth_authorization_code&clientStoreTemporaryCredential=false&ocspFailOpen=true&validateDefaultParameters=true",
+ },
{
cfg: &Config{
User: "u",
diff --git a/go.mod b/go.mod
index efb9a2e06..9b4850344 100644
--- a/go.mod
+++ b/go.mod
@@ -49,6 +49,7 @@ require (
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.33.0 // indirect
+ golang.org/x/oauth2 v0.26.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/term v0.27.0 // indirect
golang.org/x/text v0.21.0 // indirect
diff --git a/go.sum b/go.sum
index 88392ef77..45156069b 100644
--- a/go.sum
+++ b/go.sum
@@ -137,6 +137,8 @@ golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo=
golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
+golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE=
+golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
diff --git a/test_data/wiremock/mappings/oauth2/authorization_code/error_from_idp.json b/test_data/wiremock/mappings/oauth2/authorization_code/error_from_idp.json
new file mode 100644
index 000000000..1e3e352c5
--- /dev/null
+++ b/test_data/wiremock/mappings/oauth2/authorization_code/error_from_idp.json
@@ -0,0 +1,39 @@
+{
+ "mappings": [
+ {
+ "request": {
+ "urlPathPattern": "/oauth/authorize",
+ "queryParameters": {
+ "response_type": {
+ "equalTo": "code"
+ },
+ "scope": {
+ "equalTo": "session:role:ANALYST"
+ },
+ "code_challenge_method": {
+ "equalTo": "S256"
+ },
+ "redirect_uri": {
+ "equalTo": "http://localhost:1234/snowflake/oauth-redirect"
+ },
+ "code_challenge": {
+ "matches": ".+"
+ },
+ "state": {
+ "matches": "testState|invalidState"
+ },
+ "client_id": {
+ "equalTo": "testClientId"
+ }
+ },
+ "method": "GET"
+ },
+ "response": {
+ "status": 302,
+ "headers": {
+ "Location": "http://localhost:1234/snowflake/oauth-redirect?error=some+error&error_description=some+error+desc"
+ }
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/test_data/wiremock/mappings/oauth2/authorization_code/invalid_code.json b/test_data/wiremock/mappings/oauth2/authorization_code/invalid_code.json
new file mode 100644
index 000000000..41329e16a
--- /dev/null
+++ b/test_data/wiremock/mappings/oauth2/authorization_code/invalid_code.json
@@ -0,0 +1,78 @@
+{
+ "mappings": [
+ {
+ "request": {
+ "urlPathPattern": "/oauth/authorize",
+ "queryParameters": {
+ "response_type": {
+ "equalTo": "code"
+ },
+ "scope": {
+ "equalTo": "session:role:ANALYST"
+ },
+ "code_challenge_method": {
+ "equalTo": "S256"
+ },
+ "redirect_uri": {
+ "equalTo": "http://localhost:1234/snowflake/oauth-redirect"
+ },
+ "code_challenge": {
+ "matches": ".+"
+ },
+ "state": {
+ "matches": "testState"
+ },
+ "client_id": {
+ "equalTo": "testClientId"
+ }
+ },
+ "method": "GET"
+ },
+ "response": {
+ "status": 302,
+ "headers": {
+ "Location": "http://localhost:1234/snowflake/oauth-redirect?code=testCode&state=testState"
+ }
+ }
+ },
+ {
+ "scenarioName": "Successful token exchange",
+ "request": {
+ "urlPathPattern": "/oauth/token",
+ "method": "POST",
+ "headers": {
+ "Content-Type": {
+ "contains": "application/x-www-form-urlencoded"
+ }
+ },
+ "formParameters": {
+ "grant_type": {
+ "equalTo": "authorization_code"
+ },
+ "client_id": {
+ "equalTo": "testClientId"
+ },
+ "client_secret": {
+ "equalTo": "testClientSecret"
+ },
+ "code_verifier": {
+ "matches": "[a-zA-Z0-9\\-_]+"
+ },
+ "code": {
+ "equalTo": "testCode"
+ },
+ "redirect_uri": {
+ "equalTo": "http://localhost:1234/snowflake/oauth-redirect"
+ }
+ }
+ },
+ "response": {
+ "status": 400,
+ "jsonBody": {
+ "error" : "invalid_grant",
+ "error_description" : "The authorization code is invalid or has expired."
+ }
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/test_data/wiremock/mappings/oauth2/authorization_code/successful_flow.json b/test_data/wiremock/mappings/oauth2/authorization_code/successful_flow.json
new file mode 100644
index 000000000..3715a4752
--- /dev/null
+++ b/test_data/wiremock/mappings/oauth2/authorization_code/successful_flow.json
@@ -0,0 +1,83 @@
+{
+ "mappings": [
+ {
+ "request": {
+ "urlPathPattern": "/oauth/authorize",
+ "queryParameters": {
+ "response_type": {
+ "equalTo": "code"
+ },
+ "scope": {
+ "equalTo": "session:role:ANALYST"
+ },
+ "code_challenge_method": {
+ "equalTo": "S256"
+ },
+ "redirect_uri": {
+ "equalTo": "http://localhost:1234/snowflake/oauth-redirect"
+ },
+ "code_challenge": {
+ "matches": "JZpN_-zfNduuWm-zUo-D-m7vMw_pgUGv8wGDGqBR8PM"
+ },
+ "state": {
+ "matches": "testState|invalidState"
+ },
+ "client_id": {
+ "equalTo": "testClientId"
+ }
+ },
+ "method": "GET"
+ },
+ "response": {
+ "status": 302,
+ "headers": {
+ "Location": "http://localhost:1234/snowflake/oauth-redirect?code=testCode&state=testState"
+ }
+ }
+ },
+ {
+ "request": {
+ "urlPathPattern": "/oauth/token",
+ "method": "POST",
+ "headers": {
+ "Content-Type": {
+ "contains": "application/x-www-form-urlencoded"
+ }
+ },
+ "formParameters": {
+ "grant_type": {
+ "equalTo": "authorization_code"
+ },
+ "client_id": {
+ "equalTo": "testClientId"
+ },
+ "client_secret": {
+ "equalTo": "testClientSecret"
+ },
+ "code_verifier": {
+ "matches": "testCodeVerifier"
+ },
+ "code": {
+ "equalTo": "testCode"
+ },
+ "redirect_uri": {
+ "equalTo": "http://localhost:1234/snowflake/oauth-redirect"
+ }
+ }
+ },
+ "response": {
+ "status": 200,
+ "jsonBody": {
+ "access_token": "access-token-123",
+ "refresh_token": "123",
+ "token_type": "Bearer",
+ "username": "test-user",
+ "scope": "refresh_token session:role:ANALYST",
+ "expires_in": 600,
+ "refresh_token_expires_in": 86399,
+ "idpInitiated": false
+ }
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/wiremock_test.go b/wiremock_test.go
index ed37c0bdb..81e23106f 100644
--- a/wiremock_test.go
+++ b/wiremock_test.go
@@ -58,6 +58,10 @@ type wiremockMapping struct {
params map[string]string
}
+func newWiremockMapping(filePath string) wiremockMapping {
+ return wiremockMapping{filePath: filePath}
+}
+
func (wm *wiremockClient) registerMappings(t *testing.T, mappings ...wiremockMapping) {
for _, mapping := range wm.enrichWithTelemetry(mappings) {
f, err := os.Open("test_data/wiremock/mappings/" + mapping.filePath)
@@ -92,7 +96,11 @@ func (wm *wiremockClient) enrichWithTelemetry(mappings []wiremockMapping) []wire
}
func (wm *wiremockClient) mappingsURL() string {
- return fmt.Sprintf("%v://%v:%v/__admin/mappings", wm.protocol, wm.host, wm.port)
+ return fmt.Sprintf("%v/__admin/mappings", wm.baseUrl())
+}
+
+func (wm *wiremockClient) baseUrl() string {
+ return fmt.Sprintf("%v://%v:%v", wm.protocol, wm.host, wm.port)
}
// just to satisfy not used private variables and functions