Skip to content

Commit

Permalink
SNOW-1825476 Implement programmatic access token (PAT) (#1298)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus authored Feb 25, 2025
1 parent 11bb91f commit c55b379
Show file tree
Hide file tree
Showing 14 changed files with 358 additions and 28 deletions.
8 changes: 3 additions & 5 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ concurrency:
jobs:
lint:
runs-on: ubuntu-latest
strategy:
fail-fast: false
name: Check linter
steps:
- uses: actions/checkout@v4
Expand All @@ -54,7 +52,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 11
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
Expand Down Expand Up @@ -85,7 +83,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 11
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
Expand Down Expand Up @@ -115,7 +113,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 11
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
Expand Down
24 changes: 24 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"io"
"net/http"
"net/url"
"os"
"runtime"
"strconv"
"strings"
Expand Down Expand Up @@ -49,6 +50,8 @@ const (
AuthTypeTokenAccessor
// AuthTypeUsernamePasswordMFA is to use username and password with mfa
AuthTypeUsernamePasswordMFA
// AuthTypePat is to use programmatic access token
AuthTypePat
)

func determineAuthenticatorType(cfg *Config, value string) error {
Expand All @@ -72,6 +75,9 @@ func determineAuthenticatorType(cfg *Config, value string) error {
} else if upperCaseValue == AuthTypeTokenAccessor.String() {
cfg.Authenticator = AuthTypeTokenAccessor
return nil
} else if upperCaseValue == AuthTypePat.String() && experimentalAuthEnabled() {
cfg.Authenticator = AuthTypePat
return nil
} else {
// possibly Okta case
oktaURLString, err := url.QueryUnescape(lowerCaseValue)
Expand Down Expand Up @@ -121,6 +127,8 @@ func (authType AuthType) String() string {
return "TOKENACCESSOR"
case AuthTypeUsernamePasswordMFA:
return "USERNAME_PASSWORD_MFA"
case AuthTypePat:
return "PROGRAMMATIC_ACCESS_TOKEN"
default:
return "UNKNOWN"
}
Expand Down Expand Up @@ -440,6 +448,17 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
return nil, err
}
requestMain.Token = jwtTokenString
case AuthTypePat:
if !experimentalAuthEnabled() {
return nil, errors.New("programmatic access tokens are not ready to use")
}
logger.WithContext(sc.ctx).Info("Programmatic access token")
requestMain.Authenticator = AuthTypePat.String()
requestMain.LoginName = sc.cfg.User
requestMain.Token = sc.cfg.Token
if sc.cfg.Password != "" && sc.cfg.Token == "" {
requestMain.Token = sc.cfg.Password
}
case AuthTypeSnowflake:
logger.WithContext(sc.ctx).Info("Username and password")
requestMain.LoginName = sc.cfg.User
Expand Down Expand Up @@ -574,3 +593,8 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.ctx = context.WithValue(sc.ctx, SFSessionIDKey, authData.SessionID)
return nil
}

func experimentalAuthEnabled() bool {
val, ok := os.LookupEnv("ENABLE_EXPERIMENTAL_AUTHENTICATION")
return ok && strings.EqualFold(val, "true")
}
57 changes: 57 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1003,3 +1003,60 @@ func TestContextPropagatedToAuthWhenUsingOpenDB(t *testing.T) {
assertStringContainsE(t, err.Error(), "context deadline exceeded")
cancel()
}

func TestPatSuccessfulFlow(t *testing.T) {
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypePat
cfg.Token = "some PAT"
testPatSuccessfulFlow(t, cfg)
}

func testPatSuccessfulFlow(t *testing.T, cfg *Config) {
skipOnJenkins(t, "wiremock is not enabled")
enableExperimentalAuth(t)
wiremock.registerMappings(t,
wiremockMapping{filePath: "auth/pat/successful_flow.json"},
wiremockMapping{filePath: "select1.json", params: map[string]string{
"%AUTHORIZATION_HEADER%": "Snowflake Token=\\\"session token\\\""},
},
)
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
rows, err := db.Query("SELECT 1")
assertNilF(t, err)
var v int
assertTrueE(t, rows.Next())
assertNilF(t, rows.Scan(&v))
assertEqualE(t, v, 1)
}

func enableExperimentalAuth(t *testing.T) {
err := os.Setenv("ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
assertNilF(t, err)
}

func TestPatSuccessfulFlowWithPatAsPasswordWithPatAuthenticator(t *testing.T) {
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypePat
cfg.Password = "some PAT"
testPatSuccessfulFlow(t, cfg)
}

func TestPatInvalidToken(t *testing.T) {
skipOnJenkins(t, "wiremock is not enabled")
enableExperimentalAuth(t)
wiremock.registerMappings(t,
wiremockMapping{filePath: "auth/pat/invalid_token.json"},
)
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypePat
cfg.Token = "some PAT"
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
_, err := db.Query("SELECT 1")
assertNotNilF(t, err)
var se *SnowflakeError
assertTrueF(t, errors.As(err, &se))
assertEqualE(t, se.Number, 394400)
assertEqualE(t, se.Message, "Programmatic access token is invalid.")
}
1 change: 1 addition & 0 deletions cmd/programmatic_access_token/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pat
16 changes: 16 additions & 0 deletions cmd/programmatic_access_token/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
include ../../gosnowflake.mak
CMD_TARGET=pat

## Install
install: cinstall

## Run
run: crun

## Lint
lint: clint

## Format source codes
fmt: cfmt

.PHONY: install run lint fmt
52 changes: 52 additions & 0 deletions cmd/programmatic_access_token/pat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// you have to configure PAT on your user

package main

import (
"database/sql"
"flag"
"fmt"
sf "github.com/snowflakedb/gosnowflake"
"log"
)

func main() {
if !flag.Parsed() {
flag.Parse()
}

cfg, err := sf.GetConfigFromEnv([]*sf.ConfigParam{
{Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true},
{Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true},
{Name: "Token", EnvName: "SNOWFLAKE_TEST_PAT", FailOnMissing: true},
{Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false},
{Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false},
{Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false},
})
cfg.Authenticator = sf.AuthTypePat
if err != nil {
log.Fatalf("cannot build config. %v", err)
}

connector := sf.NewConnector(sf.SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
defer db.Close()

query := "SELECT 1"
rows, err := db.Query(query)
if err != nil {
log.Fatalf("failed to run a query. %v, err: %v", query, err)
}
defer rows.Close()
var v int
if !rows.Next() {
log.Fatalf("no rows returned")
}
if err = rows.Scan(&v); err != nil {
log.Fatalf("failed to scan rows. %v", err)
}
if v != 1 {
log.Fatalf("unexpected result, expected 1, got %v", v)
}
fmt.Printf("Congrats! You have successfully run %v with Snowflake DB!\n", query)
}
3 changes: 3 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ func (d SnowflakeDriver) OpenWithConfig(ctx context.Context, config Config) (dri
if err := config.Validate(); err != nil {
return nil, err
}
if config.Params == nil {
config.Params = make(map[string]*string)
}
if config.Tracing != "" {
if err := logger.SetLogLevel(config.Tracing); err != nil {
return nil, err
Expand Down
20 changes: 17 additions & 3 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ func fillMissingConfigParameters(cfg *Config) error {
if authRequiresPassword(cfg) && strings.TrimSpace(cfg.Password) == "" {
return errEmptyPassword()
}

if authRequiresEitherPasswordOrToken(cfg) && strings.TrimSpace(cfg.Password) == "" && strings.TrimSpace(cfg.Token) == "" {
return errEmptyPasswordAndToken()
}

if strings.Trim(cfg.Protocol, " ") == "" {
cfg.Protocol = "https"
}
Expand Down Expand Up @@ -576,14 +581,20 @@ func buildHostFromAccountAndRegion(account, region string) string {
func authRequiresUser(cfg *Config) bool {
return cfg.Authenticator != AuthTypeOAuth &&
cfg.Authenticator != AuthTypeTokenAccessor &&
cfg.Authenticator != AuthTypeExternalBrowser
cfg.Authenticator != AuthTypeExternalBrowser &&
cfg.Authenticator != AuthTypePat
}

func authRequiresPassword(cfg *Config) bool {
return cfg.Authenticator != AuthTypeOAuth &&
cfg.Authenticator != AuthTypeTokenAccessor &&
cfg.Authenticator != AuthTypeExternalBrowser &&
cfg.Authenticator != AuthTypeJwt
cfg.Authenticator != AuthTypeJwt &&
cfg.Authenticator != AuthTypePat
}

func authRequiresEitherPasswordOrToken(cfg *Config) bool {
return cfg.Authenticator == AuthTypePat
}

// transformAccountToHost transforms account to host
Expand Down Expand Up @@ -905,7 +916,7 @@ type ConfigParam struct {

// GetConfigFromEnv is used to parse the environment variable values to specific fields of the Config
func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) {
var account, user, password, role, host, portStr, protocol, warehouse, database, schema, region, passcode, application string
var account, user, password, token, role, host, portStr, protocol, warehouse, database, schema, region, passcode, application string
var privateKey *rsa.PrivateKey
var err error
if len(properties) == 0 || properties == nil {
Expand All @@ -923,6 +934,8 @@ func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) {
user = value
case "Password":
password = value
case "Token":
token = value
case "Role":
role = value
case "Host":
Expand Down Expand Up @@ -963,6 +976,7 @@ func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) {
Account: account,
User: user,
Password: password,
Token: token,
Role: role,
Host: host,
Port: port,
Expand Down
Loading

0 comments on commit c55b379

Please sign in to comment.