Skip to content

Commit

Permalink
fix: Fix memory bloat caused by holding too many unclosed `ArrowReade…
Browse files Browse the repository at this point in the history
…rIterator`s (#929)
  • Loading branch information
Kontinuation authored Sep 10, 2024
1 parent 8c15cf4 commit c905f40
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ class ArrowReaderIterator(channel: ReadableByteChannel, source: String)
private val reader = StreamReader(channel, source)
private var batch = nextBatch()
private var currentBatch: ColumnarBatch = null
private var isClosed: Boolean = false

override def hasNext: Boolean = {
if (isClosed) {
return false
}
if (batch.isDefined) {
return true
}
Expand All @@ -42,10 +46,12 @@ class ArrowReaderIterator(channel: ReadableByteChannel, source: String)
// memory leak.
if (currentBatch != null) {
currentBatch.close()
currentBatch = null
}

batch = nextBatch()
if (batch.isEmpty) {
close()
return false
}
true
Expand All @@ -69,10 +75,13 @@ class ArrowReaderIterator(channel: ReadableByteChannel, source: String)

def close(): Unit =
synchronized {
if (currentBatch != null) {
currentBatch.close()
currentBatch = null
if (!isClosed) {
if (currentBatch != null) {
currentBatch.close()
currentBatch = null
}
reader.close()
isClosed = true
}
reader.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,20 @@ class CometBlockStoreShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
var currentReadIterator: ArrowReaderIterator = null

// Closes last read iterator after the task is finished.
// We need to close read iterator during iterating input streams,
// instead of one callback per read iterator. Otherwise if there are too many
// read iterators, it may blow up the call stack and cause OOM.
context.addTaskCompletionListener[Unit] { _ =>
if (currentReadIterator != null) {
currentReadIterator.close()
}
}

val recordIter = fetchIterator
.flatMap { case (_, inputStream) =>
var currentReadIterator: ArrowReaderIterator = null

// Closes last read iterator after the task is finished.
// We need to close read iterator during iterating input streams,
// instead of one callback per read iterator. Otherwise if there are too many
// read iterators, it may blow up the call stack and cause OOM.
context.addTaskCompletionListener[Unit] { _ =>
if (currentReadIterator != null) {
currentReadIterator.close()
}
}

IpcInputStreamIterator(inputStream, decompressingNeeded = true, context)
.flatMap { channel =>
if (currentReadIterator != null) {
Expand Down

0 comments on commit c905f40

Please sign in to comment.