From 59eaa752dc7255af97163e6034f26a80681be908 Mon Sep 17 00:00:00 2001 From: Thiago Coimbra Lemos Date: Thu, 17 Aug 2023 17:26:10 -0300 Subject: [PATCH] cherry-pick #2385 and #2396 from develop into v0.2.6 (#2412) * fix http request instance null for websocket requests (#2385) * fix ws subscribe to get filtered log notifications (#2396) --- jsonrpc/endpoints_eth.go | 10 +- jsonrpc/handler.go | 8 +- jsonrpc/server.go | 4 +- state/pgstatestorage.go | 83 +++++++------ test/e2e/sc_test.go | 251 ++++++++++++++++++++++++++++++++++----- 5 files changed, 284 insertions(+), 72 deletions(-) diff --git a/jsonrpc/endpoints_eth.go b/jsonrpc/endpoints_eth.go index 523202f364..162bb62d7b 100644 --- a/jsonrpc/endpoints_eth.go +++ b/jsonrpc/endpoints_eth.go @@ -862,7 +862,9 @@ func (e *EthEndpoints) UninstallFilter(filterID string) (interface{}, types.Erro func (e *EthEndpoints) Syncing() (interface{}, types.Error) { return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { _, err := e.state.GetLastL2BlockNumber(ctx, dbTx) - if err != nil { + if errors.Is(err, state.ErrStateNotSynchronized) { + return nil, types.NewRPCErrorWithData(types.DefaultErrorCode, state.ErrStateNotSynchronized.Error(), nil) + } else if err != nil { return RPCErrorResponse(types.DefaultErrorCode, "failed to get last block number from state", err) } @@ -997,7 +999,10 @@ func (e *EthEndpoints) onNewL2Block(event state.NewL2BlockEvent) { } if changes != nil { - e.sendSubscriptionResponse(filter, changes) + ethLogs := changes.([]types.Log) + for _, ethLog := range ethLogs { + e.sendSubscriptionResponse(filter, ethLog) + } } } } @@ -1027,4 +1032,5 @@ func (e *EthEndpoints) sendSubscriptionResponse(filter *Filter, data interface{} if err != nil { log.Errorf(fmt.Sprintf(errMessage, filter.ID, err.Error())) } + log.Debugf("WS message sent: %v", string(message)) } diff --git a/jsonrpc/handler.go b/jsonrpc/handler.go index cd78075f37..6a1f301940 100644 --- a/jsonrpc/handler.go +++ b/jsonrpc/handler.go @@ -156,15 +156,17 @@ func (h *Handler) Handle(req handleRequest) types.Response { } // HandleWs handle websocket requests -func (h *Handler) HandleWs(reqBody []byte, wsConn *websocket.Conn) ([]byte, error) { +func (h *Handler) HandleWs(reqBody []byte, wsConn *websocket.Conn, httpReq *http.Request) ([]byte, error) { + log.Debugf("WS message received: %v", string(reqBody)) var req types.Request if err := json.Unmarshal(reqBody, &req); err != nil { return types.NewResponse(req, nil, types.NewRPCError(types.InvalidRequestErrorCode, "Invalid json request")).Bytes() } handleReq := handleRequest{ - Request: req, - wsConn: wsConn, + Request: req, + wsConn: wsConn, + HttpRequest: httpReq, } return h.Handle(handleReq).Bytes() diff --git a/jsonrpc/server.go b/jsonrpc/server.go index 8d28ded549..1b80eb13ad 100644 --- a/jsonrpc/server.go +++ b/jsonrpc/server.go @@ -366,7 +366,7 @@ func (s *Server) handleWs(w http.ResponseWriter, req *http.Request) { go func() { mu.Lock() defer mu.Unlock() - resp, err := s.handler.HandleWs(message, wsConn) + resp, err := s.handler.HandleWs(message, wsConn, req) if err != nil { log.Error(fmt.Sprintf("Unable to handle WS request, %s", err.Error())) _ = wsConn.WriteMessage(msgType, []byte(fmt.Sprintf("WS Handle error: %s", err.Error()))) @@ -394,7 +394,7 @@ func RPCErrorResponse(code int, message string, err error) (interface{}, types.E // RPCErrorResponseWithData formats error to be returned through RPC func RPCErrorResponseWithData(code int, message string, data *[]byte, err error) (interface{}, types.Error) { if err != nil { - log.Errorf("%v:%v", message, err.Error()) + log.Errorf("%v: %v", message, err.Error()) } else { log.Error(message) } diff --git a/state/pgstatestorage.go b/state/pgstatestorage.go index 7edc6254d6..2a9a1e7e13 100644 --- a/state/pgstatestorage.go +++ b/state/pgstatestorage.go @@ -1847,51 +1847,60 @@ func (p *PostgresStorage) IsL2BlockVirtualized(ctx context.Context, blockNumber // GetLogs returns the logs that match the filter func (p *PostgresStorage) GetLogs(ctx context.Context, fromBlock uint64, toBlock uint64, addresses []common.Address, topics [][]common.Hash, blockHash *common.Hash, since *time.Time, dbTx pgx.Tx) ([]*types.Log, error) { const getLogsByBlockHashSQL = ` - SELECT t.l2_block_num, b.block_hash, l.tx_hash, l.log_index, l.address, l.data, l.topic0, l.topic1, l.topic2, l.topic3 - FROM state.log l - INNER JOIN state.transaction t ON t.hash = l.tx_hash - INNER JOIN state.l2block b ON b.block_num = t.l2_block_num - WHERE b.block_hash = $1 - ORDER BY b.block_num ASC, l.log_index ASC` - const getLogsByFilterSQL = ` - SELECT t.l2_block_num, b.block_hash, l.tx_hash, l.log_index, l.address, l.data, l.topic0, l.topic1, l.topic2, l.topic3 - FROM state.log l - INNER JOIN state.transaction t ON t.hash = l.tx_hash - INNER JOIN state.l2block b ON b.block_num = t.l2_block_num - WHERE b.block_num BETWEEN $1 AND $2 AND (l.address = any($3) OR $3 IS NULL) - AND (l.topic0 = any($4) OR $4 IS NULL) - AND (l.topic1 = any($5) OR $5 IS NULL) - AND (l.topic2 = any($6) OR $6 IS NULL) - AND (l.topic3 = any($7) OR $7 IS NULL) - AND (b.created_at >= $8 OR $8 IS NULL) - ORDER BY b.block_num ASC, l.log_index ASC` - - var err error - var rows pgx.Rows - q := p.getExecQuerier(dbTx) + SELECT t.l2_block_num, b.block_hash, l.tx_hash, l.log_index, l.address, l.data, l.topic0, l.topic1, l.topic2, l.topic3 + FROM state.log l + INNER JOIN state.transaction t ON t.hash = l.tx_hash + INNER JOIN state.l2block b ON b.block_num = t.l2_block_num + WHERE b.block_hash = $1 + AND (l.address = any($2) OR $2 IS NULL) + AND (l.topic0 = any($3) OR $3 IS NULL) + AND (l.topic1 = any($4) OR $4 IS NULL) + AND (l.topic2 = any($5) OR $5 IS NULL) + AND (l.topic3 = any($6) OR $6 IS NULL) + AND (b.created_at >= $7 OR $7 IS NULL) + ORDER BY b.block_num ASC, l.log_index ASC` + const getLogsByBlockNumbersSQL = ` + SELECT t.l2_block_num, b.block_hash, l.tx_hash, l.log_index, l.address, l.data, l.topic0, l.topic1, l.topic2, l.topic3 + FROM state.log l + INNER JOIN state.transaction t ON t.hash = l.tx_hash + INNER JOIN state.l2block b ON b.block_num = t.l2_block_num + WHERE b.block_num BETWEEN $1 AND $2 + AND (l.address = any($3) OR $3 IS NULL) + AND (l.topic0 = any($4) OR $4 IS NULL) + AND (l.topic1 = any($5) OR $5 IS NULL) + AND (l.topic2 = any($6) OR $6 IS NULL) + AND (l.topic3 = any($7) OR $7 IS NULL) + AND (b.created_at >= $8 OR $8 IS NULL) + ORDER BY b.block_num ASC, l.log_index ASC` + + var args []interface{} + var query string if blockHash != nil { - rows, err = q.Query(ctx, getLogsByBlockHashSQL, blockHash.String()) + args = []interface{}{blockHash.String()} + query = getLogsByBlockHashSQL + } else { + args = []interface{}{fromBlock, toBlock} + query = getLogsByBlockNumbersSQL + } + + if len(addresses) > 0 { + args = append(args, p.addressesToHex(addresses)) } else { - args := []interface{}{fromBlock, toBlock} + args = append(args, nil) + } - if len(addresses) > 0 { - args = append(args, p.addressesToHex(addresses)) + for i := 0; i < maxTopics; i++ { + if len(topics) > i && len(topics[i]) > 0 { + args = append(args, p.hashesToHex(topics[i])) } else { args = append(args, nil) } + } - for i := 0; i < maxTopics; i++ { - if len(topics) > i && len(topics[i]) > 0 { - args = append(args, p.hashesToHex(topics[i])) - } else { - args = append(args, nil) - } - } - - args = append(args, since) + args = append(args, since) - rows, err = q.Query(ctx, getLogsByFilterSQL, args...) - } + q := p.getExecQuerier(dbTx) + rows, err := q.Query(ctx, query, args...) if err != nil { return nil, err diff --git a/test/e2e/sc_test.go b/test/e2e/sc_test.go index c2f7b4d5c5..82950e79d8 100644 --- a/test/e2e/sc_test.go +++ b/test/e2e/sc_test.go @@ -15,6 +15,7 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethclient" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -85,11 +86,214 @@ func TestEmitLog2(t *testing.T) { err = opsMan.Setup() require.NoError(t, err) + type testCase struct { + name string + logsFromSubscription chan types.Log + subscribe func(*testing.T, *ethclient.Client, *testCase, common.Address) ethereum.Subscription + getLogs func(*testing.T, *ethclient.Client, *testCase, common.Address, *types.Receipt, ethereum.Subscription) []types.Log + validate func(*testing.T, context.Context, []types.Log, *EmitLog2.EmitLog2) + } + + testCases := []testCase{ + { + name: "validate logs by block number", + getLogs: func(t *testing.T, client *ethclient.Client, tc *testCase, scAddr common.Address, scCallTxReceipt *types.Receipt, sub ethereum.Subscription) []types.Log { + filterBlock := scCallTxReceipt.BlockNumber + logs, err := client.FilterLogs(ctx, ethereum.FilterQuery{ + FromBlock: filterBlock, ToBlock: filterBlock, + Addresses: []common.Address{scAddr}, + }) + require.NoError(t, err) + return logs + }, + validate: func(t *testing.T, ctx context.Context, logs []types.Log, sc *EmitLog2.EmitLog2) { + assert.Equal(t, 4, len(logs)) + + log0 := getLogByIndex(0, logs) + assert.Equal(t, 0, len(log0.Topics)) + + _, err = sc.ParseLog(getLogByIndex(1, logs)) + require.NoError(t, err) + + logA, err := sc.ParseLogA(getLogByIndex(2, logs)) + require.NoError(t, err) + expectedA := big.NewInt(1) + assert.Equal(t, 0, logA.A.Cmp(expectedA), "A expected to be: %v found: %v", expectedA.String(), logA.A.String()) + + logABCD, err := sc.ParseLogABCD(getLogByIndex(3, logs)) + require.NoError(t, err) + expectedA = big.NewInt(1) + expectedB := big.NewInt(2) + expectedC := big.NewInt(3) + expectedD := big.NewInt(4) + assert.Equal(t, 0, logABCD.A.Cmp(expectedA), "A expected to be: %v found: %v", expectedA.String(), logABCD.A.String()) + assert.Equal(t, 0, logABCD.B.Cmp(expectedB), "B expected to be: %v found: %v", expectedA.String(), logABCD.B.String()) + assert.Equal(t, 0, logABCD.C.Cmp(expectedC), "C expected to be: %v found: %v", expectedA.String(), logABCD.C.String()) + assert.Equal(t, 0, logABCD.D.Cmp(expectedD), "D expected to be: %v found: %v", expectedA.String(), logABCD.D.String()) + }, + }, + { + name: "validate logs by block number and topics", + getLogs: func(t *testing.T, client *ethclient.Client, tc *testCase, scAddr common.Address, scCallTxReceipt *types.Receipt, sub ethereum.Subscription) []types.Log { + filterBlock := scCallTxReceipt.BlockNumber + logs, err := client.FilterLogs(ctx, ethereum.FilterQuery{ + FromBlock: filterBlock, ToBlock: filterBlock, + Addresses: []common.Address{scAddr}, + Topics: [][]common.Hash{ + { + common.HexToHash("0xe5562b12d9276c5c987df08afff7b1946f2d869236866ea2285c7e2e95685a64"), + common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000001"), + common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000002"), + common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000003"), + }, + }, + }) + require.NoError(t, err) + return logs + }, + validate: func(t *testing.T, ctx context.Context, logs []types.Log, sc *EmitLog2.EmitLog2) { + assert.Equal(t, 1, len(logs)) + + logABCD, err := sc.ParseLogABCD(getLogByIndex(3, logs)) + require.NoError(t, err) + expectedA := big.NewInt(1) + expectedB := big.NewInt(2) + expectedC := big.NewInt(3) + expectedD := big.NewInt(4) + assert.Equal(t, 0, logABCD.A.Cmp(expectedA), "A expected to be: %v found: %v", expectedA.String(), logABCD.A.String()) + assert.Equal(t, 0, logABCD.B.Cmp(expectedB), "B expected to be: %v found: %v", expectedA.String(), logABCD.B.String()) + assert.Equal(t, 0, logABCD.C.Cmp(expectedC), "C expected to be: %v found: %v", expectedA.String(), logABCD.C.String()) + assert.Equal(t, 0, logABCD.D.Cmp(expectedD), "D expected to be: %v found: %v", expectedA.String(), logABCD.D.String()) + }, + }, + { + name: "validate logs by block hash", + getLogs: func(t *testing.T, client *ethclient.Client, tc *testCase, scAddr common.Address, scCallTxReceipt *types.Receipt, sub ethereum.Subscription) []types.Log { + filterBlock := scCallTxReceipt.BlockHash + logs, err := client.FilterLogs(ctx, ethereum.FilterQuery{ + BlockHash: &filterBlock, + Addresses: []common.Address{scAddr}, + }) + require.NoError(t, err) + return logs + }, + validate: func(t *testing.T, ctx context.Context, logs []types.Log, sc *EmitLog2.EmitLog2) { + assert.Equal(t, 4, len(logs)) + + log0 := getLogByIndex(0, logs) + assert.Equal(t, 0, len(log0.Topics)) + + _, err = sc.ParseLog(getLogByIndex(1, logs)) + require.NoError(t, err) + + logA, err := sc.ParseLogA(getLogByIndex(2, logs)) + require.NoError(t, err) + expectedA := big.NewInt(1) + assert.Equal(t, 0, logA.A.Cmp(expectedA), "A expected to be: %v found: %v", expectedA.String(), logA.A.String()) + + logABCD, err := sc.ParseLogABCD(getLogByIndex(3, logs)) + require.NoError(t, err) + expectedA = big.NewInt(1) + expectedB := big.NewInt(2) + expectedC := big.NewInt(3) + expectedD := big.NewInt(4) + assert.Equal(t, 0, logABCD.A.Cmp(expectedA), "A expected to be: %v found: %v", expectedA.String(), logABCD.A.String()) + assert.Equal(t, 0, logABCD.B.Cmp(expectedB), "B expected to be: %v found: %v", expectedA.String(), logABCD.B.String()) + assert.Equal(t, 0, logABCD.C.Cmp(expectedC), "C expected to be: %v found: %v", expectedA.String(), logABCD.C.String()) + assert.Equal(t, 0, logABCD.D.Cmp(expectedD), "D expected to be: %v found: %v", expectedA.String(), logABCD.D.String()) + }, + }, + { + name: "validate logs by block hash and topics", + getLogs: func(t *testing.T, client *ethclient.Client, tc *testCase, scAddr common.Address, scCallTxReceipt *types.Receipt, sub ethereum.Subscription) []types.Log { + filterBlock := scCallTxReceipt.BlockHash + logs, err := client.FilterLogs(ctx, ethereum.FilterQuery{ + BlockHash: &filterBlock, + Addresses: []common.Address{scAddr}, + Topics: [][]common.Hash{ + { + common.HexToHash("0xe5562b12d9276c5c987df08afff7b1946f2d869236866ea2285c7e2e95685a64"), + common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000001"), + common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000002"), + common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000003"), + }, + }, + }) + require.NoError(t, err) + return logs + }, + validate: func(t *testing.T, ctx context.Context, logs []types.Log, sc *EmitLog2.EmitLog2) { + assert.Equal(t, 1, len(logs)) + + logABCD, err := sc.ParseLogABCD(getLogByIndex(3, logs)) + require.NoError(t, err) + expectedA := big.NewInt(1) + expectedB := big.NewInt(2) + expectedC := big.NewInt(3) + expectedD := big.NewInt(4) + assert.Equal(t, 0, logABCD.A.Cmp(expectedA), "A expected to be: %v found: %v", expectedA.String(), logABCD.A.String()) + assert.Equal(t, 0, logABCD.B.Cmp(expectedB), "B expected to be: %v found: %v", expectedA.String(), logABCD.B.String()) + assert.Equal(t, 0, logABCD.C.Cmp(expectedC), "C expected to be: %v found: %v", expectedA.String(), logABCD.C.String()) + assert.Equal(t, 0, logABCD.D.Cmp(expectedD), "D expected to be: %v found: %v", expectedA.String(), logABCD.D.String()) + }, + }, + { + name: "validate logs by subscription", + subscribe: func(t *testing.T, c *ethclient.Client, tc *testCase, scAddr common.Address) ethereum.Subscription { + query := ethereum.FilterQuery{Addresses: []common.Address{scAddr}} + sub, err := c.SubscribeFilterLogs(context.Background(), query, tc.logsFromSubscription) + require.NoError(t, err) + return sub + }, + getLogs: func(t *testing.T, c *ethclient.Client, tc *testCase, a common.Address, r *types.Receipt, sub ethereum.Subscription) []types.Log { + logs := []types.Log{} + for { + select { + case err := <-sub.Err(): + require.NoError(t, err) + case vLog, closed := <-tc.logsFromSubscription: + logs = append(logs, vLog) + if len(logs) == 4 && closed { + return logs + } + } + } + }, + validate: func(t *testing.T, ctx context.Context, logs []types.Log, sc *EmitLog2.EmitLog2) { + assert.Equal(t, 4, len(logs)) + + log0 := getLogByIndex(0, logs) + assert.Equal(t, 0, len(log0.Topics)) + + _, err = sc.ParseLog(getLogByIndex(1, logs)) + require.NoError(t, err) + + logA, err := sc.ParseLogA(getLogByIndex(2, logs)) + require.NoError(t, err) + expectedA := big.NewInt(1) + assert.Equal(t, 0, logA.A.Cmp(expectedA), "A expected to be: %v found: %v", expectedA.String(), logA.A.String()) + + logABCD, err := sc.ParseLogABCD(getLogByIndex(3, logs)) + require.NoError(t, err) + expectedA = big.NewInt(1) + expectedB := big.NewInt(2) + expectedC := big.NewInt(3) + expectedD := big.NewInt(4) + assert.Equal(t, 0, logABCD.A.Cmp(expectedA), "A expected to be: %v found: %v", expectedA.String(), logABCD.A.String()) + assert.Equal(t, 0, logABCD.B.Cmp(expectedB), "B expected to be: %v found: %v", expectedA.String(), logABCD.B.String()) + assert.Equal(t, 0, logABCD.C.Cmp(expectedC), "C expected to be: %v found: %v", expectedA.String(), logABCD.C.String()) + assert.Equal(t, 0, logABCD.D.Cmp(expectedD), "D expected to be: %v found: %v", expectedA.String(), logABCD.D.String()) + }, + }, + } + for _, network := range networks { log.Debugf(network.Name) client := operations.MustGetClient(network.URL) + wsClient := operations.MustGetClient(network.WebSocketURL) auth := operations.MustGetAuth(network.PrivateKey, network.ChainID) + // deploy sc scAddr, scTx, sc, err := EmitLog2.DeployEmitLog2(auth, client) require.NoError(t, err) @@ -97,40 +301,31 @@ func TestEmitLog2(t *testing.T) { err = operations.WaitTxToBeMined(ctx, client, scTx, operations.DefaultTimeoutTxToBeMined) require.NoError(t, err) - scCallTx, err := sc.EmitLogs(auth) - require.NoError(t, err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.logsFromSubscription = make(chan types.Log) - logTx(scCallTx) - err = operations.WaitTxToBeMined(ctx, client, scCallTx, operations.DefaultTimeoutTxToBeMined) - require.NoError(t, err) + var sub ethereum.Subscription + if tc.subscribe != nil { + sub = tc.subscribe(t, wsClient, &tc, scAddr) + } - scCallTxReceipt, err := client.TransactionReceipt(ctx, scCallTx.Hash()) - require.NoError(t, err) + // emit logs + scCallTx, err := sc.EmitLogs(auth) + require.NoError(t, err) - filterBlock := scCallTxReceipt.BlockNumber - logs, err := client.FilterLogs(ctx, ethereum.FilterQuery{ - FromBlock: filterBlock, ToBlock: filterBlock, - Addresses: []common.Address{scAddr}, - }) - require.NoError(t, err) - assert.Equal(t, 4, len(logs)) + logTx(scCallTx) + err = operations.WaitTxToBeMined(ctx, client, scCallTx, operations.DefaultTimeoutTxToBeMined) + require.NoError(t, err) - log0 := getLogByIndex(0, logs) - assert.Equal(t, 0, len(log0.Topics)) + scCallTxReceipt, err := client.TransactionReceipt(ctx, scCallTx.Hash()) + require.NoError(t, err) - _, err = sc.ParseLog(getLogByIndex(1, logs)) - require.NoError(t, err) + logs := tc.getLogs(t, client, &tc, scAddr, scCallTxReceipt, sub) - logA, err := sc.ParseLogA(getLogByIndex(2, logs)) - require.NoError(t, err) - assert.Equal(t, big.NewInt(1), logA.A) - - logABCD, err := sc.ParseLogABCD(getLogByIndex(3, logs)) - require.NoError(t, err) - assert.Equal(t, big.NewInt(1), logABCD.A) - assert.Equal(t, big.NewInt(2), logABCD.B) - assert.Equal(t, big.NewInt(3), logABCD.C) - assert.Equal(t, big.NewInt(4), logABCD.D) + tc.validate(t, ctx, logs, sc) + }) + } } }