Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Fix the verify_dict_indices codegen #20920

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,52 +179,20 @@ fn no_more_bitpacked_values() -> ParquetError {
}

#[inline(always)]
#[cfg(feature = "simd")]
fn verify_dict_indices(indices: &[u32; 32], dict_size: usize) -> ParquetResult<()> {
// You would think that the compiler can do this itself, but it does not always do this
// properly. So we help it a bit.
fn verify_dict_indices(indices: &[u32], dict_size: usize) -> ParquetResult<()> {
debug_assert!(dict_size <= u32::MAX as usize);
let dict_size = dict_size as u32;

use std::simd::cmp::SimdPartialOrd;
use std::simd::u32x32;

let dict_size = u32x32::splat(dict_size as u32);
let indices = u32x32::from_slice(indices);

let is_invalid = indices.simd_ge(dict_size);
if is_invalid.any() {
Err(oob_dict_idx())
} else {
Ok(())
}
}

#[inline(always)]
#[cfg(not(feature = "simd"))]
fn verify_dict_indices(indices: &[u32; 32], dict_size: usize) -> ParquetResult<()> {
let mut is_valid = true;
for &idx in indices {
is_valid &= (idx as usize) < dict_size;
is_valid &= idx < dict_size;
}

if is_valid {
return Ok(());
}

Err(oob_dict_idx())
}

#[inline(always)]
fn verify_dict_indices_slice(indices: &[u32], dict_size: usize) -> ParquetResult<()> {
let mut is_valid = true;
for &idx in indices {
is_valid &= (idx as usize) < dict_size;
}

if is_valid {
return Ok(());
Ok(())
} else {
Err(oob_dict_idx())
}

Err(oob_dict_idx())
}

/// Skip over entire chunks in a [`HybridRleDecoder`] as long as all skipped chunks do not include
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use arrow::types::AlignedBytes;

use super::{
oob_dict_idx, required_skip_whole_chunks, verify_dict_indices, verify_dict_indices_slice,
IndexMapping,
};
use super::{oob_dict_idx, required_skip_whole_chunks, verify_dict_indices, IndexMapping};
use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder};
use crate::parquet::error::ParquetResult;

Expand Down Expand Up @@ -55,7 +52,7 @@ pub fn decode<B: AlignedBytes, D: IndexMapping<Output = B>>(

if let Some((chunk, chunk_size)) = decoder.chunked().next_inexact() {
let chunk = &chunk[num_rows_to_skip..chunk_size];
verify_dict_indices_slice(chunk, dict.len())?;
verify_dict_indices(chunk, dict.len())?;
target.extend(chunk.iter().map(|&idx| {
// SAFETY: The dict indices were verified before.
unsafe { dict.get_unchecked(idx) }
Expand All @@ -73,7 +70,7 @@ pub fn decode<B: AlignedBytes, D: IndexMapping<Output = B>>(
}

if let Some((chunk, chunk_size)) = chunked.remainder() {
verify_dict_indices_slice(&chunk[..chunk_size], dict.len())?;
verify_dict_indices(&chunk[..chunk_size], dict.len())?;
target.extend(chunk[..chunk_size].iter().map(|&idx| {
// SAFETY: The dict indices were verified before.
unsafe { dict.get_unchecked(idx) }
Expand Down
Loading