From 3d3fe6cc1a0e449c248e6ee62617fcf0f5c87fc3 Mon Sep 17 00:00:00 2001 From: Andrei Matei Date: Tue, 28 Mar 2017 13:46:04 -0400 Subject: [PATCH] rpc: don't close gRPC connections on heartbeat timeouts Fixes #13989 Before this patch, the rpc.Context would perform heartbeats (a dedicated RPC) to see if a connection is healthy. If the heartbeats failed, the connection was closed (causing in-flight RPCs to fail) and the node was marked as unhealthy. These heartbeats, being regular RPCs, were subject to gRPC's flow control. This means that they were easily blocked by other large RPCs, which meant they were too feeble. In particular, they were easily blocked by large DistSQL streams. This patch moves to using gRPC's internal HTTP2 ping frames for checking conn health. These are not subject to flow control. The grpc transport-level connection is closed when they fail (and so in-flight RPCs still fail), but otherwise gRPC reconnects transparently. Heartbeats stay for the other current uses - clock skew detection and node health marking. Marking a node as unhealthy is debatable, give the shortcomings of these RPCs. However, this marking currently doesn't have big consequences - it only affects the order in which replicas are tried when a leaseholder is unknown. --- pkg/rpc/context.go | 33 ++- pkg/rpc/context_test.go | 128 ++++++++++++ pkg/testutils/net.go | 427 ++++++++++++++++++++++++++++++++++++++ pkg/testutils/net_test.go | 421 +++++++++++++++++++++++++++++++++++++ 4 files changed, 989 insertions(+), 20 deletions(-) create mode 100644 pkg/testutils/net.go create mode 100644 pkg/testutils/net_test.go diff --git a/pkg/rpc/context.go b/pkg/rpc/context.go index f2e5ac2d7fff..b6a4ecf73cbb 100644 --- a/pkg/rpc/context.go +++ b/pkg/rpc/context.go @@ -28,8 +28,8 @@ import ( "github.com/rubyist/circuitbreaker" "golang.org/x/net/context" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/roachpb" @@ -247,7 +247,7 @@ func (ctx *Context) GRPCDial(target string, opts ...grpc.DialOption) (*grpc.Clie dialOpt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)) } - dialOpts := make([]grpc.DialOption, 0, 2+len(opts)) + var dialOpts []grpc.DialOption dialOpts = append(dialOpts, dialOpt) dialOpts = append(dialOpts, grpc.WithBackoffMaxDelay(maxBackoff)) dialOpts = append(dialOpts, grpc.WithDecompressor(snappyDecompressor{})) @@ -256,6 +256,17 @@ func (ctx *Context) GRPCDial(target string, opts ...grpc.DialOption) (*grpc.Clie if ctx.rpcCompression { dialOpts = append(dialOpts, grpc.WithCompressor(snappyCompressor{})) } + dialOpts = append(dialOpts, grpc.WithKeepaliveParams(keepalive.ClientParameters{ + // Send periodic pings on the connection. + Time: base.NetworkTimeout, + // If the pings don't get a response within the timeout, we might be + // experiencing a network partition. gRPC will close the transport-level + // connection and all the pending RPCs (which may not have timeouts) will + // fail eagerly. gRPC will then reconnect the transport transparently. + Timeout: base.NetworkTimeout, + // Do the pings even when there are no ongoing RPCs. + PermitWithoutStream: true, + })) dialOpts = append(dialOpts, opts...) if SourceAddr != nil { @@ -359,24 +370,6 @@ func (ctx *Context) runHeartbeat(meta *connMeta, remoteAddr string) error { meta.heartbeatErr = err ctx.conns.Unlock() - // If we got a timeout, we might be experiencing a network partition. We - // close the connection so that all other pending RPCs (which may not have - // timeouts) fail eagerly. Any other error is likely to be noticed by - // other RPCs, so it's OK to leave the connection open while grpc - // internally reconnects if necessary. - // - // NB: This check is skipped when the connection is initiated from a CLI - // client since those clients aren't sensitive to partitions, are likely - // to be invoked while the server is starting (particularly in tests), and - // are not equipped with the retry logic necessary to deal with this - // connection termination. - // - // TODO(tamird): That we rely on the zero maxOffset to indicate a CLI - // client is a hack; we should do something more explicit. - if maxOffset != 0 && grpc.Code(err) == codes.DeadlineExceeded { - return err - } - // HACK: work around https://github.com/grpc/grpc-go/issues/1026 // Getting a "connection refused" error from the "write" system call // has confused grpc's error handling and this connection is permanently diff --git a/pkg/rpc/context_test.go b/pkg/rpc/context_test.go index 6ea87f8bf055..e20b50ba17c1 100644 --- a/pkg/rpc/context_test.go +++ b/pkg/rpc/context_test.go @@ -17,6 +17,7 @@ package rpc import ( + "math" "net" "runtime" "sync" @@ -25,9 +26,11 @@ import ( "time" "github.com/pkg/errors" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util" @@ -551,3 +554,128 @@ func TestRemoteOffsetUnhealthy(t *testing.T) { } } } + +// This is a smoketest for gRPC Keepalives: rpc.Context asks gRPC to perform +// periodic pings on the transport to check that it's still alive. If the ping +// doesn't get a pong within a timeout, the transport is supposed to be closed - +// that's what we're testing here. +func TestGRPCKeepaliveFailureFailsInflightRPCs(t *testing.T) { + defer leaktest.AfterTest(t)() + + stopper := stop.NewStopper() + defer stopper.Stop() + + clock := hlc.NewClock(time.Unix(0, 20).UnixNano, time.Nanosecond) + serverCtx := NewContext( + log.AmbientContext{}, + testutils.NewNodeTestBaseContext(), + clock, + stopper, + ) + s, ln := newTestServer(t, serverCtx, true) + remoteAddr := ln.Addr().String() + + RegisterHeartbeatServer(s, &HeartbeatService{ + clock: clock, + remoteClockMonitor: serverCtx.RemoteClocks, + }) + + clientCtx := NewContext( + log.AmbientContext{}, testutils.NewNodeTestBaseContext(), clock, stopper) + // Disable automatic heartbeats. We'll send them by hand. + clientCtx.heartbeatInterval = math.MaxInt64 + + var firstConn int32 = 1 + + blockDialerFromCreatingConns := make(chan struct{}) + defer close(blockDialerFromCreatingConns) + + // We're going to open RPC transport connections using a dialer that returns + // PartitionableConns. We'll partition the first opened connection. + dialerCh := make(chan *testutils.PartitionableConn, 1) + conn, err := clientCtx.GRPCDial(remoteAddr, + grpc.WithDialer( + func(addr string, timeout time.Duration) (net.Conn, error) { + if !atomic.CompareAndSwapInt32(&firstConn, 1, 0) { + // If we allow gRPC to open a 2nd transport connection, then our RPCs + // might succeed if they're sent on that one. We'll block the + // connection opening, in the spirit of a partition. + <-blockDialerFromCreatingConns + return nil, errors.Errorf("the test only allows one connection") + } + + conn, err := net.DialTimeout("tcp", addr, timeout) + if err != nil { + return nil, err + } + transportConn := testutils.NewPartitionableConn(conn) + dialerCh <- transportConn + return transportConn, nil + }), + // Override the keepalive settings that the rpc.Context uses to more + // aggressive ones, so that the test doesn't take long. + grpc.WithKeepaliveParams( + keepalive.ClientParameters{ + // The aggressively low timeout we set here makes the connection very + // flaky for any RPC use, particularly when running under stress with -p + // 100. This test can't expect any RPCs to succeed reliably. + Time: time.Millisecond, + Timeout: 5 * time.Millisecond, + PermitWithoutStream: false, + }), + ) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn.Close() }() + + // We'll expect any of the errors which tests revealed that the RPC call might + // return when an RPC's transport connection is closed because of the + // heartbeats timing out. + gRPCErrorsRegex := "transport is closing|" + + "rpc error: code = Unavailable desc = grpc: the connection is unavailable|" + + "rpc error: code = Internal desc = transport: io: read/write on closed pipe|" + + "rpc error: code = Internal desc = transport: tls: use of closed connection|" + + "rpc error: code = Internal desc = transport: EOF|" + + "use of closed network connection" + + // Perform an RPC so that a connection gets opened. In theory this RPC should + // succeed (and it does when running without too much stress), but we can't + // rely on that - see comment on the timeout above. + heartbeatClient := NewHeartbeatClient(conn) + request := PingRequest{} + if _, err := heartbeatClient.Ping(context.TODO(), &request); err != nil { + if !testutils.IsError(err, gRPCErrorsRegex) { + t.Fatal(err) + } + // In the rare eventuality that we got the expected error, this test + // succeeded: even though we didn't partition the connection, the low gRPC + // heartbeats timeout caused our RPC to fail (happens under stress -p 100). + // If the heartbeats didn't timeout, we're going to simulate a network + // partition and then the heartbeats must timeout. + log.Infof(context.TODO(), "test returning early; no partition done") + return + } + + // Now partition client->server and attempt to perform an RPC. We expect it to + // fail once the grpc keepalive fails to get a response from the server. + + transportConn := <-dialerCh + defer transportConn.Finish() + + transportConn.PartitionC2S() + + if _, err := heartbeatClient.Ping(context.TODO(), &request); !testutils.IsError( + err, gRPCErrorsRegex) { + t.Fatal(err) + } + + // If the DialOptions we passed to gRPC didn't prevent it from opening new + // connections, then next RPCs would succeed since gRPC reconnects the + // transport (and that would succeed here since we've only partitioned one + // connection). We could further test that the status reported by + // Context.ConnHealth() for the remote node moves to UNAVAILABLE because of + // the (application-level) heartbeats performed by rpc.Context, but the + // behaviour of our heartbeats in the face of transport failures is + // sufficiently tested in TestHeartbeatHealthTransport. +} diff --git a/pkg/testutils/net.go b/pkg/testutils/net.go new file mode 100644 index 000000000000..029e8e1d3e6d --- /dev/null +++ b/pkg/testutils/net.go @@ -0,0 +1,427 @@ +// Copyright 2017 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// +// Author: Andrei Matei (andreimatei1@gmail.com) + +package testutils + +import ( + "io" + "net" + "sync" + + "github.com/pkg/errors" + "golang.org/x/net/context" + + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" +) + +// bufferSize is the size of the buffer used by PartitionableConn. Writes to a +// partitioned connection will block after the buffer gets filled. +const bufferSize = 16 << 10 // 16 KB + +// PartitionableConn is an implementation of net.Conn that allows the +// client->server and/or the server->client directions to be temporarily +// partitioned. +// +// A PartitionableConn wraps a provided net.Conn (the serverConn member) and +// forwards every read and write to it. It interposes an arbiter in front of it +// that's used to block reads/writes while the PartitionableConn is in the +// partitioned mode. +// +// While a direction is partitioned, data sent in that direction doesn't flow. A +// write while partitioned will block after an internal buffer gets filled. Data +// written to the conn after the partition has been established is not delivered +// to the remote party until the partition is lifted. At that time, all the +// buffered data is delivered. Since data is delivered async, data written +// before the partition is established may or may not be blocked by the +// partition; use application-level ACKs if that's important. +type PartitionableConn struct { + // We embed a net.Conn so that we inherit the interface. Note that we override + // Read() and Write(). + // + // This embedded Conn is half of a net.Pipe(). The other half is clientConn. + net.Conn + + clientConn net.Conn + serverConn net.Conn + + mu struct { + syncutil.Mutex + + // err, if set, is returned by any subsequent call to Read or Write. + err error + + // Are any of the two direction (client-to-server, server-to-client) + // currently partitioned? + c2sPartitioned bool + s2cPartitioned bool + + c2sBuffer buf + s2cBuffer buf + + // Conds to be signaled when the corresponding partition is lifted. + c2sWaiter *sync.Cond + s2cWaiter *sync.Cond + } +} + +type buf struct { + // A mutex used to synchronize access to all the fields. It will be set to the + // parent PartitionableConn's mutex. + *syncutil.Mutex + + data []byte + capacity int + closed bool + // The error that was passed to Close(err). See Close() for more info. + closedErr error + name string // A human-readable name, useful for debugging. + + // readerWait is signaled when a reader should wake up and check the buffer's + // state: when new data is put in the buffer, when the buffer is + // closed, and whenever the PartitionableConn wants to unblock all reads (i.e. + // on partition). + readerWait *sync.Cond + + // capacityWait is signaled when a blocked writer should wake up because data + // is taken out of the buffer and there's now some capacity. It's also + // signaled when the buffer is closed. + capacityWait *sync.Cond +} + +func makeBuf(name string, capacity int, mu *syncutil.Mutex) buf { + b := buf{ + Mutex: mu, + name: name, + capacity: capacity, + } + b.readerWait = sync.NewCond(b.Mutex) + b.capacityWait = sync.NewCond(b.Mutex) + return b +} + +// Write adds data to the buffer. If there's zero free capacity, it will block +// until there's some capacity available or the buffer is closed. If there's +// non-zero insufficient capacity, it will perform a partial write. +// +// The number of bytes written is returned. +func (b *buf) Write(data []byte) (int, error) { + b.Lock() + defer b.Unlock() + for b.capacity == len(b.data) && !b.closed { + // Block for capacity. + b.capacityWait.Wait() + } + if b.closed { + return 0, b.closedErr + } + available := b.capacity - len(b.data) + toCopy := available + if len(data) < available { + toCopy = len(data) + } + b.data = append(b.data, data[:toCopy]...) + b.wakeReadersLocked() + return toCopy, nil +} + +// errEAgain is returned by buf.readLocked() when the read was blocked at the +// time when buf.readerWait was signalled (in particular, after the +// PartitionableConn interrupted the read because of a partition). The caller is +// expected to try the read again after the partition is gone. +var errEAgain = errors.New("try read again") + +// readLocked returns data from buf, up to "size" bytes. If there's no data in +// the buffer, it blocks until either some data becomes available or the buffer +// is closed. +func (b *buf) readLocked(size int) ([]byte, error) { + if len(b.data) == 0 && !b.closed { + b.readerWait.Wait() + // We were unblocked either by data arrving, or by a partition, or by + // another uninteresting reason. Return to the caller, in case it's because + // of a partition. + return nil, errEAgain + } + if b.closed && len(b.data) == 0 { + return nil, b.closedErr + } + var ret []byte + if len(b.data) < size { + ret = b.data + b.data = nil + } else { + ret = b.data[:size] + b.data = b.data[size:] + } + b.capacityWait.Broadcast() + return ret, nil +} + +// Close closes the buffer. All reads and writes that are currently blocked will +// be woken and they'll all return err. +func (b *buf) Close(err error) { + b.Lock() + b.closed = true + b.closedErr = err + b.readerWait.Broadcast() + b.capacityWait.Broadcast() + b.Unlock() +} + +// wakeReadersLocked wakes all the readers that are currently blocked. +// See comments on readerWait. +// +// This needs to be called while holding the buffer's mutex. +func (b *buf) wakeReadersLocked() { + b.readerWait.Broadcast() +} + +// NewPartitionableConn wraps serverConn in a PartitionableConn. +func NewPartitionableConn(serverConn net.Conn) *PartitionableConn { + clientEnd, clientConn := net.Pipe() + c := &PartitionableConn{ + Conn: clientEnd, + clientConn: clientConn, + serverConn: serverConn, + } + c.mu.c2sWaiter = sync.NewCond(&c.mu.Mutex) + c.mu.s2cWaiter = sync.NewCond(&c.mu.Mutex) + c.mu.c2sBuffer = makeBuf("c2sBuf", bufferSize, &c.mu.Mutex) + c.mu.s2cBuffer = makeBuf("s2cBuf", bufferSize, &c.mu.Mutex) + + // Start copying from client to server. + go func() { + err := c.copy( + c.clientConn, // src + c.serverConn, // dst + &c.mu.c2sBuffer, + func() { // waitForNoPartitionLocked + for c.mu.c2sPartitioned { + c.mu.c2sWaiter.Wait() + } + }) + c.mu.Lock() + c.mu.err = err + c.mu.Unlock() + if err := c.clientConn.Close(); err != nil { + log.Errorf(context.TODO(), "unexpected error closing internal pipe: %s", err) + } + if err := c.serverConn.Close(); err != nil { + log.Errorf(context.TODO(), "error closing server conn: %s", err) + } + }() + + // Start copying from server to client. + go func() { + err := c.copy( + c.serverConn, // src + c.clientConn, // dst + &c.mu.s2cBuffer, + func() { // waitForNoPartitionLocked + for c.mu.s2cPartitioned { + c.mu.s2cWaiter.Wait() + } + }) + c.mu.Lock() + c.mu.err = err + c.mu.Unlock() + if err := c.clientConn.Close(); err != nil { + log.Fatalf(context.TODO(), "unexpected error closing internal pipe: %s", err) + } + if err := c.serverConn.Close(); err != nil { + log.Errorf(context.TODO(), "error closing server conn: %s", err) + } + }() + + return c +} + +// Finish removes any partitions that may exist so that blocked goroutines can +// finish. +// Finish() must be called if a connection may have been left in a partitioned +// state. +func (c *PartitionableConn) Finish() { + c.mu.Lock() + c.mu.c2sPartitioned = false + c.mu.c2sWaiter.Signal() + c.mu.s2cPartitioned = false + c.mu.s2cWaiter.Signal() + c.mu.Unlock() +} + +// PartitionC2S partitions the client-to-server direction. +// If UnpartitionC2S() is not called, Finish() must be called. +func (c *PartitionableConn) PartitionC2S() { + c.mu.Lock() + if c.mu.c2sPartitioned { + panic("already partitioned") + } + c.mu.c2sPartitioned = true + c.mu.c2sBuffer.wakeReadersLocked() + c.mu.Unlock() +} + +// UnpartitionC2S lifts an existing client-to-server partition. +func (c *PartitionableConn) UnpartitionC2S() { + c.mu.Lock() + if !c.mu.c2sPartitioned { + panic("not partitioned") + } + c.mu.c2sPartitioned = false + c.mu.c2sWaiter.Signal() + c.mu.Unlock() +} + +// PartitionS2C partitions the server-to-client direction. +// If UnpartitionS2C() is not called, Finish() must be called. +func (c *PartitionableConn) PartitionS2C() { + c.mu.Lock() + if c.mu.s2cPartitioned { + panic("already partitioned") + } + c.mu.s2cPartitioned = true + c.mu.s2cBuffer.wakeReadersLocked() + c.mu.Unlock() +} + +// UnpartitionS2C lifts an existing server-to-client partition. +func (c *PartitionableConn) UnpartitionS2C() { + c.mu.Lock() + if !c.mu.s2cPartitioned { + panic("not partitioned") + } + c.mu.s2cPartitioned = false + c.mu.s2cWaiter.Signal() + c.mu.Unlock() +} + +// Read is part of the net.Conn interface. +func (c *PartitionableConn) Read(b []byte) (n int, err error) { + c.mu.Lock() + err = c.mu.err + c.mu.Unlock() + if err != nil { + return 0, err + } + + // Forward to the embedded connection. + return c.Conn.Read(b) +} + +// Write is part of the net.Conn interface. +func (c *PartitionableConn) Write(b []byte) (n int, err error) { + c.mu.Lock() + err = c.mu.err + c.mu.Unlock() + if err != nil { + return 0, err + } + + // Forward to the embedded connection. + return c.Conn.Write(b) +} + +// readFrom copies data from src into the buffer until src.Read() returns an +// error (e.g. io.EOF). That error is returned. +// +// readFrom is written in the spirit of interface io.ReaderFrom, except it +// returns the io.EOF error, and also doesn't guarantee that every byte that has +// been read from src is put into the buffer (as the buffer allows concurrent +// access and buf.Write can return an error). +func (b *buf) readFrom(src io.Reader) error { + data := make([]byte, 1024) + for { + nr, err := src.Read(data) + if err != nil { + return err + } + toSend := data[:nr] + for { + nw, ew := b.Write(toSend) + if ew != nil { + return ew + } + if nw == len(toSend) { + break + } + toSend = toSend[nw:] + } + } +} + +// copyFromBuffer copies data from src to dst until src.Read() returns EOF. +// The EOF is returned (i.e. the return value is always != nil). This is because +// the PartitionableConn wants to hold on to any error, including EOF. +// +// waitForNoPartitionLocked is a function to be called before consuming data +// from src, in order to make sure that we only consume data when we're not +// partitioned. It needs to be called under src.Mutex, as the check needs to be +// done atomically with consuming the buffer's data. +func (c *PartitionableConn) copyFromBuffer( + src *buf, dst net.Conn, waitForNoPartitionLocked func(), +) error { + for { + // Don't read from the buffer while we're partitioned. + src.Mutex.Lock() + waitForNoPartitionLocked() + data, err := src.readLocked(1024 * 1024) + src.Mutex.Unlock() + + if len(data) > 0 { + nw, ew := dst.Write(data) + if ew != nil { + err = ew + } + if len(data) != nw { + err = io.ErrShortWrite + } + } else if err == nil { + err = io.EOF + } else if err == errEAgain { + continue + } + if err != nil { + return err + } + } +} + +// copy copies data from src to dst while we're not partitioned and stops doing +// so while partitioned. +// +// It runs two goroutines internally: one copying from src to an internal buffer +// and one copying from the buffer to dst. The 2nd one deals with partitions. +func (c *PartitionableConn) copy( + src net.Conn, dst net.Conn, buf *buf, waitForNoPartitionLocked func(), +) error { + tasks := make(chan error) + go func() { + err := buf.readFrom(src) + buf.Close(err) + tasks <- err + }() + go func() { + err := c.copyFromBuffer(buf, dst, waitForNoPartitionLocked) + buf.Close(err) + tasks <- err + }() + err := <-tasks + err2 := <-tasks + if err == nil { + err = err2 + } + return err +} diff --git a/pkg/testutils/net_test.go b/pkg/testutils/net_test.go new file mode 100644 index 000000000000..fecf68c3c802 --- /dev/null +++ b/pkg/testutils/net_test.go @@ -0,0 +1,421 @@ +// Copyright 2017 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// +// Author: Andrei Matei (andreimatei1@gmail.com) + +package testutils + +import ( + "bufio" + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/pkg/errors" + "golang.org/x/net/context" + + "github.com/cockroachdb/cockroach/pkg/util" + "github.com/cockroachdb/cockroach/pkg/util/grpcutil" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/netutil" +) + +// RunEchoServer runs a network server that accepts one connection from ln and +// echos the data sent on it. +// +// If serverSideCh != nil, every slice of data received by the server is also +// sent on this channel before being echoed back on the connection it came on. +// Useful to observe what the server has received when this server is used with +// partitioned connections. +func RunEchoServer(ln net.Listener, serverSideCh chan<- []byte) error { + conn, err := ln.Accept() + if err != nil { + if grpcutil.IsClosedConnection(err) { + return nil + } + return err + } + if _, err := copyWithSideChan(conn, conn, serverSideCh); err != nil { + log.Warning(context.TODO(), err) + } +} + +// copyWithSideChan is like io.Copy(), but also takes a channel on which data +// read from src is sent before being written to dst. +func copyWithSideChan(dst io.Writer, src io.Reader, ch chan<- []byte) (written int64, err error) { + buf := make([]byte, 32*1024) + for { + nr, er := src.Read(buf) + if nr > 0 { + if ch != nil { + ch <- buf[:nr] + } + + nw, ew := dst.Write(buf[0:nr]) + if nw > 0 { + written += int64(nw) + } + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + return written, err +} + +func TestPartitionableConnBasic(t *testing.T) { + defer leaktest.AfterTest(t)() + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + go func() { + if err := RunEchoServer(ln, nil); err != nil { + t.Error(err) + } + }() + defer func() { + netutil.FatalIfUnexpected(ln.Close()) + }() + + serverConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + pConn := NewPartitionableConn(serverConn) + defer pConn.Close() + + exp := "let's see if this value comes back\n" + fmt.Fprintf(pConn, exp) + got, err := bufio.NewReader(pConn).ReadString('\n') + if err != nil { + t.Fatal(err) + } + if got != exp { + t.Fatalf("expecting: %q , got %q", exp, got) + } +} + +func TestPartitionableConnPartitionC2S(t *testing.T) { + defer leaktest.AfterTest(t)() + + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + serverSideCh := make(chan []byte) + go func() { + if err := RunEchoServer(ln, serverSideCh); err != nil { + t.Error(err) + } + }() + defer func() { + netutil.FatalIfUnexpected(ln.Close()) + }() + + serverConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + pConn := NewPartitionableConn(serverConn) + defer pConn.Close() + + // Partition the client->server connection. Afterwards, we're going to send + // something and assert that the server doesn't get it (within a timeout) by + // snooping on the server's side channel. Then we'll resolve the partition and + // expect that the server gets the message that was pending and echoes it + // back. + + pConn.PartitionC2S() + + // Client sends data. + exp := "let's see when this value comes back\n" + fmt.Fprintf(pConn, exp) + + // In the background, the client waits on a read. + clientDoneCh := make(chan error) + go func() { + clientDoneCh <- func() error { + got, err := bufio.NewReader(pConn).ReadString('\n') + if err != nil { + return err + } + if got != exp { + return errors.Errorf("expecting: %q , got %q", exp, got) + } + return nil + }() + }() + + timerDoneCh := make(chan error) + time.AfterFunc(3*time.Millisecond, func() { + var err error + select { + case err = <-clientDoneCh: + err = errors.Errorf("unexpected reply while partitioned: %v", err) + case buf := <-serverSideCh: + err = errors.Errorf("server was not supposed to have received data while partitioned: %q", buf) + default: + } + timerDoneCh <- err + }) + + if err := <-timerDoneCh; err != nil { + t.Fatal(err) + } + + // Now unpartition and expect the pending data to be sent and a reply to be + // received. + + pConn.UnpartitionC2S() + + // Expect the server to receive the data. + <-serverSideCh + + if err := <-clientDoneCh; err != nil { + t.Fatal(err) + } +} + +func TestPartitionableConnPartitionS2C(t *testing.T) { + defer leaktest.AfterTest(t)() + + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + serverSideCh := make(chan []byte) + go func() { + if err := RunEchoServer(ln, serverSideCh); err != nil { + t.Error(err) + } + }() + defer func() { + netutil.FatalIfUnexpected(ln.Close()) + }() + + serverConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + pConn := NewPartitionableConn(serverConn) + defer pConn.Close() + + // We're going to partition the server->client connection. Then we'll send + // some data and assert that the server gets it (by snooping on the server's + // side-channel). Then we'll assert that the client doesn't get the reply + // (with a timeout). Then we resolve the partition and assert that the client + // gets the reply. + + pConn.PartitionS2C() + + // Client sends data. + exp := "let's see when this value comes back\n" + fmt.Fprintf(pConn, exp) + + if s := <-serverSideCh; string(s) != exp { + t.Fatalf("expected server to receive %q, got %q", exp, s) + } + + // In the background, the client waits on a read. + clientDoneCh := make(chan error) + go func() { + clientDoneCh <- func() error { + got, err := bufio.NewReader(pConn).ReadString('\n') + if err != nil { + return err + } + if got != exp { + return errors.Errorf("expecting: %q , got %q", exp, got) + } + return nil + }() + }() + + // Check that the client does not get the server's response. + time.AfterFunc(3*time.Millisecond, func() { + select { + case err := <-clientDoneCh: + t.Errorf("unexpected reply while partitioned: %v", err) + default: + } + }) + + // Now unpartition and expect the pending data to be sent and a reply to be + // received. + + pConn.UnpartitionS2C() + + if err := <-clientDoneCh; err != nil { + t.Fatal(err) + } +} + +// Test that, while partitioned, a sender doesn't block while the internal +// buffer is not full. +func TestPartitionableConnBuffering(t *testing.T) { + defer leaktest.AfterTest(t)() + + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + + // In the background, the server reads everything. + exp := 5 * (bufferSize / 10) + serverDoneCh := make(chan error) + go func() { + serverDoneCh <- func() error { + conn, err := ln.Accept() + if err != nil { + return err + } + received := 0 + for { + data := make([]byte, 1024*1024) + nr, err := conn.Read(data) + if err != nil { + if err == io.EOF { + break + } + return err + } + received += nr + } + if received != exp { + return errors.Errorf("server expecting: %d , got %d", exp, received) + } + return nil + }() + }() + + serverConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + pConn := NewPartitionableConn(serverConn) + defer pConn.Close() + + pConn.PartitionC2S() + defer pConn.Finish() + + // Send chunks such that they don't add up to the buffer size exactly. + data := make([]byte, bufferSize/10) + for i := 0; i < 5; i++ { + nw, err := pConn.Write(data) + if err != nil { + t.Fatal(err) + } + if nw != len(data) { + t.Fatal("unexpected partial write; PartitionableConn always writes fully") + } + } + pConn.UnpartitionC2S() + pConn.Close() + + if err := <-serverDoneCh; err != nil { + t.Fatal(err) + } +} + +// Test that, while partitioned, a party can close the connection and the other +// party will not observe this until after the partition is lifted. +func TestPartitionableConnCloseDeliveredAfterPartition(t *testing.T) { + defer leaktest.AfterTest(t)() + + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + + // In the background, the server reads everything. + serverDoneCh := make(chan error) + go func() { + serverDoneCh <- func() error { + conn, err := ln.Accept() + if err != nil { + return err + } + received := 0 + for { + data := make([]byte, 1024*1024) + nr, err := conn.Read(data) + if err != nil { + if err == io.EOF { + return nil + } + return err + } + received += nr + } + }() + }() + + serverConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + pConn := NewPartitionableConn(serverConn) + defer pConn.Close() + + pConn.PartitionC2S() + defer pConn.Finish() + + pConn.Close() + + timerDoneCh := make(chan error) + time.AfterFunc(3*time.Millisecond, func() { + var err error + select { + case err = <-serverDoneCh: + err = errors.Errorf("server was not supposed to see the closing while partitioned: %v", err) + default: + } + timerDoneCh <- err + }) + + if err := <-timerDoneCh; err != nil { + t.Fatal(err) + } + + pConn.UnpartitionC2S() + + if err := <-serverDoneCh; err != nil { + t.Fatal(err) + } +}