Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TX relaying fixes #448

Merged
merged 5 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/core/blockchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,9 @@ func (bc *Blockchain) storeBlock(block *Block) error {
}

atomic.StoreUint32(&bc.blockHeight, block.Index)
for _, tx := range block.Transactions {
bc.memPool.Remove(tx.Hash())
}
return nil
}

Expand Down
34 changes: 34 additions & 0 deletions pkg/core/mem_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func (mp MemPool) TryAdd(hash util.Uint256, pItem *PoolItem) bool {

mp.lock.RLock()
if _, ok := mp.unsortedTxn[hash]; ok {
mp.lock.RUnlock()
return false
}
mp.unsortedTxn[hash] = pItem
Expand All @@ -131,6 +132,39 @@ func (mp MemPool) TryAdd(hash util.Uint256, pItem *PoolItem) bool {
return ok
}

// Remove removes an item from the mempool, if it exists there (and does
// nothing if it doesn't).
func (mp *MemPool) Remove(hash util.Uint256) {
var mapAndPools = []struct {
unsortedMap map[util.Uint256]*PoolItem
sortedPools []*PoolItems
}{
{unsortedMap: mp.unsortedTxn, sortedPools: []*PoolItems{&mp.sortedHighPrioTxn, &mp.sortedLowPrioTxn}},
{unsortedMap: mp.unverifiedTxn, sortedPools: []*PoolItems{&mp.unverifiedSortedHighPrioTxn, &mp.unverifiedSortedLowPrioTxn}},
}
mp.lock.Lock()
for _, mapAndPool := range mapAndPools {
if _, ok := mapAndPool.unsortedMap[hash]; ok {
delete(mapAndPool.unsortedMap, hash)
for _, pool := range mapAndPool.sortedPools {
var num int
var item *PoolItem
for num, item = range *pool {
if hash.Equals(item.txn.Hash()) {
break
}
}
if num < len(*pool)-1 {
*pool = append((*pool)[:num], (*pool)[num+1:]...)
} else if num == len(*pool)-1 {
*pool = (*pool)[:num]
}
}
}
}
mp.lock.Unlock()
}

// RemoveOverCapacity removes transactions with lowest fees until the total number of transactions
// in the MemPool is within the capacity of the MemPool.
func (mp *MemPool) RemoveOverCapacity() {
Expand Down
64 changes: 64 additions & 0 deletions pkg/core/mem_pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package core

import (
"testing"

"github.com/CityOfZion/neo-go/pkg/core/transaction"
"github.com/CityOfZion/neo-go/pkg/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type FeerStub struct {
lowPriority bool
sysFee util.Fixed8
netFee util.Fixed8
perByteFee util.Fixed8
}

func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 {
return fs.netFee
}

func (fs *FeerStub) IsLowPriority(*transaction.Transaction) bool {
return fs.lowPriority
}

func (fs *FeerStub) FeePerByte(*transaction.Transaction) util.Fixed8 {
return fs.perByteFee
}

func (fs *FeerStub) SystemFee(*transaction.Transaction) util.Fixed8 {
return fs.sysFee
}

func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) {
mp := NewMemPool(10)
tx := newMinerTX()
item := NewPoolItem(tx, fs)
_, ok := mp.TryGetValue(tx.Hash())
require.Equal(t, false, ok)
require.Equal(t, true, mp.TryAdd(tx.Hash(), item))
// Re-adding should fail.
require.Equal(t, false, mp.TryAdd(tx.Hash(), item))
tx2, ok := mp.TryGetValue(tx.Hash())
require.Equal(t, true, ok)
require.Equal(t, tx, tx2)
mp.Remove(tx.Hash())
_, ok = mp.TryGetValue(tx.Hash())
require.Equal(t, false, ok)
// Make sure nothing left in the mempool after removal.
assert.Equal(t, 0, len(mp.unsortedTxn))
assert.Equal(t, 0, len(mp.unverifiedTxn))
assert.Equal(t, 0, len(mp.sortedHighPrioTxn))
assert.Equal(t, 0, len(mp.sortedLowPrioTxn))
assert.Equal(t, 0, len(mp.unverifiedSortedHighPrioTxn))
assert.Equal(t, 0, len(mp.unverifiedSortedLowPrioTxn))
}

func TestMemPoolAddRemove(t *testing.T) {
var fs = &FeerStub{lowPriority: false}
t.Run("low priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) })
fs.lowPriority = true
t.Run("high priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) })
}
3 changes: 2 additions & 1 deletion pkg/core/transaction/invocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ func NewInvocationTX(script []byte) *Transaction {
Type: InvocationType,
Version: 1,
Data: &InvocationTX{
Script: script,
Script: script,
Version: 1,
},
Attributes: []*Attribute{},
Inputs: []*Input{},
Expand Down
19 changes: 19 additions & 0 deletions pkg/core/transaction/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,25 @@ func TestDecodeEncodeInvocationTX(t *testing.T) {
assert.Equal(t, rawInvocationTX, hex.EncodeToString(buf.Bytes()))
}

func TestNewInvocationTX(t *testing.T) {
script := []byte{0x51}
tx := NewInvocationTX(script)
txData := tx.Data.(*InvocationTX)
assert.Equal(t, InvocationType, tx.Type)
assert.Equal(t, tx.Version, txData.Version)
assert.Equal(t, script, txData.Script)
buf := io.NewBufBinWriter()
// Update hash fields to match tx2 that is gonna autoupdate them on decode.
_ = tx.Hash()
tx.EncodeBinary(buf.BinWriter)
assert.Nil(t, buf.Err)
var tx2 Transaction
r := io.NewBinReaderFromBuf(buf.Bytes())
tx2.DecodeBinary(r)
assert.Nil(t, r.Err)
assert.Equal(t, *tx, tx2)
}

func TestDecodePublishTX(t *testing.T) {
expectedTXData := &PublishTX{}
expectedTXData.Name = "Lock"
Expand Down
41 changes: 38 additions & 3 deletions pkg/network/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,40 @@ func (s *Server) handleBlockCmd(p Peer, block *core.Block) error {

// handleInvCmd processes the received inventory.
func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
if !inv.Type.Valid() || len(inv.Hashes) == 0 {
return errInvalidInvType
}
payload := payload.NewInventory(inv.Type, inv.Hashes)
return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
}

// handleInvCmd processes the received inventory.
func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
switch inv.Type {
case payload.TXType:
for _, hash := range inv.Hashes {
tx, _, err := s.chain.GetTransaction(hash)
if err == nil {
err = p.WriteMsg(NewMessage(s.Net, CMDTX, tx))
if err != nil {
return err
}

}
}
case payload.BlockType:
for _, hash := range inv.Hashes {
b, err := s.chain.GetBlock(hash)
if err == nil {
err = p.WriteMsg(NewMessage(s.Net, CMDBlock, b))
if err != nil {
return err
}
}
}
case payload.ConsensusType:
// TODO (#431)
}
return nil
}

// handleAddrCmd will process received addresses.
func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error {
for _, a := range addrs.Addrs {
Expand Down Expand Up @@ -350,13 +377,21 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
}

if peer.Handshaked() {
if inv, ok := msg.Payload.(*payload.Inventory); ok {
if !inv.Type.Valid() || len(inv.Hashes) == 0 {
return errInvalidInvType
}
}
switch msg.CommandType() {
case CMDAddr:
addrs := msg.Payload.(*payload.AddressList)
return s.handleAddrCmd(peer, addrs)
case CMDGetAddr:
// it has no payload
return s.handleGetAddrCmd(peer)
case CMDGetData:
inv := msg.Payload.(*payload.Inventory)
return s.handleGetDataCmd(peer, inv)
case CMDHeaders:
headers := msg.Payload.(*payload.Headers)
go s.handleHeadersCmd(peer, headers)
Expand Down