Skip to content

Commit

Permalink
fix: potential fix for false positive arrow leak memory
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed May 6, 2024
1 parent 4e4d528 commit 48517ef
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,21 @@ case class StreamReader(channel: ReadableByteChannel, source: String) extends Au
}

override def close(): Unit = {
close(false)
}

def close(forceCloseAllocator: Boolean): Unit = {
if (root != null) {
arrowReader.close()
root.close()
allocator.close()

arrowReader = null
root = null
}

// don't close the allocator unless it's empty or forced to.
if (allocator != null && (forceCloseAllocator || allocator.getAllocatedMemory == 0)) {
allocator.close()
allocator = null
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ class ArrowReaderIterator(channel: ReadableByteChannel, source: String)
reader.nextBatch()
}

def close(): Unit =
def close(forceCloseAllocator: Boolean): Unit =
synchronized {
if (currentBatch != null) {
currentBatch.close()
currentBatch = null
}
reader.close()
reader.close(forceCloseAllocator)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.BaseShuffleHandle
import org.apache.spark.shuffle.ShuffleReader
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.BlockId
import org.apache.spark.storage.BlockManager
import org.apache.spark.storage.BlockManagerId
Expand Down Expand Up @@ -98,19 +99,25 @@ class CometBlockStoreShuffleReader[K, C](
// read iterators, it may blow up the call stack and cause OOM.
context.addTaskCompletionListener[Unit] { _ =>
if (currentReadIterator != null) {
currentReadIterator.close()
currentReadIterator.close(true)
}
}

IpcInputStreamIterator(inputStream, decompressingNeeded = true, context)
// accumulated readers/allocator to be closed after the input stream is consumed
val accumulatedReaders = new scala.collection.mutable.ArrayBuffer[ArrowReaderIterator]()
val iter = IpcInputStreamIterator(inputStream, decompressingNeeded = true, context)
.flatMap { channel =>
if (currentReadIterator != null) {
// Closes previous read iterator.
currentReadIterator.close()
currentReadIterator.close(false)
accumulatedReaders.append(currentReadIterator)
}
currentReadIterator = new ArrowReaderIterator(channel, this.getClass.getSimpleName)
currentReadIterator.map((0, _)) // use 0 as key since it's not used
}
CompletionIterator[(Int, ColumnarBatch), Iterator[(Int, ColumnarBatch)]](
iter,
accumulatedReaders.foreach(_.close(true)))
}

// Update the context task metrics for each record read.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,12 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
CometConf.COMET_EXEC_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> coalescePartitionsEnabled.toString,
"spark.comet.shuffle.enforceMode.enabled" -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = sql(
"SELECT * FROM (SELECT * FROM testData WHERE key = 0) t1 FULL JOIN " +
"testData2 t2 ON t1.key = t2.a")
if (coalescePartitionsEnabled) {
checkShuffleAnswer(df, 0)
} else {
checkShuffleAnswer(df, 2)
}
checkShuffleAnswer(df, 2)
}
}
}
Expand Down

0 comments on commit 48517ef

Please sign in to comment.