Skip to content

Commit

Permalink
[SPARK-34605][SQL] Support java.time.Duration as an external type o…
Browse files Browse the repository at this point in the history
…f the day-time interval type

### What changes were proposed in this pull request?
In the PR, I propose to extend Spark SQL API to accept [`java.time.Duration`](https://docs.oracle.com/javase/8/docs/api/java/time/Duration.html) as an external type of recently added new Catalyst type - `DayTimeIntervalType` (see #31614). The Java class `java.time.Duration` has similar semantic to ANSI SQL day-time interval type, and it is the most suitable to be an external type for `DayTimeIntervalType`. In more details:
1. Added `DurationConverter` which converts `java.time.Duration` instances to/from internal representation of the Catalyst type `DayTimeIntervalType` (to `Long` type). The `DurationConverter` object uses new methods of `IntervalUtils`:
    - `durationToMicros()` converts the input duration to the total length in microseconds. If this duration is too large to fit `Long`, the method throws the exception `ArithmeticException`. **Note:** _the input duration has nanosecond precision, the method casts the nanos part to microseconds by dividing by 1000._
    - `microsToDuration()` obtains a `java.time.Duration` representing a number of microseconds.
2. Support new type `DayTimeIntervalType` in `RowEncoder` via the methods `createDeserializerForDuration()` and `createSerializerForJavaDuration()`.
3. Extended the Literal API to construct literals from `java.time.Duration` instances.

### Why are the changes needed?
1. To allow users parallelization of `java.time.Duration` collections, and construct day-time interval columns. Also to collect such columns back to the driver side.
2. This will allow to write tests in other sub-tasks of SPARK-27790.

### Does this PR introduce _any_ user-facing change?
The PR extends existing functionality. So, users can parallelize instances of the `java.time.Duration` class and collect them back:

```Scala
scala> val ds = Seq(java.time.Duration.ofDays(10)).toDS
ds: org.apache.spark.sql.Dataset[java.time.Duration] = [value: daytimeinterval]

scala> ds.collect
res0: Array[java.time.Duration] = Array(PT240H)
```

### How was this patch tested?
- Added a few tests to `CatalystTypeConvertersSuite` to check conversion from/to `java.time.Duration`.
- Checking row encoding by new tests in `RowEncoderSuite`.
- Making literals of `DayTimeIntervalType` are tested in `LiteralExpressionSuite`
- Check collecting by `DatasetSuite` and `JavaDatasetSuite`.

Closes #31729 from MaxGekk/java-time-duration.

Authored-by: Max Gekk <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
MaxGekk authored and cloud-fan committed Mar 4, 2021
1 parent e7e0161 commit 17601e0
Show file tree
Hide file tree
Showing 23 changed files with 229 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ public static Object read(
if (handleUserDefinedType && dataType instanceof UserDefinedType) {
return obj.get(ordinal, ((UserDefinedType)dataType).sqlType());
}
if (dataType instanceof DayTimeIntervalType) {
return obj.getLong(ordinal);
}

throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ object Encoders {
*/
def BINARY: Encoder[Array[Byte]] = ExpressionEncoder()

/**
* Creates an encoder that serializes instances of the `java.time.Duration` class
* to the internal representation of nullable Catalyst's DayTimeIntervalType.
*
* @since 3.2.0
*/
def DURATION: Encoder[java.time.Duration] = ExpressionEncoder()

/**
* Creates an encoder for Java Bean of type T.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Duration, Instant, LocalDate}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable

Expand Down Expand Up @@ -74,6 +74,7 @@ object CatalystTypeConverters {
case LongType => LongConverter
case FloatType => FloatConverter
case DoubleType => DoubleConverter
case DayTimeIntervalType => DurationConverter
case dataType: DataType => IdentityConverter(dataType)
}
converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
Expand Down Expand Up @@ -400,6 +401,18 @@ object CatalystTypeConverters {
override def toScalaImpl(row: InternalRow, column: Int): Double = row.getDouble(column)
}

private object DurationConverter extends CatalystTypeConverter[Duration, Duration, Any] {
override def toCatalystImpl(scalaValue: Duration): Long = {
IntervalUtils.durationToMicros(scalaValue)
}
override def toScala(catalystValue: Any): Duration = {
if (catalystValue == null) null
else IntervalUtils.microsToDuration(catalystValue.asInstanceOf[Long])
}
override def toScalaImpl(row: InternalRow, column: Int): Duration =
IntervalUtils.microsToDuration(row.getLong(column))
}

/**
* Creates a converter function that will convert Scala objects to the specified Catalyst type.
* Typical use case would be converting a collection of rows that have the same schema. You will
Expand Down Expand Up @@ -465,6 +478,7 @@ object CatalystTypeConverters {
map,
(key: Any) => convertToCatalyst(key),
(value: Any) => convertToCatalyst(value))
case d: Duration => DurationConverter.toCatalyst(d)
case other => other
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.types._

object DeserializerBuildHelper {
Expand Down Expand Up @@ -143,6 +143,15 @@ object DeserializerBuildHelper {
returnNullable = false)
}

def createDeserializerForDuration(path: Expression): Expression = {
StaticInvoke(
IntervalUtils.getClass,
ObjectType(classOf[java.time.Duration]),
"microsToDuration",
path :: Nil,
returnNullable = false)
}

/**
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ object InternalRow {
case ByteType => (input, ordinal) => input.getByte(ordinal)
case ShortType => (input, ordinal) => input.getShort(ordinal)
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
case LongType | TimestampType | DayTimeIntervalType =>
(input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal)
case DoubleType => (input, ordinal) => input.getDouble(ordinal)
case StringType => (input, ordinal) => input.getUTF8String(ordinal)
Expand Down Expand Up @@ -168,7 +169,8 @@ object InternalRow {
case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte])
case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short])
case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
case LongType | TimestampType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
case LongType | TimestampType | DayTimeIntervalType =>
(input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float])
case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double])
case CalendarIntervalType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType, true)

case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
Expand Down Expand Up @@ -249,6 +250,9 @@ object JavaTypeInference {
case c if c == classOf[java.sql.Timestamp] =>
createDeserializerForSqlTimestamp(path)

case c if c == classOf[java.time.Duration] =>
createDeserializerForDuration(path)

case c if c == classOf[java.lang.String] =>
createDeserializerForString(path, returnNullable = true)

Expand Down Expand Up @@ -406,6 +410,8 @@ object JavaTypeInference {

case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)

case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject)

case c if c == classOf[java.math.BigDecimal] =>
createSerializerForJavaBigDecimal(inputObject)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
createDeserializerForSqlTimestamp(path)

case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createDeserializerForDuration(path)

case t if isSubtype(t, localTypeOf[java.lang.String]) =>
createDeserializerForString(path, returnNullable = false)

Expand Down Expand Up @@ -522,6 +525,9 @@ object ScalaReflection extends ScalaReflection {

case t if isSubtype(t, localTypeOf[java.sql.Date]) => createSerializerForSqlDate(inputObject)

case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createSerializerForJavaDuration(inputObject)

case t if isSubtype(t, localTypeOf[BigDecimal]) =>
createSerializerForScalaBigDecimal(inputObject)

Expand Down Expand Up @@ -740,6 +746,8 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true)
case t if isSubtype(t, localTypeOf[CalendarInterval]) =>
Schema(CalendarIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
Schema(DayTimeIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
Expand Down Expand Up @@ -837,7 +845,8 @@ object ScalaReflection extends ScalaReflection {
DateType -> classOf[DateType.InternalType],
TimestampType -> classOf[TimestampType.InternalType],
BinaryType -> classOf[BinaryType.InternalType],
CalendarIntervalType -> classOf[CalendarInterval]
CalendarIntervalType -> classOf[CalendarInterval],
DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType]
)

val typeBoxedJavaMapping = Map[DataType, Class[_]](
Expand All @@ -849,7 +858,8 @@ object ScalaReflection extends ScalaReflection {
FloatType -> classOf[java.lang.Float],
DoubleType -> classOf[java.lang.Double],
DateType -> classOf[java.lang.Integer],
TimestampType -> classOf[java.lang.Long]
TimestampType -> classOf[java.lang.Long],
DayTimeIntervalType -> classOf[java.lang.Long]
)

def dataTypeJavaClass(dt: DataType): Class[_] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, CreateNamedStruct, Expression, IsNull, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -104,6 +104,15 @@ object SerializerBuildHelper {
returnNullable = false)
}

def createSerializerForJavaDuration(inputObject: Expression): Expression = {
StaticInvoke(
IntervalUtils.getClass,
DayTimeIntervalType,
"durationToMicros",
inputObject :: Nil,
returnNullable = false)
}

def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,11 @@ package object dsl {
/** Creates a new AttributeReference of type timestamp */
def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)()

/** Creates a new AttributeReference of the day-time interval type */
def dayTimeInterval: AttributeReference = {
AttributeReference(s, DayTimeIntervalType, nullable = true)()
}

/** Creates a new AttributeReference of type binary */
def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ import org.apache.spark.sql.types._
* TimestampType -> java.sql.Timestamp if spark.sql.datetime.java8API.enabled is false
* TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true
*
* DayTimeIntervalType -> java.time.Duration
*
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq or Array
* MapType -> scala.collection.Map
Expand Down Expand Up @@ -108,6 +110,8 @@ object RowEncoder {
createSerializerForSqlDate(inputObject)
}

case DayTimeIntervalType => createSerializerForJavaDuration(inputObject)

case d: DecimalType =>
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down Expand Up @@ -226,6 +230,7 @@ object RowEncoder {
} else {
ObjectType(classOf[java.sql.Date])
}
case DayTimeIntervalType => ObjectType(classOf[java.time.Duration])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
case StringType => ObjectType(classOf[java.lang.String])
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
Expand Down Expand Up @@ -281,6 +286,8 @@ object RowEncoder {
createDeserializerForSqlDate(input)
}

case DayTimeIntervalType => createDeserializerForDuration(input)

case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false)

case StringType => createDeserializerForString(input, returnNullable = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ object InterpretedUnsafeProjection {
case IntegerType | DateType =>
(v, i) => writer.write(i, v.getInt(i))

case LongType | TimestampType =>
case LongType | TimestampType | DayTimeIntervalType =>
(v, i) => writer.write(i, v.getLong(i))

case FloatType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen
private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = dataType match {
// We use INT for DATE internally
case IntegerType | DateType => new MutableInt
// We use Long for Timestamp internally
case LongType | TimestampType => new MutableLong
// We use Long for Timestamp and DayTimeInterval internally
case LongType | TimestampType | DayTimeIntervalType => new MutableLong
case FloatType => new MutableFloat
case DoubleType => new MutableDouble
case BooleanType => new MutableBoolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1813,7 +1813,7 @@ object CodeGenerator extends Logging {
case ByteType => JAVA_BYTE
case ShortType => JAVA_SHORT
case IntegerType | DateType => JAVA_INT
case LongType | TimestampType => JAVA_LONG
case LongType | TimestampType | DayTimeIntervalType => JAVA_LONG
case FloatType => JAVA_FLOAT
case DoubleType => JAVA_DOUBLE
case _: DecimalType => "Decimal"
Expand All @@ -1834,7 +1834,7 @@ object CodeGenerator extends Logging {
case ByteType => java.lang.Byte.TYPE
case ShortType => java.lang.Short.TYPE
case IntegerType | DateType => java.lang.Integer.TYPE
case LongType | TimestampType => java.lang.Long.TYPE
case LongType | TimestampType | DayTimeIntervalType => java.lang.Long.TYPE
case FloatType => java.lang.Float.TYPE
case DoubleType => java.lang.Double.TYPE
case _: DecimalType => classOf[Decimal]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort}
import java.math.{BigDecimal => JavaBigDecimal}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Duration, Instant, LocalDate}
import java.util
import java.util.Objects
import javax.xml.bind.DatatypeConverter
Expand All @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Scala
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
import org.apache.spark.sql.catalyst.util.IntervalUtils.durationToMicros
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -76,6 +77,7 @@ object Literal {
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType)
case a: Array[Byte] => Literal(a, BinaryType)
case a: collection.mutable.WrappedArray[_] => apply(a.array)
case a: Array[_] =>
Expand Down Expand Up @@ -111,6 +113,7 @@ object Literal {
case _ if clz == classOf[Date] => DateType
case _ if clz == classOf[Instant] => TimestampType
case _ if clz == classOf[Timestamp] => TimestampType
case _ if clz == classOf[Duration] => DayTimeIntervalType
case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT
case _ if clz == classOf[Array[Byte]] => BinaryType
case _ if clz == classOf[Array[Char]] => StringType
Expand Down Expand Up @@ -167,6 +170,7 @@ object Literal {
case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale))
case DateType => create(0, DateType)
case TimestampType => create(0L, TimestampType)
case DayTimeIntervalType => create(0L, DayTimeIntervalType)
case StringType => Literal("")
case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8))
case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0))
Expand All @@ -186,7 +190,7 @@ object Literal {
case ByteType => v.isInstanceOf[Byte]
case ShortType => v.isInstanceOf[Short]
case IntegerType | DateType => v.isInstanceOf[Int]
case LongType | TimestampType => v.isInstanceOf[Long]
case LongType | TimestampType | DayTimeIntervalType => v.isInstanceOf[Long]
case FloatType => v.isInstanceOf[Float]
case DoubleType => v.isInstanceOf[Double]
case _: DecimalType => v.isInstanceOf[Decimal]
Expand Down Expand Up @@ -388,7 +392,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
}
case ByteType | ShortType =>
ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType))
case TimestampType | LongType =>
case TimestampType | LongType | DayTimeIntervalType =>
toExprCode(s"${value}L")
case _ =>
val constRef = ctx.addReferenceObj("literal", value, javaType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.util

import java.time.Duration
import java.time.temporal.ChronoUnit
import java.util.concurrent.TimeUnit

import scala.util.control.NonFatal
Expand Down Expand Up @@ -762,4 +764,31 @@ object IntervalUtils {

new CalendarInterval(totalMonths, totalDays, micros)
}

/**
* Converts this duration to the total length in microseconds.
* <p>
* If this duration is too large to fit in a [[Long]] microseconds, then an
* exception is thrown.
* <p>
* If this duration has greater than microsecond precision, then the conversion
* will drop any excess precision information as though the amount in nanoseconds
* was subject to integer division by one thousand.
*
* @return The total length of the duration in microseconds
* @throws ArithmeticException If numeric overflow occurs
*/
def durationToMicros(duration: Duration): Long = {
val us = Math.multiplyExact(duration.getSeconds, MICROS_PER_SECOND)
val result = Math.addExact(us, duration.getNano / NANOS_PER_MICROS)
result
}

/**
* Obtains a [[Duration]] representing a number of microseconds.
*
* @param micros The number of microseconds, positive or negative
* @return A [[Duration]], not null
*/
def microsToDuration(micros: Long): Duration = Duration.of(micros, ChronoUnit.MICROS)
}
Loading

0 comments on commit 17601e0

Please sign in to comment.