Skip to content

Commit

Permalink
webgpu/shader: Migrate all f32 expression tests to the CaseCache.
Browse files Browse the repository at this point in the history
  • Loading branch information
ben-clayton committed Nov 11, 2022
1 parent ea0cfeb commit d7e8d00
Show file tree
Hide file tree
Showing 53 changed files with 1,125 additions and 680 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,61 @@ import {
subtractionInterval,
} from '../../../../util/f32_interval.js';
import { kVectorTestValues } from '../../../../util/math.js';
import { makeCaseCache } from '../case_cache.js';
import { allInputSources, Case, makeBinaryToF32IntervalCase, run } from '../expression.js';

import { binary } from './binary.js';

export const g = makeTestGroup(GPUTest);

export const d = makeCaseCache('binary/f32_arithmetic', {
addition: () => {
const makeCase = (lhs: number, rhs: number): Case => {
return makeBinaryToF32IntervalCase(lhs, rhs, additionInterval);
};

return kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1]);
});
},
subtraction: () => {
const makeCase = (lhs: number, rhs: number): Case => {
return makeBinaryToF32IntervalCase(lhs, rhs, subtractionInterval);
};

return kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1]);
});
},
multiplication: () => {
const makeCase = (lhs: number, rhs: number): Case => {
return makeBinaryToF32IntervalCase(lhs, rhs, multiplicationInterval);
};

return kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1]);
});
},
division: () => {
const makeCase = (lhs: number, rhs: number): Case => {
return makeBinaryToF32IntervalCase(lhs, rhs, divisionInterval);
};

return kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1]);
});
},
remainder: () => {
const makeCase = (lhs: number, rhs: number): Case => {
return makeBinaryToF32IntervalCase(lhs, rhs, remainderInterval);
};

return kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1]);
});
},
});

g.test('addition')
.specURL('https://www.w3.org/TR/WGSL/#floating-point-evaluation')
.desc(
Expand All @@ -31,14 +80,7 @@ Accuracy: Correctly rounded
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.fn(async t => {
const makeCase = (lhs: number, rhs: number): Case => {
return makeBinaryToF32IntervalCase(lhs, rhs, additionInterval);
};

const cases = kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1]);
});

const cases = await d.get('addition');
await run(t, binary('+'), [TypeF32, TypeF32], TypeF32, t.params, cases);
});

Expand All @@ -54,14 +96,7 @@ Accuracy: Correctly rounded
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.fn(async t => {
const makeCase = (lhs: number, rhs: number): Case => {
return makeBinaryToF32IntervalCase(lhs, rhs, subtractionInterval);
};

const cases = kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1]);
});

const cases = await d.get('subtraction');
await run(t, binary('-'), [TypeF32, TypeF32], TypeF32, t.params, cases);
});

Expand All @@ -77,14 +112,7 @@ Accuracy: Correctly rounded
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.fn(async t => {
const makeCase = (lhs: number, rhs: number): Case => {
return makeBinaryToF32IntervalCase(lhs, rhs, multiplicationInterval);
};

const cases = kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1]);
});

const cases = await d.get('multiplication');
await run(t, binary('*'), [TypeF32, TypeF32], TypeF32, t.params, cases);
});

Expand All @@ -100,14 +128,7 @@ Accuracy: 2.5 ULP for |y| in the range [2^-126, 2^126]
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.fn(async t => {
const makeCase = (lhs: number, rhs: number): Case => {
return makeBinaryToF32IntervalCase(lhs, rhs, divisionInterval);
};

const cases = kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1]);
});

const cases = await d.get('division');
await run(t, binary('/'), [TypeF32, TypeF32], TypeF32, t.params, cases);
});

Expand All @@ -123,13 +144,6 @@ Accuracy: Derived from x - y * trunc(x/y)
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.fn(async t => {
const makeCase = (lhs: number, rhs: number): Case => {
return makeBinaryToF32IntervalCase(lhs, rhs, remainderInterval);
};

const cases = kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1]);
});

const cases = await d.get('remainder');
await run(t, binary('%'), [TypeF32, TypeF32], TypeF32, t.params, cases);
});
112 changes: 64 additions & 48 deletions src/webgpu/shader/execution/expression/binary/f32_logical.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { GPUTest } from '../../../../gpu_test.js';
import { anyOf } from '../../../../util/compare.js';
import { bool, f32, Scalar, TypeBool, TypeF32 } from '../../../../util/conversion.js';
import { flushSubnormalScalarF32, kVectorTestValues } from '../../../../util/math.js';
import { makeCaseCache } from '../case_cache.js';
import { allInputSources, Case, run } from '../expression.js';

import { binary } from './binary.js';
Expand Down Expand Up @@ -39,6 +40,63 @@ function makeCase(
return { input: [f32_lhs, f32_rhs], expected: anyOf(...expected) };
}

