@@ -20,10 +20,12 @@ import java.nio.charset.StandardCharsets
20
20
21
21
import scala .collection .JavaConverters ._
22
22
import scala .collection .mutable .ListBuffer
23
+
23
24
import ai .rapids .cudf
24
25
import ai .rapids .cudf .{ColumnVector , DType , HostMemoryBuffer , Scalar , Schema , Table }
25
26
import com .nvidia .spark .rapids ._
26
27
import org .apache .hadoop .conf .Configuration
28
+
27
29
import org .apache .spark .broadcast .Broadcast
28
30
import org .apache .spark .sql .SparkSession
29
31
import org .apache .spark .sql .catalyst .InternalRow
@@ -377,21 +379,24 @@ class JsonPartitionReader(
377
379
private def sanitizeNumbers (input : ColumnVector ): ColumnVector = {
378
380
// Note that this is not 100% consistent with Spark versions prior to Spark 3.3.0
379
381
// due to https://issues.apache.org/jira/browse/SPARK-38060
380
- val regex = if (parsedOptions.allowNonNumericNumbers) {
381
- " ^" +
382
- " (?:" +
383
- " (?:-?[0-9]+(?:\\ .[0-9]+)?(?:[eE][\\ -\\ +]?[0-9]+)?)" +
384
- " |NaN" +
385
- " |(?:[\\ +\\ -]INF)" +
386
- " |(?:[\\ -\\ +]?Infinity)" +
387
- " )" +
388
- " $"
382
+ // cuDF `isFloat` supports some inputs that are not valid JSON numbers, such as `.1`, `1.`,
383
+ // and `+1` so we use a regular expression to match valid JSON numbers instead
384
+ val jsonNumberRegexp = " ^-?[0-9]+(?:\\ .[0-9]+)?(?:[eE][\\ -\\ +]?[0-9]+)?$"
385
+ val isValid = if (parsedOptions.allowNonNumericNumbers) {
386
+ withResource(ColumnVector .fromStrings(" NaN" , " +INF" , " -INF" , " +Infinity" ,
387
+ " Infinity" , " -Infinity" )) { nonNumeric =>
388
+ withResource(input.matchesRe(jsonNumberRegexp)) { isJsonNumber =>
389
+ withResource(input.contains(nonNumeric)) { nonNumeric =>
390
+ isJsonNumber.or(nonNumeric)
391
+ }
392
+ }
393
+ }
389
394
} else {
390
- " ^-?[0-9]+(?: \\ .[0-9]+)?(?:[eE][ \\ - \\ +]?[0-9]+)?$ "
395
+ input.matchesRe(jsonNumberRegexp)
391
396
}
392
- withResource(input.matchesRe(regex)) { validJsonDecimal =>
397
+ withResource(isValid) { _ =>
393
398
withResource(Scalar .fromNull(DType .STRING )) { nullString =>
394
- validJsonDecimal .ifElse(input, nullString)
399
+ isValid .ifElse(input, nullString)
395
400
}
396
401
}
397
402
}
0 commit comments