struct vec5 {x: i32, y: i32, z: i32, w: i32, u: i32}; struct vec6 {x: i32, y: i32, z: i32, w: i32, u: i32, v: i32}; // Checks whether coordinates lie within the bounds of the shape. fn coordsInBounds2D(coord : vec2, shape : vec2) -> bool { return all(coord >= vec2(0)) && all(coord < shape); } fn coordsInBounds3D(coord : vec3, shape : vec3) -> bool { return all(coord >= vec3(0)) && all(coord < shape); } fn coordsInBounds4D(coord : vec4, shape : vec4) -> bool { return all(coord >= vec4(0)) && all(coord < shape); } fn getIndexFromCoords1D(coord : i32, shape : i32) -> i32 { return coord; } fn getIndexFromCoords2D(coords : vec2, shape : vec2) -> i32 { return dot(coords, vec2(shape.y, 1)); } fn getIndexFromCoords3D(coords : vec3, shape : vec3) -> i32 { return dot(coords, vec3(shape.y * shape.z, shape.z, 1)); } fn getIndexFromCoords4D(coords : vec4, shape : vec4) -> i32 { return dot(coords, vec4( shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1)); } fn getIndexFromCoords5D(coords : vec5, shape : vec5) -> i32 { let shapeStrides: vec5 = vec5(shape.y * shape.z * shape.w * shape.u, shape.z * shape.w * shape.u, shape.w * shape.u, shape.u, 1); return coords.x*shapeStrides.x + coords.y*shapeStrides.y + coords.z*shapeStrides.z + coords.w*shapeStrides.w + coords.u*shapeStrides.u; } fn getIndexFromCoords6D(coords : vec6, shape : vec6) -> i32 { let shapeStrides: vec6 = vec6(shape.y * shape.z * shape.w * shape.u * shape.v, shape.z * shape.w * shape.u * shape.v, shape.w * shape.u * shape.v, shape.u * shape.v, shape.v, 1); return coords.x*shapeStrides.x + coords.y*shapeStrides.y + coords.z*shapeStrides.z + coords.w*shapeStrides.w + coords.u*shapeStrides.u + coords.v*shapeStrides.v; } fn idiv(a: i32, b: i32, sign: f32) -> i32 { var res: i32 = a / b; let modulo: i32 = a % b; if (sign < 0. && modulo != 0) { res = res - 1; } return res; } // NaN defination in IEEE 754-1985 is : // - sign = either 0 or 1. // - biased exponent = all 1 bits. // - fraction = anything except all 0 bits (since all 0 bits represents infinity). // https://en.wikipedia.org/wiki/IEEE_754-1985#Representation_of_non-numbers fn isnan(val: f32) -> bool { let floatToUint: u32 = bitcast(val); return (floatToUint & 0x7fffffffu) > 0x7f800000u; } fn isnanVec4(val : vec4) -> vec4 { return vec4(isnan(val[0]), isnan(val[1]), isnan(val[2]), isnan(val[3])); } const workGroupSizeX = 256u; const workGroupSizeY = 1u; const workGroupSizeZ = 1u; var localId: vec3; var globalId: vec3; var numWorkgroups: vec3; // Only used when the y/z dimension of workgroup size is 1. fn getGlobalIndex() -> i32 { return i32(globalId.x); } struct Uniforms { NAN : f32, aShape : vec2, bShape : i32, outShape : vec2, outShapeStrides: i32, size : i32, }; @group(0) @binding(0) var result: array; @group(0) @binding(1) var A: array; @group(0) @binding(2) var B: array; @group(0) @binding(3) var uniforms: Uniforms; fn getCoordsFromIndex(index : i32) -> vec2 { let d0 = index / uniforms.outShapeStrides; let d1 = index - d0 * uniforms.outShapeStrides; return vec2(d0, d1); } fn getOutputCoords() -> vec2{ let globalIndex = getGlobalIndex(); return getCoordsFromIndex(globalIndex); } fn getOutputIndexFromCoords(coords : vec2) -> i32 { return dot(coords, vec2(uniforms.outShapeStrides, 1)); } fn setOutputAtIndex(flatIndex : i32, value : f32) { result[flatIndex] = f32(value); } fn setOutputAtIndexI32(flatIndex : i32, value : i32) { result[flatIndex] = f32(value); } fn setOutputAtCoords(d0 : i32, d1 : i32, value : f32) { let flatIndex = getOutputIndexFromCoords(vec2(d0, d1)); setOutputAtIndex(flatIndex, value); } fn setOutputAtCoordsI32(d0 : i32, d1 : i32, value : i32) { let flatIndex = getOutputIndexFromCoords(vec2(d0, d1)); setOutputAtIndexI32(flatIndex, value); } fn getA(d0 : i32, d1 : i32) -> f32 { return f32(A[getIndexFromCoords2D(vec2(d0,d1), uniforms.aShape)]); } fn getAByOutputIndex(globalIndex : i32) -> f32 { return f32(A[globalIndex]); } fn getAByOutputCoords(coords : vec2) -> f32 { return f32(A[getOutputIndexFromCoords(coords)]); } fn getB(d0 : i32) -> f32 { return f32(B[getIndexFromCoords1D(i32(d0), uniforms.bShape)]); } fn getBByOutputIndex(globalIndex : i32) -> f32 { var coords = getCoordsFromIndex(globalIndex); return f32(B[getIndexFromCoords1D(i32(coords.y), uniforms.bShape)]); } fn getBByOutputCoords(coordsIn : vec2) -> f32 { var coords = coordsIn; return f32(B[getIndexFromCoords1D(i32(coords.y), uniforms.bShape)]); } fn binaryOperation(a : f32, b : f32) -> f32 { return a + b; } var sharedBuf : array; @compute @workgroup_size(workGroupSizeX, workGroupSizeY, workGroupSizeZ) fn _start(@builtin(local_invocation_id) LocalId : vec3, @builtin(global_invocation_id) GlobalId : vec3, @builtin(num_workgroups) NumWorkgroups : vec3) { localId = LocalId; globalId = GlobalId; numWorkgroups = NumWorkgroups; main(getGlobalIndex()); } fn main(index : i32) { // Fill in the shared memory buffer. let localIndex = i32(localId.x); if(localIndex < 2) { sharedBuf[localIndex] = f32(B[localIndex]); } workgroupBarrier(); if(index < uniforms.size) { let coords = getCoordsFromIndex(index); let a = getAByOutputIndex(index); let b = sharedBuf[coords[1]]; setOutputAtIndex(index, binaryOperation(a, b)); } }