diff --git a/Makefile b/Makefile index aad31f6bc57..03c96756abc 100644 --- a/Makefile +++ b/Makefile @@ -205,6 +205,7 @@ generate-mocks: install-mock-generators mockery --name 'BlockTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" mockery --name 'DataProvider' --dir="./engine/access/rest/websockets/data_providers" --case=underscore --output="./engine/access/rest/websockets/data_providers/mock" --outpkg="mock" mockery --name 'DataProviderFactory' --dir="./engine/access/rest/websockets/data_providers" --case=underscore --output="./engine/access/rest/websockets/data_providers/mock" --outpkg="mock" + mockery --name 'WebsocketConnection' --dir="./engine/access/rest/websockets" --case=underscore --output="./engine/access/rest/websockets/mock" --outpkg="mock" mockery --name 'ExecutionDataTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" mockery --name 'ConnectionFactory' --dir="./engine/access/rpc/connection" --case=underscore --output="./engine/access/rpc/connection/mock" --outpkg="mock" mockery --name 'Communicator' --dir="./engine/access/rpc/backend" --case=underscore --output="./engine/access/rpc/backend/mock" --outpkg="mock" diff --git a/engine/access/rest/websockets/config.go b/engine/access/rest/websockets/config.go index 7f563ba94b9..8236913dd5f 100644 --- a/engine/access/rest/websockets/config.go +++ b/engine/access/rest/websockets/config.go @@ -4,6 +4,33 @@ import ( "time" ) +const ( + // PingPeriod defines the interval at which ping messages are sent to the client. + // This value must be less than pongWait, cause it that case the server ensures it sends a ping well before the PongWait + // timeout elapses. Each new pong message resets the server's read deadline, keeping the connection alive as long as + // the client is responsive. + // + // Example: + // At t=9, the server sends a ping, initial read deadline is t=10 (for the first message) + // At t=10, the client responds with a pong. The server resets its read deadline to t=20. + // At t=18, the server sends another ping. If the client responds with a pong at t=19, the read deadline is extended to t=29. + // + // In case of failure: + // If the client stops responding, the server will send a ping at t=9 but won't receive a pong by t=10. The server then closes the connection. + PingPeriod = (PongWait * 9) / 10 + + // PongWait specifies the maximum time to wait for a pong response message from the peer + // after sending a ping + PongWait = 10 * time.Second + + // WriteWait specifies a timeout for the write operation. If the write + // isn't completed within this duration, it fails with a timeout error. + // SetWriteDeadline ensures the write operation does not block indefinitely + // if the client is slow or unresponsive. This prevents resource exhaustion + // and allows the server to gracefully handle timeouts for delayed writes. + WriteWait = 10 * time.Second +) + type Config struct { MaxSubscriptionsPerConnection uint64 MaxResponsesPerSecond uint64 diff --git a/engine/access/rest/websockets/connection.go b/engine/access/rest/websockets/connection.go index 5e1880f7ce8..5170e917e9f 100644 --- a/engine/access/rest/websockets/connection.go +++ b/engine/access/rest/websockets/connection.go @@ -1,42 +1,57 @@ package websockets import ( + "time" + "github.com/gorilla/websocket" ) -// We wrap gorilla's websocket connection with interface -// to be able to mock it in order to test the types dependent on it - type WebsocketConnection interface { ReadJSON(v interface{}) error WriteJSON(v interface{}) error + WriteControl(messageType int, deadline time.Time) error Close() error + SetReadDeadline(deadline time.Time) error + SetWriteDeadline(deadline time.Time) error + SetPongHandler(h func(string) error) } -type GorillaWebsocketConnection struct { +type WebsocketConnectionImpl struct { conn *websocket.Conn } -func NewGorillaWebsocketConnection(conn *websocket.Conn) *GorillaWebsocketConnection { - return &GorillaWebsocketConnection{ +func NewWebsocketConnection(conn *websocket.Conn) *WebsocketConnectionImpl { + return &WebsocketConnectionImpl{ conn: conn, } } -var _ WebsocketConnection = (*GorillaWebsocketConnection)(nil) +var _ WebsocketConnection = (*WebsocketConnectionImpl)(nil) + +func (c *WebsocketConnectionImpl) ReadJSON(v interface{}) error { + return c.conn.ReadJSON(v) +} + +func (c *WebsocketConnectionImpl) WriteJSON(v interface{}) error { + return c.conn.WriteJSON(v) +} + +func (c *WebsocketConnectionImpl) WriteControl(messageType int, deadline time.Time) error { + return c.conn.WriteControl(messageType, nil, deadline) +} -func (m *GorillaWebsocketConnection) ReadJSON(v interface{}) error { - return m.conn.ReadJSON(v) +func (c *WebsocketConnectionImpl) Close() error { + return c.conn.Close() } -func (m *GorillaWebsocketConnection) WriteJSON(v interface{}) error { - return m.conn.WriteJSON(v) +func (c *WebsocketConnectionImpl) SetReadDeadline(deadline time.Time) error { + return c.conn.SetReadDeadline(deadline) } -func (m *GorillaWebsocketConnection) SetCloseHandler(handler func(code int, text string) error) { - m.conn.SetCloseHandler(handler) +func (c *WebsocketConnectionImpl) SetWriteDeadline(deadline time.Time) error { + return c.conn.SetWriteDeadline(deadline) } -func (m *GorillaWebsocketConnection) Close() error { - return m.conn.Close() +func (c *WebsocketConnectionImpl) SetPongHandler(h func(string) error) { + c.conn.SetPongHandler(h) } diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index fe41580167f..b82878406f8 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -4,10 +4,12 @@ import ( "context" "encoding/json" "fmt" + "time" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers" "github.com/onflow/flow-go/engine/access/rest/websockets/models" @@ -15,19 +17,21 @@ import ( ) type Controller struct { - logger zerolog.Logger - config Config - conn WebsocketConnection - communicationChannel chan interface{} - dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] - dataProviderFactory dp.DataProviderFactory + logger zerolog.Logger + config Config + conn WebsocketConnection + + communicationChannel chan interface{} // Channel for sending messages to the client. + + dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] + dataProviderFactory dp.DataProviderFactory } func NewWebSocketController( logger zerolog.Logger, config Config, - dataProviderFactory dp.DataProviderFactory, conn WebsocketConnection, + dataProviderFactory dp.DataProviderFactory, ) *Controller { return &Controller{ logger: logger.With().Str("component", "websocket-controller").Logger(), @@ -39,40 +43,99 @@ func NewWebSocketController( } } -// HandleConnection manages the WebSocket connection, adding context and error handling. +// HandleConnection manages the lifecycle of a WebSocket connection, +// including setup, message processing, and graceful shutdown. +// +// Parameters: +// - ctx: The context for controlling cancellation and timeouts. func (c *Controller) HandleConnection(ctx context.Context) { - //TODO: configure the connection with ping-pong and deadlines - //TODO: spin up a response limit tracker routine defer c.shutdownConnection() - go c.readMessages(ctx) - c.writeMessages(ctx) + + // configuring the connection with appropriate read/write deadlines and handlers. + err := c.configureKeepalive() + if err != nil { + // TODO: add error handling here + c.logger.Error().Err(err).Msg("error configuring keepalive connection") + + return + } + + // for track all goroutines and error handling + g, gCtx := errgroup.WithContext(ctx) + + g.Go(func() error { + return c.readMessagesFromClient(gCtx) + }) + + g.Go(func() error { + return c.keepalive(gCtx) + }) + + g.Go(func() error { + return c.writeMessagesToClient(gCtx) + }) + + if err = g.Wait(); err != nil { + //TODO: add error handling here + c.logger.Error().Err(err).Msg("error detected in one of the goroutines") + } +} + +// configureKeepalive sets up the WebSocket connection with a read deadline +// and a handler for receiving pong messages from the client. +// +// The function does the following: +// 1. Sets an initial read deadline to ensure the server doesn't wait indefinitely +// for a pong message from the client. If no message is received within the +// specified `pongWait` duration, the connection will be closed. +// 2. Establishes a Pong handler that resets the read deadline every time a pong +// message is received from the client, allowing the server to continue waiting +// for further pong messages within the new deadline. +// +// No errors are expected during normal operation. +func (c *Controller) configureKeepalive() error { + // Set the initial read deadline for the first pong message + // The Pong handler itself only resets the read deadline after receiving a Pong. + // It doesn't set an initial deadline. The initial read deadline is crucial to prevent the server from waiting + // forever if the client doesn't send Pongs. + if err := c.conn.SetReadDeadline(time.Now().Add(PongWait)); err != nil { + return fmt.Errorf("failed to set the initial read deadline: %w", err) + } + // Establish a Pong handler which sets the handler for pong messages received from the peer. + c.conn.SetPongHandler(func(string) error { + return c.conn.SetReadDeadline(time.Now().Add(PongWait)) + }) + + return nil } // writeMessages reads a messages from communication channel and passes them on to a client WebSocket connection. // The communication channel is filled by data providers. Besides, the response limit tracker is involved in // write message regulation -func (c *Controller) writeMessages(ctx context.Context) { +// +// No errors are expected during normal operation. All errors are considered benign. +func (c *Controller) writeMessagesToClient(ctx context.Context) error { for { select { case <-ctx.Done(): - return + return nil case msg, ok := <-c.communicationChannel: if !ok { - return + return fmt.Errorf("communication channel closed, no error occurred") } - c.logger.Debug().Msgf("read message from communication channel: %s", msg) - // TODO: handle 'response per second' limits + // Specifies a timeout for the write operation. If the write + // isn't completed within this duration, it fails with a timeout error. + // SetWriteDeadline ensures the write operation does not block indefinitely + // if the client is slow or unresponsive. This prevents resource exhaustion + // and allows the server to gracefully handle timeouts for delayed writes. + if err := c.conn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil { + return fmt.Errorf("failed to set the write deadline: %w", err) + } err := c.conn.WriteJSON(msg) if err != nil { - if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) || - websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { - return - } - - c.logger.Error().Err(err).Msg("error writing to connection") - return + return fmt.Errorf("failed to write message to connection: %w", err) } c.logger.Debug().Msg("written message to client") @@ -82,30 +145,30 @@ func (c *Controller) writeMessages(ctx context.Context) { // readMessages continuously reads messages from a client WebSocket connection, // processes each message, and handles actions based on the message type. -func (c *Controller) readMessages(ctx context.Context) { +// +// No errors are expected during normal operation. All errors are considered benign. +func (c *Controller) readMessagesFromClient(ctx context.Context) error { for { - msg, err := c.readMessage() - if err != nil { - if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) || - websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { - return + select { + case <-ctx.Done(): + return nil + default: + msg, err := c.readMessage() + if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) { + return nil + } + return fmt.Errorf("failed to read message from client: %w", err) } - c.logger.Debug().Err(err).Msg("error reading message from client") - continue - } - - baseMsg, validatedMsg, err := c.parseAndValidateMessage(msg) - if err != nil { - c.logger.Debug().Err(err).Msg("error parsing and validating client message") - //TODO: write error to error channel - continue - } + _, validatedMsg, err := c.parseAndValidateMessage(msg) + if err != nil { + return fmt.Errorf("failed to parse and validate client message: %w", err) + } - if err := c.handleAction(ctx, validatedMsg); err != nil { - c.logger.Debug().Err(err).Str("action", baseMsg.Action).Msg("error handling action") - //TODO: write error to error channel - continue + if err := c.handleAction(ctx, validatedMsg); err != nil { + return fmt.Errorf("failed to handle message action: %w", err) + } } } } @@ -150,7 +213,6 @@ func (c *Controller) parseAndValidateMessage(message json.RawMessage) (models.Ba validatedMsg = listMsg default: - c.logger.Debug().Str("action", baseMsg.Action).Msg("unknown action type") return baseMsg, nil, fmt.Errorf("unknown action type: %s", baseMsg.Action) } @@ -229,6 +291,7 @@ func (c *Controller) shutdownConnection() { if err := c.conn.Close(); err != nil { c.logger.Warn().Err(err).Msg("error closing connection") } + // TODO: safe closing communicationChannel will be included as a part of PR #6642 }() c.logger.Debug().Msg("shutting down connection") @@ -243,3 +306,24 @@ func (c *Controller) shutdownConnection() { c.dataProviders.Clear() } + +// keepalive sends a ping message periodically to keep the WebSocket connection alive +// and avoid timeouts. +// +// No errors are expected during normal operation. All errors are considered benign. +func (c *Controller) keepalive(ctx context.Context) error { + pingTicker := time.NewTicker(PingPeriod) + defer pingTicker.Stop() + + for { + select { + case <-ctx.Done(): + return nil + case <-pingTicker.C: + err := c.conn.WriteControl(websocket.PingMessage, time.Now().Add(WriteWait)) + if err != nil { + return fmt.Errorf("failed to write ping message: %w", err) + } + } + } +} diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 4e36aacc3bc..72adf511161 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -3,247 +3,274 @@ package websockets import ( "context" "encoding/json" + "fmt" "testing" + "time" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/rs/zerolog" + + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers" dpmock "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers/mock" - connmock "github.com/onflow/flow-go/engine/access/rest/websockets/mock" + connectionmock "github.com/onflow/flow-go/engine/access/rest/websockets/mock" "github.com/onflow/flow-go/engine/access/rest/websockets/models" - "github.com/onflow/flow-go/engine/access/state_stream/backend" - streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" - "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) -type WsControllerSuite struct { +// ControllerSuite is a test suite for the WebSocket Controller. +type ControllerSuite struct { suite.Suite - logger zerolog.Logger - wsConfig Config - streamApi *streammock.API - streamConfig backend.Config + logger zerolog.Logger + config Config + + connection *connectionmock.WebsocketConnection + dataProviderFactory *dpmock.DataProviderFactory } -func (s *WsControllerSuite) SetupTest() { +func TestControllerSuite(t *testing.T) { + suite.Run(t, new(ControllerSuite)) +} + +// SetupTest initializes the test suite with required dependencies. +func (s *ControllerSuite) SetupTest() { s.logger = unittest.Logger() - s.wsConfig = NewDefaultWebsocketConfig() - s.streamApi = streammock.NewAPI(s.T()) - s.streamConfig = backend.Config{} + s.config = Config{} + + s.connection = connectionmock.NewWebsocketConnection(s.T()) + s.dataProviderFactory = dpmock.NewDataProviderFactory(s.T()) } -func TestWsControllerSuite(t *testing.T) { - suite.Run(t, new(WsControllerSuite)) +// TestConfigureKeepaliveConnection ensures that the WebSocket connection is configured correctly. +func (s *ControllerSuite) TestConfigureKeepaliveConnection() { + controller := s.initializeController() + + // Mock configureConnection to succeed + s.mockConnectionSetup() + + // Call configureKeepalive and check for errors + err := controller.configureKeepalive() + s.Require().NoError(err, "configureKeepalive should not return an error") + + // Assert expectations + s.connection.AssertExpectations(s.T()) } -// TestSubscribeRequest tests the subscribe to topic flow. -// We emulate a request message from a client, and a response message from a controller. -func (s *WsControllerSuite) TestSubscribeRequest() { - s.T().Run("Happy path", func(t *testing.T) { - conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, dataProviderFactory, conn) +// TestControllerShutdown ensures that HandleConnection shuts down gracefully when an error occurs. +func (s *ControllerSuite) TestControllerShutdown() { + s.T().Run("keepalive routine failed", func(*testing.T) { + controller := s.initializeController() - dataProvider. - On("Run"). - Run(func(args mock.Arguments) {}). - Return(nil). + // Mock configureConnection to succeed + s.mockConnectionSetup() + + // Mock keepalive to return an error + done := make(chan struct{}, 1) + s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(func(int, time.Time) error { + close(done) + return websocket.ErrCloseSent + }).Once() + + s.connection. + On("ReadJSON", mock.Anything). + Return(func(interface{}) error { + <-done + return websocket.ErrCloseSent + }). Once() - subscribeRequest := models.SubscribeMessageRequest{ + s.connection.On("Close").Return(nil).Once() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + controller.HandleConnection(ctx) + + // Ensure all expectations are met + s.connection.AssertExpectations(s.T()) + }) + + s.T().Run("read routine failed", func(*testing.T) { + controller := s.initializeController() + // Mock configureConnection to succeed + s.mockConnectionSetup() + + s.connection. + On("ReadJSON", mock.Anything). + Return(func(_ interface{}) error { + return assert.AnError + }). + Once() + + s.connection.On("Close").Return(nil).Once() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + controller.HandleConnection(ctx) + + // Ensure all expectations are met + s.connection.AssertExpectations(s.T()) + }) + + s.T().Run("write routine failed", func(*testing.T) { + controller := s.initializeController() + + // Mock configureConnection to succeed + s.mockConnectionSetup() + blocksDataProvider := s.mockBlockDataProviderSetup(uuid.New()) + + done := make(chan struct{}, 1) + requestMessage := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, - Topic: "blocks", + Topic: dp.BlocksTopic, Arguments: nil, } + msg, err := json.Marshal(requestMessage) + s.Require().NoError(err) - // Simulate receiving the subscription request from the client - conn. + // Mocks `ReadJSON(v interface{}) error` which accepts an uninitialize interface that + // receives the contents of the read message. This logic mocks that behavior, setting + // the target with the value `msg` + s.connection. On("ReadJSON", mock.Anything). Run(func(args mock.Arguments) { - requestMsg, ok := args.Get(0).(*json.RawMessage) - require.True(t, ok) - subscribeRequestMessage, err := json.Marshal(subscribeRequest) - require.NoError(t, err) - *requestMsg = subscribeRequestMessage + reqMsg, ok := args.Get(0).(*json.RawMessage) + s.Require().True(ok) + *reqMsg = msg }). Return(nil). Once() - // Channel to signal the test flow completion - done := make(chan struct{}, 1) - - // Simulate writing a successful subscription response back to the client - conn. - On("WriteJSON", mock.Anything). - Return(func(msg interface{}) error { - response, ok := msg.(models.SubscribeMessageResponse) - require.True(t, ok) - require.True(t, response.Success) - close(done) // Signal that response has been sent - return websocket.ErrCloseSent - }) - - // Simulate client closing connection after receiving the response - conn. + s.connection. On("ReadJSON", mock.Anything). Return(func(interface{}) error { <-done return websocket.ErrCloseSent }) - controller.HandleConnection(context.Background()) + s.connection.On("SetWriteDeadline", mock.Anything).Return(nil).Once() + s.connection. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + close(done) + return assert.AnError + }) + s.connection.On("Close").Return(nil).Once() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + controller.HandleConnection(ctx) + + // Ensure all expectations are met + s.connection.AssertExpectations(s.T()) + s.dataProviderFactory.AssertExpectations(s.T()) + blocksDataProvider.AssertExpectations(s.T()) }) -} -// TestSubscribeBlocks tests the functionality for streaming blocks to a subscriber. -func (s *WsControllerSuite) TestSubscribeBlocks() { - s.T().Run("Stream one block", func(t *testing.T) { - conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, dataProviderFactory, conn) + s.T().Run("context closed", func(*testing.T) { + controller := s.initializeController() - // Simulate data provider write a block to the controller - expectedBlock := unittest.BlockFixture() - dataProvider. - On("Run", mock.Anything). - Run(func(args mock.Arguments) { - controller.communicationChannel <- expectedBlock - }). - Return(nil). - Once() + // Mock configureConnection to succeed + s.mockConnectionSetup() - done := make(chan struct{}, 1) - s.expectSubscriptionRequest(conn, done) - s.expectSubscriptionResponse(conn, true) + s.connection.On("Close").Return(nil).Once() - // Expect a valid block to be passed to WriteJSON. - // If we got to this point, the controller executed all its logic properly - var actualBlock flow.Block - conn. - On("WriteJSON", mock.Anything). - Return(func(msg interface{}) error { - block, ok := msg.(flow.Block) - require.True(t, ok) - actualBlock = block + ctx, cancel := context.WithCancel(context.Background()) + cancel() - close(done) - return websocket.ErrCloseSent - }) + controller.HandleConnection(ctx) - controller.HandleConnection(context.Background()) - require.Equal(t, expectedBlock, actualBlock) + // Ensure all expectations are met + s.connection.AssertExpectations(s.T()) }) +} - s.T().Run("Stream many blocks", func(t *testing.T) { - conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, dataProviderFactory, conn) +// TestKeepaliveHappyCase tests the behavior of the keepalive function. +func (s *ControllerSuite) TestKeepaliveHappyCase() { + // Create a context for the test + ctx := context.Background() - // Simulate data provider writes some blocks to the controller - expectedBlocks := unittest.BlockFixtures(100) - dataProvider. - On("Run", mock.Anything). - Run(func(args mock.Arguments) { - for _, block := range expectedBlocks { - controller.communicationChannel <- *block - } - }). - Return(nil). - Once() + controller := s.initializeController() + s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(nil) - done := make(chan struct{}, 1) - s.expectSubscriptionRequest(conn, done) - s.expectSubscriptionResponse(conn, true) + // Start the keepalive process in a separate goroutine + go func() { + err := controller.keepalive(ctx) + s.Require().NoError(err) + }() - i := 0 - actualBlocks := make([]*flow.Block, len(expectedBlocks)) + // Use Eventually to wait for some ping messages + expectedCalls := 3 // expected 3 ping messages for 30 seconds + s.Require().Eventually(func() bool { + return len(s.connection.Calls) == expectedCalls + }, time.Duration(expectedCalls)*PongWait, 1*time.Second, "not all ping messages were sent") - // Expect valid blocks to be passed to WriteJSON. - // If we got to this point, the controller executed all its logic properly - conn. - On("WriteJSON", mock.Anything). - Return(func(msg interface{}) error { - block, ok := msg.(flow.Block) - require.True(t, ok) + s.connection.On("Close").Return(nil).Once() + controller.shutdownConnection() - actualBlocks[i] = &block - i += 1 + // Assert that the ping was sent + s.connection.AssertExpectations(s.T()) +} - if i == len(expectedBlocks) { - close(done) - return websocket.ErrCloseSent - } +// TestKeepaliveError tests the behavior of the keepalive function when there is an error in writing the ping. +func (s *ControllerSuite) TestKeepaliveError() { + controller := s.initializeController() - return nil - }). - Times(len(expectedBlocks)) + // Setup the mock connection with an error + s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(assert.AnError).Once() - controller.HandleConnection(context.Background()) - require.Equal(t, expectedBlocks, actualBlocks) - }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + expectedError := fmt.Errorf("failed to write ping message: %w", assert.AnError) + // Start the keepalive process + err := controller.keepalive(ctx) + s.Require().Error(err) + s.Require().Equal(expectedError, err) + + // Assert expectations + s.connection.AssertExpectations(s.T()) } -// newControllerMocks initializes mock WebSocket connection, data provider, and data provider factory. -// The mocked functions are expected to be called in a case when a test is expected to reach WriteJSON function. -func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.DataProviderFactory, *dpmock.DataProvider) { - conn := connmock.NewWebsocketConnection(t) - conn.On("Close").Return(nil).Once() - - id := uuid.New() - topic := "blocks" - dataProvider := dpmock.NewDataProvider(t) - dataProvider.On("ID").Return(id) - dataProvider.On("Close").Return(nil) - dataProvider.On("Topic").Return(topic) - - factory := dpmock.NewDataProviderFactory(t) - factory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(dataProvider, nil). - Once() - - return conn, factory, dataProvider +// TestKeepaliveContextCancel tests the behavior of keepalive when the context is canceled before a ping is sent and +// no ping message is sent after the context is canceled. +func (s *ControllerSuite) TestKeepaliveContextCancel() { + controller := s.initializeController() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Immediately cancel the context + + // Start the keepalive process with the context canceled + err := controller.keepalive(ctx) + s.Require().NoError(err) + + // Assert expectations + s.connection.AssertExpectations(s.T()) // Should not invoke WriteMessage after context cancellation } -// expectSubscriptionRequest mocks the client's subscription request. -func (s *WsControllerSuite) expectSubscriptionRequest(conn *connmock.WebsocketConnection, done <-chan struct{}) { - requestMessage := models.SubscribeMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, - Topic: "blocks", - } - - // The very first message from a client is a request to subscribe to some topic - conn.On("ReadJSON", mock.Anything). - Run(func(args mock.Arguments) { - reqMsg, ok := args.Get(0).(*json.RawMessage) - require.True(s.T(), ok) - msg, err := json.Marshal(requestMessage) - require.NoError(s.T(), err) - *reqMsg = msg - }). - Return(nil). - Once() - - // In the default case, no further communication is expected from the client. - // We wait for the writer routine to signal completion, allowing us to close the connection gracefully - conn. - On("ReadJSON", mock.Anything). - Return(func(msg interface{}) error { - <-done - return websocket.ErrCloseSent - }) +// initializeController initializes the WebSocket controller. +func (s *ControllerSuite) initializeController() *Controller { + return NewWebSocketController(s.logger, s.config, s.connection, s.dataProviderFactory) +} + +// mockDataProviderSetup is a helper which mocks a blocks data provider setup. +func (s *ControllerSuite) mockBlockDataProviderSetup(id uuid.UUID) *dpmock.DataProvider { + dataProvider := dpmock.NewDataProvider(s.T()) + dataProvider.On("ID").Return(id).Once() + dataProvider.On("Close").Return(nil).Once() + s.dataProviderFactory.On("NewDataProvider", mock.Anything, dp.BlocksTopic, mock.Anything, mock.Anything). + Return(dataProvider, nil).Once() + dataProvider.On("Run").Return(nil).Once() + + return dataProvider } -// expectSubscriptionResponse mocks the subscription response sent to the client. -func (s *WsControllerSuite) expectSubscriptionResponse(conn *connmock.WebsocketConnection, success bool) { - conn.On("WriteJSON", mock.Anything). - Run(func(args mock.Arguments) { - response, ok := args.Get(0).(models.SubscribeMessageResponse) - require.True(s.T(), ok) - require.Equal(s.T(), success, response.Success) - }). - Return(nil). - Once() +// mockConnectionSetup is a helper which mocks connection setup for SetReadDeadline and SetPongHandler. +func (s *ControllerSuite) mockConnectionSetup() { + s.connection.On("SetReadDeadline", mock.Anything).Return(nil).Once() + s.connection.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() } diff --git a/engine/access/rest/websockets/handler.go b/engine/access/rest/websockets/handler.go index a428a435f5a..8dbe13078ad 100644 --- a/engine/access/rest/websockets/handler.go +++ b/engine/access/rest/websockets/handler.go @@ -61,7 +61,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - newConn := NewGorillaWebsocketConnection(conn) - controller := NewWebSocketController(logger, h.websocketConfig, h.dataProviderFactory, newConn) - controller.HandleConnection(context.Background()) + controller := NewWebSocketController(logger, h.websocketConfig, NewWebsocketConnection(conn), h.dataProviderFactory) + controller.HandleConnection(context.TODO()) } diff --git a/engine/access/rest/websockets/legacy/websocket_handler.go b/engine/access/rest/websockets/legacy/websocket_handler.go index 7132314b16c..06aa8323de4 100644 --- a/engine/access/rest/websockets/legacy/websocket_handler.go +++ b/engine/access/rest/websockets/legacy/websocket_handler.go @@ -12,6 +12,7 @@ import ( "go.uber.org/atomic" "github.com/onflow/flow-go/engine/access/rest/common" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/engine/access/subscription" @@ -19,17 +20,6 @@ import ( "github.com/onflow/flow-go/model/flow" ) -const ( - // Time allowed to read the next pong message from the peer. - pongWait = 10 * time.Second - - // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = (pongWait * 9) / 10 - - // Time allowed to write a message to the peer. - writeWait = 10 * time.Second -) - // WebsocketController holds the necessary components and parameters for handling a WebSocket subscription. // It manages the communication between the server and the WebSocket client for subscribing. type WebsocketController struct { @@ -47,17 +37,17 @@ type WebsocketController struct { // manage incoming Pong messages. These methods allow to specify a time limit for reading from or writing to a WebSocket // connection. If the operation (reading or writing) takes longer than the specified deadline, the connection will be closed. func (wsController *WebsocketController) SetWebsocketConf() error { - err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) // Set the initial write deadline for the first ping message + err := wsController.conn.SetWriteDeadline(time.Now().Add(websockets.WriteWait)) // Set the initial write deadline for the first ping message if err != nil { return common.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err) } - err = wsController.conn.SetReadDeadline(time.Now().Add(pongWait)) // Set the initial read deadline for the first pong message + err = wsController.conn.SetReadDeadline(time.Now().Add(websockets.PongWait)) // Set the initial read deadline for the first pong message if err != nil { return common.NewRestError(http.StatusInternalServerError, "Set the initial read deadline error: ", err) } // Establish a Pong handler wsController.conn.SetPongHandler(func(string) error { - err := wsController.conn.SetReadDeadline(time.Now().Add(pongWait)) + err := wsController.conn.SetReadDeadline(time.Now().Add(websockets.PongWait)) if err != nil { return err } @@ -111,7 +101,7 @@ func (wsController *WebsocketController) wsErrorHandler(err error) { // If an error occurs or the subscription channel is closed, it handles the error or termination accordingly. // The function uses a ticker to periodically send ping messages to the client to maintain the connection. func (wsController *WebsocketController) writeEvents(sub subscription.Subscription) { - ticker := time.NewTicker(pingPeriod) + ticker := time.NewTicker(websockets.PingPeriod) defer ticker.Stop() blocksSinceLastMessage := uint64(0) @@ -137,7 +127,7 @@ func (wsController *WebsocketController) writeEvents(sub subscription.Subscripti wsController.wsErrorHandler(common.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err)) return } - err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) + err := wsController.conn.SetWriteDeadline(time.Now().Add(websockets.WriteWait)) if err != nil { wsController.wsErrorHandler(common.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline: ", err)) return @@ -178,7 +168,7 @@ func (wsController *WebsocketController) writeEvents(sub subscription.Subscripti return } case <-ticker.C: - err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) + err := wsController.conn.SetWriteDeadline(time.Now().Add(websockets.WriteWait)) if err != nil { wsController.wsErrorHandler(common.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline: ", err)) return diff --git a/engine/access/rest/websockets/mock/websocket_connection.go b/engine/access/rest/websockets/mock/websocket_connection.go index e81b2bcec3f..02a60fd0a3c 100644 --- a/engine/access/rest/websockets/mock/websocket_connection.go +++ b/engine/access/rest/websockets/mock/websocket_connection.go @@ -2,7 +2,11 @@ package mock -import mock "github.com/stretchr/testify/mock" +import ( + time "time" + + mock "github.com/stretchr/testify/mock" +) // WebsocketConnection is an autogenerated mock type for the WebsocketConnection type type WebsocketConnection struct { @@ -45,6 +49,65 @@ func (_m *WebsocketConnection) ReadJSON(v interface{}) error { return r0 } +// SetPongHandler provides a mock function with given fields: h +func (_m *WebsocketConnection) SetPongHandler(h func(string) error) { + _m.Called(h) +} + +// SetReadDeadline provides a mock function with given fields: deadline +func (_m *WebsocketConnection) SetReadDeadline(deadline time.Time) error { + ret := _m.Called(deadline) + + if len(ret) == 0 { + panic("no return value specified for SetReadDeadline") + } + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(deadline) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetWriteDeadline provides a mock function with given fields: deadline +func (_m *WebsocketConnection) SetWriteDeadline(deadline time.Time) error { + ret := _m.Called(deadline) + + if len(ret) == 0 { + panic("no return value specified for SetWriteDeadline") + } + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(deadline) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// WriteControl provides a mock function with given fields: messageType, deadline +func (_m *WebsocketConnection) WriteControl(messageType int, deadline time.Time) error { + ret := _m.Called(messageType, deadline) + + if len(ret) == 0 { + panic("no return value specified for WriteControl") + } + + var r0 error + if rf, ok := ret.Get(0).(func(int, time.Time) error); ok { + r0 = rf(messageType, deadline) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // WriteJSON provides a mock function with given fields: v func (_m *WebsocketConnection) WriteJSON(v interface{}) error { ret := _m.Called(v) diff --git a/storage/mock/iter_item.go b/storage/mock/iter_item.go new file mode 100644 index 00000000000..5d699511fb8 --- /dev/null +++ b/storage/mock/iter_item.go @@ -0,0 +1,82 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mock + +import mock "github.com/stretchr/testify/mock" + +// IterItem is an autogenerated mock type for the IterItem type +type IterItem struct { + mock.Mock +} + +// Key provides a mock function with given fields: +func (_m *IterItem) Key() []byte { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Key") + } + + var r0 []byte + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + return r0 +} + +// KeyCopy provides a mock function with given fields: dst +func (_m *IterItem) KeyCopy(dst []byte) []byte { + ret := _m.Called(dst) + + if len(ret) == 0 { + panic("no return value specified for KeyCopy") + } + + var r0 []byte + if rf, ok := ret.Get(0).(func([]byte) []byte); ok { + r0 = rf(dst) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + return r0 +} + +// Value provides a mock function with given fields: _a0 +func (_m *IterItem) Value(_a0 func([]byte) error) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Value") + } + + var r0 error + if rf, ok := ret.Get(0).(func(func([]byte) error) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewIterItem creates a new instance of IterItem. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewIterItem(t interface { + mock.TestingT + Cleanup(func()) +}) *IterItem { + mock := &IterItem{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/storage/mock/iterator.go b/storage/mock/iterator.go new file mode 100644 index 00000000000..1b094ac15e1 --- /dev/null +++ b/storage/mock/iterator.go @@ -0,0 +1,106 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mock + +import ( + storage "github.com/onflow/flow-go/storage" + mock "github.com/stretchr/testify/mock" +) + +// Iterator is an autogenerated mock type for the Iterator type +type Iterator struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *Iterator) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// First provides a mock function with given fields: +func (_m *Iterator) First() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for First") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// IterItem provides a mock function with given fields: +func (_m *Iterator) IterItem() storage.IterItem { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for IterItem") + } + + var r0 storage.IterItem + if rf, ok := ret.Get(0).(func() storage.IterItem); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(storage.IterItem) + } + } + + return r0 +} + +// Next provides a mock function with given fields: +func (_m *Iterator) Next() { + _m.Called() +} + +// Valid provides a mock function with given fields: +func (_m *Iterator) Valid() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Valid") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// NewIterator creates a new instance of Iterator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewIterator(t interface { + mock.TestingT + Cleanup(func()) +}) *Iterator { + mock := &Iterator{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/storage/mock/reader.go b/storage/mock/reader.go new file mode 100644 index 00000000000..f9b15b532b5 --- /dev/null +++ b/storage/mock/reader.go @@ -0,0 +1,98 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mock + +import ( + io "io" + + storage "github.com/onflow/flow-go/storage" + mock "github.com/stretchr/testify/mock" +) + +// Reader is an autogenerated mock type for the Reader type +type Reader struct { + mock.Mock +} + +// Get provides a mock function with given fields: key +func (_m *Reader) Get(key []byte) ([]byte, io.Closer, error) { + ret := _m.Called(key) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 []byte + var r1 io.Closer + var r2 error + if rf, ok := ret.Get(0).(func([]byte) ([]byte, io.Closer, error)); ok { + return rf(key) + } + if rf, ok := ret.Get(0).(func([]byte) []byte); ok { + r0 = rf(key) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func([]byte) io.Closer); ok { + r1 = rf(key) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(io.Closer) + } + } + + if rf, ok := ret.Get(2).(func([]byte) error); ok { + r2 = rf(key) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// NewIter provides a mock function with given fields: startPrefix, endPrefix, ops +func (_m *Reader) NewIter(startPrefix []byte, endPrefix []byte, ops storage.IteratorOption) (storage.Iterator, error) { + ret := _m.Called(startPrefix, endPrefix, ops) + + if len(ret) == 0 { + panic("no return value specified for NewIter") + } + + var r0 storage.Iterator + var r1 error + if rf, ok := ret.Get(0).(func([]byte, []byte, storage.IteratorOption) (storage.Iterator, error)); ok { + return rf(startPrefix, endPrefix, ops) + } + if rf, ok := ret.Get(0).(func([]byte, []byte, storage.IteratorOption) storage.Iterator); ok { + r0 = rf(startPrefix, endPrefix, ops) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(storage.Iterator) + } + } + + if rf, ok := ret.Get(1).(func([]byte, []byte, storage.IteratorOption) error); ok { + r1 = rf(startPrefix, endPrefix, ops) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewReader creates a new instance of Reader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewReader(t interface { + mock.TestingT + Cleanup(func()) +}) *Reader { + mock := &Reader{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/storage/mock/reader_batch_writer.go b/storage/mock/reader_batch_writer.go new file mode 100644 index 00000000000..c64c340704e --- /dev/null +++ b/storage/mock/reader_batch_writer.go @@ -0,0 +1,72 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mock + +import ( + storage "github.com/onflow/flow-go/storage" + mock "github.com/stretchr/testify/mock" +) + +// ReaderBatchWriter is an autogenerated mock type for the ReaderBatchWriter type +type ReaderBatchWriter struct { + mock.Mock +} + +// AddCallback provides a mock function with given fields: _a0 +func (_m *ReaderBatchWriter) AddCallback(_a0 func(error)) { + _m.Called(_a0) +} + +// GlobalReader provides a mock function with given fields: +func (_m *ReaderBatchWriter) GlobalReader() storage.Reader { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GlobalReader") + } + + var r0 storage.Reader + if rf, ok := ret.Get(0).(func() storage.Reader); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(storage.Reader) + } + } + + return r0 +} + +// Writer provides a mock function with given fields: +func (_m *ReaderBatchWriter) Writer() storage.Writer { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Writer") + } + + var r0 storage.Writer + if rf, ok := ret.Get(0).(func() storage.Writer); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(storage.Writer) + } + } + + return r0 +} + +// NewReaderBatchWriter creates a new instance of ReaderBatchWriter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewReaderBatchWriter(t interface { + mock.TestingT + Cleanup(func()) +}) *ReaderBatchWriter { + mock := &ReaderBatchWriter{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/storage/mock/writer.go b/storage/mock/writer.go new file mode 100644 index 00000000000..f80b206d39c --- /dev/null +++ b/storage/mock/writer.go @@ -0,0 +1,81 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mock + +import ( + storage "github.com/onflow/flow-go/storage" + mock "github.com/stretchr/testify/mock" +) + +// Writer is an autogenerated mock type for the Writer type +type Writer struct { + mock.Mock +} + +// Delete provides a mock function with given fields: key +func (_m *Writer) Delete(key []byte) error { + ret := _m.Called(key) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func([]byte) error); ok { + r0 = rf(key) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteByRange provides a mock function with given fields: globalReader, startPrefix, endPrefix +func (_m *Writer) DeleteByRange(globalReader storage.Reader, startPrefix []byte, endPrefix []byte) error { + ret := _m.Called(globalReader, startPrefix, endPrefix) + + if len(ret) == 0 { + panic("no return value specified for DeleteByRange") + } + + var r0 error + if rf, ok := ret.Get(0).(func(storage.Reader, []byte, []byte) error); ok { + r0 = rf(globalReader, startPrefix, endPrefix) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Set provides a mock function with given fields: k, v +func (_m *Writer) Set(k []byte, v []byte) error { + ret := _m.Called(k, v) + + if len(ret) == 0 { + panic("no return value specified for Set") + } + + var r0 error + if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok { + r0 = rf(k, v) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewWriter creates a new instance of Writer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewWriter(t interface { + mock.TestingT + Cleanup(func()) +}) *Writer { + mock := &Writer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}