Skip to content

Commit

Permalink
Add support and tests for Unix domain socket connections
Browse files Browse the repository at this point in the history
This commit introduces functionality to handle Unix domain socket (UDS) connections. It includes updates to the client and test suite to validate UDS support and extends the test server to support Unix sockets. These changes ensure compatibility for environments requiring UDS communication.
  • Loading branch information
wneessen committed Jan 8, 2025
1 parent 149ec4a commit fa129a2
Showing 1 changed file with 86 additions and 8 deletions.
94 changes: 86 additions & 8 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,19 @@ func TestNewClient(t *testing.T) {
})
}
})
t.Run("NewClient on Unix Domain Socket", func(t *testing.T) {
client, err := NewClient("unix:///tmp/mail.sock")
if err != nil {
t.Fatalf("failed to create new client: %s", err)
}
if !client.useUnixSocket {
t.Error("Expected useUnixSocket flag to be set to true")
}
if !strings.EqualFold(client.host, "/tmp/mail.sock") {
t.Errorf("expected host to be set to unix socket path, expected: %s, got: %s", "/tmp/mail.sock",
client.host)
}
})
}

func TestClient_TLSPolicy(t *testing.T) {
Expand Down Expand Up @@ -2816,6 +2829,64 @@ func TestClient_DialToSMTPClientWithContext(t *testing.T) {
t.Fatal("expected connection to fake to fail")
}
})
t.Run("dial to Unix domain socket", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
PortAdder.Add(1)
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
props := &serverProps{
FeatureSet: featureSet,
ListenPort: serverPort,
UnixSocket: true,
}
go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 30)

ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500)
t.Cleanup(cancelDial)
t.Cleanup(func() {
if err := os.RemoveAll(props.UnixSocketPath); err != nil {
t.Errorf("failed to remove unix socket: %s", err)
}
})

client, err := NewClient("unix://"+props.UnixSocketPath+"/server.sock", WithTLSPolicy(NoTLS))
if err != nil {
t.Fatalf("failed to create new client: %s", err)
}
smtpClient, err := client.DialToSMTPClientWithContext(ctxDial)
if err != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
t.Skip("failed to connect to the test server due to timeout")
}
t.Fatalf("failed to connect to test server: %s", err)
}
t.Cleanup(func() {
if err := client.CloseWithSMTPClient(smtpClient); err != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
t.Skip("failed to close the test server connection due to timeout")
}
t.Errorf("failed to close client: %s", err)
}
})
if smtpClient == nil {
t.Fatal("expected SMTP client, got nil")
}
if !smtpClient.HasConnection() {
t.Fatal("expected connection on smtp client")
}
if ok, _ := smtpClient.Extension("DSN"); !ok {
t.Error("expected DSN extension but it was not found")
}
})
}

func TestClient_sendSingleMsg(t *testing.T) {
Expand Down Expand Up @@ -3837,6 +3908,8 @@ type serverProps struct {
SSLListener bool
IsTLS bool
SupportDSN bool
UnixSocket bool
UnixSocketPath string
}

// simpleSMTPServer starts a simple TCP server that resonds to SMTP commands.
Expand All @@ -3850,18 +3923,23 @@ func simpleSMTPServer(ctx context.Context, t *testing.T, props *serverProps) err

var listener net.Listener
var err error
if props.SSLListener {
keypair, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
return fmt.Errorf("failed to read TLS keypair: %w", err)
switch {
case props.UnixSocket:
path, perr := os.MkdirTemp("", "go-mail-server-*")
if perr != nil {
return fmt.Errorf("failed to create temp directory: %w", perr)
}
listener, err = net.Listen("unix", path+"/server.sock")
props.UnixSocketPath = path
case props.SSLListener:
keypair, kerr := tls.X509KeyPair(localhostCert, localhostKey)
if kerr != nil {
return fmt.Errorf("failed to read TLS keypair: %w", kerr)
}
tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}}
listener, err = tls.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort),
tlsConfig)
if err != nil {
t.Fatalf("failed to create TLS listener: %s", err)
}
} else {
default:
listener, err = net.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort))
}
if err != nil {
Expand Down

0 comments on commit fa129a2

Please sign in to comment.