export const d = makeCaseCache('binary/f32_logical', {
equals: () => {
const truthFunc = (lhs: Scalar, rhs: Scalar): boolean => {
return (lhs.value as number) === (rhs.value as number);
};

return kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1], truthFunc);
});
},
not_equals: () => {
const truthFunc = (lhs: Scalar, rhs: Scalar): boolean => {
return (lhs.value as number) !== (rhs.value as number);
};

return kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1], truthFunc);
});
},
less_than: () => {
const truthFunc = (lhs: Scalar, rhs: Scalar): boolean => {
return (lhs.value as number) < (rhs.value as number);
};

return kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1], truthFunc);
});
},
less_equals: () => {
const truthFunc = (lhs: Scalar, rhs: Scalar): boolean => {
return (lhs.value as number) <= (rhs.value as number);
};

return kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1], truthFunc);
});
},
greater_than: () => {
const truthFunc = (lhs: Scalar, rhs: Scalar): boolean => {
return (lhs.value as number) > (rhs.value as number);
};

return kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1], truthFunc);
});
},
greater_equals: () => {
const truthFunc = (lhs: Scalar, rhs: Scalar): boolean => {
return (lhs.value as number) >= (rhs.value as number);
};

return kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1], truthFunc);
});
},
});

g.test('equals')
.specURL('https://www.w3.org/TR/WGSL/#floating-point-evaluation')
.desc(
Expand All @@ -51,14 +109,7 @@ Accuracy: Correct result
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.fn(async t => {
const truthFunc = (lhs: Scalar, rhs: Scalar): boolean => {
return (lhs.value as number) === (rhs.value as number);
};

const cases = kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1], truthFunc);
});

const cases = await d.get('equals');
await run(t, binary('=='), [TypeF32, TypeF32], TypeBool, t.params, cases);
});

Expand All @@ -74,14 +125,7 @@ Accuracy: Correct result
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.fn(async t => {
const truthFunc = (lhs: Scalar, rhs: Scalar): boolean => {
return (lhs.value as number) !== (rhs.value as number);
};

const cases = kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1], truthFunc);
});

const cases = await d.get('not_equals');
await run(t, binary('!='), [TypeF32, TypeF32], TypeBool, t.params, cases);
});

Expand All @@ -97,14 +141,7 @@ Accuracy: Correct result
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.fn(async t => {
const truthFunc = (lhs: Scalar, rhs: Scalar): boolean => {
return (lhs.value as number) < (rhs.value as number);
};

const cases = kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1], truthFunc);
});

const cases = await d.get('less_than');
await run(t, binary('<'), [TypeF32, TypeF32], TypeBool, t.params, cases);
});

Expand All @@ -120,14 +157,7 @@ Accuracy: Correct result
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.fn(async t => {
const truthFunc = (lhs: Scalar, rhs: Scalar): boolean => {
return (lhs.value as number) <= (rhs.value as number);
};

const cases = kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1], truthFunc);
});

const cases = await d.get('less_equals');
await run(t, binary('<='), [TypeF32, TypeF32], TypeBool, t.params, cases);
});

Expand All @@ -143,14 +173,7 @@ Accuracy: Correct result
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.fn(async t => {
const truthFunc = (lhs: Scalar, rhs: Scalar): boolean => {
return (lhs.value as number) > (rhs.value as number);
};

const cases = kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1], truthFunc);
});

const cases = await d.get('greater_than');
await run(t, binary('>'), [TypeF32, TypeF32], TypeBool, t.params, cases);
});

Expand All @@ -166,13 +189,6 @@ Accuracy: Correct result
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.fn(async t => {
const truthFunc = (lhs: Scalar, rhs: Scalar): boolean => {
return (lhs.value as number) >= (rhs.value as number);
};

const cases = kVectorTestValues[2].map(v => {
return makeCase(v[0], v[1], truthFunc);
});

const cases = await d.get('greater_equals');
await run(t, binary('>='), [TypeF32, TypeF32], TypeBool, t.params, cases);
});
24 changes: 14 additions & 10 deletions src/webgpu/shader/execution/expression/call/builtin/abs.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,25 @@ import { kBit } from '../../../../../util/constants.js';
import { i32Bits, TypeF32, TypeI32, TypeU32, u32Bits } from '../../../../../util/conversion.js';
import { absInterval } from '../../../../../util/f32_interval.js';
import { fullF32Range } from '../../../../../util/math.js';
import { makeCaseCache } from '../../case_cache.js';
import { allInputSources, Case, makeUnaryToF32IntervalCase, run } from '../../expression.js';

import { builtin } from './builtin.js';

export const g = makeTestGroup(GPUTest);

export const d = makeCaseCache('abs', {
f32: () => {
const makeCase = (x: number): Case => {
return makeUnaryToF32IntervalCase(x, absInterval);
};

return [Number.NEGATIVE_INFINITY, ...fullF32Range(), Number.POSITIVE_INFINITY].map(x =>
makeCase(x)
);
},
});

g.test('abstract_int')
.specURL('https://www.w3.org/TR/WGSL/#integer-builtin-functions')
.desc(`abstract int tests`)
Expand Down Expand Up @@ -147,16 +160,7 @@ g.test('f32')
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.fn(async t => {
const makeCase = (x: number): Case => {
return makeUnaryToF32IntervalCase(x, absInterval);
};

const cases: Array<Case> = [
Number.NEGATIVE_INFINITY,
...fullF32Range(),
Number.POSITIVE_INFINITY,
].map(x => makeCase(x));

const cases = await d.get('f32');
await run(t, builtin('abs'), [TypeF32], TypeF32, t.params, cases);
});

Expand Down
Loading

0 comments on commit d7e8d00

Please sign in to comment.