-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SNOW-1346233] Tests for authentication methods (external browser, oa…
…uth, okta, keypair) (#1264)
- Loading branch information
1 parent
220e36e
commit 80c18ea
Showing
14 changed files
with
498 additions
and
6 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
package gosnowflake | ||
|
||
import ( | ||
"fmt" | ||
"testing" | ||
) | ||
|
||
func getAuthTestConfigFromEnv() (*Config, error) { | ||
return GetConfigFromEnv([]*ConfigParam{ | ||
{Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true}, | ||
{Name: "User", EnvName: "SNOWFLAKE_AUTH_TEST_OKTA_USER", FailOnMissing: true}, | ||
{Name: "Password", EnvName: "SNOWFLAKE_AUTH_TEST_OKTA_PASS", FailOnMissing: true}, | ||
{Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false}, | ||
{Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false}, | ||
{Name: "Protocol", EnvName: "SNOWFLAKE_AUTH_TEST_PROTOCOL", FailOnMissing: false}, | ||
{Name: "Role", EnvName: "SNOWFLAKE_TEST_ROLE", FailOnMissing: false}, | ||
}) | ||
} | ||
|
||
func getAuthTestsConfig(t *testing.T, authMethod AuthType) (*Config, error) { | ||
cfg, err := getAuthTestConfigFromEnv() | ||
assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err)) | ||
|
||
cfg.Authenticator = authMethod | ||
|
||
return cfg, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
package gosnowflake | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"errors" | ||
"fmt" | ||
"log" | ||
"os/exec" | ||
"sync" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestExternalBrowserSuccessful(t *testing.T) { | ||
cfg := setupExternalBrowserTest(t) | ||
var wg sync.WaitGroup | ||
wg.Add(2) | ||
go func() { | ||
defer wg.Done() | ||
provideExternalBrowserCredentials(t, externalBrowserType.Success, cfg.User, cfg.Password) | ||
}() | ||
go func() { | ||
defer wg.Done() | ||
err := verifyConnectionToSnowflakeAuthTests(t, cfg) | ||
assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err)) | ||
}() | ||
wg.Wait() | ||
} | ||
|
||
func TestExternalBrowserFailed(t *testing.T) { | ||
cfg := setupExternalBrowserTest(t) | ||
cfg.ExternalBrowserTimeout = time.Duration(10) * time.Second | ||
var wg sync.WaitGroup | ||
wg.Add(2) | ||
go func() { | ||
defer wg.Done() | ||
provideExternalBrowserCredentials(t, externalBrowserType.Fail, "FakeAccount", "NotARealPassword") | ||
}() | ||
go func() { | ||
defer wg.Done() | ||
err := verifyConnectionToSnowflakeAuthTests(t, cfg) | ||
assertEqualE(t, err.Error(), "authentication timed out") | ||
}() | ||
wg.Wait() | ||
} | ||
|
||
func TestExternalBrowserTimeout(t *testing.T) { | ||
cfg := setupExternalBrowserTest(t) | ||
cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second | ||
var wg sync.WaitGroup | ||
wg.Add(2) | ||
go func() { | ||
defer wg.Done() | ||
provideExternalBrowserCredentials(t, externalBrowserType.Timeout, cfg.User, cfg.Password) | ||
}() | ||
go func() { | ||
defer wg.Done() | ||
err := verifyConnectionToSnowflakeAuthTests(t, cfg) | ||
assertEqualE(t, err.Error(), "authentication timed out") | ||
}() | ||
wg.Wait() | ||
} | ||
|
||
func TestExternalBrowserMismatchUser(t *testing.T) { | ||
cfg := setupExternalBrowserTest(t) | ||
correctUsername := cfg.User | ||
cfg.User = "fakeAccount" | ||
var wg sync.WaitGroup | ||
|
||
wg.Add(2) | ||
go func() { | ||
defer wg.Done() | ||
provideExternalBrowserCredentials(t, externalBrowserType.Success, correctUsername, cfg.Password) | ||
}() | ||
go func() { | ||
defer wg.Done() | ||
err := verifyConnectionToSnowflakeAuthTests(t, cfg) | ||
var snowflakeErr *SnowflakeError | ||
assertTrueF(t, errors.As(err, &snowflakeErr)) | ||
assertEqualE(t, snowflakeErr.Number, 390191, fmt.Sprintf("Expected 390191, but got %v", snowflakeErr.Number)) | ||
}() | ||
wg.Wait() | ||
} | ||
|
||
func TestClientStoreCredentials(t *testing.T) { | ||
cfg := setupExternalBrowserTest(t) | ||
cfg.ClientStoreTemporaryCredential = 1 | ||
cfg.ExternalBrowserTimeout = time.Duration(10) * time.Second | ||
|
||
t.Run("Obtains the ID token from the server and saves it on the local storage", func(t *testing.T) { | ||
cleanupBrowserProcesses(t) | ||
var wg sync.WaitGroup | ||
wg.Add(2) | ||
go func() { | ||
defer wg.Done() | ||
provideExternalBrowserCredentials(t, externalBrowserType.Success, cfg.User, cfg.Password) | ||
}() | ||
go func() { | ||
defer wg.Done() | ||
err := verifyConnectionToSnowflakeAuthTests(t, cfg) | ||
assertNilE(t, err, fmt.Sprintf("Connection failed: err %v", err)) | ||
}() | ||
wg.Wait() | ||
}) | ||
|
||
t.Run("Verify validation of ID token if option enabled", func(t *testing.T) { | ||
cleanupBrowserProcesses(t) | ||
cfg.ClientStoreTemporaryCredential = 1 | ||
db := getDbHandlerFromConfig(t, cfg) | ||
conn, err := db.Conn(context.Background()) | ||
assertNilE(t, err, fmt.Sprintf("Failed to connect to Snowflake. err: %v", err)) | ||
defer conn.Close() | ||
rows, err := conn.QueryContext(context.Background(), "SELECT 1") | ||
assertNilE(t, err, fmt.Sprintf("Failed to run a query. err: %v", err)) | ||
rows.Close() | ||
}) | ||
|
||
t.Run("Verify validation of IDToken if option disabled", func(t *testing.T) { | ||
cleanupBrowserProcesses(t) | ||
cfg.ClientStoreTemporaryCredential = 0 | ||
db := getDbHandlerFromConfig(t, cfg) | ||
_, err := db.Conn(context.Background()) | ||
assertEqualE(t, err.Error(), "authentication timed out", fmt.Sprintf("Expected timeout, but got %v", err)) | ||
}) | ||
} | ||
|
||
type ExternalBrowserProcessResult struct { | ||
Success string | ||
Fail string | ||
Timeout string | ||
} | ||
|
||
var externalBrowserType = ExternalBrowserProcessResult{ | ||
Success: "success", | ||
Fail: "fail", | ||
Timeout: "timeout", | ||
} | ||
|
||
func cleanupBrowserProcesses(t *testing.T) { | ||
const cleanBrowserProcessesPath = "/externalbrowser/cleanBrowserProcesses.js" | ||
_, err := exec.Command("node", cleanBrowserProcessesPath).Output() | ||
assertNilE(t, err, fmt.Sprintf("failed to execute command: %v", err)) | ||
} | ||
|
||
func provideExternalBrowserCredentials(t *testing.T, ExternalBrowserProcess string, user string, password string) { | ||
const provideBrowserCredentialsPath = "/externalbrowser/provideBrowserCredentials.js" | ||
_, err := exec.Command("node", provideBrowserCredentialsPath, ExternalBrowserProcess, user, password).Output() | ||
assertNilE(t, err, fmt.Sprintf("failed to execute command: %v", err)) | ||
} | ||
|
||
func verifyConnectionToSnowflakeAuthTests(t *testing.T, cfg *Config) (err error) { | ||
dsn, err := DSN(cfg) | ||
assertNilE(t, err, "failed to create DSN from Config") | ||
|
||
db, err := sql.Open("snowflake", dsn) | ||
assertNilE(t, err, "failed to open Snowflake DB connection") | ||
defer db.Close() | ||
|
||
rows, err := db.Query("SELECT 1") | ||
if err != nil { | ||
log.Printf("failed to run a query. 'SELECT 1', err: %v", err) | ||
return err | ||
} | ||
|
||
defer rows.Close() | ||
assertTrueE(t, rows.Next(), "failed to get result", "There were no results for query: ") | ||
|
||
return err | ||
} | ||
|
||
func setupExternalBrowserTest(t *testing.T) *Config { | ||
runOnlyOnDockerContainer(t, "Running only on Docker container") | ||
cleanupBrowserProcesses(t) | ||
cfg, err := getAuthTestsConfig(t, AuthTypeExternalBrowser) | ||
assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err)) | ||
return cfg | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package gosnowflake | ||
|
||
import ( | ||
"crypto/rsa" | ||
"errors" | ||
"fmt" | ||
"golang.org/x/crypto/ssh" | ||
"os" | ||
"testing" | ||
) | ||
|
||
func TestKeypairSuccessful(t *testing.T) { | ||
cfg := setupKeyPairTest(t) | ||
cfg.PrivateKey = loadRsaPrivateKeyForKeyPair(t, "SNOWFLAKE_AUTH_TEST_PRIVATE_KEY_PATH") | ||
|
||
err := verifyConnectionToSnowflakeAuthTests(t, cfg) | ||
assertNilE(t, err, fmt.Sprintf("failed to connect. err: %v", err)) | ||
} | ||
|
||
func TestKeypairInvalidKey(t *testing.T) { | ||
cfg := setupKeyPairTest(t) | ||
cfg.PrivateKey = loadRsaPrivateKeyForKeyPair(t, "SNOWFLAKE_AUTH_TEST_INVALID_PRIVATE_KEY_PATH") | ||
err := verifyConnectionToSnowflakeAuthTests(t, cfg) | ||
var snowflakeErr *SnowflakeError | ||
assertTrueF(t, errors.As(err, &snowflakeErr)) | ||
assertEqualE(t, snowflakeErr.Number, 390144, fmt.Sprintf("Expected 390144, but got %v", snowflakeErr.Number)) | ||
} | ||
|
||
func setupKeyPairTest(t *testing.T) *Config { | ||
runOnlyOnDockerContainer(t, "Running only on Docker container") | ||
|
||
cfg, err := getAuthTestsConfig(t, AuthTypeJwt) | ||
assertEqualE(t, err, nil, fmt.Sprintf("failed to get config: %v", err)) | ||
|
||
return cfg | ||
} | ||
|
||
func loadRsaPrivateKeyForKeyPair(t *testing.T, envName string) *rsa.PrivateKey { | ||
filePath, err := GetFromEnv(envName, true) | ||
assertNilF(t, err, fmt.Sprintf("failed to get env: %v", err)) | ||
|
||
bytes, err := os.ReadFile(filePath) | ||
assertNilF(t, err, fmt.Sprintf("failed to read file: %v", err)) | ||
|
||
key, err := ssh.ParseRawPrivateKey(bytes) | ||
assertNilF(t, err, fmt.Sprintf("failed to parse private key: %v", err)) | ||
|
||
return key.(*rsa.PrivateKey) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
package gosnowflake | ||
|
||
import ( | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
"net/url" | ||
"strings" | ||
"testing" | ||
) | ||
|
||
func TestOauthSuccessful(t *testing.T) { | ||
cfg := setupOauthTest(t) | ||
token, err := getOauthTestToken(t, cfg) | ||
assertNilE(t, err, fmt.Sprintf("failed to get token. err: %v", err)) | ||
cfg.Token = token | ||
err = verifyConnectionToSnowflakeAuthTests(t, cfg) | ||
assertNilE(t, err, fmt.Sprintf("failed to connect. err: %v", err)) | ||
} | ||
|
||
func TestOauthInvalidToken(t *testing.T) { | ||
cfg := setupOauthTest(t) | ||
cfg.Token = "invalid_token" | ||
|
||
err := verifyConnectionToSnowflakeAuthTests(t, cfg) | ||
|
||
var snowflakeErr *SnowflakeError | ||
assertTrueF(t, errors.As(err, &snowflakeErr)) | ||
assertEqualE(t, snowflakeErr.Number, 390303, fmt.Sprintf("Expected 390303, but got %v", snowflakeErr.Number)) | ||
} | ||
|
||
func TestOauthMismatchedUser(t *testing.T) { | ||
cfg := setupOauthTest(t) | ||
token, err := getOauthTestToken(t, cfg) | ||
assertNilE(t, err, fmt.Sprintf("failed to get token. err: %v", err)) | ||
cfg.Token = token | ||
cfg.User = "fakeaccount" | ||
|
||
err = verifyConnectionToSnowflakeAuthTests(t, cfg) | ||
|
||
var snowflakeErr *SnowflakeError | ||
assertTrueF(t, errors.As(err, &snowflakeErr)) | ||
assertEqualE(t, snowflakeErr.Number, 390309, fmt.Sprintf("Expected 390309, but got %v", snowflakeErr.Number)) | ||
} | ||
|
||
func setupOauthTest(t *testing.T) *Config { | ||
runOnlyOnDockerContainer(t, "Running only on Docker container") | ||
|
||
cfg, err := getAuthTestsConfig(t, AuthTypeOAuth) | ||
assertNilF(t, err, fmt.Sprintf("failed to connect. err: %v", err)) | ||
|
||
return cfg | ||
} | ||
|
||
func getOauthTestToken(t *testing.T, cfg *Config) (string, error) { | ||
|
||
client := &http.Client{} | ||
|
||
authURL, err := GetFromEnv("SNOWFLAKE_AUTH_TEST_OAUTH_URL", true) | ||
assertNilF(t, err, "SNOWFLAKE_AUTH_TEST_OAUTH_URL is not set") | ||
|
||
oauthClientID, err := GetFromEnv("SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_ID", true) | ||
assertNilF(t, err, "SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_ID is not set") | ||
|
||
oauthClientSecret, err := GetFromEnv("SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_SECRET", true) | ||
assertNilF(t, err, "SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_SECRET is not set") | ||
|
||
inputData := formData(cfg) | ||
|
||
req, err := http.NewRequest("POST", authURL, strings.NewReader(inputData.Encode())) | ||
assertNilF(t, err, fmt.Sprintf("Request failed %v", err)) | ||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded;charset=UTF-8") | ||
req.SetBasicAuth(oauthClientID, oauthClientSecret) | ||
resp, err := client.Do(req) | ||
|
||
assertNilF(t, err, fmt.Sprintf("Response failed %v", err)) | ||
|
||
if resp.StatusCode != http.StatusOK { | ||
return "", fmt.Errorf("failed to get access token, status code: %d", resp.StatusCode) | ||
} | ||
|
||
defer resp.Body.Close() | ||
|
||
var response OAuthTokenResponse | ||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { | ||
return "", fmt.Errorf("failed to decode response: %v", err) | ||
} | ||
|
||
return response.Token, err | ||
} | ||
|
||
func formData(cfg *Config) url.Values { | ||
data := url.Values{} | ||
data.Set("username", cfg.User) | ||
data.Set("password", cfg.Password) | ||
data.Set("grant_type", "password") | ||
data.Set("scope", fmt.Sprintf("session:role:%s", strings.ToLower(cfg.Role))) | ||
|
||
return data | ||
|
||
} | ||
|
||
type OAuthTokenResponse struct { | ||
Type string `json:"token_type"` | ||
Expiration int `json:"expires_in"` | ||
Token string `json:"access_token"` | ||
Scope string `json:"scope"` | ||
} |
Oops, something went wrong.