Skip to content

Commit

Permalink
binary ops
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent fb81d7f commit 6627349
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 51 deletions.
10 changes: 4 additions & 6 deletions js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
['Clip', '', '11+', unaryOps.clipV11],
// ['Concat', '', '4+', concat, parseConcatAttributes],
// ['Conv', '', '1+', conv, parseConvAttributes],
['Cos', '', '7+', unaryOps.cos],
// ['Div', '', '7+', binaryOps.div],
['Cos', '', '7+', unaryOps.cos], ['Div', '', '7+', binaryOps.div],
// ['Dropout', '', '7+', unaryOps.identity],
// ['DepthToSpace', '', '1+', depthToSpace, parseDepthToSpaceAttributes],
// ['Equal', '', '7+', binaryOps.equal],
Expand All @@ -44,13 +43,12 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['MatMul', '', '1+', matMul, parseMatMulAttributes],
// // TODO: support new attributes for MaxPool-8 and MaxPool-10
// ['MaxPool', '', '1+', maxPool, parseMaxPoolAttributes],
// ['Mul', '', '7+', binaryOps.mul],
['Neg', '', '6+', unaryOps.neg],
['Mul', '', '7+', binaryOps.mul], ['Neg', '', '6+', unaryOps.neg],
// ['Not', '', '1+', unaryOps.not],
// ['Or', '', '7+', binaryOps.or],
// ['Pad', '', '2-10', padV2, parsePadAttributesV2],
// ['Pad', '', '11+', padV11, parsePadAttributesV11],
// ['Pow', '', '7+', binaryOps.pow],
['Pow', '', '7+', binaryOps.pow],
// ['PRelu', '', '7+', binaryOps.pRelu],
// ['ReduceLogSum', '', '1+', reduceLogSum, parseReduceAttributes],
// ['ReduceMax', '', '1+', reduceMax, parseReduceAttributes],
Expand All @@ -77,7 +75,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
['Sqrt', '', '6+', unaryOps.sqrt],
// ['Squeeze', '', '1-12', squeeze, parseSqueezeAttributes],
// ['Squeeze', '', '13+', squeezeV13],
// ['Sub', '', '7+', binaryOps.sub],
['Sub', '', '7+', binaryOps.sub],
// ['Sum', '', '6+', sum],
['Tan', '', '7+', unaryOps.tan], ['Tanh', '', '6+', unaryOps.tanh],
// ['Tile', '', '6+', tile],
Expand Down
78 changes: 47 additions & 31 deletions js/web/lib/onnxjs/backends/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {Tensor} from '../../../tensor';
import {BroadcastUtil, ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';
import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
import {WORKGROUP_SIZE} from './common';
import {createIndicesHelper, WORKGROUP_SIZE} from './common';

type BuiltinFunctionName = string;
type BinaryCustomExpression = (expressionA: string, expressionB: string) => string;
Expand Down Expand Up @@ -34,42 +34,25 @@ const createBinaryOpProgramShader =
}

