diff --git a/auth.go b/auth.go index ae214bb3f..46690c9bd 100644 --- a/auth.go +++ b/auth.go @@ -13,6 +13,7 @@ import ( "io" "net/http" "net/url" + "os" "runtime" "strconv" "strings" @@ -49,6 +50,8 @@ const ( AuthTypeTokenAccessor // AuthTypeUsernamePasswordMFA is to use username and password with mfa AuthTypeUsernamePasswordMFA + // AuthTypeOAuthAuthorizationCode is to use browser-based OAuth2 flow + AuthTypeOAuthAuthorizationCode ) func determineAuthenticatorType(cfg *Config, value string) error { @@ -72,6 +75,12 @@ func determineAuthenticatorType(cfg *Config, value string) error { } else if upperCaseValue == AuthTypeTokenAccessor.String() { cfg.Authenticator = AuthTypeTokenAccessor return nil + } else if upperCaseValue == AuthTypeOAuthAuthorizationCode.String() { + if experimentalAuthEnabled() { + cfg.Authenticator = AuthTypeOAuthAuthorizationCode + return nil + } + return errors.New("OAuth2 authorization code is not yet enabled") } else { // possibly Okta case oktaURLString, err := url.QueryUnescape(lowerCaseValue) @@ -121,6 +130,8 @@ func (authType AuthType) String() string { return "TOKENACCESSOR" case AuthTypeUsernamePasswordMFA: return "USERNAME_PASSWORD_MFA" + case AuthTypeOAuthAuthorizationCode: + return "OAUTH_AUTHORIZATION_CODE" default: return "UNKNOWN" } @@ -441,7 +452,7 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface } requestMain.Token = jwtTokenString case AuthTypeSnowflake: - logger.WithContext(sc.ctx).Info("Username and password") + logger.WithContext(sc.ctx).Debug("Username and password") requestMain.LoginName = sc.cfg.User requestMain.Password = sc.cfg.Password switch { @@ -452,7 +463,7 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface requestMain.ExtAuthnDuoMethod = "passcode" } case AuthTypeUsernamePasswordMFA: - logger.WithContext(sc.ctx).Info("Username and password MFA") + logger.WithContext(sc.ctx).Debug("Username and password MFA") requestMain.LoginName = sc.cfg.User requestMain.Password = sc.cfg.Password switch { @@ -464,6 +475,21 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface requestMain.Passcode = sc.cfg.Passcode requestMain.ExtAuthnDuoMethod = "passcode" } + case AuthTypeOAuthAuthorizationCode: + logger.WithContext(sc.ctx).Debug("OAuth authorization code") + if !experimentalAuthEnabled() { + return nil, errors.New("OAuth2 is not yet enabled") + } + oauthClient, err := newOauthClient(sc.ctx, sc.cfg) + if err != nil { + return nil, err + } + token, err := oauthClient.authenticateByOAuthAuthorizationCode() + if err != nil { + return nil, err + } + requestMain.LoginName = sc.cfg.User + requestMain.Token = token } authRequest := authRequest{ @@ -571,3 +597,8 @@ func authenticateWithConfig(sc *snowflakeConn) error { sc.ctx = context.WithValue(sc.ctx, SFSessionIDKey, authData.SessionID) return nil } + +func experimentalAuthEnabled() bool { + val, ok := os.LookupEnv("ENABLE_EXPERIMENTAL_AUTHENTICATION") + return ok && strings.EqualFold(val, "true") +} diff --git a/auth_oauth.go b/auth_oauth.go new file mode 100644 index 000000000..9804e863e --- /dev/null +++ b/auth_oauth.go @@ -0,0 +1,257 @@ +package gosnowflake + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "golang.org/x/oauth2" + "io" + "net" + "net/http" + "net/url" + "strconv" + "strings" +) + +const ( + oauthSuccessHTML = ` +OAuth for Snowflake + +OAuth authentication completed successfully. +` +) + +type oauthClient struct { + ctx context.Context + cfg *Config + + port int + redirectUriTemplate string + + authorizationCodeProviderFactory func() authorizationCodeProvider +} + +func newOauthClient(ctx context.Context, cfg *Config) (*oauthClient, error) { + port := 0 + if cfg.OauthRedirectURI != "" { + uri, err := url.Parse(cfg.OauthRedirectURI) + if err != nil { + return nil, err + } + portStr := uri.Port() + if portStr != "" { + if port, err = strconv.Atoi(portStr); err != nil { + return nil, err + } + } + } + + redirectUriTemplate := "" + if cfg.OauthRedirectURI == "" { + redirectUriTemplate = "http://localhost:%v/" + } + + client := &http.Client{ + Transport: getTransport(cfg), + } + return &oauthClient{ + ctx: context.WithValue(ctx, oauth2.HTTPClient, client), + cfg: cfg, + port: port, + redirectUriTemplate: redirectUriTemplate, + authorizationCodeProviderFactory: func() authorizationCodeProvider { + return &browserBasedAuthorizationCodeProvider{} + }, + }, nil +} + +func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode() (string, error) { + authCodeProvider := oauthClient.authorizationCodeProviderFactory() + + // TODO timeouts + // TODO retries + successChan := make(chan []byte) + errChan := make(chan error) + responseBodyChan := make(chan string, 2) + closeListenerChan := make(chan bool, 2) + + defer func() { + closeListenerChan <- true + close(successChan) + close(errChan) + close(responseBodyChan) + close(closeListenerChan) + }() + + tcpListener, callbackPort, err := oauthClient.setupListener() + if err != nil { + return "", err + } + defer func(tcpListener *net.TCPListener) { + <-closeListenerChan + if err := tcpListener.Close(); err != nil { + logger.Warnf("error while closing TCP listener. %v", err) + } + }(tcpListener) + + go handleOAuthSocket(tcpListener, successChan, errChan, responseBodyChan, closeListenerChan) + + oauth2cfg := oauthClient.buildOauthConfig(callbackPort) + codeVerifier := authCodeProvider.createCodeVerifier() + state := authCodeProvider.createState() + authorizationUrl := oauth2cfg.AuthCodeURL(state, oauth2.S256ChallengeOption(codeVerifier)) + if err = authCodeProvider.run(authorizationUrl); err != nil { + responseBodyChan <- err.Error() + closeListenerChan <- true + return "", err + } + + err = <-errChan + if err != nil { + responseBodyChan <- err.Error() + return "", err + } + codeReqBytes := <-successChan + + codeReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(codeReqBytes))) + if err != nil { + responseBodyChan <- err.Error() + return "", err + } + + return oauthClient.exchangeAccessToken(codeReq, state, err, oauth2cfg, codeVerifier, responseBodyChan, errChan) +} + +func (oauthClient *oauthClient) setupListener() (*net.TCPListener, int, error) { + tcpListener, err := createLocalTCPListener(oauthClient.port) + if err != nil { + return nil, 0, err + } + callbackPort := tcpListener.Addr().(*net.TCPAddr).Port + return tcpListener, callbackPort, nil +} + +func (oauthClient *oauthClient) exchangeAccessToken(codeReq *http.Request, state string, err error, oauth2cfg *oauth2.Config, codeVerifier string, responseBodyChan chan string, errChan chan error) (string, error) { + queryParams := codeReq.URL.Query() + errorMsg := queryParams.Get("error") + if errorMsg != "" { + errorDesc := queryParams.Get("error_description") + errMsg := fmt.Sprintf("error while getting authentication from oauth: %v. Details: %v", errorMsg, errorDesc) + responseBodyChan <- errMsg + return "", errors.New(errMsg) + } + + receivedState := queryParams.Get("state") + if state != receivedState { + errMsg := "invalid oauth state received" + responseBodyChan <- errMsg + return "", errors.New(errMsg) + } + + code := queryParams.Get("code") + token, err := oauth2cfg.Exchange(oauthClient.ctx, code, oauth2.VerifierOption(codeVerifier)) + if err != nil { + responseBodyChan <- err.Error() + return "", err + } + responseBodyChan <- oauthSuccessHTML + return token.AccessToken, nil +} + +func (oauthClient *oauthClient) buildOauthConfig(callbackPort int) *oauth2.Config { + return &oauth2.Config{ + ClientID: oauthClient.cfg.OauthClientID, + ClientSecret: oauthClient.cfg.OauthClientSecret, + RedirectURL: oauthClient.buildRedirectUri(callbackPort), + Scopes: oauthClient.buildScopes(), + Endpoint: oauth2.Endpoint{ + AuthURL: oauthClient.cfg.OauthAuthorizationURL, + TokenURL: oauthClient.cfg.OauthTokenRequestURL, + }, + } +} + +func (oauthClient *oauthClient) buildRedirectUri(port int) string { + if oauthClient.cfg.OauthRedirectURI != "" { + return oauthClient.cfg.OauthRedirectURI + } + return fmt.Sprintf(oauthClient.redirectUriTemplate, port) +} + +func (oauthClient *oauthClient) buildScopes() []string { + if oauthClient.cfg.OauthScope == "" { + return []string{"session:role:" + oauthClient.cfg.Role} + } + scopes := strings.Split(oauthClient.cfg.OauthScope, ",") + for i, scope := range scopes { + scopes[i] = strings.TrimSpace(scope) + } + return scopes +} + +func handleOAuthSocket(tcpListener *net.TCPListener, successChan chan []byte, errChan chan error, responseBodyChan chan string, closeListenerChan chan bool) { + conn, err := tcpListener.AcceptTCP() + if err != nil { + logger.Warnf("error creating socket. %v", err) + return + } + defer conn.Close() + var buf [bufSize]byte + codeResp := bytes.NewBuffer(nil) + for { + readBytes, err := conn.Read(buf[:]) + if err == io.EOF { + break + } + if err != nil { + errChan <- err + return + } + codeResp.Write(buf[0:readBytes]) + if readBytes < bufSize { + break + } + } + + errChan <- nil + successChan <- codeResp.Bytes() + + responseBody := <-responseBodyChan + respToBrowser, err := buildResponse(responseBody) + if err != nil { + logger.Warnf("cannot create response to browser. %v", err) + } + _, err = conn.Write(respToBrowser.Bytes()) + if err != nil { + logger.Warnf("cannot write response to browser. %v", err) + } + closeListenerChan <- true +} + +type authorizationCodeProvider interface { + run(authorizationUrl string) error + createState() string + createCodeVerifier() string +} + +type browserBasedAuthorizationCodeProvider struct { + conn *net.TCPConn +} + +func newBrowserBasedAuthorizationCodeProvider() authorizationCodeProvider { + return &browserBasedAuthorizationCodeProvider{} +} + +func (provider *browserBasedAuthorizationCodeProvider) run(authorizationUrl string) error { + return openBrowser(authorizationUrl) +} + +func (provider *browserBasedAuthorizationCodeProvider) createState() string { + return NewUUID().String() +} + +func (provider *browserBasedAuthorizationCodeProvider) createCodeVerifier() string { + return oauth2.GenerateVerifier() +} diff --git a/auth_oauth_test.go b/auth_oauth_test.go new file mode 100644 index 000000000..275aafe55 --- /dev/null +++ b/auth_oauth_test.go @@ -0,0 +1,123 @@ +package gosnowflake + +import ( + "context" + "errors" + "golang.org/x/oauth2" + "io" + "net/http" + "testing" + "time" +) + +func TestUnitOAuthAuthorizationCode(t *testing.T) { + client, err := newOauthClient(context.Background(), &Config{ + Role: "ANALYST", + OauthClientID: "testClientId", + OauthClientSecret: "testClientSecret", + OauthAuthorizationURL: wiremock.baseUrl() + "/oauth/authorize", + OauthTokenRequestURL: wiremock.baseUrl() + "/oauth/token", + OauthRedirectURI: "http://localhost:1234/snowflake/oauth-redirect", + }) + assertNilF(t, err) + + t.Run("Success", func(t *testing.T) { + wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/successful_flow.json")) + authCodeProvider := &nonInteractiveAuthorizationCodeProvider{} + client.authorizationCodeProviderFactory = func() authorizationCodeProvider { + return authCodeProvider + } + token, err := client.authenticateByOAuthAuthorizationCode() + assertNilF(t, err) + assertEqualE(t, token, "access-token-123") + time.Sleep(100 * time.Millisecond) + assertStringContainsE(t, authCodeProvider.responseBody, "OAuth authentication completed successfully.") + }) + + t.Run("InvalidState", func(t *testing.T) { + wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/successful_flow.json")) + authCodeProvider := &nonInteractiveAuthorizationCodeProvider{ + tamperWithState: true, + } + client.authorizationCodeProviderFactory = func() authorizationCodeProvider { + return authCodeProvider + } + _, err = client.authenticateByOAuthAuthorizationCode() + assertEqualE(t, err.Error(), "invalid oauth state received") + time.Sleep(100 * time.Millisecond) + assertStringContainsE(t, authCodeProvider.responseBody, "invalid oauth state received") + }) + + t.Run("ErrorFromIdPWhileGettingCode", func(t *testing.T) { + wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/error_from_idp.json")) + authCodeProvider := &nonInteractiveAuthorizationCodeProvider{} + client.authorizationCodeProviderFactory = func() authorizationCodeProvider { + return authCodeProvider + } + _, err = client.authenticateByOAuthAuthorizationCode() + assertEqualE(t, err.Error(), "error while getting authentication from oauth: some error. Details: some error desc") + time.Sleep(100 * time.Millisecond) + assertStringContainsE(t, authCodeProvider.responseBody, "error while getting authentication from oauth: some error. Details: some error desc") + }) + + t.Run("ErrorFromProviderWhileGettingCode", func(t *testing.T) { + authCodeProvider := &nonInteractiveAuthorizationCodeProvider{ + triggerError: "test error", + } + client.authorizationCodeProviderFactory = func() authorizationCodeProvider { + return authCodeProvider + } + _, err = client.authenticateByOAuthAuthorizationCode() + assertEqualE(t, err.Error(), "test error") + }) + + t.Run("InvalidCode", func(t *testing.T) { + wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/invalid_code.json")) + authCodeProvider := &nonInteractiveAuthorizationCodeProvider{} + client.authorizationCodeProviderFactory = func() authorizationCodeProvider { + return authCodeProvider + } + _, err = client.authenticateByOAuthAuthorizationCode() + assertNotNilE(t, err) + assertEqualE(t, err.(*oauth2.RetrieveError).ErrorCode, "invalid_grant") + assertEqualE(t, err.(*oauth2.RetrieveError).ErrorDescription, "The authorization code is invalid or has expired.") + time.Sleep(100 * time.Millisecond) + assertStringContainsE(t, authCodeProvider.responseBody, "invalid_grant") + }) +} + +type nonInteractiveAuthorizationCodeProvider struct { + t *testing.T + callbackUrl string + tamperWithState bool + error string + errorDescription string + triggerError string + responseBody string +} + +func (provider *nonInteractiveAuthorizationCodeProvider) run(authorizationUrl string) error { + if provider.triggerError != "" { + return errors.New(provider.triggerError) + } + go func() { + resp, err := http.Get(authorizationUrl) + assertNilF(provider.t, err) + assertEqualE(provider.t, resp.StatusCode, http.StatusOK) + respBody, err := io.ReadAll(resp.Body) + assertNilF(provider.t, err) + provider.responseBody = string(respBody) + }() + return nil +} + +func (provider *nonInteractiveAuthorizationCodeProvider) createState() string { + if provider.tamperWithState { + return "invalidState" + } + return "testState" +} + +func (provider *nonInteractiveAuthorizationCodeProvider) createCodeVerifier() string { + return "testCodeVerifier" +} diff --git a/auth_test.go b/auth_test.go index c559a2407..76bcbbb62 100644 --- a/auth_test.go +++ b/auth_test.go @@ -1003,3 +1003,22 @@ func TestContextPropagatedToAuthWhenUsingOpenDB(t *testing.T) { assertStringContainsE(t, err.Error(), "context deadline exceeded") cancel() } + +func TestWithOauthAuthorizationCodeFlowManual(t *testing.T) { + t.Skip("manual test") + cfg, err := GetConfigFromEnv([]*ConfigParam{ + {"OAuthClientId", "SNOWFLAKE_TEST_OAUTH_CLIENT_ID", true}, + {"OAuthClientSecret", "SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET", true}, + {"OAuthAuthorizationURL", "SNOWFLAKE_TEST_OAUTH_AUTHORIZATION_URL", true}, + {"OAuthTokenRequestURL", "SNOWFLAKE_TEST_OAUTH_TOKEN_REQUEST_URL", true}, + {"OAuthRedirectURI", "SNOWFLAKE_TEST_OAUTH_REDIRECT_URI", true}, + {"Role", "SNOWFLAKE_TEST_OAUTH_ROLE", true}, + {"Account", "SNOWFLAKE_TEST_ACCOUNT", true}, + }) + assertNilF(t, err) + cfg.Authenticator = AuthTypeOAuthAuthorizationCode + connector := NewConnector(&SnowflakeDriver{}, *cfg) + db := sql.OpenDB(connector) + defer db.Close() + runSmokeQuery(t, db) +} diff --git a/authexternalbrowser.go b/authexternalbrowser.go index 04a91fd97..67f69fb6d 100644 --- a/authexternalbrowser.go +++ b/authexternalbrowser.go @@ -22,22 +22,19 @@ import ( ) const ( - successHTML = ` + samlSuccessHTML = ` SAML Response for Snowflake Your identity was confirmed and propagated to Snowflake %v. You can close this window now and go back where you started from. ` -) -const ( bufSize = 8192 ) // Builds a response to show to the user after successfully // getting a response from Snowflake. -func buildResponse(application string) (bytes.Buffer, error) { - body := fmt.Sprintf(successHTML, application) +func buildResponse(body string) (bytes.Buffer, error) { t := &http.Response{ Status: "200 OK", StatusCode: 200, @@ -57,8 +54,8 @@ func buildResponse(application string) (bytes.Buffer, error) { // This opens a socket that listens on all available unicast // and any anycast IP addresses locally. By specifying "0", we are // able to bind to a free port. -func createLocalTCPListener() (*net.TCPListener, error) { - l, err := net.Listen("tcp", "localhost:0") +func createLocalTCPListener(port int) (*net.TCPListener, error) { + l, err := net.Listen("tcp", fmt.Sprintf("localhost:%v", port)) if err != nil { return nil, err } @@ -243,7 +240,7 @@ func doAuthenticateByExternalBrowser( password string, disableConsoleLogin ConfigBool, ) authenticateByExternalBrowserResult { - l, err := createLocalTCPListener() + l, err := createLocalTCPListener(0) if err != nil { return authenticateByExternalBrowserResult{nil, nil, err} } @@ -310,7 +307,8 @@ func doAuthenticateByExternalBrowser( buf.Grow(bufSize) } if encodedSamlResponse != "" { - httpResponse, err := buildResponse(application) + body := fmt.Sprintf(samlSuccessHTML, application) + httpResponse, err := buildResponse(body) if err != nil && errAccept == nil { errAccept = err } diff --git a/authexternalbrowser_test.go b/authexternalbrowser_test.go index 4546da461..4c9765d1e 100644 --- a/authexternalbrowser_test.go +++ b/authexternalbrowser_test.go @@ -135,7 +135,7 @@ func TestAuthenticationTimeout(t *testing.T) { } func Test_createLocalTCPListener(t *testing.T) { - listener, err := createLocalTCPListener() + listener, err := createLocalTCPListener(0) if err != nil { t.Fatalf("createLocalTCPListener() failed: %v", err) } diff --git a/driver_test.go b/driver_test.go index ba465dd5a..daf34af99 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2111,3 +2111,14 @@ func TestTimePrecision(t *testing.T) { } }) } + +func runSmokeQuery(t *testing.T, db *sql.DB) { + rows, err := db.Query("SELECT 1") + assertNilF(t, err) + defer rows.Close() + assertTrueF(t, rows.Next()) + var v int + err = rows.Scan(&v) + assertNilF(t, err) + assertEqualE(t, v, 1) +} \ No newline at end of file diff --git a/dsn.go b/dsn.go index d63eccc84..810c4fcd2 100644 --- a/dsn.go +++ b/dsn.go @@ -601,14 +601,16 @@ func buildHostFromAccountAndRegion(account, region string) string { func authRequiresUser(cfg *Config) bool { return cfg.Authenticator != AuthTypeOAuth && cfg.Authenticator != AuthTypeTokenAccessor && - cfg.Authenticator != AuthTypeExternalBrowser + cfg.Authenticator != AuthTypeExternalBrowser && + cfg.Authenticator != AuthTypeOAuthAuthorizationCode } func authRequiresPassword(cfg *Config) bool { return cfg.Authenticator != AuthTypeOAuth && cfg.Authenticator != AuthTypeTokenAccessor && cfg.Authenticator != AuthTypeExternalBrowser && - cfg.Authenticator != AuthTypeJwt + cfg.Authenticator != AuthTypeJwt && + cfg.Authenticator != AuthTypeOAuthAuthorizationCode } // transformAccountToHost transforms account to host @@ -943,6 +945,7 @@ type ConfigParam struct { // GetConfigFromEnv is used to parse the environment variable values to specific fields of the Config func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) { var account, user, password, role, host, portStr, protocol, warehouse, database, schema, region, passcode, application string + var oauthClientId, oauthClientSecret, oauthAuthorizationUrl, oauthTokenRequestUrl, oauthRedirectUri string var privateKey *rsa.PrivateKey var err error if len(properties) == 0 || properties == nil { @@ -985,6 +988,18 @@ func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) { if err != nil { return nil, err } + case "OAuthClientId": + oauthClientId = value + case "OAuthClientSecret": + oauthClientSecret = value + case "OAuthAuthorizationURL": + oauthAuthorizationUrl = value + case "OAuthTokenRequestURL": + oauthTokenRequestUrl = value + case "OAuthRedirectURI": + oauthRedirectUri = value + default: + return nil, errors.New("unknown property: " + prop.Name) } } @@ -997,21 +1012,26 @@ func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) { } cfg := &Config{ - Account: account, - User: user, - Password: password, - Role: role, - Host: host, - Port: port, - Protocol: protocol, - Warehouse: warehouse, - Database: database, - Schema: schema, - PrivateKey: privateKey, - Region: region, - Passcode: passcode, - Application: application, - Params: map[string]*string{}, + Account: account, + User: user, + Password: password, + Role: role, + Host: host, + Port: port, + Protocol: protocol, + Warehouse: warehouse, + Database: database, + Schema: schema, + PrivateKey: privateKey, + Region: region, + Passcode: passcode, + Application: application, + OauthClientID: oauthClientId, + OauthClientSecret: oauthClientSecret, + OauthAuthorizationURL: oauthAuthorizationUrl, + OauthTokenRequestURL: oauthTokenRequestUrl, + OauthRedirectURI: oauthRedirectUri, + Params: map[string]*string{}, } return cfg, nil } diff --git a/dsn_test.go b/dsn_test.go index 2f63f9acd..6fae4f89f 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -466,6 +466,22 @@ func TestParseDSN(t *testing.T) { ocspMode: ocspModeFailOpen, err: nil, }, + { + dsn: "snowflake.local:9876?account=a&protocol=http&authenticator=OAUTH_AUTHORIZATION_CODE", + config: &Config{ + Account: "a", Authenticator: AuthTypeOAuthAuthorizationCode, + Protocol: "http", Host: "snowflake.local", Port: 9876, + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, + IncludeRetryReason: ConfigBoolTrue, + }, + ocspMode: ocspModeFailOpen, + err: nil, + }, { dsn: "u:@a.snowflake.local:9876?account=a&protocol=http&authenticator=SNOWFLAKE_JWT", config: &Config{ @@ -1515,6 +1531,16 @@ func TestDSN(t *testing.T) { }, dsn: "u:p@a.snowflakecomputing.com:443?authenticator=externalbrowser&clientStoreTemporaryCredential=false&ocspFailOpen=true&validateDefaultParameters=true", }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a", + Authenticator: AuthTypeOAuthAuthorizationCode, + ClientStoreTemporaryCredential: ConfigBoolFalse, + }, + dsn: "u:p@a.snowflakecomputing.com:443?authenticator=oauth_authorization_code&clientStoreTemporaryCredential=false&ocspFailOpen=true&validateDefaultParameters=true", + }, { cfg: &Config{ User: "u", diff --git a/go.mod b/go.mod index efb9a2e06..9b4850344 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/net v0.33.0 // indirect + golang.org/x/oauth2 v0.26.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/term v0.27.0 // indirect golang.org/x/text v0.21.0 // indirect diff --git a/go.sum b/go.sum index 88392ef77..45156069b 100644 --- a/go.sum +++ b/go.sum @@ -137,6 +137,8 @@ golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE= +golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/test_data/wiremock/mappings/oauth2/authorization_code/error_from_idp.json b/test_data/wiremock/mappings/oauth2/authorization_code/error_from_idp.json new file mode 100644 index 000000000..1e3e352c5 --- /dev/null +++ b/test_data/wiremock/mappings/oauth2/authorization_code/error_from_idp.json @@ -0,0 +1,39 @@ +{ + "mappings": [ + { + "request": { + "urlPathPattern": "/oauth/authorize", + "queryParameters": { + "response_type": { + "equalTo": "code" + }, + "scope": { + "equalTo": "session:role:ANALYST" + }, + "code_challenge_method": { + "equalTo": "S256" + }, + "redirect_uri": { + "equalTo": "http://localhost:1234/snowflake/oauth-redirect" + }, + "code_challenge": { + "matches": ".+" + }, + "state": { + "matches": "testState|invalidState" + }, + "client_id": { + "equalTo": "testClientId" + } + }, + "method": "GET" + }, + "response": { + "status": 302, + "headers": { + "Location": "http://localhost:1234/snowflake/oauth-redirect?error=some+error&error_description=some+error+desc" + } + } + } + ] +} \ No newline at end of file diff --git a/test_data/wiremock/mappings/oauth2/authorization_code/invalid_code.json b/test_data/wiremock/mappings/oauth2/authorization_code/invalid_code.json new file mode 100644 index 000000000..41329e16a --- /dev/null +++ b/test_data/wiremock/mappings/oauth2/authorization_code/invalid_code.json @@ -0,0 +1,78 @@ +{ + "mappings": [ + { + "request": { + "urlPathPattern": "/oauth/authorize", + "queryParameters": { + "response_type": { + "equalTo": "code" + }, + "scope": { + "equalTo": "session:role:ANALYST" + }, + "code_challenge_method": { + "equalTo": "S256" + }, + "redirect_uri": { + "equalTo": "http://localhost:1234/snowflake/oauth-redirect" + }, + "code_challenge": { + "matches": ".+" + }, + "state": { + "matches": "testState" + }, + "client_id": { + "equalTo": "testClientId" + } + }, + "method": "GET" + }, + "response": { + "status": 302, + "headers": { + "Location": "http://localhost:1234/snowflake/oauth-redirect?code=testCode&state=testState" + } + } + }, + { + "scenarioName": "Successful token exchange", + "request": { + "urlPathPattern": "/oauth/token", + "method": "POST", + "headers": { + "Content-Type": { + "contains": "application/x-www-form-urlencoded" + } + }, + "formParameters": { + "grant_type": { + "equalTo": "authorization_code" + }, + "client_id": { + "equalTo": "testClientId" + }, + "client_secret": { + "equalTo": "testClientSecret" + }, + "code_verifier": { + "matches": "[a-zA-Z0-9\\-_]+" + }, + "code": { + "equalTo": "testCode" + }, + "redirect_uri": { + "equalTo": "http://localhost:1234/snowflake/oauth-redirect" + } + } + }, + "response": { + "status": 400, + "jsonBody": { + "error" : "invalid_grant", + "error_description" : "The authorization code is invalid or has expired." + } + } + } + ] +} \ No newline at end of file diff --git a/test_data/wiremock/mappings/oauth2/authorization_code/successful_flow.json b/test_data/wiremock/mappings/oauth2/authorization_code/successful_flow.json new file mode 100644 index 000000000..3715a4752 --- /dev/null +++ b/test_data/wiremock/mappings/oauth2/authorization_code/successful_flow.json @@ -0,0 +1,83 @@ +{ + "mappings": [ + { + "request": { + "urlPathPattern": "/oauth/authorize", + "queryParameters": { + "response_type": { + "equalTo": "code" + }, + "scope": { + "equalTo": "session:role:ANALYST" + }, + "code_challenge_method": { + "equalTo": "S256" + }, + "redirect_uri": { + "equalTo": "http://localhost:1234/snowflake/oauth-redirect" + }, + "code_challenge": { + "matches": "JZpN_-zfNduuWm-zUo-D-m7vMw_pgUGv8wGDGqBR8PM" + }, + "state": { + "matches": "testState|invalidState" + }, + "client_id": { + "equalTo": "testClientId" + } + }, + "method": "GET" + }, + "response": { + "status": 302, + "headers": { + "Location": "http://localhost:1234/snowflake/oauth-redirect?code=testCode&state=testState" + } + } + }, + { + "request": { + "urlPathPattern": "/oauth/token", + "method": "POST", + "headers": { + "Content-Type": { + "contains": "application/x-www-form-urlencoded" + } + }, + "formParameters": { + "grant_type": { + "equalTo": "authorization_code" + }, + "client_id": { + "equalTo": "testClientId" + }, + "client_secret": { + "equalTo": "testClientSecret" + }, + "code_verifier": { + "matches": "testCodeVerifier" + }, + "code": { + "equalTo": "testCode" + }, + "redirect_uri": { + "equalTo": "http://localhost:1234/snowflake/oauth-redirect" + } + } + }, + "response": { + "status": 200, + "jsonBody": { + "access_token": "access-token-123", + "refresh_token": "123", + "token_type": "Bearer", + "username": "test-user", + "scope": "refresh_token session:role:ANALYST", + "expires_in": 600, + "refresh_token_expires_in": 86399, + "idpInitiated": false + } + } + } + ] +} \ No newline at end of file diff --git a/wiremock_test.go b/wiremock_test.go index ed37c0bdb..81e23106f 100644 --- a/wiremock_test.go +++ b/wiremock_test.go @@ -58,6 +58,10 @@ type wiremockMapping struct { params map[string]string } +func newWiremockMapping(filePath string) wiremockMapping { + return wiremockMapping{filePath: filePath} +} + func (wm *wiremockClient) registerMappings(t *testing.T, mappings ...wiremockMapping) { for _, mapping := range wm.enrichWithTelemetry(mappings) { f, err := os.Open("test_data/wiremock/mappings/" + mapping.filePath) @@ -92,7 +96,11 @@ func (wm *wiremockClient) enrichWithTelemetry(mappings []wiremockMapping) []wire } func (wm *wiremockClient) mappingsURL() string { - return fmt.Sprintf("%v://%v:%v/__admin/mappings", wm.protocol, wm.host, wm.port) + return fmt.Sprintf("%v/__admin/mappings", wm.baseUrl()) +} + +func (wm *wiremockClient) baseUrl() string { + return fmt.Sprintf("%v://%v:%v", wm.protocol, wm.host, wm.port) } // just to satisfy not used private variables and functions