Skip to content

Commit

Permalink
s2autil unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xmenxk committed Mar 9, 2023
1 parent f36a8da commit f99c183
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 26 deletions.
10 changes: 9 additions & 1 deletion transport/grpc/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package grpc

import (
"context"
"crypto/tls"
"errors"
"log"
"net"
Expand Down Expand Up @@ -132,7 +133,14 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C
if insecure {
transportCreds = grpcinsecure.NewCredentials()
} else {
transportCreds, endpoint = s2autil.ConfigureTransportCreds(clientCertSource, endpoint, o)
transportCreds = credentials.NewTLS(&tls.Config{
GetClientCertificate: clientCertSource,
})
s2aTransportCreds, targetEndpoint := s2autil.ConfigureTransportCreds(clientCertSource, endpoint, o)
if s2aTransportCreds != nil {
transportCreds, endpoint = s2aTransportCreds, targetEndpoint
}

}

// Initialize gRPC dial options with transport-level security options.
Expand Down
6 changes: 4 additions & 2 deletions transport/http/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ func NewClient(ctx context.Context, opts ...option.ClientOption) (*http.Client,
return settings.HTTPClient, endpoint, nil
}

dialTLSContext, endpoint := s2autil.ConfigureHTTPTransport(clientCertSource, endpoint, settings)

dialTLSContext, targetEndpoint := s2autil.ConfigureHTTPTransport(clientCertSource, endpoint, settings)
if dialTLSContext != nil {
endpoint = targetEndpoint
}
trans, err := newTransport(ctx, defaultBaseTransport(ctx, clientCertSource, dialTLSContext), settings)
if err != nil {
return nil, "", err
Expand Down
3 changes: 2 additions & 1 deletion transport/internal/s2a/mtls_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
package s2autil

import (
"cloud.google.com/go/compute/metadata"
"encoding/json"
"log"
"sync"
"time"

"cloud.google.com/go/compute/metadata"
)

const configEndpointSuffix = "googleAutoMtlsConfiguration"
Expand Down
12 changes: 7 additions & 5 deletions transport/internal/s2a/mtls_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ func TestGetS2AAddress(t *testing.T) {
if want, got := tc.Want, GetS2AAddress(); got != want {
t.Errorf("%s: want address [%s], got address [%s]", tc.Desc, want, got)
}
// Let the config expire, trigger refresh.
time.Sleep(3 * time.Millisecond)
// Let the config expire at end of each test case.
time.Sleep(2 * time.Millisecond)
}
}

func TestMTLSConfigExpiry(t *testing.T) {
oldHTTPGet := httpGetMetadataMTLSConfig
oldExpiry := configExpiry
configExpiry = 3 * time.Second
configExpiry = 1 * time.Second
defer func() {
httpGetMetadataMTLSConfig = oldHTTPGet
configExpiry = oldExpiry
Expand All @@ -97,10 +97,12 @@ func TestMTLSConfigExpiry(t *testing.T) {
}
httpGetMetadataMTLSConfig = invalidConfigResp
if got, want := GetS2AAddress(), "host:port"; got != want {
t.Errorf("cashed config should still be valid, expected address: [%s], got [%s]", want, got)
t.Errorf("cached config should still be valid, expected address: [%s], got [%s]", want, got)
}
time.Sleep(3 * time.Second)
time.Sleep(1 * time.Second)
if got, want := GetS2AAddress(), ""; got != want {
t.Errorf("config should be refreshed, expected address: [%s], got [%s]", want, got)
}
// Let the config expire before running other tests.
time.Sleep(1 * time.Second)
}
30 changes: 14 additions & 16 deletions transport/internal/s2a/s2a.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,42 @@ package s2autil

import (
"context"
"crypto/tls"
"net"
"net/url"

"github.com/google/s2a-go"
"github.com/google/s2a-go/fallback"
"google.golang.org/api/internal"
"google.golang.org/api/transport/cert"
"google.golang.org/grpc/credentials"
"net"
"net/url"
)

var defaultMTLSEnabled =false
var defaultMTLSEnabled = false

func EndpointIsMTLSEnabled(mtlsEndpoint string) bool {
// TODO(xmenxk): determine this via discovery config setting, plus runtime override.
return defaultMTLSEnabled
}

// ConfigureTransportCreds returns credentials.TransportCredentials, and the endpoint to connect.
// It enables MTLS with S2A only when all the followings are true:
// ConfigureTransportCreds returns an S2A-enabled credentials.TransportCredentials, and the endpoint to connect.
// The following condition must all be true, otherwise returns nil.
// 1. No client cert is specified,
// 2. MTLS endpoint is indeed enabled.
func ConfigureTransportCreds(clientCertSource cert.Source, endpoint string, settings *internal.DialSettings) (credentials.TransportCredentials, string) {
defaultTransportCreds := credentials.NewTLS(&tls.Config{
GetClientCertificate: clientCertSource,
})
if clientCertSource != nil {
return defaultTransportCreds, endpoint
return nil, ""
}

targetMTLSEndpoint := settings.DefaultMTLSEndpoint
if settings.Endpoint != "" {
targetMTLSEndpoint = endpoint
}
if !EndpointIsMTLSEnabled(targetMTLSEndpoint) {
return defaultTransportCreds, endpoint
return nil, ""
}
s2aAddress := GetS2AAddress()
if s2aAddress == "" {
return defaultTransportCreds, endpoint
return nil, ""
}

var fallbackOpts *s2a.FallbackOptions
Expand All @@ -62,7 +60,7 @@ func ConfigureTransportCreds(clientCertSource cert.Source, endpoint string, sett
})
if err != nil {
// Use default if we cannot initialize S2A client transport credentials.
return defaultTransportCreds, endpoint
return nil, ""
}
return transportCreds, targetMTLSEndpoint
}
Expand All @@ -73,19 +71,19 @@ func ConfigureTransportCreds(clientCertSource cert.Source, endpoint string, sett
// 2. MTLS endpoint is indeed enabled.
func ConfigureHTTPTransport(clientCertSource cert.Source, endpoint string, settings *internal.DialSettings) (func(context.Context, string, string) (net.Conn, error), string) {
if clientCertSource != nil {
return nil, endpoint
return nil, ""
}

targetMTLSEndpoint := settings.DefaultMTLSEndpoint
if settings.Endpoint != "" {
targetMTLSEndpoint = endpoint
}
if !EndpointIsMTLSEnabled(targetMTLSEndpoint) {
return nil, endpoint
return nil, ""
}
s2aAddress := GetS2AAddress()
if s2aAddress == "" {
return nil, endpoint
return nil, ""
}

var fallbackOpts *s2a.FallbackOptions
Expand Down
219 changes: 218 additions & 1 deletion transport/internal/s2a/s2a_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,221 @@

package s2autil

// TODO(xmenxk): Add unit tests.
import (
"crypto/tls"
"testing"
"time"

"google.golang.org/api/internal"
"google.golang.org/api/transport/cert"
)

var testCertSource = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return nil, nil
}

func TestConfigureTransportCreds(t *testing.T) {

testCases := []struct {
Desc string
InputClientCert cert.Source
InputEndpoint string
InputSettings *internal.DialSettings
S2ARespFunc func() (string, error)
MTLSEnabled bool
WantEndpoint string
TransportCredsNil bool
}{
{
"no client cert, MTLS enabled, S2A address not empty",
nil,
"passed_in_endpoint",
&internal.DialSettings{
DefaultMTLSEndpoint: "mtls_endpoint",
DefaultEndpoint: "regular_endpoint",
},
validConfigResp,
true,
"mtls_endpoint",
false,
},
{
"has client cert",
testCertSource,
"passed_in_endpoint",
&internal.DialSettings{
DefaultMTLSEndpoint: "mtls_endpoint",
DefaultEndpoint: "regular_endpoint",
},
validConfigResp,
true,
"",
true,
},
{
"no client cert, MTLS not enabled",
nil,
"passed_in_endpoint",
&internal.DialSettings{
DefaultMTLSEndpoint: "mtls_endpoint",
DefaultEndpoint: "regular_endpoint",
},
validConfigResp,
false,
"",
true,
},
{
"no client cert, MTLS enabled, S2A address empty",
nil,
"passed_in_endpoint",
&internal.DialSettings{
DefaultMTLSEndpoint: "mtls_endpoint",
DefaultEndpoint: "regular_endpoint",
},
invalidConfigResp,
true,
"",
true,
},
{
"no client cert, MTLS enabled, S2A address not empty, override endpoint",
nil,
"passed_in_endpoint",
&internal.DialSettings{
DefaultMTLSEndpoint: "mtls_endpoint",
DefaultEndpoint: "regular_endpoint",
Endpoint: "override_endpoint",
},
validConfigResp,
true,
"passed_in_endpoint",
false,
},
}
oldDefaultMTLSEnabled := defaultMTLSEnabled
oldHTTPGet := httpGetMetadataMTLSConfig
oldExpiry := configExpiry
defer func() {
httpGetMetadataMTLSConfig = oldHTTPGet
defaultMTLSEnabled = oldDefaultMTLSEnabled
configExpiry = oldExpiry
}()
configExpiry = time.Millisecond
for _, tc := range testCases {
httpGetMetadataMTLSConfig = tc.S2ARespFunc
defaultMTLSEnabled = tc.MTLSEnabled
s2aTransportCreds, endpoint := ConfigureTransportCreds(tc.InputClientCert, tc.InputEndpoint, tc.InputSettings)
if tc.WantEndpoint != endpoint {
t.Errorf("%s: want endpoint: [%s], got [%s]", tc.Desc, tc.WantEndpoint, endpoint)
}
if want, got := tc.TransportCredsNil, s2aTransportCreds == nil; want != got {
t.Errorf("%s: expecting returned transportCreds is nil: [%v], got [%v]", tc.Desc, tc.TransportCredsNil, got)
}
// Let config expire at end of each test case.
time.Sleep(2 * time.Millisecond)
}
}

func TestConfigureHTTPTransport(t *testing.T) {

testCases := []struct {
Desc string
InputClientCert cert.Source
InputEndpoint string
InputSettings *internal.DialSettings
S2ARespFunc func() (string, error)
MTLSEnabled bool
WantEndpoint string
DialFuncNil bool
}{
{
"no client cert, MTLS enabled, S2A address not empty",
nil,
"passed_in_endpoint",
&internal.DialSettings{
DefaultMTLSEndpoint: "mtls_endpoint",
DefaultEndpoint: "regular_endpoint",
},
validConfigResp,
true,
"mtls_endpoint",
false,
},
{
"has client cert",
testCertSource,
"passed_in_endpoint",
&internal.DialSettings{
DefaultMTLSEndpoint: "mtls_endpoint",
DefaultEndpoint: "regular_endpoint",
},
validConfigResp,
true,
"",
true,
},
{
"no client cert, MTLS not enabled",
nil,
"",
&internal.DialSettings{
DefaultMTLSEndpoint: "mtls_endpoint",
DefaultEndpoint: "regular_endpoint",
},
validConfigResp,
false,
"",
true,
},
{
"no client cert, MTLS enabled, S2A address empty",
nil,
"passed_in_endpoint",
&internal.DialSettings{
DefaultMTLSEndpoint: "mtls_endpoint",
DefaultEndpoint: "regular_endpoint",
},
invalidConfigResp,
true,
"",
true,
},
{
"no client cert, MTLS enabled, S2A address not empty, override endpoint",
nil,
"passed_in_endpoint",
&internal.DialSettings{
DefaultMTLSEndpoint: "mtls_endpoint",
DefaultEndpoint: "regular_endpoint",
Endpoint: "override_endpoint",
},
validConfigResp,
true,
"passed_in_endpoint",
false,
},
}
oldDefaultMTLSEnabled := defaultMTLSEnabled
oldHTTPGet := httpGetMetadataMTLSConfig
oldExpiry := configExpiry
defer func() {
httpGetMetadataMTLSConfig = oldHTTPGet
defaultMTLSEnabled = oldDefaultMTLSEnabled
configExpiry = oldExpiry
}()
configExpiry = time.Millisecond
for _, tc := range testCases {
httpGetMetadataMTLSConfig = tc.S2ARespFunc
defaultMTLSEnabled = tc.MTLSEnabled
dialFunc, endpoint := ConfigureHTTPTransport(tc.InputClientCert, tc.InputEndpoint, tc.InputSettings)
if tc.WantEndpoint != endpoint {
t.Errorf("%s: want endpoint: [%s], got [%s]", tc.Desc, tc.WantEndpoint, endpoint)
}
if want, got := tc.DialFuncNil, dialFunc == nil; want != got {
t.Errorf("%s: expecting returned dialFunc is nil: [%v], got [%v]", tc.Desc, tc.DialFuncNil, got)
}
// Let config expire at end of each test case.
time.Sleep(2 * time.Millisecond)
}
}

0 comments on commit f99c183

Please sign in to comment.