Skip to content

Commit

Permalink
[SPARK-34615][SQL] Support java.time.Period as an external type of …
Browse files Browse the repository at this point in the history
…the year-month interval type

### What changes were proposed in this pull request?
In the PR, I propose to extend Spark SQL API to accept [`java.time.Period`](https://docs.oracle.com/javase/8/docs/api/java/time/Period.html) as an external type of recently added new Catalyst type - `YearMonthIntervalType` (see #31614). The Java class `java.time.Period` has similar semantic to ANSI SQL year-month interval type, and it is the most suitable to be an external type for `YearMonthIntervalType`. In more details:
1. Added `PeriodConverter` which converts `java.time.Period` instances to/from internal representation of the Catalyst type `YearMonthIntervalType` (to `Int` type). The `PeriodConverter` object uses new methods of `IntervalUtils`:
    - `periodToMonths()` converts the input period to the total length in months. If this period is too large to fit `Int`, the method throws the exception `ArithmeticException`. **Note:** _the input period has "days" precision, the method just ignores the days unit._
    - `monthToPeriod()` obtains a `java.time.Period` representing a number of months.
2. Support new type `YearMonthIntervalType` in `RowEncoder` via the methods `createDeserializerForPeriod()` and `createSerializerForJavaPeriod()`.
3. Extended the Literal API to construct literals from `java.time.Period` instances.

### Why are the changes needed?
1. To allow users parallelization of `java.time.Period` collections, and construct year-month interval columns. Also to collect such columns back to the driver side.
2. This will allow to write tests in other sub-tasks of SPARK-27790.

### Does this PR introduce _any_ user-facing change?
The PR extends existing functionality. So, users can parallelize instances of the `java.time.Duration` class and collect them back:

```scala
scala> val ds = Seq(java.time.Period.ofYears(10).withMonths(2)).toDS
ds: org.apache.spark.sql.Dataset[java.time.Period] = [value: yearmonthinterval]

scala> ds.collect
res0: Array[java.time.Period] = Array(P10Y2M)
```

### How was this patch tested?
- Added a few tests to `CatalystTypeConvertersSuite` to check conversion from/to `java.time.Period`.
- Checking row encoding by new tests in `RowEncoderSuite`.
- Making literals of `YearMonthIntervalType` are tested in `LiteralExpressionSuite`.
- Check collecting by `DatasetSuite` and `JavaDatasetSuite`.
- New tests in `IntervalUtilsSuites` to check conversions `java.time.Period` <-> months.

Closes #31765 from MaxGekk/java-time-period.

Authored-by: Max Gekk <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
MaxGekk authored and cloud-fan committed Mar 8, 2021
1 parent 43f355b commit e10bf64
Show file tree
Hide file tree
Showing 23 changed files with 230 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ public static Object read(
if (dataType instanceof DayTimeIntervalType) {
return obj.getLong(ordinal);
}
if (dataType instanceof YearMonthIntervalType) {
return obj.getInt(ordinal);
}

throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ object Encoders {
*/
def DURATION: Encoder[java.time.Duration] = ExpressionEncoder()

/**
* Creates an encoder that serializes instances of the `java.time.Period` class
* to the internal representation of nullable Catalyst's YearMonthIntervalType.
*
* @since 3.2.0
*/
def PERIOD: Encoder[java.time.Period] = ExpressionEncoder()

/**
* Creates an encoder for Java Bean of type T.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
import java.time.{Duration, Instant, LocalDate}
import java.time.{Duration, Instant, LocalDate, Period}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable

Expand Down Expand Up @@ -75,6 +75,7 @@ object CatalystTypeConverters {
case FloatType => FloatConverter
case DoubleType => DoubleConverter
case DayTimeIntervalType => DurationConverter
case YearMonthIntervalType => PeriodConverter
case dataType: DataType => IdentityConverter(dataType)
}
converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
Expand Down Expand Up @@ -413,6 +414,18 @@ object CatalystTypeConverters {
IntervalUtils.microsToDuration(row.getLong(column))
}

private object PeriodConverter extends CatalystTypeConverter[Period, Period, Any] {
override def toCatalystImpl(scalaValue: Period): Int = {
IntervalUtils.periodToMonths(scalaValue)
}
override def toScala(catalystValue: Any): Period = {
if (catalystValue == null) null
else IntervalUtils.monthsToPeriod(catalystValue.asInstanceOf[Int])
}
override def toScalaImpl(row: InternalRow, column: Int): Period =
IntervalUtils.monthsToPeriod(row.getInt(column))
}

/**
* Creates a converter function that will convert Scala objects to the specified Catalyst type.
* Typical use case would be converting a collection of rows that have the same schema. You will
Expand Down Expand Up @@ -479,6 +492,7 @@ object CatalystTypeConverters {
(key: Any) => convertToCatalyst(key),
(value: Any) => convertToCatalyst(value))
case d: Duration => DurationConverter.toCatalyst(d)
case p: Period => PeriodConverter.toCatalyst(p)
case other => other
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ object DeserializerBuildHelper {
returnNullable = false)
}

def createDeserializerForPeriod(path: Expression): Expression = {
StaticInvoke(
IntervalUtils.getClass,
ObjectType(classOf[java.time.Period]),
"monthsToPeriod",
path :: Nil,
returnNullable = false)
}

/**
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ object InternalRow {
case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
case ByteType => (input, ordinal) => input.getByte(ordinal)
case ShortType => (input, ordinal) => input.getShort(ordinal)
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
case IntegerType | DateType | YearMonthIntervalType =>
(input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType | DayTimeIntervalType =>
(input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal)
Expand Down Expand Up @@ -168,7 +169,8 @@ object InternalRow {
case BooleanType => (input, v) => input.setBoolean(ordinal, v.asInstanceOf[Boolean])
case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte])
case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short])
case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
case IntegerType | DateType | YearMonthIntervalType =>
(input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
case LongType | TimestampType | DayTimeIntervalType =>
(input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType, true)
case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType, true)

case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
Expand Down Expand Up @@ -253,6 +254,9 @@ object JavaTypeInference {
case c if c == classOf[java.time.Duration] =>
createDeserializerForDuration(path)

case c if c == classOf[java.time.Period] =>
createDeserializerForPeriod(path)

case c if c == classOf[java.lang.String] =>
createDeserializerForString(path, returnNullable = true)

Expand Down Expand Up @@ -412,6 +416,8 @@ object JavaTypeInference {

case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject)

case c if c == classOf[java.time.Period] => createSerializerForJavaPeriod(inputObject)

case c if c == classOf[java.math.BigDecimal] =>
createSerializerForJavaBigDecimal(inputObject)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createDeserializerForDuration(path)

case t if isSubtype(t, localTypeOf[java.time.Period]) =>
createDeserializerForPeriod(path)

case t if isSubtype(t, localTypeOf[java.lang.String]) =>
createDeserializerForString(path, returnNullable = false)

Expand Down Expand Up @@ -528,6 +531,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createSerializerForJavaDuration(inputObject)

case t if isSubtype(t, localTypeOf[java.time.Period]) =>
createSerializerForJavaPeriod(inputObject)

case t if isSubtype(t, localTypeOf[BigDecimal]) =>
createSerializerForScalaBigDecimal(inputObject)

Expand Down Expand Up @@ -748,6 +754,8 @@ object ScalaReflection extends ScalaReflection {
Schema(CalendarIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
Schema(DayTimeIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.Period]) =>
Schema(YearMonthIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
Expand Down Expand Up @@ -846,7 +854,8 @@ object ScalaReflection extends ScalaReflection {
TimestampType -> classOf[TimestampType.InternalType],
BinaryType -> classOf[BinaryType.InternalType],
CalendarIntervalType -> classOf[CalendarInterval],
DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType]
DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType],
YearMonthIntervalType -> classOf[YearMonthIntervalType.InternalType]
)

val typeBoxedJavaMapping = Map[DataType, Class[_]](
Expand All @@ -859,7 +868,8 @@ object ScalaReflection extends ScalaReflection {
DoubleType -> classOf[java.lang.Double],
DateType -> classOf[java.lang.Integer],
TimestampType -> classOf[java.lang.Long],
DayTimeIntervalType -> classOf[java.lang.Long]
DayTimeIntervalType -> classOf[java.lang.Long],
YearMonthIntervalType -> classOf[java.lang.Integer]
)

def dataTypeJavaClass(dt: DataType): Class[_] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ object SerializerBuildHelper {
returnNullable = false)
}

def createSerializerForJavaPeriod(inputObject: Expression): Expression = {
StaticInvoke(
IntervalUtils.getClass,
YearMonthIntervalType,
"periodToMonths",
inputObject :: Nil,
returnNullable = false)
}

def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ package object dsl {
AttributeReference(s, DayTimeIntervalType, nullable = true)()
}

/** Creates a new AttributeReference of the year-month interval type */
def yearMonthInterval: AttributeReference = {
AttributeReference(s, YearMonthIntervalType, nullable = true)()
}

/** Creates a new AttributeReference of type binary */
def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import org.apache.spark.sql.types._
* TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true
*
* DayTimeIntervalType -> java.time.Duration
* YearMonthIntervalType -> java.time.Period
*
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq or Array
Expand Down Expand Up @@ -112,6 +113,8 @@ object RowEncoder {

case DayTimeIntervalType => createSerializerForJavaDuration(inputObject)

case YearMonthIntervalType => createSerializerForJavaPeriod(inputObject)

case d: DecimalType =>
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down Expand Up @@ -231,6 +234,7 @@ object RowEncoder {
ObjectType(classOf[java.sql.Date])
}
case DayTimeIntervalType => ObjectType(classOf[java.time.Duration])
case YearMonthIntervalType => ObjectType(classOf[java.time.Period])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
case StringType => ObjectType(classOf[java.lang.String])
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
Expand Down Expand Up @@ -288,6 +292,8 @@ object RowEncoder {

case DayTimeIntervalType => createDeserializerForDuration(input)

case YearMonthIntervalType => createDeserializerForPeriod(input)

case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false)

case StringType => createDeserializerForString(input, returnNullable = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ object InterpretedUnsafeProjection {
case ShortType =>
(v, i) => writer.write(i, v.getShort(i))

case IntegerType | DateType =>
case IntegerType | DateType | YearMonthIntervalType =>
(v, i) => writer.write(i, v.getInt(i))

case LongType | TimestampType | DayTimeIntervalType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ final class MutableAny extends MutableValue {
final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGenericInternalRow {

private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = dataType match {
// We use INT for DATE internally
case IntegerType | DateType => new MutableInt
// We use INT for DATE and YearMonthIntervalType internally
case IntegerType | DateType | YearMonthIntervalType => new MutableInt
// We use Long for Timestamp and DayTimeInterval internally
case LongType | TimestampType | DayTimeIntervalType => new MutableLong
case FloatType => new MutableFloat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1812,7 +1812,7 @@ object CodeGenerator extends Logging {
case BooleanType => JAVA_BOOLEAN
case ByteType => JAVA_BYTE
case ShortType => JAVA_SHORT
case IntegerType | DateType => JAVA_INT
case IntegerType | DateType | YearMonthIntervalType => JAVA_INT
case LongType | TimestampType | DayTimeIntervalType => JAVA_LONG
case FloatType => JAVA_FLOAT
case DoubleType => JAVA_DOUBLE
Expand All @@ -1833,7 +1833,7 @@ object CodeGenerator extends Logging {
case BooleanType => java.lang.Boolean.TYPE
case ByteType => java.lang.Byte.TYPE
case ShortType => java.lang.Short.TYPE
case IntegerType | DateType => java.lang.Integer.TYPE
case IntegerType | DateType | YearMonthIntervalType => java.lang.Integer.TYPE
case LongType | TimestampType | DayTimeIntervalType => java.lang.Long.TYPE
case FloatType => java.lang.Float.TYPE
case DoubleType => java.lang.Double.TYPE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort}
import java.math.{BigDecimal => JavaBigDecimal}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.{Duration, Instant, LocalDate}
import java.time.{Duration, Instant, LocalDate, Period}
import java.util
import java.util.Objects
import javax.xml.bind.DatatypeConverter
Expand All @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Scala
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
import org.apache.spark.sql.catalyst.util.IntervalUtils.durationToMicros
import org.apache.spark.sql.catalyst.util.IntervalUtils.{durationToMicros, periodToMonths}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -78,6 +78,7 @@ object Literal {
case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType)
case p: Period => Literal(periodToMonths(p), YearMonthIntervalType)
case a: Array[Byte] => Literal(a, BinaryType)
case a: collection.mutable.WrappedArray[_] => apply(a.array)
case a: Array[_] =>
Expand Down Expand Up @@ -114,6 +115,7 @@ object Literal {
case _ if clz == classOf[Instant] => TimestampType
case _ if clz == classOf[Timestamp] => TimestampType
case _ if clz == classOf[Duration] => DayTimeIntervalType
case _ if clz == classOf[Period] => YearMonthIntervalType
case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT
case _ if clz == classOf[Array[Byte]] => BinaryType
case _ if clz == classOf[Array[Char]] => StringType
Expand Down Expand Up @@ -171,6 +173,7 @@ object Literal {
case DateType => create(0, DateType)
case TimestampType => create(0L, TimestampType)
case DayTimeIntervalType => create(0L, DayTimeIntervalType)
case YearMonthIntervalType => create(0, YearMonthIntervalType)
case StringType => Literal("")
case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8))
case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0))
Expand All @@ -189,7 +192,7 @@ object Literal {
case BooleanType => v.isInstanceOf[Boolean]
case ByteType => v.isInstanceOf[Byte]
case ShortType => v.isInstanceOf[Short]
case IntegerType | DateType => v.isInstanceOf[Int]
case IntegerType | DateType | YearMonthIntervalType => v.isInstanceOf[Int]
case LongType | TimestampType | DayTimeIntervalType => v.isInstanceOf[Long]
case FloatType => v.isInstanceOf[Float]
case DoubleType => v.isInstanceOf[Double]
Expand Down Expand Up @@ -366,7 +369,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
ExprCode.forNonNullValue(JavaCode.literal(code, dataType))
}
dataType match {
case BooleanType | IntegerType | DateType =>
case BooleanType | IntegerType | DateType | YearMonthIntervalType =>
toExprCode(value.toString)
case FloatType =>
value.asInstanceOf[Float] match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.util

import java.time.Duration
import java.time.{Duration, Period}
import java.time.temporal.ChronoUnit
import java.util.concurrent.TimeUnit

Expand Down Expand Up @@ -791,4 +791,35 @@ object IntervalUtils {
* @return A [[Duration]], not null
*/
def microsToDuration(micros: Long): Duration = Duration.of(micros, ChronoUnit.MICROS)

/**
* Gets the total number of months in this period.
* <p>
* This returns the total number of months in the period by multiplying the
* number of years by 12 and adding the number of months.
* <p>
*
* @return The total number of months in the period, may be negative
* @throws ArithmeticException If numeric overflow occurs
*/
def periodToMonths(period: Period): Int = {
val monthsInYears = Math.multiplyExact(period.getYears, MONTHS_PER_YEAR)
Math.addExact(monthsInYears, period.getMonths)
}

/**
* Obtains a [[Period]] representing a number of months. The days unit will be zero, and the years
* and months units will be normalized.
*
* <p>
* The months unit is adjusted to have an absolute value < 12, with the years unit being adjusted
* to compensate. For example, the method returns "2 years and 3 months" for the 27 input months.
* <p>
* The sign of the years and months units will be the same after normalization.
* For example, -13 months will be converted to "-1 year and -1 month".
*
* @param months The number of months, positive or negative
* @return The period of months, not null
*/
def monthsToPeriod(months: Int): Period = Period.ofMonths(months).normalized()
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ object DataType {
private val otherTypes = {
Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType,
DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType,
DayTimeIntervalType)
DayTimeIntervalType, YearMonthIntervalType)
.map(t => t.typeName -> t).toMap
}

Expand Down
Loading

0 comments on commit e10bf64

Please sign in to comment.