diff --git a/src/workerd/api/streams/internal.c++ b/src/workerd/api/streams/internal.c++ index 94b3d94d2df..5d25d0e63a2 100644 --- a/src/workerd/api/streams/internal.c++ +++ b/src/workerd/api/streams/internal.c++ @@ -48,6 +48,8 @@ kj::Promise pumpTo(ReadableStreamSource& input, WritableStreamSink& output class AllReader { // Modified from AllReader in kj/async-io.c++. + using PartList = kj::Array>; + public: explicit AllReader(ReadableStreamSource& input, uint64_t limit) : input(input), limit(limit) { @@ -59,58 +61,53 @@ public: } kj::Promise> readAllBytes() { - auto parts = co_await readParts(); - auto out = kj::heapArray(runningTotal); - copyInto(out, kj::mv(parts)); - co_return out; + return loop().then([this](PartList&& partPtrs) { + auto out = kj::heapArray(runningTotal); + copyInto(out, kj::mv(partPtrs)); + return kj::mv(out); + }); } kj::Promise readAllText() { - auto parts = co_await readParts(); - auto out = kj::heapArray(runningTotal + 1); - copyInto(out.slice(0, out.size() - 1).asBytes(), kj::mv(parts)); - out.back() = '\0'; - co_return kj::String(kj::mv(out)); + return loop().then([this](PartList&& partPtrs) { + auto out = kj::heapArray(runningTotal + 1); + copyInto(out.slice(0, out.size() - 1).asBytes(), kj::mv(partPtrs)); + out.back() = '\0'; + return kj::String(kj::mv(out)); + }); } private: ReadableStreamSource& input; uint64_t limit; + kj::Vector> parts; uint64_t runningTotal = 0; - struct Part { - kj::Array buffer; - size_t amount; - }; - using PartList = kj::Vector; - - kj::Promise readParts() { - static constexpr size_t bufferSize = 4096; - PartList parts; + kj::Promise loop() { + auto bytes = kj::heapArray(4096); - while (true) { - auto buffer = kj::heapArray(bufferSize); - auto amount = co_await input.tryRead(buffer.begin(), bufferSize, bufferSize); - runningTotal += amount; - JSG_REQUIRE(runningTotal < limit, TypeError, "Memory limit exceeded before EOF."); - - if (amount > 0) { - Part part = { .buffer = kj::mv(buffer), .amount = amount }; - parts.add(kj::mv(part)); + return input.tryRead(bytes.begin(), 1, bytes.size()) + .then([this, bytes = kj::mv(bytes)](size_t amount) mutable + -> kj::Promise { + if (amount == 0) { + return KJ_MAP(p, parts) { return p.asPtr(); }; } - if (amount < bufferSize) { - co_return kj::mv(parts); + runningTotal += amount; + if (runningTotal >= limit) { + return JSG_KJ_EXCEPTION(FAILED, TypeError, "Memory limit exceeded before EOF."); } - } + parts.add(bytes.slice(0, amount).attach(kj::mv(bytes))); + return loop(); + }); } - void copyInto(kj::ArrayPtr out, PartList&& in) { + void copyInto(kj::ArrayPtr out, PartList in) { size_t pos = 0; for (auto& part: in) { - KJ_ASSERT(part.amount <= out.size() - pos); - memcpy(out.begin() + pos, part.buffer.begin(), part.amount); - pos += part.amount; + KJ_ASSERT(part.size() <= out.size() - pos); + memcpy(out.begin() + pos, part.begin(), part.size()); + pos += part.size(); } } };