Skip to content

Commit

Permalink
Add CA reloading dialer
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Lopez Rubio <[email protected]>
  • Loading branch information
marclop committed Feb 12, 2025
1 parent 1eaf70f commit ab849a4
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 24 deletions.
84 changes: 73 additions & 11 deletions kafka/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"

awsconfig "github.com/aws/aws-sdk-go-v2/config"
Expand Down Expand Up @@ -217,19 +219,14 @@ func (cfg *CommonConfig) finalize() error {
cfg.TLS.InsecureSkipVerify = true
}
if caCertPath := os.Getenv("KAFKA_TLS_CA_CERT_PATH"); caCertPath != "" {
caCert, err := os.ReadFile(caCertPath)
// Auto-configure a dialer that reloads the CA cert when the file
// changes.
dialFn, err := newCertReloadingDialer(caCertPath, cfg.TLS)
if err != nil {
errs = append(errs, fmt.Errorf("kafka: failed to read CA cert: %w", err))
} else {
rootCa, err := x509.SystemCertPool()
if err != nil {
errs = append(errs, fmt.Errorf("kafka: failed to load system cert pool: %w", err))
}
if !rootCa.AppendCertsFromPEM(caCert) {
errs = append(errs, errors.New("kafka: failed to append CA cert"))
}
cfg.TLS.RootCAs = rootCa
errs = append(errs, fmt.Errorf("kafka: error creating dialer with CA cert: %w", err))
}
cfg.Dialer = dialFn
cfg.TLS = nil
}
}
if cfg.SASL == nil {
Expand Down Expand Up @@ -371,3 +368,68 @@ func topicFieldFunc(f TopicLogFieldFunc) TopicLogFieldFunc {
return zap.Skip()
}
}

// newCertReloadingDialer returns a dialer that reloads the CA cert when the
// file mod time changes.
func newCertReloadingDialer(caPath string, tlsCfg *tls.Config) (func(ctx context.Context, network, address string) (net.Conn, error), error) {
p, err := os.Stat(caPath)
if err != nil {
return nil, err
}
dialer := &net.Dialer{Timeout: 10 * time.Second}
cfg := tlsCfg.Clone()
caCert, err := os.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("kafka: failed to read CA cert: %w", err)
}
cfg.RootCAs = x509.NewCertPool()
if !cfg.RootCAs.AppendCertsFromPEM(caCert) {
return nil, errors.New("kafka: failed to append CA cert")
}
var certModTS atomic.Int64
certModTS.Store(p.ModTime().UnixNano())
var mu sync.RWMutex // guards cfg.RootCAs and certModTS
return func(ctx context.Context, network, host string) (net.Conn, error) {
if p, err := os.Stat(caPath); err == nil {
if modTS := certModTS.Load(); p.ModTime().UnixNano() != modTS {
if err := func() error { // anonymous function to defer unlock.
mu.Lock()
defer mu.Unlock()
currentModTS := p.ModTime().UnixNano()
if modTs := certModTS.Load(); currentModTS != modTs {

Check failure on line 399 in kafka/common.go

View workflow job for this annotation

GitHub Actions / lint

var modTs should be modTS (ST1003)

Check failure on line 399 in kafka/common.go

View workflow job for this annotation

GitHub Actions / lint

var modTs should be modTS (ST1003)
caCert, err := os.ReadFile(caPath)
if err != nil {
return fmt.Errorf(
"failed to read CA cert on reload: %w", err,
)
}
cfg.RootCAs = x509.NewCertPool()
if !cfg.RootCAs.AppendCertsFromPEM(caCert) {
return errors.New("failed to append CA cert on reload")
}
certModTS.Store(currentModTS)
}
return nil
}(); err != nil {
return nil, fmt.Errorf("kafka: %w", err)
}
}
}
mu.RLock()
c := cfg.Clone()
mu.RUnlock()
// Copied this pattern from franz-go client.go.
// https://github.com/twmb/franz-go/blob/f30c518d6b727b9169a90b8c10e2127301822a3a/pkg/kgo/client.go#L440-L453
if c.ServerName == "" {
server, _, err := net.SplitHostPort(host)
if err != nil {
return nil, fmt.Errorf("dialer: unable to split host:port for dialing: %w", err)
}
c.ServerName = server
}
return (&tls.Dialer{
NetDialer: dialer,
Config: c,
}).DialContext(ctx, network, host)
}, nil
}
131 changes: 118 additions & 13 deletions kafka/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
"net"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -326,37 +330,40 @@ func TestTopicFieldFunc(t *testing.T) {
}

// generateValidCACert creates a valid self-signed CA certificate in PEM format.
func generateValidCACert(t testing.TB) []byte {
func generateValidCACert(t testing.TB) ([]byte, *x509.Certificate, *rsa.PrivateKey) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Test CA"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment,
IsCA: true,
SerialNumber: big.NewInt(int64(time.Now().Year())),
Subject: pkix.Name{CommonName: "Test CA"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
require.NoError(t, err)

return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), &template, key
}

