diff --git a/dialer_test.go b/dialer_test.go index 5d697e0b..93a801ba 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -771,6 +771,8 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) { info: cloudsql.NewConnectionInfo( instance.ConnName{}, "", + "GOOGLE_MANAGED_INTERNAL_CA", + "", map[string]string{ // no public IP cloudsql.PrivateIP: "10.0.0.1", diff --git a/e2e_postgres_test.go b/e2e_postgres_test.go index 48a59656..79800b60 100644 --- a/e2e_postgres_test.go +++ b/e2e_postgres_test.go @@ -37,21 +37,27 @@ import ( ) var ( - postgresConnName = os.Getenv("POSTGRES_CONNECTION_NAME") // "Cloud SQL Postgres instance connection name, in the form of 'project:region:instance'. - postgresUser = os.Getenv("POSTGRES_USER") // Name of database user. - postgresPass = os.Getenv("POSTGRES_PASS") // Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history). - postgresDB = os.Getenv("POSTGRES_DB") // Name of the database to connect to. - postgresUserIAM = os.Getenv("POSTGRES_USER_IAM") // Name of database IAM user. + postgresConnName = os.Getenv("POSTGRES_CONNECTION_NAME") // "Cloud SQL Postgres instance connection name, in the form of 'project:region:instance'. + postgresCASConnName = os.Getenv("POSTGRES_CAS_CONNECTION_NAME") // "Cloud SQL Postgres CAS instance connection name, in the form of 'project:region:instance'. + postgresUser = os.Getenv("POSTGRES_USER") // Name of database user. + postgresPass = os.Getenv("POSTGRES_PASS") // Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history). + postgresCASPass = os.Getenv("POSTGRES_CAS_PASS") // Password for the database user for CAS instances; be careful when entering a password on the command line (it may go into your terminal's history). + postgresDB = os.Getenv("POSTGRES_DB") // Name of the database to connect to. + postgresUserIAM = os.Getenv("POSTGRES_USER_IAM") // Name of database IAM user. ) func requirePostgresVars(t *testing.T) { switch "" { case postgresConnName: t.Fatal("'POSTGRES_CONNECTION_NAME' env var not set") + case postgresCASConnName: + t.Fatal("'POSTGRES_CAS_CONNECTION_NAME' env var not set") case postgresUser: t.Fatal("'POSTGRES_USER' env var not set") case postgresPass: t.Fatal("'POSTGRES_PASS' env var not set") + case postgresCASPass: + t.Fatal("'POSTGRES_CAS_PASS' env var not set") case postgresDB: t.Fatal("'POSTGRES_DB' env var not set") case postgresUserIAM: @@ -107,6 +113,54 @@ func TestPostgresPgxPoolConnect(t *testing.T) { t.Log(now) } +func TestPostgresCASConnect(t *testing.T) { + if testing.Short() { + t.Skip("skipping Postgres integration tests") + } + requirePostgresVars(t) + + ctx := context.Background() + + // Configure the driver to connect to the database + dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", postgresUser, postgresCASPass, postgresDB) + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + t.Fatalf("failed to parse pgx config: %v", err) + } + + // Create a new dialer with any options + d, err := cloudsqlconn.NewDialer(ctx) + if err != nil { + t.Fatalf("failed to init Dialer: %v", err) + } + + // call cleanup when you're done with the database connection to close dialer + cleanup := func() error { return d.Close() } + + // Tell the driver to use the Cloud SQL Go Connector to create connections + // postgresConnName takes the form of 'project:region:instance'. + config.ConnConfig.DialFunc = func(ctx context.Context, _ string, _ string) (net.Conn, error) { + return d.Dial(ctx, postgresCASConnName) + } + + // Interact with the driver directly as you normally would + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + t.Fatalf("failed to create pool: %s", err) + } + // ... etc + + defer cleanup() + defer pool.Close() + + var now time.Time + err = pool.QueryRow(context.Background(), "SELECT NOW()").Scan(&now) + if err != nil { + t.Fatalf("QueryRow failed: %s", err) + } + t.Log(now) +} + func TestPostgresConnectWithIAMUser(t *testing.T) { if testing.Short() { t.Skip("skipping Postgres integration tests") diff --git a/go.mod b/go.mod index e0d0a733..80b35893 100644 --- a/go.mod +++ b/go.mod @@ -12,26 +12,25 @@ require ( golang.org/x/net v0.27.0 golang.org/x/oauth2 v0.21.0 golang.org/x/time v0.5.0 - google.golang.org/api v0.188.0 - google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5 + google.golang.org/api v0.190.0 + google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf google.golang.org/grpc v1.64.1 ) require ( - cloud.google.com/go/auth v0.7.0 // indirect - cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect - cloud.google.com/go/compute/metadata v0.4.0 // indirect + cloud.google.com/go/auth v0.7.3 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.3 // indirect + cloud.google.com/go/compute/metadata v0.5.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/go-logr/logr v1.4.1 // indirect + github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/golang/protobuf v1.5.4 // indirect - github.com/google/s2a-go v0.1.7 // indirect + github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect - github.com/googleapis/gax-go/v2 v2.12.5 // indirect + github.com/googleapis/gax-go/v2 v2.13.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 // indirect github.com/jackc/pgio v1.0.0 // indirect diff --git a/go.sum b/go.sum index b8d8a409..55972515 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,10 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go/auth v0.7.0 h1:kf/x9B3WTbBUHkC+1VS8wwwli9TzhSt0vSTVBmMR8Ts= -cloud.google.com/go/auth v0.7.0/go.mod h1:D+WqdrpcjmiCgWrXmLLxOVq1GACoE36chW6KXoEvuIw= -cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= -cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= -cloud.google.com/go/compute/metadata v0.4.0 h1:vHzJCWaM4g8XIcm8kopr3XmDA4Gy/lblD3EhhSux05c= -cloud.google.com/go/compute/metadata v0.4.0/go.mod h1:SIQh1Kkb4ZJ8zJ874fqVkslA29PRXuleyj6vOzlbK7M= +cloud.google.com/go/auth v0.7.3 h1:98Vr+5jMaCZ5NZk6e/uBgf60phTk/XN84r8QEWB9yjY= +cloud.google.com/go/auth v0.7.3/go.mod h1:HJtWUx1P5eqjy/f6Iq5KeytNpbAcGolPhOgyop2LlzA= +cloud.google.com/go/auth/oauth2adapt v0.2.3 h1:MlxF+Pd3OmSudg/b1yZ5lJwoXCEaeedAguodky1PcKI= +cloud.google.com/go/auth/oauth2adapt v0.2.3/go.mod h1:tMQXOfZzFuNuUxOypHlQEXgdfX5cuhwU+ffUuXRJE8I= +cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= +cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.1 h1:lGlwhPtrX6EVml1hO0ivjkUxsSyl4dsiw9qcA1k/3IQ= @@ -36,8 +36,8 @@ github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSw github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= -github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= @@ -64,8 +64,6 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -74,15 +72,15 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= -github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= +github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= +github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= -github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA= -github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E= +github.com/googleapis/gax-go/v2 v2.13.0 h1:yitjD5f7jQHhyDsnhKEBU52NdvvdSeGzlAnDPT0hH1s= +github.com/googleapis/gax-go/v2 v2.13.0/go.mod h1:Z/fvTZXF8/uw7Xu5GuslPw+bplx6SS338j1Is2S+B7A= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= @@ -288,17 +286,17 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.188.0 h1:51y8fJ/b1AaaBRJr4yWm96fPcuxSo0JcegXE3DaHQHw= -google.golang.org/api v0.188.0/go.mod h1:VR0d+2SIiWOYG3r/jdm7adPW9hI2aRv9ETOSCQ9Beag= +google.golang.org/api v0.190.0 h1:ASM+IhLY1zljNdLu19W1jTmU6A+gMk6M46Wlur61s+Q= +google.golang.org/api v0.190.0/go.mod h1:QIr6I9iedBLnfqoD6L6Vze1UvS5Hzj5r2aUBOaZnLHo= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20240708141625-4ad9e859172b h1:dSTjko30weBaMj3eERKc0ZVXW4GudCswM3m+P++ukU0= -google.golang.org/genproto/googleapis/api v0.0.0-20240610135401-a8a62080eff3 h1:QW9+G6Fir4VcRXVH8x3LilNAb6cxBGLa6+GM4hRwexE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5 h1:SbSDUWW1PAO24TNpLdeheoYPd7kllICcLU52x6eD4kQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= +google.golang.org/genproto v0.0.0-20240730163845-b1a4ccb954bf h1:OqdXDEakZCVtDiZTjcxfwbHPCT11ycCEsTKesBVKvyY= +google.golang.org/genproto/googleapis/api v0.0.0-20240711142825-46eb208f015d h1:kHjw/5UfflP/L5EbledDrcG4C2597RtymmGRZvHiCuY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf h1:liao9UHurZLtiEwBgT9LMOnKYsHze6eA6w1KQCMVN2Q= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index bb6d52d1..f8e44b3b 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -170,9 +170,13 @@ func (i *RefreshAheadCache) Close() error { type ConnectionInfo struct { ConnectionName instance.ConnName ClientCertificate tls.Certificate - ServerCaCert *x509.Certificate + ServerCACert []*x509.Certificate + ServerCAMode string DBVersion string - Expiration time.Time + // The DNSName is from the ConnectSettings API. + // It is used to validate the server identity of the CAS instances. + DNSName string + Expiration time.Time addrs map[string]string } @@ -180,18 +184,22 @@ type ConnectionInfo struct { // NewConnectionInfo initializes a ConnectionInfo struct. func NewConnectionInfo( cn instance.ConnName, + dnsName string, + serverCAMode string, version string, ipAddrs map[string]string, - serverCaCert *x509.Certificate, + serverCACert []*x509.Certificate, clientCert tls.Certificate, ) ConnectionInfo { return ConnectionInfo{ addrs: ipAddrs, - ServerCaCert: serverCaCert, ClientCertificate: clientCert, + ServerCACert: serverCACert, + ServerCAMode: serverCAMode, Expiration: clientCert.Leaf.NotAfter, DBVersion: version, ConnectionName: cn, + DNSName: dnsName, } } @@ -225,7 +233,18 @@ func (c ConnectionInfo) Addr(ipType string) (string, error) { // TLSConfig constructs a TLS configuration for the given connection info. func (c ConnectionInfo) TLSConfig() *tls.Config { pool := x509.NewCertPool() - pool.AddCert(c.ServerCaCert) + for _, caCert := range c.ServerCACert { + pool.AddCert(caCert) + } + if c.ServerCAMode == "GOOGLE_MANAGED_CAS_CA" { + // For CAS instances, we can rely on the DNS name to verify the server identity. + return &tls.Config{ + ServerName: c.DNSName, + Certificates: []tls.Certificate{c.ClientCertificate}, + RootCAs: pool, + MinVersion: tls.VersionTLS13, + } + } return &tls.Config{ ServerName: c.ConnectionName.String(), Certificates: []tls.Certificate{c.ClientCertificate}, diff --git a/internal/cloudsql/instance_test.go b/internal/cloudsql/instance_test.go index ac239044..c2f7561f 100644 --- a/internal/cloudsql/instance_test.go +++ b/internal/cloudsql/instance_test.go @@ -158,7 +158,7 @@ func TestConnectionInfoTLSConfig(t *testing.T) { t.Fatal(err) } b, _ = pem.Decode(certBytes) - serverCert, err := x509.ParseCertificate(b.Bytes) + serverCACert, err := x509.ParseCertificate(b.Bytes) if err != nil { t.Fatal(err) } @@ -172,7 +172,7 @@ func TestConnectionInfoTLSConfig(t *testing.T) { PrivateKey: RSAKey, Leaf: clientCert, }, - ServerCaCert: serverCert, + ServerCACert: []*x509.Certificate{serverCACert}, DBVersion: "doesn't matter here", Expiration: clientCert.NotAfter, } @@ -198,7 +198,7 @@ func TestConnectionInfoTLSConfig(t *testing.T) { } verifyPeerCert := got.VerifyPeerCertificate - err = verifyPeerCert([][]byte{serverCert.Raw}, nil) + err = verifyPeerCert([][]byte{serverCACert.Raw}, nil) if err != nil { t.Fatalf("expected to verify peer cert, got error: %v", err) } @@ -375,3 +375,82 @@ func TestRefreshDuration(t *testing.T) { }) } } + +func TestConnectionInfoTLSConfigForCAS(t *testing.T) { + cn := testInstanceConnName() + i := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name()) + // Generate a client certificate with the client's public key and signed by + // the server's private key + cert, err := i.ClientCert(&RSAKey.PublicKey) + if err != nil { + t.Fatal(err) + } + // Now parse the bytes back out as structured data + b, _ := pem.Decode(cert) + clientCert, err := x509.ParseCertificate(b.Bytes) + if err != nil { + t.Fatal(err) + } + // Create a PEM with two certificates to pretend the sub CA and root CA. + rootCABytes, err := mock.SelfSign(i.Cert, i.Key) + if err != nil { + t.Fatal(err) + } + b, _ = pem.Decode(rootCABytes) + rootCACert, err := x509.ParseCertificate(b.Bytes) + if err != nil { + t.Fatal(err) + } + subCABytes, err := mock.SelfSign(i.Cert, i.Key) + if err != nil { + t.Fatal(err) + } + b, _ = pem.Decode(subCABytes) + subCACert, err := x509.ParseCertificate(b.Bytes) + if err != nil { + t.Fatal(err) + } + caCerts := []*x509.Certificate{rootCACert, subCACert} + wantRootCAs := x509.NewCertPool() + wantRootCAs.AddCert(rootCACert) + wantRootCAs.AddCert(subCACert) + // Assemble a connection info with the raw and parsed client cert + // and the self-signed server certificate + wantServerName := "testing dns name" + ci := ConnectionInfo{ + DNSName: wantServerName, + ClientCertificate: tls.Certificate{ + Certificate: [][]byte{clientCert.Raw}, + PrivateKey: RSAKey, + Leaf: clientCert, + }, + ServerCACert: caCerts, + DBVersion: "doesn't matter here", + Expiration: clientCert.NotAfter, + ServerCAMode: "GOOGLE_MANAGED_CAS_CA", + } + + got := ci.TLSConfig() + + if got.ServerName != wantServerName { + t.Fatalf( + "ConnectInfo return unexpected server name in TLS Config, "+ + "want = %v, got = %v", + wantServerName, got.ServerName, + ) + } + if got.MinVersion != tls.VersionTLS13 { + t.Fatalf( + "want TLS 1.3, got = %v", got.MinVersion, + ) + } + if got.Certificates[0].Leaf != ci.ClientCertificate.Leaf { + t.Fatal("leaf certificates do not match") + } + if got.InsecureSkipVerify { + t.Fatal("InsecureSkipVerify is true, expected false") + } + if !got.RootCAs.Equal(wantRootCAs) { + t.Fatalf("unexpected root CAs, got %v, want %v", got.RootCAs, wantRootCAs) + } +} diff --git a/internal/cloudsql/refresh.go b/internal/cloudsql/refresh.go index 3d2e6b58..152d99a4 100644 --- a/internal/cloudsql/refresh.go +++ b/internal/cloudsql/refresh.go @@ -48,7 +48,9 @@ const ( // connections. type metadata struct { ipAddrs map[string]string - serverCaCert *x509.Certificate + serverCACert []*x509.Certificate + serverCAMode string + dnsName string version string } @@ -98,7 +100,8 @@ func fetchMetadata( } // resolve DnsName into IP address for PSC - if db.DnsName != "" { + // Note that we have to check for PSC enablement first because CAS instances also set the DnsName. + if db.PscEnabled && db.DnsName != "" { ipAddrs[PSC] = db.DnsName } @@ -110,23 +113,28 @@ func fetchMetadata( } // parse the server-side CA certificate - b, _ := pem.Decode([]byte(db.ServerCaCert.Cert)) - if b == nil { - return metadata{}, errtype.NewRefreshError("failed to decode valid PEM cert", inst.String(), nil) - } - cert, err := x509.ParseCertificate(b.Bytes) - if err != nil { - return metadata{}, errtype.NewRefreshError( - fmt.Sprintf("failed to parse as X.509 certificate: %v", err), - inst.String(), - nil, - ) + caCerts := []*x509.Certificate{} + for b, rest := pem.Decode([]byte(db.ServerCaCert.Cert)); b != nil; b, rest = pem.Decode(rest) { + if b == nil { + return metadata{}, errtype.NewRefreshError("failed to decode valid PEM cert", inst.String(), nil) + } + caCert, err := x509.ParseCertificate(b.Bytes) + if err != nil { + return metadata{}, errtype.NewRefreshError( + fmt.Sprintf("failed to parse as X.509 certificate: %v", err), + inst.String(), + nil, + ) + } + caCerts = append(caCerts, caCert) } m = metadata{ ipAddrs: ipAddrs, - serverCaCert: cert, + serverCACert: caCerts, version: db.DatabaseVersion, + dnsName: db.DnsName, + serverCAMode: db.ServerCaMode, } return m, nil @@ -345,7 +353,7 @@ func (c adminAPIClient) ConnectionInfo( } return NewConnectionInfo( - cn, md.version, md.ipAddrs, md.serverCaCert, ec, + cn, md.dnsName, md.serverCAMode, md.version, md.ipAddrs, md.serverCACert, ec, ), nil } diff --git a/internal/cloudsql/refresh_test.go b/internal/cloudsql/refresh_test.go index 1a23f4db..7d5e75b4 100644 --- a/internal/cloudsql/refresh_test.go +++ b/internal/cloudsql/refresh_test.go @@ -35,7 +35,8 @@ const testDialerID = "some-dialer-id" func TestRefresh(t *testing.T) { wantPublicIP := "127.0.0.1" wantPrivateIP := "10.0.0.1" - wantPSC := "abcde.12345.us-central1.sql.goog" + wantPSC := true + wantDNS := "abcde.12345.us-central1.sql.goog" wantExpiry := time.Now().Add(time.Hour).UTC().Round(time.Second) cn := testInstanceConnName() inst := mock.NewFakeCSQLInstance( @@ -43,6 +44,7 @@ func TestRefresh(t *testing.T) { mock.WithPublicIP(wantPublicIP), mock.WithPrivateIP(wantPrivateIP), mock.WithPSC(wantPSC), + mock.WithDNS(wantDNS), mock.WithCertExpiry(wantExpiry), ) client, cleanup, err := mock.NewSQLAdminService( @@ -79,17 +81,17 @@ func TestRefresh(t *testing.T) { if wantPrivateIP != gotIP { t.Fatalf("metadata IP mismatch, want = %v, got = %v", wantPrivateIP, gotIP) } - gotPSC, ok := rr.addrs[PSC] + gotPSCDNS, ok := rr.addrs[PSC] if !ok { t.Fatal("metadata IP addresses did not include PSC endpoint") } - if wantPSC != gotPSC { - t.Fatalf("metadata IP mismatch, want = %v. got = %v", wantPSC, gotPSC) + if wantDNS != gotPSCDNS { + t.Fatalf("metadata IP mismatch, want = %v. got = %v", wantDNS, gotPSCDNS) } if cn != rr.ConnectionName { t.Fatalf( "connection name mismatch, want = %v, got = %v", - wantExpiry, rr.Expiration, + cn.Name(), rr.ConnectionName, ) } if wantExpiry != rr.Expiration { @@ -97,6 +99,47 @@ func TestRefresh(t *testing.T) { } } +func TestRefreshForCASInstances(t *testing.T) { + wantDNS := "abcde.12345.us-central1.sql.goog" + cn := testInstanceConnName() + inst := mock.NewFakeCSQLInstance( + cn.Project(), cn.Region(), cn.Name(), + mock.WithPublicIP("127.0.0.1"), + mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"), + mock.WithDNS(wantDNS), + mock.WithCertExpiry(time.Now().Add(time.Hour).UTC().Round(time.Second)), + ) + client, cleanup, err := mock.NewSQLAdminService( + context.Background(), + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + ) + if err != nil { + t.Fatalf("failed to create test SQL admin service: %s", err) + } + defer func() { + if err := cleanup(); err != nil { + t.Fatalf("%v", err) + } + }() + + r := newAdminAPIClient(nullLogger{}, client, RSAKey, nil, testDialerID) + rr, err := r.ConnectionInfo(context.Background(), cn, false) + if err != nil { + t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err) + } + + if wantDNS != rr.DNSName { + t.Fatalf("DNS mismatch, want = %v. got = %v", wantDNS, rr.DNSName) + } + if rr.ServerCAMode != "GOOGLE_MANAGED_CAS_CA" { + t.Fatalf("server CA mode mismatch, want = GOOGLE_MANAGED_CAS_CA, got = %v", rr.ServerCAMode) + } + if len(rr.ServerCACert) < 2 { + t.Fatalf("number of server cert mismatch, want 2, got %d", len(rr.ServerCACert)) + } +} + // If a caller has provided a static token source that cannot be refreshed // (e.g., when the Cloud SQL Proxy is invokved with --token), then the // refresher cannot determine the token's expiration (without additional API diff --git a/internal/mock/cloudsql.go b/internal/mock/cloudsql.go index 04f7969c..27fc80a3 100644 --- a/internal/mock/cloudsql.go +++ b/internal/mock/cloudsql.go @@ -51,6 +51,8 @@ type FakeCSQLInstance struct { ipAddrs map[string]string backendType string DNSName string + serverCAMode string + pscEnabled bool signer SignFunc clientSigner ClientSignFunc // Key is the server's private key @@ -92,8 +94,15 @@ func WithPrivateIP(addr string) FakeCSQLInstanceOption { } } -// WithPSC sets the PSC DnsName to addr. -func WithPSC(dns string) FakeCSQLInstanceOption { +// WithPSC sets the PSC enabled. +func WithPSC(enabled bool) FakeCSQLInstanceOption { + return func(f *FakeCSQLInstance) { + f.pscEnabled = enabled + } +} + +// WithDNS sets the DnsName to addr. +func WithDNS(dns string) FakeCSQLInstanceOption { return func(f *FakeCSQLInstance) { f.DNSName = dns } @@ -160,6 +169,13 @@ func WithNoIPAddrs() FakeCSQLInstanceOption { } } +// WithServerCAMode sets the ServerCaMode of the instance. +func WithServerCAMode(serverCAMode string) FakeCSQLInstanceOption { + return func(f *FakeCSQLInstance) { + f.serverCAMode = serverCAMode + } +} + // NewFakeCSQLInstance returns a CloudSQLInst object for configuring mocks. func NewFakeCSQLInstance(project, region, name string, opts ...FakeCSQLInstanceOption) FakeCSQLInstance { // TODO: consider options for this? diff --git a/internal/mock/sqladmin.go b/internal/mock/sqladmin.go index 9913d4e6..b25f042f 100644 --- a/internal/mock/sqladmin.go +++ b/internal/mock/sqladmin.go @@ -112,10 +112,19 @@ func InstanceGetSuccess(i FakeCSQLInstance, ct int) *Request { ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"}) } } - certBytes, err := i.signedCert() + certBytes1, err := i.signedCert() if err != nil { panic(err) } + certBytes := certBytes1 + if i.serverCAMode == "GOOGLE_MANAGED_CAS_CA" { + // CAS instances return two CAs in the trust chain. + certBytes2, err := i.signedCert() + if err != nil { + panic(err) + } + certBytes = append(certBytes, certBytes2...) + } db := &sqladmin.ConnectSettings{ BackendType: i.backendType, DatabaseVersion: i.dbVersion, @@ -123,6 +132,8 @@ func InstanceGetSuccess(i FakeCSQLInstance, ct int) *Request { IpAddresses: ips, Region: i.region, ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)}, + PscEnabled: i.pscEnabled, + ServerCaMode: i.serverCAMode, } r := &Request{