diff --git a/cmd/gossamer/config.go b/cmd/gossamer/config.go index ce944d426d..5ee894b7ba 100644 --- a/cmd/gossamer/config.go +++ b/cmd/gossamer/config.go @@ -686,12 +686,16 @@ func setDotNetworkConfig(ctx *cli.Context, tomlCfg ctoml.NetworkConfig, cfg *dot func setDotRPCConfig(ctx *cli.Context, tomlCfg ctoml.RPCConfig, cfg *dot.RPCConfig) { cfg.Enabled = tomlCfg.Enabled cfg.External = tomlCfg.External + cfg.Unsafe = tomlCfg.Unsafe + cfg.UnsafeExternal = tomlCfg.UnsafeExternal cfg.Port = tomlCfg.Port cfg.Host = tomlCfg.Host cfg.Modules = tomlCfg.Modules cfg.WSPort = tomlCfg.WSPort cfg.WS = tomlCfg.WS cfg.WSExternal = tomlCfg.WSExternal + cfg.WSUnsafe = tomlCfg.WSUnsafe + cfg.WSUnsafeExternal = tomlCfg.WSUnsafeExternal // check --rpc flag and update node configuration if enabled := ctx.GlobalBool(RPCEnabledFlag.Name); enabled || cfg.Enabled { diff --git a/cmd/gossamer/flags.go b/cmd/gossamer/flags.go index 284a8a3e25..370cab4cbc 100644 --- a/cmd/gossamer/flags.go +++ b/cmd/gossamer/flags.go @@ -189,6 +189,16 @@ var ( Name: "rpc-external", Usage: "Enable external HTTP-RPC connections", } + // RPCEnabledFlag Enable the HTTP-RPC + RPCUnsafeEnabledFlag = cli.BoolFlag{ + Name: "rpc-unsafe", + Usage: "Enable the HTTP-RPC server to unsafe procedures", + } + // RPCExternalFlag Enable the external HTTP-RPC + RPCUnsafeExternalFlag = cli.BoolFlag{ + Name: "rpc-unsafe-external", + Usage: "Enable external HTTP-RPC connections to unsafe procedures", + } // RPCHostFlag HTTP-RPC server listening hostname RPCHostFlag = cli.StringFlag{ Name: "rpchost", @@ -219,6 +229,16 @@ var ( Name: "ws-external", Usage: "Enable external websocket connections", } + // WSFlag Enable the websockets server + WSUnsafeFlag = cli.BoolFlag{ + Name: "ws-unsafe", + Usage: "Enable access to websocket unsafe calls", + } + // WSExternalFlag Enable external websocket connections + WSUnsafeExternalFlag = cli.BoolFlag{ + Name: "ws-unsafe-external", + Usage: "Enable external access to websocket unsafe calls", + } ) // Account management flags @@ -329,11 +349,15 @@ var ( // rpc flags RPCEnabledFlag, RPCExternalFlag, + RPCUnsafeEnabledFlag, + RPCUnsafeExternalFlag, RPCHostFlag, RPCPortFlag, RPCModulesFlag, WSFlag, WSExternalFlag, + WSUnsafeFlag, + WSUnsafeExternalFlag, WSPortFlag, // metrics flag diff --git a/dot/config.go b/dot/config.go index ef0b23ba41..048f522e13 100644 --- a/dot/config.go +++ b/dot/config.go @@ -106,14 +106,26 @@ type CoreConfig struct { // RPCConfig is to marshal/unmarshal toml RPC config vars type RPCConfig struct { - Enabled bool - External bool - Port uint32 - Host string - Modules []string - WSPort uint32 - WS bool - WSExternal bool + Enabled bool + External bool + Unsafe bool + UnsafeExternal bool + Port uint32 + Host string + Modules []string + WSPort uint32 + WS bool + WSExternal bool + WSUnsafe bool + WSUnsafeExternal bool +} + +func (r *RPCConfig) isRPCEnabled() bool { + return r.Enabled || r.External || r.Unsafe || r.UnsafeExternal +} + +func (r *RPCConfig) isWSEnabled() bool { + return r.WS || r.WSExternal || r.WSUnsafe || r.WSUnsafeExternal } // StateConfig is the config for the State service diff --git a/dot/config/toml/config.go b/dot/config/toml/config.go index 7cc270e590..3a68175945 100644 --- a/dot/config/toml/config.go +++ b/dot/config/toml/config.go @@ -86,12 +86,16 @@ type CoreConfig struct { // RPCConfig is to marshal/unmarshal toml RPC config vars type RPCConfig struct { - Enabled bool `toml:"enabled,omitempty"` - External bool `toml:"external,omitempty"` - Port uint32 `toml:"port,omitempty"` - Host string `toml:"host,omitempty"` - Modules []string `toml:"modules,omitempty"` - WSPort uint32 `toml:"ws-port,omitempty"` - WS bool `toml:"ws,omitempty"` - WSExternal bool `toml:"ws-external,omitempty"` + Enabled bool `toml:"enabled,omitempty"` + Unsafe bool `toml:"unsafe,omitempty"` + UnsafeExternal bool `toml:"unsafe-external,omitempty"` + External bool `toml:"external,omitempty"` + Port uint32 `toml:"port,omitempty"` + Host string `toml:"host,omitempty"` + Modules []string `toml:"modules,omitempty"` + WSPort uint32 `toml:"ws-port,omitempty"` + WS bool `toml:"ws,omitempty"` + WSExternal bool `toml:"ws-external,omitempty"` + WSUnsafe bool `toml:"ws-unsafe,omitempty"` + WSUnsafeExternal bool `toml:"ws-unsafe-external,omitempty"` } diff --git a/dot/node.go b/dot/node.go index 7cdde1aecf..46bb31e65b 100644 --- a/dot/node.go +++ b/dot/node.go @@ -95,10 +95,7 @@ func InitNode(cfg *Config) error { config := state.Config{ Path: cfg.Global.BasePath, LogLevel: cfg.Global.LogLvl, - PrunerCfg: struct { - Mode pruner.Mode - RetainedBlocks int64 - }{ + PrunerCfg: pruner.Config{ Mode: cfg.Global.Pruning, RetainedBlocks: cfg.Global.RetainBlocks, }, @@ -310,7 +307,7 @@ func NewNode(cfg *Config, ks *keystore.GlobalKeystore, stopFunc func()) (*Node, nodeSrvcs = append(nodeSrvcs, sysSrvc) // check if rpc service is enabled - if enabled := cfg.RPC.Enabled || cfg.RPC.WS; enabled { + if enabled := cfg.RPC.isRPCEnabled() || cfg.RPC.isWSEnabled(); enabled { rpcSrvc := createRPCService(cfg, stateSrvc, coreSrvc, networkSrvc, bp, sysSrvc, fg) nodeSrvcs = append(nodeSrvcs, rpcSrvc) } else { diff --git a/dot/rpc/helpers.go b/dot/rpc/helpers.go index 732f68c62d..8f47c8f251 100644 --- a/dot/rpc/helpers.go +++ b/dot/rpc/helpers.go @@ -18,8 +18,12 @@ package rpc import ( "errors" + "fmt" "net" + "strings" + "github.com/ChainSafe/gossamer/dot/rpc/modules" + "github.com/go-playground/validator/v10" "github.com/gorilla/rpc/v2" "github.com/jpillora/ipfilter" ) @@ -35,6 +39,7 @@ func LocalhostFilter() *ipfilter.IPFilter { // LocalRequestOnly HTTP handler to restrict to only local connections func LocalRequestOnly(r *rpc.RequestInfo, i interface{}) error { ip, _, err := net.SplitHostPort(r.Request.RemoteAddr) + if err != nil { return errors.New("unable to parse IP") } @@ -44,3 +49,42 @@ func LocalRequestOnly(r *rpc.RequestInfo, i interface{}) error { } return errors.New("external HTTP request refused") } + +func snakeCaseFormat(method string) (string, error) { + parts := strings.Split(method, ".") + if len(parts) < 2 { + return "", fmt.Errorf("invalid rpc method format %s, should be 'module.FunctionName'", method) + } + + service, funcName := parts[0], parts[1] + funcName = strings.ToLower(string(funcName[0])) + funcName[1:] + return strings.Join([]string{service, funcName}, "_"), nil +} + +func rpcValidator(cfg *HTTPServerConfig, validate *validator.Validate) func(r *rpc.RequestInfo, i interface{}) error { + return func(r *rpc.RequestInfo, v interface{}) error { + var ( + err error + rpcmethod string + ) + + if rpcmethod, err = snakeCaseFormat(r.Method); err != nil { + return err + } + + isUnsafe := modules.IsUnsafe(rpcmethod) + if isUnsafe && !cfg.rpcUnsafeEnabled() { + return fmt.Errorf("unsafe rpc method %s cannot be reachable", rpcmethod) + } + + if err = validate.Struct(v); err != nil { + return err + } + + if !cfg.exposeRPC() || modules.IsUnsafe(rpcmethod) && !cfg.RPCUnsafeExternal { + return LocalRequestOnly(r, v) + } + + return nil + } +} diff --git a/dot/rpc/http.go b/dot/rpc/http.go index 4bad9a0bd6..250d615890 100644 --- a/dot/rpc/http.go +++ b/dot/rpc/http.go @@ -53,15 +53,36 @@ type HTTPServerConfig struct { TransactionQueueAPI modules.TransactionStateAPI RPCAPI modules.RPCAPI SystemAPI modules.SystemAPI - External bool + RPC bool + RPCExternal bool + RPCUnsafe bool + RPCUnsafeExternal bool Host string RPCPort uint32 WS bool WSExternal bool + WSUnsafe bool + WSUnsafeExternal bool WSPort uint32 Modules []string } +func (h *HTTPServerConfig) rpcUnsafeEnabled() bool { + return h.RPCUnsafe || h.RPCUnsafeExternal +} + +func (h *HTTPServerConfig) wsUnsafeEnabled() bool { + return h.WSUnsafe || h.WSUnsafeExternal +} + +func (h *HTTPServerConfig) exposeWS() bool { + return h.WSExternal || h.WSUnsafeExternal +} + +func (h *HTTPServerConfig) exposeRPC() bool { + return h.RPCExternal || h.RPCUnsafeExternal +} + var logger log.Logger // NewHTTPServer creates a new http server and registers an associated rpc server @@ -78,10 +99,6 @@ func NewHTTPServer(cfg *HTTPServerConfig) *HTTPServer { } server.RegisterModules(cfg.Modules) - if !cfg.External { - server.rpcServer.RegisterValidateRequestFunc(LocalRequestOnly) - } - return server } @@ -138,15 +155,8 @@ func (h *HTTPServer) Start() error { // Add custom validator for `common.Hash` validate.RegisterCustomTypeFunc(common.HashValidator, common.Hash{}) - validateHandler := func(r *rpc.RequestInfo, v interface{}) error { - err := validate.Struct(v) - if err != nil { - return err - } - return nil - } + h.rpcServer.RegisterValidateRequestFunc(rpcValidator(h.serverConfig, validate)) - h.rpcServer.RegisterValidateRequestFunc(validateHandler) go func() { err := http.ListenAndServe(fmt.Sprintf(":%d", h.serverConfig.RPCPort), r) if err != nil { @@ -199,7 +209,7 @@ func (h *HTTPServer) Stop() error { func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { var upg = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { - if !h.serverConfig.WSExternal { + if !h.serverConfig.exposeWS() { ip, _, error := net.SplitHostPort(r.RemoteAddr) if error != nil { logger.Error("unable to parse IP", "error") @@ -214,6 +224,7 @@ func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { logger.Debug("external websocket request refused", "error") return false } + return true }, } @@ -233,6 +244,7 @@ func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // NewWSConn to create new WebSocket Connection struct func NewWSConn(conn *websocket.Conn, cfg *HTTPServerConfig) *subscription.WSConn { c := &subscription.WSConn{ + UnsafeEnabled: cfg.wsUnsafeEnabled(), Wsconn: conn, Subscriptions: make(map[uint32]subscription.Listener), StorageAPI: cfg.StorageAPI, diff --git a/dot/rpc/http_test.go b/dot/rpc/http_test.go index cf3f21bac2..34397ab682 100644 --- a/dot/rpc/http_test.go +++ b/dot/rpc/http_test.go @@ -18,13 +18,23 @@ package rpc import ( "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net" "net/http" "testing" "time" "github.com/ChainSafe/gossamer/dot/core" + "github.com/ChainSafe/gossamer/dot/rpc/modules" + "github.com/ChainSafe/gossamer/dot/rpc/modules/mocks" "github.com/ChainSafe/gossamer/dot/system" "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/btcsuite/btcutil/base58" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -47,6 +57,7 @@ func TestNewHTTPServer(t *testing.T) { require.Nil(t, err) time.Sleep(time.Second) // give server a second to start + defer s.Stop() // Valid request client := &http.Client{} @@ -55,7 +66,7 @@ func TestNewHTTPServer(t *testing.T) { buf := &bytes.Buffer{} _, err = buf.Write(data) require.Nil(t, err) - req, err := http.NewRequest("POST", "http://localhost:8545/", buf) + req, err := http.NewRequest("POST", fmt.Sprintf("http://localhost:%v/", cfg.RPCPort), buf) require.Nil(t, err) req.Header.Set("Content-Type", "application/json") @@ -67,7 +78,7 @@ func TestNewHTTPServer(t *testing.T) { require.Equal(t, "200 OK", res.Status) // nil POST - req, err = http.NewRequest("POST", "http://localhost:8545/", nil) + req, err = http.NewRequest("POST", fmt.Sprintf("http://localhost:%v/", cfg.RPCPort), nil) require.Nil(t, err) req.Header.Set("Content-Type", "application/json;") @@ -79,7 +90,7 @@ func TestNewHTTPServer(t *testing.T) { require.Equal(t, "200 OK", res.Status) // GET - req, err = http.NewRequest("GET", "http://localhost:8545/", nil) + req, err = http.NewRequest("GET", fmt.Sprintf("http://localhost:%v/", cfg.RPCPort), nil) require.Nil(t, err) req.Header.Set("Content-Type", "application/json;") @@ -90,3 +101,219 @@ func TestNewHTTPServer(t *testing.T) { require.Equal(t, "405 Method Not Allowed", res.Status) } + +func TestUnsafeRPCProtection(t *testing.T) { + cfg := &HTTPServerConfig{ + Modules: []string{"system", "author", "chain", "state", "rpc", "grandpa", "dev"}, + RPCPort: 7878, + RPCAPI: NewService(), + RPCUnsafe: false, + RPCUnsafeExternal: false, + } + + s := NewHTTPServer(cfg) + err := s.Start() + require.NoError(t, err) + + time.Sleep(time.Second) + defer s.Stop() + + for _, unsafe := range modules.UnsafeMethods { + t.Run(fmt.Sprintf("Unsafe method %s should not be reachable", unsafe), func(t *testing.T) { + data := []byte(fmt.Sprintf(`{"jsonrpc":"2.0","method":"%s","params":[],"id":1}`, unsafe)) + + buf := new(bytes.Buffer) + _, err = buf.Write(data) + require.NoError(t, err) + + _, resBody := PostRequest(t, fmt.Sprintf("http://localhost:%v/", cfg.RPCPort), buf) + expected := fmt.Sprintf( + `{"jsonrpc":"2.0","error":{"code":-32000,"message":"unsafe rpc method %s cannot be reachable","data":null},"id":1}`+"\n", + unsafe, + ) + + require.Equal(t, expected, string(resBody)) + }) + } +} +func TestRPCUnsafeExpose(t *testing.T) { + data := []byte(fmt.Sprintf( + `{"jsonrpc":"2.0","method":"%s","params":["%s"],"id":1}`, + "system_addReservedPeer", + "/ip4/198.51.100.19/tcp/30333/p2p/QmSk5HQbn6LhUwDiNMseVUjuRYhEtYj4aUZ6WfWoGURpdV")) + + buf := new(bytes.Buffer) + _, err := buf.Write(data) + require.NoError(t, err) + + netmock := new(mocks.MockNetworkAPI) + netmock.On("AddReservedPeers", mock.AnythingOfType("string")).Return(nil) + + cfg := &HTTPServerConfig{ + Modules: []string{"system"}, + RPCPort: 7879, + RPCAPI: NewService(), + RPCUnsafeExternal: true, + NetworkAPI: netmock, + } + + s := NewHTTPServer(cfg) + err = s.Start() + require.NoError(t, err) + + time.Sleep(time.Second) + defer s.Stop() + + ip, err := externalIP() + require.NoError(t, err) + + _, resBody := PostRequest(t, fmt.Sprintf("http://%s:%v/", ip, cfg.RPCPort), buf) + expected := `{"jsonrpc":"2.0","result":null,"id":1}` + "\n" + require.Equal(t, expected, string(resBody)) +} + +func TestUnsafeRPCJustToLocalhost(t *testing.T) { + unsafeMethod := "system_addReservedPeer" + data := []byte(fmt.Sprintf( + `{"jsonrpc":"2.0","method":"%s","params":["%s"],"id":1}`, + unsafeMethod, + "/ip4/198.51.100.19/tcp/30333/p2p/QmSk5HQbn6LhUwDiNMseVUjuRYhEtYj4aUZ6WfWoGURpdV")) + + buf := new(bytes.Buffer) + _, err := buf.Write(data) + require.NoError(t, err) + + netmock := new(mocks.MockNetworkAPI) + netmock.On("AddReservedPeers", mock.AnythingOfType("string")).Return(nil) + + cfg := &HTTPServerConfig{ + Modules: []string{"system"}, + RPCPort: 7880, + RPCAPI: NewService(), + RPCUnsafe: true, + NetworkAPI: netmock, + } + + s := NewHTTPServer(cfg) + err = s.Start() + require.NoError(t, err) + + time.Sleep(time.Second) + defer s.Stop() + + ip, err := externalIP() + require.NoError(t, err) + + _, resBody := PostRequest(t, fmt.Sprintf("http://%s:7880/", ip), buf) + expected := `{"jsonrpc":"2.0","error":{"code":-32000,"message":"external HTTP request refused","data":null},"id":1}` + "\n" + require.Equal(t, expected, string(resBody)) +} + +func TestRPCExternalEnable_UnsafeExternalNotEnabled(t *testing.T) { + unsafeData := []byte(fmt.Sprintf( + `{"jsonrpc":"2.0","method":"%s","params":["%s"],"id":1}`, + "system_addReservedPeer", + "/ip4/198.51.100.19/tcp/30333/p2p/QmSk5HQbn6LhUwDiNMseVUjuRYhEtYj4aUZ6WfWoGURpdV")) + unsafebuf := new(bytes.Buffer) + unsafebuf.Write(unsafeData) + + safeData := []byte(fmt.Sprintf( + `{"jsonrpc":"2.0","method":"%s","params":[],"id":2}`, + "system_localPeerId")) + safebuf := new(bytes.Buffer) + safebuf.Write(safeData) + + netmock := new(mocks.MockNetworkAPI) + netmock.On("NetworkState").Return(common.NetworkState{ + PeerID: "peer id", + }) + + httpServerConfig := &HTTPServerConfig{ + Modules: []string{"system"}, + RPCPort: 8786, + RPCAPI: NewService(), + RPCUnsafe: true, + RPCUnsafeExternal: false, + RPCExternal: true, + NetworkAPI: netmock, + } + + s := NewHTTPServer(httpServerConfig) + err := s.Start() + require.NoError(t, err) + + time.Sleep(time.Second) + defer s.Stop() + + ip, err := externalIP() + require.NoError(t, err) + + _, resBody := PostRequest(t, fmt.Sprintf("http://%s:%v/", ip, httpServerConfig.RPCPort), safebuf) + encoded := base58.Encode([]byte("peer id")) + expected := fmt.Sprintf(`{"jsonrpc":"2.0","result":"%s","id":2}`, encoded) + "\n" + require.Equal(t, expected, string(resBody)) + + // unsafe method should not be ok + _, resBody = PostRequest(t, fmt.Sprintf("http://%s:%v/", ip, httpServerConfig.RPCPort), unsafebuf) + expected = `{"jsonrpc":"2.0","error":{"code":-32000,"message":"external HTTP request refused","data":null},"id":1}` + "\n" + require.Equal(t, expected, string(resBody)) +} + +func PostRequest(t *testing.T, url string, data io.Reader) (int, []byte) { + t.Helper() + + req, err := http.NewRequest(http.MethodPost, url, data) + require.NoError(t, err) + + req.Header.Set("Content-Type", "application/json") + res, err := new(http.Client).Do(req) + require.NoError(t, err) + + defer res.Body.Close() + + resBody, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + + responseData := new(bytes.Buffer) + _, err = responseData.Write(resBody) + require.NoError(t, err) + + return res.StatusCode, responseData.Bytes() +} + +func externalIP() (string, error) { + ifaces, err := net.Interfaces() + if err != nil { + return "", err + } + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 { + continue // interface down + } + if iface.Flags&net.FlagLoopback != 0 { + continue // loopback interface + } + addrs, err := iface.Addrs() + if err != nil { + return "", err + } + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if ip == nil || ip.IsLoopback() { + continue + } + ip = ip.To4() + if ip == nil { + continue // not an ipv4 address + } + return ip.String(), nil + } + } + return "", errors.New("are you connected to the network?") +} diff --git a/dot/rpc/modules/rpc.go b/dot/rpc/modules/rpc.go index 560492c9c6..05f0201e2b 100644 --- a/dot/rpc/modules/rpc.go +++ b/dot/rpc/modules/rpc.go @@ -16,7 +16,24 @@ package modules -import "net/http" +import ( + "net/http" +) + +var ( + // UnsafeMethods is a list of all unsafe rpc methods of https://github.com/w3f/PSPs/blob/master/PSPs/drafts/psp-6.md + UnsafeMethods = []string{ + "system_addReservedPeer", + "system_removeReservedPeer", + "author_submitExtrinsic", + "author_removeExtrinsic", + "author_insertKey", + "author_rotateKeys", + "state_getPairs", + "state_getKeysPaged", + "state_queryStorage", + } +) // RPCModule is a RPC module providing access to RPC methods type RPCModule struct { @@ -41,3 +58,14 @@ func (rm *RPCModule) Methods(r *http.Request, req *EmptyRequest, res *MethodsRes return nil } + +// IsUnsafe returns true if the `name` has the suffix +func IsUnsafe(name string) bool { + for _, unsafe := range UnsafeMethods { + if name == unsafe { + return true + } + } + + return false +} diff --git a/dot/rpc/service.go b/dot/rpc/service.go index 35fbb41979..cad18df08c 100644 --- a/dot/rpc/service.go +++ b/dot/rpc/service.go @@ -17,6 +17,7 @@ package rpc import ( + "fmt" "net/http" "reflect" "strings" @@ -85,7 +86,8 @@ func (s *Service) BuildMethodNames(rcvr interface{}, name string) { continue } - s.rpcMethods = append(s.rpcMethods, name+"_"+strings.ToLower(string(method.Name[0]))+method.Name[1:]) + s.rpcMethods = append(s.rpcMethods, + fmt.Sprintf("%s_%s%s", name, strings.ToLower(string(method.Name[0])), method.Name[1:])) } } diff --git a/dot/rpc/subscription/listeners.go b/dot/rpc/subscription/listeners.go index 3aa9287e88..01a4ad896d 100644 --- a/dot/rpc/subscription/listeners.go +++ b/dot/rpc/subscription/listeners.go @@ -60,7 +60,7 @@ type WSConnAPI interface { type StorageObserver struct { id uint32 filter map[string][]byte - wsconn WSConnAPI + wsconn *WSConn } // Change type defining key value pair representing change @@ -107,7 +107,10 @@ func (s *StorageObserver) GetFilter() map[string][]byte { func (s *StorageObserver) Listen() {} // Stop to satisfy Listener interface (but is no longer used by StorageObserver) -func (s *StorageObserver) Stop() error { return nil } +func (s *StorageObserver) Stop() error { + s.wsconn.StorageAPI.UnregisterStorageObserver(s) + return nil +} // BlockListener to handle listening for blocks importedChan type BlockListener struct { @@ -227,7 +230,6 @@ type ExtrinsicSubmitListener struct { // Listen implementation of Listen interface to listen for importedChan changes func (l *ExtrinsicSubmitListener) Listen() { - // listen for imported blocks with extrinsic go func() { defer func() { diff --git a/dot/rpc/subscription/listeners_test.go b/dot/rpc/subscription/listeners_test.go index 4dc89479b2..6a4c612ad2 100644 --- a/dot/rpc/subscription/listeners_test.go +++ b/dot/rpc/subscription/listeners_test.go @@ -51,10 +51,12 @@ func (m *MockWSConnAPI) safeSend(msg interface{}) { } func TestStorageObserver_Update(t *testing.T) { - mockConnection := &MockWSConnAPI{} + wsconn, ws, cancel := setupWSConn(t) + defer cancel() + storageObserver := StorageObserver{ id: 0, - wsconn: mockConnection, + wsconn: wsconn, } data := []state.KeyValue{{ @@ -74,13 +76,20 @@ func TestStorageObserver_Update(t *testing.T) { expected.Changes[i] = Change{common.BytesToHex(v.Key), common.BytesToHex(v.Value)} } - expectedRespones := newSubcriptionBaseResponseJSON() - expectedRespones.Method = "state_storage" - expectedRespones.Params.Result = expected + expectedResponse := newSubcriptionBaseResponseJSON() + expectedResponse.Method = stateStorageMethod + expectedResponse.Params.Result = expected storageObserver.Update(change) - time.Sleep(time.Millisecond * 10) - require.Equal(t, expectedRespones, mockConnection.lastMessage) + time.Sleep(time.Millisecond * 100) + + _, msg, err := ws.ReadMessage() + require.NoError(t, err) + + expectedResponseBytes, err := json.Marshal(expectedResponse) + require.NoError(t, err) + + require.Equal(t, string(expectedResponseBytes)+"\n", string(msg)) } func TestBlockListener_Listen(t *testing.T) { diff --git a/dot/rpc/subscription/subscription.go b/dot/rpc/subscription/subscription.go index 9dfe6cba70..cbf7648c78 100644 --- a/dot/rpc/subscription/subscription.go +++ b/dot/rpc/subscription/subscription.go @@ -6,56 +6,56 @@ import ( "strconv" ) -var errUknownParamSubscribeID = errors.New("invalid params format type") -var errCannotParseID = errors.New("could not parse param id") -var errCannotFindListener = errors.New("could not find listener") -var errCannotFindUnsubsriber = errors.New("could not find unsubsriber function") +const ( + authorSubmitAndWatchExtrinsic string = "author_submitAndWatchExtrinsic" //nolint + chainSubscribeNewHeads string = "chain_subscribeNewHeads" + chainSubscribeNewHead string = "chain_subscribeNewHead" + chainSubscribeFinalizedHeads string = "chain_subscribeFinalizedHeads" + stateSubscribeStorage string = "state_subscribeStorage" + stateSubscribeRuntimeVersion string = "state_subscribeRuntimeVersion" + grandpaSubscribeJustifications string = "grandpa_subscribeJustifications" +) -type unsubListener func(reqid float64, l Listener, params interface{}) type setupListener func(reqid float64, params interface{}) (Listener, error) +var ( + errUknownParamSubscribeID = errors.New("invalid params format type") + errCannotParseID = errors.New("could not parse param id") + errCannotFindListener = errors.New("could not find listener") + errCannotFindUnsubsriber = errors.New("could not find unsubsriber function") +) + func (c *WSConn) getSetupListener(method string) setupListener { switch method { - case "chain_subscribeNewHeads", "chain_subscribeNewHead": + case authorSubmitAndWatchExtrinsic: + return c.initExtrinsicWatch + case chainSubscribeNewHeads, chainSubscribeNewHead: return c.initBlockListener - case "state_subscribeStorage": + case stateSubscribeStorage: return c.initStorageChangeListener - case "chain_subscribeFinalizedHeads": + case chainSubscribeFinalizedHeads: return c.initBlockFinalizedListener - case "state_subscribeRuntimeVersion": + case stateSubscribeRuntimeVersion: return c.initRuntimeVersionListener - case "grandpa_subscribeJustifications": + case grandpaSubscribeJustifications: return c.initGrandpaJustificationListener default: return nil } } -func (c *WSConn) getUnsubListener(method string, params interface{}) (unsubListener, Listener, error) { +func (c *WSConn) getUnsubListener(params interface{}) (Listener, error) { subscribeID, err := parseSubscribeID(params) if err != nil { - return nil, nil, err + return nil, err } listener, ok := c.Subscriptions[subscribeID] if !ok { - return nil, nil, fmt.Errorf("subscriber id %v: %w", subscribeID, errCannotFindListener) - } - - var unsub unsubListener - - switch method { - case "state_unsubscribeStorage": - unsub = c.unsubscribeStorageListener - case "state_unsubscribeRuntimeVersion": - unsub = c.unsubscribeRuntimeVersionListener - case "grandpa_unsubscribeJustifications": - unsub = c.unsubscribeGrandpaJustificationListener - default: - return nil, nil, errCannotFindUnsubsriber + return nil, fmt.Errorf("subscriber id %v: %w", subscribeID, errCannotFindListener) } - return unsub, listener, nil + return listener, nil } func parseSubscribeID(p interface{}) (uint32, error) { diff --git a/dot/rpc/subscription/websocket.go b/dot/rpc/subscription/websocket.go index f53e51f952..f3ede55f39 100644 --- a/dot/rpc/subscription/websocket.go +++ b/dot/rpc/subscription/websocket.go @@ -29,7 +29,6 @@ import ( "sync/atomic" "github.com/ChainSafe/gossamer/dot/rpc/modules" - "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/runtime" @@ -50,6 +49,7 @@ const DEFAULT_BUFFER_SIZE = 100 // WSConn struct to hold WebSocket Connection references type WSConn struct { + UnsafeEnabled bool Wsconn *websocket.Conn mu sync.Mutex qtyListeners uint32 @@ -104,9 +104,14 @@ func (c *WSConn) HandleComm() { logger.Debug("ws method called", "method", method, "params", params) - if strings.Contains(method, "_subscribe") { + if !strings.Contains(method, "_unsubscribe") && !strings.Contains(method, "_unwatch") { setup := c.getSetupListener(method) + if setup == nil { + c.executeRPCCall(mbytes) + continue + } + listener, err := setup(reqid, params) //nolint if err != nil { logger.Warn("failed to create listener", "method", method, "error", err) @@ -117,61 +122,48 @@ func (c *WSConn) HandleComm() { continue } - if strings.Contains(method, "_unsubscribe") { - unsub, listener, err := c.getUnsubListener(method, params) //nolint - - if err != nil { - logger.Warn("failed to get unsubscriber", "method", method, "error", err) - - if errors.Is(err, errUknownParamSubscribeID) || errors.Is(err, errCannotFindUnsubsriber) { - c.safeSendError(reqid, big.NewInt(InvalidRequestCode), InvalidRequestMessage) - continue - } + listener, err := c.getUnsubListener(params) //nolint - if errors.Is(err, errCannotParseID) || errors.Is(err, errCannotFindListener) { - c.safeSend(newBooleanResponseJSON(false, reqid)) - continue - } - } - - unsub(reqid, listener, params) - err = listener.Stop() + if err != nil { + logger.Warn("failed to get unsubscriber", "method", method, "error", err) - if err != nil { - logger.Warn("failed to cancel listener goroutine", "method", method, "error", err) + if errors.Is(err, errUknownParamSubscribeID) || errors.Is(err, errCannotFindUnsubsriber) { + c.safeSendError(reqid, big.NewInt(InvalidRequestCode), InvalidRequestMessage) + continue } - continue - } - - if strings.Contains(method, "submitAndWatchExtrinsic") { - listener, err := c.initExtrinsicWatch(reqid, params) //nolint - if err != nil { - logger.Warn("failed to create listener", "method", method, "error", err) - c.safeSendError(reqid, nil, err.Error()) + if errors.Is(err, errCannotParseID) || errors.Is(err, errCannotFindListener) { + c.safeSend(newBooleanResponseJSON(false, reqid)) continue } - - listener.Listen() - continue } - // handle non-subscribe calls - request, err := c.prepareRequest(mbytes) + err = listener.Stop() if err != nil { - logger.Warn("failed while preparing the request", "error", err) - return + logger.Warn("failed to cancel listener goroutine", "method", method, "error", err) + c.safeSend(newBooleanResponseJSON(false, reqid)) } - var wsresponse interface{} - err = c.executeRequest(request, &wsresponse) - if err != nil { - logger.Warn("problems while executing the request", "error", err) - return - } + c.safeSend(newBooleanResponseJSON(true, reqid)) + continue + } +} - c.safeSend(wsresponse) +func (c *WSConn) executeRPCCall(data []byte) { + request, err := c.prepareRequest(data) + if err != nil { + logger.Warn("failed while preparing the request", "error", err) + return + } + + var wsresponse interface{} + err = c.executeRequest(request, &wsresponse) + if err != nil { + logger.Warn("problems while executing the request", "error", err) + return } + + c.safeSend(wsresponse) } func (c *WSConn) initStorageChangeListener(reqID float64, params interface{}) (Listener, error) { @@ -220,18 +212,6 @@ func (c *WSConn) initStorageChangeListener(reqID float64, params interface{}) (L return stgobs, nil } -func (c *WSConn) unsubscribeStorageListener(reqID float64, l Listener, _ interface{}) { - observer, ok := l.(state.Observer) - if !ok { - initRes := newBooleanResponseJSON(false, reqID) - c.safeSend(initRes) - return - } - - c.StorageAPI.UnregisterStorageObserver(observer) - c.safeSend(newBooleanResponseJSON(true, reqID)) -} - func (c *WSConn) initBlockListener(reqID float64, _ interface{}) (Listener, error) { bl := &BlockListener{ Channel: make(chan *types.Block, DEFAULT_BUFFER_SIZE), @@ -338,6 +318,7 @@ func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (Listener err = c.CoreAPI.HandleSubmittedExtrinsic(extBytes) if err != nil { + c.safeSendError(reqID, nil, err.Error()) return nil, err } c.safeSend(NewSubscriptionResponseJSON(esl.subID, reqID)) @@ -381,19 +362,6 @@ func (c *WSConn) initRuntimeVersionListener(reqID float64, _ interface{}) (Liste return rvl, nil } -func (c *WSConn) unsubscribeRuntimeVersionListener(reqID float64, l Listener, _ interface{}) { - observer, ok := l.(VersionListener) - if !ok { - initRes := newBooleanResponseJSON(false, reqID) - c.safeSend(initRes) - return - } - id := observer.GetChannelID() - - res := c.BlockAPI.UnregisterRuntimeUpdatedChannel(id) - c.safeSend(newBooleanResponseJSON(res, reqID)) -} - func (c *WSConn) initGrandpaJustificationListener(reqID float64, _ interface{}) (Listener, error) { if c.BlockAPI == nil { c.safeSendError(reqID, nil, "error BlockAPI not set") @@ -426,17 +394,6 @@ func (c *WSConn) initGrandpaJustificationListener(reqID float64, _ interface{}) return jl, nil } -func (c *WSConn) unsubscribeGrandpaJustificationListener(reqID float64, l Listener, params interface{}) { - listener, ok := l.(*GrandpaJustificationListener) - if !ok { - c.safeSend(newBooleanResponseJSON(false, reqID)) - return - } - - c.BlockAPI.UnregisterFinalisedChannel(listener.finalisedChID) - c.safeSend(newBooleanResponseJSON(true, reqID)) -} - func (c *WSConn) safeSend(msg interface{}) { c.mu.Lock() defer c.mu.Unlock() diff --git a/dot/rpc/subscription/websocket_test.go b/dot/rpc/subscription/websocket_test.go index f3475376c2..74d3c88796 100644 --- a/dot/rpc/subscription/websocket_test.go +++ b/dot/rpc/subscription/websocket_test.go @@ -25,8 +25,6 @@ func TestWSConn_HandleComm(t *testing.T) { go wsconn.HandleComm() time.Sleep(time.Second * 2) - fmt.Println("ws defined") - // test storageChangeListener res, err := wsconn.initStorageChangeListener(1, nil) require.Nil(t, res) diff --git a/dot/rpc/websocket_test.go b/dot/rpc/websocket_test.go index 8520b6ad8e..8d378552b0 100644 --- a/dot/rpc/websocket_test.go +++ b/dot/rpc/websocket_test.go @@ -55,17 +55,17 @@ func TestHTTPServer_ServeHTTP(t *testing.T) { bAPI := modules.NewMockBlockAPI() sAPI := modules.NewMockStorageAPI() cfg := &HTTPServerConfig{ - Modules: []string{"system", "chain"}, - External: false, - RPCPort: 8545, - WSPort: 8546, - WS: true, - WSExternal: false, - RPCAPI: NewService(), - CoreAPI: coreAPI, - SystemAPI: sysAPI, - BlockAPI: bAPI, - StorageAPI: sAPI, + Modules: []string{"system", "chain"}, + RPCExternal: false, + RPCPort: 8545, + WSPort: 8546, + WS: true, + WSExternal: false, + RPCAPI: NewService(), + CoreAPI: coreAPI, + SystemAPI: sysAPI, + BlockAPI: bAPI, + StorageAPI: sAPI, } s := NewHTTPServer(cfg) diff --git a/dot/services.go b/dot/services.go index a74a85db8c..2df030659b 100644 --- a/dot/services.go +++ b/dot/services.go @@ -343,11 +343,16 @@ func createRPCService(cfg *Config, stateSrvc *state.Service, coreSrvc *core.Serv TransactionQueueAPI: stateSrvc.Transaction, RPCAPI: rpcService, SystemAPI: sysSrvc, - External: cfg.RPC.External, + RPC: cfg.RPC.Enabled, + RPCExternal: cfg.RPC.External, + RPCUnsafe: cfg.RPC.Unsafe, + RPCUnsafeExternal: cfg.RPC.UnsafeExternal, Host: cfg.RPC.Host, RPCPort: cfg.RPC.Port, WS: cfg.RPC.WS, WSExternal: cfg.RPC.WSExternal, + WSUnsafe: cfg.RPC.WSUnsafe, + WSUnsafeExternal: cfg.RPC.WSUnsafeExternal, WSPort: cfg.RPC.WSPort, Modules: cfg.RPC.Modules, } diff --git a/tests/utils/gossamer_utils.go b/tests/utils/gossamer_utils.go index c85e4476c7..72bd2a6eb0 100644 --- a/tests/utils/gossamer_utils.go +++ b/tests/utils/gossamer_utils.go @@ -449,10 +449,12 @@ func generateDefaultConfig() *ctoml.Config { MaxPeers: 3, }, RPC: ctoml.RPCConfig{ - Enabled: false, - Host: "localhost", - Modules: []string{"system", "author", "chain", "state"}, - WS: false, + Enabled: false, + Unsafe: true, + WSUnsafe: true, + Host: "localhost", + Modules: []string{"system", "author", "chain", "state"}, + WS: false, }, } }