Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(rust): Remove last instances of itoa #20881

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ repository = "https://github.com/pola-rs/polars"
ahash = ">=0.8.5"
aho-corasick = "1.1"
arboard = { version = "3.4.0", default-features = false }
atoi = "2"
atoi_simd = "0.16"
atomic-waker = "1"
avro-schema = { version = "0.3" }
Expand Down
3 changes: 1 addition & 2 deletions crates/polars-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ repository = { workspace = true }
description = "Minimal implementation of the Arrow specification forked from arrow2"

[dependencies]
atoi = { workspace = true, optional = true }
bytemuck = { workspace = true, features = ["must_cast"] }
chrono = { workspace = true }
# for timezone support
Expand Down Expand Up @@ -143,7 +142,7 @@ timezones = [
"chrono-tz",
]
dtype-array = []
dtype-decimal = ["atoi", "itoa"]
dtype-decimal = ["atoi_simd", "itoa"]
bigidx = ["polars-utils/bigidx"]
nightly = []
performant = []
Expand Down
147 changes: 67 additions & 80 deletions crates/polars-arrow/src/compute/decimal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::sync::atomic::{AtomicBool, Ordering};

use atoi::FromRadix10SignedChecked;
use num_traits::Euclid;

static TRIM_DECIMAL_ZEROS: AtomicBool = AtomicBool::new(false);
Expand All @@ -12,103 +11,86 @@ pub fn set_trim_decimal_zeros(trim: Option<bool>) {
TRIM_DECIMAL_ZEROS.store(trim.unwrap_or(false), Ordering::Relaxed)
}

/// Count the number of b'0's at the beginning of a slice.
fn leading_zeros(bytes: &[u8]) -> u8 {
bytes.iter().take_while(|byte| **byte == b'0').count() as u8
}

fn split_decimal_bytes(bytes: &[u8]) -> (Option<&[u8]>, Option<&[u8]>) {
let mut a = bytes.splitn(2, |x| *x == b'.');
let lhs = a.next();
let rhs = a.next();
(lhs, rhs)
}

/// Parse a single i128 from bytes, ensuring the entire slice is read.
fn parse_integer_checked(bytes: &[u8]) -> Option<i128> {
let (n, len) = i128::from_radix_10_signed_checked(bytes);
n.filter(|_| len == bytes.len())
}

/// Assuming bytes are a well-formed decimal number (with or without a separator),
/// infer the scale of the number. If no separator is present, the scale is 0.
pub fn infer_scale(bytes: &[u8]) -> u8 {
let (_lhs, rhs) = split_decimal_bytes(bytes);
rhs.map_or(0, |x| x.len() as u8)
let Some(separator) = bytes.iter().position(|b| *b == b'.') else {
return 0;
};
(bytes.len() - (1 + separator)) as u8
}

