Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a table-based Huffman decoder #88

Merged
merged 8 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 135 additions & 124 deletions src/huffman.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,49 +17,45 @@ enum HuffmanTreeNode {
Empty,
}

/// Huffman tree
#[derive(Clone, Debug, Default)]
pub(crate) struct HuffmanTree {
tree: Vec<HuffmanTreeNode>,
max_nodes: usize,
num_nodes: usize,
#[derive(Clone, Debug)]
enum HuffmanTreeInner {
Single(u16),
Tree {
tree: Vec<HuffmanTreeNode>,
table: Vec<u32>,
table_mask: u16,
},
}

impl HuffmanTree {
fn is_full(&self) -> bool {
self.num_nodes == self.max_nodes
}

/// Turns a node from empty into a branch and assigns its children
fn assign_children(&mut self, node_index: usize) -> usize {
let offset_index = self.num_nodes - node_index;
self.tree[node_index] = HuffmanTreeNode::Branch(offset_index);
self.num_nodes += 2;
/// Huffman tree
#[derive(Clone, Debug)]
pub(crate) struct HuffmanTree(HuffmanTreeInner);

offset_index
impl Default for HuffmanTree {
fn default() -> Self {
Self(HuffmanTreeInner::Single(0))
}
}

/// Init a huffman tree
fn init(num_leaves: usize) -> Result<HuffmanTree, DecodingError> {
if num_leaves == 0 {
return Err(DecodingError::HuffmanError);
}
impl HuffmanTree {
/// Builds a tree implicitly, just from code lengths
pub(crate) fn build_implicit(code_lengths: Vec<u16>) -> Result<HuffmanTree, DecodingError> {
let mut num_symbols = 0;
let mut root_symbol = 0;

let max_nodes = 2 * num_leaves - 1;
let tree = vec![HuffmanTreeNode::Empty; max_nodes];
let num_nodes = 1;
for (symbol, length) in code_lengths.iter().enumerate() {
if *length > 0 {
num_symbols += 1;
root_symbol = symbol.try_into().unwrap();
}
}

let tree = HuffmanTree {
tree,
max_nodes,
num_nodes,
if num_symbols == 0 {
return Err(DecodingError::HuffmanError);
} else if num_symbols == 1 {
return Ok(Self::build_single_node(root_symbol));
};

Ok(tree)
}

/// Converts code lengths to codes
fn code_lengths_to_codes(code_lengths: &[u16]) -> Result<Vec<Option<u16>>, DecodingError> {
let max_code_length = *code_lengths
.iter()
.reduce(|a, b| if a >= b { a } else { b })
Expand All @@ -86,129 +82,117 @@ impl HuffmanTree {

// Assign codes
let mut curr_code = 0;
let mut next_codes = [None; MAX_ALLOWED_CODE_LENGTH + 1];
let mut next_codes = [0; MAX_ALLOWED_CODE_LENGTH + 1];
for code_len in 1..=usize::from(max_code_length) {
curr_code = (curr_code + code_length_hist[code_len - 1]) << 1;
next_codes[code_len] = Some(curr_code);
next_codes[code_len] = curr_code;
}
let mut huff_codes = vec![None; code_lengths.len()];
let mut huff_codes = vec![0u16; code_lengths.len()];
for (symbol, &length) in code_lengths.iter().enumerate() {
let length = usize::from(length);
if length > 0 {
huff_codes[symbol] = next_codes[length];
if let Some(value) = next_codes[length].as_mut() {
*value += 1;
}
} else {
huff_codes[symbol] = None;
next_codes[length] += 1;
}
}

Ok(huff_codes)
}

/// Adds a symbol to a huffman tree
fn add_symbol(
&mut self,
symbol: u16,
code: u16,
code_length: u16,
) -> Result<(), DecodingError> {
let mut node_index = 0;
let code = usize::from(code);

for length in (0..code_length).rev() {
if node_index >= self.max_nodes {
return Err(DecodingError::HuffmanError);
}

let node = self.tree[node_index];

let offset = match node {
HuffmanTreeNode::Empty => {
if self.is_full() {
return Err(DecodingError::HuffmanError);
}
self.assign_children(node_index)
// Populate decoding table
let table_bits = max_code_length.min(10);
let table_size = (1 << table_bits) as usize;
let table_mask = table_size as u16 - 1;
let mut table = vec![0; table_size];
for (symbol, (&code, &length)) in huff_codes.iter().zip(code_lengths.iter()).enumerate() {
if length != 0 && length <= table_bits {
let mut j = (u16::reverse_bits(code) >> (16 - length)) as usize;
let entry = ((length as u32) << 16) | symbol as u32;
while j < table_size {
table[j] = entry;
j += 1 << length as usize;
}
HuffmanTreeNode::Leaf(_) => return Err(DecodingError::HuffmanError),
HuffmanTreeNode::Branch(offset) => offset,
};

node_index += offset + ((code >> length) & 1);
}

match self.tree[node_index] {
HuffmanTreeNode::Empty => self.tree[node_index] = HuffmanTreeNode::Leaf(symbol),
HuffmanTreeNode::Leaf(_) => return Err(DecodingError::HuffmanError),
HuffmanTreeNode::Branch(_offset) => return Err(DecodingError::HuffmanError),
}

Ok(())
}

/// Builds a tree implicitly, just from code lengths
pub(crate) fn build_implicit(code_lengths: Vec<u16>) -> Result<HuffmanTree, DecodingError> {
let mut num_symbols = 0;
let mut root_symbol = 0;

for (symbol, length) in code_lengths.iter().enumerate() {
if *length > 0 {
num_symbols += 1;
root_symbol = symbol.try_into().unwrap();
}
}

let mut tree = HuffmanTree::init(num_symbols)?;

if num_symbols == 1 {
tree.add_symbol(root_symbol, 0, 0)?;
} else {
let codes = HuffmanTree::code_lengths_to_codes(&code_lengths)?;
// If the longest code is larger than the table size, build a tree as a fallback.
let mut tree = Vec::new();
if max_code_length > table_bits {
tree = vec![HuffmanTreeNode::Empty; 2 * num_symbols - 1];

let mut num_nodes = 1;
for (symbol, &length) in code_lengths.iter().enumerate() {
if length > 0 && codes[symbol].is_some() {
tree.add_symbol(symbol.try_into().unwrap(), codes[symbol].unwrap(), length)?;
let code = huff_codes[symbol];
let code_length = length;
let symbol = symbol.try_into().unwrap();

if length > 0 {
let mut node_index = 0;
let code = usize::from(code);

for length in (0..code_length).rev() {
let node = tree[node_index];

let offset = match node {
HuffmanTreeNode::Empty => {
// Turns a node from empty into a branch and assigns its children
let offset_index = num_nodes - node_index;
tree[node_index] = HuffmanTreeNode::Branch(offset_index);
num_nodes += 2;
offset_index
}
HuffmanTreeNode::Leaf(_) => return Err(DecodingError::HuffmanError),
HuffmanTreeNode::Branch(offset) => offset,
};

node_index += offset + ((code >> length) & 1);
}

match tree[node_index] {
HuffmanTreeNode::Empty => tree[node_index] = HuffmanTreeNode::Leaf(symbol),
HuffmanTreeNode::Leaf(_) => return Err(DecodingError::HuffmanError),
HuffmanTreeNode::Branch(_offset) => {
return Err(DecodingError::HuffmanError)
}
}
}
}
}

Ok(tree)
Ok(Self(HuffmanTreeInner::Tree {
tree,
table,
table_mask,
}))
}

/// Builds a tree explicitly from lengths, codes and symbols
pub(crate) fn build_explicit(
code_lengths: Vec<u16>,
codes: Vec<u16>,
symbols: Vec<u16>,
) -> Result<HuffmanTree, DecodingError> {
let mut tree = HuffmanTree::init(symbols.len())?;

for i in 0..symbols.len() {
tree.add_symbol(symbols[i], codes[i], code_lengths[i])?;
}
pub(crate) fn build_single_node(symbol: u16) -> HuffmanTree {
Self(HuffmanTreeInner::Single(symbol))
}

Ok(tree)
pub(crate) fn build_two_node(zero: u16, one: u16) -> HuffmanTree {
Self(HuffmanTreeInner::Tree {
tree: vec![
HuffmanTreeNode::Leaf(zero),
HuffmanTreeNode::Leaf(one),
HuffmanTreeNode::Empty,
],
table: vec![1 << 16 | zero as u32, 1 << 16 | one as u32],
table_mask: 0x1,
})
}

pub(crate) fn is_single_node(&self) -> bool {
self.num_nodes == 1
matches!(self.0, HuffmanTreeInner::Single(_))
}

/// Reads a symbol using the bitstream.
///
/// You must call call `bit_reader.fill()` before calling this function or it may erroroneosly
/// detect the end of the stream and return a bitstream error.
pub(crate) fn read_symbol<R: Read>(
&self,
#[inline(never)]
fn read_symbol_slowpath<R: Read>(
tree: &[HuffmanTreeNode],
mut v: usize,
bit_reader: &mut BitReader<R>,
) -> Result<u16, DecodingError> {
let mut v = bit_reader.peek(15) as usize;
let mut depth = 0;

let mut index = 0;
loop {
match &self.tree[index] {
match &tree[index] {
HuffmanTreeNode::Branch(children_offset) => {
index += children_offset + (v & 1);
depth += 1;
Expand All @@ -222,4 +206,31 @@ impl HuffmanTree {
}
}
}

/// Reads a symbol using the bitstream.
///
/// You must call call `bit_reader.fill()` before calling this function or it may erroroneosly
/// detect the end of the stream and return a bitstream error.
pub(crate) fn read_symbol<R: Read>(
&self,
bit_reader: &mut BitReader<R>,
) -> Result<u16, DecodingError> {
match &self.0 {
HuffmanTreeInner::Tree {
tree,
table,
table_mask,
} => {
let v = bit_reader.peek_full() as u16;
let entry = table[(v & table_mask) as usize];
if entry != 0 {
bit_reader.consume((entry >> 16) as u8)?;
return Ok(entry as u16);
}

Self::read_symbol_slowpath(tree, v as usize, bit_reader)
}
HuffmanTreeInner::Single(symbol) => Ok(*symbol),
}
}
}
29 changes: 16 additions & 13 deletions src/lossless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,24 +358,22 @@ impl<R: Read> LosslessDecoder<R> {
if simple {
let num_symbols = self.bit_reader.read_bits::<u8>(1)? + 1;

let mut code_lengths = vec![u16::from(num_symbols - 1)];
let mut codes = vec![0];
let mut symbols = Vec::new();

let is_first_8bits = self.bit_reader.read_bits::<u8>(1)?;
symbols.push(self.bit_reader.read_bits::<u16>(1 + 7 * is_first_8bits)?);

if num_symbols == 2 {
symbols.push(self.bit_reader.read_bits::<u16>(8)?);
code_lengths.push(1);
codes.push(1);
}
let zero_symbol = self.bit_reader.read_bits::<u16>(1 + 7 * is_first_8bits)?;

if symbols.iter().any(|&s| s > alphabet_size) {
if zero_symbol >= alphabet_size {
return Err(DecodingError::BitStreamError);
}

HuffmanTree::build_explicit(code_lengths, codes, symbols)
if num_symbols == 1 {
Ok(HuffmanTree::build_single_node(zero_symbol))
} else {
let one_symbol = self.bit_reader.read_bits::<u16>(8)?;
if one_symbol >= alphabet_size {
return Err(DecodingError::BitStreamError);
}
Ok(HuffmanTree::build_two_node(zero_symbol, one_symbol))
}
} else {
let mut code_length_code_lengths = vec![0; CODE_LENGTH_CODES];

Expand Down Expand Up @@ -751,6 +749,11 @@ impl<R: Read> BitReader<R> {
self.buffer & ((1 << num) - 1)
}

/// Peeks at the full buffer.
pub(crate) fn peek_full(&self) -> u64 {
self.buffer
}

/// Consumes `num` bits from the buffer returning an error if there are not enough bits.
pub(crate) fn consume(&mut self, num: u8) -> Result<(), DecodingError> {
if self.nbits < num {
Expand Down
Loading