diff --git a/package.json b/package.json index 444ce5e..ec595e3 100644 --- a/package.json +++ b/package.json @@ -27,11 +27,12 @@ "build": "npm run clean && npm run compile && npm run add-package-jsons", "watch": "tsc --watch", "prepublishOnly": "npm run build", - "test:unit": "mocha --recursive --full-trace --exit --reporter json > test/reports/test-results.json", + "test:unit": "mocha --recursive --full-trace --exit", + "test:unit:report": "mocha --recursive --full-trace --exit --reporter json > test/reports/test-results.json", "pretest:integration:init": "npm run build", "test:integration:init": "sh ./test/scripts/initIntTests.sh", "test:integration": "npm run test:integration:init && cucumber-js --config ./test/config/cucumber.mjs", - "test": "npm run test:unit && npm run test:integration", + "test": "npm run test:unit:report && npm run test:integration", "coverage": "c8 mocha && c8 report --reporter=html && c8 report --reporter=json-summary", "lcov": "c8 mocha && c8 report --reporter=lcov", "lint": "eslint .", diff --git a/src/consumer.ts b/src/consumer.ts index b96a367..44f24d1 100644 --- a/src/consumer.ts +++ b/src/consumer.ts @@ -18,7 +18,11 @@ import { MessageSystemAttributeName, } from "@aws-sdk/client-sqs"; -import { ConsumerOptions, StopOptions, UpdatableOptions } from "./types.js"; +import type { + ConsumerOptions, + StopOptions, + UpdatableOptions, +} from "./types.js"; import { TypedEventEmitter } from "./emitter.js"; import { SQSError, @@ -51,7 +55,10 @@ export class Consumer extends TypedEventEmitter { private alwaysAcknowledge: boolean; private batchSize: number; private visibilityTimeout: number; - private terminateVisibilityTimeout: boolean | number; + private terminateVisibilityTimeout: + | boolean + | number + | ((message: Message[]) => number); private waitTimeSeconds: number; private authenticationErrorTimeout: number; private pollingWaitTimeMs: number; @@ -363,11 +370,16 @@ export class Consumer extends TypedEventEmitter { this.emitError(err, message); if (this.terminateVisibilityTimeout !== false) { - const timeout = - this.terminateVisibilityTimeout === true - ? 0 - : this.terminateVisibilityTimeout; - await this.changeVisibilityTimeout(message, timeout); + if (typeof this.terminateVisibilityTimeout === "function") { + const timeout = this.terminateVisibilityTimeout([message]); + await this.changeVisibilityTimeout(message, timeout); + } else { + const timeout = + this.terminateVisibilityTimeout === true + ? 0 + : this.terminateVisibilityTimeout; + await this.changeVisibilityTimeout(message, timeout); + } } } finally { if (this.heartbeatInterval) { @@ -405,11 +417,16 @@ export class Consumer extends TypedEventEmitter { this.emit("error", err, messages); if (this.terminateVisibilityTimeout !== false) { - const timeout = - this.terminateVisibilityTimeout === true - ? 0 - : this.terminateVisibilityTimeout; - await this.changeVisibilityTimeoutBatch(messages, timeout); + if (typeof this.terminateVisibilityTimeout === "function") { + const timeout = this.terminateVisibilityTimeout(messages); + await this.changeVisibilityTimeoutBatch(messages, timeout); + } else { + const timeout = + this.terminateVisibilityTimeout === true + ? 0 + : this.terminateVisibilityTimeout; + await this.changeVisibilityTimeoutBatch(messages, timeout); + } } } finally { clearInterval(heartbeatTimeoutId); @@ -430,9 +447,9 @@ export class Consumer extends TypedEventEmitter { messages, this.visibilityTimeout, ); - } else { - return this.changeVisibilityTimeout(message, this.visibilityTimeout); } + + return this.changeVisibilityTimeout(message, this.visibilityTimeout); }, this.heartbeatInterval * 1000); } @@ -511,7 +528,7 @@ export class Consumer extends TypedEventEmitter { let handleMessageTimeoutId: NodeJS.Timeout | undefined = undefined; try { - let result; + let result: Message | void; if (this.handleMessageTimeout) { const pending: Promise = new Promise((_, reject): void => { @@ -533,7 +550,8 @@ export class Consumer extends TypedEventEmitter { err, `Message handler timed out after ${this.handleMessageTimeout}ms: Operation timed out.`, ); - } else if (err instanceof Error) { + } + if (err instanceof Error) { throw toStandardError( err, `Unexpected message handler failure: ${err.message}`, diff --git a/src/types.ts b/src/types.ts index a1ed7aa..839809e 100644 --- a/src/types.ts +++ b/src/types.ts @@ -68,9 +68,14 @@ export interface ConsumerOptions { /** * If true, sets the message visibility timeout to 0 after a `processing_error`. You can * also specify a different timeout using a number. + * If you would like to use exponential backoff, you can pass a function that returns + * a number and it will use that as the value for the timeout. * @defaultvalue `false` */ - terminateVisibilityTimeout?: boolean | number; + terminateVisibilityTimeout?: + | boolean + | number + | ((messages: Message[]) => number); /** * The interval (in seconds) between requests to extend the message visibility timeout. * diff --git a/test/tests/consumer.test.ts b/test/tests/consumer.test.ts index 871bcfa..974277c 100644 --- a/test/tests/consumer.test.ts +++ b/test/tests/consumer.test.ts @@ -6,6 +6,7 @@ import { ReceiveMessageCommand, SQSClient, QueueAttributeName, + Message, } from "@aws-sdk/client-sqs"; import { assert } from "chai"; import * as sinon from "sinon"; @@ -1008,7 +1009,7 @@ describe("Consumer", () => { consumer.stop(); }); - it("terminate message visibility timeout on processing error", async () => { + it("terminates message visibility timeout on processing error", async () => { handleMessage.rejects(new Error("Processing error")); consumer.terminateVisibilityTimeout = true; @@ -1031,6 +1032,54 @@ describe("Consumer", () => { ); }); + it("terminates message visibility timeout with a function to calculate timeout on processing error", async () => { + const messageWithAttr = { + ReceiptHandle: "receipt-handle", + MessageId: "1", + Body: "body-2", + Attributes: { + ApproximateReceiveCount: 2, + }, + }; + sqs.send.withArgs(mockReceiveMessage).resolves({ + Messages: [messageWithAttr], + }); + + consumer = new Consumer({ + queueUrl: QUEUE_URL, + messageSystemAttributeNames: ["ApproximateReceiveCount"], + region: REGION, + handleMessage, + sqs, + terminateVisibilityTimeout: (messages: Message[]) => { + const receiveCount = + Number.parseInt( + messages[0].Attributes?.ApproximateReceiveCount || "1", + ) || 1; + return receiveCount * 10; + }, + }); + + handleMessage.rejects(new Error("Processing error")); + + consumer.start(); + await pEvent(consumer, "processing_error"); + consumer.stop(); + + sandbox.assert.calledWith( + sqs.send.secondCall, + mockChangeMessageVisibility, + ); + sandbox.assert.match( + sqs.send.secondCall.args[0].input, + sinon.match({ + QueueUrl: QUEUE_URL, + ReceiptHandle: "receipt-handle", + VisibilityTimeout: 20, + }), + ); + }); + it("changes message visibility timeout on processing error", async () => { handleMessage.rejects(new Error("Processing error"));