func TestTLSCACertPath(t *testing.T) {
t.Run("valid cert", func(t *testing.T) {
t.Setenv("KAFKA_PLAINTEXT", "") // clear plaintext mode

tempFile := filepath.Join(t.TempDir(), "ca_cert.pem")
err := os.WriteFile(tempFile, generateValidCACert(t), 0644)
validCert, _, _ := generateValidCACert(t)
err := os.WriteFile(tempFile, validCert, 0644)
require.NoError(t, err)

t.Setenv("KAFKA_TLS_CA_CERT_PATH", tempFile)
cfg := CommonConfig{Brokers: []string{"broker"}, Logger: zap.NewNop()}
require.NoError(t, cfg.finalize())
require.NotNil(t, cfg.TLS)
require.NotNil(t, cfg.TLS.RootCAs)
require.NotNil(t, cfg.Dialer)
require.Nil(t, cfg.TLS)
})
t.Run("missing file", func(t *testing.T) {
t.Setenv("KAFKA_PLAINTEXT", "")
Expand All @@ -365,7 +372,8 @@ func TestTLSCACertPath(t *testing.T) {
cfg := CommonConfig{Brokers: []string{"broker"}, Logger: zap.NewNop()}
err := cfg.finalize()
require.Error(t, err)
require.Contains(t, err.Error(), "failed to read CA cert")
require.Contains(t, err.Error(), "kafka: error creating dialer with CA cert")
require.Contains(t, err.Error(), "no such file or directory")
})
t.Run("invalid cert", func(t *testing.T) {
t.Setenv("KAFKA_PLAINTEXT", "")
Expand All @@ -377,6 +385,103 @@ func TestTLSCACertPath(t *testing.T) {
cfg := CommonConfig{Brokers: []string{"broker"}, Logger: zap.NewNop()}
err = cfg.finalize()
require.Error(t, err)
require.Contains(t, err.Error(), "failed to append CA cert")
require.Contains(t, err.Error(), "kafka: error creating dialer with CA cert")
})
}

func TestTLSHotReload(t *testing.T) {
tempFile := filepath.Join(t.TempDir(), fmt.Sprintf("%s.pem", rand.Text()))

Check failure on line 393 in kafka/common_test.go

View workflow job for this annotation

GitHub Actions / run-tests

undefined: rand.Text

Check failure on line 393 in kafka/common_test.go

View workflow job for this annotation

GitHub Actions / lint

undefined: rand.Text (compile)

Check failure on line 393 in kafka/common_test.go

View workflow job for this annotation

GitHub Actions / lint

undefined: rand.Text (compile)

Check failure on line 393 in kafka/common_test.go

View workflow job for this annotation

GitHub Actions / run-tests

undefined: rand.Text
caCertBytes, ca, key := generateValidCACert(t)
err := os.WriteFile(tempFile, caCertBytes, 0644)
require.NoError(t, err)

dialFunc, err := newCertReloadingDialer(tempFile, &tls.Config{})
require.NoError(t, err)
require.NotNil(t, dialFunc)

tmpl := x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{CommonName: "localhost"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}

derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, ca, &key.PublicKey, key)
require.NoError(t, err)

certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})

tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(t, err)

// Spawn a TLS TCP server using the generated certificate.
ln, err := tls.Listen("tcp", "localhost:0", &tls.Config{Certificates: []tls.Certificate{tlsCert}})
addr := ln.Addr()
require.NoError(t, err)
timeout := time.After(time.Second)
go func() {
defer ln.Close()
for {
select {
case <-timeout:
return
default:
}
conn, err := ln.Accept()
if err != nil {
return
}
if _, err := conn.Read(make([]byte, 5)); err != nil {
continue
}
conn.Close()
}
}()

var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())
for i := 0; i < runtime.GOMAXPROCS(0)*4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
case <-time.After(time.Millisecond):
conn, err := dialFunc(ctx, addr.Network(), addr.String())
if !errors.Is(err, io.EOF) {
select {
case <-ctx.Done():
continue
default:
}
// Ensure no TLS errors occur.
require.NoError(t, err)
}
if conn != nil {
_, err := conn.Write([]byte("hello"))
require.NoError(t, err)
}
conn.Close()
}
}
}()
}

<-time.After(200 * time.Millisecond)

for i := 0; i < runtime.GOMAXPROCS(0); i++ {
// Update the file, so that the CA cert is reloaded when dialer is called again.
err = os.WriteFile(tempFile, caCertBytes, 0644)
require.NoError(t, err)
<-time.After(50 * time.Millisecond)
}

cancel() // allow go routines to exit
wg.Wait() // wait for all go routines to finish
}

0 comments on commit ab849a4

Please sign in to comment.