let broadcastImpl = '';
let broadcastVars = '';
const outputIndicesHelper = createIndicesHelper('output', dimsOutput);
if (doBroadcast) {
broadcastVars = `var outputDims: array<u32, ${dimsOutput.length}>;`;

let calcDimsOutputImpl = '';
const outputStrides = ShapeUtil.computeStrides(dimsOutput);
for (let i = 0; i < dimsOutput.length - 1; i++) {
calcDimsOutputImpl += `
let dim${i} = current / ${outputStrides[i]}u;
let rest${i} = current % ${outputStrides[i]}u;
(*outputDims)[${i}] = dim${i};
current = rest${i};
`;
}
calcDimsOutputImpl += `(*outputDims)[${dimsOutput.length - 1}] = current;`;

const calcOffsetImpl = (dims: readonly number[]) => {
const strides = ShapeUtil.computeStrides(dims);
const offsets: string[] = [];
for (let i = dims.length - 1; i >= 0; i--) {
offsets.push(`${strides[i]}u * ((*outputDims)[${i + dimsOutput.length - dims.length}] % ${dims[i]}u)`);
offsets.push(`${strides[i]}u * ((*outputIndices)[${i + dimsOutput.length - dims.length}] % ${dims[i]}u)`);
}
return offsets.length > 0 ? offsets.join('+') : '0u';
};

broadcastImpl = `
fn calcDimsOutput(outputOffset: u32, outputDims: ptr<function, array<u32, ${dimsOutput.length}>>) {
var current = outputOffset;
${calcDimsOutputImpl}
}
${outputIndicesHelper.o2iImpl}
fn calcOffsetA(outputDims: ptr<function, array<u32, ${dimsOutput.length}>>) -> u32 {
fn calcOffsetA(outputIndices: ptr<function, array<u32, ${dimsOutput.length}>>) -> u32 {
return ${calcOffsetImpl(dimsA)};
}
fn calcOffsetB(outputDims: ptr<function, array<u32, ${dimsOutput.length}>>) -> u32 {
fn calcOffsetB(outputIndices: ptr<function, array<u32, ${dimsOutput.length}>>) -> u32 {
return ${calcOffsetImpl(dimsB)};
}
`;
Expand All @@ -79,10 +62,10 @@ const createBinaryOpProgramShader =
if (vectorize) {
if (doBroadcast) {
assignment = `
${broadcastVars}
calcDimsOutput(global_id.x * 4u, &outputDims);
let offsetA = calcOffsetA(&outputDims);
let offsetB = calcOffsetB(&outputDims);
${outputIndicesHelper.indicesVariableDeclaration('outputIndices')}
${outputIndicesHelper.o2iCall('global_id.x * 4u', 'outputIndices')}
let offsetA = calcOffsetA(&outputIndices);
let offsetB = calcOffsetB(&outputIndices);
outputData[global_id.x] = ${expressionVector('aData[offsetA / 4u]', 'bData[offsetB / 4u]')};`;
} else {
assignment = `outputData[global_id.x] = ${expressionVector('aData[global_id.x]', 'bData[global_id.x]')};`;
Expand All @@ -95,9 +78,9 @@ const createBinaryOpProgramShader =
const expressionA = `aData[indexA${x}][componentA${x}]`;
const expressionB = `bData[indexB${x}][componentB${x}]`;
return `
calcDimsOutput(global_id.x * 4u + ${x}u, &outputDims);
let offsetA${x} = calcOffsetA(&outputDims);
let offsetB${x} = calcOffsetB(&outputDims);
${outputIndicesHelper.o2iCall(`global_id.x * 4u + ${x}u`, 'outputIndices')}
let offsetA${x} = calcOffsetA(&outputIndices);
let offsetB${x} = calcOffsetB(&outputIndices);
let indexA${x} = offsetA${x} / 4u;
let indexB${x} = offsetB${x} / 4u;
let componentA${x} = offsetA${x} % 4u;
Expand All @@ -106,7 +89,7 @@ const createBinaryOpProgramShader =
};

assignment = `
${broadcastVars}
${outputIndicesHelper.indicesVariableDeclaration('outputIndices')}
${singleAssignment(0)}
${singleAssignment(1)}
${singleAssignment(2)}
Expand Down Expand Up @@ -198,3 +181,36 @@ const createBinaryOpProgramInfoLoader =

export const add = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createBinaryOpProgramInfoLoader(inputs, 'Add', (a, b) => `${a}+${b}`), inputs);

// export const and = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslAnd(), 'bool'), inputs)];

export const div = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createBinaryOpProgramInfoLoader(inputs, 'Div', (a, b) => `${a}/${b}`), inputs);

// export const equal = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslEqual(), 'bool'), inputs)];

// export const greater = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslGreater(), 'bool'), inputs)];

// export const less = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslLess(), 'bool'), inputs)];

export const mul = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createBinaryOpProgramInfoLoader(inputs, 'Mul', (a, b) => `${a}*${b}`), inputs);

// export const or = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslOr(), 'bool'), inputs)];

export const pow = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createBinaryOpProgramInfoLoader(inputs, 'Pow', 'pow'), inputs);

// export const pRelu = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslPRelu()), inputs)];

export const sub = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createBinaryOpProgramInfoLoader(inputs, 'Sub', (a, b) => `${a}-${b}`), inputs);

// export const xor = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslXor(), 'bool'), inputs)];
66 changes: 66 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {ShapeUtil} from '../../../util';

/**
* constant value for a workgroup size.
*
Expand All @@ -12,3 +14,67 @@
* from: https://surma.dev/things/webgpu/
**/
export const WORKGROUP_SIZE = 64;

