Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace wrapTransversable generators to prevent memory leaks #2709

Merged
merged 4 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 16 additions & 33 deletions lib/Doctrine/ODM/MongoDB/Iterator/CachingIterator.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
namespace Doctrine\ODM\MongoDB\Iterator;

use Countable;
use Generator;
use Iterator as SPLIterator;
use IteratorIterator;
use ReturnTypeWillChange;
use RuntimeException;
use Traversable;
Expand Down Expand Up @@ -33,13 +34,11 @@ final class CachingIterator implements Countable, Iterator
/** @var array<mixed, TValue> */
private array $items = [];

/** @var Generator<mixed, TValue>|null */
private ?Generator $iterator;
/** @var SPLIterator<mixed, TValue>|null */
private ?SPLIterator $iterator;

private bool $iteratorAdvanced = false;

private bool $iteratorExhausted = false;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced by setting $iterator = null, as it's already done for UnrewindableIterator


/**
* Initialize the iterator and stores the first item in the cache. This
* effectively rewinds the Traversable and the wrapping Generator, which
Expand All @@ -51,7 +50,8 @@ final class CachingIterator implements Countable, Iterator
*/
public function __construct(Traversable $iterator)
{
$this->iterator = $this->wrapTraversable($iterator);
$this->iterator = new IteratorIterator($iterator);
$this->iterator->rewind();
$this->storeCurrentItem();
}

