Skip to content

Commit

Permalink
🐿️ service: do not hardcode port numbers in tests
Browse files Browse the repository at this point in the history
While here, clean up the tests a little bit.
  • Loading branch information
database64128 committed Feb 15, 2025
1 parent 74980f3 commit d2d8ab6
Showing 1 changed file with 88 additions and 111 deletions.
199 changes: 88 additions & 111 deletions service/client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log/slog"
"net"
"net/netip"
"slices"
"testing"
"time"

Expand All @@ -23,15 +24,15 @@ var cases = []struct {
name: "ZeroOverhead",
serverConfig: ServerConfig{
Name: "wg0",
ProxyListenAddress: "[::1]:20200",
ProxyListenAddress: "[::1]:",
ProxyMode: "zero-overhead",
WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20201)),
WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 0)),
MTU: 1500,
},
clientConfig: ClientConfig{
Name: "wg0",
WgListenAddress: "[::1]:20202",
ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20200)),
WgListenAddress: "[::1]:",
ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 0)),
ProxyMode: "zero-overhead",
MTU: 1500,
},
Expand All @@ -40,119 +41,131 @@ var cases = []struct {
name: "Paranoid",
serverConfig: ServerConfig{
Name: "wg0",
ProxyListenAddress: "[::1]:20200",
ProxyListenAddress: "[::1]:",
ProxyMode: "paranoid",
WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20201)),
WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 0)),
MTU: 1500,
},
clientConfig: ClientConfig{
Name: "wg0",
WgListenAddress: "[::1]:20202",
ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20200)),
WgListenAddress: "[::1]:",
ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 0)),
ProxyMode: "paranoid",
MTU: 1500,
},
},
}

func init() {
for i := range cases {
psk := generateTestPSK()
cases[i].serverConfig.ProxyPSK = psk
cases[i].clientConfig.ProxyPSK = psk
}
}

