Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Ctx compress/decompress #83

Merged
merged 3 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions zstd_ctx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package zstd

/*
#define ZSTD_STATIC_LINKING_ONLY
#include "zstd.h"
#include "stdint.h" // for uintptr_t

// The following *_wrapper function are used for removing superfluous
// memory allocations when calling the wrapped functions from Go code.
// See https://github.com/golang/go/issues/24450 for details.

static size_t ZSTD_compressCCtx_wrapper(ZSTD_CCtx* cctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize, int compressionLevel) {
return ZSTD_compressCCtx(cctx, (void*)dst, maxDstSize, (const void*)src, srcSize, compressionLevel);
}

static size_t ZSTD_decompressDCtx_wrapper(ZSTD_DCtx* dctx, uintptr_t dst, size_t maxDstSize, uintptr_t src, size_t srcSize) {
return ZSTD_decompressDCtx(dctx, (void*)dst, maxDstSize, (const void *)src, srcSize);
}

*/
import "C"
import (
"bytes"
"io/ioutil"
"runtime"
"unsafe"
)

type Ctx interface {
// Compress src into dst. If you have a buffer to use, you can pass it to
// prevent allocation. If it is too small, or if nil is passed, a new buffer
// will be allocated and returned.
Compress(dst, src []byte) ([]byte, error)

// CompressLevel is the same as Compress but you can pass a compression level
CompressLevel(dst, src []byte, level int) ([]byte, error)

// Decompress src into dst. If you have a buffer to use, you can pass it to
// prevent allocation. If it is too small, or if nil is passed, a new buffer
// will be allocated and returned.
Decompress(dst, src []byte) ([]byte, error)
}

type ctx struct {
cctx *C.ZSTD_CCtx
dctx *C.ZSTD_DCtx
}

// Create a new ZStd Context.
Viq111 marked this conversation as resolved.
Show resolved Hide resolved
// When compressing/decompressing many times, it is recommended to allocate a
// context just once, and re-use it for each successive compression operation.
// This will make workload friendlier for system's memory.
// Note : re-using context is just a speed / resource optimization.
// It doesn't change the compression ratio, which remains identical.
// Note 2 : In multi-threaded environments,
// use one different context per thread for parallel execution.
//
func NewCtx() Ctx {
c := &ctx{
cctx: C.ZSTD_createCCtx(),
dctx: C.ZSTD_createDCtx(),
}

runtime.SetFinalizer(c, finalizeCtx)
return c
}

func (c *ctx) Compress(dst, src []byte) ([]byte, error) {
return c.CompressLevel(dst, src, DefaultCompression)
}

func (c *ctx) CompressLevel(dst, src []byte, level int) ([]byte, error) {
bound := CompressBound(len(src))
if cap(dst) >= bound {
dst = dst[0:bound] // Reuse dst buffer
} else {
dst = make([]byte, bound)
}

srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
if len(src) > 0 {
srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&src[0])))
}

cWritten := C.ZSTD_compressCCtx_wrapper(
c.cctx,
C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
C.size_t(len(dst)),
srcPtr,
C.size_t(len(src)),
C.int(level))

runtime.KeepAlive(src)
written := int(cWritten)
// Check if the return is an Error code
if err := getError(written); err != nil {
return nil, err
}
return dst[:written], nil
}


func (c *ctx) Decompress(dst, src []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, ErrEmptySlice
}
decompress := func(dst, src []byte) ([]byte, error) {

cWritten := C.ZSTD_decompressDCtx_wrapper(
c.dctx,
C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
C.size_t(len(dst)),
C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
C.size_t(len(src)))

runtime.KeepAlive(src)
written := int(cWritten)
// Check error
if err := getError(written); err != nil {
return nil, err
}
return dst[:written], nil
}

if len(dst) == 0 {
// Attempt to use zStd to determine decompressed size (may result in error or 0)
size := int(C.size_t(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))))

if err := getError(size); err != nil {
return nil, err
}

if size > 0 {
dst = make([]byte, size)
} else {
dst = make([]byte, len(src)*3) // starting guess
}
}
for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
result, err := decompress(dst, src)
if !IsDstSizeTooSmallError(err) {
return result, err
}
dst = make([]byte, len(dst)*2) // Grow buffer by 2
}

// We failed getting a dst buffer of correct size, use stream API
r := NewReader(bytes.NewReader(src))
defer r.Close()
return ioutil.ReadAll(r)
}

func finalizeCtx(c *ctx) {
C.ZSTD_freeCCtx(c.cctx)
C.ZSTD_freeDCtx(c.dctx)
}
172 changes: 172 additions & 0 deletions zstd_ctx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package zstd

import (
"bytes"
"testing"
)

