From a897cc1d5b0940e2e92a30ad7701114807e5caac Mon Sep 17 00:00:00 2001 From: lplearn Date: Tue, 25 Oct 2022 02:12:31 +0000 Subject: [PATCH] feat: support client stop --- client/client.go | 80 +++++++++++++++++++++++++--------------------- client/cmd/main.go | 8 +++++ 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/client/client.go b/client/client.go index f315d4e..9270581 100644 --- a/client/client.go +++ b/client/client.go @@ -20,6 +20,7 @@ type ClientEndpoint struct { ServerEndpointSocket string TokenSource token.TokenSourcePlugin TlsConfig *tls.Config + DoneCh chan struct{} } func (c *ClientEndpoint) Start() { @@ -37,44 +38,51 @@ func (c *ClientEndpoint) Start() { } defer listener.Close() log.Infow("Client endpoint start up successful", "listen address", listener.Addr()) - for { - // Accept client application connectin request - conn, err := listener.Accept() - if err != nil { - log.Errorw("Client app connect failed", "error", err.Error()) - } else { - logger := log.WithValues(constants.ClientAppAddr, conn.RemoteAddr().String()) - logger.Info("Client connection accepted, prepare to entablish tunnel with server endpint for this connection.") - go func() { - defer func() { - conn.Close() - logger.Info("Tunnel closed") - }() - // Open a quic stream for each client application connection. - stream, err := session.OpenStreamSync(context.Background()) - if err != nil { - logger.Errorw("Failed to open stream to server endpoint.", "error", err.Error()) - return - } - defer stream.Close() - logger = logger.WithValues(constants.StreamID, stream.StreamID()) - // Create a context argument for each new tunnel - ctx := context.WithValue( - logger.WithContext(parent_ctx), - constants.CtxClientAppAddr, conn.RemoteAddr().String()) - hsh := tunnel.NewHandshakeHelper(constants.TokenLength, handshake) - hsh.TokenSource = &c.TokenSource - // Create a new tunnel for the new client application connection. - tun := tunnel.NewTunnel(&stream, constants.ClientEndpoint) - tun.Conn = &conn - tun.Hsh = &hsh - if !tun.HandShake(ctx) { - return + go func() { + for { + // Accept client application connectin request + conn, err := listener.Accept() + if err != nil { + log.Errorw("Client app connect failed", "error", err.Error()) + if oe, ok := err.(*net.OpError); ok && oe.Op == "accept" { + break } - tun.Establish(ctx) - }() + } else { + logger := log.WithValues(constants.ClientAppAddr, conn.RemoteAddr().String()) + logger.Info("Client connection accepted, prepare to entablish tunnel with server endpint for this connection.") + go func() { + defer func() { + conn.Close() + logger.Info("Tunnel closed") + }() + // Open a quic stream for each client application connection. + stream, err := session.OpenStreamSync(context.Background()) + if err != nil { + logger.Errorw("Failed to open stream to server endpoint.", "error", err.Error()) + return + } + defer stream.Close() + logger = logger.WithValues(constants.StreamID, stream.StreamID()) + // Create a context argument for each new tunnel + ctx := context.WithValue( + logger.WithContext(parent_ctx), + constants.CtxClientAppAddr, conn.RemoteAddr().String()) + hsh := tunnel.NewHandshakeHelper(constants.TokenLength, handshake) + hsh.TokenSource = &c.TokenSource + // Create a new tunnel for the new client application connection. + tun := tunnel.NewTunnel(&stream, constants.ClientEndpoint) + tun.Conn = &conn + tun.Hsh = &hsh + if !tun.HandShake(ctx) { + return + } + tun.Establish(ctx) + }() + } } - } + }() + <-c.DoneCh + log.Info("The client is going to close") } func handshake(ctx context.Context, stream *quic.Stream, hsh *tunnel.HandshakeHelper) (bool, *net.Conn) { diff --git a/client/cmd/main.go b/client/cmd/main.go index 2a6839d..c3802f6 100644 --- a/client/cmd/main.go +++ b/client/cmd/main.go @@ -26,6 +26,7 @@ var ( apiOptions *options.RestfulAPIOptions secOptions *options.SecureOptions logOptions *log.Options + doneCh chan struct{} ) func buildCommand(basename string) *cobra.Command { @@ -93,6 +94,7 @@ func runFunc(co *options.ClientOptions, ao *options.RestfulAPIOptions, seco *opt caFile := seco.CaFile verifyServer := seco.VerifyRemoteEndpoint apiListenOn := ao.HttpdListenOn + doneCh = make(chan struct{}) tlsConfig := &tls.Config{ InsecureSkipVerify: !verifyServer, @@ -133,6 +135,7 @@ func runFunc(co *options.ClientOptions, ao *options.RestfulAPIOptions, seco *opt ServerEndpointSocket: serverEndpointSocket, TokenSource: loadTokenSourcePlugin(tokenPlugin, tokenSource), TlsConfig: tlsConfig, + DoneCh: doneCh, } c.Start() } @@ -162,3 +165,8 @@ func main() { os.Exit(1) } } + +//export Stop +func Stop() { + close(doneCh) +}