export interface IndicesHelper {
/**
* WGSL code of function implementation for offset-to-indices
*/
o2iImpl: string;
/**
* WGSL code of function call for offset-to-indices
*/
o2iCall: (varOffset: string, varIndices: string) => string;
/**
* WGSL code of function implementation for indices-to-offset
*/
i2oImpl: string;
/**
* WGSL code of function implementation for indices-to-offset
*/
i2oExpression: (varIndices: string) => string;
/**
* WGSL code of indices variable declaration
*/
indicesVariableDeclaration: (v: string) => string;
}

export const createIndicesHelper = (name: string, shape: readonly number[]) => {
const iType = shape.length < 2 ? 'u32' : `array<u32, ${shape.length}>`;

const strides = ShapeUtil.computeStrides(shape);
let o2iSnippet = '';
for (let i = 0; i < shape.length - 1; i++) {
o2iSnippet += `
let dim${i} = current / ${strides[i]}u;
let rest${i} = current % ${strides[i]}u;
(*indices)[${i}] = dim${i};
current = rest${i};
`;
}
o2iSnippet += `(*indices)[${shape.length - 1}] = current;`;

const o2iImpl = shape.length < 2 ? '' : `
fn ih_o2i_${name}(offset: u32, indices: ptr<function, ${iType}>) {
var current = offset;
${o2iSnippet}
}`;

const o2iCall = (varOffset: string, varIndices: string) =>
shape.length < 2 ? `${varIndices}=${varOffset};` : `ih_o2i_${name}(${varOffset}, &${varIndices});`;

const offsets: string[] = [];
for (let i = shape.length - 1; i >= 0; i--) {
offsets.push(`${strides[i]}u * ((*indices)[${i}])`);
}

const i2oImpl = shape.length < 2 ? '' : `
fn ih_i2o_${name}(indices: ptr<function, ${iType}>) -> u32 {
return ${offsets.length > 0 ? offsets.join('+') : '0u'}
}`;

const i2oExpression = (varIndices: string) => shape.length < 2 ? varIndices : `ih_i2o_${name}(&${varIndices})`;

const indicesVariableDeclaration = (v: string) => `var ${v}:${iType};`;

return {o2iImpl, o2iCall, i2oImpl, i2oExpression, indicesVariableDeclaration};
};
28 changes: 14 additions & 14 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,9 @@
"test_constant",
"test_cos_example",
"test_cos",
// "test_div_bcast",
// "test_div_example",
// "test_div",
"test_div_bcast",
"test_div_example",
"test_div",
// "test_dropout_default",
// "test_dropout_random",
// "test_depthtospace_crd_mode",
Expand Down Expand Up @@ -390,9 +390,9 @@
// "v12/test_maxpool_2d_same_upper",
// "v12/test_maxpool_2d_strides",
// "test_maxpool_3d_default",
// "test_mul_bcast",
// "test_mul_example",
// "test_mul",
"test_mul_bcast",
"test_mul_example",
"test_mul",
"test_neg",
"test_neg_example",
// "test_not_2d",
Expand Down Expand Up @@ -424,9 +424,9 @@
// "name": "test_softmax_large_number",
// "condition": "^((?!iOS).)*$" // does NOT contains 'iOS': large number cannot be handled in a half_float environment
// },
// "test_sub_bcast",
// "test_sub_example",
// "test_sub",
"test_sub_bcast",
"test_sub_example",
"test_sub",
// "test_sum_example",
// "test_sum_one_input",
// "test_sum_two_inputs",
Expand Down Expand Up @@ -517,7 +517,7 @@
//"concat.jsonc",
//"conv.jsonc",
"cos.jsonc",
//"div.jsonc",
"div.jsonc",
//"depth-to-space.jsonc",
//"equal.jsonc",
"exp.jsonc",
Expand All @@ -530,7 +530,7 @@
//"less.jsonc",
"log.jsonc",
//"matmul.jsonc",
//"mul.jsonc",
"mul.jsonc",
"neg.jsonc",
//"not.jsonc",
//"or.jsonc",
Expand All @@ -539,14 +539,14 @@
"relu.jsonc",
//"pad.jsonc",
//"pad-big.jsonc",
//"pow.jsonc",
//"pow-big-number.jsonc",
"pow.jsonc",
"pow-big-number.jsonc",
//"reshape.jsonc",
//"softmax.jsonc",
"sin.jsonc",
//"split.jsonc",
"sqrt.jsonc",
//"sub.jsonc",
"sub.jsonc",
"tan.jsonc"
//"transpose.jsonc",
//"xor.jsonc"
Expand Down

0 comments on commit 6627349

Please sign in to comment.