// Test compression
func TestCtxCompressDecompress(t *testing.T) {
ctx := NewCtx()

input := []byte("Hello World!")
out, err := ctx.Compress(nil, input)
if err != nil {
t.Fatalf("Error while compressing: %v", err)
}
out2 := make([]byte, 1000)
out2, err = ctx.Compress(out2, input)
if err != nil {
t.Fatalf("Error while compressing: %v", err)
}
t.Logf("Compressed: %v", out)

rein, err := ctx.Decompress(nil, out)
if err != nil {
t.Fatalf("Error while decompressing: %v", err)
}
rein2 := make([]byte, 10)
rein2, err = ctx.Decompress(rein2, out2)
if err != nil {
t.Fatalf("Error while decompressing: %v", err)
}

if string(input) != string(rein) {
t.Fatalf("Cannot compress and decompress: %s != %s", input, rein)
}
if string(input) != string(rein2) {
t.Fatalf("Cannot compress and decompress: %s != %s", input, rein)
}
}

func TestCtxEmptySliceCompress(t *testing.T) {
ctx := NewCtx()

compressed, err := ctx.Compress(nil, []byte{})
if err != nil {
t.Fatalf("Error while compressing: %v", err)
}
t.Logf("Compressing empty slice gives 0x%x", compressed)
decompressed, err := ctx.Decompress(nil, compressed)
if err != nil {
t.Fatalf("Error while compressing: %v", err)
}
if string(decompressed) != "" {
t.Fatalf("Expected empty slice as decompressed, got %v instead", decompressed)
}
}

func TestCtxEmptySliceDecompress(t *testing.T) {
ctx := NewCtx()

_, err := ctx.Decompress(nil, []byte{})
if err != ErrEmptySlice {
t.Fatalf("Did not get the correct error: %s", err)
}
}

func TestCtxDecompressZeroLengthBuf(t *testing.T) {
ctx := NewCtx()

input := []byte("Hello World!")
out, err := ctx.Compress(nil, input)
if err != nil {
t.Fatalf("Error while compressing: %v", err)
}

buf := make([]byte, 0)
decompressed, err := ctx.Decompress(buf, out)
if err != nil {
t.Fatalf("Error while decompressing: %v", err)
}

if res, exp := string(input), string(decompressed); res != exp {
t.Fatalf("expected %s but decompressed to %s", exp, res)
}
}

func TestCtxTooSmall(t *testing.T) {
ctx := NewCtx()

var long bytes.Buffer
for i := 0; i < 10000; i++ {
long.Write([]byte("Hellow World!"))
}
input := long.Bytes()
out, err := ctx.Compress(nil, input)
if err != nil {
t.Fatalf("Error while compressing: %v", err)
}
rein := make([]byte, 1)
// This should switch to the decompression stream to handle too small dst
rein, err = ctx.Decompress(rein, out)
if err != nil {
t.Fatalf("Failed decompressing: %s", err)
}
if string(input) != string(rein) {
t.Fatalf("Cannot compress and decompress: %s != %s", input, rein)
}
}

func TestCtxRealPayload(t *testing.T) {
ctx := NewCtx()

if raw == nil {
t.Skip(ErrNoPayloadEnv)
}
dst, err := ctx.Compress(nil, raw)
if err != nil {
t.Fatalf("Failed to compress: %s", err)
}
rein, err := ctx.Decompress(nil, dst)
if err != nil {
t.Fatalf("Failed to decompress: %s", err)
}
if string(raw) != string(rein) {
t.Fatalf("compressed/decompressed payloads are not the same (lengths: %v & %v)", len(raw), len(rein))
}
}

func BenchmarkCtxCompression(b *testing.B) {
ctx := NewCtx()

if raw == nil {
b.Fatal(ErrNoPayloadEnv)
}
dst := make([]byte, CompressBound(len(raw)))
b.SetBytes(int64(len(raw)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := ctx.Compress(dst, raw)
if err != nil {
b.Fatalf("Failed compressing: %s", err)
}
}
}

func BenchmarkCtxDecompression(b *testing.B) {
ctx := NewCtx()

if raw == nil {
b.Fatal(ErrNoPayloadEnv)
}
src := make([]byte, len(raw))
dst, err := ctx.Compress(nil, raw)
if err != nil {
b.Fatalf("Failed compressing: %s", err)
}
b.Logf("Reduced from %v to %v", len(raw), len(dst))
b.SetBytes(int64(len(raw)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
src2, err := ctx.Decompress(src, dst)
if err != nil {
b.Fatalf("Failed decompressing: %s", err)
}
b.StopTimer()
if !bytes.Equal(raw, src2) {
b.Fatalf("Results are not the same: %v != %v", len(raw), len(src2))
}
b.StartTimer()
}
}