Expand Down Expand Up @@ -94,9 +94,10 @@ public function key()
/** @see http://php.net/iterator.next */
public function next(): void
{
if (! $this->iteratorExhausted) {
$this->getIterator()->next();
if ($this->iterator !== null) {
$this->iterator->next();
$this->storeCurrentItem();
$this->iteratorAdvanced = true;
}

next($this->items);
Expand Down Expand Up @@ -126,15 +127,13 @@ public function valid(): bool
*/
private function exhaustIterator(): void
{
while (! $this->iteratorExhausted) {
while ($this->iterator !== null) {
$this->next();
}

$this->iterator = null;
}

/** @return Generator<mixed, TValue> */
private function getIterator(): Generator
/** @return SPLIterator<mixed, TValue> */
private function getIterator(): SPLIterator
{
if ($this->iterator === null) {
throw new RuntimeException('Iterator has already been destroyed');
Expand All @@ -148,28 +147,12 @@ private function getIterator(): Generator
*/
private function storeCurrentItem(): void
{
$key = $this->getIterator()->key();
$key = $this->iterator->key();

if ($key === null) {
return;
$this->iterator = null;
} else {
$this->items[$key] = $this->getIterator()->current();
}

$this->items[$key] = $this->getIterator()->current();
}

/**
* @param Traversable<mixed, TValue> $traversable
*
* @return Generator<mixed, TValue>
*/
private function wrapTraversable(Traversable $traversable): Generator
{
foreach ($traversable as $key => $value) {
yield $key => $value;

$this->iteratorAdvanced = true;
}

$this->iteratorExhausted = true;
}
}
25 changes: 7 additions & 18 deletions lib/Doctrine/ODM/MongoDB/Iterator/HydratingIterator.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

use Doctrine\ODM\MongoDB\Mapping\ClassMetadata;
use Doctrine\ODM\MongoDB\UnitOfWork;
use Generator;
use Iterator;
use IteratorIterator;
use ReturnTypeWillChange;
use RuntimeException;
use Traversable;
Expand All @@ -24,8 +24,8 @@
*/
final class HydratingIterator implements Iterator
{
/** @var Generator<mixed, array<string, mixed>>|null */
private ?Generator $iterator;
/** @var Iterator<mixed, array<string, mixed>>|null */
private ?Iterator $iterator;

/**
* @param Traversable<mixed, array<string, mixed>> $traversable
Expand All @@ -34,7 +34,8 @@ final class HydratingIterator implements Iterator
*/
public function __construct(Traversable $traversable, private UnitOfWork $unitOfWork, private ClassMetadata $class, private array $unitOfWorkHints = [])
{
$this->iterator = $this->wrapTraversable($traversable);
$this->iterator = new IteratorIterator($traversable);
$this->iterator->rewind();
}

public function __destruct()
Expand Down Expand Up @@ -74,8 +75,8 @@ public function valid(): bool
return $this->key() !== null;
}

/** @return Generator<mixed, array<string, mixed>> */
private function getIterator(): Generator
/** @return Iterator<mixed, array<string, mixed>> */
private function getIterator(): Iterator
{
if ($this->iterator === null) {
throw new RuntimeException('Iterator has already been destroyed');
Expand All @@ -93,16 +94,4 @@ private function hydrate(?array $document): ?object
{
return $document !== null ? $this->unitOfWork->getOrCreateDocument($this->class->name, $document, $this->unitOfWorkHints) : null;
}

/**
* @param Traversable<mixed, array<string, mixed>> $traversable
*
* @return Generator<mixed, array<string, mixed>>
*/
private function wrapTraversable(Traversable $traversable): Generator
{
foreach ($traversable as $key => $value) {
yield $key => $value;
}
}
}
61 changes: 24 additions & 37 deletions lib/Doctrine/ODM/MongoDB/Iterator/UnrewindableIterator.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

namespace Doctrine\ODM\MongoDB\Iterator;

use Generator;
use Iterator as SPLIterator;
use IteratorIterator;
use LogicException;
use ReturnTypeWillChange;
use RuntimeException;
Expand All @@ -23,39 +24,34 @@
*/
final class UnrewindableIterator implements Iterator
{
/** @var Generator<mixed, TValue>|null */
private ?Generator $iterator;
/** @var SPLIterator<mixed, TValue>|null */
private ?SPLIterator $iterator;

private bool $iteratorAdvanced = false;

/**
* Initialize the iterator. This effectively rewinds the Traversable and
* the wrapping Generator, which will execute up to its first yield statement.
* Additionally, this mimics behavior of the SPL iterators and allows users
* to omit an explicit call to rewind() before using the other methods.
* Initialize the iterator. This effectively rewinds the Traversable.
* This mimics behavior of the SPL iterators and allows users to omit an
* explicit call to rewind() before using the other methods.
*
* @param Traversable<mixed, TValue> $iterator
*/
public function __construct(Traversable $iterator)
{
$this->iterator = $this->wrapTraversable($iterator);
$this->iterator->key();
$this->iterator = new IteratorIterator($iterator);
$this->iterator->rewind();
}

public function toArray(): array
{
$this->preventRewinding(__METHOD__);

$toArray = function () {
if (! $this->valid()) {
return;
}

yield $this->key() => $this->current();
yield from $this->getIterator();
};

return iterator_to_array($toArray());
try {
return iterator_to_array($this->getIterator());
} finally {
$this->iteratorAdvanced = true;
$this->iterator = null;
}
}

/** @return TValue|null */
Expand Down Expand Up @@ -84,6 +80,13 @@ public function next(): void
}

$this->iterator->next();
$this->iteratorAdvanced = true;

if ($this->iterator->valid()) {
return;
}

$this->iterator = null;
}

/** @see http://php.net/iterator.rewind */
Expand All @@ -108,29 +111,13 @@ private function preventRewinding(string $method): void
}
}

/** @return Generator<mixed, TValue> */
private function getIterator(): Generator
/** @return SPLIterator<mixed, TValue> */
private function getIterator(): SPLIterator
{
if ($this->iterator === null) {
throw new RuntimeException('Iterator has already been destroyed');
}

return $this->iterator;
}

/**
* @param Traversable<mixed, TValue> $traversable
*
* @return Generator<mixed, TValue>
*/
private function wrapTraversable(Traversable $traversable): Generator
{
foreach ($traversable as $key => $value) {
yield $key => $value;

$this->iteratorAdvanced = true;
}

$this->iterator = null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ public function testIterationWithEmptySet(): void
self::assertFalse($iterator->valid());
}

public function testIterationWithInvalidIterator(): void
{
$mock = $this->createMock(Iterator::class);
// The method next() should not be called on a dead cursor.
$mock->expects(self::never())->method('next');
// The method valid() return false on a dead cursor.
$mock->expects(self::once())->method('valid')->willReturn(false);

$iterator = new CachingIterator($mock);

$this->assertEquals([], $iterator->toArray());
}

public function testPartialIterationDoesNotExhaust(): void
{
$traversable = $this->getTraversableThatThrows([1, 2, new Exception()]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ public function testRewindAfterPartialIteration(): void
iterator_to_array($iterator);
}

public function testRewindAfterToArray(): void
{
$iterator = new UnrewindableIterator($this->getTraversable([1, 2, 3]));

$iterator->toArray();
$this->expectException(LogicException::class);
$iterator->rewind();
}

public function testToArray(): void
{
$iterator = new UnrewindableIterator($this->getTraversable([1, 2, 3]));
Expand Down