Skip to content

Commit

Permalink
Add writeStream (#480)
Browse files Browse the repository at this point in the history
* Add methods

* Add WriteStreamTests

* Remove withWatermark
  • Loading branch information
etspaceman authored Jan 8, 2021
1 parent 4e1b90b commit 45764e8
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
9 changes: 9 additions & 0 deletions dataset/src/main/scala/frameless/TypedDatasetForwarded.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import java.util

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, DataFrameWriter, SQLContext, SparkSession}
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -116,6 +117,14 @@ trait TypedDatasetForwarded[T] { self: TypedDataset[T] =>
def write: DataFrameWriter[T] =
dataset.write

/**
* Interface for saving the content of the streaming Dataset out into external storage.
*
* apache/spark
*/
def writeStream: DataStreamWriter[T] =
dataset.writeStream

/** Returns a new [[TypedDataset]] that has exactly `numPartitions` partitions.
* Similar to coalesce defined on an RDD, this operation results in a narrow dependency, e.g.
* if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of
Expand Down
87 changes: 87 additions & 0 deletions dataset/src/test/scala/frameless/forward/WriteStreamTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package frameless

import java.util.UUID

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.scalacheck.Prop._
import org.scalacheck.{Arbitrary, Gen, Prop}

class WriteStreamTests extends TypedDatasetSuite {

val genNested = for {
d <- Arbitrary.arbitrary[Double]
as <- Arbitrary.arbitrary[String]
} yield Nested(d, as)

val genOptionFieldsOnly = for {
o1 <- Gen.option(Arbitrary.arbitrary[Int])
o2 <- Gen.option(genNested)
} yield OptionFieldsOnly(o1, o2)

val genWriteExample = for {
i <- Arbitrary.arbitrary[Int]
s <- Arbitrary.arbitrary[String]
on <- Gen.option(genNested)
ooo <- Gen.option(genOptionFieldsOnly)
} yield WriteExample(i, s, on, ooo)

test("write csv") {
val spark = session
import spark.implicits._
def prop[A: TypedEncoder: Encoder](data: List[A]): Prop = {
val uid = UUID.randomUUID()
val uidNoHyphens = uid.toString.replace("-", "")
val filePath = s"$TEST_OUTPUT_DIR/$uid}"
val checkpointPath = s"$TEST_OUTPUT_DIR/checkpoint/$uid"
val inputStream = MemoryStream[A]
val input = TypedDataset.create(inputStream.toDS())
val inputter = input.writeStream.format("csv").option("checkpointLocation", s"$checkpointPath/input").start(filePath)
inputStream.addData(data)
inputter.processAllAvailable()
val dataset = TypedDataset.createUnsafe(sqlContext.readStream.schema(input.schema).csv(filePath))

val tester = dataset
.writeStream
.option("checkpointLocation", s"$checkpointPath/tester")
.format("memory")
.queryName(s"testCsv_$uidNoHyphens")
.start()
tester.processAllAvailable()
val output = spark.table(s"testCsv_$uidNoHyphens").as[A]
TypedDataset.create(data).collect().run().groupBy(identity) ?= output.collect().groupBy(identity).map { case (k, arr) => (k, arr.toSeq) }
}

check(forAll(Gen.nonEmptyListOf(Gen.alphaNumStr.suchThat(_.nonEmpty)))(prop[String]))
check(forAll(Gen.nonEmptyListOf(Arbitrary.arbitrary[Int]))(prop[Int]))
}

test("write parquet") {
val spark = session
import spark.implicits._
def prop[A: TypedEncoder: Encoder](data: List[A]): Prop = {
val uid = UUID.randomUUID()
val uidNoHyphens = uid.toString.replace("-", "")
val filePath = s"$TEST_OUTPUT_DIR/$uid}"
val checkpointPath = s"$TEST_OUTPUT_DIR/checkpoint/$uid"
val inputStream = MemoryStream[A]
val input = TypedDataset.create(inputStream.toDS())
val inputter = input.writeStream.format("parquet").option("checkpointLocation", s"$checkpointPath/input").start(filePath)
inputStream.addData(data)
inputter.processAllAvailable()
val dataset = TypedDataset.createUnsafe(sqlContext.readStream.schema(input.schema).parquet(filePath))

val tester = dataset
.writeStream
.option("checkpointLocation", s"$checkpointPath/tester")
.format("memory")
.queryName(s"testParquet_$uidNoHyphens")
.start()
tester.processAllAvailable()
val output = spark.table(s"testParquet_$uidNoHyphens").as[A]
TypedDataset.create(data).collect().run().groupBy(identity) ?= output.collect().groupBy(identity).map { case (k, arr) => (k, arr.toSeq) }
}

check(forAll(Gen.nonEmptyListOf(genWriteExample))(prop[WriteExample]))
}
}

0 comments on commit 45764e8

Please sign in to comment.