From 6627349fb6192b0bcb7c75bbd5af9c1c25c30a95 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 7 Apr 2022 18:46:53 -0700 Subject: [PATCH] binary ops --- .../backends/webgpu/op-resolve-rules.ts | 10 +-- .../onnxjs/backends/webgpu/ops/binary-op.ts | 78 +++++++++++-------- .../lib/onnxjs/backends/webgpu/ops/common.ts | 66 ++++++++++++++++ js/web/test/suite-test-list.jsonc | 28 +++---- 4 files changed, 131 insertions(+), 51 deletions(-) diff --git a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts index ea767e914a1ed..dc1f063197cf6 100644 --- a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts @@ -20,8 +20,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ ['Clip', '', '11+', unaryOps.clipV11], // ['Concat', '', '4+', concat, parseConcatAttributes], // ['Conv', '', '1+', conv, parseConvAttributes], - ['Cos', '', '7+', unaryOps.cos], - // ['Div', '', '7+', binaryOps.div], + ['Cos', '', '7+', unaryOps.cos], ['Div', '', '7+', binaryOps.div], // ['Dropout', '', '7+', unaryOps.identity], // ['DepthToSpace', '', '1+', depthToSpace, parseDepthToSpaceAttributes], // ['Equal', '', '7+', binaryOps.equal], @@ -44,13 +43,12 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ // ['MatMul', '', '1+', matMul, parseMatMulAttributes], // // TODO: support new attributes for MaxPool-8 and MaxPool-10 // ['MaxPool', '', '1+', maxPool, parseMaxPoolAttributes], - // ['Mul', '', '7+', binaryOps.mul], - ['Neg', '', '6+', unaryOps.neg], + ['Mul', '', '7+', binaryOps.mul], ['Neg', '', '6+', unaryOps.neg], // ['Not', '', '1+', unaryOps.not], // ['Or', '', '7+', binaryOps.or], // ['Pad', '', '2-10', padV2, parsePadAttributesV2], // ['Pad', '', '11+', padV11, parsePadAttributesV11], - // ['Pow', '', '7+', binaryOps.pow], + ['Pow', '', '7+', binaryOps.pow], // ['PRelu', '', '7+', binaryOps.pRelu], // ['ReduceLogSum', '', '1+', reduceLogSum, parseReduceAttributes], // ['ReduceMax', '', '1+', reduceMax, parseReduceAttributes], @@ -77,7 +75,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ ['Sqrt', '', '6+', unaryOps.sqrt], // ['Squeeze', '', '1-12', squeeze, parseSqueezeAttributes], // ['Squeeze', '', '13+', squeezeV13], - // ['Sub', '', '7+', binaryOps.sub], + ['Sub', '', '7+', binaryOps.sub], // ['Sum', '', '6+', sum], ['Tan', '', '7+', unaryOps.tan], ['Tanh', '', '6+', unaryOps.tanh], // ['Tile', '', '6+', tile], diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/binary-op.ts b/js/web/lib/onnxjs/backends/webgpu/ops/binary-op.ts index 0ba23f13b9314..d6489bdbda6d2 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/binary-op.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/binary-op.ts @@ -6,7 +6,7 @@ import {Tensor} from '../../../tensor'; import {BroadcastUtil, ShapeUtil} from '../../../util'; import {WebGpuInferenceHandler} from '../inference-handler'; import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; -import {WORKGROUP_SIZE} from './common'; +import {createIndicesHelper, WORKGROUP_SIZE} from './common'; type BuiltinFunctionName = string; type BinaryCustomExpression = (expressionA: string, expressionB: string) => string; @@ -34,42 +34,25 @@ const createBinaryOpProgramShader = } let broadcastImpl = ''; - let broadcastVars = ''; + const outputIndicesHelper = createIndicesHelper('output', dimsOutput); if (doBroadcast) { - broadcastVars = `var outputDims: array;`; - - let calcDimsOutputImpl = ''; - const outputStrides = ShapeUtil.computeStrides(dimsOutput); - for (let i = 0; i < dimsOutput.length - 1; i++) { - calcDimsOutputImpl += ` - let dim${i} = current / ${outputStrides[i]}u; - let rest${i} = current % ${outputStrides[i]}u; - (*outputDims)[${i}] = dim${i}; - current = rest${i}; - `; - } - calcDimsOutputImpl += `(*outputDims)[${dimsOutput.length - 1}] = current;`; - const calcOffsetImpl = (dims: readonly number[]) => { const strides = ShapeUtil.computeStrides(dims); const offsets: string[] = []; for (let i = dims.length - 1; i >= 0; i--) { - offsets.push(`${strides[i]}u * ((*outputDims)[${i + dimsOutput.length - dims.length}] % ${dims[i]}u)`); + offsets.push(`${strides[i]}u * ((*outputIndices)[${i + dimsOutput.length - dims.length}] % ${dims[i]}u)`); } return offsets.length > 0 ? offsets.join('+') : '0u'; }; broadcastImpl = ` - fn calcDimsOutput(outputOffset: u32, outputDims: ptr>) { - var current = outputOffset; - ${calcDimsOutputImpl} - } + ${outputIndicesHelper.o2iImpl} - fn calcOffsetA(outputDims: ptr>) -> u32 { + fn calcOffsetA(outputIndices: ptr>) -> u32 { return ${calcOffsetImpl(dimsA)}; } - fn calcOffsetB(outputDims: ptr>) -> u32 { + fn calcOffsetB(outputIndices: ptr>) -> u32 { return ${calcOffsetImpl(dimsB)}; } `; @@ -79,10 +62,10 @@ const createBinaryOpProgramShader = if (vectorize) { if (doBroadcast) { assignment = ` - ${broadcastVars} - calcDimsOutput(global_id.x * 4u, &outputDims); - let offsetA = calcOffsetA(&outputDims); - let offsetB = calcOffsetB(&outputDims); + ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} + ${outputIndicesHelper.o2iCall('global_id.x * 4u', 'outputIndices')} + let offsetA = calcOffsetA(&outputIndices); + let offsetB = calcOffsetB(&outputIndices); outputData[global_id.x] = ${expressionVector('aData[offsetA / 4u]', 'bData[offsetB / 4u]')};`; } else { assignment = `outputData[global_id.x] = ${expressionVector('aData[global_id.x]', 'bData[global_id.x]')};`; @@ -95,9 +78,9 @@ const createBinaryOpProgramShader = const expressionA = `aData[indexA${x}][componentA${x}]`; const expressionB = `bData[indexB${x}][componentB${x}]`; return ` - calcDimsOutput(global_id.x * 4u + ${x}u, &outputDims); - let offsetA${x} = calcOffsetA(&outputDims); - let offsetB${x} = calcOffsetB(&outputDims); + ${outputIndicesHelper.o2iCall(`global_id.x * 4u + ${x}u`, 'outputIndices')} + let offsetA${x} = calcOffsetA(&outputIndices); + let offsetB${x} = calcOffsetB(&outputIndices); let indexA${x} = offsetA${x} / 4u; let indexB${x} = offsetB${x} / 4u; let componentA${x} = offsetA${x} % 4u; @@ -106,7 +89,7 @@ const createBinaryOpProgramShader = }; assignment = ` - ${broadcastVars} + ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} ${singleAssignment(0)} ${singleAssignment(1)} ${singleAssignment(2)} @@ -198,3 +181,36 @@ const createBinaryOpProgramInfoLoader = export const add = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => handler.run(createBinaryOpProgramInfoLoader(inputs, 'Add', (a, b) => `${a}+${b}`), inputs); + +// export const and = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslAnd(), 'bool'), inputs)]; + +export const div = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => + handler.run(createBinaryOpProgramInfoLoader(inputs, 'Div', (a, b) => `${a}/${b}`), inputs); + +// export const equal = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslEqual(), 'bool'), inputs)]; + +// export const greater = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslGreater(), 'bool'), inputs)]; + +// export const less = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslLess(), 'bool'), inputs)]; + +export const mul = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => + handler.run(createBinaryOpProgramInfoLoader(inputs, 'Mul', (a, b) => `${a}*${b}`), inputs); + +// export const or = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslOr(), 'bool'), inputs)]; + +export const pow = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => + handler.run(createBinaryOpProgramInfoLoader(inputs, 'Pow', 'pow'), inputs); + +// export const pRelu = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslPRelu()), inputs)]; + +export const sub = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => + handler.run(createBinaryOpProgramInfoLoader(inputs, 'Sub', (a, b) => `${a}-${b}`), inputs); + +// export const xor = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslXor(), 'bool'), inputs)]; diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/common.ts b/js/web/lib/onnxjs/backends/webgpu/ops/common.ts index b436e82f0d25a..dd16781496b4d 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/common.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/common.ts @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {ShapeUtil} from '../../../util'; + /** * constant value for a workgroup size. * @@ -12,3 +14,67 @@ * from: https://surma.dev/things/webgpu/ **/ export const WORKGROUP_SIZE = 64; + +export interface IndicesHelper { + /** + * WGSL code of function implementation for offset-to-indices + */ + o2iImpl: string; + /** + * WGSL code of function call for offset-to-indices + */ + o2iCall: (varOffset: string, varIndices: string) => string; + /** + * WGSL code of function implementation for indices-to-offset + */ + i2oImpl: string; + /** + * WGSL code of function implementation for indices-to-offset + */ + i2oExpression: (varIndices: string) => string; + /** + * WGSL code of indices variable declaration + */ + indicesVariableDeclaration: (v: string) => string; +} + +export const createIndicesHelper = (name: string, shape: readonly number[]) => { + const iType = shape.length < 2 ? 'u32' : `array`; + + const strides = ShapeUtil.computeStrides(shape); + let o2iSnippet = ''; + for (let i = 0; i < shape.length - 1; i++) { + o2iSnippet += ` + let dim${i} = current / ${strides[i]}u; + let rest${i} = current % ${strides[i]}u; + (*indices)[${i}] = dim${i}; + current = rest${i}; + `; + } + o2iSnippet += `(*indices)[${shape.length - 1}] = current;`; + + const o2iImpl = shape.length < 2 ? '' : ` + fn ih_o2i_${name}(offset: u32, indices: ptr) { + var current = offset; + ${o2iSnippet} + }`; + + const o2iCall = (varOffset: string, varIndices: string) => + shape.length < 2 ? `${varIndices}=${varOffset};` : `ih_o2i_${name}(${varOffset}, &${varIndices});`; + + const offsets: string[] = []; + for (let i = shape.length - 1; i >= 0; i--) { + offsets.push(`${strides[i]}u * ((*indices)[${i}])`); + } + + const i2oImpl = shape.length < 2 ? '' : ` + fn ih_i2o_${name}(indices: ptr) -> u32 { + return ${offsets.length > 0 ? offsets.join('+') : '0u'} + }`; + + const i2oExpression = (varIndices: string) => shape.length < 2 ? varIndices : `ih_i2o_${name}(&${varIndices})`; + + const indicesVariableDeclaration = (v: string) => `var ${v}:${iType};`; + + return {o2iImpl, o2iCall, i2oImpl, i2oExpression, indicesVariableDeclaration}; +}; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 74e4ed3e9b9e1..3dcfe91a2e77e 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -338,9 +338,9 @@ "test_constant", "test_cos_example", "test_cos", - // "test_div_bcast", - // "test_div_example", - // "test_div", + "test_div_bcast", + "test_div_example", + "test_div", // "test_dropout_default", // "test_dropout_random", // "test_depthtospace_crd_mode", @@ -390,9 +390,9 @@ // "v12/test_maxpool_2d_same_upper", // "v12/test_maxpool_2d_strides", // "test_maxpool_3d_default", - // "test_mul_bcast", - // "test_mul_example", - // "test_mul", + "test_mul_bcast", + "test_mul_example", + "test_mul", "test_neg", "test_neg_example", // "test_not_2d", @@ -424,9 +424,9 @@ // "name": "test_softmax_large_number", // "condition": "^((?!iOS).)*$" // does NOT contains 'iOS': large number cannot be handled in a half_float environment // }, - // "test_sub_bcast", - // "test_sub_example", - // "test_sub", + "test_sub_bcast", + "test_sub_example", + "test_sub", // "test_sum_example", // "test_sum_one_input", // "test_sum_two_inputs", @@ -517,7 +517,7 @@ //"concat.jsonc", //"conv.jsonc", "cos.jsonc", - //"div.jsonc", + "div.jsonc", //"depth-to-space.jsonc", //"equal.jsonc", "exp.jsonc", @@ -530,7 +530,7 @@ //"less.jsonc", "log.jsonc", //"matmul.jsonc", - //"mul.jsonc", + "mul.jsonc", "neg.jsonc", //"not.jsonc", //"or.jsonc", @@ -539,14 +539,14 @@ "relu.jsonc", //"pad.jsonc", //"pad-big.jsonc", - //"pow.jsonc", - //"pow-big-number.jsonc", + "pow.jsonc", + "pow-big-number.jsonc", //"reshape.jsonc", //"softmax.jsonc", "sin.jsonc", //"split.jsonc", "sqrt.jsonc", - //"sub.jsonc", + "sub.jsonc", "tan.jsonc" //"transpose.jsonc", //"xor.jsonc"