Skip to content

Commit

Permalink
rpc: add filter IsValid functionality
Browse files Browse the repository at this point in the history
Additional check of filters parameters added for filter validation.

Closes #3241.

Signed-off-by: Ekaterina Pavlova <[email protected]>
  • Loading branch information
AliceInHunterland committed Dec 13, 2023
1 parent 2c647b2 commit 04c5f2c
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 22 deletions.
58 changes: 58 additions & 0 deletions pkg/neorpc/filters.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package neorpc

import (
"errors"

"github.com/nspcc-dev/neo-go/pkg/core/interop/runtime"
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
"github.com/nspcc-dev/neo-go/pkg/util"
)
Expand Down Expand Up @@ -48,6 +51,43 @@ type (
Type *mempoolevent.Type `json:"type,omitempty"`
}
)
type ValidFilter interface {
IsValid() error
}

// IsValid checks whether the filter is valid.
func (f *BlockFilter) IsValid() error {
return nil
}

// IsValid checks whether the filter is valid.
func (f *TxFilter) IsValid() error {
return nil
}

// IsValid checks whether the filter is valid.
func (f *NotificationFilter) IsValid() error {
if f.Name != nil && runtime.MaxEventNameLen < len(*f.Name) {
return errors.New("bad name parameter")
}
return nil
}

// IsValid checks whether the filter is valid.
func (f *ExecutionFilter) IsValid() error {
if f.State != nil {
if *f.State != "HALT" && *f.State != "FAULT" {
return errors.New("bad state parameter")
}
}

return nil
}

// IsValid checks whether the filter is valid.
func (f *NotaryRequestFilter) IsValid() error {
return nil
}

// Copy creates a deep copy of the BlockFilter. It handles nil BlockFilter correctly.
func (f *BlockFilter) Copy() *BlockFilter {
Expand Down Expand Up @@ -141,3 +181,21 @@ func (f *NotaryRequestFilter) Copy() *NotaryRequestFilter {
}
return res
}

// NewFilter creates a new filter for the specified event.
func NewFilter(event EventID) ValidFilter {
switch event {
case BlockEventID:
return new(BlockFilter)
case TransactionEventID:
return new(TxFilter)
case NotaryRequestEventID:
return new(NotaryRequestFilter)
case NotificationEventID:
return new(NotificationFilter)
case ExecutionEventID:
return new(ExecutionFilter)
default:
return nil
}
}
5 changes: 0 additions & 5 deletions pkg/rpcclient/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -795,11 +795,6 @@ func (c *WSClient) ReceiveExecutions(flt *neorpc.ExecutionFilter, rcvr chan<- *s
}
params := []any{"transaction_executed"}
if flt != nil {
if flt.State != nil {
if *flt.State != "HALT" && *flt.State != "FAULT" {
return "", errors.New("bad state parameter")
}
}
flt = flt.Copy()
params = append(params, *flt)
}
Expand Down
29 changes: 12 additions & 17 deletions pkg/services/rpcsrv/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2729,29 +2729,24 @@ func (s *Server) subscribe(reqParams params.Params, sub *subscriber) (any, *neor
param := *p
jd := json.NewDecoder(bytes.NewReader(param.RawMessage))
jd.DisallowUnknownFields()
flt := neorpc.NewFilter(event)
err = flt.IsValid()
if err != nil {
return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, err.Error())
}
err = jd.Decode(flt)
switch event {
case neorpc.BlockEventID:
flt := new(neorpc.BlockFilter)
err = jd.Decode(flt)
filter = *flt
filter = *flt.(*neorpc.BlockFilter)
case neorpc.TransactionEventID:
flt := new(neorpc.TxFilter)
err = jd.Decode(flt)
filter = *flt
filter = *flt.(*neorpc.TxFilter)
case neorpc.NotaryRequestEventID:
flt := new(neorpc.NotaryRequestFilter)
err = jd.Decode(flt)
filter = *flt
filter = *flt.(*neorpc.NotaryRequestFilter)
case neorpc.NotificationEventID:
flt := new(neorpc.NotificationFilter)
err = jd.Decode(flt)
filter = *flt
filter = *flt.(*neorpc.NotificationFilter)
case neorpc.ExecutionEventID:
flt := new(neorpc.ExecutionFilter)
err = jd.Decode(flt)
if err == nil && (flt.State == nil || (*flt.State == "HALT" || *flt.State == "FAULT")) {
filter = *flt
} else if err == nil {
filter = *flt.(*neorpc.ExecutionFilter)
if err == nil {
err = errors.New("invalid state")
}
}
Expand Down

0 comments on commit 04c5f2c

Please sign in to comment.