Skip to content

Commit

Permalink
perf: frequency optimizations
Browse files Browse the repository at this point in the history
- amortize more string vars
- precompute hot path in hot loop, instead of repeatedly doing if chains inside loop
- skip unnecessary bounds checks
- use more crossbeam channels
  • Loading branch information
jqnatividad committed Jan 29, 2025
1 parent bc2ce7c commit a4aa9a5
Showing 1 changed file with 44 additions and 116 deletions.
160 changes: 44 additions & 116 deletions src/cmd/frequency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
let mut itoa_buffer = itoa::Buffer::new();
let mut pct_decimal: Decimal;
let mut final_pct_decimal: Decimal;
let mut pct_string: String;
// most percentages are less than 10 characters, so pre-allocate 10 characters
#[allow(unused_assignments)]
let mut pct_string = String::with_capacity(10);
let mut pct_scale: u32;
let mut current_scale: u32;
let abs_dec_places = args.flag_pct_dec_places.unsigned_abs() as u32;
Expand Down Expand Up @@ -392,7 +394,7 @@ impl Args {
let nchunks = util::num_of_chunks(idx_count, chunk_size);

let pool = ThreadPool::new(njobs);
let (send, recv) = crossbeam_channel::bounded(0);
let (send, recv) = crossbeam_channel::bounded(nchunks);
for i in 0..nchunks {
let (send, args, sel) = (send.clone(), self.clone(), sel.clone());
pool.execute(move || {
Expand Down Expand Up @@ -435,128 +437,50 @@ impl Args {
.map(|i| all_unique_headers.contains(&i))
.collect();

if flag_ignore_case {
// case insensitive when computing frequencies
let mut buf = String::new();

// Pre-compute function pointers for the hot path
// instead of doing if chains repeatedly in the hot loop
let process_field = if flag_ignore_case {
if flag_no_trim {
// case-insensitive, don't trim whitespace
for row in it {
// safety: we know the row is not empty
row_buffer.clone_from(&row.unwrap());
for (i, field) in nsel.select(row_buffer.into_iter()).enumerate() {
// safety: all_unique_flag_vec.len() is the same as nsel.len()
if unsafe { *all_unique_flag_vec.get_unchecked(i) } {
// if the column has all unique values,
// we don't need to compute frequencies
continue;
}

// safety: we do get_unchecked_mut on freq_tables
// as we know that nsel_len is the same as freq_tables.len()
// so we can skip the bounds check
if !field.is_empty() {
field_buffer = {
if let Ok(s) = simdutf8::basic::from_utf8(field) {
util::to_lowercase_into(s, &mut buf);
buf.as_bytes().to_vec()
} else {
field.to_vec()
}
};
unsafe {
freq_tables.get_unchecked_mut(i).add(field_buffer);
}
} else if !flag_no_nulls {
unsafe {
freq_tables.get_unchecked_mut(i).add(null.clone());
}
}
|field: &[u8], buf: &mut String| {
if let Ok(s) = simdutf8::basic::from_utf8(field) {
util::to_lowercase_into(s, buf);
buf.as_bytes().to_vec()
} else {
field.to_vec()
}
}
} else {
// case-insensitive, trim whitespace
for row in it {
// safety: we know the row is not empty
row_buffer.clone_from(&row.unwrap());
for (i, field) in nsel.select(row_buffer.into_iter()).enumerate() {
if unsafe { *all_unique_flag_vec.get_unchecked(i) } {
continue;
}

// safety: we do get_unchecked_mut on freq_tables
// as we know that nsel_len is the same as freq_tables.len()
// so we can skip the bounds check
if !field.is_empty() {
field_buffer = {
if let Ok(s) = simdutf8::basic::from_utf8(field) {
util::to_lowercase_into(s.trim(), &mut buf);
buf.as_bytes().to_vec()
} else {
util::trim_bs_whitespace(field).to_vec()
}
};
unsafe {
freq_tables.get_unchecked_mut(i).add(field_buffer);
}
} else if !flag_no_nulls {
unsafe {
freq_tables.get_unchecked_mut(i).add(null.clone());
}
}
|field: &[u8], buf: &mut String| {
if let Ok(s) = simdutf8::basic::from_utf8(field) {
util::to_lowercase_into(s.trim(), buf);
buf.as_bytes().to_vec()
} else {
util::trim_bs_whitespace(field).to_vec()
}
}
}
} else if flag_no_trim {
|field: &[u8], _buf: &mut String| field.to_vec()
} else {
// case sensitive by default when computing frequencies
for row in it {
// safety: we know the row is not empty
row_buffer.clone_from(&row.unwrap());

if flag_no_trim {
// case-sensitive, don't trim whitespace
for (i, field) in nsel.select(row_buffer.into_iter()).enumerate() {
if unsafe { *all_unique_flag_vec.get_unchecked(i) } {
continue;
}

// safety: get_unchecked_mut on freq_tables for same safety reason above
if !field.is_empty() {
// no need to convert to string and back to bytes for a "case-sensitive"
// comparison we can just use the field directly
unsafe {
freq_tables.get_unchecked_mut(i).add(field.to_vec());
}
} else if !flag_no_nulls {
unsafe {
freq_tables.get_unchecked_mut(i).add(null.clone());
}
}
|field: &[u8], _buf: &mut String| util::trim_bs_whitespace(field).to_vec()
};

let mut string_buf = String::with_capacity(100);
for row in it {
row_buffer.clone_from(&row.unwrap());
for (i, field) in nsel.select(row_buffer.into_iter()).enumerate() {
if unsafe { *all_unique_flag_vec.get_unchecked(i) } {
continue;
}

if !field.is_empty() {
field_buffer = process_field(field, &mut string_buf);
unsafe {
freq_tables.get_unchecked_mut(i).add(field_buffer);
}
} else {
// case-sensitive, trim whitespace
for (i, field) in nsel.select(row_buffer.into_iter()).enumerate() {
if unsafe { *all_unique_flag_vec.get_unchecked(i) } {
continue;
}

// safety: get_unchecked_mut on freq_tables for same safety reason above
if !field.is_empty() {
field_buffer = {
if let Ok(s) = simdutf8::basic::from_utf8(field) {
s.trim().as_bytes().to_vec()
} else {
util::trim_bs_whitespace(field).to_vec()
}
};
unsafe {
freq_tables.get_unchecked_mut(i).add(field_buffer);
}
} else if !flag_no_nulls {
unsafe {
freq_tables.get_unchecked_mut(i).add(null.clone());
}
}
} else if !flag_no_nulls {
unsafe {
freq_tables.get_unchecked_mut(i).add(null.clone());
}
}
}
Expand Down Expand Up @@ -628,9 +552,13 @@ impl Args {
let row_count = util::count_rows(&self.rconfig())?;
FREQ_ROW_COUNT.set(row_count as u64).unwrap();

// Most datasets have relatively few columns with all unique values (e.g. ID columns)
// so pre-allocate space for 5 as a reasonable default capacity
let mut all_unique_headers_vec: Vec<usize> = Vec::with_capacity(5);
for (i, _header) in headers.iter().enumerate() {
let cardinality = col_cardinality_vec[i].1;
// safety: we know that col_cardinality_vec has the same length as headers
// as it was constructed from csv_fields which has the same length as headers
let cardinality = unsafe { col_cardinality_vec.get_unchecked(i).1 };

if cardinality == row_count {
all_unique_headers_vec.push(i);
Expand Down

0 comments on commit a4aa9a5

Please sign in to comment.