From ef9e103814759121a22713aca013e70dd70e9af3 Mon Sep 17 00:00:00 2001
From: Neal Beeken <neal.beeken@mongodb.com>
Date: Tue, 13 Feb 2024 13:33:16 -0500
Subject: [PATCH] perf(NODE-5928): consolidate signal use and abort promise
 wrap

---
 src/cmap/connection.ts  |  64 ++++++++++++-------
 src/utils.ts            |  30 ---------
 test/unit/utils.test.ts | 132 ----------------------------------------
 3 files changed, 41 insertions(+), 185 deletions(-)

diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts
index 3406fed594b..7b277794edc 100644
--- a/src/cmap/connection.ts
+++ b/src/cmap/connection.ts
@@ -29,7 +29,6 @@ import { type CancellationToken, TypedEventEmitter } from '../mongo_types';
 import type { ReadPreferenceLike } from '../read_preference';
 import { applySession, type ClientSession, updateSessionFromResponse } from '../sessions';
 import {
-  abortable,
   BufferPool,
   calculateDurationInMs,
   type Callback,
@@ -37,6 +36,7 @@ import {
   maxWireVersion,
   type MongoDBNamespace,
   now,
+  promiseWithResolvers,
   uuidV4
 } from '../utils';
 import type { WriteConcern } from '../write_concern';
@@ -161,15 +161,14 @@ function streamIdentifier(stream: Stream, options: ConnectionOptions): string {
 export class Connection extends TypedEventEmitter<ConnectionEvents> {
   public id: number | '<monitor>';
   public address: string;
-  public lastHelloMS?: number;
+  public lastHelloMS = -1;
   public serverApi?: ServerApi;
-  public helloOk?: boolean;
+  public helloOk = false;
   public authContext?: AuthContext;
   public delayedTimeoutId: NodeJS.Timeout | null = null;
   public generation: number;
   public readonly description: Readonly<StreamDescription>;
   /**
-   * @public
    * Represents if the connection has been established:
    *  - TCP handshake
    *  - TLS negotiated
@@ -180,15 +179,16 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
   public established: boolean;
 
   private lastUseTime: number;
-  private socketTimeoutMS: number;
-  private monitorCommands: boolean;
-  private socket: Stream;
-  private controller: AbortController;
-  private messageStream: Readable;
-  private socketWrite: (buffer: Uint8Array) => Promise<void>;
   private clusterTime: Document | null = null;
-  /** @internal */
-  override mongoLogger: MongoLogger | undefined;
+
+  private readonly socketTimeoutMS: number;
+  private readonly monitorCommands: boolean;
+  private readonly socket: Stream;
+  private readonly controller: AbortController;
+  private readonly signal: AbortSignal;
+  private readonly messageStream: Readable;
+  private readonly socketWrite: (buffer: Uint8Array) => Promise<void>;
+  private readonly aborted: Promise<never>;
 
   /** @event */
   static readonly COMMAND_STARTED = COMMAND_STARTED;
@@ -221,7 +221,21 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
     this.lastUseTime = now();
 
     this.socket = stream;
+
+    // TODO: Remove signal from connection layer
     this.controller = new AbortController();
+    const { signal } = this.controller;
+    this.signal = signal;
+    const { promise: aborted, reject } = promiseWithResolvers<never>();
+    aborted.then(undefined, () => null); // Prevent unhandled rejection
+    this.signal.addEventListener(
+      'abort',
+      function onAbort() {
+        reject(signal.reason);
+      },
+      { once: true }
+    );
+    this.aborted = aborted;
 
     this.messageStream = this.socket
       .on('error', this.onError.bind(this))
@@ -232,13 +246,13 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
 
     const socketWrite = promisify(this.socket.write.bind(this.socket));
     this.socketWrite = async buffer => {
-      return abortable(socketWrite(buffer), { signal: this.controller.signal });
+      return Promise.race([socketWrite(buffer), this.aborted]);
     };
   }
 
   /** Indicates that the connection (including underlying TCP socket) has been closed. */
   public get closed(): boolean {
-    return this.controller.signal.aborted;
+    return this.signal.aborted;
   }
 
   public get hello() {
@@ -407,7 +421,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
   }
 
   private async *sendWire(message: WriteProtocolMessageType, options: CommandOptions) {
-    this.controller.signal.throwIfAborted();
+    this.throwIfAborted();
 
     if (typeof options.socketTimeoutMS === 'number') {
       this.socket.setTimeout(options.socketTimeoutMS);
@@ -426,7 +440,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
         return;
       }
 
-      this.controller.signal.throwIfAborted();
+      this.throwIfAborted();
 
       for await (const response of this.readMany()) {
         this.socket.setTimeout(0);
@@ -447,7 +461,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
         }
 
         yield document;
-        this.controller.signal.throwIfAborted();
+        this.throwIfAborted();
 
         if (typeof options.socketTimeoutMS === 'number') {
           this.socket.setTimeout(options.socketTimeoutMS);
@@ -481,7 +495,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
 
     let document;
     try {
-      this.controller.signal.throwIfAborted();
+      this.throwIfAborted();
       for await (document of this.sendWire(message, options)) {
         if (!Buffer.isBuffer(document) && document.writeConcernError) {
           throw new MongoWriteConcernError(document.writeConcernError, document);
@@ -511,7 +525,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
         }
 
         yield document;
-        this.controller.signal.throwIfAborted();
+        this.throwIfAborted();
       }
     } catch (error) {
       if (this.shouldEmitAndLogCommand) {
@@ -554,7 +568,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
     command: Document,
     options: CommandOptions = {}
   ): Promise<Document> {
-    this.controller.signal.throwIfAborted();
+    this.throwIfAborted();
     for await (const document of this.sendCommand(ns, command, options)) {
       return document;
     }
@@ -568,16 +582,20 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
     replyListener: Callback
   ) {
     const exhaustLoop = async () => {
-      this.controller.signal.throwIfAborted();
+      this.throwIfAborted();
       for await (const reply of this.sendCommand(ns, command, options)) {
         replyListener(undefined, reply);
-        this.controller.signal.throwIfAborted();
+        this.throwIfAborted();
       }
       throw new MongoUnexpectedServerResponseError('Server ended moreToCome unexpectedly');
     };
     exhaustLoop().catch(replyListener);
   }
 
+  private throwIfAborted() {
+    this.signal.throwIfAborted();
+  }
+
   /**
    * @internal
    *
@@ -611,7 +629,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
    * Note that `for-await` loops call `return` automatically when the loop is exited.
    */
   private async *readMany(): AsyncGenerator<OpMsgResponse | OpQueryResponse> {
-    for await (const message of onData(this.messageStream, { signal: this.controller.signal })) {
+    for await (const message of onData(this.messageStream, { signal: this.signal })) {
       const response = await decompressResponse(message);
       yield response;
 
diff --git a/src/utils.ts b/src/utils.ts
index 719367cad21..173de9053a5 100644
--- a/src/utils.ts
+++ b/src/utils.ts
@@ -1283,36 +1283,6 @@ export function isHostMatch(match: RegExp, host?: string): boolean {
   return host && match.test(host.toLowerCase()) ? true : false;
 }
 
-/**
- * Takes a promise and races it with a promise wrapping the abort event of the optionally provided signal.
- * The given promise is _always_ ordered before the signal's abort promise.
- * When given an already rejected promise and an already aborted signal, the promise's rejection takes precedence.
- *
- * @see https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Promise/race
- *
- * @param promise - A promise to discard if the signal aborts
- * @param options - An options object carrying an optional signal
- */
-export async function abortable<T>(
-  promise: Promise<T>,
-  { signal }: { signal: AbortSignal }
-): Promise<T> {
-  const { promise: aborted, reject } = promiseWithResolvers<never>();
-
-  function rejectOnAbort() {
-    reject(signal.reason);
-  }
-
-  if (signal.aborted) rejectOnAbort();
-  else signal.addEventListener('abort', rejectOnAbort, { once: true });
-
-  try {
-    return await Promise.race([promise, aborted]);
-  } finally {
-    signal.removeEventListener('abort', rejectOnAbort);
-  }
-}
-
 export function promiseWithResolvers<T>() {
   let resolve!: Parameters<ConstructorParameters<typeof Promise<T>>[0]>[0];
   let reject!: Parameters<ConstructorParameters<typeof Promise<T>>[0]>[1];
diff --git a/test/unit/utils.test.ts b/test/unit/utils.test.ts
index cf988382a28..b5fcadbffcb 100644
--- a/test/unit/utils.test.ts
+++ b/test/unit/utils.test.ts
@@ -1,9 +1,7 @@
 import { expect } from 'chai';
 import * as sinon from 'sinon';
-import { setTimeout } from 'timers';
 
 import {
-  abortable,
   BufferPool,
   ByteUtils,
   compareObjectId,
@@ -21,7 +19,6 @@ import {
   shuffle,
   TimeoutController
 } from '../mongodb';
-import { sleep } from '../tools/utils';
 import { createTimerSandbox } from './timer_sandbox';
 
 describe('driver utils', function () {
@@ -1077,133 +1074,4 @@ describe('driver utils', function () {
       });
     });
   });
-
-  describe('abortable()', () => {
-    const goodError = new Error('good error');
-    const badError = new Error('unexpected bad error!');
-    const expectedValue = "don't panic";
-
-    context('always removes the abort listener it attaches', () => {
-      let controller;
-      let removeEventListenerSpy;
-      let addEventListenerSpy;
-
-      beforeEach(() => {
-        controller = new AbortController();
-        addEventListenerSpy = sinon.spy(controller.signal, 'addEventListener');
-        removeEventListenerSpy = sinon.spy(controller.signal, 'removeEventListener');
-      });
-
-      afterEach(() => sinon.restore());
-
-      const expectListenerCleanup = () => {
-        expect(addEventListenerSpy).to.have.been.calledOnce;
-        expect(removeEventListenerSpy).to.have.been.calledOnce;
-      };
-
-      it('when promise rejects', async () => {
-        await abortable(Promise.reject(goodError), { signal: controller.signal }).catch(e => e);
-        expectListenerCleanup();
-      });
-
-      it('when promise resolves', async () => {
-        await abortable(Promise.resolve(expectedValue), { signal: controller.signal });
-        expectListenerCleanup();
-      });
-
-      it('when signal aborts', async () => {
-        setTimeout(() => controller.abort(goodError));
-        await abortable(new Promise(() => null), { signal: controller.signal }).catch(e => e);
-        expectListenerCleanup();
-      });
-    });
-
-    context('when given already rejected promise with already aborted signal', () => {
-      it('returns promise rejection', async () => {
-        const controller = new AbortController();
-        const { signal } = controller;
-        controller.abort(badError);
-        const result = await abortable(Promise.reject(goodError), { signal }).catch(e => e);
-        expect(result).to.deep.equal(goodError);
-      });
-    });
-
-    context('when given already resolved promise with already aborted signal', () => {
-      it('returns promise resolution', async () => {
-        const controller = new AbortController();
-        const { signal } = controller;
-        controller.abort(badError);
-        const result = await abortable(Promise.resolve(expectedValue), { signal }).catch(e => e);
-        expect(result).to.deep.equal(expectedValue);
-      });
-    });
-
-    context('when given already rejected promise with not yet aborted signal', () => {
-      it('returns promise rejection', async () => {
-        const controller = new AbortController();
-        const { signal } = controller;
-        const result = await abortable(Promise.reject(goodError), { signal }).catch(e => e);
-        expect(result).to.deep.equal(goodError);
-      });
-    });
-
-    context('when given already resolved promise with not yet aborted signal', () => {
-      it('returns promise resolution', async () => {
-        const controller = new AbortController();
-        const { signal } = controller;
-        const result = await abortable(Promise.resolve(expectedValue), { signal }).catch(e => e);
-        expect(result).to.deep.equal(expectedValue);
-      });
-    });
-
-    context('when given unresolved promise with an already aborted signal', () => {
-      it('returns signal reason', async () => {
-        const controller = new AbortController();
-        const { signal } = controller;
-        controller.abort(goodError);
-        const result = await abortable(new Promise(() => null), { signal }).catch(e => e);
-        expect(result).to.deep.equal(goodError);
-      });
-    });
-
-    context('when given eventually rejecting promise with not yet aborted signal', () => {
-      const eventuallyReject = async () => {
-        await sleep(1);
-        throw goodError;
-      };
-
-      it('returns promise rejection', async () => {
-        const controller = new AbortController();
-        const { signal } = controller;
-        const result = await abortable(eventuallyReject(), { signal }).catch(e => e);
-        expect(result).to.deep.equal(goodError);
-      });
-    });
-
-    context('when given eventually resolving promise with not yet aborted signal', () => {
-      const eventuallyResolve = async () => {
-        await sleep(1);
-        return expectedValue;
-      };
-
-      it('returns promise resolution', async () => {
-        const controller = new AbortController();
-        const { signal } = controller;
-        const result = await abortable(eventuallyResolve(), { signal }).catch(e => e);
-        expect(result).to.deep.equal(expectedValue);
-      });
-    });
-
-    context('when given unresolved promise with eventually aborted signal', () => {
-      it('returns signal reason', async () => {
-        const controller = new AbortController();
-        const { signal } = controller;
-
-        setTimeout(() => controller.abort(goodError), 1);
-
-        const result = await abortable(new Promise(() => null), { signal }).catch(e => e);
-        expect(result).to.deep.equal(goodError);
-      });
-    });
-  });
 });