diff --git a/dot/core/service.go b/dot/core/service.go index f4db0e7fdb..f59f3bef92 100644 --- a/dot/core/service.go +++ b/dot/core/service.go @@ -461,6 +461,16 @@ func (s *Service) HasKey(pubKeyStr, keyType string) (bool, error) { return keystore.HasKey(pubKeyStr, keyType, s.keys.Acco) } +// DecodeSessionKeys executes the runtime DecodeSessionKeys and return the scale encoded keys +func (s *Service) DecodeSessionKeys(enc []byte) ([]byte, error) { + rt, err := s.blockState.GetRuntime(nil) + if err != nil { + return nil, err + } + + return rt.DecodeSessionKeys(enc) +} + // GetRuntimeVersion gets the current RuntimeVersion func (s *Service) GetRuntimeVersion(bhash *common.Hash) (runtime.Version, error) { var stateRootHash *common.Hash diff --git a/dot/rpc/http.go b/dot/rpc/http.go index 67649e306f..4bad9a0bd6 100644 --- a/dot/rpc/http.go +++ b/dot/rpc/http.go @@ -233,15 +233,13 @@ func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // NewWSConn to create new WebSocket Connection struct func NewWSConn(conn *websocket.Conn, cfg *HTTPServerConfig) *subscription.WSConn { c := &subscription.WSConn{ - Wsconn: conn, - Subscriptions: make(map[uint]subscription.Listener), - BlockSubChannels: make(map[uint]byte), - StorageSubChannels: make(map[int]byte), - StorageAPI: cfg.StorageAPI, - BlockAPI: cfg.BlockAPI, - CoreAPI: cfg.CoreAPI, - TxStateAPI: cfg.TransactionQueueAPI, - RPCHost: fmt.Sprintf("http://%s:%d/", cfg.Host, cfg.RPCPort), + Wsconn: conn, + Subscriptions: make(map[uint32]subscription.Listener), + StorageAPI: cfg.StorageAPI, + BlockAPI: cfg.BlockAPI, + CoreAPI: cfg.CoreAPI, + TxStateAPI: cfg.TransactionQueueAPI, + RPCHost: fmt.Sprintf("http://%s:%d/", cfg.Host, cfg.RPCPort), HTTP: &http.Client{ Timeout: time.Second * 30, }, diff --git a/dot/rpc/modules/api.go b/dot/rpc/modules/api.go index 5dc1e3328c..beb9cf303c 100644 --- a/dot/rpc/modules/api.go +++ b/dot/rpc/modules/api.go @@ -76,6 +76,7 @@ type CoreAPI interface { GetRuntimeVersion(bhash *common.Hash) (runtime.Version, error) HandleSubmittedExtrinsic(types.Extrinsic) error GetMetadata(bhash *common.Hash) ([]byte, error) + DecodeSessionKeys(enc []byte) ([]byte, error) } // RPCAPI is the interface for methods related to RPC service diff --git a/dot/rpc/modules/author.go b/dot/rpc/modules/author.go index c0dd36079d..51394a2bab 100644 --- a/dot/rpc/modules/author.go +++ b/dot/rpc/modules/author.go @@ -17,13 +17,14 @@ package modules import ( - "fmt" + "errors" "net/http" - "reflect" + "strings" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/keystore" + "github.com/ChainSafe/gossamer/pkg/scale" log "github.com/ChainSafe/log15" ) @@ -35,8 +36,17 @@ type AuthorModule struct { txStateAPI TransactionStateAPI } +// HasSessionKeyRequest is used to receive the rpc data +type HasSessionKeyRequest struct { + PublicKeys string +} + // KeyInsertRequest is used as model for the JSON -type KeyInsertRequest []string +type KeyInsertRequest struct { + Type string + Seed string + PublicKey string +} // Extrinsic represents a hex-encoded extrinsic type Extrinsic struct { @@ -64,6 +74,18 @@ type RemoveExtrinsicsResponse []common.Hash // KeyRotateResponse is a byte array used to rotate type KeyRotateResponse []byte +// HasSessionKeyResponse is the response to the RPC call author_hasSessionKeys +type HasSessionKeyResponse bool + +// KeyTypeID represents the key type of a session key +type keyTypeID [4]uint8 + +// DecodedKey is the representation of a scaled decoded public key +type decodedKey struct { + Data []uint8 + Type keyTypeID +} + // ExtrinsicStatus holds the actual valid statuses type ExtrinsicStatus struct { IsFuture bool @@ -94,27 +116,66 @@ func NewAuthorModule(logger log.Logger, coreAPI CoreAPI, txStateAPI TransactionS } } -// InsertKey inserts a key into the keystore -func (am *AuthorModule) InsertKey(r *http.Request, req *KeyInsertRequest, res *KeyInsertResponse) error { - keyReq := *req +// HasSessionKeys checks if the keystore has private keys for the given session public keys. +func (am *AuthorModule) HasSessionKeys(r *http.Request, req *HasSessionKeyRequest, res *HasSessionKeyResponse) error { + pubKeysBytes, err := common.HexToBytes(req.PublicKeys) + if err != nil { + return err + } + + pkeys, err := scale.Marshal(pubKeysBytes) + if err != nil { + return err + } - pkDec, err := common.HexToBytes(keyReq[1]) + data, err := am.coreAPI.DecodeSessionKeys(pkeys) if err != nil { + *res = false return err } - privateKey, err := keystore.DecodePrivateKey(pkDec, keystore.DetermineKeyType(keyReq[0])) + var decodedKeys *[]decodedKey + err = scale.Unmarshal(data, &decodedKeys) + if err != nil { + return err + } + + if decodedKeys == nil || len(*decodedKeys) < 1 { + *res = false + return nil + } + + for _, key := range *decodedKeys { + encType := keystore.Name(key.Type[:]) + ok, err := am.coreAPI.HasKey(common.BytesToHex(key.Data), string(encType)) + + if err != nil || !ok { + *res = false + return err + } + } + + *res = true + return nil +} + +// InsertKey inserts a key into the keystore +func (am *AuthorModule) InsertKey(r *http.Request, req *KeyInsertRequest, res *KeyInsertResponse) error { + keyReq := *req + + keyBytes, err := common.HexToBytes(req.Seed) if err != nil { return err } - keyPair, err := keystore.PrivateKeyToKeypair(privateKey) + keyPair, err := keystore.DecodeKeyPairFromHex(keyBytes, keystore.DetermineKeyType(keyReq.Type)) if err != nil { return err } - if !reflect.DeepEqual(keyPair.Public().Hex(), keyReq[2]) { - return fmt.Errorf("generated public key does not equal provide public key") + //strings.EqualFold compare using case-insensitivity. + if !strings.EqualFold(keyPair.Public().Hex(), keyReq.PublicKey) { + return errors.New("generated public key does not equal provide public key") } am.coreAPI.InsertKey(keyPair) diff --git a/dot/rpc/modules/author_test.go b/dot/rpc/modules/author_test.go index 4de36c2b33..fb79cf955f 100644 --- a/dot/rpc/modules/author_test.go +++ b/dot/rpc/modules/author_test.go @@ -8,13 +8,78 @@ import ( apimocks "github.com/ChainSafe/gossamer/dot/rpc/modules/mocks" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/crypto/sr25519" "github.com/ChainSafe/gossamer/lib/keystore" + "github.com/ChainSafe/gossamer/lib/runtime" + "github.com/ChainSafe/gossamer/lib/runtime/wasmer" "github.com/ChainSafe/gossamer/lib/transaction" log "github.com/ChainSafe/log15" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) +func TestAuthorModule_HasSessionKey(t *testing.T) { + globalStore := keystore.NewGlobalKeystore() + + coremockapi := new(apimocks.MockCoreAPI) + mockInsertKey := coremockapi.On("InsertKey", mock.AnythingOfType("*sr25519.Keypair")) + mockInsertKey.Run(func(args mock.Arguments) { + kp := args.Get(0).(*sr25519.Keypair) + globalStore.Acco.Insert(kp) + }) + + mockHasKey := coremockapi.On("HasKey", mock.AnythingOfType("string"), mock.AnythingOfType("string")) + mockHasKey.Run(func(args mock.Arguments) { + pubKeyHex := args.Get(0).(string) + keyType := args.Get(1).(string) + + ok, err := keystore.HasKey(pubKeyHex, keyType, globalStore.Acco) + mockHasKey.ReturnArguments = []interface{}{ok, err} + }) + + keys := "0xd43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d34309a9d2a24213896ff06895db16aade8b6502f3a71cf56374cc3852042602634309a9d2a24213896ff06895db16aade8b6502f3a71cf56374cc3852042602634309a9d2a24213896ff06895db16aade8b6502f3a71cf56374cc38520426026" + runtimeInstance := wasmer.NewTestInstance(t, runtime.NODE_RUNTIME) + + decodeSessionKeysMock := coremockapi.On("DecodeSessionKeys", mock.AnythingOfType("[]uint8")) + decodeSessionKeysMock.Run(func(args mock.Arguments) { + b := args.Get(0).([]byte) + dec, err := runtimeInstance.DecodeSessionKeys(b) + decodeSessionKeysMock.ReturnArguments = []interface{}{dec, err} + }) + + module := &AuthorModule{ + coreAPI: coremockapi, + logger: log.New("service", "RPC", "module", "author"), + } + + req := &HasSessionKeyRequest{ + PublicKeys: keys, + } + + err := module.InsertKey(nil, &KeyInsertRequest{ + Type: "babe", + Seed: "0xfec0f475b818470af5caf1f3c1b1558729961161946d581d2755f9fb566534f8", + PublicKey: "0x34309a9d2a24213896ff06895db16aade8b6502f3a71cf56374cc38520426026", + }, nil) + coremockapi.AssertCalled(t, "InsertKey", mock.AnythingOfType("*sr25519.Keypair")) + require.NoError(t, err) + require.Equal(t, 1, globalStore.Acco.Size()) + + err = module.InsertKey(nil, &KeyInsertRequest{ + Type: "babe", + Seed: "0xe5be9a5092b81bca64be81d212e7f2f9eba183bb7a90954f7b76361f6edb5c0a", + PublicKey: "0xd43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d", + }, nil) + require.NoError(t, err) + require.Equal(t, 2, globalStore.Acco.Size()) + + var res HasSessionKeyResponse + err = module.HasSessionKeys(nil, req, &res) + require.NoError(t, err) + require.True(t, bool(res)) +} + func TestAuthorModule_SubmitExtrinsic(t *testing.T) { errMockCoreAPI := &apimocks.MockCoreAPI{} errMockCoreAPI.On("HandleSubmittedExtrinsic", mock.AnythingOfType("types.Extrinsic")).Return(fmt.Errorf("some error")) @@ -202,8 +267,8 @@ func TestAuthorModule_InsertKey(t *testing.T) { args: args{ req: &KeyInsertRequest{ "babe", - "0xb7e9185065667390d2ad952a5324e8c365c9bf503dcf97c67a5ce861afe97309", "0x6246ddf254e0b4b4e7dffefc8adf69d212b98ac2b579c362b473fec8c40b4c0a", + "0xdad5131003242c37c227f744f82118dd59a24b949ae264a93d949100738c196c", }, }, }, @@ -214,9 +279,10 @@ func TestAuthorModule_InsertKey(t *testing.T) { coreAPI: mockCoreAPI, }, args: args{ - req: &KeyInsertRequest{"gran", - "0xb7e9185065667390d2ad952a5324e8c365c9bf503dcf97c67a5ce861afe97309b7e9185065667390d2ad952a5324e8c365c9bf503dcf97c67a5ce861afe97309", - "0xb7e9185065667390d2ad952a5324e8c365c9bf503dcf97c67a5ce861afe97309", + req: &KeyInsertRequest{ + "gran", + "0xb48004c6e1625282313b07d1c9950935e86894a2e4f21fb1ffee9854d180c781", + "0xa7d6507d59f8871b8f1a0f2c32e219adfacff4c9fcb05b0b2d8ebd6a65c88ee6", }, }, }, diff --git a/dot/rpc/modules/mocks/core_api.go b/dot/rpc/modules/mocks/core_api.go index 380de34353..beda95a749 100644 --- a/dot/rpc/modules/mocks/core_api.go +++ b/dot/rpc/modules/mocks/core_api.go @@ -18,6 +18,29 @@ type MockCoreAPI struct { mock.Mock } +// DecodeSessionKeys provides a mock function with given fields: enc +func (_m *MockCoreAPI) DecodeSessionKeys(enc []byte) ([]byte, error) { + ret := _m.Called(enc) + + var r0 []byte + if rf, ok := ret.Get(0).(func([]byte) []byte); ok { + r0 = rf(enc) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(enc) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetMetadata provides a mock function with given fields: bhash func (_m *MockCoreAPI) GetMetadata(bhash *common.Hash) ([]byte, error) { ret := _m.Called(bhash) @@ -103,17 +126,3 @@ func (_m *MockCoreAPI) HasKey(pubKeyStr string, keyType string) (bool, error) { func (_m *MockCoreAPI) InsertKey(kp crypto.Keypair) { _m.Called(kp) } - -// IsBlockProducer provides a mock function with given fields: -func (_m *MockCoreAPI) IsBlockProducer() bool { - ret := _m.Called() - - var r0 bool - if rf, ok := ret.Get(0).(func() bool); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(bool) - } - - return r0 -} diff --git a/dot/rpc/service_test.go b/dot/rpc/service_test.go index 64e95f4ee8..1661028388 100644 --- a/dot/rpc/service_test.go +++ b/dot/rpc/service_test.go @@ -35,7 +35,7 @@ func TestNewService(t *testing.T) { func TestService_Methods(t *testing.T) { qtySystemMethods := 13 qtyRPCMethods := 1 - qtyAuthorMethods := 7 + qtyAuthorMethods := 8 rpcService := NewService() sysMod := modules.NewSystemModule(nil, nil, nil, nil, nil, nil) diff --git a/dot/rpc/subscription/listeners.go b/dot/rpc/subscription/listeners.go index ba9e4ca7b6..7845624fa6 100644 --- a/dot/rpc/subscription/listeners.go +++ b/dot/rpc/subscription/listeners.go @@ -16,9 +16,11 @@ package subscription import ( - "context" + "errors" "fmt" + "math/big" "reflect" + "time" "github.com/ChainSafe/gossamer/dot/rpc/modules" "github.com/ChainSafe/gossamer/dot/state" @@ -26,10 +28,26 @@ import ( "github.com/ChainSafe/gossamer/lib/common" ) +const ( + grandpaJustificationsMethod = "grandpa_justifications" + stateRuntimeVersionMethod = "state_runtimeVersion" + authorExtrinsicUpdatesMethod = "author_extrinsicUpdate" + chainFinalizedHeadMethod = "chain_finalizedHead" + chainNewHeadMethod = "chain_newHead" + stateStorageMethod = "state_storage" +) + +var ( + // ErrCannotCancel when is not possible to cancel a goroutine after `cancelTimeout` seconds + ErrCannotCancel = errors.New("cannot cancel listen goroutines") + + defaultCancelTimeout = time.Second * 10 +) + // Listener interface for functions that define Listener related functions type Listener interface { Listen() - Stop() + Stop() error } // WSConnAPI interface defining methors a WSConn should have @@ -39,7 +57,7 @@ type WSConnAPI interface { // StorageObserver struct to hold data for observer (Observer Design Pattern) type StorageObserver struct { - id uint + id uint32 filter map[string][]byte wsconn WSConnAPI } @@ -68,15 +86,15 @@ func (s *StorageObserver) Update(change *state.SubscriptionResult) { } res := newSubcriptionBaseResponseJSON() - res.Method = "state_storage" + res.Method = stateStorageMethod res.Params.Result = changeResult - res.Params.SubscriptionID = s.GetID() + res.Params.SubscriptionID = s.id s.wsconn.safeSend(res) } // GetID the id for the Observer func (s *StorageObserver) GetID() uint { - return s.id + return uint(s.id) } // GetFilter returns the filter the Observer is using @@ -88,26 +106,30 @@ func (s *StorageObserver) GetFilter() map[string][]byte { func (s *StorageObserver) Listen() {} // Stop to satisfy Listener interface (but is no longer used by StorageObserver) -func (s *StorageObserver) Stop() {} +func (s *StorageObserver) Stop() error { return nil } // BlockListener to handle listening for blocks importedChan type BlockListener struct { - Channel chan *types.Block - wsconn WSConnAPI - ChanID byte - subID uint - - ctx context.Context - cancel context.CancelFunc + Channel chan *types.Block + wsconn *WSConn + ChanID byte + subID uint32 + done chan struct{} + cancel chan struct{} + cancelTimeout time.Duration } // Listen implementation of Listen interface to listen for importedChan changes func (l *BlockListener) Listen() { - l.ctx, l.cancel = context.WithCancel(context.Background()) go func() { + defer func() { + l.wsconn.BlockAPI.UnregisterImportedChannel(l.ChanID) + close(l.done) + }() + for { select { - case <-l.ctx.Done(): + case <-l.cancel: return case block, ok := <-l.Channel: if !ok { @@ -123,7 +145,7 @@ func (l *BlockListener) Listen() { } res := newSubcriptionBaseResponseJSON() - res.Method = "chain_newHead" + res.Method = chainNewHeadMethod res.Params.Result = head res.Params.SubscriptionID = l.subID l.wsconn.safeSend(res) @@ -133,26 +155,32 @@ func (l *BlockListener) Listen() { } // Stop to cancel the running goroutines to this listener -func (l *BlockListener) Stop() { l.cancel() } +func (l *BlockListener) Stop() error { + return cancelWithTimeout(l.cancel, l.done, l.cancelTimeout) +} // BlockFinalizedListener to handle listening for finalised blocks type BlockFinalizedListener struct { - channel chan *types.FinalisationInfo - wsconn WSConnAPI - chanID byte - subID uint - ctx context.Context - cancel context.CancelFunc + channel chan *types.FinalisationInfo + wsconn *WSConn + chanID byte + subID uint32 + done chan struct{} + cancel chan struct{} + cancelTimeout time.Duration } // Listen implementation of Listen interface to listen for importedChan changes func (l *BlockFinalizedListener) Listen() { - l.ctx, l.cancel = context.WithCancel(context.Background()) - go func() { + defer func() { + l.wsconn.BlockAPI.UnregisterFinalisedChannel(l.chanID) + close(l.done) + }() + for { select { - case <-l.ctx.Done(): + case <-l.cancel: return case info, ok := <-l.channel: if !ok { @@ -167,7 +195,7 @@ func (l *BlockFinalizedListener) Listen() { logger.Error("failed to convert header to JSON", "error", err) } res := newSubcriptionBaseResponseJSON() - res.Method = "chain_finalizedHead" + res.Method = chainFinalizedHeadMethod res.Params.Result = head res.Params.SubscriptionID = l.subID l.wsconn.safeSend(res) @@ -177,36 +205,39 @@ func (l *BlockFinalizedListener) Listen() { } // Stop to cancel the running goroutines to this listener -func (l *BlockFinalizedListener) Stop() { l.cancel() } +func (l *BlockFinalizedListener) Stop() error { + return cancelWithTimeout(l.cancel, l.done, l.cancelTimeout) +} // ExtrinsicSubmitListener to handle listening for extrinsic events type ExtrinsicSubmitListener struct { - wsconn WSConnAPI - subID uint - extrinsic types.Extrinsic - + wsconn *WSConn + subID uint32 + extrinsic types.Extrinsic importedChan chan *types.Block importedChanID byte importedHash common.Hash finalisedChan chan *types.FinalisationInfo finalisedChanID byte - - ctx context.Context - cancel context.CancelFunc + done chan struct{} + cancel chan struct{} + cancelTimeout time.Duration } -// AuthorExtrinsicUpdates method name -const AuthorExtrinsicUpdates = "author_extrinsicUpdate" - // Listen implementation of Listen interface to listen for importedChan changes func (l *ExtrinsicSubmitListener) Listen() { - l.ctx, l.cancel = context.WithCancel(context.Background()) // listen for imported blocks with extrinsic go func() { + defer func() { + l.wsconn.BlockAPI.UnregisterImportedChannel(l.importedChanID) + l.wsconn.BlockAPI.UnregisterFinalisedChannel(l.finalisedChanID) + close(l.done) + }() + for { select { - case <-l.ctx.Done(): + case <-l.cancel: return case block, ok := <-l.importedChan: if !ok { @@ -226,18 +257,9 @@ func (l *ExtrinsicSubmitListener) Listen() { resM["inBlock"] = block.Header.Hash().String() l.importedHash = block.Header.Hash() - l.wsconn.safeSend(newSubscriptionResponse(AuthorExtrinsicUpdates, l.subID, resM)) + l.wsconn.safeSend(newSubscriptionResponse(authorExtrinsicUpdatesMethod, l.subID, resM)) } - } - } - }() - // listen for finalised headers - go func() { - for { - select { - case <-l.ctx.Done(): - return case info, ok := <-l.finalisedChan: if !ok { return @@ -246,7 +268,7 @@ func (l *ExtrinsicSubmitListener) Listen() { if reflect.DeepEqual(l.importedHash, info.Header.Hash()) { resM := make(map[string]interface{}) resM["finalised"] = info.Header.Hash().String() - l.wsconn.safeSend(newSubscriptionResponse(AuthorExtrinsicUpdates, l.subID, resM)) + l.wsconn.safeSend(newSubscriptionResponse(authorExtrinsicUpdatesMethod, l.subID, resM)) } } } @@ -254,12 +276,14 @@ func (l *ExtrinsicSubmitListener) Listen() { } // Stop to cancel the running goroutines to this listener -func (l *ExtrinsicSubmitListener) Stop() { l.cancel() } +func (l *ExtrinsicSubmitListener) Stop() error { + return cancelWithTimeout(l.cancel, l.done, l.cancelTimeout) +} // RuntimeVersionListener to handle listening for Runtime Version type RuntimeVersionListener struct { wsconn *WSConn - subID uint + subID uint32 } // Listen implementation of Listen interface to listen for runtime version changes @@ -280,9 +304,70 @@ func (l *RuntimeVersionListener) Listen() { ver.TransactionVersion = rtVersion.TransactionVersion() ver.Apis = modules.ConvertAPIs(rtVersion.APIItems()) - l.wsconn.safeSend(newSubscriptionResponse("state_runtimeVersion", l.subID, ver)) + l.wsconn.safeSend(newSubscriptionResponse(stateRuntimeVersionMethod, l.subID, ver)) } // Stop to runtimeVersionListener not implemented yet because the listener // does not need to be stoped -func (l *RuntimeVersionListener) Stop() {} +func (l *RuntimeVersionListener) Stop() error { return nil } + +// GrandpaJustificationListener struct has the finalisedCh and the context to stop the goroutines +type GrandpaJustificationListener struct { + cancel chan struct{} + cancelTimeout time.Duration + done chan struct{} + wsconn *WSConn + subID uint32 + finalisedChID byte + finalisedCh chan *types.FinalisationInfo +} + +// Listen will start goroutines that listen to the finaised blocks +func (g *GrandpaJustificationListener) Listen() { + // listen for finalised headers + go func() { + defer func() { + g.wsconn.BlockAPI.UnregisterFinalisedChannel(g.finalisedChID) + close(g.done) + }() + + for { + select { + case <-g.cancel: + return + + case info, ok := <-g.finalisedCh: + if !ok { + return + } + + just, err := g.wsconn.BlockAPI.GetJustification(info.Header.Hash()) + if err != nil { + g.wsconn.safeSendError(float64(g.subID), big.NewInt(InvalidRequestCode), + fmt.Sprintf("error while retrieve justification: %v", err)) + } + + g.wsconn.safeSend(newSubscriptionResponse(grandpaJustificationsMethod, g.subID, common.BytesToHex(just))) + } + } + }() +} + +// Stop will cancel all the goroutines that are executing +func (g *GrandpaJustificationListener) Stop() error { + return cancelWithTimeout(g.cancel, g.done, g.cancelTimeout) +} + +func cancelWithTimeout(cancel, done chan struct{}, t time.Duration) error { + close(cancel) + + timeout := time.NewTimer(t) + defer timeout.Stop() + + select { + case <-done: + return nil + case <-timeout.C: + return ErrCannotCancel + } +} diff --git a/dot/rpc/subscription/listeners_test.go b/dot/rpc/subscription/listeners_test.go index 3c6375259b..60c53468f4 100644 --- a/dot/rpc/subscription/listeners_test.go +++ b/dot/rpc/subscription/listeners_test.go @@ -17,14 +17,24 @@ package subscription import ( + "encoding/json" + "fmt" + "log" "math/big" + "net/http" + "net/http/httptest" + "strings" "testing" "time" "github.com/ChainSafe/gossamer/dot/rpc/modules" + "github.com/ChainSafe/gossamer/dot/rpc/modules/mocks" "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/grandpa" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -60,76 +70,132 @@ func TestStorageObserver_Update(t *testing.T) { expected.Changes[i] = Change{common.BytesToHex(v.Key), common.BytesToHex(v.Value)} } - expectedRespones := newSubcriptionBaseResponseJSON() - expectedRespones.Method = "state_storage" - expectedRespones.Params.Result = expected + expectedResponse := newSubcriptionBaseResponseJSON() + expectedResponse.Method = stateStorageMethod + expectedResponse.Params.Result = expected storageObserver.Update(change) time.Sleep(time.Millisecond * 10) - require.Equal(t, expectedRespones, mockConnection.lastMessage) + require.Equal(t, expectedResponse, mockConnection.lastMessage) } func TestBlockListener_Listen(t *testing.T) { + wsconn, ws, cancel := setupWSConn(t) + defer cancel() + + mockBlockAPI := new(mocks.MockBlockAPI) + mockBlockAPI.On("UnregisterImportedChannel", mock.AnythingOfType("uint8")) + + wsconn.BlockAPI = mockBlockAPI + notifyChan := make(chan *types.Block) - mockConnection := &MockWSConnAPI{} bl := BlockListener{ - Channel: notifyChan, - wsconn: mockConnection, + Channel: notifyChan, + wsconn: wsconn, + cancel: make(chan struct{}), + done: make(chan struct{}), + cancelTimeout: time.Second * 5, } block := types.NewEmptyBlock() block.Header.Number = big.NewInt(1) + go bl.Listen() + defer func() { + require.NoError(t, bl.Stop()) + time.Sleep(time.Millisecond * 10) + mockBlockAPI.AssertCalled(t, "UnregisterImportedChannel", mock.AnythingOfType("uint8")) + }() + + notifyChan <- block + time.Sleep(time.Second * 2) + + _, msg, err := ws.ReadMessage() + require.NoError(t, err) + head, err := modules.HeaderToJSON(*block.Header) require.NoError(t, err) expectedResposnse := newSubcriptionBaseResponseJSON() - expectedResposnse.Method = "chain_newHead" + expectedResposnse.Method = chainNewHeadMethod expectedResposnse.Params.Result = head - go bl.Listen() + expectedResponseBytes, err := json.Marshal(expectedResposnse) + require.NoError(t, err) - notifyChan <- block - time.Sleep(time.Millisecond * 10) - require.Equal(t, expectedResposnse, mockConnection.lastMessage) + require.Equal(t, string(expectedResponseBytes)+"\n", string(msg)) } func TestBlockFinalizedListener_Listen(t *testing.T) { + wsconn, ws, cancel := setupWSConn(t) + defer cancel() + + mockBlockAPI := new(mocks.MockBlockAPI) + mockBlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + + wsconn.BlockAPI = mockBlockAPI + notifyChan := make(chan *types.FinalisationInfo) - mockConnection := &MockWSConnAPI{} bfl := BlockFinalizedListener{ - channel: notifyChan, - wsconn: mockConnection, + channel: notifyChan, + wsconn: wsconn, + cancel: make(chan struct{}), + done: make(chan struct{}), + cancelTimeout: time.Second * 5, } header := types.NewEmptyHeader() + + bfl.Listen() + defer func() { + require.NoError(t, bfl.Stop()) + time.Sleep(time.Millisecond * 10) + mockBlockAPI.AssertCalled(t, "UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + }() + + notifyChan <- &types.FinalisationInfo{ + Header: header, + } + time.Sleep(time.Second * 2) + + _, msg, err := ws.ReadMessage() + require.NoError(t, err) + head, err := modules.HeaderToJSON(*header) if err != nil { logger.Error("failed to convert header to JSON", "error", err) } expectedResponse := newSubcriptionBaseResponseJSON() - expectedResponse.Method = "chain_finalizedHead" + expectedResponse.Method = chainFinalizedHeadMethod expectedResponse.Params.Result = head - go bfl.Listen() + expectedResponseBytes, err := json.Marshal(expectedResponse) + require.NoError(t, err) - notifyChan <- &types.FinalisationInfo{ - Header: header, - } - time.Sleep(time.Millisecond * 10) - require.Equal(t, expectedResponse, mockConnection.lastMessage) + require.Equal(t, string(expectedResponseBytes)+"\n", string(msg)) } func TestExtrinsicSubmitListener_Listen(t *testing.T) { + wsconn, ws, cancel := setupWSConn(t) + defer cancel() + notifyImportedChan := make(chan *types.Block, 100) notifyFinalizedChan := make(chan *types.FinalisationInfo, 100) - mockConnection := &MockWSConnAPI{} + mockBlockAPI := new(mocks.MockBlockAPI) + mockBlockAPI.On("UnregisterImportedChannel", mock.AnythingOfType("uint8")) + mockBlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + + wsconn.BlockAPI = mockBlockAPI + esl := ExtrinsicSubmitListener{ importedChan: notifyImportedChan, finalisedChan: notifyFinalizedChan, - wsconn: mockConnection, + wsconn: wsconn, extrinsic: types.Extrinsic{1, 2, 3}, + cancel: make(chan struct{}), + done: make(chan struct{}), + cancelTimeout: time.Second * 5, } header := types.NewEmptyHeader() exts := []types.Extrinsic{{1, 2, 3}, {7, 8, 9, 0}, {0xa, 0xb}} @@ -142,20 +208,122 @@ func TestExtrinsicSubmitListener_Listen(t *testing.T) { Body: body, } - resImported := map[string]interface{}{"inBlock": block.Header.Hash().String()} - expectedImportedRespones := newSubscriptionResponse(AuthorExtrinsicUpdates, esl.subID, resImported) + esl.Listen() + defer func() { + require.NoError(t, esl.Stop()) + time.Sleep(time.Millisecond * 10) - go esl.Listen() + mockBlockAPI.AssertCalled(t, "UnregisterImportedChannel", mock.AnythingOfType("uint8")) + mockBlockAPI.AssertCalled(t, "UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + }() notifyImportedChan <- block - time.Sleep(time.Millisecond * 10) - require.Equal(t, expectedImportedRespones, mockConnection.lastMessage) + time.Sleep(time.Second * 2) + + _, msg, err := ws.ReadMessage() + require.NoError(t, err) + resImported := map[string]interface{}{"inBlock": block.Header.Hash().String()} + expectedImportedBytes, err := json.Marshal(newSubscriptionResponse(authorExtrinsicUpdatesMethod, esl.subID, resImported)) + require.NoError(t, err) + require.Equal(t, string(expectedImportedBytes)+"\n", string(msg)) notifyFinalizedChan <- &types.FinalisationInfo{ Header: header, } - time.Sleep(time.Millisecond * 10) + time.Sleep(time.Second * 2) + + _, msg, err = ws.ReadMessage() + require.NoError(t, err) resFinalised := map[string]interface{}{"finalised": block.Header.Hash().String()} - expectedFinalizedRespones := newSubscriptionResponse(AuthorExtrinsicUpdates, esl.subID, resFinalised) - require.Equal(t, expectedFinalizedRespones, mockConnection.lastMessage) + expectedFinalizedBytes, err := json.Marshal(newSubscriptionResponse(authorExtrinsicUpdatesMethod, esl.subID, resFinalised)) + require.NoError(t, err) + require.Equal(t, string(expectedFinalizedBytes)+"\n", string(msg)) +} + +func TestGrandpaJustification_Listen(t *testing.T) { + t.Run("When justification doesnt returns error", func(t *testing.T) { + wsconn, ws, cancel := setupWSConn(t) + defer cancel() + + mockedJust := grandpa.Justification{ + Round: 1, + Commit: &grandpa.Commit{ + Hash: common.Hash{}, + Number: 1, + Precommits: nil, + }, + } + + mockedJustBytes, err := mockedJust.Encode() + require.NoError(t, err) + + blockStateMock := new(mocks.MockBlockAPI) + blockStateMock.On("GetJustification", mock.AnythingOfType("common.Hash")).Return(mockedJustBytes, nil) + blockStateMock.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + wsconn.BlockAPI = blockStateMock + + finchannel := make(chan *types.FinalisationInfo) + sub := GrandpaJustificationListener{ + subID: 10, + wsconn: wsconn, + cancel: make(chan struct{}, 1), + done: make(chan struct{}, 1), + finalisedCh: finchannel, + cancelTimeout: time.Second * 5, + } + + sub.Listen() + finchannel <- &types.FinalisationInfo{ + Header: types.NewEmptyHeader(), + } + + time.Sleep(time.Second * 3) + + _, msg, err := ws.ReadMessage() + require.NoError(t, err) + + expected := `{"jsonrpc":"2.0","method":"grandpa_justifications","params":{"result":"%s","subscription":10}}` + "\n" + expected = fmt.Sprintf(expected, common.BytesToHex(mockedJustBytes)) + + require.Equal(t, string(msg), expected) + require.NoError(t, sub.Stop()) + wsconn.Wsconn.Close() + }) + +} + +func setupWSConn(t *testing.T) (*WSConn, *websocket.Conn, func()) { + t.Helper() + + wskt := new(WSConn) + var up = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + + h := func(w http.ResponseWriter, r *http.Request) { + c, err := up.Upgrade(w, r, nil) + if err != nil { + log.Print("error while setup handler:", err) + return + } + + wskt.Wsconn = c + } + + server := httptest.NewServer(http.HandlerFunc(h)) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + ws, r, err := websocket.DefaultDialer.Dial(wsURL, nil) + defer r.Body.Close() + + require.NoError(t, err) + + cancel := func() { + server.Close() + ws.Close() + wskt.Wsconn.Close() + } + + return wskt, ws, cancel } diff --git a/dot/rpc/subscription/messages.go b/dot/rpc/subscription/messages.go index 1e37f3f265..3019f77337 100644 --- a/dot/rpc/subscription/messages.go +++ b/dot/rpc/subscription/messages.go @@ -25,7 +25,7 @@ type BaseResponseJSON struct { // Params for json param response type Params struct { Result interface{} `json:"result"` - SubscriptionID uint `json:"subscription"` + SubscriptionID uint32 `json:"subscription"` } // InvalidRequestCode error code returned for invalid request parameters, value derived from Substrate node output @@ -40,7 +40,7 @@ func newSubcriptionBaseResponseJSON() BaseResponseJSON { } } -func newSubscriptionResponse(method string, subID uint, result interface{}) BaseResponseJSON { +func newSubscriptionResponse(method string, subID uint32, result interface{}) BaseResponseJSON { return BaseResponseJSON{ Jsonrpc: "2.0", Method: method, @@ -54,12 +54,12 @@ func newSubscriptionResponse(method string, subID uint, result interface{}) Base // ResponseJSON for json subscription responses type ResponseJSON struct { Jsonrpc string `json:"jsonrpc"` - Result uint `json:"result"` + Result uint32 `json:"result"` ID float64 `json:"id"` } // NewSubscriptionResponseJSON builds a Response JSON object -func NewSubscriptionResponseJSON(subID uint, reqID float64) ResponseJSON { +func NewSubscriptionResponseJSON(subID uint32, reqID float64) ResponseJSON { return ResponseJSON{ Jsonrpc: "2.0", Result: subID, diff --git a/dot/rpc/subscription/subscription.go b/dot/rpc/subscription/subscription.go index 93413fec65..e20ed73420 100644 --- a/dot/rpc/subscription/subscription.go +++ b/dot/rpc/subscription/subscription.go @@ -24,6 +24,8 @@ func (c *WSConn) getSetupListener(method string) setupListener { return c.initBlockFinalizedListener case "state_subscribeRuntimeVersion": return c.initRuntimeVersionListener + case "grandpa_subscribeJustifications": + return c.initGrandpaJustificationListener default: return nil } @@ -45,6 +47,8 @@ func (c *WSConn) getUnsubListener(method string, params interface{}) (unsubListe switch method { case "state_unsubscribeStorage": unsub = c.unsubscribeStorageListener + case "grandpa_unsubscribeJustifications": + unsub = c.unsubscribeGrandpaJustificationListener default: return nil, nil, errCannotFindUnsubsriber } @@ -52,7 +56,7 @@ func (c *WSConn) getUnsubListener(method string, params interface{}) (unsubListe return unsub, listener, nil } -func parseSubscribeID(p interface{}) (uint, error) { +func parseSubscribeID(p interface{}) (uint32, error) { switch v := p.(type) { case []interface{}: if len(v) == 0 { @@ -62,16 +66,16 @@ func parseSubscribeID(p interface{}) (uint, error) { return 0, errUknownParamSubscribeID } - var id uint + var id uint32 switch v := p.([]interface{})[0].(type) { case float64: - id = uint(v) + id = uint32(v) case string: i, err := strconv.ParseUint(v, 10, 32) if err != nil { return 0, errCannotParseID } - id = uint(i) + id = uint32(i) default: return 0, errUknownParamSubscribeID } diff --git a/dot/rpc/subscription/websocket.go b/dot/rpc/subscription/websocket.go index 56e6721136..3d351b6c66 100644 --- a/dot/rpc/subscription/websocket.go +++ b/dot/rpc/subscription/websocket.go @@ -26,6 +26,7 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "github.com/ChainSafe/gossamer/dot/rpc/modules" "github.com/ChainSafe/gossamer/dot/state" @@ -48,17 +49,15 @@ const DEFAULT_BUFFER_SIZE = 100 // WSConn struct to hold WebSocket Connection references type WSConn struct { - Wsconn *websocket.Conn - mu sync.Mutex - BlockSubChannels map[uint]byte - StorageSubChannels map[int]byte - qtyListeners uint - Subscriptions map[uint]Listener - StorageAPI modules.StorageAPI - BlockAPI modules.BlockAPI - CoreAPI modules.CoreAPI - TxStateAPI modules.TransactionStateAPI - RPCHost string + Wsconn *websocket.Conn + mu sync.Mutex + qtyListeners uint32 + Subscriptions map[uint32]Listener + StorageAPI modules.StorageAPI + BlockAPI modules.BlockAPI + CoreAPI modules.CoreAPI + TxStateAPI modules.TransactionStateAPI + RPCHost string HTTP httpclient } @@ -135,7 +134,12 @@ func (c *WSConn) HandleComm() { } unsub(reqid, listener, params) - listener.Stop() + err = listener.Stop() + + if err != nil { + logger.Warn("failed to cancel listener goroutine", "method", method, "error", err) + } + continue } @@ -203,8 +207,7 @@ func (c *WSConn) initStorageChangeListener(reqID float64, params interface{}) (L c.mu.Lock() - c.qtyListeners++ - stgobs.id = c.qtyListeners + stgobs.id = atomic.AddUint32(&c.qtyListeners, 1) c.Subscriptions[stgobs.id] = stgobs c.mu.Unlock() @@ -230,8 +233,11 @@ func (c *WSConn) unsubscribeStorageListener(reqID float64, l Listener, _ interfa func (c *WSConn) initBlockListener(reqID float64, _ interface{}) (Listener, error) { bl := &BlockListener{ - Channel: make(chan *types.Block, DEFAULT_BUFFER_SIZE), - wsconn: c, + Channel: make(chan *types.Block, DEFAULT_BUFFER_SIZE), + wsconn: c, + cancel: make(chan struct{}, 1), + cancelTimeout: defaultCancelTimeout, + done: make(chan struct{}, 1), } if c.BlockAPI == nil { @@ -248,10 +254,8 @@ func (c *WSConn) initBlockListener(reqID float64, _ interface{}) (Listener, erro c.mu.Lock() - c.qtyListeners++ - bl.subID = c.qtyListeners + bl.subID = atomic.AddUint32(&c.qtyListeners, 1) c.Subscriptions[bl.subID] = bl - c.BlockSubChannels[bl.subID] = bl.ChanID c.mu.Unlock() @@ -262,8 +266,11 @@ func (c *WSConn) initBlockListener(reqID float64, _ interface{}) (Listener, erro func (c *WSConn) initBlockFinalizedListener(reqID float64, _ interface{}) (Listener, error) { bfl := &BlockFinalizedListener{ - channel: make(chan *types.FinalisationInfo, DEFAULT_BUFFER_SIZE), - wsconn: c, + channel: make(chan *types.FinalisationInfo), + cancel: make(chan struct{}, 1), + done: make(chan struct{}, 1), + cancelTimeout: defaultCancelTimeout, + wsconn: c, } if c.BlockAPI == nil { @@ -279,10 +286,8 @@ func (c *WSConn) initBlockFinalizedListener(reqID float64, _ interface{}) (Liste c.mu.Lock() - c.qtyListeners++ - bfl.subID = c.qtyListeners + bfl.subID = atomic.AddUint32(&c.qtyListeners, 1) c.Subscriptions[bfl.subID] = bfl - c.BlockSubChannels[bfl.subID] = bfl.chanID c.mu.Unlock() @@ -304,7 +309,10 @@ func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (Listener importedChan: make(chan *types.Block, DEFAULT_BUFFER_SIZE), wsconn: c, extrinsic: types.Extrinsic(extBytes), - finalisedChan: make(chan *types.FinalisationInfo, DEFAULT_BUFFER_SIZE), + finalisedChan: make(chan *types.FinalisationInfo), + cancel: make(chan struct{}, 1), + done: make(chan struct{}, 1), + cancelTimeout: defaultCancelTimeout, } if c.BlockAPI == nil { @@ -322,10 +330,8 @@ func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (Listener c.mu.Lock() - c.qtyListeners++ - esl.subID = c.qtyListeners + esl.subID = atomic.AddUint32(&c.qtyListeners, 1) c.Subscriptions[esl.subID] = esl - c.BlockSubChannels[esl.subID] = esl.importedChanID c.mu.Unlock() @@ -337,7 +343,7 @@ func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (Listener // TODO (ed) since HandleSubmittedExtrinsic has been called we assume the extrinsic is in the tx queue // should we add a channel to tx queue so we're notified when it's in the queue (See issue #1535) - c.safeSend(newSubscriptionResponse(AuthorExtrinsicUpdates, esl.subID, "ready")) + c.safeSend(newSubscriptionResponse(authorExtrinsicUpdatesMethod, esl.subID, "ready")) // todo (ed) determine which peer extrinsic has been broadcast to, and set status return esl, err @@ -355,8 +361,7 @@ func (c *WSConn) initRuntimeVersionListener(reqID float64, _ interface{}) (Liste c.mu.Lock() - c.qtyListeners++ - rvl.subID = c.qtyListeners + rvl.subID = atomic.AddUint32(&c.qtyListeners, 1) c.Subscriptions[rvl.subID] = rvl c.mu.Unlock() @@ -366,6 +371,49 @@ func (c *WSConn) initRuntimeVersionListener(reqID float64, _ interface{}) (Liste return rvl, nil } +func (c *WSConn) initGrandpaJustificationListener(reqID float64, _ interface{}) (Listener, error) { + if c.BlockAPI == nil { + c.safeSendError(reqID, nil, "error BlockAPI not set") + return nil, fmt.Errorf("error BlockAPI not set") + } + + jl := &GrandpaJustificationListener{ + cancel: make(chan struct{}, 1), + done: make(chan struct{}, 1), + wsconn: c, + finalisedCh: make(chan *types.FinalisationInfo, 1), + cancelTimeout: defaultCancelTimeout, + } + + var err error + jl.finalisedChID, err = c.BlockAPI.RegisterFinalizedChannel(jl.finalisedCh) + if err != nil { + return nil, err + } + + c.mu.Lock() + + jl.subID = atomic.AddUint32(&c.qtyListeners, 1) + c.Subscriptions[jl.subID] = jl + + c.mu.Unlock() + + c.safeSend(NewSubscriptionResponseJSON(jl.subID, reqID)) + + return jl, nil +} + +func (c *WSConn) unsubscribeGrandpaJustificationListener(reqID float64, l Listener, params interface{}) { + listener, ok := l.(*GrandpaJustificationListener) + if !ok { + c.safeSend(newBooleanResponseJSON(false, reqID)) + return + } + + c.BlockAPI.UnregisterFinalisedChannel(listener.finalisedChID) + c.safeSend(newBooleanResponseJSON(true, reqID)) +} + func (c *WSConn) safeSend(msg interface{}) { c.mu.Lock() defer c.mu.Unlock() diff --git a/dot/rpc/subscription/websocket_test.go b/dot/rpc/subscription/websocket_test.go index cb2ab48906..3e72d83808 100644 --- a/dot/rpc/subscription/websocket_test.go +++ b/dot/rpc/subscription/websocket_test.go @@ -1,14 +1,22 @@ package subscription import ( + "fmt" "log" + "math/big" "net/http" "os" "testing" "time" + modulesmocks "github.com/ChainSafe/gossamer/dot/rpc/modules/mocks" + "github.com/ChainSafe/gossamer/dot/rpc/modules" + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/grandpa" "github.com/gorilla/websocket" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -17,8 +25,7 @@ var upgrader = websocket.Upgrader{ } var wsconn = &WSConn{ - Subscriptions: make(map[uint]Listener), - BlockSubChannels: make(map[uint]byte), + Subscriptions: make(map[uint32]Listener), } func handler(w http.ResponseWriter, r *http.Request) { @@ -238,4 +245,65 @@ func TestWSConn_HandleComm(t *testing.T) { require.NotNil(t, res) require.Len(t, wsconn.Subscriptions, 8) + _, msg, err = c.ReadMessage() + require.NoError(t, err) + require.Equal(t, `{"jsonrpc":"2.0","result":8,"id":0}`+"\n", string(msg)) + + _, msg, err = c.ReadMessage() + require.NoError(t, err) + require.Equal(t, `{"jsonrpc":"2.0","method":"author_extrinsicUpdate","params":{"result":"ready","subscription":8}}`+"\n", string(msg)) + + var fCh chan<- *types.FinalisationInfo + mockedJust := grandpa.Justification{ + Round: 1, + Commit: &grandpa.Commit{ + Hash: common.Hash{}, + Number: 1, + Precommits: nil, + }, + } + + mockedJustBytes, err := mockedJust.Encode() + require.NoError(t, err) + + mockBlockAPI := new(modulesmocks.MockBlockAPI) + mockBlockAPI.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")). + Run(func(args mock.Arguments) { + ch := args.Get(0).(chan<- *types.FinalisationInfo) + fCh = ch + }). + Return(uint8(4), nil) + + mockBlockAPI.On("GetJustification", mock.AnythingOfType("common.Hash")).Return(mockedJustBytes, nil) + mockBlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + + wsconn.BlockAPI = mockBlockAPI + listener, err := wsconn.initGrandpaJustificationListener(0, nil) + require.NoError(t, err) + require.NotNil(t, listener) + + _, msg, err = c.ReadMessage() + require.NoError(t, err) + require.Equal(t, `{"jsonrpc":"2.0","result":9,"id":0}`+"\n", string(msg)) + + listener.Listen() + header := &types.Header{ + ParentHash: common.Hash{}, + Number: big.NewInt(1), + } + + fCh <- &types.FinalisationInfo{ + Header: header, + } + + time.Sleep(time.Second * 2) + + expected := `{"jsonrpc":"2.0","method":"grandpa_justifications","params":{"result":"%s","subscription":9}}` + "\n" + expected = fmt.Sprintf(expected, common.BytesToHex(mockedJustBytes)) + _, msg, err = c.ReadMessage() + require.NoError(t, err) + require.Equal(t, []byte(expected), msg) + + err = listener.Stop() + require.NoError(t, err) } diff --git a/go.sum b/go.sum index c11637e1b0..91273a38f8 100644 --- a/go.sum +++ b/go.sum @@ -274,7 +274,6 @@ github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfb github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0= diff --git a/lib/crypto/sr25519/sr25519.go b/lib/crypto/sr25519/sr25519.go index 157faa4dec..f39e7074f2 100644 --- a/lib/crypto/sr25519/sr25519.go +++ b/lib/crypto/sr25519/sr25519.go @@ -83,9 +83,13 @@ func NewKeypairFromPrivate(priv *PrivateKey) (*Keypair, error) { } // NewKeypairFromSeed returns a new sr25519 Keypair given a seed -func NewKeypairFromSeed(seed []byte) (*Keypair, error) { +func NewKeypairFromSeed(keystr []byte) (*Keypair, error) { + if len(keystr) != SeedLength { + return nil, errors.New("cannot generate key from seed: seed is not 32 bytes long") + } + buf := [SeedLength]byte{} - copy(buf[:], seed) + copy(buf[:], keystr) msc, err := sr25519.NewMiniSecretKeyFromRaw(buf) if err != nil { return nil, err @@ -154,6 +158,30 @@ func NewPrivateKey(in []byte) (*PrivateKey, error) { return priv, err } +// NewPrivateKeyFromHex returns a private key from a hex-encoded private key +func NewPrivateKeyFromHex(keystr string) (*PrivateKey, error) { + seedBytes, err := common.HexToBytes(keystr) + if err != nil { + return nil, err + } + + if len(seedBytes) != PrivateKeyLength { + return nil, errors.New("cannot create public key: input is not 32 bytes") + } + + var privKeyBytes [32]byte + copy(privKeyBytes[:], seedBytes) + + miniSecretKey, err := sr25519.NewMiniSecretKeyFromRaw(privKeyBytes) + if err != nil { + return nil, err + } + + return &PrivateKey{ + key: miniSecretKey.ExpandUniform(), + }, nil +} + // GenerateKeypair returns a new sr25519 keypair func GenerateKeypair() (*Keypair, error) { priv, pub, err := sr25519.GenerateKeypair() diff --git a/lib/keystore/helpers.go b/lib/keystore/helpers.go index 0cc862ff29..576be2b930 100644 --- a/lib/keystore/helpers.go +++ b/lib/keystore/helpers.go @@ -64,6 +64,20 @@ func DecodePrivateKey(in []byte, keytype crypto.KeyType) (priv crypto.PrivateKey return priv, err } +// DecodeKeyPairFromHex turns an hex-encoded private key into a keypair +func DecodeKeyPairFromHex(keystr []byte, keytype crypto.KeyType) (kp crypto.Keypair, err error) { + switch keytype { + case crypto.Sr25519Type: + kp, err = sr25519.NewKeypairFromSeed(keystr) + case crypto.Ed25519Type: + kp, err = ed25519.NewKeypairFromSeed(keystr) + default: + return nil, errors.New("cannot decode key: invalid key type") + } + + return kp, err +} + // GenerateKeypair create a new keypair with the corresponding type and saves // it to basepath/keystore/[public key].key in json format encrypted using the // specified password and returns the resulting filepath of the new key diff --git a/lib/runtime/constants.go b/lib/runtime/constants.go index 5c98744771..1d53a55eb1 100644 --- a/lib/runtime/constants.go +++ b/lib/runtime/constants.go @@ -67,6 +67,8 @@ var ( BlockBuilderApplyExtrinsic = "BlockBuilder_apply_extrinsic" // BlockBuilderFinalizeBlock is the runtime API call BlockBuilder_finalize_block BlockBuilderFinalizeBlock = "BlockBuilder_finalize_block" + // DecodeSessionKeys is the runtime API call SessionKeys_decode_session_keys + DecodeSessionKeys = "SessionKeys_decode_session_keys" ) // GrandpaAuthoritiesKey is the location of GRANDPA authority data in the storage trie for LEGACY_NODE_RUNTIME and NODE_RUNTIME diff --git a/lib/runtime/interface.go b/lib/runtime/interface.go index 5b9500e64f..8ace2d1e9a 100644 --- a/lib/runtime/interface.go +++ b/lib/runtime/interface.go @@ -47,6 +47,7 @@ type Instance interface { ApplyExtrinsic(data types.Extrinsic) ([]byte, error) FinalizeBlock() (*types.Header, error) ExecuteBlock(block *types.Block) ([]byte, error) + DecodeSessionKeys(enc []byte) ([]byte, error) // TODO: parameters and return values for these are undefined in the spec CheckInherents() diff --git a/lib/runtime/life/exports.go b/lib/runtime/life/exports.go index cc04df8112..b0ef1b38f3 100644 --- a/lib/runtime/life/exports.go +++ b/lib/runtime/life/exports.go @@ -150,6 +150,11 @@ func (in *Instance) ExecuteBlock(block *types.Block) ([]byte, error) { return in.Exec(runtime.CoreExecuteBlock, bdEnc) } +// DecodeSessionKeys decodes the given public session keys. Returns a list of raw public keys including their key type. +func (in *Instance) DecodeSessionKeys(enc []byte) ([]byte, error) { + return in.Exec(runtime.DecodeSessionKeys, enc) +} + func (in *Instance) CheckInherents() {} //nolint func (in *Instance) RandomSeed() {} //nolint func (in *Instance) OffchainWorker() {} //nolint diff --git a/lib/runtime/wasmer/exports.go b/lib/runtime/wasmer/exports.go index bb23594b90..48368eafe9 100644 --- a/lib/runtime/wasmer/exports.go +++ b/lib/runtime/wasmer/exports.go @@ -174,6 +174,11 @@ func (in *Instance) ExecuteBlock(block *types.Block) ([]byte, error) { return in.exec(runtime.CoreExecuteBlock, bdEnc) } +// DecodeSessionKeys decodes the given public session keys. Returns a list of raw public keys including their key type. +func (in *Instance) DecodeSessionKeys(enc []byte) ([]byte, error) { + return in.exec(runtime.DecodeSessionKeys, enc) +} + func (in *Instance) CheckInherents() {} //nolint func (in *Instance) RandomSeed() {} //nolint func (in *Instance) OffchainWorker() {} //nolint diff --git a/lib/runtime/wasmtime/exports.go b/lib/runtime/wasmtime/exports.go index e2574f53c1..dc6f3a6f61 100644 --- a/lib/runtime/wasmtime/exports.go +++ b/lib/runtime/wasmtime/exports.go @@ -150,6 +150,11 @@ func (in *Instance) ExecuteBlock(block *types.Block) ([]byte, error) { return in.exec(runtime.CoreExecuteBlock, bdEnc) } +// DecodeSessionKeys decodes the given public session keys. Returns a list of raw public keys including their key type. +func (in *Instance) DecodeSessionKeys(enc []byte) ([]byte, error) { + return in.exec(runtime.DecodeSessionKeys, enc) +} + func (in *Instance) CheckInherents() {} //nolint func (in *Instance) RandomSeed() {} //nolint func (in *Instance) OffchainWorker() {} //nolint