diff --git a/pkg/boltz/ws.go b/pkg/boltz/ws.go index e9fa854b..d9be8285 100644 --- a/pkg/boltz/ws.go +++ b/pkg/boltz/ws.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "slices" + "sync" "time" "github.com/BoltzExchange/boltz-client/v2/internal/logger" @@ -24,15 +25,17 @@ type SwapUpdate struct { } type Websocket struct { - Updates chan SwapUpdate - - apiUrl string - subscriptions chan bool - conn *websocket.Conn - closed bool - reconnect bool - dialer *websocket.Dialer - swapIds []string + Updates chan SwapUpdate + updatesLock sync.Mutex + + apiUrl string + subscriptions chan bool + conn *websocket.Conn + closed bool + dialer *websocket.Dialer + swapIds []string + swapIdsLock sync.Mutex + reconnectInterval time.Duration } type wsResponse struct { @@ -51,14 +54,18 @@ func (boltz *Api) NewWebsocket() *Websocket { } return &Websocket{ - apiUrl: boltz.URL, - subscriptions: make(chan bool), - dialer: &dialer, - Updates: make(chan SwapUpdate), + apiUrl: boltz.URL, + subscriptions: make(chan bool), + dialer: &dialer, + Updates: make(chan SwapUpdate), + reconnectInterval: reconnectInterval, } } func (boltz *Websocket) Connect() error { + if boltz.closed { + return errors.New("websocket is closed") + } wsUrl, err := url.Parse(boltz.apiUrl) if err != nil { return err @@ -108,12 +115,13 @@ func (boltz *Websocket) Connect() error { for { msgType, message, err := conn.ReadMessage() if err != nil { - if boltz.closed { - close(boltz.Updates) - return - } - if !boltz.reconnect { + // if `close` was intentionally called, `Connected` will return false + // since the connection has already been set to nil. + if boltz.Connected() { + boltz.conn = nil logger.Error("could not receive message: " + err.Error()) + } else { + return } break } @@ -122,58 +130,56 @@ func (boltz *Websocket) Connect() error { switch msgType { case websocket.TextMessage: - var response wsResponse - if err := json.Unmarshal(message, &response); err != nil { - logger.Errorf("could not parse websocket response: %s", err) - continue - } - if response.Error != "" { - logger.Errorf("boltz websocket error: %s", response.Error) - continue - } - - switch response.Event { - case "update": - switch response.Channel { - case "swap.update": - for _, arg := range response.Args { - var update SwapUpdate - if err := mapstructure.Decode(arg, &update); err != nil { - logger.Errorf("invalid boltz response: %v", err) - } - boltz.Updates <- update - } - default: - logger.Warnf("unknown update channel: %s", response.Channel) - } - case "subscribe": - boltz.subscriptions <- true - continue - default: - logger.Warnf("unknown event: %s", response.Event) + if err := boltz.handleTextMessage(message); err != nil { + logger.Errorf("could not handle websocket message: %s", err) } + default: + logger.Warnf("unknown message type: %v", msgType) } } + pingTicker.Stop() for { - pingTicker.Stop() - if boltz.reconnect { - boltz.reconnect = false - return - } else { - logger.Errorf("lost connection to boltz ws, reconnecting in %s", reconnectInterval) - time.Sleep(reconnectInterval) - } + logger.Errorf("lost connection to boltz ws, reconnecting in %s", boltz.reconnectInterval) + time.Sleep(boltz.reconnectInterval) err := boltz.Connect() if err == nil { return } } }() + return nil +} - if len(boltz.swapIds) > 0 { - return boltz.subscribe(boltz.swapIds) +func (boltz *Websocket) handleTextMessage(data []byte) error { + boltz.updatesLock.Lock() + defer boltz.updatesLock.Unlock() + var response wsResponse + if err := json.Unmarshal(data, &response); err != nil { + return fmt.Errorf("invalid json: %s", err) + } + if response.Error != "" { + return fmt.Errorf("boltz error: %s", response.Error) } + switch response.Event { + case "update": + switch response.Channel { + case "swap.update": + for _, arg := range response.Args { + var update SwapUpdate + if err := mapstructure.Decode(arg, &update); err != nil { + return fmt.Errorf("invalid boltz response: %v", err) + } + boltz.Updates <- update + } + default: + logger.Warnf("unknown update channel: %s", response.Channel) + } + case "subscribe": + boltz.subscriptions <- true + default: + logger.Warnf("unknown ws event: %s", response.Event) + } return nil } @@ -201,6 +207,14 @@ func (boltz *Websocket) subscribe(swapIds []string) error { } func (boltz *Websocket) Subscribe(swapIds []string) error { + if len(swapIds) == 0 { + return nil + } + if !boltz.Connected() { + if err := boltz.Connect(); err != nil { + return fmt.Errorf("could not connect boltz ws: %w", err) + } + } if err := boltz.subscribe(swapIds); err != nil { // the connection might be dead, so forcefully reconnect if err := boltz.Reconnect(); err != nil { @@ -210,29 +224,54 @@ func (boltz *Websocket) Subscribe(swapIds []string) error { return err } } + boltz.swapIdsLock.Lock() boltz.swapIds = append(boltz.swapIds, swapIds...) + boltz.swapIdsLock.Unlock() return nil } func (boltz *Websocket) Unsubscribe(swapId string) { + boltz.swapIdsLock.Lock() + defer boltz.swapIdsLock.Unlock() boltz.swapIds = slices.DeleteFunc(boltz.swapIds, func(id string) bool { return id == swapId }) logger.Debugf("Unsubscribed from swap %s", swapId) + if len(boltz.swapIds) == 0 { + logger.Debugf("No more pending swaps, disconnecting websocket") + if err := boltz.close(); err != nil { + logger.Warnf("could not close boltz ws: %v", err) + } + } +} + +func (boltz *Websocket) close() error { + if conn := boltz.conn; conn != nil { + boltz.conn = nil + return conn.Close() + } + return nil } func (boltz *Websocket) Close() error { + // setting this flag will cause the `Updates` channel to be closed + // in the receiving goroutine. this isn't done here to avoid a situation + // where we close the channel while the receiving routine is processing an incoming message + // and then tries to send on the closed channel. + boltz.updatesLock.Lock() + defer boltz.updatesLock.Unlock() + close(boltz.Updates) boltz.closed = true - return boltz.conn.Close() + return boltz.close() +} + +func (boltz *Websocket) Connected() bool { + return boltz.conn != nil } func (boltz *Websocket) Reconnect() error { - if boltz.closed { - return errors.New("websocket is closed") - } logger.Infof("Force reconnecting to Boltz ws") - boltz.reconnect = true - if err := boltz.conn.Close(); err != nil { + if err := boltz.close(); err != nil { logger.Warnf("could not close boltz ws: %v", err) } return boltz.Connect() diff --git a/pkg/boltz/ws_test.go b/pkg/boltz/ws_test.go index 702a3f72..70a7f44d 100644 --- a/pkg/boltz/ws_test.go +++ b/pkg/boltz/ws_test.go @@ -6,27 +6,96 @@ import ( "github.com/BoltzExchange/boltz-client/v2/internal/logger" "github.com/stretchr/testify/require" "testing" + "time" ) -func TestWebsocketReconnect(t *testing.T) { +func setupWs(t *testing.T) *Websocket { logger.Init(logger.Options{Level: "debug"}) api := Api{URL: "http://localhost:9001"} - ws := api.NewWebsocket() - err := ws.Connect() - require.NoError(t, err) + require.False(t, ws.Connected()) + return ws +} + +func TestWebsocketLazy(t *testing.T) { + ws := setupWs(t) firstId := "swapId" - err = ws.Subscribe([]string{firstId}) + secondId := "anotherSwapId" + err := ws.Subscribe([]string{firstId, secondId}) require.NoError(t, err) - firstConn := ws.conn - err = firstConn.Close() - require.NoError(t, err) + require.Equal(t, []string{firstId, secondId}, ws.swapIds) + require.True(t, ws.Connected()) + require.NotNil(t, ws.conn) - anotherId := "anotherSwapId" - err = ws.Subscribe([]string{anotherId}) - require.NoError(t, err) - require.NotEqual(t, firstConn, ws.conn, "subscribe should reconnect forcefully") - require.Equal(t, []string{firstId, anotherId}, ws.swapIds) + ws.Unsubscribe(firstId) + require.True(t, ws.Connected()) + require.Equal(t, []string{secondId}, ws.swapIds) + ws.Unsubscribe(secondId) + + require.False(t, ws.Connected()) + require.Nil(t, ws.conn) +} + +func TestWebsocketReconnect(t *testing.T) { + setup := func(t *testing.T) *Websocket { + ws := setupWs(t) + require.NoError(t, ws.Connect()) + require.True(t, ws.Connected()) + return ws + } + + t.Run("Automatic", func(t *testing.T) { + ws := setup(t) + ws.reconnectInterval = 50 * time.Millisecond + firstConn := ws.conn + require.NoError(t, ws.conn.Close()) + + waitFor := time.Second + + require.Eventually(t, func() bool { + return !ws.Connected() + }, waitFor, ws.reconnectInterval) + + require.Eventually(t, func() bool { + return ws.Connected() + }, waitFor, ws.reconnectInterval) + + newConn := ws.conn + require.NotNil(t, newConn) + require.NotEqual(t, firstConn, newConn) + }) + + t.Run("Force", func(t *testing.T) { + ws := setup(t) + firstConn := ws.conn + + err := ws.Subscribe([]string{"swapId"}) + require.NoError(t, err) + + require.NoError(t, ws.conn.Close()) + + err = ws.Subscribe([]string{"anotherSwapId"}) + require.NoError(t, err) + require.NotEqual(t, firstConn, ws.conn, "subscribe should reconnect forcefully") + require.True(t, ws.Connected()) + }) + +} + +func TestWebsocketShutdown(t *testing.T) { + ws := setupWs(t) + require.NoError(t, ws.Connect()) + require.True(t, ws.Connected()) + + require.NoError(t, ws.Close()) + require.True(t, ws.closed) + require.False(t, ws.Connected()) + require.Eventually(t, func() bool { + _, ok := <-ws.Updates + return !ok + }, time.Second, 10*time.Millisecond) + + require.Error(t, ws.Subscribe([]string{"swapId"})) }