Skip to content

Commit

Permalink
Fix by not estimating decompressed size; another allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
milesgranger committed Feb 19, 2021
1 parent 892ed60 commit 950e3f9
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 92 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "cramjam"
version = "2.0.1"
version = "2.0.2"
authors = ["Miles Granger <[email protected]>"]
edition = "2018"
license-file = "LICENSE"
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ bench:
python -m pytest -v --benchmark-only --benchmark-sort name benchmarks/

bench-snappy:
$(BASE_BENCH_CMD) snappy
$(BASE_BENCH_CMD) test_snappy

bench-snappy-compress-into:
$(BASE_BENCH_CMD) snappy_de_compress_into
Expand Down
105 changes: 53 additions & 52 deletions benchmarks/README.md

Large diffs are not rendered by default.

7 changes: 2 additions & 5 deletions benchmarks/test_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ def test_snappy(benchmark, file, use_cramjam: bool):
"""
import snappy

data = bytearray(
file.read_bytes()
) # bytearray avoids double allocation in cramjam snappy by default
# Can be slightly faster if passing output_len to compress/decompress ops
data = file.read_bytes()
if use_cramjam:
benchmark(
round_trip,
Expand All @@ -54,7 +51,7 @@ def test_cramjam_snappy_de_compress_into(benchmark, op, file):
"""
from cramjam import snappy

data = bytearray(file.read_bytes())
data = file.read_bytes()
compressed_data = cramjam.snappy.compress(data)

operation = getattr(snappy, op)
Expand Down
69 changes: 41 additions & 28 deletions src/snappy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::prelude::*;
use pyo3::types::{PyByteArray, PyBytes};
use pyo3::wrap_pyfunction;
use pyo3::{PyResult, Python};
use snap::raw::{decompress_len, max_compress_len};
use snap::raw::max_compress_len;

pub fn init_py_module(m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(compress, m)?)?;
Expand All @@ -27,34 +27,43 @@ pub fn init_py_module(m: &PyModule) -> PyResult<()> {
/// ```
#[pyfunction]
pub fn decompress<'a>(py: Python<'a>, data: BytesType<'a>, output_len: Option<usize>) -> PyResult<BytesType<'a>> {
let estimated_len = match output_len {
Some(len) => len,
None => to_py_err!(DecompressionError -> decompress_len(data.as_bytes()))?,
};
let result = match data {
BytesType::Bytes(bytes) => {
let pybytes = if output_len.is_some() {
PyBytes::new_with(py, estimated_len, |buffer| {
to_py_err!(DecompressionError -> self::internal::decompress(bytes.as_bytes(), Output::Slice(buffer)))?;
let pybytes = match output_len {
Some(len) => PyBytes::new_with(py, len, |output| {
to_py_err!(DecompressionError -> self::internal::decompress(bytes.as_bytes(), Output::Slice(output)))?;
Ok(())
})?
} else {
let mut buffer = Vec::with_capacity(estimated_len);
})?,
None => {
let mut output = Vec::with_capacity(data.len());

to_py_err!(DecompressionError -> self::internal::decompress(bytes.as_bytes(), Output::Vector(&mut buffer)))?;
PyBytes::new(py, &buffer)
to_py_err!(DecompressionError -> self::internal::decompress(bytes.as_bytes(), Output::Vector(&mut output)))?;
PyBytes::new(py, &output)
}
};
BytesType::Bytes(pybytes)
}
BytesType::ByteArray(bytes_array) => unsafe {
let mut actual_len = 0;
let pybytes = PyByteArray::new_with(py, estimated_len, |output| {
actual_len = to_py_err!(DecompressionError -> self::internal::decompress(bytes_array.as_bytes(), Output::Slice(output)))?;
Ok(())
})?;
pybytes.resize(actual_len)?;
BytesType::ByteArray(pybytes)
},
BytesType::ByteArray(bytes_array) => {
let bytes = unsafe { bytes_array.as_bytes() };
match output_len {
Some(len) => {
let mut actual_len = 0;
let pybytes = PyByteArray::new_with(py, len, |output| {
actual_len =
to_py_err!(DecompressionError -> self::internal::decompress(bytes, Output::Slice(output)))?;
Ok(())
})?;
pybytes.resize(actual_len)?;
BytesType::ByteArray(pybytes)
}
None => {
let mut output = Vec::with_capacity(data.len());
to_py_err!(DecompressionError -> self::internal::decompress(bytes, Output::Vector(&mut output)))?;
let pybytes = PyByteArray::new(py, &output);
BytesType::ByteArray(pybytes)
}
}
}
};
Ok(result)
}
Expand Down Expand Up @@ -165,7 +174,7 @@ pub fn decompress_into<'a>(_py: Python<'a>, data: BytesType<'a>, array: &'a PyAr
pub(crate) mod internal {
use snap::raw::{Decoder, Encoder};
use snap::read::{FrameDecoder, FrameEncoder};
use std::io::{Cursor, Error, Read, Write};
use std::io::{Error, Read};

use crate::Output;

Expand All @@ -186,7 +195,6 @@ pub(crate) mod internal {
let mut decoder = FrameDecoder::new(data);
match output {
Output::Slice(slice) => {
let mut decoder = FrameDecoder::new(data);
let mut n_bytes = 0;
loop {
let count = decoder.read(&mut slice[n_bytes..])?;
Expand All @@ -206,10 +214,15 @@ pub(crate) mod internal {
let mut encoder = FrameEncoder::new(data);
match output {
Output::Slice(slice) => {
let buffer = Cursor::new(slice);
let mut encoder = snap::write::FrameEncoder::new(buffer);
encoder.write_all(data)?;
Ok(encoder.get_ref().position() as usize)
let mut n_bytes = 0;
loop {
let count = encoder.read(&mut slice[n_bytes..])?;
if count == 0 {
break;
}
n_bytes += count;
}
Ok(n_bytes)
}
Output::Vector(v) => encoder.read_to_end(v),
}
Expand Down
15 changes: 10 additions & 5 deletions tests/test_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,30 @@
import cramjam
import hashlib


def same_same(a, b):
return hashlib.md5(a).hexdigest() == hashlib.md5(b).hexdigest()


@pytest.mark.parametrize("is_bytearray", (True, False))
@pytest.mark.parametrize(
"variant_str", ("snappy", "brotli", "lz4", "gzip", "deflate", "zstd")
)
def test_variants_simple(variant_str):
def test_variants_simple(variant_str, is_bytearray):

variant = getattr(cramjam, variant_str)

uncompressed = b"some bytes to compress 123" * 1000
if is_bytearray:
uncompressed = bytearray(uncompressed)

compressed = variant.compress(uncompressed)
assert compressed != uncompressed
assert type(compressed) == type(uncompressed)

decompressed = variant.decompress(compressed, output_len=len(uncompressed))
assert decompressed == uncompressed
assert type(decompressed) == type(uncompressed)


@pytest.mark.parametrize(
Expand All @@ -28,12 +35,10 @@ def test_variants_simple(variant_str):
def test_variants_raise_exception(variant_str):
variant = getattr(cramjam, variant_str)
with pytest.raises(cramjam.DecompressionError):
variant.decompress(b'sknow')
variant.decompress(b"sknow")


@pytest.mark.parametrize(
"variant_str", ("snappy", "brotli", "gzip", "deflate", "zstd")
)
@pytest.mark.parametrize("variant_str", ("snappy", "brotli", "gzip", "deflate", "zstd"))
def test_variants_de_compress_into(variant_str):

# TODO: support lz4 de/compress_into
Expand Down

0 comments on commit 950e3f9

Please sign in to comment.