diff --git a/src/execution/execute.ts b/src/execution/execute.ts index 1bc6c4267b..a57247ebec 100644 --- a/src/execution/execute.ts +++ b/src/execution/execute.ts @@ -53,6 +53,7 @@ import { collectSubfields as _collectSubfields, } from './collectFields.js'; import { mapAsyncIterable } from './mapAsyncIterable.js'; +import { Publisher } from './publisher.js'; import { getArgumentValues, getDirectiveValues, @@ -121,7 +122,11 @@ export interface ExecutionContext { typeResolver: GraphQLTypeResolver; subscribeFieldResolver: GraphQLFieldResolver; errors: Array; - subsequentPayloads: Set; + publisher: Publisher< + AsyncPayloadRecord, + IncrementalResult, + SubsequentIncrementalExecutionResult + >; } /** @@ -357,13 +362,14 @@ function executeImpl( return result.then( (data) => { const initialResult = buildResponse(data, exeContext.errors); - if (exeContext.subsequentPayloads.size > 0) { + const publisher = exeContext.publisher; + if (publisher.hasNext()) { return { initialResult: { ...initialResult, hasNext: true, }, - subsequentResults: yieldSubsequentPayloads(exeContext), + subsequentResults: publisher.subscribe(), }; } return initialResult; @@ -375,13 +381,14 @@ function executeImpl( ); } const initialResult = buildResponse(result, exeContext.errors); - if (exeContext.subsequentPayloads.size > 0) { + const publisher = exeContext.publisher; + if (publisher.hasNext()) { return { initialResult: { ...initialResult, hasNext: true, }, - subsequentResults: yieldSubsequentPayloads(exeContext), + subsequentResults: publisher.subscribe(), }; } return initialResult; @@ -503,7 +510,7 @@ export function buildExecutionContext( fieldResolver: fieldResolver ?? defaultFieldResolver, typeResolver: typeResolver ?? defaultTypeResolver, subscribeFieldResolver: subscribeFieldResolver ?? defaultFieldResolver, - subsequentPayloads: new Set(), + publisher: new Publisher(resultFromAsyncPayloadRecord, payloadFromResults), errors: [], }; } @@ -515,7 +522,7 @@ function buildPerEventExecutionContext( return { ...exeContext, rootValue: payload, - subsequentPayloads: new Set(), + publisher: new Publisher(resultFromAsyncPayloadRecord, payloadFromResults), errors: [], }; } @@ -2038,132 +2045,49 @@ function filterSubsequentPayloads( currentAsyncRecord: AsyncPayloadRecord | undefined, ): void { const nullPathArray = pathToArray(nullPath); - exeContext.subsequentPayloads.forEach((asyncRecord) => { + exeContext.publisher.filter((asyncRecord) => { if (asyncRecord === currentAsyncRecord) { // don't remove payload from where error originates - return; + return true; } for (let i = 0; i < nullPathArray.length; i++) { if (asyncRecord.path[i] !== nullPathArray[i]) { // asyncRecord points to a path unaffected by this payload - return; + return true; } } - // asyncRecord path points to nulled error field - if (isStreamPayload(asyncRecord) && asyncRecord.iterator?.return) { - asyncRecord.iterator.return().catch(() => { - // ignore error - }); - } - exeContext.subsequentPayloads.delete(asyncRecord); + + return false; }); } -function getCompletedIncrementalResults( - exeContext: ExecutionContext, -): Array { - const incrementalResults: Array = []; - for (const asyncPayloadRecord of exeContext.subsequentPayloads) { - const incrementalResult: IncrementalResult = {}; - if (!asyncPayloadRecord.isCompleted) { - continue; - } - exeContext.subsequentPayloads.delete(asyncPayloadRecord); - if (isStreamPayload(asyncPayloadRecord)) { - const items = asyncPayloadRecord.items; - if (asyncPayloadRecord.isCompletedIterator) { - // async iterable resolver just finished but there may be pending payloads - continue; - } - (incrementalResult as IncrementalStreamResult).items = items; - } else { - const data = asyncPayloadRecord.data; - (incrementalResult as IncrementalDeferResult).data = data ?? null; - } - - incrementalResult.path = asyncPayloadRecord.path; - if (asyncPayloadRecord.label) { - incrementalResult.label = asyncPayloadRecord.label; - } - if (asyncPayloadRecord.errors.length > 0) { - incrementalResult.errors = asyncPayloadRecord.errors; - } - incrementalResults.push(incrementalResult); +function resultFromAsyncPayloadRecord( + asyncPayloadRecord: AsyncPayloadRecord, +): IncrementalResult { + const incrementalResult: IncrementalResult = {}; + if (isStreamPayload(asyncPayloadRecord)) { + const items = asyncPayloadRecord.items; + (incrementalResult as IncrementalStreamResult).items = items; + } else { + const data = asyncPayloadRecord.data; + (incrementalResult as IncrementalDeferResult).data = data ?? null; } - return incrementalResults; -} - -function yieldSubsequentPayloads( - exeContext: ExecutionContext, -): AsyncGenerator { - let isDone = false; - - async function next(): Promise< - IteratorResult - > { - if (isDone) { - return { value: undefined, done: true }; - } - await Promise.race( - Array.from(exeContext.subsequentPayloads).map((p) => p.promise), - ); - - if (isDone) { - // a different call to next has exhausted all payloads - return { value: undefined, done: true }; - } - - const incremental = getCompletedIncrementalResults(exeContext); - const hasNext = exeContext.subsequentPayloads.size > 0; - - if (!incremental.length && hasNext) { - return next(); - } - - if (!hasNext) { - isDone = true; - } - - return { - value: incremental.length ? { incremental, hasNext } : { hasNext }, - done: false, - }; + incrementalResult.path = asyncPayloadRecord.path; + if (asyncPayloadRecord.label) { + incrementalResult.label = asyncPayloadRecord.label; } - - function returnStreamIterators() { - const promises: Array>> = []; - exeContext.subsequentPayloads.forEach((asyncPayloadRecord) => { - if ( - isStreamPayload(asyncPayloadRecord) && - asyncPayloadRecord.iterator?.return - ) { - promises.push(asyncPayloadRecord.iterator.return()); - } - }); - return Promise.all(promises); + if (asyncPayloadRecord.errors.length > 0) { + incrementalResult.errors = asyncPayloadRecord.errors; } + return incrementalResult; +} - return { - [Symbol.asyncIterator]() { - return this; - }, - next, - async return(): Promise< - IteratorResult - > { - await returnStreamIterators(); - isDone = true; - return { value: undefined, done: true }; - }, - async throw( - error?: unknown, - ): Promise> { - await returnStreamIterators(); - isDone = true; - return Promise.reject(error); - }, - }; +function payloadFromResults( + incremental: ReadonlyArray, + hasNext: boolean, +): SubsequentIncrementalExecutionResult { + return incremental.length ? { incremental, hasNext } : { hasNext }; } class DeferredFragmentRecord { @@ -2189,7 +2113,7 @@ class DeferredFragmentRecord { this.parentContext = opts.parentContext; this.errors = []; this._exeContext = opts.exeContext; - this._exeContext.subsequentPayloads.add(this); + this._exeContext.publisher.add(this); this.isCompleted = false; this.data = null; this.promise = new Promise | null>((resolve) => { @@ -2240,7 +2164,7 @@ class StreamRecord { this.iterator = opts.iterator; this.errors = []; this._exeContext = opts.exeContext; - this._exeContext.subsequentPayloads.add(this); + this._exeContext.publisher.add(this); this.isCompleted = false; this.items = null; this.promise = new Promise | null>((resolve) => { diff --git a/src/execution/publisher.ts b/src/execution/publisher.ts new file mode 100644 index 0000000000..0378aae34e --- /dev/null +++ b/src/execution/publisher.ts @@ -0,0 +1,129 @@ +interface Source { + promise: Promise; + isCompleted: boolean; + isCompletedIterator?: boolean | undefined; + iterator?: AsyncIterator | undefined; +} + +type ToIncrementalResult = ( + source: TSource, +) => TIncremental; + +type ToPayload = ( + incremental: ReadonlyArray, + hasNext: boolean, +) => TPayload; + +/** + * @internal + */ +export class Publisher { + sources: Set; + toIncrementalResult: ToIncrementalResult; + toPayload: ToPayload; + + constructor( + toIncrementalResult: ToIncrementalResult, + toPayload: ToPayload, + ) { + this.sources = new Set(); + this.toIncrementalResult = toIncrementalResult; + this.toPayload = toPayload; + } + + add(source: TSource) { + this.sources.add(source); + } + + hasNext(): boolean { + return this.sources.size > 0; + } + + filter(predicate: (source: TSource) => boolean): void { + this.sources.forEach((source) => { + if (predicate(source)) { + return; + } + if (source.iterator?.return) { + source.iterator.return().catch(() => { + // ignore error + }); + } + this.sources.delete(source); + }); + } + + _getCompletedIncrementalResults(): Array { + const incrementalResults: Array = []; + for (const source of this.sources) { + if (!source.isCompleted) { + continue; + } + this.sources.delete(source); + if (source.isCompletedIterator) { + continue; + } + incrementalResults.push(this.toIncrementalResult(source)); + } + return incrementalResults; + } + + subscribe(): AsyncGenerator { + let isDone = false; + + const next = async (): Promise> => { + if (isDone) { + return { value: undefined, done: true }; + } + + await Promise.race(Array.from(this.sources).map((p) => p.promise)); + + if (isDone) { + return { value: undefined, done: true }; + } + + const incremental = this._getCompletedIncrementalResults(); + const hasNext = this.sources.size > 0; + + if (!incremental.length && hasNext) { + return next(); + } + + if (!hasNext) { + isDone = true; + } + + return { + value: this.toPayload(incremental, hasNext), + done: false, + }; + }; + + const returnIterators = () => { + const promises: Array>> = []; + this.sources.forEach((source) => { + if (source.iterator?.return) { + promises.push(source.iterator.return()); + } + }); + return Promise.all(promises); + }; + + return { + [Symbol.asyncIterator]() { + return this; + }, + next, + async return(): Promise> { + await returnIterators(); + isDone = true; + return { value: undefined, done: true }; + }, + async throw(error?: unknown): Promise> { + await returnIterators(); + isDone = true; + return Promise.reject(error); + }, + }; + } +}