From 79dd53945b8356426a00359c17511de73daa89b2 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 8 Apr 2022 04:46:03 -0700 Subject: [PATCH] concat --- .../backends/webgpu/op-resolve-rules.ts | 4 +- .../lib/onnxjs/backends/webgpu/ops/common.ts | 13 +- .../lib/onnxjs/backends/webgpu/ops/concat.ts | 173 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 14 +- 4 files changed, 192 insertions(+), 12 deletions(-) create mode 100644 js/web/lib/onnxjs/backends/webgpu/ops/concat.ts diff --git a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts index dc786a423893d..4a3a5dfbf5003 100644 --- a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts @@ -4,6 +4,7 @@ import {OpSet} from '../../opset'; import * as binaryOps from './ops/binary-op'; +import {concat, parseConcatAttributes} from './ops/concat'; import {gather, parseGatherAttributes} from './ops/gather'; import {reshape} from './ops/reshape'; import * as unaryOps from './ops/unary-op'; @@ -18,8 +19,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ // ['BatchNormalization', '', '7+', batchNormalization, parseBatchNormalizationAttributes], // ['Cast', '', '6+', cast, parseCastAttributes], ['Ceil', '', '6+', unaryOps.ceil], ['Clip', '', '6-10', unaryOps.clip, unaryOps.parseClipAttributes], - ['Clip', '', '11+', unaryOps.clipV11], - // ['Concat', '', '4+', concat, parseConcatAttributes], + ['Clip', '', '11+', unaryOps.clipV11], ['Concat', '', '4+', concat, parseConcatAttributes], // ['Conv', '', '1+', conv, parseConvAttributes], ['Cos', '', '7+', unaryOps.cos], ['Div', '', '7+', binaryOps.div], // ['Dropout', '', '7+', unaryOps.identity], diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/common.ts b/js/web/lib/onnxjs/backends/webgpu/ops/common.ts index cc9134b7f33ad..35c79f705b437 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/common.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/common.ts @@ -30,12 +30,18 @@ export interface IndicesHelper { i2oImpl: string; /** * WGSL code of function implementation for indices-to-offset + * + * @param isPtr - whether the variable is a pointer. default is false. */ - i2oExpression: (varIndices: string) => string; + i2oExpression: (varIndices: string, isPtr?: boolean) => string; /** * WGSL code of indices variable declaration */ indicesVariableDeclaration: (v: string) => string; + /** + * data type of indices + */ + iType: string; } export const createIndicesHelper = (name: string, shape: readonly number[]) => { @@ -72,9 +78,10 @@ export const createIndicesHelper = (name: string, shape: readonly number[]) => { return ${offsets.length > 0 ? offsets.join('+') : '0u'}; }`; - const i2oExpression = (varIndices: string) => shape.length < 2 ? varIndices : `ih_i2o_${name}(&${varIndices})`; + const i2oExpression = (varIndices: string, isPtr?: boolean) => + shape.length < 2 ? `(${isPtr ? '*' : ''}${varIndices})` : `ih_i2o_${name}(${isPtr ? '' : '&'}${varIndices})`; const indicesVariableDeclaration = (v: string) => `var ${v}:${iType};`; - return {o2iImpl, o2iCall, i2oImpl, i2oExpression, indicesVariableDeclaration}; + return {o2iImpl, o2iCall, i2oImpl, i2oExpression, indicesVariableDeclaration, iType}; }; diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/concat.ts b/js/web/lib/onnxjs/backends/webgpu/ops/concat.ts new file mode 100644 index 0000000000000..cc03a8e5d699e --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgpu/ops/concat.ts @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; +import {Graph} from '../../../graph'; +import {OperatorInitialization} from '../../../operators'; +import {Tensor} from '../../../tensor'; +import {ShapeUtil} from '../../../util'; +import {WebGpuInferenceHandler} from '../inference-handler'; +import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {createIndicesHelper, IndicesHelper, WORKGROUP_SIZE} from './common'; + +export interface ConcatAttributes extends AttributeWithCacheKey { + readonly axis: number; +} + +export const concat = async( + inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], attributes: ConcatAttributes): Promise => { + validateInputs(inputs); + return inferenceHandler.run(createConcatProgramInfoLoader(inputs, attributes), inputs); +}; + +const createConcatProgramMetadata = (inputCount: number, cacheHint: string) => + ({name: 'Concat', inputTypes: Array(inputCount).fill(GpuDataType.default), cacheHint}); + +const createConcatProgramInfo = + (metadata: ProgramMetadata, inputs: Tensor[], axis: number, dataType = 'f32'): ProgramInfo => { + const inputShape = inputs[0].dims.slice(); + if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { + throw new Error('axis specified for concat doesn\'t match input dimensionality'); + } + if (axis < 0) { + axis = inputShape.length + axis; + } + // ensure all of the non-concatenated axes match each other + // calculate the shape of the output tensor while we do that + const outputShape = inputShape.slice(0); + for (let i = 1; i < inputs.length; i++) { + const dataNShape = inputs[i].dims.slice(); + for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { + // add to the placeholder for computing output shape + if (axisIndex === axis) { + outputShape[axis] += dataNShape[axisIndex]; + } + // ensure all non-cancatenated axes match each other + else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { + throw new Error('non concat dimensions must match'); + } + } + } + + const outputSize = ShapeUtil.size(outputShape); + const rank = outputShape.length; + + const sizeInConcatAxis = new Array(inputs.length); + const inputStorageBuffersDeclarations = new Array(inputs.length); + const inputIndicesHelpers = new Array(inputs.length); + + let previousSum = 0; + for (let i = 0; i < inputs.length; ++i) { + previousSum += inputs[i].dims[axis]; + sizeInConcatAxis[i] = previousSum; + + inputStorageBuffersDeclarations[i] = + `@group(0) @binding(${i}) var input${i} : array<${dataType}>;`; + + inputIndicesHelpers[i] = createIndicesHelper(`input${i}`, inputs[i].dims); + } + + const outputIndicesHelper = createIndicesHelper('output', outputShape); + + const indicesAxis = rank < 2 ? 'indices' : `indices[${axis}]`; + const shaderSource = ` + let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u; + + ${inputStorageBuffersDeclarations.join('\n')} + @group(0) @binding(${inputs.length}) var output : array<${dataType}>; + + ${inputIndicesHelpers.map(i => i.i2oImpl).join('\n')} + ${outputIndicesHelper.o2iImpl} + + let sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); + ${calculateInputIndexImpl(sizeInConcatAxis.length)} + ${readBufferDataImpl(inputIndicesHelpers, rank, dataType)} + + @stage(compute) @workgroup_size(WORKGROUP_SIZE) + fn main(@builtin(global_invocation_id) global_id : vec3) { + + // Guard against out-of-bounds work group sizes + if (global_id.x >= ${outputSize}u) { + return; + } + + ${outputIndicesHelper.indicesVariableDeclaration('indices')} + ${outputIndicesHelper.o2iCall('global_id.x', 'indices')} + + let textureIndex = calculateInputIndex(${indicesAxis}); + if (textureIndex != 0u) { + ${indicesAxis} -= sizeInConcatAxis[textureIndex - 1u]; + } + + output[global_id.x] = readBufferData(textureIndex, &indices); + }`; + return { + ...metadata, + outputs: [{dims: outputShape, type: inputs[0].type, gpuDataType: GpuDataType.default}], + shaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +const createConcatProgramInfoLoader = (inputs: Tensor[], attributes: ConcatAttributes): ProgramInfoLoader => { + const metadata = createConcatProgramMetadata(inputs.length, attributes.cacheKey); + return {...metadata, get: () => createConcatProgramInfo(metadata, inputs, attributes.axis)}; +}; + +const calculateInputIndexImpl = (numberOfTensors: number): string => ` + fn calculateInputIndex(index: u32) -> u32 { + for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) { + if (index < sizeInConcatAxis[i]) { + return i; + } + } + return ${numberOfTensors}u; + }`; + +const readBufferDataImpl = (indicesHelper: readonly IndicesHelper[], tensorRank: number, dataType: string) => { + const numberOfTensors = indicesHelper.length; + const codeLines: string[] = []; + for (let i = 0; i < numberOfTensors; ++i) { + const returnSnippet = `return input${i}[${indicesHelper[i].i2oExpression('indices', true)}];`; + if (i === 0) { + codeLines.push(`if (textureIndex == ${i}u) { ${returnSnippet} }`); + } else if (i === numberOfTensors - 1) { + codeLines.push(`else { ${returnSnippet} }`); + } else { + codeLines.push(`else if (textureIndex == ${i}) { ${returnSnippet} }`); + } + } + return ` + fn readBufferData(textureIndex: u32, indices: ptr) -> ${dataType} { + ${codeLines.join('\n')} + }`; +}; + +export const parseConcatAttributes: OperatorInitialization = (node: Graph.Node): ConcatAttributes => + createAttributeWithCacheKey({axis: node.attributes.getInt('axis')}); + +const validateInputs = (inputs: Tensor[]): void => { + if (!inputs || inputs.length < 1) { + throw new Error('too few inputs'); + } + + const inputType = inputs[0].type; + const inputDimensionality = inputs[0].dims.length; + + // TODO: Support string concat + if (inputType === 'string') { + throw new Error('string tensor is not supported yet'); + } + + for (const input of inputs) { + // make sure types of all inputs match + if (input.type !== inputType) { + throw new Error('input tensors should be one type'); + } + + // make sure the dimensionality of all inputs are the same + if (input.dims.length !== inputDimensionality) { + throw new Error('input tensors should have the same shape'); + } + } +}; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 64d5007c7a02e..06270082d7605 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -326,12 +326,12 @@ "v{7,8,9,10}/test_clip_default_max", "v{7,8,9,10}/test_clip_default_inbounds", "v{7,8,9,10}/test_clip", - // "test_concat_1d_axis_0", - // "test_concat_2d_axis_0", - // "test_concat_2d_axis_1", - // "test_concat_3d_axis_0", - // "test_concat_3d_axis_1", - // "test_concat_3d_axis_2", + "test_concat_1d_axis_0", + "test_concat_2d_axis_0", + "test_concat_2d_axis_1", + "test_concat_3d_axis_0", + "test_concat_3d_axis_1", + "test_concat_3d_axis_2", // "test_conv_with_strides_and_asymmetric_padding", // "test_conv_with_strides_no_padding", // "test_conv_with_strides_padding", @@ -514,7 +514,7 @@ //"and.jsonc", "asin.jsonc", "ceil.jsonc", - //"concat.jsonc", + "concat.jsonc", //"conv.jsonc", "cos.jsonc", "div.jsonc",