Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FP16] Implement lane access instructions. #6821

Merged
merged 1 commit into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scripts/gen-s-parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@
("i64x2.splat", "makeUnary(UnaryOp::SplatVecI64x2)"),
("i64x2.extract_lane", "makeSIMDExtract(SIMDExtractOp::ExtractLaneVecI64x2, 2)"),
("i64x2.replace_lane", "makeSIMDReplace(SIMDReplaceOp::ReplaceLaneVecI64x2, 2)"),
("f16x8.splat", "makeUnary(UnaryOp::SplatVecF16x8)"),
("f16x8.extract_lane", "makeSIMDExtract(SIMDExtractOp::ExtractLaneVecF16x8, 8)"),
("f16x8.replace_lane", "makeSIMDReplace(SIMDReplaceOp::ReplaceLaneVecF16x8, 8)"),
("f32x4.splat", "makeUnary(UnaryOp::SplatVecF32x4)"),
("f32x4.extract_lane", "makeSIMDExtract(SIMDExtractOp::ExtractLaneVecF32x4, 4)"),
("f32x4.replace_lane", "makeSIMDReplace(SIMDReplaceOp::ReplaceLaneVecF32x4, 4)"),
Expand Down
23 changes: 23 additions & 0 deletions src/gen-s-parser.inc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,29 @@ switch (buf[0]) {
}
case 'f': {
switch (buf[1]) {
case '1': {
switch (buf[6]) {
case 'e':
if (op == "f16x8.extract_lane"sv) {
CHECK_ERR(makeSIMDExtract(ctx, pos, annotations, SIMDExtractOp::ExtractLaneVecF16x8, 8));
return Ok{};
}
goto parse_error;
case 'r':
if (op == "f16x8.replace_lane"sv) {
CHECK_ERR(makeSIMDReplace(ctx, pos, annotations, SIMDReplaceOp::ReplaceLaneVecF16x8, 8));
return Ok{};
}
goto parse_error;
case 's':
if (op == "f16x8.splat"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::SplatVecF16x8));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
}
case '3': {
switch (buf[3]) {
case '.': {
Expand Down
2 changes: 2 additions & 0 deletions src/ir/child-typer.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ template<typename Subtype> struct ChildTyper : OverriddenVisitor<Subtype> {
case ReplaceLaneVecI64x2:
note(&curr->value, Type::i64);
break;
case ReplaceLaneVecF16x8:
case ReplaceLaneVecF32x4:
note(&curr->value, Type::f32);
break;
Expand Down Expand Up @@ -337,6 +338,7 @@ template<typename Subtype> struct ChildTyper : OverriddenVisitor<Subtype> {
case TruncSatUFloat32ToInt64:
case ReinterpretFloat32:
case PromoteFloat32:
case SplatVecF16x8:
case SplatVecF32x4:
note(&curr->value, Type::f32);
break;
Expand Down
1 change: 1 addition & 0 deletions src/ir/cost.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ struct CostAnalyzer : public OverriddenVisitor<CostAnalyzer, CostType> {
case SplatVecI16x8:
case SplatVecI32x4:
case SplatVecI64x2:
case SplatVecF16x8:
case SplatVecF32x4:
case SplatVecF64x2:
case NotVec128:
Expand Down
4 changes: 4 additions & 0 deletions src/literal.h
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ class Literal {
std::array<Literal, 8> getLanesUI16x8() const;
std::array<Literal, 4> getLanesI32x4() const;
std::array<Literal, 2> getLanesI64x2() const;
std::array<Literal, 8> getLanesF16x8() const;
std::array<Literal, 4> getLanesF32x4() const;
std::array<Literal, 2> getLanesF64x2() const;

Expand All @@ -463,6 +464,9 @@ class Literal {
Literal splatI64x2() const;
Literal extractLaneI64x2(uint8_t index) const;
Literal replaceLaneI64x2(const Literal& other, uint8_t index) const;
Literal splatF16x8() const;
Literal extractLaneF16x8(uint8_t index) const;
Literal replaceLaneF16x8(const Literal& other, uint8_t index) const;
Literal splatF32x4() const;
Literal extractLaneF32x4(uint8_t index) const;
Literal replaceLaneF32x4(const Literal& other, uint8_t index) const;
Expand Down
9 changes: 9 additions & 0 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,9 @@ struct PrintExpressionContents
case ExtractLaneVecI64x2:
o << "i64x2.extract_lane";
break;
case ExtractLaneVecF16x8:
o << "f16x8.extract_lane";
break;
case ExtractLaneVecF32x4:
o << "f32x4.extract_lane";
break;
Expand All @@ -728,6 +731,9 @@ struct PrintExpressionContents
case ReplaceLaneVecI64x2:
o << "i64x2.replace_lane";
break;
case ReplaceLaneVecF16x8:
o << "f16x8.replace_lane";
break;
case ReplaceLaneVecF32x4:
o << "f32x4.replace_lane";
break;
Expand Down Expand Up @@ -1137,6 +1143,9 @@ struct PrintExpressionContents
case SplatVecI64x2:
o << "i64x2.splat";
break;
case SplatVecF16x8:
o << "f16x8.splat";
break;
case SplatVecF32x4:
o << "f32x4.splat";
break;
Expand Down
1 change: 1 addition & 0 deletions src/tools/fuzzing/fuzzing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3579,6 +3579,7 @@ Expression* TranslateToFuzzReader::makeSIMDExtract(Type type) {
break;
case ExtractLaneSVecI16x8:
case ExtractLaneUVecI16x8:
case ExtractLaneVecF16x8:
index = upTo(8);
break;
case ExtractLaneVecI32x4:
Expand Down
3 changes: 3 additions & 0 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,9 @@ enum ASTNodes {
// half precision opcodes
F32_F16LoadMem = 0x30,
F32_F16StoreMem = 0x31,
F16x8Splat = 0x120,
F16x8ExtractLane = 0x121,
F16x8ReplaceLane = 0x122,

// bulk memory opcodes

Expand Down
6 changes: 6 additions & 0 deletions src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
return value.splatI32x4();
case SplatVecI64x2:
return value.splatI64x2();
case SplatVecF16x8:
return value.splatF16x8();
case SplatVecF32x4:
return value.splatF32x4();
case SplatVecF64x2:
Expand Down Expand Up @@ -1070,6 +1072,8 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
return vec.extractLaneI32x4(curr->index);
case ExtractLaneVecI64x2:
return vec.extractLaneI64x2(curr->index);
case ExtractLaneVecF16x8:
return vec.extractLaneF16x8(curr->index);
case ExtractLaneVecF32x4:
return vec.extractLaneF32x4(curr->index);
case ExtractLaneVecF64x2:
Expand Down Expand Up @@ -1098,6 +1102,8 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
return vec.replaceLaneI32x4(value, curr->index);
case ReplaceLaneVecI64x2:
return vec.replaceLaneI64x2(value, curr->index);
case ReplaceLaneVecF16x8:
return vec.replaceLaneF16x8(value, curr->index);
case ReplaceLaneVecF32x4:
return vec.replaceLaneF32x4(value, curr->index);
case ReplaceLaneVecF64x2:
Expand Down
5 changes: 5 additions & 0 deletions src/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ enum UnaryOp {
RelaxedTruncZeroSVecF64x2ToVecI32x4,
RelaxedTruncZeroUVecF64x2ToVecI32x4,

// Half precision SIMD
SplatVecF16x8,

InvalidUnary
};

Expand Down Expand Up @@ -490,6 +493,7 @@ enum SIMDExtractOp {
ExtractLaneUVecI16x8,
ExtractLaneVecI32x4,
ExtractLaneVecI64x2,
ExtractLaneVecF16x8,
ExtractLaneVecF32x4,
ExtractLaneVecF64x2
};
Expand All @@ -499,6 +503,7 @@ enum SIMDReplaceOp {
ReplaceLaneVecI16x8,
ReplaceLaneVecI32x4,
ReplaceLaneVecI64x2,
ReplaceLaneVecF16x8,
ReplaceLaneVecF32x4,
ReplaceLaneVecF64x2,
};
Expand Down
19 changes: 19 additions & 0 deletions src/wasm/literal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cmath>

#include "emscripten-optimizer/simple_ast.h"
#include "fp16.h"
#include "ir/bits.h"
#include "pretty_printing.h"
#include "support/bits.h"
Expand Down Expand Up @@ -1729,6 +1730,13 @@ LaneArray<4> Literal::getLanesI32x4() const {
LaneArray<2> Literal::getLanesI64x2() const {
return getLanes<int64_t, 2>(*this);
}
LaneArray<8> Literal::getLanesF16x8() const {
auto lanes = getLanesUI16x8();
for (size_t i = 0; i < lanes.size(); ++i) {
lanes[i] = Literal(fp16_ieee_to_fp32_value(lanes[i].geti32()));
}
return lanes;
}
LaneArray<4> Literal::getLanesF32x4() const {
auto lanes = getLanesI32x4();
for (size_t i = 0; i < lanes.size(); ++i) {
Expand Down Expand Up @@ -1766,6 +1774,10 @@ Literal Literal::splatI8x16() const { return splat<Type::i32, 16>(*this); }
Literal Literal::splatI16x8() const { return splat<Type::i32, 8>(*this); }
Literal Literal::splatI32x4() const { return splat<Type::i32, 4>(*this); }
Literal Literal::splatI64x2() const { return splat<Type::i64, 2>(*this); }
Literal Literal::splatF16x8() const {
uint16_t f16 = fp16_ieee_from_fp32_value(getf32());
return splat<Type::i32, 8>(Literal(f16));
}
Literal Literal::splatF32x4() const { return splat<Type::f32, 4>(*this); }
Literal Literal::splatF64x2() const { return splat<Type::f64, 2>(*this); }

Expand All @@ -1787,6 +1799,9 @@ Literal Literal::extractLaneI32x4(uint8_t index) const {
Literal Literal::extractLaneI64x2(uint8_t index) const {
return getLanesI64x2().at(index);
}
Literal Literal::extractLaneF16x8(uint8_t index) const {
return getLanesF16x8().at(index);
}
Literal Literal::extractLaneF32x4(uint8_t index) const {
return getLanesF32x4().at(index);
}
Expand Down Expand Up @@ -1815,6 +1830,10 @@ Literal Literal::replaceLaneI32x4(const Literal& other, uint8_t index) const {
Literal Literal::replaceLaneI64x2(const Literal& other, uint8_t index) const {
return replace<2, &Literal::getLanesI64x2>(*this, other, index);
}
Literal Literal::replaceLaneF16x8(const Literal& other, uint8_t index) const {
return replace<8, &Literal::getLanesF16x8>(
*this, Literal(fp16_ieee_from_fp32_value(other.getf32())), index);
}
Literal Literal::replaceLaneF32x4(const Literal& other, uint8_t index) const {
return replace<4, &Literal::getLanesF32x4>(*this, other, index);
}
Expand Down
14 changes: 14 additions & 0 deletions src/wasm/wasm-binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6239,6 +6239,10 @@ bool WasmBinaryReader::maybeVisitSIMDUnary(Expression*& out, uint32_t code) {
curr = allocator.alloc<Unary>();
curr->op = SplatVecI64x2;
break;
case BinaryConsts::F16x8Splat:
curr = allocator.alloc<Unary>();
curr->op = SplatVecF16x8;
break;
case BinaryConsts::F32x4Splat:
curr = allocator.alloc<Unary>();
curr->op = SplatVecF32x4;
Expand Down Expand Up @@ -6569,6 +6573,11 @@ bool WasmBinaryReader::maybeVisitSIMDExtract(Expression*& out, uint32_t code) {
curr->op = ExtractLaneVecI64x2;
curr->index = getLaneIndex(2);
break;
case BinaryConsts::F16x8ExtractLane:
curr = allocator.alloc<SIMDExtract>();
curr->op = ExtractLaneVecF16x8;
curr->index = getLaneIndex(8);
break;
case BinaryConsts::F32x4ExtractLane:
curr = allocator.alloc<SIMDExtract>();
curr->op = ExtractLaneVecF32x4;
Expand Down Expand Up @@ -6611,6 +6620,11 @@ bool WasmBinaryReader::maybeVisitSIMDReplace(Expression*& out, uint32_t code) {
curr->op = ReplaceLaneVecI64x2;
curr->index = getLaneIndex(2);
break;
case BinaryConsts::F16x8ReplaceLane:
curr = allocator.alloc<SIMDReplace>();
curr->op = ReplaceLaneVecF16x8;
curr->index = getLaneIndex(8);
break;
case BinaryConsts::F32x4ReplaceLane:
curr = allocator.alloc<SIMDReplace>();
curr->op = ReplaceLaneVecF32x4;
Expand Down
9 changes: 9 additions & 0 deletions src/wasm/wasm-stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,9 @@ void BinaryInstWriter::visitSIMDExtract(SIMDExtract* curr) {
case ExtractLaneVecI64x2:
o << U32LEB(BinaryConsts::I64x2ExtractLane);
break;
case ExtractLaneVecF16x8:
o << U32LEB(BinaryConsts::F16x8ExtractLane);
break;
case ExtractLaneVecF32x4:
o << U32LEB(BinaryConsts::F32x4ExtractLane);
break;
Expand All @@ -615,6 +618,9 @@ void BinaryInstWriter::visitSIMDReplace(SIMDReplace* curr) {
case ReplaceLaneVecI64x2:
o << U32LEB(BinaryConsts::I64x2ReplaceLane);
break;
case ReplaceLaneVecF16x8:
o << U32LEB(BinaryConsts::F16x8ReplaceLane);
break;
case ReplaceLaneVecF32x4:
o << U32LEB(BinaryConsts::F32x4ReplaceLane);
break;
Expand Down Expand Up @@ -1050,6 +1056,9 @@ void BinaryInstWriter::visitUnary(Unary* curr) {
case SplatVecI64x2:
o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::I64x2Splat);
break;
case SplatVecF16x8:
o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Splat);
break;
case SplatVecF32x4:
o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F32x4Splat);
break;
Expand Down
9 changes: 9 additions & 0 deletions src/wasm/wasm-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,10 @@ void FunctionValidator::visitSIMDExtract(SIMDExtract* curr) {
lane_t = Type::i64;
lanes = 2;
break;
case ExtractLaneVecF16x8:
lane_t = Type::f32;
lanes = 8;
break;
case ExtractLaneVecF32x4:
lane_t = Type::f32;
lanes = 4;
Expand Down Expand Up @@ -1318,6 +1322,10 @@ void FunctionValidator::visitSIMDReplace(SIMDReplace* curr) {
lane_t = Type::i64;
lanes = 2;
break;
case ReplaceLaneVecF16x8:
lane_t = Type::f32;
lanes = 8;
break;
case ReplaceLaneVecF32x4:
lane_t = Type::f32;
lanes = 4;
Expand Down Expand Up @@ -2036,6 +2044,7 @@ void FunctionValidator::visitUnary(Unary* curr) {
shouldBeEqual(
curr->value->type, Type(Type::i64), curr, "expected i64 splat value");
break;
case SplatVecF16x8:
case SplatVecF32x4:
shouldBeEqual(
curr->type, Type(Type::v128), curr, "expected splat to have v128 type");
Expand Down
2 changes: 2 additions & 0 deletions src/wasm/wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ void SIMDExtract::finalize() {
case ExtractLaneVecI64x2:
type = Type::i64;
break;
case ExtractLaneVecF16x8:
case ExtractLaneVecF32x4:
type = Type::f32;
break;
Expand Down Expand Up @@ -636,6 +637,7 @@ void Unary::finalize() {
case SplatVecI16x8:
case SplatVecI32x4:
case SplatVecI64x2:
case SplatVecF16x8:
case SplatVecF32x4:
case SplatVecF64x2:
case NotVec128:
Expand Down
Loading
Loading