diff --git a/packages/logger/src/middleware/middy.ts b/packages/logger/src/middleware/middy.ts index 99c21aa25e..f68ea4a8e6 100644 --- a/packages/logger/src/middleware/middy.ts +++ b/packages/logger/src/middleware/middy.ts @@ -1,5 +1,6 @@ import { Logger } from '../Logger'; import { HandlerOptions, LogAttributes } from '../types'; +import { LOGGER_KEY } from '@aws-lambda-powertools/commons/lib/middleware'; import type { MiddlewareLikeObj, MiddyLikeRequest, @@ -35,15 +36,30 @@ const injectLambdaContext = ( ): MiddlewareLikeObj => { const loggers = target instanceof Array ? target : [target]; const persistentAttributes: LogAttributes[] = []; + const isClearState = options && options.clearState === true; + + /** + * Set the cleanup function to be called in case other middlewares return early. + * + * @param request - The request object + */ + const setCleanupFunction = (request: MiddyLikeRequest): void => { + request.internal = { + ...request.internal, + [LOGGER_KEY]: injectLambdaContextAfterOrOnError, + }; + }; const injectLambdaContextBefore = async ( request: MiddyLikeRequest ): Promise => { loggers.forEach((logger: Logger, index: number) => { - if (options && options.clearState === true) { + if (isClearState) { persistentAttributes[index] = { ...logger.getPersistentLogAttributes(), }; + + setCleanupFunction(request); } Logger.injectLambdaContextBefore( logger, @@ -55,7 +71,7 @@ const injectLambdaContext = ( }; const injectLambdaContextAfterOrOnError = async (): Promise => { - if (options && options.clearState === true) { + if (isClearState) { loggers.forEach((logger: Logger, index: number) => { Logger.injectLambdaContextAfterOrOnError( logger, diff --git a/packages/logger/tests/unit/middleware/middy.test.ts b/packages/logger/tests/unit/middleware/middy.test.ts index 4c97a93444..f2815b01cd 100644 --- a/packages/logger/tests/unit/middleware/middy.test.ts +++ b/packages/logger/tests/unit/middleware/middy.test.ts @@ -7,6 +7,7 @@ import { ContextExamples as dummyContext, Events as dummyEvent, } from '@aws-lambda-powertools/commons'; +import { cleanupMiddlewares } from '@aws-lambda-powertools/commons/lib/middleware'; import { ConfigServiceInterface, EnvironmentVariablesService, @@ -197,6 +198,64 @@ describe('Middy middleware', () => { persistentAttribsBeforeInvocation ); }); + + test('when enabled, and another middleware returns early, it still clears the state', async () => { + // Prepare + const logger = new Logger({ + logLevel: 'DEBUG', + }); + const loggerSpy = jest.spyOn(logger['console'], 'debug'); + const myCustomMiddleware = (): middy.MiddlewareObj => { + const before = async ( + request: middy.Request + ): Promise => { + // Return early on the second invocation + if (request.event.idx === 1) { + // Cleanup Powertools resources + await cleanupMiddlewares(request); + + // Then return early + return 'foo'; + } + }; + + return { + before, + }; + }; + const handler = middy( + ( + event: typeof dummyEvent.Custom.CustomEvent & { idx: number } + ): void => { + // Add a key only at the first invocation, so we can check that it's cleared + if (event.idx === 0) { + logger.appendKeys({ + details: { user_id: '1234' }, + }); + } + logger.debug('This is a DEBUG log'); + } + ) + .use(injectLambdaContext(logger, { clearState: true })) + .use(myCustomMiddleware()); + + // Act + await handler({ ...event, idx: 0 }, context); + await handler({ ...event, idx: 1 }, context); + + // Assess + const persistentAttribsAfterInvocation = { + ...logger.getPersistentLogAttributes(), + }; + expect(persistentAttribsAfterInvocation).toEqual({}); + // Only one log because the second invocation returned early + // from the custom middleware + expect(loggerSpy).toBeCalledTimes(1); + expect(loggerSpy).toHaveBeenNthCalledWith( + 1, + expect.stringContaining('"details":{"user_id":"1234"}') + ); + }); }); describe('Feature: log event', () => {