Skip to content

Commit

Permalink
Merge pull request #83 from merlimat/ctx-compress
Browse files Browse the repository at this point in the history
Added Ctx compress/decompress
  • Loading branch information
Viq111 authored Jun 17, 2020
2 parents 0e71ac6 + fda5922 commit 89f69fb
Show file tree
Hide file tree
Showing 2 changed files with 328 additions and 0 deletions.
156 changes: 156 additions & 0 deletions zstd_ctx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package zstd

/*
#define ZSTD_STATIC_LINKING_ONLY
#include "zstd.h"
#include "stdint.h" // for uintptr_t
// The following *_wrapper function are used for removing superfluous
// memory allocations when calling the wrapped functions from Go code.
// See https://github.com/golang/go/issues/24450 for details.
static size_t ZSTD_compressCCtx_wrapper(ZSTD_CCtx* cctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize, int compressionLevel) {
return ZSTD_compressCCtx(cctx, (void*)dst, maxDstSize, (const void*)src, srcSize, compressionLevel);
}
static size_t ZSTD_decompressDCtx_wrapper(ZSTD_DCtx* dctx, uintptr_t dst, size_t maxDstSize, uintptr_t src, size_t srcSize) {
return ZSTD_decompressDCtx(dctx, (void*)dst, maxDstSize, (const void *)src, srcSize);
}
*/
import "C"
import (
"bytes"
"io/ioutil"
"runtime"
"unsafe"
)

type Ctx interface {
// Compress src into dst. If you have a buffer to use, you can pass it to
// prevent allocation. If it is too small, or if nil is passed, a new buffer
// will be allocated and returned.
Compress(dst, src []byte) ([]byte, error)

// CompressLevel is the same as Compress but you can pass a compression level
CompressLevel(dst, src []byte, level int) ([]byte, error)

// Decompress src into dst. If you have a buffer to use, you can pass it to
// prevent allocation. If it is too small, or if nil is passed, a new buffer
// will be allocated and returned.
Decompress(dst, src []byte) ([]byte, error)
}

type ctx struct {
cctx *C.ZSTD_CCtx
dctx *C.ZSTD_DCtx
}

// Create a new ZStd Context.
// When compressing/decompressing many times, it is recommended to allocate a
// context just once, and re-use it for each successive compression operation.
// This will make workload friendlier for system's memory.
// Note : re-using context is just a speed / resource optimization.
// It doesn't change the compression ratio, which remains identical.
// Note 2 : In multi-threaded environments,
// use one different context per thread for parallel execution.
//
func NewCtx() Ctx {
c := &ctx{
cctx: C.ZSTD_createCCtx(),
dctx: C.ZSTD_createDCtx(),
}

runtime.SetFinalizer(c, finalizeCtx)
return c
}

func (c *ctx) Compress(dst, src []byte) ([]byte, error) {
return c.CompressLevel(dst, src, DefaultCompression)
}

func (c *ctx) CompressLevel(dst, src []byte, level int) ([]byte, error) {
bound := CompressBound(len(src))
if cap(dst) >= bound {
dst = dst[0:bound] // Reuse dst buffer
} else {
dst = make([]byte, bound)
}

srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
if len(src) > 0 {
srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&src[0])))
}

cWritten := C.ZSTD_compressCCtx_wrapper(
c.cctx,
C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
C.size_t(len(dst)),
srcPtr,
C.size_t(len(src)),
C.int(level))

runtime.KeepAlive(src)
written := int(cWritten)
// Check if the return is an Error code
if err := getError(written); err != nil {
return nil, err
}
return dst[:written], nil
}


func (c *ctx) Decompress(dst, src []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, ErrEmptySlice
}
decompress := func(dst, src []byte) ([]byte, error) {

cWritten := C.ZSTD_decompressDCtx_wrapper(
c.dctx,
C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
C.size_t(len(dst)),
C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
C.size_t(len(src)))

runtime.KeepAlive(src)
written := int(cWritten)
// Check error
if err := getError(written); err != nil {
return nil, err
}
return dst[:written], nil
}

if len(dst) == 0 {
// Attempt to use zStd to determine decompressed size (may result in error or 0)
size := int(C.size_t(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))))

if err := getError(size); err != nil {
return nil, err
}

if size > 0 {
dst = make([]byte, size)
} else {
dst = make([]byte, len(src)*3) // starting guess
}
}
for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
result, err := decompress(dst, src)
if !IsDstSizeTooSmallError(err) {
return result, err
}
dst = make([]byte, len(dst)*2) // Grow buffer by 2
}

// We failed getting a dst buffer of correct size, use stream API
r := NewReader(bytes.NewReader(src))
defer r.Close()
return ioutil.ReadAll(r)
}

func finalizeCtx(c *ctx) {
C.ZSTD_freeCCtx(c.cctx)
C.ZSTD_freeDCtx(c.dctx)
}
172 changes: 172 additions & 0 deletions zstd_ctx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package zstd

