Skip to content

Commit

Permalink
chore(middleware-flexible-checksums): perform checksum calculation an…
Browse files Browse the repository at this point in the history
…d validation by default (aws#6750)
  • Loading branch information
trivikr authored Jan 15, 2025
1 parent 2293f5a commit f6068c8
Show file tree
Hide file tree
Showing 9 changed files with 478 additions and 135 deletions.
13 changes: 13 additions & 0 deletions packages/middleware-flexible-checksums/src/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import {
Encoder,
GetAwsChunkedEncodingStream,
HashConstructor,
Provider,
StreamCollector,
StreamHasher,
} from "@smithy/types";

import { RequestChecksumCalculation, ResponseChecksumValidation } from "./constants";

export interface PreviouslyResolved {
/**
* The function that will be used to convert binary data to a base64-encoded string.
Expand All @@ -31,6 +34,16 @@ export interface PreviouslyResolved {
*/
md5: ChecksumConstructor | HashConstructor;

/**
* Determines when a checksum will be calculated for request payloads
*/
requestChecksumCalculation: Provider<RequestChecksumCalculation>;

/**
* Determines when a checksum will be calculated for response payloads
*/
responseChecksumValidation: Provider<ResponseChecksumValidation>;

/**
* A constructor for a class implementing the {@link Hash} interface that computes SHA1 hashes.
* @internal
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import { setFeature } from "@aws-sdk/core";
import { afterEach, describe, expect, test as it, vi } from "vitest";

import { PreviouslyResolved } from "./configuration";
import { DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation, ResponseChecksumValidation } from "./constants";
import { flexibleChecksumsInputMiddleware } from "./flexibleChecksumsInputMiddleware";

vi.mock("@aws-sdk/core");

describe(flexibleChecksumsInputMiddleware.name, () => {
const mockNext = vi.fn();
const mockRequestValidationModeMember = "mockRequestValidationModeMember";

const mockConfig = {
requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_SUPPORTED),
responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_SUPPORTED),
} as PreviouslyResolved;

afterEach(() => {
expect(mockNext).toHaveBeenCalledTimes(1);
vi.clearAllMocks();
});

describe("sets input.requestValidationModeMember", () => {
it("when requestValidationModeMember is defined and responseChecksumValidation is supported", async () => {
const mockMiddlewareConfigWithMockRequestValidationModeMember = {
requestValidationModeMember: mockRequestValidationModeMember,
};
const handler = flexibleChecksumsInputMiddleware(
mockConfig,
mockMiddlewareConfigWithMockRequestValidationModeMember
)(mockNext, {});
await handler({ input: {} });
expect(mockNext).toHaveBeenCalledWith({ input: { [mockRequestValidationModeMember]: "ENABLED" } });
});
});

describe("leaves input.requestValidationModeMember", () => {
const mockArgs = { input: {} };

it("when requestValidationModeMember is not defined", async () => {
const handler = flexibleChecksumsInputMiddleware(mockConfig, {})(mockNext, {});
await handler(mockArgs);
expect(mockNext).toHaveBeenCalledWith(mockArgs);
});

it("when responseChecksumValidation is required", async () => {
const mockConfigResWhenRequired = {
...mockConfig,
responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_REQUIRED),
} as PreviouslyResolved;

const handler = flexibleChecksumsInputMiddleware(mockConfigResWhenRequired, {})(mockNext, {});
await handler(mockArgs);

expect(mockNext).toHaveBeenCalledWith(mockArgs);
});
});

describe("set feature", () => {
it.each([
[
"FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED",
"a",
"requestChecksumCalculation",
RequestChecksumCalculation.WHEN_REQUIRED,
],
[
"FLEXIBLE_CHECKSUMS_REQ_WHEN_SUPPORTED",
"Z",
"requestChecksumCalculation",
RequestChecksumCalculation.WHEN_SUPPORTED,
],
[
"FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED",
"c",
"responseChecksumValidation",
ResponseChecksumValidation.WHEN_REQUIRED,
],
[
"FLEXIBLE_CHECKSUMS_RES_WHEN_SUPPORTED",
"b",
"responseChecksumValidation",
ResponseChecksumValidation.WHEN_SUPPORTED,
],
])("logs %s:%s when %s=%s", async (feature, value, configKey, configValue) => {
const mockConfigOverride = {
...mockConfig,
[configKey]: () => Promise.resolve(configValue),
} as PreviouslyResolved;

const handler = flexibleChecksumsInputMiddleware(mockConfigOverride, {})(mockNext, {});
await handler({ input: {} });

expect(setFeature).toHaveBeenCalledTimes(2);
if (configKey === "requestChecksumCalculation") {
expect(setFeature).toHaveBeenNthCalledWith(1, expect.anything(), feature, value);
} else {
expect(setFeature).toHaveBeenNthCalledWith(2, expect.anything(), feature, value);
}
});
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import { setFeature } from "@aws-sdk/core";
import {
HandlerExecutionContext,
MetadataBearer,
RelativeMiddlewareOptions,
SerializeHandler,
SerializeHandlerArguments,
SerializeHandlerOutput,
SerializeMiddleware,
} from "@smithy/types";

import { PreviouslyResolved } from "./configuration";
import { RequestChecksumCalculation, ResponseChecksumValidation } from "./constants";

export interface FlexibleChecksumsInputMiddlewareConfig {
/**
* Defines a top-level operation input member used to opt-in to best-effort validation
* of a checksum returned in the HTTP response of the operation.
*/
requestValidationModeMember?: string;
}

/**
* @internal
*/
export const flexibleChecksumsInputMiddlewareOptions: RelativeMiddlewareOptions = {
name: "flexibleChecksumsInputMiddleware",
toMiddleware: "serializerMiddleware",
relation: "before",
tags: ["BODY_CHECKSUM"],
override: true,
};

/**
* @internal
*
* The input counterpart to the flexibleChecksumsMiddleware.
*/
export const flexibleChecksumsInputMiddleware =
(
config: PreviouslyResolved,
middlewareConfig: FlexibleChecksumsInputMiddlewareConfig
): SerializeMiddleware<any, any> =>
<Output extends MetadataBearer>(
next: SerializeHandler<any, Output>,
context: HandlerExecutionContext
): SerializeHandler<any, Output> =>
async (args: SerializeHandlerArguments<any>): Promise<SerializeHandlerOutput<Output>> => {
const input = args.input;
const { requestValidationModeMember } = middlewareConfig;

const requestChecksumCalculation = await config.requestChecksumCalculation();
const responseChecksumValidation = await config.responseChecksumValidation();

switch (requestChecksumCalculation) {
case RequestChecksumCalculation.WHEN_REQUIRED:
setFeature(context, "FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED", "a");
break;
case RequestChecksumCalculation.WHEN_SUPPORTED:
setFeature(context, "FLEXIBLE_CHECKSUMS_REQ_WHEN_SUPPORTED", "Z");
break;
}

switch (responseChecksumValidation) {
case ResponseChecksumValidation.WHEN_REQUIRED:
setFeature(context, "FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED", "c");
break;
case ResponseChecksumValidation.WHEN_SUPPORTED:
setFeature(context, "FLEXIBLE_CHECKSUMS_RES_WHEN_SUPPORTED", "b");
break;
}

// The value for input member to opt-in to best-effort validation of a checksum returned in the HTTP response is not set.
if (requestValidationModeMember && !input[requestValidationModeMember]) {
// Set requestValidationModeMember as ENABLED only if response checksum validation is supported.
if (responseChecksumValidation === ResponseChecksumValidation.WHEN_SUPPORTED) {
input[requestValidationModeMember] = "ENABLED";
}
}

return next(args);
};
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { BuildHandlerArguments } from "@smithy/types";
import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation } from "./constants";
import { flexibleChecksumsMiddleware } from "./flexibleChecksumsMiddleware";
import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest";
import { getChecksumLocationName } from "./getChecksumLocationName";
Expand All @@ -13,6 +13,7 @@ import { isStreaming } from "./isStreaming";
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";
import { stringHasher } from "./stringHasher";

