From 3cd5b165e928d229c382dfd364c8b93c9adf2117 Mon Sep 17 00:00:00 2001 From: ZuoTiJia <2239651886@qq.com> Date: Fri, 14 Oct 2022 10:00:35 +0800 Subject: [PATCH] Fix the panic when lpad/rpad parameter is negative * fix the panic when lpad/rpad parameter is negative * add lpad test --- datafusion/core/tests/sql/unicode.rs | 4 + .../physical-expr/src/unicode_expressions.rs | 74 ++++++++++++------- 2 files changed, 52 insertions(+), 26 deletions(-) diff --git a/datafusion/core/tests/sql/unicode.rs b/datafusion/core/tests/sql/unicode.rs index 7a4bd10e7cbaf..b787b5a328cb3 100644 --- a/datafusion/core/tests/sql/unicode.rs +++ b/datafusion/core/tests/sql/unicode.rs @@ -49,7 +49,9 @@ async fn test_unicode_expressions() -> Result<()> { test_expression!("length('chars')", "5"); test_expression!("length('josé')", "4"); test_expression!("length(NULL)", "NULL"); + test_expression!("lpad('hi', -1, 'xy')", ""); test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); + test_expression!("lpad('hi', -1)", ""); test_expression!("lpad('hi', 0)", ""); test_expression!("lpad('hi', 21, 'abcdef')", "abcdefabcdefabcdefahi"); test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); @@ -71,7 +73,9 @@ async fn test_unicode_expressions() -> Result<()> { test_expression!("right('abcde', CAST(NULL AS INT))", "NULL"); test_expression!("right(NULL, 2)", "NULL"); test_expression!("right(NULL, CAST(NULL AS INT))", "NULL"); + test_expression!("rpad('hi', -1, 'xy')", ""); test_expression!("rpad('hi', 5, 'xy')", "hixyx"); + test_expression!("rpad('hi', -1)", ""); test_expression!("rpad('hi', 0)", ""); test_expression!("rpad('hi', 21, 'abcdef')", "hiabcdefabcdefabcdefa"); test_expression!("rpad('hi', 5, 'xy')", "hixyx"); diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 5ef7029e7d56c..61707b0c6f0b3 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -129,23 +129,29 @@ pub fn lpad(args: &[ArrayRef]) -> Result { .zip(length_array.iter()) .map(|(string, length)| match (string, length) { (Some(string), Some(length)) => { - let length = length as usize; + if length > i32::MAX as i64 { + return Err(DataFusionError::Internal( + "lpad requested length too large".to_string(), + )); + } + + let length = if length < 0 { 0 } else { length as usize }; if length == 0 { - Some("".to_string()) + Ok(Some("".to_string())) } else { let graphemes = string.graphemes(true).collect::>(); if length < graphemes.len() { - Some(graphemes[..length].concat()) + Ok(Some(graphemes[..length].concat())) } else { let mut s: String = " ".repeat(length - graphemes.len()); s.push_str(string); - Some(s) + Ok(Some(s)) } } } - _ => None, + _ => Ok(None), }) - .collect::>(); + .collect::>>()?; Ok(Arc::new(result) as ArrayRef) } @@ -160,18 +166,23 @@ pub fn lpad(args: &[ArrayRef]) -> Result { .zip(fill_array.iter()) .map(|((string, length), fill)| match (string, length, fill) { (Some(string), Some(length), Some(fill)) => { - let length = length as usize; + if length > i32::MAX as i64 { + return Err(DataFusionError::Internal( + "lpad requested length too large".to_string(), + )); + } + let length = if length < 0 { 0 } else { length as usize }; if length == 0 { - Some("".to_string()) + Ok(Some("".to_string())) } else { let graphemes = string.graphemes(true).collect::>(); let fill_chars = fill.chars().collect::>(); if length < graphemes.len() { - Some(graphemes[..length].concat()) + Ok(Some(graphemes[..length].concat())) } else if fill_chars.is_empty() { - Some(string.to_string()) + Ok(Some(string.to_string())) } else { let mut s = string.to_string(); let mut char_vector = @@ -185,13 +196,13 @@ pub fn lpad(args: &[ArrayRef]) -> Result { 0, char_vector.iter().collect::().as_str(), ); - Some(s) + Ok(Some(s)) } } } - _ => None, + _ => Ok(None), }) - .collect::>(); + .collect::>>()?; Ok(Arc::new(result) as ArrayRef) } @@ -262,24 +273,29 @@ pub fn rpad(args: &[ArrayRef]) -> Result { .zip(length_array.iter()) .map(|(string, length)| match (string, length) { (Some(string), Some(length)) => { - let length = length as usize; + if length > i32::MAX as i64 { + return Err(DataFusionError::Internal( + "lpad requested length too large".to_string(), + )); + } + + let length = if length < 0 { 0 } else { length as usize }; if length == 0 { - Some("".to_string()) + Ok(Some("".to_string())) } else { let graphemes = string.graphemes(true).collect::>(); if length < graphemes.len() { - Some(graphemes[..length].concat()) + Ok(Some(graphemes[..length].concat())) } else { let mut s = string.to_string(); s.push_str(" ".repeat(length - graphemes.len()).as_str()); - Some(s) + Ok(Some(s)) } } } - _ => None, + _ => Ok(None), }) - .collect::>(); - + .collect::>>()?; Ok(Arc::new(result) as ArrayRef) } 3 => { @@ -293,14 +309,20 @@ pub fn rpad(args: &[ArrayRef]) -> Result { .zip(fill_array.iter()) .map(|((string, length), fill)| match (string, length, fill) { (Some(string), Some(length), Some(fill)) => { - let length = length as usize; + if length > i32::MAX as i64 { + return Err(DataFusionError::Internal( + "lpad requested length too large".to_string(), + )); + } + + let length = if length < 0 { 0 } else { length as usize }; let graphemes = string.graphemes(true).collect::>(); let fill_chars = fill.chars().collect::>(); if length < graphemes.len() { - Some(graphemes[..length].concat()) + Ok(Some(graphemes[..length].concat())) } else if fill_chars.is_empty() { - Some(string.to_string()) + Ok(Some(string.to_string())) } else { let mut s = string.to_string(); let mut char_vector = @@ -310,12 +332,12 @@ pub fn rpad(args: &[ArrayRef]) -> Result { .push(*fill_chars.get(l % fill_chars.len()).unwrap()); } s.push_str(char_vector.iter().collect::().as_str()); - Some(s) + Ok(Some(s)) } } - _ => None, + _ => Ok(None), }) - .collect::>(); + .collect::>>()?; Ok(Arc::new(result) as ArrayRef) }