Skip to content

Commit

Permalink
distsqlrun: add tests for cockroachdb#13989
Browse files Browse the repository at this point in the history
We didn't have any tests exercising a stream failing to connect within
the registry's timeout.

I've changed the fr.ConnectInbound interface slightly - made it return a
RowReceiver explicitly so that callers don't access the
inboundStreamInfo's fields without the flow registry lock - which is
documented to be required.
  • Loading branch information
andreimatei committed Mar 14, 2017
1 parent 93f315c commit 24cd6fa
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 47 deletions.
3 changes: 3 additions & 0 deletions pkg/sql/distsqlrun/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ func (rb *RowBuffer) PushRow(row sqlbase.EncDatumRow) bool {

// Close is part of the RowReceiver interface.
func (rb *RowBuffer) Close(err error) {
if rb.Closed {
panic("RowBuffer already closed")
}
rb.Err = err
rb.Closed = true
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/distsqlrun/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func (f *Flow) Start(ctx context.Context, doneFn func()) {
// set up the WaitGroup counter before.
f.waitGroup.Add(len(f.inboundStreams) + len(f.outboxes) + len(f.processors))

f.flowRegistry.RegisterFlow(ctx, f.id, f, f.inboundStreams)
f.flowRegistry.RegisterFlow(ctx, f.id, f, f.inboundStreams, flowStreamDefaultTimeout)
if log.V(1) {
log.Infof(ctx, "registered flow %s", f.id.Short())
}
Expand Down
84 changes: 49 additions & 35 deletions pkg/sql/distsqlrun/flow_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@ import (
"github.com/pkg/errors"
)

// flowStreamTimeout is the amount of time incoming streams wait for a flow to
// flowStreamDefaultTimeout is the amount of time incoming streams wait for a flow to
// be set up before erroring out.
const flowStreamTimeout time.Duration = 10 * time.Second
const flowStreamDefaultTimeout time.Duration = 10 * time.Second

// inboundStreamInfo represents the endpoint where a data stream from another
// node connects to a flow. The external node initiates this process through a
// FlowStream RPC, which uses (*Flow).connectInboundStream() to associate the
// stream to a receiver to push rows to.
//
// All fields are protected by the flowRegistry mutex (except the receiver,
// whose methods can be called freely).
type inboundStreamInfo struct {
// receiver is the entity that will receive rows from another host, which is
// part of a processor (normally an input synchronizer).
Expand All @@ -55,7 +58,8 @@ type inboundStreamInfo struct {
}

// flowEntry is a structure associated with a (potential) flow.
// All fields are protected by the flowRegistry mutex.
// All fields are protected by the flowRegistry mutex, except flow, whose
// methods can be called freely.
type flowEntry struct {
// waitCh is set if one or more clients are waiting for the flow; the
// channel gets closed when the flow is registered.
Expand Down Expand Up @@ -125,10 +129,13 @@ func (fr *flowRegistry) releaseEntryLocked(id FlowID) {
// flow from the registry.
//
// inboundStreams are all the remote streams that will be connected into this
// flow. If any of them is not connected within a timeout, errors are
// propagated.
// flow. If any of them is not connected within timeout, errors are propagated.
func (fr *flowRegistry) RegisterFlow(
ctx context.Context, id FlowID, f *Flow, inboundStreams map[StreamID]*inboundStreamInfo,
ctx context.Context,
id FlowID,
f *Flow,
inboundStreams map[StreamID]*inboundStreamInfo,
timeout time.Duration,
) {
fr.Lock()
defer fr.Unlock()
Expand All @@ -147,11 +154,11 @@ func (fr *flowRegistry) RegisterFlow(

if len(inboundStreams) > 0 {
// Set up a function to time out inbound streams after a while.
entry.streamTimer = time.AfterFunc(flowStreamTimeout, func() {
entry.streamTimer = time.AfterFunc(timeout, func() {
fr.Lock()
defer fr.Unlock()
numTimedOut := 0
for _, is := range entry.inboundStreams {
for streamID, is := range entry.inboundStreams {
if is.timedOut {
panic("stream already marked as timed out")
}
Expand All @@ -162,7 +169,7 @@ func (fr *flowRegistry) RegisterFlow(
// its consumer; the error will propagate and eventually drain all the
// processors.
is.receiver.Close(errors.Errorf("inbound stream timed out waiting for connection"))
fr.finishInboundStreamLocked(is)
fr.finishInboundStreamLocked(id, streamID)
}
}
if numTimedOut != 0 {
Expand All @@ -171,15 +178,15 @@ func (fr *flowRegistry) RegisterFlow(
"flow id:%s : %d inbound streams timed out after %s; propagated error throughout flow",
id,
numTimedOut,
flowStreamTimeout,
timeout,
)
}
})
}
}

// UnregisterFlow removes a flow from the registry. Any subsequent
// ConnectInboundStream calls will wait for the flow in vain.
// ConnectInboundStream calls for the flow will fail to find it and timeout.
func (fr *flowRegistry) UnregisterFlow(id FlowID) {
fr.Lock()
entry := fr.flows[id]
Expand Down Expand Up @@ -232,48 +239,55 @@ func (fr *flowRegistry) waitForFlowLocked(id FlowID, timeout time.Duration) *flo
}

// ConnectInboundStream finds the inboundStreamInfo for the given ID and marks it
// as connected. FinishInboundStream must be called after the rows are
// transferred.
// as connected. It waits up to timeout for the stream to be registered with the
// registry. Non-test callers should pass flowStreamDefaultTimeout.
//
// It returns the Flow that the stream is connecting to, the receiver that the
// stream mush push data to, a cleanup function that must be called to
// de-register the flow from the registry after all the data has been pushed.
//
// The cleanup function will decrement the flow's WaitGroup, so that Flow.Wait()
// is not blocked on this stream any more.
func (fr *flowRegistry) ConnectInboundStream(
flowID FlowID, streamID StreamID,
) (*Flow, *inboundStreamInfo, error) {
flowID FlowID, streamID StreamID, timeout time.Duration,
) (*Flow, RowReceiver, func(), error) {
fr.Lock()
defer fr.Unlock()
entry := fr.waitForFlowLocked(flowID, flowStreamTimeout)
entry := fr.waitForFlowLocked(flowID, timeout)
if entry == nil {
return nil, nil, errors.Errorf("flow %s not found", flowID)
return nil, nil, nil, errors.Errorf("flow %s not found", flowID)
}

s, ok := entry.inboundStreams[streamID]
if !ok {
return nil, nil, errors.Errorf("flow %s: no inbound stream %d", flowID, streamID)
return nil, nil, nil, errors.Errorf("flow %s: no inbound stream %d", flowID, streamID)
}
if s.connected {
return nil, nil, errors.Errorf("flow %s: inbound stream %d already connected", flowID, streamID)
return nil, nil, nil, errors.Errorf("flow %s: inbound stream %d already connected", flowID, streamID)
}
if s.timedOut {
return nil, nil, errors.Errorf("flow %s: inbound stream %d came too late", flowID, streamID)
return nil, nil, nil, errors.Errorf("flow %s: inbound stream %d came too late", flowID, streamID)
}
s.connected = true
return entry.flow, s, nil
cleanup := func() {
fr.Lock()
defer fr.Unlock()
fr.finishInboundStreamLocked(flowID, streamID)
}
return entry.flow, s.receiver, cleanup, nil
}

func (fr *flowRegistry) finishInboundStreamLocked(is *inboundStreamInfo) {
if !is.connected && !is.timedOut {
panic("finishing inbound stream that didn't connect or time out")
func (fr *flowRegistry) finishInboundStreamLocked(fid FlowID, sid StreamID) {
flowEntry := fr.getEntryLocked(fid)
streamEntry := flowEntry.inboundStreams[sid]

if !streamEntry.connected && !streamEntry.timedOut {
panic("finising inbound stream that didn't connect or time out")
}
if is.finished {
if streamEntry.finished {
panic("double finish")
}

is.finished = true
is.waitGroup.Done()
}

// FinishInboundStream is to be called when we are done transferring rows for a
// stream previously connected via ConnectInboundStream.
func (fr *flowRegistry) FinishInboundStream(is *inboundStreamInfo) {
fr.Lock()
defer fr.Unlock()
fr.finishInboundStreamLocked(is)
streamEntry.finished = true
streamEntry.waitGroup.Done()
}
88 changes: 80 additions & 8 deletions pkg/sql/distsqlrun/flow_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,45 @@ import (

"golang.org/x/net/context"

"github.com/cockroachdb/cockroach/pkg/testutils"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
"github.com/pkg/errors"
)

// lookupFlow returns the registered flow with the given ID. If no such flow is
// registered, waits until it gets registered - up to the given timeout. If the
// timeout elapses, returns nil.
func lookupFlow(fr *flowRegistry, id FlowID, timeout time.Duration) *Flow {
// timeout elapses and the flow is not registered, the bool return value will be
// false.
func lookupFlow(fr *flowRegistry, fid FlowID, timeout time.Duration) *Flow {
fr.Lock()
defer fr.Unlock()
entry := fr.waitForFlowLocked(id, timeout)
entry := fr.waitForFlowLocked(fid, timeout)
if entry == nil {
return nil
}
return entry.flow
}

// lookupStreamInfo returns a stream entry from a flowRegistry. If either the
// flow or the streams are missing, an error is returned.
//
// A copy of the registry's inboundStreamInfo is returned so it can be accessed
// without locking.
func lookupStreamInfo(fr *flowRegistry, fid FlowID, sid StreamID) (inboundStreamInfo, error) {
fr.Lock()
defer fr.Unlock()
entry := fr.getEntryLocked(fid)
if entry.flow == nil {
return inboundStreamInfo{}, errors.Errorf("missing flow entry")
}
si, ok := entry.inboundStreams[sid]
if !ok {
return inboundStreamInfo{}, errors.Errorf("missing stream entry")
}
return *si, nil
}

func TestFlowRegistry(t *testing.T) {
defer leaktest.AfterTest(t)()
reg := makeFlowRegistry()
Expand Down Expand Up @@ -67,7 +89,7 @@ func TestFlowRegistry(t *testing.T) {
}

ctx := context.Background()
reg.RegisterFlow(ctx, id1, f1, nil /* inboundStreams */)
reg.RegisterFlow(ctx, id1, f1, nil /* inboundStreams */, flowStreamDefaultTimeout)

if f := lookupFlow(reg, id1, 0); f != f1 {
t.Error("couldn't lookup previously registered flow")
Expand All @@ -83,7 +105,7 @@ func TestFlowRegistry(t *testing.T) {

go func() {
time.Sleep(jiffy)
reg.RegisterFlow(ctx, id1, f1, nil /* inboundStreams */)
reg.RegisterFlow(ctx, id1, f1, nil /* inboundStreams */, flowStreamDefaultTimeout)
}()

if f := lookupFlow(reg, id1, 10*jiffy); f != f1 {
Expand Down Expand Up @@ -114,7 +136,7 @@ func TestFlowRegistry(t *testing.T) {
}()

time.Sleep(jiffy)
reg.RegisterFlow(ctx, id2, f2, nil /* inboundStreams */)
reg.RegisterFlow(ctx, id2, f2, nil /* inboundStreams */, flowStreamDefaultTimeout)
wg.Wait()

// -- Multiple lookups, with the first one failing. --
Expand All @@ -139,18 +161,68 @@ func TestFlowRegistry(t *testing.T) {
}()

wg1.Wait()
reg.RegisterFlow(ctx, id3, f3, nil /* inboundStreams */)
reg.RegisterFlow(ctx, id3, f3, nil /* inboundStreams */, flowStreamDefaultTimeout)
wg2.Wait()

// -- Lookup with huge timeout, register in the meantime. --

go func() {
time.Sleep(jiffy)
reg.RegisterFlow(ctx, id4, f4, nil /* inboundStreams */)
reg.RegisterFlow(ctx, id4, f4, nil /* inboundStreams */, flowStreamDefaultTimeout)
}()

// This should return in a jiffy.
if f := lookupFlow(reg, id4, time.Hour); f != f4 {
t.Error("couldn't lookup registered flow (with wait)")
}
}

// Test that, if inbound streams are not connected within the timeout, errors
// are propagated to their consumers and future attempts to connect them fail.
func TestStreamConnectionTimeout(t *testing.T) {
defer leaktest.AfterTest(t)()
reg := makeFlowRegistry()

jiffy := time.Nanosecond

// Register a flow with a very low timeout. After it times out, we'll attempt
// to connect a stream, but it'll be too late.
id1 := FlowID{uuid.MakeV4()}
f1 := &Flow{}
streamID1 := StreamID(1)
consumer := &RowBuffer{}
wg := &sync.WaitGroup{}
wg.Add(1)
inboundStreams := map[StreamID]*inboundStreamInfo{
streamID1: {receiver: consumer, waitGroup: wg},
}
reg.RegisterFlow(context.TODO(), id1, f1, inboundStreams, jiffy)

testutils.SucceedsSoon(t, func() error {
si, err := lookupStreamInfo(reg, id1, streamID1)
if err != nil {
t.Fatal(err)
}
if !si.timedOut {
return errors.Errorf("not timed out yet")
}
return nil
})

if !consumer.Closed {
t.Fatalf("expected consumer to have been closed when the flow timed out")
}

if _, _, _, err := reg.ConnectInboundStream(id1, streamID1, jiffy); !testutils.IsError(
err, "came too late") {
t.Fatalf("expected %q, got: %v", "came too late", err)
}

// Unregister the flow. Subsequent attempts to connect a stream should result
// in a different error than before.
reg.UnregisterFlow(id1)
if _, _, _, err := reg.ConnectInboundStream(id1, streamID1, jiffy); !testutils.IsError(
err, "not found") {
t.Fatalf("expected %q, got: %v", "not found", err)
}
}
7 changes: 4 additions & 3 deletions pkg/sql/distsqlrun/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,14 @@ func (ds *ServerImpl) flowStreamInt(ctx context.Context, stream DistSQL_FlowStre
if log.V(1) {
log.Infof(ctx, "connecting inbound stream %s/%d", flowID.Short(), streamID)
}
f, streamInfo, err := ds.flowRegistry.ConnectInboundStream(flowID, streamID)
f, receiver, cleanup, err := ds.flowRegistry.ConnectInboundStream(
flowID, streamID, flowStreamDefaultTimeout)
if err != nil {
return err
}
log.VEventf(ctx, 1, "connected inbound stream %s/%d", flowID.Short(), streamID)
defer ds.flowRegistry.FinishInboundStream(streamInfo)
return ProcessInboundStream(f.AnnotateCtx(ctx), stream, msg, streamInfo.receiver)
defer cleanup()
return ProcessInboundStream(f.AnnotateCtx(ctx), stream, msg, receiver)
}

// FlowStream is part of the DistSQLServer interface.
Expand Down

0 comments on commit 24cd6fa

Please sign in to comment.