Skip to content

Commit

Permalink
Fix batch use-after-close in partitioning, shuffle env init (NVIDIA#432)
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Lowe <[email protected]>
  • Loading branch information
jlowe authored Jul 25, 2020
1 parent cdfacbb commit b1cb808
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ case class GpuHashPartitioning(expressions: Seq[Expression], numPartitions: Int)
// We are doing this here because the cudf partition command is at this level
val totalRange = new NvtxRange("Hash partition", NvtxColor.PURPLE)
try {
val numRows = batch.numRows
val (partitionIndexes, partitionColumns) = {
val partitionRange = new NvtxRange("partition", NvtxColor.BLUE)
try {
Expand All @@ -129,7 +130,7 @@ case class GpuHashPartitioning(expressions: Seq[Expression], numPartitions: Int)
partitionRange.close()
}
}
val ret = sliceInternalGpuOrCpu(batch, partitionIndexes, partitionColumns)
val ret = sliceInternalGpuOrCpu(numRows, partitionIndexes, partitionColumns)
partitionColumns.safeClose()
// Close the partition columns we copied them as a part of the slice
ret.zipWithIndex.filter(_._1 != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ trait GpuPartitioning extends Partitioning {
ret
}

def sliceInternalOnGpu(batch: ColumnarBatch, partitionIndexes: Array[Int],
def sliceInternalOnGpu(numRows: Int, partitionIndexes: Array[Int],
partitionColumns: Array[GpuColumnVector]): Array[ColumnarBatch] = {
// The first index will always be 0, so we need to skip it.
val batches = if (batch.numRows > 0) {
val batches = if (numRows > 0) {
val parts = partitionIndexes.slice(1, partitionIndexes.length)
val splits = new ArrayBuffer[ColumnarBatch](numPartitions)
val table = new Table(partitionColumns.map(_.getBase).toArray: _*)
Expand Down Expand Up @@ -69,7 +69,7 @@ trait GpuPartitioning extends Partitioning {
batches
}

def sliceInternalOnCpu(batch: ColumnarBatch, partitionIndexes: Array[Int],
def sliceInternalOnCpu(numRows: Int, partitionIndexes: Array[Int],
partitionColumns: Array[GpuColumnVector]): Array[ColumnarBatch] = {
// We need to make sure that we have a null count calculated ahead of time.
// This should be a temp work around.
Expand All @@ -87,14 +87,14 @@ trait GpuPartitioning extends Partitioning {
ret(i - 1) = sliceBatch(hostPartColumns, start, idx)
start = idx
}
ret(numPartitions - 1) = sliceBatch(hostPartColumns, start, batch.numRows())
ret(numPartitions - 1) = sliceBatch(hostPartColumns, start, numRows)
ret
} finally {
hostPartColumns.safeClose()
}
}

def sliceInternalGpuOrCpu(batch: ColumnarBatch, partitionIndexes: Array[Int],
def sliceInternalGpuOrCpu(numRows: Int, partitionIndexes: Array[Int],
partitionColumns: Array[GpuColumnVector]): Array[ColumnarBatch] = {
val rapidsShuffleEnabled = GpuShuffleEnv.isRapidsShuffleEnabled
val nvtxRangeKey = if (rapidsShuffleEnabled) {
Expand All @@ -107,9 +107,9 @@ trait GpuPartitioning extends Partitioning {
val sliceRange = new NvtxRange(nvtxRangeKey, NvtxColor.CYAN)
try {
if (rapidsShuffleEnabled) {
sliceInternalOnGpu(batch, partitionIndexes, partitionColumns)
sliceInternalOnGpu(numRows, partitionIndexes, partitionColumns)
} else {
sliceInternalOnCpu(batch, partitionIndexes, partitionColumns)
sliceInternalOnCpu(numRows, partitionIndexes, partitionColumns)
}
} finally {
sliceRange.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ case class GpuRangePartitioning(
slicedSortedTbl = new Table(sortColumns: _*)
//get the final column batch, remove the sort order sortColumns
finalSortedCb = GpuColumnVector.from(sortedTbl, numSortCols, sortedTbl.getNumberOfColumns)
val numRows = finalSortedCb.numRows
partitionColumns = GpuColumnVector.extractColumns(finalSortedCb)
// get the ranges table and get upper bounds if possible
// rangeBounds can be empty or of length < numPartitions in cases where the samples are less
Expand All @@ -132,7 +133,7 @@ case class GpuRangePartitioning(
retCv = slicedSortedTbl.upperBound(nullFlags.toArray, rangesTbl, descFlags.toArray)
parts = parts ++ GpuColumnVector.toIntArray(retCv)
}
slicedCb = sliceInternalGpuOrCpu(finalSortedCb, parts, partitionColumns)
slicedCb = sliceInternalGpuOrCpu(numRows, parts, partitionColumns)
} finally {
batch.close()
if (inputCvs != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ case class GpuRoundRobinPartitioning(numPartitions: Int)
}
val totalRange = new NvtxRange("Round Robin partition", NvtxColor.PURPLE)
try {
val numRows = batch.numRows
val (partitionIndexes, partitionColumns) = {
val partitionRange = new NvtxRange("partition", NvtxColor.BLUE)
try {
Expand All @@ -77,7 +78,7 @@ case class GpuRoundRobinPartitioning(numPartitions: Int)
}
}
val ret: Array[ColumnarBatch] =
sliceInternalGpuOrCpu(batch, partitionIndexes, partitionColumns)
sliceInternalGpuOrCpu(numRows, partitionIndexes, partitionColumns)
partitionColumns.safeClose()
// Close the partition columns we copied them as a part of the slice
ret.zipWithIndex.filter(_._1 != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@ case class GpuSinglePartitioning(expressions: Seq[Expression])
Array(batch).zipWithIndex
} else {
try {
// Need to produce a contiguous table. Until there's a direct way to do this, using
// contiguous split as a workaround, closing any degenerate table after the first one.
// Nothing needs to be sliced but a contiguous table is needed for GPU shuffle which
// slice will produce.
val sliced = sliceInternalGpuOrCpu(
batch,
Array(0, batch.numRows),
batch.numRows,
Array(0),
GpuColumnVector.extractColumns(batch))
sliced.drop(1).foreach(_.close())
sliced.take(1).zipWithIndex
sliced.zipWithIndex
} finally {
batch.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.util.Utils

class GpuShuffleEnv extends Logging {
private val RAPIDS_SHUFFLE_CLASS = ShimLoader.getSparkShims.getRapidsShuffleManagerClass
private var isRapidsShuffleManagerInitialized: Boolean = false

private val catalog = new RapidsBufferCatalog
private var shuffleCatalog: ShuffleBufferCatalog = _
private var shuffleReceivedBufferCatalog: ShuffleReceivedBufferCatalog = _
Expand All @@ -41,19 +38,12 @@ class GpuShuffleEnv extends Logging {

lazy val isRapidsShuffleConfigured: Boolean = {
conf.contains("spark.shuffle.manager") &&
conf.get("spark.shuffle.manager") == RAPIDS_SHUFFLE_CLASS
}

// the shuffle plugin will call this on initialize
def setRapidsShuffleManagerInitialized(initialized: Boolean, className: String): Unit = {
assert(className == RAPIDS_SHUFFLE_CLASS)
logInfo("RapidsShuffleManager is initialized")
isRapidsShuffleManagerInitialized = initialized
conf.get("spark.shuffle.manager") == GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS
}

lazy val isRapidsShuffleEnabled: Boolean = {
val env = SparkEnv.get
val isRapidsManager = isRapidsShuffleManagerInitialized
val isRapidsManager = GpuShuffleEnv.isRapidsShuffleManagerInitialized
val externalShuffle = env.blockManager.externalShuffleServiceEnabled
isRapidsManager && !externalShuffle
}
Expand Down Expand Up @@ -110,7 +100,10 @@ class GpuShuffleEnv extends Logging {
def getDeviceStorage: RapidsDeviceMemoryStore = deviceStorage
}

object GpuShuffleEnv {
object GpuShuffleEnv extends Logging {
val RAPIDS_SHUFFLE_CLASS: String = ShimLoader.getSparkShims.getRapidsShuffleManagerClass

private var isRapidsShuffleManagerInitialized: Boolean = false
@volatile private var env: GpuShuffleEnv = _

def init(devInfo: CudaMemInfo): Unit = {
Expand All @@ -134,6 +127,10 @@ object GpuShuffleEnv {

def isRapidsShuffleEnabled: Boolean = env.isRapidsShuffleEnabled

def setRapidsShuffleManagerInitialized(initialized: Boolean, className: String): Unit =
env.setRapidsShuffleManagerInitialized(initialized, className)
// the shuffle plugin will call this on initialize
def setRapidsShuffleManagerInitialized(initialized: Boolean, className: String): Unit = {
assert(className == GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS)
logInfo("RapidsShuffleManager is initialized")
isRapidsShuffleManagerInitialized = initialized
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@

package com.nvidia.spark.rapids

import java.io.File

import ai.rapids.cudf.{Cuda, Table}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import org.scalatest.FunSuite

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.rapids.{GpuShuffleEnv, RapidsDiskBlockManager}
import org.apache.spark.sql.rapids.GpuShuffleEnv
import org.apache.spark.sql.vectorized.ColumnarBatch

class GpuPartitioningSuite extends FunSuite with Arm {
Expand Down Expand Up @@ -58,36 +56,19 @@ class GpuPartitioningSuite extends FunSuite with Arm {
}
}

def withGpuSparkSession(conf: SparkConf)(f: SparkSession => Unit): Unit = {
SparkSession.getActiveSession.foreach(_.close())
val spark = SparkSession.builder()
.master("local[1]")
.config(conf)
.config(RapidsConf.SQL_ENABLED.key, "true")
.config("spark.plugins", "com.nvidia.spark.SQLPlugin")
.appName(classOf[GpuPartitioningSuite].getSimpleName)
.getOrCreate()
try {
f(spark)
} finally {
spark.stop()
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
}
}

test("GPU partition") {
SparkSession.getActiveSession.foreach(_.close())
val conf = new SparkConf()
withGpuSparkSession(conf) { spark =>
TestUtils.withGpuSparkSession(conf) { _ =>
GpuShuffleEnv.init(Cuda.memGetInfo())
val partitionIndices = Array(0, 2)
val gp = new GpuPartitioning {
override val numPartitions: Int = partitionIndices.length
}
withResource(buildBatch()) { batch =>
val columns = GpuColumnVector.extractColumns(batch)
withResource(gp.sliceInternalOnGpu(batch, partitionIndices, columns)) { partitions =>
val numRows = batch.numRows
withResource(gp.sliceInternalOnGpu(numRows, partitionIndices, columns)) { partitions =>
partitions.zipWithIndex.foreach { case (partBatch, partIndex) =>
val startRow = partitionIndices(partIndex)
val endRow = if (partIndex < partitionIndices.length - 1) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import ai.rapids.cudf.{Cuda, Table}
import org.scalatest.FunSuite

import org.apache.spark.SparkConf
import org.apache.spark.sql.rapids.GpuShuffleEnv
import org.apache.spark.sql.vectorized.ColumnarBatch

class GpuSinglePartitioningSuite extends FunSuite with Arm {
private def buildBatch(): ColumnarBatch = {
withResource(new Table.TestBuilder()
.column(5, null.asInstanceOf[java.lang.Integer], 3, 1, 1, 1, 1, 1, 1, 1)
.column("five", "two", null, null, "one", "one", "one", "one", "one", "one")
.column(5.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)
.build()) { table =>
GpuColumnVector.from(table)
}
}

test("generates contiguous split") {
val conf = new SparkConf().set("spark.shuffle.manager", GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS)
TestUtils.withGpuSparkSession(conf) { _ =>
GpuShuffleEnv.init(Cuda.memGetInfo())
val partitioner = GpuSinglePartitioning(Nil)
withResource(buildBatch()) { expected =>
// partition will consume batch, so make a new batch with incremented refcounts
val columns = GpuColumnVector.extractColumns(expected)
columns.foreach(_.incRefCount())
val batch = new ColumnarBatch(columns.toArray, expected.numRows)
val result = partitioner.columnarEval(batch).asInstanceOf[Array[(ColumnarBatch, Int)]]
try {
assertResult(1)(result.length)
assertResult(0)(result.head._2)
val resultBatch = result.head._1
// verify this is a contiguous split table
assert(resultBatch.column(0).isInstanceOf[GpuColumnVectorFromBuffer])
TestUtils.compareBatches(expected, resultBatch)
} finally {
result.foreach(_._1.close())
}
}
}
}
}
55 changes: 39 additions & 16 deletions tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
package com.nvidia.spark.rapids

import java.io.File
import java.nio.ByteBuffer

import ai.rapids.cudf.{BufferType, ColumnVector, HostColumnVector, Table}
import ai.rapids.cudf.{BufferType, ColumnVector, DType, HostColumnVector, Table}
import org.scalatest.Assertions

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.rapids.GpuShuffleEnv
import org.apache.spark.sql.vectorized.ColumnarBatch

/** A collection of utility methods useful in tests. */
Expand Down Expand Up @@ -52,25 +56,44 @@ object TestUtils extends Assertions with Arm {
def compareColumns(expected: ColumnVector, actual: ColumnVector): Unit = {
assertResult(expected.getType)(actual.getType)
assertResult(expected.getRowCount)(actual.getRowCount)
withResource(expected.copyToHost()) { expectedHost =>
withResource(actual.copyToHost()) { actualHost =>
compareColumnBuffers(expectedHost, actualHost, BufferType.DATA)
compareColumnBuffers(expectedHost, actualHost, BufferType.VALIDITY)
compareColumnBuffers(expectedHost, actualHost, BufferType.OFFSET)
withResource(expected.copyToHost()) { e =>
withResource(actual.copyToHost()) { a =>
(0L until expected.getRowCount).foreach { i =>
assertResult(e.isNull(i))(a.isNull(i))
if (!e.isNull(i)) {
e.getType match {
case DType.BOOL8 => assertResult(e.getBoolean(i))(a.getBoolean(i))
case DType.INT8 => assertResult(e.getByte(i))(a.getByte(i))
case DType.INT16 => assertResult(e.getShort(i))(a.getShort(i))
case DType.INT32 => assertResult(e.getInt(i))(a.getInt(i))
case DType.INT64 => assertResult(e.getLong(i))(a.getLong(i))
case DType.FLOAT32 => assertResult(e.getFloat(i))(a.getFloat(i))
case DType.FLOAT64 => assertResult(e.getDouble(i))(a.getDouble(i))
case DType.STRING => assertResult(e.getJavaString(i))(a.getJavaString(i))
case _ => throw new UnsupportedOperationException("not implemented yet")
}
}
}
}
}
}

private def compareColumnBuffers(
expected: HostColumnVector,
actual: HostColumnVector,
bufferType: BufferType): Unit = {
val expectedBuffer = expected.getHostBufferFor(bufferType)
val actualBuffer = actual.getHostBufferFor(bufferType)
if (expectedBuffer != null) {
assertResult(expectedBuffer.asByteBuffer())(actualBuffer.asByteBuffer())
} else {
assertResult(null)(actualBuffer)
def withGpuSparkSession(conf: SparkConf)(f: SparkSession => Unit): Unit = {
SparkSession.getActiveSession.foreach(_.close())
val spark = SparkSession.builder()
.master("local[1]")
.config(conf)
.config(RapidsConf.SQL_ENABLED.key, "true")
.config("spark.plugins", "com.nvidia.spark.SQLPlugin")
.appName(classOf[GpuPartitioningSuite].getSimpleName)
.getOrCreate()
try {
f(spark)
} finally {
spark.stop()
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
GpuShuffleEnv.setRapidsShuffleManagerInitialized(false, GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS)
}
}
}

0 comments on commit b1cb808

Please sign in to comment.