-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #83 from merlimat/ctx-compress
Added Ctx compress/decompress
- Loading branch information
Showing
2 changed files
with
328 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
package zstd | ||
|
||
/* | ||
#define ZSTD_STATIC_LINKING_ONLY | ||
#include "zstd.h" | ||
#include "stdint.h" // for uintptr_t | ||
// The following *_wrapper function are used for removing superfluous | ||
// memory allocations when calling the wrapped functions from Go code. | ||
// See https://github.com/golang/go/issues/24450 for details. | ||
static size_t ZSTD_compressCCtx_wrapper(ZSTD_CCtx* cctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize, int compressionLevel) { | ||
return ZSTD_compressCCtx(cctx, (void*)dst, maxDstSize, (const void*)src, srcSize, compressionLevel); | ||
} | ||
static size_t ZSTD_decompressDCtx_wrapper(ZSTD_DCtx* dctx, uintptr_t dst, size_t maxDstSize, uintptr_t src, size_t srcSize) { | ||
return ZSTD_decompressDCtx(dctx, (void*)dst, maxDstSize, (const void *)src, srcSize); | ||
} | ||
*/ | ||
import "C" | ||
import ( | ||
"bytes" | ||
"io/ioutil" | ||
"runtime" | ||
"unsafe" | ||
) | ||
|
||
type Ctx interface { | ||
// Compress src into dst. If you have a buffer to use, you can pass it to | ||
// prevent allocation. If it is too small, or if nil is passed, a new buffer | ||
// will be allocated and returned. | ||
Compress(dst, src []byte) ([]byte, error) | ||
|
||
// CompressLevel is the same as Compress but you can pass a compression level | ||
CompressLevel(dst, src []byte, level int) ([]byte, error) | ||
|
||
// Decompress src into dst. If you have a buffer to use, you can pass it to | ||
// prevent allocation. If it is too small, or if nil is passed, a new buffer | ||
// will be allocated and returned. | ||
Decompress(dst, src []byte) ([]byte, error) | ||
} | ||
|
||
type ctx struct { | ||
cctx *C.ZSTD_CCtx | ||
dctx *C.ZSTD_DCtx | ||
} | ||
|
||
// Create a new ZStd Context. | ||
// When compressing/decompressing many times, it is recommended to allocate a | ||
// context just once, and re-use it for each successive compression operation. | ||
// This will make workload friendlier for system's memory. | ||
// Note : re-using context is just a speed / resource optimization. | ||
// It doesn't change the compression ratio, which remains identical. | ||
// Note 2 : In multi-threaded environments, | ||
// use one different context per thread for parallel execution. | ||
// | ||
func NewCtx() Ctx { | ||
c := &ctx{ | ||
cctx: C.ZSTD_createCCtx(), | ||
dctx: C.ZSTD_createDCtx(), | ||
} | ||
|
||
runtime.SetFinalizer(c, finalizeCtx) | ||
return c | ||
} | ||
|
||
func (c *ctx) Compress(dst, src []byte) ([]byte, error) { | ||
return c.CompressLevel(dst, src, DefaultCompression) | ||
} | ||
|
||
func (c *ctx) CompressLevel(dst, src []byte, level int) ([]byte, error) { | ||
bound := CompressBound(len(src)) | ||
if cap(dst) >= bound { | ||
dst = dst[0:bound] // Reuse dst buffer | ||
} else { | ||
dst = make([]byte, bound) | ||
} | ||
|
||
srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty | ||
if len(src) > 0 { | ||
srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))) | ||
} | ||
|
||
cWritten := C.ZSTD_compressCCtx_wrapper( | ||
c.cctx, | ||
C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))), | ||
C.size_t(len(dst)), | ||
srcPtr, | ||
C.size_t(len(src)), | ||
C.int(level)) | ||
|
||
runtime.KeepAlive(src) | ||
written := int(cWritten) | ||
// Check if the return is an Error code | ||
if err := getError(written); err != nil { | ||
return nil, err | ||
} | ||
return dst[:written], nil | ||
} | ||
|
||
|
||
func (c *ctx) Decompress(dst, src []byte) ([]byte, error) { | ||
if len(src) == 0 { | ||
return []byte{}, ErrEmptySlice | ||
} | ||
decompress := func(dst, src []byte) ([]byte, error) { | ||
|
||
cWritten := C.ZSTD_decompressDCtx_wrapper( | ||
c.dctx, | ||
C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))), | ||
C.size_t(len(dst)), | ||
C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))), | ||
C.size_t(len(src))) | ||
|
||
runtime.KeepAlive(src) | ||
written := int(cWritten) | ||
// Check error | ||
if err := getError(written); err != nil { | ||
return nil, err | ||
} | ||
return dst[:written], nil | ||
} | ||
|
||
if len(dst) == 0 { | ||
// Attempt to use zStd to determine decompressed size (may result in error or 0) | ||
size := int(C.size_t(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))) | ||
|
||
if err := getError(size); err != nil { | ||
return nil, err | ||
} | ||
|
||
if size > 0 { | ||
dst = make([]byte, size) | ||
} else { | ||
dst = make([]byte, len(src)*3) // starting guess | ||
} | ||
} | ||
for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer | ||
result, err := decompress(dst, src) | ||
if !IsDstSizeTooSmallError(err) { | ||
return result, err | ||
} | ||
dst = make([]byte, len(dst)*2) // Grow buffer by 2 | ||
} | ||
|
||
// We failed getting a dst buffer of correct size, use stream API | ||
r := NewReader(bytes.NewReader(src)) | ||
defer r.Close() | ||
return ioutil.ReadAll(r) | ||
} | ||
|
||
func finalizeCtx(c *ctx) { | ||
C.ZSTD_freeCCtx(c.cctx) | ||
C.ZSTD_freeDCtx(c.dctx) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
package zstd | ||
|
||
import ( | ||
"bytes" | ||
"testing" | ||
) | ||
|
||
// Test compression | ||
func TestCtxCompressDecompress(t *testing.T) { | ||
ctx := NewCtx() | ||
|
||
input := []byte("Hello World!") | ||
out, err := ctx.Compress(nil, input) | ||
if err != nil { | ||
t.Fatalf("Error while compressing: %v", err) | ||
} | ||
out2 := make([]byte, 1000) | ||
out2, err = ctx.Compress(out2, input) | ||
if err != nil { | ||
t.Fatalf("Error while compressing: %v", err) | ||
} | ||
t.Logf("Compressed: %v", out) | ||
|
||
rein, err := ctx.Decompress(nil, out) | ||
if err != nil { | ||
t.Fatalf("Error while decompressing: %v", err) | ||
} | ||
rein2 := make([]byte, 10) | ||
rein2, err = ctx.Decompress(rein2, out2) | ||
if err != nil { | ||
t.Fatalf("Error while decompressing: %v", err) | ||
} | ||
|
||
if string(input) != string(rein) { | ||
t.Fatalf("Cannot compress and decompress: %s != %s", input, rein) | ||
} | ||
if string(input) != string(rein2) { | ||
t.Fatalf("Cannot compress and decompress: %s != %s", input, rein) | ||
} | ||
} | ||
|
||
func TestCtxEmptySliceCompress(t *testing.T) { | ||
ctx := NewCtx() | ||
|
||
compressed, err := ctx.Compress(nil, []byte{}) | ||
if err != nil { | ||
t.Fatalf("Error while compressing: %v", err) | ||
} | ||
t.Logf("Compressing empty slice gives 0x%x", compressed) | ||
decompressed, err := ctx.Decompress(nil, compressed) | ||
if err != nil { | ||
t.Fatalf("Error while compressing: %v", err) | ||
} | ||
if string(decompressed) != "" { | ||
t.Fatalf("Expected empty slice as decompressed, got %v instead", decompressed) | ||
} | ||
} | ||
|
||
func TestCtxEmptySliceDecompress(t *testing.T) { | ||
ctx := NewCtx() | ||
|
||
_, err := ctx.Decompress(nil, []byte{}) | ||
if err != ErrEmptySlice { | ||
t.Fatalf("Did not get the correct error: %s", err) | ||
} | ||
} | ||
|
||
func TestCtxDecompressZeroLengthBuf(t *testing.T) { | ||
ctx := NewCtx() | ||
|
||
input := []byte("Hello World!") | ||
out, err := ctx.Compress(nil, input) | ||
if err != nil { | ||
t.Fatalf("Error while compressing: %v", err) | ||
} | ||
|
||
buf := make([]byte, 0) | ||
decompressed, err := ctx.Decompress(buf, out) | ||
if err != nil { | ||
t.Fatalf("Error while decompressing: %v", err) | ||
} | ||
|
||
if res, exp := string(input), string(decompressed); res != exp { | ||
t.Fatalf("expected %s but decompressed to %s", exp, res) | ||
} | ||
} | ||
|
||
func TestCtxTooSmall(t *testing.T) { | ||
ctx := NewCtx() | ||
|
||
var long bytes.Buffer | ||
for i := 0; i < 10000; i++ { | ||
long.Write([]byte("Hellow World!")) | ||
} | ||
input := long.Bytes() | ||
out, err := ctx.Compress(nil, input) | ||
if err != nil { | ||
t.Fatalf("Error while compressing: %v", err) | ||
} | ||
rein := make([]byte, 1) | ||
// This should switch to the decompression stream to handle too small dst | ||
rein, err = ctx.Decompress(rein, out) | ||
if err != nil { | ||
t.Fatalf("Failed decompressing: %s", err) | ||
} | ||
if string(input) != string(rein) { | ||
t.Fatalf("Cannot compress and decompress: %s != %s", input, rein) | ||
} | ||
} | ||
|
||
func TestCtxRealPayload(t *testing.T) { | ||
ctx := NewCtx() | ||
|
||
if raw == nil { | ||
t.Skip(ErrNoPayloadEnv) | ||
} | ||
dst, err := ctx.Compress(nil, raw) | ||
if err != nil { | ||
t.Fatalf("Failed to compress: %s", err) | ||
} | ||
rein, err := ctx.Decompress(nil, dst) | ||
if err != nil { | ||
t.Fatalf("Failed to decompress: %s", err) | ||
} | ||
if string(raw) != string(rein) { | ||
t.Fatalf("compressed/decompressed payloads are not the same (lengths: %v & %v)", len(raw), len(rein)) | ||
} | ||
} | ||
|
||
func BenchmarkCtxCompression(b *testing.B) { | ||
ctx := NewCtx() | ||
|
||
if raw == nil { | ||
b.Fatal(ErrNoPayloadEnv) | ||
} | ||
dst := make([]byte, CompressBound(len(raw))) | ||
b.SetBytes(int64(len(raw))) | ||
b.ResetTimer() | ||
for i := 0; i < b.N; i++ { | ||
_, err := ctx.Compress(dst, raw) | ||
if err != nil { | ||
b.Fatalf("Failed compressing: %s", err) | ||
} | ||
} | ||
} | ||
|
||
func BenchmarkCtxDecompression(b *testing.B) { | ||
ctx := NewCtx() | ||
|
||
if raw == nil { | ||
b.Fatal(ErrNoPayloadEnv) | ||
} | ||
src := make([]byte, len(raw)) | ||
dst, err := ctx.Compress(nil, raw) | ||
if err != nil { | ||
b.Fatalf("Failed compressing: %s", err) | ||
} | ||
b.Logf("Reduced from %v to %v", len(raw), len(dst)) | ||
b.SetBytes(int64(len(raw))) | ||
b.ResetTimer() | ||
for i := 0; i < b.N; i++ { | ||
src2, err := ctx.Decompress(src, dst) | ||
if err != nil { | ||
b.Fatalf("Failed decompressing: %s", err) | ||
} | ||
b.StopTimer() | ||
if !bytes.Equal(raw, src2) { | ||
b.Fatalf("Results are not the same: %v != %v", len(raw), len(src2)) | ||
} | ||
b.StartTimer() | ||
} | ||
} |