import (
"bytes"
"testing"
)

// Test compression
func TestCtxCompressDecompress(t *testing.T) {
ctx := NewCtx()

input := []byte("Hello World!")
out, err := ctx.Compress(nil, input)
if err != nil {
t.Fatalf("Error while compressing: %v", err)
}
out2 := make([]byte, 1000)
out2, err = ctx.Compress(out2, input)
if err != nil {
t.Fatalf("Error while compressing: %v", err)
}
t.Logf("Compressed: %v", out)

rein, err := ctx.Decompress(nil, out)
if err != nil {
t.Fatalf("Error while decompressing: %v", err)
}
rein2 := make([]byte, 10)
rein2, err = ctx.Decompress(rein2, out2)
if err != nil {
t.Fatalf("Error while decompressing: %v", err)
}

if string(input) != string(rein) {
t.Fatalf("Cannot compress and decompress: %s != %s", input, rein)
}
if string(input) != string(rein2) {
t.Fatalf("Cannot compress and decompress: %s != %s", input, rein)
}
}

func TestCtxEmptySliceCompress(t *testing.T) {
ctx := NewCtx()

compressed, err := ctx.Compress(nil, []byte{})
if err != nil {
t.Fatalf("Error while compressing: %v", err)
}
t.Logf("Compressing empty slice gives 0x%x", compressed)
decompressed, err := ctx.Decompress(nil, compressed)
if err != nil {
t.Fatalf("Error while compressing: %v", err)
}
if string(decompressed) != "" {
t.Fatalf("Expected empty slice as decompressed, got %v instead", decompressed)
}
}

func TestCtxEmptySliceDecompress(t *testing.T) {
ctx := NewCtx()

_, err := ctx.Decompress(nil, []byte{})
if err != ErrEmptySlice {
t.Fatalf("Did not get the correct error: %s", err)
}
}

func TestCtxDecompressZeroLengthBuf(t *testing.T) {
ctx := NewCtx()

input := []byte("Hello World!")
out, err := ctx.Compress(nil, input)
if err != nil {
t.Fatalf("Error while compressing: %v", err)
}

buf := make([]byte, 0)
decompressed, err := ctx.Decompress(buf, out)
if err != nil {
t.Fatalf("Error while decompressing: %v", err)
}

if res, exp := string(input), string(decompressed); res != exp {
t.Fatalf("expected %s but decompressed to %s", exp, res)
}
}

func TestCtxTooSmall(t *testing.T) {
ctx := NewCtx()

var long bytes.Buffer
for i := 0; i < 10000; i++ {
long.Write([]byte("Hellow World!"))
}
input := long.Bytes()
out, err := ctx.Compress(nil, input)
if err != nil {
t.Fatalf("Error while compressing: %v", err)
}
rein := make([]byte, 1)
// This should switch to the decompression stream to handle too small dst
rein, err = ctx.Decompress(rein, out)
if err != nil {
t.Fatalf("Failed decompressing: %s", err)
}
if string(input) != string(rein) {
t.Fatalf("Cannot compress and decompress: %s != %s", input, rein)
}
}

func TestCtxRealPayload(t *testing.T) {
ctx := NewCtx()

if raw == nil {
t.Skip(ErrNoPayloadEnv)
}
dst, err := ctx.Compress(nil, raw)
if err != nil {
t.Fatalf("Failed to compress: %s", err)
}
rein, err := ctx.Decompress(nil, dst)
if err != nil {
t.Fatalf("Failed to decompress: %s", err)
}
if string(raw) != string(rein) {
t.Fatalf("compressed/decompressed payloads are not the same (lengths: %v & %v)", len(raw), len(rein))
}
}

func BenchmarkCtxCompression(b *testing.B) {
ctx := NewCtx()

if raw == nil {
b.Fatal(ErrNoPayloadEnv)
}
dst := make([]byte, CompressBound(len(raw)))
b.SetBytes(int64(len(raw)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := ctx.Compress(dst, raw)
if err != nil {
b.Fatalf("Failed compressing: %s", err)
}
}
}

func BenchmarkCtxDecompression(b *testing.B) {
ctx := NewCtx()

if raw == nil {
b.Fatal(ErrNoPayloadEnv)
}
src := make([]byte, len(raw))
dst, err := ctx.Compress(nil, raw)
if err != nil {
b.Fatalf("Failed compressing: %s", err)
}
b.Logf("Reduced from %v to %v", len(raw), len(dst))
b.SetBytes(int64(len(raw)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
src2, err := ctx.Decompress(src, dst)
if err != nil {
b.Fatalf("Failed decompressing: %s", err)
}
b.StopTimer()
if !bytes.Equal(raw, src2) {
b.Fatalf("Results are not the same: %v != %v", len(raw), len(src2))
}
b.StartTimer()
}
}

0 comments on commit 89f69fb

Please sign in to comment.