diff --git a/cns/configuration/cns_config.json b/cns/configuration/cns_config.json index 16f72ef2f0..6be0daafd3 100644 --- a/cns/configuration/cns_config.json +++ b/cns/configuration/cns_config.json @@ -20,6 +20,7 @@ "TLSPort": "10091", "TLSSubjectName": "", "UseHTTPS": false, + "UseMTLS": false, "WireserverIP": "168.63.129.16", "KeyVaultSettings": { "URL": "", diff --git a/cns/configuration/configuration.go b/cns/configuration/configuration.go index 44da9c3e87..25b0f19c1b 100644 --- a/cns/configuration/configuration.go +++ b/cns/configuration/configuration.go @@ -55,6 +55,7 @@ type CNSConfig struct { TLSSubjectName string TelemetrySettings TelemetrySettings UseHTTPS bool + UseMTLS bool WatchPods bool `json:"-"` WireserverIP string } diff --git a/cns/configuration/configuration_test.go b/cns/configuration/configuration_test.go index 24f49c5701..1c9b8a5539 100644 --- a/cns/configuration/configuration_test.go +++ b/cns/configuration/configuration_test.go @@ -87,6 +87,7 @@ func TestReadConfigFromFile(t *testing.T) { PopulateHomeAzCacheRetryIntervalSecs: 60, }, UseHTTPS: true, + UseMTLS: true, WireserverIP: "168.63.129.16", }, wantErr: false, diff --git a/cns/configuration/testdata/good.json b/cns/configuration/testdata/good.json index ddc8cfb175..185484ede1 100644 --- a/cns/configuration/testdata/good.json +++ b/cns/configuration/testdata/good.json @@ -30,6 +30,7 @@ "TelemetryBatchSizeBytes": 16384 }, "UseHTTPS": true, + "UseMTLS": true, "WireserverIP": "168.63.129.16", "AZRSettings": { "PopulateHomeAzCacheRetryIntervalSecs": 60 diff --git a/cns/service.go b/cns/service.go index e0cdd75ab7..147c41051d 100644 --- a/cns/service.go +++ b/cns/service.go @@ -6,6 +6,7 @@ package cns import ( "context" "crypto/tls" + "crypto/x509" "fmt" "net" "net/http" @@ -190,6 +191,18 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) }, } + if tlsSettings.UseMTLS { + rootCAs, err := mtlsRootCAsFromCertificate(&tlsCert) + if err != nil { + return nil, errors.Wrap(err, "failed to get root CAs for configuring mTLS") + } + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + tlsConfig.ClientCAs = rootCAs + tlsConfig.RootCAs = rootCAs + } + + logger.Debugf("TLS configured successfully from file: %+v", tlsSettings) + return tlsConfig, nil } @@ -224,9 +237,51 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e }, } + if tlsSettings.UseMTLS { + tlsCert := cr.GetCertificate() + rootCAs, err := mtlsRootCAsFromCertificate(tlsCert) + if err != nil { + return nil, errors.Wrap(err, "failed to get root CAs for configuring mTLS") + } + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + tlsConfig.ClientCAs = rootCAs + tlsConfig.RootCAs = rootCAs + } + + logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings) + return &tlsConfig, nil } +// Given a TLS cert, return the root CAs +func mtlsRootCAsFromCertificate(tlsCert *tls.Certificate) (*x509.CertPool, error) { + switch { + case tlsCert == nil || len(tlsCert.Certificate) == 0: + return nil, errors.New("no certificate provided") + case len(tlsCert.Certificate) == 1: + certs := x509.NewCertPool() + cert, err := x509.ParseCertificate(tlsCert.Certificate[0]) + if err != nil { + return nil, errors.Wrap(err, "parsing self signed cert") + } + certs.AddCert(cert) + + return certs, nil + default: + certs := x509.NewCertPool() + // given a fullchain cert, we skip leaf cert at index 0 because + // we only want intermediate and root certs in the cert pool for mTLS + for _, certBytes := range tlsCert.Certificate[1:] { + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, errors.Wrap(err, "parsing root certs") + } + certs.AddCert(cert) + } + return certs, nil + } +} + func (service *Service) StartListener(config *common.ServiceConfig) error { log.Debugf("[Azure CNS] Going to start listener: %+v", config) diff --git a/cns/service/main.go b/cns/service/main.go index 6fa20abad4..e183f171ea 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -785,6 +785,7 @@ func main() { KeyVaultCertificateName: cnsconfig.KeyVaultSettings.CertificateName, MSIResourceID: cnsconfig.MSISettings.ResourceID, KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour, + UseMTLS: cnsconfig.UseMTLS, } } diff --git a/cns/service_test.go b/cns/service_test.go new file mode 100644 index 0000000000..9bf4af8ce7 --- /dev/null +++ b/cns/service_test.go @@ -0,0 +1,324 @@ +// Copyright 2017 Microsoft. All rights reserved. +// MIT License + +package cns + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/Azure/azure-container-networking/cns/common" + "github.com/Azure/azure-container-networking/cns/logger" + acn "github.com/Azure/azure-container-networking/common" + serverTLS "github.com/Azure/azure-container-networking/server/tls" + "github.com/Azure/azure-container-networking/store" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewService(t *testing.T) { + logger.InitLogger("azure-cns.log", 0, 0, "/") + mockStore := store.NewMockStore("test") + + config := &common.ServiceConfig{ + Name: "test", + Version: "1.0", + ChannelMode: "Direct", + Store: mockStore, + } + + t.Run("NewService", func(t *testing.T) { + svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store) + require.NoError(t, err) + require.IsType(t, &Service{}, svc) + + svc.SetOption(acn.OptCnsURL, "") + svc.SetOption(acn.OptCnsPort, "") + + require.Empty(t, config.TLSSettings) + + err = svc.Initialize(config) + t.Cleanup(func() { + svc.Uninitialize() + }) + require.NoError(t, err) + + err = svc.StartListener(config) + require.NoError(t, err) + + client := &http.Client{} + + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody) + require.NoError(t, err) + resp, err := client.Do(req) + t.Cleanup(func() { + resp.Body.Close() + }) + require.NoError(t, err) + }) + + t.Run("NewServiceWithTLS", func(t *testing.T) { + testCertFilePath := createTestCertificate(t) + + config.TLSSettings = serverTLS.TlsSettings{ + TLSPort: "10091", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + } + + svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store) + require.NoError(t, err) + require.IsType(t, &Service{}, svc) + + svc.SetOption(acn.OptCnsURL, "") + svc.SetOption(acn.OptCnsPort, "") + + err = svc.Initialize(config) + t.Cleanup(func() { + svc.Uninitialize() + }) + require.NoError(t, err) + + err = svc.StartListener(config) + require.NoError(t, err) + + tlsClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS13, + ServerName: config.TLSSettings.TLSSubjectName, + // #nosec G402 for test purposes only + InsecureSkipVerify: true, + }, + }, + } + + // TLS listener + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody) + require.NoError(t, err) + resp, err := tlsClient.Do(req) + t.Cleanup(func() { + resp.Body.Close() + }) + require.NoError(t, err) + + // HTTP listener + httpClient := &http.Client{} + req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody) + require.NoError(t, err) + resp, err = httpClient.Do(req) + t.Cleanup(func() { + resp.Body.Close() + }) + require.NoError(t, err) + }) + + t.Run("NewServiceWithMutualTLS", func(t *testing.T) { + testCertFilePath := createTestCertificate(t) + + config.TLSSettings = serverTLS.TlsSettings{ + TLSPort: "10091", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + UseMTLS: true, + } + + svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store) + require.NoError(t, err) + require.IsType(t, &Service{}, svc) + + svc.SetOption(acn.OptCnsURL, "") + svc.SetOption(acn.OptCnsPort, "") + + err = svc.Initialize(config) + t.Cleanup(func() { + svc.Uninitialize() + }) + require.NoError(t, err) + + err = svc.StartListener(config) + require.NoError(t, err) + + mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings) + require.NoError(t, err) + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: mTLSConfig, + }, + } + + // TLS listener + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody) + require.NoError(t, err) + resp, err := client.Do(req) + t.Cleanup(func() { + resp.Body.Close() + }) + require.NoError(t, err) + + // HTTP listener + httpClient := &http.Client{} + req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody) + require.NoError(t, err) + resp, err = httpClient.Do(req) + t.Cleanup(func() { + resp.Body.Close() + }) + require.NoError(t, err) + }) +} + +func TestMtlsRootCAsFromCertificate(t *testing.T) { + testCertFilePath := createTestCertificate(t) + + tlsSettings := serverTLS.TlsSettings{ + TLSCertificatePath: testCertFilePath, + } + tlsCertRetriever, err := serverTLS.GetTlsCertificateRetriever(tlsSettings) + require.NoError(t, err) + + cert, err := tlsCertRetriever.GetCertificate() + require.NoError(t, err) + + key, err := tlsCertRetriever.GetPrivateKey() + require.NoError(t, err) + + tests := []struct { + name string + cert *tls.Certificate + wantErr bool + wantErrMsg string + }{ + { + name: "returns root CA pool when provided a single self-signed CA cert", + cert: &tls.Certificate{ + Certificate: [][]byte{cert.Raw}, + PrivateKey: key, + Leaf: cert, + }, + wantErr: false, + wantErrMsg: "", + }, + { + name: "returns root CA pool when provided with a full cert chain", + cert: &tls.Certificate{ + Certificate: [][]byte{cert.Raw, cert.Raw}, + PrivateKey: key, + Leaf: cert, + }, + wantErr: false, + wantErrMsg: "", + }, + { + name: "does not return root CA pool when provided with nil", + cert: nil, + wantErr: true, + wantErrMsg: "no certificate provided", + }, + { + name: "does not return root CA pool when provided with empty cert", + cert: &tls.Certificate{}, + wantErr: true, + wantErrMsg: "no certificate provided", + }, + { + name: "does not return root CA pool when provided with single invalid cert", + cert: &tls.Certificate{ + Certificate: [][]byte{[]byte("invalid leaf cert")}, + }, + wantErr: true, + wantErrMsg: "parsing self signed cert", + }, + { + name: "does not return root CA pool when provided with invalid full chain cert", + cert: &tls.Certificate{ + Certificate: [][]byte{[]byte("invalid leaf cert"), []byte("invalid root CA cert")}, + }, + wantErr: true, + wantErrMsg: "parsing root certs", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, err := mtlsRootCAsFromCertificate(tt.cert) + if tt.wantErr { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErrMsg) + assert.Nil(t, r) + } else { + require.NoError(t, err) + assert.NotNil(t, r) + } + }) + } +} + +// createTestCertificate is a test helper that creates a test certificate +// and writes it to a temporary file that is cleaned up after the test. +// Returns the path to the test certificate file +func createTestCertificate(t *testing.T) string { + t.Helper() + + t.Log("Creating test certificate...") + + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "foo.com", + }, + DNSNames: []string{"localhost", "127.0.0.1", "example.com"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(3 * time.Hour), + + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + + // Create certificate with the template and keys + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + require.NoError(t, err) + + // Cert PEM + pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + require.NotNil(t, pemCert) + + // Private Key PEM + privBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + pemKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) + require.NotNil(t, pemKey) + + pemCert = append(pemCert, pemKey...) + + // Write PEM cert and key to a file in a temp dir + testCertFilePath := filepath.Join(t.TempDir(), "dummy.pem") + err = os.WriteFile(testCertFilePath, pemCert, 0o600) + require.NoError(t, err) + + t.Log("Created test certificate file at: ", testCertFilePath) + + return testCertFilePath +} diff --git a/common/listener.go b/common/listener.go index 0214c36b59..745729b233 100644 --- a/common/listener.go +++ b/common/listener.go @@ -100,6 +100,7 @@ func (l *Listener) Stop() { if l.tlsListener != nil { // Stop servicing requests on secure listener _ = l.tlsListener.Close() + log.Printf("[Listener] Stopped listening on tls endpoint %s", l.tlsListener.Addr()) } // Delete the unix socket. @@ -107,7 +108,7 @@ func (l *Listener) Stop() { _ = os.Remove(l.localAddress) } - log.Printf("[Listener] Stopped listening on %s", l.localAddress) + log.Printf("[Listener] Stopped listening on %s", l.listener.Addr()) } // GetMux returns the HTTP mux for the listener. diff --git a/server/tls/tlscertificate_retriever.go b/server/tls/tlscertificate_retriever.go index 28d7f6e952..d3037815be 100644 --- a/server/tls/tlscertificate_retriever.go +++ b/server/tls/tlscertificate_retriever.go @@ -13,6 +13,7 @@ type TlsSettings struct { KeyVaultCertificateName string MSIResourceID string KeyVaultCertificateRefreshInterval time.Duration + UseMTLS bool } func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) {