Skip to content

Commit

Permalink
neorpc: add WS notification filter IsValid functionality
Browse files Browse the repository at this point in the history
Additional check of filters parameters added for filter validation.

Closes #3241.

Signed-off-by: Ekaterina Pavlova <[email protected]>
  • Loading branch information
AliceInHunterland committed Dec 18, 2023
1 parent 385c1d5 commit 50e55c4
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 16 deletions.
49 changes: 49 additions & 0 deletions pkg/neorpc/filters.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package neorpc

import (
"errors"
"fmt"

"github.com/nspcc-dev/neo-go/pkg/core/interop/runtime"
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/vmstate"
)

type (
Expand Down Expand Up @@ -49,6 +54,16 @@ type (
}
)

// SubscriptionFilter is an interface for all subscription filters.
type SubscriptionFilter interface {
// IsValid checks whether the filter is valid and returns
// a specific [ErrInvalidSubscriptionFilter] error if not.
IsValid() error
}

// ErrInvalidSubscriptionFilter is returned when the filter is invalid.
var ErrInvalidSubscriptionFilter = errors.New("invalid subscription filter")

// Copy creates a deep copy of the BlockFilter. It handles nil BlockFilter correctly.
func (f *BlockFilter) Copy() *BlockFilter {
if f == nil {
Expand All @@ -70,6 +85,11 @@ func (f *BlockFilter) Copy() *BlockFilter {
return res
}

// IsValid implements SubscriptionFilter interface.
func (f BlockFilter) IsValid() error {
return nil
}

// Copy creates a deep copy of the TxFilter. It handles nil TxFilter correctly.
func (f *TxFilter) Copy() *TxFilter {
if f == nil {
Expand All @@ -87,6 +107,11 @@ func (f *TxFilter) Copy() *TxFilter {
return res
}

// IsValid implements SubscriptionFilter interface.
func (f TxFilter) IsValid() error {
return nil
}

// Copy creates a deep copy of the NotificationFilter. It handles nil NotificationFilter correctly.
func (f *NotificationFilter) Copy() *NotificationFilter {
if f == nil {
Expand All @@ -104,6 +129,14 @@ func (f *NotificationFilter) Copy() *NotificationFilter {
return res
}

// IsValid implements SubscriptionFilter interface.
func (f NotificationFilter) IsValid() error {
if f.Name != nil && len(*f.Name) > runtime.MaxEventNameLen {
return fmt.Errorf("%w: NotificationFilter name parameter must be less than %d", ErrInvalidSubscriptionFilter, runtime.MaxEventNameLen)
}
return nil
}

// Copy creates a deep copy of the ExecutionFilter. It handles nil ExecutionFilter correctly.
func (f *ExecutionFilter) Copy() *ExecutionFilter {
if f == nil {
Expand All @@ -121,6 +154,17 @@ func (f *ExecutionFilter) Copy() *ExecutionFilter {
return res
}

// IsValid implements SubscriptionFilter interface.
func (f ExecutionFilter) IsValid() error {
if f.State != nil {
if *f.State != vmstate.Halt.String() && *f.State != vmstate.Fault.String() {
return fmt.Errorf("%w: ExecutionFilter state parameter must be either %s or %s", ErrInvalidSubscriptionFilter, vmstate.Halt, vmstate.Fault)
}
}

return nil
}

// Copy creates a deep copy of the NotaryRequestFilter. It handles nil NotaryRequestFilter correctly.
func (f *NotaryRequestFilter) Copy() *NotaryRequestFilter {
if f == nil {
Expand All @@ -141,3 +185,8 @@ func (f *NotaryRequestFilter) Copy() *NotaryRequestFilter {
}
return res
}

// IsValid implements SubscriptionFilter interface.
func (f NotaryRequestFilter) IsValid() error {
return nil
}
2 changes: 1 addition & 1 deletion pkg/neorpc/rpcevent/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type (
// filter notifications.
Comparator interface {
EventID() neorpc.EventID
Filter() any
Filter() neorpc.SubscriptionFilter
}
// Container is an interface required from notification event to be able to
// pass filter.
Expand Down
4 changes: 2 additions & 2 deletions pkg/neorpc/rpcevent/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
type (
testComparator struct {
id neorpc.EventID
filter any
filter neorpc.SubscriptionFilter
}
testContainer struct {
id neorpc.EventID
Expand All @@ -29,7 +29,7 @@ type (
func (c testComparator) EventID() neorpc.EventID {
return c.id
}
func (c testComparator) Filter() any {
func (c testComparator) Filter() neorpc.SubscriptionFilter {
return c.filter
}
func (c testContainer) EventID() neorpc.EventID {
Expand Down
20 changes: 10 additions & 10 deletions pkg/rpcclient/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (r *blockReceiver) EventID() neorpc.EventID {
}

// Filter implements neorpc.Comparator interface.
func (r *blockReceiver) Filter() any {
func (r *blockReceiver) Filter() neorpc.SubscriptionFilter {
if r.filter == nil {
return nil
}
Expand Down Expand Up @@ -174,7 +174,7 @@ func (r *txReceiver) EventID() neorpc.EventID {
}

// Filter implements neorpc.Comparator interface.
func (r *txReceiver) Filter() any {
func (r *txReceiver) Filter() neorpc.SubscriptionFilter {
if r.filter == nil {
return nil
}
Expand Down Expand Up @@ -221,7 +221,7 @@ func (r *executionNotificationReceiver) EventID() neorpc.EventID {
}

// Filter implements neorpc.Comparator interface.
func (r *executionNotificationReceiver) Filter() any {
func (r *executionNotificationReceiver) Filter() neorpc.SubscriptionFilter {
if r.filter == nil {
return nil
}
Expand Down Expand Up @@ -268,7 +268,7 @@ func (r *executionReceiver) EventID() neorpc.EventID {
}

// Filter implements neorpc.Comparator interface.
func (r *executionReceiver) Filter() any {
func (r *executionReceiver) Filter() neorpc.SubscriptionFilter {
if r.filter == nil {
return nil
}
Expand Down Expand Up @@ -315,7 +315,7 @@ func (r *notaryRequestReceiver) EventID() neorpc.EventID {
}

// Filter implements neorpc.Comparator interface.
func (r *notaryRequestReceiver) Filter() any {
func (r *notaryRequestReceiver) Filter() neorpc.SubscriptionFilter {
if r.filter == nil {
return nil
}
Expand Down Expand Up @@ -712,6 +712,11 @@ func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) {
func (c *WSClient) performSubscription(params []any, rcvr notificationReceiver) (string, error) {
var resp string

if flt := rcvr.Filter(); flt != nil {
if err := flt.IsValid(); err != nil {
return "", err
}
}
if err := c.performRequest("subscribe", params, &resp); err != nil {
return "", err
}
Expand Down Expand Up @@ -795,11 +800,6 @@ func (c *WSClient) ReceiveExecutions(flt *neorpc.ExecutionFilter, rcvr chan<- *s
}
params := []any{"transaction_executed"}
if flt != nil {
if flt.State != nil {
if *flt.State != "HALT" && *flt.State != "FAULT" {
return "", errors.New("bad state parameter")
}
}
flt = flt.Copy()
params = append(params, *flt)
}
Expand Down
13 changes: 13 additions & 0 deletions pkg/rpcclient/wsclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,19 @@ func TestWSExecutionVMStateCheck(t *testing.T) {
wsc.Close()
}

func TestWSNotificationNameCheck(t *testing.T) {
// Will answer successfully if request slips through.
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init())
filter := "notification_from_execution_with_long_name"
_, err = wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &filter}, make(chan *state.ContainedNotificationEvent))
require.Error(t, err)
wsc.Close()
}

func TestWSFilteredSubscriptions(t *testing.T) {
var cases = []struct {
name string
Expand Down
8 changes: 7 additions & 1 deletion pkg/services/rpcsrv/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2724,7 +2724,7 @@ func (s *Server) subscribe(reqParams params.Params, sub *subscriber) (any, *neor
return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, "P2PSigExtensions are disabled")
}
// Optional filter.
var filter any
var filter neorpc.SubscriptionFilter
if p := reqParams.Value(1); p != nil {
param := *p
jd := json.NewDecoder(bytes.NewReader(param.RawMessage))
Expand Down Expand Up @@ -2759,6 +2759,12 @@ func (s *Server) subscribe(reqParams params.Params, sub *subscriber) (any, *neor
return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, err.Error())
}
}
if filter != nil {
err = filter.IsValid()
if err != nil {
return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, err.Error())
}
}

s.subsLock.Lock()
var id int
Expand Down
4 changes: 2 additions & 2 deletions pkg/services/rpcsrv/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type (
// feed stores subscriber's desired event ID with filter.
feed struct {
event neorpc.EventID
filter any
filter neorpc.SubscriptionFilter
}
)

Expand All @@ -38,7 +38,7 @@ func (f feed) EventID() neorpc.EventID {
}

// Filter implements neorpc.EventComparator interface and returns notification filter.
func (f feed) Filter() any {
func (f feed) Filter() neorpc.SubscriptionFilter {
return f.filter
}

Expand Down
57 changes: 57 additions & 0 deletions pkg/services/rpcsrv/subscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -625,3 +625,60 @@ func TestSubscriptionOverflow(t *testing.T) {
finishedFlag.CompareAndSwap(false, true)
c.Close()
}

func TestFilteredIsValidSubscriptions(t *testing.T) {
priv0 := testchain.PrivateKeyByID(0)
var goodSender = priv0.GetScriptHash()

var cases = map[string]struct {
params string
}{
"tx wrong sender": {
params: `["transaction_added", {"unknown_sender":"` + goodSender.StringLE() + `"}]`,
},
"notification with long name": {
params: `["notification_from_execution", {"name":"notification_from_execution_with_long_name"}]`,
},
"execution not valid vm state": {
params: `["transaction_executed", {"state":"NOTHALT"}]`,
},
}
var s string
for name, this := range cases {
t.Run(name, func(t *testing.T) {
chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)

defer chain.Close()
defer rpcSrv.Shutdown()

blockSubID := callSubscribe(t, c, respMsgs, `["block_added"]`)

resp := callWSGetRaw(t, c, fmt.Sprintf(`{"jsonrpc": "2.0","method": "subscribe","params": %s,"id": 1}`, this.params), respMsgs)
require.NotNil(t, resp.Error)
require.Nil(t, resp.Result)
require.Error(t, json.Unmarshal(resp.Result, &s))

var lastBlock uint32
for _, b := range getTestBlocks(t) {
require.NoError(t, chain.AddBlock(b))
lastBlock = b.Index
}

for {
resp := getNotification(t, respMsgs)
rmap := resp.Payload[0].(map[string]any)
if resp.Event == neorpc.BlockEventID {
index := rmap["index"].(float64)
if uint32(index) == lastBlock {
break
}
continue
}
}

callUnsubscribe(t, c, respMsgs, blockSubID)
finishedFlag.CompareAndSwap(false, true)
c.Close()
})
}
}

0 comments on commit 50e55c4

Please sign in to comment.