From d2d8ab658f610baffb3c042db05630c7740eb1c2 Mon Sep 17 00:00:00 2001 From: database64128 Date: Sat, 15 Feb 2025 16:27:18 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=BF=EF=B8=8F=20service:=20do=20not=20h?= =?UTF-8?q?ardcode=20port=20numbers=20in=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit While here, clean up the tests a little bit. --- service/client_server_test.go | 199 +++++++++++++++------------------- 1 file changed, 88 insertions(+), 111 deletions(-) diff --git a/service/client_server_test.go b/service/client_server_test.go index 3be086b..c21894a 100644 --- a/service/client_server_test.go +++ b/service/client_server_test.go @@ -6,6 +6,7 @@ import ( "log/slog" "net" "net/netip" + "slices" "testing" "time" @@ -23,15 +24,15 @@ var cases = []struct { name: "ZeroOverhead", serverConfig: ServerConfig{ Name: "wg0", - ProxyListenAddress: "[::1]:20200", + ProxyListenAddress: "[::1]:", ProxyMode: "zero-overhead", - WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20201)), + WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 0)), MTU: 1500, }, clientConfig: ClientConfig{ Name: "wg0", - WgListenAddress: "[::1]:20202", - ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20200)), + WgListenAddress: "[::1]:", + ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 0)), ProxyMode: "zero-overhead", MTU: 1500, }, @@ -40,119 +41,131 @@ var cases = []struct { name: "Paranoid", serverConfig: ServerConfig{ Name: "wg0", - ProxyListenAddress: "[::1]:20200", + ProxyListenAddress: "[::1]:", ProxyMode: "paranoid", - WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20201)), + WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 0)), MTU: 1500, }, clientConfig: ClientConfig{ Name: "wg0", - WgListenAddress: "[::1]:20202", - ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20200)), + WgListenAddress: "[::1]:", + ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 0)), ProxyMode: "paranoid", MTU: 1500, }, }, } -func init() { - for i := range cases { - psk := generateTestPSK() - cases[i].serverConfig.ProxyPSK = psk - cases[i].clientConfig.ProxyPSK = psk - } -} - -func generateTestPSK() []byte { +func testClientServerConn( + t *testing.T, + logger *tslog.Logger, + serverConfig ServerConfig, + clientConfig ClientConfig, + f func(t *testing.T, clientConn, serverConn *net.UDPConn), +) { psk := make([]byte, 32) rand.Read(psk) - return psk -} + serverConfig.ProxyPSK = psk + clientConfig.ProxyPSK = psk -func testClientServerHandshake(t *testing.T, logger *tslog.Logger, serverConfig ServerConfig, clientConfig ClientConfig) { - sc := Config{ - Servers: []ServerConfig{serverConfig}, - Clients: []ClientConfig{clientConfig}, - } - m, err := sc.Manager(logger) + listenConfigCache := conn.NewListenConfigCache() + ctx := t.Context() + + s, err := serverConfig.Server(logger, listenConfigCache) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to create server service %q: %v", serverConfig.Name, err) } - ctx := t.Context() - if err = m.Start(ctx); err != nil { - t.Fatal(err) + if err = s.Start(ctx); err != nil { + t.Fatalf("Failed to start server service %q: %v", serverConfig.Name, err) } - defer m.Stop() + defer s.Stop() - // Make packets. - handshakeInitiationPacket := make([]byte, packet.WireGuardMessageLengthHandshakeInitiation) - handshakeInitiationPacket[0] = packet.WireGuardMessageTypeHandshakeInitiation - rand.Read(handshakeInitiationPacket[1:]) - expectedHandshakeInitiationPacket := make([]byte, packet.WireGuardMessageLengthHandshakeInitiation) - copy(expectedHandshakeInitiationPacket, handshakeInitiationPacket) - receivedHandshakeInitiationPacket := make([]byte, packet.WireGuardMessageLengthHandshakeInitiation+1) + proxyAddrPort := s.proxyConn.LocalAddr().(*net.UDPAddr).AddrPort() + clientConfig.ProxyEndpointAddress = conn.AddrFromIPPort(proxyAddrPort) - handshakeResponsePacket := make([]byte, packet.WireGuardMessageLengthHandshakeResponse) - handshakeResponsePacket[0] = packet.WireGuardMessageTypeHandshakeResponse - rand.Read(handshakeResponsePacket[1:]) - expectedHandshakeResponsePacket := make([]byte, packet.WireGuardMessageLengthHandshakeResponse) - copy(expectedHandshakeResponsePacket, handshakeResponsePacket) - receivedHandshakeResponsePacket := make([]byte, packet.WireGuardMessageLengthHandshakeResponse+1) + c, err := clientConfig.Client(logger, listenConfigCache) + if err != nil { + t.Fatalf("Failed to create client service %q: %v", clientConfig.Name, err) + } + if err = c.Start(ctx); err != nil { + t.Fatalf("Failed to start client service %q: %v", clientConfig.Name, err) + } + defer c.Stop() - // Start client and server conns. - clientConn, err := net.Dial("udp", clientConfig.WgListenAddress) + clientConn, err := net.Dial("udp", c.wgListenAddress) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to dial client connection: %v", err) } defer clientConn.Close() - serverConn, _, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", serverConfig.WgEndpointAddress.String()) + serverConn, _, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", "[::1]:") if err != nil { - t.Fatal(err) + t.Fatalf("Failed to listen server connection: %v", err) } defer serverConn.Close() + serverWgAddrPort := serverConn.LocalAddr().(*net.UDPAddr).AddrPort() + s.wgAddr = conn.AddrFromIPPort(serverWgAddrPort) + // Set read/write deadlines to make the test fail fast. deadline := time.Now().Add(3 * time.Second) if err = clientConn.SetDeadline(deadline); err != nil { - t.Fatal(err) + t.Fatalf("Failed to set client connection deadline: %v", err) } if err = serverConn.SetDeadline(deadline); err != nil { - t.Fatal(err) + t.Fatalf("Failed to set server connection deadline: %v", err) } + f(t, clientConn.(*net.UDPConn), serverConn) +} + +func testClientServerHandshake(t *testing.T, clientConn, serverConn *net.UDPConn) { + // Make packets. + handshakeInitiationPacket := make([]byte, packet.WireGuardMessageLengthHandshakeInitiation) + handshakeInitiationPacket[0] = packet.WireGuardMessageTypeHandshakeInitiation + rand.Read(handshakeInitiationPacket[1:]) + expectedHandshakeInitiationPacket := slices.Clone(handshakeInitiationPacket) + receivedHandshakeInitiationPacket := make([]byte, packet.WireGuardMessageLengthHandshakeInitiation+1) + + handshakeResponsePacket := make([]byte, packet.WireGuardMessageLengthHandshakeResponse) + handshakeResponsePacket[0] = packet.WireGuardMessageTypeHandshakeResponse + rand.Read(handshakeResponsePacket[1:]) + expectedHandshakeResponsePacket := slices.Clone(handshakeResponsePacket) + receivedHandshakeResponsePacket := make([]byte, packet.WireGuardMessageLengthHandshakeResponse+1) + // Client sends handshake initiation. - _, err = clientConn.Write(handshakeInitiationPacket) - if err != nil { - t.Fatal(err) + if _, err := clientConn.Write(handshakeInitiationPacket); err != nil { + t.Fatalf("Failed to write handshake initiation packet: %v", err) } // Server receives handshake initiation. n, addr, err := serverConn.ReadFromUDPAddrPort(receivedHandshakeInitiationPacket) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to read handshake initiation packet: %v", err) } + receivedHandshakeInitiationPacket = receivedHandshakeInitiationPacket[:n] // Server verifies handshake initiation. - if !bytes.Equal(receivedHandshakeInitiationPacket[:n], expectedHandshakeInitiationPacket) { - t.Error("Received handshake initiation packet does not match expectation.") + if !bytes.Equal(receivedHandshakeInitiationPacket, expectedHandshakeInitiationPacket) { + t.Errorf("receivedHandshakeInitiationPacket = %v, want %v", receivedHandshakeInitiationPacket, expectedHandshakeInitiationPacket) } // Server sends handshake response. _, err = serverConn.WriteToUDPAddrPort(handshakeResponsePacket, addr) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to write handshake response packet: %v", err) } // Client receives handshake response. n, err = clientConn.Read(receivedHandshakeResponsePacket) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to read handshake response packet: %v", err) } + receivedHandshakeResponsePacket = receivedHandshakeResponsePacket[:n] // Client verifies handshake response. - if !bytes.Equal(receivedHandshakeResponsePacket[:n], expectedHandshakeResponsePacket) { - t.Error("Received handshake response packet does not match expectation.") + if !bytes.Equal(receivedHandshakeResponsePacket, expectedHandshakeResponsePacket) { + t.Errorf("receivedHandshakeResponsePacket = %v, want %v", receivedHandshakeResponsePacket, expectedHandshakeResponsePacket) } } @@ -162,88 +175,52 @@ func TestClientServerHandshake(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { - testClientServerHandshake(t, logger, c.serverConfig, c.clientConfig) + testClientServerConn(t, logger, c.serverConfig, c.clientConfig, testClientServerHandshake) }) } } -func testClientServerDataPackets(t *testing.T, logger *tslog.Logger, serverConfig ServerConfig, clientConfig ClientConfig) { - sc := Config{ - Servers: []ServerConfig{serverConfig}, - Clients: []ClientConfig{clientConfig}, - } - m, err := sc.Manager(logger) - if err != nil { - t.Fatal(err) - } - ctx := t.Context() - if err = m.Start(ctx); err != nil { - t.Fatal(err) - } - defer m.Stop() - +func testClientServerDataPackets(t *testing.T, clientConn, serverConn *net.UDPConn) { // Make packets. smallDataPacket := make([]byte, 1024) smallDataPacket[0] = packet.WireGuardMessageTypeData rand.Read(smallDataPacket[1:]) - expectedSmallDataPacket := make([]byte, 1024) - copy(expectedSmallDataPacket, smallDataPacket) + expectedSmallDataPacket := slices.Clone(smallDataPacket) receivedSmallDataPacket := make([]byte, 1024+1) - // Start client and server conns. - clientConn, err := net.Dial("udp", clientConfig.WgListenAddress) - if err != nil { - t.Fatal(err) - } - defer clientConn.Close() - - serverConn, _, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", serverConfig.WgEndpointAddress.String()) - if err != nil { - t.Fatal(err) - } - defer serverConn.Close() - - // Set read/write deadlines to make the test fail fast. - deadline := time.Now().Add(3 * time.Second) - if err = clientConn.SetDeadline(deadline); err != nil { - t.Fatal(err) - } - if err = serverConn.SetDeadline(deadline); err != nil { - t.Fatal(err) - } - // Client sends small data packet. - _, err = clientConn.Write(smallDataPacket) - if err != nil { - t.Fatal(err) + if _, err := clientConn.Write(smallDataPacket); err != nil { + t.Fatalf("Failed to write small data packet: %v", err) } // Server receives small data packet. n, addr, err := serverConn.ReadFromUDPAddrPort(receivedSmallDataPacket) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to read small data packet: %v", err) } + receivedSmallDataPacket = receivedSmallDataPacket[:n] // Server verifies small data packet. - if !bytes.Equal(receivedSmallDataPacket[:n], expectedSmallDataPacket) { - t.Error("Received small data packet does not match expectation.") + if !bytes.Equal(receivedSmallDataPacket, expectedSmallDataPacket) { + t.Errorf("receivedSmallDataPacket = %v, want %v", receivedSmallDataPacket, expectedSmallDataPacket) } // Server sends small data packet. _, err = serverConn.WriteToUDPAddrPort(smallDataPacket, addr) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to write small data packet: %v", err) } // Client receives small data packet. n, err = clientConn.Read(receivedSmallDataPacket) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to read small data packet: %v", err) } + receivedSmallDataPacket = receivedSmallDataPacket[:n] // Client verifies small data packet. - if !bytes.Equal(receivedSmallDataPacket[:n], expectedSmallDataPacket) { - t.Error("Received small data packet does not match expectation.") + if !bytes.Equal(receivedSmallDataPacket, expectedSmallDataPacket) { + t.Errorf("receivedSmallDataPacket = %v, want %v", receivedSmallDataPacket, expectedSmallDataPacket) } } @@ -253,7 +230,7 @@ func TestClientServerDataPackets(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { - testClientServerDataPackets(t, logger, c.serverConfig, c.clientConfig) + testClientServerConn(t, logger, c.serverConfig, c.clientConfig, testClientServerDataPackets) }) } }