diff --git a/dht_net.go b/dht_net.go index f6261089b..0af33f38f 100644 --- a/dht_net.go +++ b/dht_net.go @@ -21,11 +21,6 @@ import ( var dhtReadMessageTimeout = time.Minute var ErrReadTimeout = fmt.Errorf("timed out reading response") -type bufferedWriteCloser interface { - ggio.WriteCloser - Flush() error -} - // The Protobuf writer performs multiple small writes when writing a message. // We need to buffer those writes, to make sure that we're not sending a new // packet for every single write. @@ -34,12 +29,26 @@ type bufferedDelimitedWriter struct { ggio.WriteCloser } -func newBufferedDelimitedWriter(str io.Writer) bufferedWriteCloser { - w := bufio.NewWriter(str) - return &bufferedDelimitedWriter{ - Writer: w, - WriteCloser: ggio.NewDelimitedWriter(w), +var writerPool = sync.Pool{ + New: func() interface{} { + w := bufio.NewWriter(nil) + return &bufferedDelimitedWriter{ + Writer: w, + WriteCloser: ggio.NewDelimitedWriter(w), + } + }, +} + +func writeMsg(w io.Writer, mes *pb.Message) error { + bw := writerPool.Get().(*bufferedDelimitedWriter) + bw.Reset(w) + err := bw.WriteMsg(mes) + if err == nil { + err = bw.Flush() } + bw.Reset(nil) + writerPool.Put(bw) + return err } func (w *bufferedDelimitedWriter) Flush() error { @@ -62,7 +71,6 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) bool { cr := ctxio.NewReader(ctx, s) // ok to use. we defer close stream in this func cw := ctxio.NewWriter(ctx, s) // ok to use. we defer close stream in this func r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax) - w := newBufferedDelimitedWriter(cw) mPeer := s.Conn().RemotePeer() for { @@ -118,10 +126,7 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) bool { } // send out response msg - err = w.WriteMsg(resp) - if err == nil { - err = w.Flush() - } + err = writeMsg(cw, resp) if err != nil { stats.Record(ctx, metrics.ReceivedMessageErrors.M(1)) logger.Debugf("error writing response: %v", err) @@ -237,7 +242,6 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa type messageSender struct { s inet.Stream r ggio.ReadCloser - w bufferedWriteCloser lk sync.Mutex p peer.ID dht *IpfsDHT @@ -281,7 +285,6 @@ func (ms *messageSender) prep(ctx context.Context) error { } ms.r = ggio.NewDelimitedReader(nstr, inet.MessageSizeMax) - ms.w = newBufferedDelimitedWriter(nstr) ms.s = nstr return nil @@ -377,10 +380,7 @@ func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb } func (ms *messageSender) writeMsg(pmes *pb.Message) error { - if err := ms.w.WriteMsg(pmes); err != nil { - return err - } - return ms.w.Flush() + return writeMsg(ms.s, pmes) } func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error {