func generateTestPSK() []byte {
func testClientServerConn(
t *testing.T,
logger *tslog.Logger,
serverConfig ServerConfig,
clientConfig ClientConfig,
f func(t *testing.T, clientConn, serverConn *net.UDPConn),
) {
psk := make([]byte, 32)
rand.Read(psk)
return psk
}
serverConfig.ProxyPSK = psk
clientConfig.ProxyPSK = psk

func testClientServerHandshake(t *testing.T, logger *tslog.Logger, serverConfig ServerConfig, clientConfig ClientConfig) {
sc := Config{
Servers: []ServerConfig{serverConfig},
Clients: []ClientConfig{clientConfig},
}
m, err := sc.Manager(logger)
listenConfigCache := conn.NewListenConfigCache()
ctx := t.Context()

s, err := serverConfig.Server(logger, listenConfigCache)
if err != nil {
t.Fatal(err)
t.Fatalf("Failed to create server service %q: %v", serverConfig.Name, err)
}
ctx := t.Context()
if err = m.Start(ctx); err != nil {
t.Fatal(err)
if err = s.Start(ctx); err != nil {
t.Fatalf("Failed to start server service %q: %v", serverConfig.Name, err)
}
defer m.Stop()
defer s.Stop()

// Make packets.
handshakeInitiationPacket := make([]byte, packet.WireGuardMessageLengthHandshakeInitiation)
handshakeInitiationPacket[0] = packet.WireGuardMessageTypeHandshakeInitiation
rand.Read(handshakeInitiationPacket[1:])
expectedHandshakeInitiationPacket := make([]byte, packet.WireGuardMessageLengthHandshakeInitiation)
copy(expectedHandshakeInitiationPacket, handshakeInitiationPacket)
receivedHandshakeInitiationPacket := make([]byte, packet.WireGuardMessageLengthHandshakeInitiation+1)
proxyAddrPort := s.proxyConn.LocalAddr().(*net.UDPAddr).AddrPort()
clientConfig.ProxyEndpointAddress = conn.AddrFromIPPort(proxyAddrPort)

handshakeResponsePacket := make([]byte, packet.WireGuardMessageLengthHandshakeResponse)
handshakeResponsePacket[0] = packet.WireGuardMessageTypeHandshakeResponse
rand.Read(handshakeResponsePacket[1:])
expectedHandshakeResponsePacket := make([]byte, packet.WireGuardMessageLengthHandshakeResponse)
copy(expectedHandshakeResponsePacket, handshakeResponsePacket)
receivedHandshakeResponsePacket := make([]byte, packet.WireGuardMessageLengthHandshakeResponse+1)
c, err := clientConfig.Client(logger, listenConfigCache)
if err != nil {
t.Fatalf("Failed to create client service %q: %v", clientConfig.Name, err)
}
if err = c.Start(ctx); err != nil {
t.Fatalf("Failed to start client service %q: %v", clientConfig.Name, err)
}
defer c.Stop()

// Start client and server conns.
clientConn, err := net.Dial("udp", clientConfig.WgListenAddress)
clientConn, err := net.Dial("udp", c.wgListenAddress)
if err != nil {
t.Fatal(err)
t.Fatalf("Failed to dial client connection: %v", err)
}
defer clientConn.Close()

serverConn, _, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", serverConfig.WgEndpointAddress.String())
serverConn, _, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", "[::1]:")
if err != nil {
t.Fatal(err)
t.Fatalf("Failed to listen server connection: %v", err)
}
defer serverConn.Close()

serverWgAddrPort := serverConn.LocalAddr().(*net.UDPAddr).AddrPort()
s.wgAddr = conn.AddrFromIPPort(serverWgAddrPort)

// Set read/write deadlines to make the test fail fast.
deadline := time.Now().Add(3 * time.Second)
if err = clientConn.SetDeadline(deadline); err != nil {
t.Fatal(err)
t.Fatalf("Failed to set client connection deadline: %v", err)
}
if err = serverConn.SetDeadline(deadline); err != nil {
t.Fatal(err)
t.Fatalf("Failed to set server connection deadline: %v", err)
}

f(t, clientConn.(*net.UDPConn), serverConn)
}

func testClientServerHandshake(t *testing.T, clientConn, serverConn *net.UDPConn) {
// Make packets.
handshakeInitiationPacket := make([]byte, packet.WireGuardMessageLengthHandshakeInitiation)
handshakeInitiationPacket[0] = packet.WireGuardMessageTypeHandshakeInitiation
rand.Read(handshakeInitiationPacket[1:])
expectedHandshakeInitiationPacket := slices.Clone(handshakeInitiationPacket)
receivedHandshakeInitiationPacket := make([]byte, packet.WireGuardMessageLengthHandshakeInitiation+1)

handshakeResponsePacket := make([]byte, packet.WireGuardMessageLengthHandshakeResponse)
handshakeResponsePacket[0] = packet.WireGuardMessageTypeHandshakeResponse
rand.Read(handshakeResponsePacket[1:])
expectedHandshakeResponsePacket := slices.Clone(handshakeResponsePacket)
receivedHandshakeResponsePacket := make([]byte, packet.WireGuardMessageLengthHandshakeResponse+1)

// Client sends handshake initiation.
_, err = clientConn.Write(handshakeInitiationPacket)
if err != nil {
t.Fatal(err)
if _, err := clientConn.Write(handshakeInitiationPacket); err != nil {
t.Fatalf("Failed to write handshake initiation packet: %v", err)
}

// Server receives handshake initiation.
n, addr, err := serverConn.ReadFromUDPAddrPort(receivedHandshakeInitiationPacket)
if err != nil {
t.Fatal(err)
t.Fatalf("Failed to read handshake initiation packet: %v", err)
}
receivedHandshakeInitiationPacket = receivedHandshakeInitiationPacket[:n]

// Server verifies handshake initiation.
if !bytes.Equal(receivedHandshakeInitiationPacket[:n], expectedHandshakeInitiationPacket) {
t.Error("Received handshake initiation packet does not match expectation.")
if !bytes.Equal(receivedHandshakeInitiationPacket, expectedHandshakeInitiationPacket) {
t.Errorf("receivedHandshakeInitiationPacket = %v, want %v", receivedHandshakeInitiationPacket, expectedHandshakeInitiationPacket)
}

// Server sends handshake response.
_, err = serverConn.WriteToUDPAddrPort(handshakeResponsePacket, addr)
if err != nil {
t.Fatal(err)
t.Fatalf("Failed to write handshake response packet: %v", err)
}

// Client receives handshake response.
n, err = clientConn.Read(receivedHandshakeResponsePacket)
if err != nil {
t.Fatal(err)
t.Fatalf("Failed to read handshake response packet: %v", err)
}
receivedHandshakeResponsePacket = receivedHandshakeResponsePacket[:n]

// Client verifies handshake response.
if !bytes.Equal(receivedHandshakeResponsePacket[:n], expectedHandshakeResponsePacket) {
t.Error("Received handshake response packet does not match expectation.")
if !bytes.Equal(receivedHandshakeResponsePacket, expectedHandshakeResponsePacket) {
t.Errorf("receivedHandshakeResponsePacket = %v, want %v", receivedHandshakeResponsePacket, expectedHandshakeResponsePacket)
}
}

Expand All @@ -162,88 +175,52 @@ func TestClientServerHandshake(t *testing.T) {

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
testClientServerHandshake(t, logger, c.serverConfig, c.clientConfig)
testClientServerConn(t, logger, c.serverConfig, c.clientConfig, testClientServerHandshake)
})
}
}

func testClientServerDataPackets(t *testing.T, logger *tslog.Logger, serverConfig ServerConfig, clientConfig ClientConfig) {
sc := Config{
Servers: []ServerConfig{serverConfig},
Clients: []ClientConfig{clientConfig},
}
m, err := sc.Manager(logger)
if err != nil {
t.Fatal(err)
}
ctx := t.Context()
if err = m.Start(ctx); err != nil {
t.Fatal(err)
}
defer m.Stop()

func testClientServerDataPackets(t *testing.T, clientConn, serverConn *net.UDPConn) {
// Make packets.
smallDataPacket := make([]byte, 1024)
smallDataPacket[0] = packet.WireGuardMessageTypeData
rand.Read(smallDataPacket[1:])
expectedSmallDataPacket := make([]byte, 1024)
copy(expectedSmallDataPacket, smallDataPacket)
expectedSmallDataPacket := slices.Clone(smallDataPacket)
receivedSmallDataPacket := make([]byte, 1024+1)

// Start client and server conns.
clientConn, err := net.Dial("udp", clientConfig.WgListenAddress)
if err != nil {
t.Fatal(err)
}
defer clientConn.Close()

serverConn, _, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", serverConfig.WgEndpointAddress.String())
if err != nil {
t.Fatal(err)
}
defer serverConn.Close()

// Set read/write deadlines to make the test fail fast.
deadline := time.Now().Add(3 * time.Second)
if err = clientConn.SetDeadline(deadline); err != nil {
t.Fatal(err)
}
if err = serverConn.SetDeadline(deadline); err != nil {
t.Fatal(err)
}

// Client sends small data packet.
_, err = clientConn.Write(smallDataPacket)
if err != nil {
t.Fatal(err)
if _, err := clientConn.Write(smallDataPacket); err != nil {
t.Fatalf("Failed to write small data packet: %v", err)
}

// Server receives small data packet.
n, addr, err := serverConn.ReadFromUDPAddrPort(receivedSmallDataPacket)
if err != nil {
t.Fatal(err)
t.Fatalf("Failed to read small data packet: %v", err)
}
receivedSmallDataPacket = receivedSmallDataPacket[:n]

// Server verifies small data packet.
if !bytes.Equal(receivedSmallDataPacket[:n], expectedSmallDataPacket) {
t.Error("Received small data packet does not match expectation.")
if !bytes.Equal(receivedSmallDataPacket, expectedSmallDataPacket) {
t.Errorf("receivedSmallDataPacket = %v, want %v", receivedSmallDataPacket, expectedSmallDataPacket)
}

// Server sends small data packet.
_, err = serverConn.WriteToUDPAddrPort(smallDataPacket, addr)
if err != nil {
t.Fatal(err)
t.Fatalf("Failed to write small data packet: %v", err)
}

// Client receives small data packet.
n, err = clientConn.Read(receivedSmallDataPacket)
if err != nil {
t.Fatal(err)
t.Fatalf("Failed to read small data packet: %v", err)
}
receivedSmallDataPacket = receivedSmallDataPacket[:n]

// Client verifies small data packet.
if !bytes.Equal(receivedSmallDataPacket[:n], expectedSmallDataPacket) {
t.Error("Received small data packet does not match expectation.")
if !bytes.Equal(receivedSmallDataPacket, expectedSmallDataPacket) {
t.Errorf("receivedSmallDataPacket = %v, want %v", receivedSmallDataPacket, expectedSmallDataPacket)
}
}

Expand All @@ -253,7 +230,7 @@ func TestClientServerDataPackets(t *testing.T) {

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
testClientServerDataPackets(t, logger, c.serverConfig, c.clientConfig)
testClientServerConn(t, logger, c.serverConfig, c.clientConfig, testClientServerDataPackets)
})
}
}

0 comments on commit d2d8ab6

Please sign in to comment.