Skip to content

Commit

Permalink
feat: lazy websocket connection (#393)
Browse files Browse the repository at this point in the history
* feat: lazy websocket connection

only create actual websocket connection when there are pending swaps and
close it if there are no more pending swaps

* chore: add lock to websocket `swapIds`
  • Loading branch information
jackstar12 authored Feb 25, 2025
1 parent 91b5ab1 commit 4442b44
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 76 deletions.
165 changes: 102 additions & 63 deletions pkg/boltz/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/url"
"slices"
"sync"
"time"

"github.com/BoltzExchange/boltz-client/v2/internal/logger"
Expand All @@ -24,15 +25,17 @@ type SwapUpdate struct {
}

type Websocket struct {
Updates chan SwapUpdate

apiUrl string
subscriptions chan bool
conn *websocket.Conn
closed bool
reconnect bool
dialer *websocket.Dialer
swapIds []string
Updates chan SwapUpdate
updatesLock sync.Mutex

apiUrl string
subscriptions chan bool
conn *websocket.Conn
closed bool
dialer *websocket.Dialer
swapIds []string
swapIdsLock sync.Mutex
reconnectInterval time.Duration
}

type wsResponse struct {
Expand All @@ -51,14 +54,18 @@ func (boltz *Api) NewWebsocket() *Websocket {
}

return &Websocket{
apiUrl: boltz.URL,
subscriptions: make(chan bool),
dialer: &dialer,
Updates: make(chan SwapUpdate),
apiUrl: boltz.URL,
subscriptions: make(chan bool),
dialer: &dialer,
Updates: make(chan SwapUpdate),
reconnectInterval: reconnectInterval,
}
}

func (boltz *Websocket) Connect() error {
if boltz.closed {
return errors.New("websocket is closed")
}
wsUrl, err := url.Parse(boltz.apiUrl)
if err != nil {
return err
Expand Down Expand Up @@ -108,12 +115,13 @@ func (boltz *Websocket) Connect() error {
for {
msgType, message, err := conn.ReadMessage()
if err != nil {
if boltz.closed {
close(boltz.Updates)
return
}
if !boltz.reconnect {
// if `close` was intentionally called, `Connected` will return false
// since the connection has already been set to nil.
if boltz.Connected() {
boltz.conn = nil
logger.Error("could not receive message: " + err.Error())
} else {
return
}
break
}
Expand All @@ -122,58 +130,56 @@ func (boltz *Websocket) Connect() error {

switch msgType {
case websocket.TextMessage:
var response wsResponse
if err := json.Unmarshal(message, &response); err != nil {
logger.Errorf("could not parse websocket response: %s", err)
continue
}
if response.Error != "" {
logger.Errorf("boltz websocket error: %s", response.Error)
continue
}

switch response.Event {
case "update":
switch response.Channel {
case "swap.update":
for _, arg := range response.Args {
var update SwapUpdate
if err := mapstructure.Decode(arg, &update); err != nil {
logger.Errorf("invalid boltz response: %v", err)
}
boltz.Updates <- update
}
default:
logger.Warnf("unknown update channel: %s", response.Channel)
}
case "subscribe":
boltz.subscriptions <- true
continue
default:
logger.Warnf("unknown event: %s", response.Event)
if err := boltz.handleTextMessage(message); err != nil {
logger.Errorf("could not handle websocket message: %s", err)
}
default:
logger.Warnf("unknown message type: %v", msgType)
}
}
pingTicker.Stop()
for {
pingTicker.Stop()
if boltz.reconnect {
boltz.reconnect = false
return
} else {
logger.Errorf("lost connection to boltz ws, reconnecting in %s", reconnectInterval)
time.Sleep(reconnectInterval)
}
logger.Errorf("lost connection to boltz ws, reconnecting in %s", boltz.reconnectInterval)
time.Sleep(boltz.reconnectInterval)
err := boltz.Connect()
if err == nil {
return
}
}
}()
return nil
}

if len(boltz.swapIds) > 0 {
return boltz.subscribe(boltz.swapIds)
func (boltz *Websocket) handleTextMessage(data []byte) error {
boltz.updatesLock.Lock()
defer boltz.updatesLock.Unlock()
var response wsResponse
if err := json.Unmarshal(data, &response); err != nil {
return fmt.Errorf("invalid json: %s", err)
}
if response.Error != "" {
return fmt.Errorf("boltz error: %s", response.Error)
}

switch response.Event {
case "update":
switch response.Channel {
case "swap.update":
for _, arg := range response.Args {
var update SwapUpdate
if err := mapstructure.Decode(arg, &update); err != nil {
return fmt.Errorf("invalid boltz response: %v", err)
}
boltz.Updates <- update
}
default:
logger.Warnf("unknown update channel: %s", response.Channel)
}
case "subscribe":
boltz.subscriptions <- true
default:
logger.Warnf("unknown ws event: %s", response.Event)
}
return nil
}

Expand Down Expand Up @@ -201,6 +207,14 @@ func (boltz *Websocket) subscribe(swapIds []string) error {
}

func (boltz *Websocket) Subscribe(swapIds []string) error {
if len(swapIds) == 0 {
return nil
}
if !boltz.Connected() {
if err := boltz.Connect(); err != nil {
return fmt.Errorf("could not connect boltz ws: %w", err)
}
}
if err := boltz.subscribe(swapIds); err != nil {
// the connection might be dead, so forcefully reconnect
if err := boltz.Reconnect(); err != nil {
Expand All @@ -210,29 +224,54 @@ func (boltz *Websocket) Subscribe(swapIds []string) error {
return err
}
}
boltz.swapIdsLock.Lock()
boltz.swapIds = append(boltz.swapIds, swapIds...)
boltz.swapIdsLock.Unlock()
return nil
}

func (boltz *Websocket) Unsubscribe(swapId string) {
boltz.swapIdsLock.Lock()
defer boltz.swapIdsLock.Unlock()
boltz.swapIds = slices.DeleteFunc(boltz.swapIds, func(id string) bool {
return id == swapId
})
logger.Debugf("Unsubscribed from swap %s", swapId)
if len(boltz.swapIds) == 0 {
logger.Debugf("No more pending swaps, disconnecting websocket")
if err := boltz.close(); err != nil {
logger.Warnf("could not close boltz ws: %v", err)
}
}
}

func (boltz *Websocket) close() error {
if conn := boltz.conn; conn != nil {
boltz.conn = nil
return conn.Close()
}
return nil
}

func (boltz *Websocket) Close() error {
// setting this flag will cause the `Updates` channel to be closed
// in the receiving goroutine. this isn't done here to avoid a situation
// where we close the channel while the receiving routine is processing an incoming message
// and then tries to send on the closed channel.
boltz.updatesLock.Lock()
defer boltz.updatesLock.Unlock()
close(boltz.Updates)
boltz.closed = true
return boltz.conn.Close()
return boltz.close()
}

func (boltz *Websocket) Connected() bool {
return boltz.conn != nil
}

func (boltz *Websocket) Reconnect() error {
if boltz.closed {
return errors.New("websocket is closed")
}
logger.Infof("Force reconnecting to Boltz ws")
boltz.reconnect = true
if err := boltz.conn.Close(); err != nil {
if err := boltz.close(); err != nil {
logger.Warnf("could not close boltz ws: %v", err)
}
return boltz.Connect()
Expand Down
95 changes: 82 additions & 13 deletions pkg/boltz/ws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,96 @@ import (
"github.com/BoltzExchange/boltz-client/v2/internal/logger"
"github.com/stretchr/testify/require"
"testing"
"time"
)

func TestWebsocketReconnect(t *testing.T) {
func setupWs(t *testing.T) *Websocket {
logger.Init(logger.Options{Level: "debug"})
api := Api{URL: "http://localhost:9001"}

ws := api.NewWebsocket()
err := ws.Connect()
require.NoError(t, err)
require.False(t, ws.Connected())
return ws
}

func TestWebsocketLazy(t *testing.T) {
ws := setupWs(t)

firstId := "swapId"
err = ws.Subscribe([]string{firstId})
secondId := "anotherSwapId"
err := ws.Subscribe([]string{firstId, secondId})
require.NoError(t, err)

firstConn := ws.conn
err = firstConn.Close()
require.NoError(t, err)
require.Equal(t, []string{firstId, secondId}, ws.swapIds)
require.True(t, ws.Connected())
require.NotNil(t, ws.conn)

anotherId := "anotherSwapId"
err = ws.Subscribe([]string{anotherId})
require.NoError(t, err)
require.NotEqual(t, firstConn, ws.conn, "subscribe should reconnect forcefully")
require.Equal(t, []string{firstId, anotherId}, ws.swapIds)
ws.Unsubscribe(firstId)
require.True(t, ws.Connected())
require.Equal(t, []string{secondId}, ws.swapIds)
ws.Unsubscribe(secondId)

require.False(t, ws.Connected())
require.Nil(t, ws.conn)
}

func TestWebsocketReconnect(t *testing.T) {
setup := func(t *testing.T) *Websocket {
ws := setupWs(t)
require.NoError(t, ws.Connect())
require.True(t, ws.Connected())
return ws
}

t.Run("Automatic", func(t *testing.T) {
ws := setup(t)
ws.reconnectInterval = 50 * time.Millisecond
firstConn := ws.conn
require.NoError(t, ws.conn.Close())

waitFor := time.Second

require.Eventually(t, func() bool {
return !ws.Connected()
}, waitFor, ws.reconnectInterval)

require.Eventually(t, func() bool {
return ws.Connected()
}, waitFor, ws.reconnectInterval)

newConn := ws.conn
require.NotNil(t, newConn)
require.NotEqual(t, firstConn, newConn)
})

t.Run("Force", func(t *testing.T) {
ws := setup(t)
firstConn := ws.conn

err := ws.Subscribe([]string{"swapId"})
require.NoError(t, err)

require.NoError(t, ws.conn.Close())

err = ws.Subscribe([]string{"anotherSwapId"})
require.NoError(t, err)
require.NotEqual(t, firstConn, ws.conn, "subscribe should reconnect forcefully")
require.True(t, ws.Connected())
})

}

func TestWebsocketShutdown(t *testing.T) {
ws := setupWs(t)
require.NoError(t, ws.Connect())
require.True(t, ws.Connected())

require.NoError(t, ws.Close())
require.True(t, ws.closed)
require.False(t, ws.Connected())
require.Eventually(t, func() bool {
_, ok := <-ws.Updates
return !ok
}, time.Second, 10*time.Millisecond)

require.Error(t, ws.Subscribe([]string{"swapId"}))
}

0 comments on commit 4442b44

Please sign in to comment.