/// Deserialize bytes to a single i128 representing a decimal, at a specified precision
/// (optional) and scale (required). If precision is not specified, it is assumed to be
/// 38 (the max precision allowed by the i128 representation). The number is checked to
/// ensure it fits within the specified precision and scale. Consistent with float parsing,
/// no decimal separator is required (eg "500", "500.", and "500.0" are all accepted); this allows
/// mixed integer/decimal sequences to be parsed as decimals. All trailing zeros are assumed to
/// be significant, whether or not a separator is present: 1200 requires precision >= 4, while 1200.200
/// requires precision >= 7 and scale >= 3. Returns None if the number is not well-formed, or does not
/// fit. Only b'.' is allowed as a decimal separator (issue #6698).
/// Deserialize bytes to a single i128 representing a decimal, at a specified
/// precision (optional) and scale (required). The number is checked to ensure
/// it fits within the specified precision and scale. Consistent with float
/// parsing, no decimal separator is required (eg "500", "500.", and "500.0" are
/// all accepted); this allows mixed integer/decimal sequences to be parsed as
/// decimals. All trailing zeros are assumed to be significant, whether or not
/// a separator is present: 1200 requires precision >= 4, while 1200.200
/// requires precision >= 7 and scale >= 3. Returns None if the number is not
/// well-formed, or does not fit. Only b'.' is allowed as a decimal separator
/// (issue #6698).
#[inline]
pub fn deserialize_decimal(mut bytes: &[u8], precision: Option<u8>, scale: u8) -> Option<i128> {
// While parse_integer_checked will parse positive/negative numbers, we want to
// handle the sign ourselves, and so check for it initially, then handle it
// at the end.
pub fn deserialize_decimal(bytes: &[u8], precision: Option<u8>, scale: u8) -> Option<i128> {
let precision_digits = precision.unwrap_or(38).min(38) as usize;
if scale as usize > precision_digits {
return None;
}

let separator = bytes.iter().position(|b| *b == b'.').unwrap_or(bytes.len());
let (mut int, mut frac) = bytes.split_at(separator);
if frac.len() <= 1 || scale == 0 {
// Only integer fast path.
let n: i128 = atoi_simd::parse(int).ok()?;
let ret = n.checked_mul(POW10[scale as usize] as i128)?;
if precision.is_some() && ret >= POW10[precision_digits] as i128 {
return None;
}
return Some(ret);
}

// Skip period.
frac = &frac[1..];

// Skip sign.
let negative = match bytes.first() {
Some(s @ (b'+' | b'-')) => {
bytes = &bytes[1..];
int = &int[1..];
*s == b'-'
},
_ => false,
};
let (lhs, rhs) = split_decimal_bytes(bytes);
let precision = precision.unwrap_or(38);

let lhs_b = lhs?;
// Truncate trailing digits that extend beyond the scale.
let frac_scale = if scale as usize <= frac.len() {
frac = &frac[..scale as usize];
0
} else {
scale as usize - frac.len()
};

// For the purposes of decimal parsing, we assume that all digits other than leading zeros
// are significant, eg, 001200 has 4 significant digits, not 2. The Decimal type does
// not allow negative scales, so all trailing zeros on the LHS of any decimal separator
// will still take up space in the representation (eg, 1200 requires, at minimum, precision 4
// at scale 0; there is no scale -2 where it would only need precision 2).
let lhs_s = lhs_b.len() as u8 - leading_zeros(lhs_b);
// Parse and combine parts.
let pint: u128 = if int.is_empty() {
0
} else {
atoi_simd::parse_pos(int).ok()?
};
let pfrac: u128 = atoi_simd::parse_pos(frac).ok()?;

if lhs_s + scale > precision {
// the integer already exceeds the precision
let ret = pint
.checked_mul(POW10[scale as usize])?
.checked_add(pfrac.checked_mul(POW10[frac_scale])?)?;
if precision.is_some() && ret >= POW10[precision_digits] {
return None;
}

let abs = parse_integer_checked(lhs_b).and_then(|x| match rhs {
// A decimal separator was found, so LHS and RHS need to be combined.
Some(mut rhs) => {
if matches!(rhs.first(), Some(b'+' | b'-')) {
// RHS starts with a '+'/'-' sign and the number is not well-formed.
return None;
}
let scale_adjust = if (scale as usize) <= rhs.len() {
// Truncate trailing digits that extend beyond the scale
rhs = &rhs[..scale as usize];
None
} else {
Some(scale as u32 - rhs.len() as u32)
};

parse_integer_checked(rhs).map(|y| {
let lhs = x * 10i128.pow(scale as u32);
let rhs = scale_adjust.map_or(y, |s| y * 10i128.pow(s));
lhs + rhs
})
},
// No decimal separator was found; we have an integer / LHS only.
None => {
if lhs_b.is_empty() {
// we simply have no number at all / an empty string.
return None;
}
Some(x * 10i128.pow(scale as u32))
},
});
if negative {
Some(-abs?)
if ret > (1 << 127) {
None
} else {
Some(ret.wrapping_neg() as i128)
}
} else {
abs
ret.try_into().ok()
}
}

Expand Down Expand Up @@ -322,6 +304,11 @@ mod test {
assert_eq!(deserialize_decimal(val, Some(5), 6), None); // insufficient precision, excess scale
assert_eq!(deserialize_decimal(val, Some(5), 3), None); // insufficient precision, exact scale
assert_eq!(deserialize_decimal(val, Some(12), 5), Some(120001000)); // excess precision, excess scale
assert_eq!(deserialize_decimal(val, None, 35), None); // scale causes insufficient precision
assert_eq!(
deserialize_decimal(val, None, 35),
Some(120001000000000000000000000000000000000)
);
assert_eq!(deserialize_decimal(val, None, 36), None);
assert_eq!(deserialize_decimal(val, Some(38), 35), None); // scale causes insufficient precision
}
}
2 changes: 1 addition & 1 deletion crates/polars-time/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ polars-error = { workspace = true }
polars-ops = { workspace = true }
polars-utils = { workspace = true }

atoi = { workspace = true }
atoi_simd = { workspace = true }
bytemuck = { workspace = true }
chrono = { workspace = true }
chrono-tz = { workspace = true, optional = true }
Expand Down
7 changes: 3 additions & 4 deletions crates/polars-time/src/chunkedarray/string/strptime.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! Much more opinionated, but also much faster strptrime than the one given in Chrono.
//!
use atoi::FromRadix10;
use chrono::{NaiveDate, NaiveDateTime};
use once_cell::sync::Lazy;
use regex::Regex;
Expand All @@ -14,15 +13,15 @@ static TWELVE_HOUR_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"%[_-]?[Il]")
static MERIDIEM_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"%[_-]?[pP]").unwrap());

#[inline]
fn update_and_parse<T: atoi::FromRadix10>(
fn update_and_parse<T: atoi_simd::Parse>(
incr: usize,
offset: usize,
vals: &[u8],
) -> Option<(T, usize)> {
// this maybe oob because we cannot entirely sure about fmt lengths
let new_offset = offset + incr;
let bytes = vals.get(offset..new_offset)?;
let (val, parsed) = T::from_radix_10(bytes);
let (val, parsed) = atoi_simd::parse_any(bytes).ok()?;
if parsed == 0 {
None
} else {
Expand Down Expand Up @@ -154,7 +153,7 @@ impl StrpTimeState {
let new_offset = offset + 2;
let bytes = val.get_unchecked(offset..new_offset);

let (decade, parsed) = i32::from_radix_10(bytes);
let (decade, parsed) = atoi_simd::parse_any::<i32>(bytes).ok()?;
if parsed == 0 {
return None;
}
Expand Down
Loading