vi.mock("@aws-sdk/core");
vi.mock("@smithy/protocol-http");
vi.mock("./getChecksumAlgorithmForRequest");
vi.mock("./getChecksumLocationName");
Expand All @@ -28,10 +29,14 @@ describe(flexibleChecksumsMiddleware.name, () => {
const mockChecksum = "mockChecksum";
const mockChecksumAlgorithmFunction = vi.fn();
const mockChecksumLocationName = "mock-checksum-location-name";
const mockRequestAlgorithmMember = "mockRequestAlgorithmMember";
const mockRequestAlgorithmMemberHttpHeader = "mock-request-algorithm-member-http-header";

const mockInput = {};
const mockConfig = {} as PreviouslyResolved;
const mockMiddlewareConfig = { requestChecksumRequired: false };
const mockConfig = {
requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_REQUIRED),
} as PreviouslyResolved;
const mockMiddlewareConfig = { input: mockInput, requestChecksumRequired: false };

const mockBody = { body: "mockRequestBody" };
const mockHeaders = { "content-length": 100, "content-encoding": "gzip" };
Expand All @@ -41,9 +46,8 @@ describe(flexibleChecksumsMiddleware.name, () => {

beforeEach(() => {
mockNext.mockResolvedValueOnce(mockResult);
const { isInstance } = HttpRequest;
(isInstance as unknown as any).mockReturnValue(true);
vi.mocked(getChecksumAlgorithmForRequest).mockReturnValue(ChecksumAlgorithm.MD5);
vi.mocked(HttpRequest.isInstance).mockReturnValue(true);
vi.mocked(getChecksumAlgorithmForRequest).mockReturnValue(ChecksumAlgorithm.CRC32);
vi.mocked(getChecksumLocationName).mockReturnValue(mockChecksumLocationName);
vi.mocked(hasHeader).mockReturnValue(true);
vi.mocked(hasHeaderWithPrefix).mockReturnValue(false);
Expand All @@ -58,8 +62,7 @@ describe(flexibleChecksumsMiddleware.name, () => {

describe("skips", () => {
it("if not an instance of HttpRequest", async () => {
const { isInstance } = HttpRequest;
(isInstance as unknown as any).mockReturnValue(false);
vi.mocked(HttpRequest.isInstance).mockReturnValue(false);
const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {});
await handler(mockArgs);
expect(getChecksumAlgorithmForRequest).not.toHaveBeenCalled();
Expand All @@ -77,7 +80,7 @@ describe(flexibleChecksumsMiddleware.name, () => {
expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1);
});

it("if header is already present", async () => {
it("skip if header is already present", async () => {
const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {});
vi.mocked(hasHeaderWithPrefix).mockReturnValue(true);

Expand All @@ -94,11 +97,53 @@ describe(flexibleChecksumsMiddleware.name, () => {

describe("adds checksum in the request header", () => {
afterEach(() => {
expect(HttpRequest.isInstance).toHaveBeenCalledTimes(1);
expect(hasHeaderWithPrefix).toHaveBeenCalledTimes(1);
expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
});

describe("if input.requestAlgorithmMember can be set", () => {
describe("input[requestAlgorithmMember] is not defined and", () => {
const mockMwConfigWithReqAlgoMember = {
...mockMiddlewareConfig,
requestAlgorithmMember: {
name: mockRequestAlgorithmMember,
httpHeader: mockRequestAlgorithmMemberHttpHeader,
},
};

it("requestChecksumCalculation is supported", async () => {
const handler = flexibleChecksumsMiddleware(
{
...mockConfig,
requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_SUPPORTED),
},
mockMwConfigWithReqAlgoMember
)(mockNext, {});
await handler(mockArgs);
expect(mockNext.mock.calls[0][0].input[mockRequestAlgorithmMember]).toEqual(DEFAULT_CHECKSUM_ALGORITHM);
expect(mockNext.mock.calls[0][0].request.headers[mockRequestAlgorithmMemberHttpHeader]).toEqual(
DEFAULT_CHECKSUM_ALGORITHM
);
});

it("requestChecksumRequired is set to true", async () => {
const handler = flexibleChecksumsMiddleware(mockConfig, {
...mockMwConfigWithReqAlgoMember,
requestChecksumRequired: true,
})(mockNext, {});

await handler(mockArgs);
expect(mockNext.mock.calls[0][0].input[mockRequestAlgorithmMember]).toEqual(DEFAULT_CHECKSUM_ALGORITHM);
expect(mockNext.mock.calls[0][0].request.headers[mockRequestAlgorithmMemberHttpHeader]).toEqual(
DEFAULT_CHECKSUM_ALGORITHM
);
});
});
});

it("for streaming body", async () => {
vi.mocked(isStreaming).mockReturnValue(true);
const mockUpdatedBody = { body: "mockUpdatedBody" };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
} from "@smithy/types";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation } from "./constants";
import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { hasHeader } from "./hasHeader";
Expand Down Expand Up @@ -73,10 +73,26 @@ export const flexibleChecksumsMiddleware =
const { body: requestBody, headers } = request;
const { base64Encoder, streamHasher } = config;
const { requestChecksumRequired, requestAlgorithmMember } = middlewareConfig;
const requestChecksumCalculation = await config.requestChecksumCalculation();

const requestAlgorithmMemberName = requestAlgorithmMember?.name;
const requestAlgorithmMemberHttpHeader = requestAlgorithmMember?.httpHeader;
// The value for input member to configure flexible checksum is not set.
if (requestAlgorithmMemberName && !input[requestAlgorithmMemberName]) {
// Set requestAlgorithmMember as default checksum algorithm only if request checksum calculation is supported
// or request checksum is required.
if (requestChecksumCalculation === RequestChecksumCalculation.WHEN_SUPPORTED || requestChecksumRequired) {
input[requestAlgorithmMemberName] = DEFAULT_CHECKSUM_ALGORITHM;
if (requestAlgorithmMemberHttpHeader) {
headers[requestAlgorithmMemberHttpHeader] = DEFAULT_CHECKSUM_ALGORITHM;
}
}
}

const checksumAlgorithm = getChecksumAlgorithmForRequest(input, {
requestChecksumRequired,
requestAlgorithmMember: requestAlgorithmMember?.name,
requestChecksumCalculation,
});
let updatedBody = requestBody;
let updatedHeaders = headers;
Expand Down
Loading

0 comments on commit f6068c8

Please sign in to comment.