diff --git a/go.mod b/go.mod index 4e3c78e..02b140d 100644 --- a/go.mod +++ b/go.mod @@ -1 +1,3 @@ module github.com/libp2p/go-buffer-pool + +go 1.12 diff --git a/pool_test.go b/pool_test.go index f517167..1941144 100644 --- a/pool_test.go +++ b/pool_test.go @@ -139,11 +139,27 @@ func BenchmarkPool(b *testing.B) { i = i << 1 } b := p.Get(i) + b[0] = byte(i) p.Put(b) } }) } +func BenchmarkAlloc(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + i := 7 + for pb.Next() { + if i > 1<<20 { + i = 7 + } else { + i = i << 1 + } + b := make([]byte, i) + b[1] = byte(i) + } + }) +} + func BenchmarkPoolOverlflow(b *testing.B) { var p BufferPool b.RunParallel(func(pb *testing.PB) { diff --git a/writer.go b/writer.go new file mode 100644 index 0000000..cea83f9 --- /dev/null +++ b/writer.go @@ -0,0 +1,119 @@ +package pool + +import ( + "bufio" + "io" + "sync" +) + +const WriterBufferSize = 4096 + +var bufioWriterPool = sync.Pool{ + New: func() interface{} { + return bufio.NewWriterSize(nil, WriterBufferSize) + }, +} + +// Writer is a buffered writer that returns its internal buffer in a pool when +// not in use. +type Writer struct { + W io.Writer + bufw *bufio.Writer +} + +func (w *Writer) ensureBuffer() { + if w.bufw == nil { + w.bufw = bufioWriterPool.Get().(*bufio.Writer) + w.bufw.Reset(w.W) + } +} + +// Write writes the given byte slice to the underlying connection. +// +// Note: Write won't return the write buffer to the pool even if it ends up +// being empty after the write. You must call Flush() to do that. +func (w *Writer) Write(b []byte) (int, error) { + if w.bufw == nil { + if len(b) >= WriterBufferSize { + return w.W.Write(b) + } + w.bufw = bufioWriterPool.Get().(*bufio.Writer) + w.bufw.Reset(w.W) + } + return w.bufw.Write(b) +} + +// Size returns the size of the underlying buffer. +func (w *Writer) Size() int { + return WriterBufferSize +} + +// Available returns the amount buffer space available. +func (w *Writer) Available() int { + if w.bufw != nil { + return w.bufw.Available() + } + return WriterBufferSize +} + +// Buffered returns the amount of data buffered. +func (w *Writer) Buffered() int { + if w.bufw != nil { + return w.bufw.Buffered() + } + return 0 +} + +// WriteByte writes a single byte. +func (w *Writer) WriteByte(b byte) error { + w.ensureBuffer() + return w.bufw.WriteByte(b) +} + +// WriteRune writes a single rune, returning the number of bytes written. +func (w *Writer) WriteRune(r rune) (int, error) { + w.ensureBuffer() + return w.bufw.WriteRune(r) +} + +// WriteString writes a string, returning the number of bytes written. +func (w *Writer) WriteString(s string) (int, error) { + w.ensureBuffer() + return w.bufw.WriteString(s) +} + +// Flush flushes the write buffer, if any, and returns it to the pool. +func (w *Writer) Flush() error { + if w.bufw == nil { + return nil + } + if err := w.bufw.Flush(); err != nil { + return err + } + w.bufw.Reset(nil) + bufioWriterPool.Put(w.bufw) + w.bufw = nil + return nil +} + +// Close flushes the underlying writer and closes it if it implements the +// io.Closer interface. +// +// Note: Close() closes the writer even if Flush() fails to avoid leaking system +// resources. If you want to make sure Flush() succeeds, call it first. +func (w *Writer) Close() error { + var ( + ferr, cerr error + ) + ferr = w.Flush() + + // always close even if flush fails. + if closer, ok := w.W.(io.Closer); ok { + cerr = closer.Close() + } + + if ferr != nil { + return ferr + } + return cerr +} diff --git a/writer_test.go b/writer_test.go new file mode 100644 index 0000000..ae57520 --- /dev/null +++ b/writer_test.go @@ -0,0 +1,91 @@ +package pool + +import ( + "bytes" + "testing" +) + +func checkSize(t *testing.T, w *Writer) { + if w.Size()-w.Buffered() != w.Available() { + t.Fatalf("size (%d), buffered (%d), available (%d) mismatch", w.Size(), w.Buffered(), w.Available()) + } +} + +func TestWriter(t *testing.T) { + var b bytes.Buffer + w := Writer{W: &b} + n, err := w.Write([]byte("foobar")) + checkSize(t, &w) + + if err != nil || n != 6 { + t.Fatalf("write failed: %d, %s", n, err) + } + if b.Len() != 0 { + t.Fatal("expected the buffer to be empty") + } + if w.Buffered() != 6 { + t.Fatalf("expected 6 bytes to be buffered, got %d", w.Buffered()) + } + checkSize(t, &w) + if err := w.Flush(); err != nil { + t.Fatal(err) + } + checkSize(t, &w) + if err := w.Flush(); err != nil { + t.Fatal(err) + } + checkSize(t, &w) + if b.String() != "foobar" { + t.Fatal("expected to have written foobar") + } + b.Reset() + + buf := make([]byte, WriterBufferSize) + n, err = w.Write(buf) + if n != WriterBufferSize || err != nil { + t.Fatalf("write failed: %d, %s", n, err) + } + checkSize(t, &w) + if b.Len() != WriterBufferSize { + t.Fatal("large write should have gone through directly") + } + if err := w.Flush(); err != nil { + t.Fatal(err) + } + checkSize(t, &w) + + b.Reset() + if err := w.WriteByte(1); err != nil { + t.Fatal(err) + } + if w.Buffered() != 1 { + t.Fatalf("expected 1 byte to be buffered, got %d", w.Buffered()) + } + if n, err := w.WriteRune('1'); err != nil || n != 1 { + t.Fatal(err) + } + if w.Buffered() != 2 { + t.Fatalf("expected 2 bytes to be buffered, got %d", w.Buffered()) + } + checkSize(t, &w) + if n, err := w.WriteString("foobar"); err != nil || n != 6 { + t.Fatal(err) + } + if w.Buffered() != 8 { + t.Fatalf("expected 8 bytes to be buffered, got %d", w.Buffered()) + } + checkSize(t, &w) + if b.Len() != 0 { + t.Fatal("write should have been buffered") + } + n, err = w.Write(buf) + if n != WriterBufferSize || err != nil { + t.Fatalf("write failed: %d, %s", n, err) + } + if b.Len() != WriterBufferSize || b.Bytes()[0] != 1 || b.String()[1:8] != "1foobar" { + t.Fatalf("failed to flush properly: len:%d, prefix:%#v", b.Len(), b.Bytes()[:10]) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } +}