Skip to content

Commit

Permalink
fix: subscriber data race (#1169)
Browse files Browse the repository at this point in the history
* fix: subscriber data race

Signed-off-by: Jim Ma <[email protected]>
  • Loading branch information
jim3ma authored and gaius-qi committed Jun 28, 2023
1 parent a46a37b commit 42e662d
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 149 deletions.
302 changes: 154 additions & 148 deletions client/daemon/rpcserver/rpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,180 +336,186 @@ func TestDownloadManager_SyncPieceTasks(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockStorageManger := mock_storage.NewMockManager(ctrl)
for _, delay := range []bool{false, true} {
delay := delay
mockStorageManger := mock_storage.NewMockManager(ctrl)

if tc.limit == 0 {
tc.limit = 1024
}
if tc.limit == 0 {
tc.limit = 1024
}

var (
totalPieces []*base.PieceInfo
lock sync.Mutex
)
var (
totalPieces []*base.PieceInfo
lock sync.Mutex
)

var addedPieces = make(map[uint32]*base.PieceInfo)
for _, p := range tc.existPieces {
if p.end == 0 {
p.end = p.start
}
for i := p.start; i <= p.end; i++ {
if _, ok := addedPieces[uint32(i)]; ok {
continue
var addedPieces = make(map[uint32]*base.PieceInfo)
for _, p := range tc.existPieces {
if p.end == 0 {
p.end = p.start
}
piece := &base.PieceInfo{
PieceNum: int32(i),
RangeStart: uint64(i) * uint64(pieceSize),
RangeSize: pieceSize,
PieceOffset: uint64(i) * uint64(pieceSize),
PieceStyle: base.PieceStyle_PLAIN,
for i := p.start; i <= p.end; i++ {
if _, ok := addedPieces[uint32(i)]; ok {
continue
}
piece := &base.PieceInfo{
PieceNum: int32(i),
RangeStart: uint64(i) * uint64(pieceSize),
RangeSize: pieceSize,
PieceOffset: uint64(i) * uint64(pieceSize),
PieceStyle: base.PieceStyle_PLAIN,
}
totalPieces = append(totalPieces, piece)
addedPieces[uint32(i)] = piece
}
totalPieces = append(totalPieces, piece)
addedPieces[uint32(i)] = piece
}
}

mockStorageManger.EXPECT().GetPieces(gomock.Any(),
gomock.Any()).AnyTimes().DoAndReturn(
func(ctx context.Context, req *base.PieceTaskRequest) (*base.PiecePacket, error) {
var pieces []*base.PieceInfo
lock.Lock()
for i := req.StartNum; i < tc.totalPieces; i++ {
if piece, ok := addedPieces[i]; ok {
if piece.PieceNum >= int32(req.StartNum) && len(pieces) < int(req.Limit) {
pieces = append(pieces, piece)
mockStorageManger.EXPECT().GetPieces(gomock.Any(),
gomock.Any()).AnyTimes().DoAndReturn(
func(ctx context.Context, req *base.PieceTaskRequest) (*base.PiecePacket, error) {
var pieces []*base.PieceInfo
lock.Lock()
for i := req.StartNum; i < tc.totalPieces; i++ {
if piece, ok := addedPieces[i]; ok {
if piece.PieceNum >= int32(req.StartNum) && len(pieces) < int(req.Limit) {
pieces = append(pieces, piece)
}
}
}
}
lock.Unlock()
return &base.PiecePacket{
TaskId: req.TaskId,
DstPid: req.DstPid,
DstAddr: "",
PieceInfos: pieces,
TotalPiece: int32(tc.totalPieces),
ContentLength: int64(tc.totalPieces) * int64(pieceSize),
PieceMd5Sign: "",
}, nil
})
mockTaskManager := mock_peer.NewMockTaskManager(ctrl)
mockTaskManager.EXPECT().Subscribe(gomock.Any()).AnyTimes().DoAndReturn(
func(request *base.PieceTaskRequest) (*peer.SubscribeResult, bool) {
ch := make(chan *peer.PieceInfo)
success := make(chan struct{})
fail := make(chan struct{})
lock.Unlock()
return &base.PiecePacket{
TaskId: req.TaskId,
DstPid: req.DstPid,
DstAddr: "",
PieceInfos: pieces,
TotalPiece: int32(tc.totalPieces),
ContentLength: int64(tc.totalPieces) * int64(pieceSize),
PieceMd5Sign: "",
}, nil
})
mockTaskManager := mock_peer.NewMockTaskManager(ctrl)
mockTaskManager.EXPECT().Subscribe(gomock.Any()).AnyTimes().DoAndReturn(
func(request *base.PieceTaskRequest) (*peer.SubscribeResult, bool) {
ch := make(chan *peer.PieceInfo)
success := make(chan struct{})
fail := make(chan struct{})

go func(followingPieces []pieceRange) {
for i, p := range followingPieces {
if p.end == 0 {
p.end = p.start
}
for j := p.start; j <= p.end; j++ {
lock.Lock()
if _, ok := addedPieces[uint32(j)]; ok {
continue
go func(followingPieces []pieceRange) {
for i, p := range followingPieces {
if p.end == 0 {
p.end = p.start
}
piece := &base.PieceInfo{
PieceNum: int32(j),
RangeStart: uint64(j) * uint64(pieceSize),
RangeSize: pieceSize,
PieceOffset: uint64(j) * uint64(pieceSize),
PieceStyle: base.PieceStyle_PLAIN,
}
totalPieces = append(totalPieces, piece)
addedPieces[uint32(j)] = piece
lock.Unlock()
for j := p.start; j <= p.end; j++ {
lock.Lock()
if _, ok := addedPieces[uint32(j)]; ok {
continue
}
piece := &base.PieceInfo{
PieceNum: int32(j),
RangeStart: uint64(j) * uint64(pieceSize),
RangeSize: pieceSize,
PieceOffset: uint64(j) * uint64(pieceSize),
PieceStyle: base.PieceStyle_PLAIN,
}
totalPieces = append(totalPieces, piece)
addedPieces[uint32(j)] = piece
lock.Unlock()

var finished bool
if i == len(followingPieces)-1 && j == p.end {
finished = true
}
ch <- &peer.PieceInfo{
Num: int32(j),
Finished: finished,
var finished bool
if i == len(followingPieces)-1 && j == p.end {
finished = true
}
if !delay {
ch <- &peer.PieceInfo{
Num: int32(j),
Finished: finished,
}
}
}
}
}
close(success)
}(tc.followingPieces)
close(success)
}(tc.followingPieces)

return &peer.SubscribeResult{
Storage: mockStorageManger,
PieceInfoChannel: ch,
Success: success,
Fail: fail,
}, true
})
return &peer.SubscribeResult{
Storage: mockStorageManger,
PieceInfoChannel: ch,
Success: success,
Fail: fail,
}, true
})

s := &server{
KeepAlive: clientutil.NewKeepAlive("test"),
peerHost: &scheduler.PeerHost{},
storageManager: mockStorageManger,
peerTaskManager: mockTaskManager,
}
s := &server{
KeepAlive: clientutil.NewKeepAlive("test"),
peerHost: &scheduler.PeerHost{},
storageManager: mockStorageManger,
peerTaskManager: mockTaskManager,
}

port, client := setupPeerServerAndClient(t, s, assert, s.ServePeer)
defer s.peerServer.GracefulStop()
port, client := setupPeerServerAndClient(t, s, assert, s.ServePeer)
defer s.peerServer.GracefulStop()

syncClient, err := client.SyncPieceTasks(
context.Background(),
dfnet.NetAddr{
Type: dfnet.TCP,
Addr: fmt.Sprintf("127.0.0.1:%d", port),
},
&base.PieceTaskRequest{
TaskId: tc.name,
SrcPid: idgen.PeerID(iputils.IPv4),
DstPid: idgen.PeerID(iputils.IPv4),
StartNum: 0,
Limit: tc.limit,
})
assert.Nil(err, "client sync piece tasks grpc call should be ok")
syncClient, err := client.SyncPieceTasks(
context.Background(),
dfnet.NetAddr{
Type: dfnet.TCP,
Addr: fmt.Sprintf("127.0.0.1:%d", port),
},
&base.PieceTaskRequest{
TaskId: tc.name,
SrcPid: idgen.PeerID(iputils.IPv4),
DstPid: idgen.PeerID(iputils.IPv4),
StartNum: 0,
Limit: tc.limit,
})
assert.Nil(err, "client sync piece tasks grpc call should be ok")

var (
total = make(map[int32]bool)
maxNum int32
requestSent = make(chan bool)
)
if len(tc.requestPieces) == 0 {
close(requestSent)
} else {
go func() {
for _, n := range tc.requestPieces {
request := &base.PieceTaskRequest{
TaskId: tc.name,
SrcPid: idgen.PeerID(iputils.IPv4),
DstPid: idgen.PeerID(iputils.IPv4),
StartNum: uint32(n),
Limit: tc.limit,
}
assert.Nil(syncClient.Send(request))
}
var (
total = make(map[int32]bool)
maxNum int32
requestSent = make(chan bool)
)
if len(tc.requestPieces) == 0 {
close(requestSent)
}()
}
for {
p, err := syncClient.Recv()
if err == io.EOF {
break
} else {
go func() {
for _, n := range tc.requestPieces {
request := &base.PieceTaskRequest{
TaskId: tc.name,
SrcPid: idgen.PeerID(iputils.IPv4),
DstPid: idgen.PeerID(iputils.IPv4),
StartNum: uint32(n),
Limit: tc.limit,
}
assert.Nil(syncClient.Send(request))
}
close(requestSent)
}()
}
for _, info := range p.PieceInfos {
total[info.PieceNum] = true
if info.PieceNum >= maxNum {
maxNum = info.PieceNum
for {
p, err := syncClient.Recv()
if err == io.EOF {
break
}
for _, info := range p.PieceInfos {
total[info.PieceNum] = true
if info.PieceNum >= maxNum {
maxNum = info.PieceNum
}
}
if tc.success {
assert.Nil(err, "receive piece tasks should be ok")
}
if int(p.TotalPiece) == len(total) {
<-requestSent
err = syncClient.CloseSend()
assert.Nil(err)
}
}
if tc.success {
assert.Nil(err, "receive piece tasks should be ok")
assert.Equal(int(maxNum+1), len(total))
}
if int(p.TotalPiece) == len(total) {
<-requestSent
err = syncClient.CloseSend()
assert.Nil(err)
}
}
if tc.success {
assert.Equal(int(maxNum+1), len(total))
}

})
}
}
Expand Down
3 changes: 2 additions & 1 deletion client/daemon/rpcserver/subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,12 @@ loop:
s.Unlock()
break loop
}
s.Unlock()
if info.Finished {
s.Unlock()
break loop
}
nextPieceNum = s.searchNextPieceNum(nextPieceNum)
s.Unlock()
case <-s.Success:
s.Lock()
// all pieces already sent
Expand Down

0 comments on commit 42e662d

Please sign in to comment.