diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e05fea09e8..90591bbb0af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,11 @@ full changeset diff at the end of each section. Current Trunk ------------- +- Reference type support is added. Supported instructions are `ref.null`, + `ref.is_null`, `ref.func`, and typed `select`. Table instructions are not + supported yet. For typed `select`, C/JS API can take an additional 'type' + parameter. + v90 --- diff --git a/check.py b/check.py index 377d41918fd..f4d6bed7e64 100755 --- a/check.py +++ b/check.py @@ -153,8 +153,10 @@ def check(): shared.fail_if_not_identical_to_file(actual, f) - shared.binary_format_check(t, wasm_as_args=['-g']) # test with debuginfo - shared.binary_format_check(t, wasm_as_args=[], binary_suffix='.fromBinary.noDebugInfo') # test without debuginfo + # FIXME Remove this condition after nullref is implemented in V8 + if 'reference-types.wast' not in t: + shared.binary_format_check(t, wasm_as_args=['-g']) # test with debuginfo + shared.binary_format_check(t, wasm_as_args=[], binary_suffix='.fromBinary.noDebugInfo') # test without debuginfo shared.minify_check(t) @@ -271,9 +273,9 @@ def run_wasm_reduce_tests(): before = os.stat('a.wasm').st_size support.run_command(shared.WASM_REDUCE + ['a.wasm', '--command=%s b.wasm --fuzz-exec -all' % shared.WASM_OPT[0], '-t', 'b.wasm', '-w', 'c.wasm']) after = os.stat('c.wasm').st_size - # 0.65 is a custom threshold to check if we have shrunk the output - # sufficiently - assert after < 0.7 * before, [before, after] + # This number is a custom threshold to check if we have shrunk the + # output sufficiently + assert after < 0.75 * before, [before, after] def run_spec_tests(): @@ -323,7 +325,10 @@ def check_expected(actual, expected): # some wast files cannot be split: # * comments.wast: contains characters that are not valid utf-8, # so our string splitting code fails there - if os.path.basename(wast) not in ['comments.wast']: + + # FIXME Remove reference type tests from this list after nullref is + # implemented in V8 + if os.path.basename(wast) not in ['comments.wast', 'ref_null.wast', 'ref_is_null.wast', 'ref_func.wast', 'old_select.wast']: split_num = 0 actual = '' for module, asserts in support.split_wast(wast): diff --git a/scripts/fuzz_opt.py b/scripts/fuzz_opt.py index a0b763aabac..b77d8e9b4d1 100644 --- a/scripts/fuzz_opt.py +++ b/scripts/fuzz_opt.py @@ -228,7 +228,7 @@ def compare_vs(self, before, after): break def can_run_on_feature_opts(self, feature_opts): - return all([x in feature_opts for x in ['--disable-simd']]) + return all([x in feature_opts for x in ['--disable-simd', '--disable-reference-types', '--disable-exception-handling']]) # Fuzz the interpreter with --fuzz-exec. This tests everything in a single command (no @@ -294,7 +294,7 @@ def run(self, wasm): return out def can_run_on_feature_opts(self, feature_opts): - return all([x in feature_opts for x in ['--disable-exception-handling', '--disable-simd', '--disable-threads', '--disable-bulk-memory', '--disable-nontrapping-float-to-int', '--disable-tail-call', '--disable-sign-ext']]) + return all([x in feature_opts for x in ['--disable-exception-handling', '--disable-simd', '--disable-threads', '--disable-bulk-memory', '--disable-nontrapping-float-to-int', '--disable-tail-call', '--disable-sign-ext', '--disable-reference-types']]) class Asyncify(TestCaseHandler): @@ -339,7 +339,7 @@ def do_asyncify(wasm): compare(before, after_asyncify, 'Asyncify (before/after_asyncify)') def can_run_on_feature_opts(self, feature_opts): - return all([x in feature_opts for x in ['--disable-exception-handling', '--disable-simd', '--disable-tail-call']]) + return all([x in feature_opts for x in ['--disable-exception-handling', '--disable-simd', '--disable-tail-call', '--disable-reference-types']]) # The global list of all test case handlers diff --git a/scripts/gen-s-parser.py b/scripts/gen-s-parser.py index 79b1a60f8ad..4a900ac0e51 100755 --- a/scripts/gen-s-parser.py +++ b/scripts/gen-s-parser.py @@ -49,7 +49,9 @@ ("f32.pop", "makePop(f32)"), ("f64.pop", "makePop(f64)"), ("v128.pop", "makePop(v128)"), + ("funcref.pop", "makePop(funcref)"), ("anyref.pop", "makePop(anyref)"), + ("nullref.pop", "makePop(nullref)"), ("exnref.pop", "makePop(exnref)"), ("i32.load", "makeLoad(s, i32, /*isAtomic=*/false)"), ("i64.load", "makeLoad(s, i64, /*isAtomic=*/false)"), @@ -469,6 +471,11 @@ ("i32x4.widen_low_i16x8_u", "makeUnary(s, UnaryOp::WidenLowUVecI16x8ToVecI32x4)"), ("i32x4.widen_high_i16x8_u", "makeUnary(s, UnaryOp::WidenHighUVecI16x8ToVecI32x4)"), ("v8x16.swizzle", "makeBinary(s, BinaryOp::SwizzleVec8x16)"), + # reference types instructions + # TODO Add table instructions + ("ref.null", "makeRefNull(s)"), + ("ref.is_null", "makeRefIsNull(s)"), + ("ref.func", "makeRefFunc(s)"), # exception handling instructions ("try", "makeTry(s)"), ("throw", "makeThrow(s)"), diff --git a/src/asmjs/asm_v_wasm.cpp b/src/asmjs/asm_v_wasm.cpp index 3720ca07902..5959db43e7c 100644 --- a/src/asmjs/asm_v_wasm.cpp +++ b/src/asmjs/asm_v_wasm.cpp @@ -53,10 +53,11 @@ AsmType wasmToAsmType(Type type) { return ASM_INT64; case v128: assert(false && "v128 not implemented yet"); + case funcref: case anyref: - assert(false && "anyref is not supported by asm2wasm"); + case nullref: case exnref: - assert(false && "exnref is not supported by asm2wasm"); + assert(false && "reference types are not supported by asm2wasm"); case none: return ASM_NONE; case unreachable: @@ -77,10 +78,14 @@ char getSig(Type type) { return 'd'; case v128: return 'V'; + case funcref: + return 'F'; case anyref: - return 'a'; + return 'A'; + case nullref: + return 'N'; case exnref: - return 'e'; + return 'E'; case none: return 'v'; case unreachable: diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index 82cbc4c1f01..826193e0690 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -64,13 +64,16 @@ BinaryenLiteral toBinaryenLiteral(Literal x) { case Type::f64: ret.i64 = x.reinterpreti64(); break; - case Type::v128: { + case Type::v128: memcpy(&ret.v128, x.getv128Ptr(), 16); break; - } - - case Type::anyref: // there's no anyref literals - case Type::exnref: // there's no exnref literals + case Type::funcref: + ret.func = x.getFunc().c_str(); + break; + case Type::nullref: + break; + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("unexpected type"); @@ -90,8 +93,12 @@ Literal fromBinaryenLiteral(BinaryenLiteral x) { return Literal(x.i64).castToF64(); case Type::v128: return Literal(x.v128); - case Type::anyref: // there's no anyref literals - case Type::exnref: // there's no exnref literals + case Type::funcref: + return Literal::makeFuncref(x.func); + case Type::nullref: + return Literal::makeNullref(); + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("unexpected type"); @@ -209,8 +216,14 @@ void printArg(std::ostream& setup, std::ostream& out, BinaryenLiteral arg) { out << "BinaryenLiteralVec128(" << array << ")"; break; } - case Type::anyref: // there's no anyref literals - case Type::exnref: // there's no exnref literals + case Type::funcref: + out << "BinaryenLiteralFuncref(" << arg.func << ")"; + break; + case Type::nullref: + out << "BinaryenLiteralNullref()"; + break; + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("unexpected type"); @@ -265,7 +278,9 @@ BinaryenType BinaryenTypeInt64(void) { return i64; } BinaryenType BinaryenTypeFloat32(void) { return f32; } BinaryenType BinaryenTypeFloat64(void) { return f64; } BinaryenType BinaryenTypeVec128(void) { return v128; } +BinaryenType BinaryenTypeFuncref(void) { return funcref; } BinaryenType BinaryenTypeAnyref(void) { return anyref; } +BinaryenType BinaryenTypeNullref(void) { return nullref; } BinaryenType BinaryenTypeExnref(void) { return exnref; } BinaryenType BinaryenTypeUnreachable(void) { return unreachable; } BinaryenType BinaryenTypeAuto(void) { return uint32_t(-1); } @@ -397,6 +412,15 @@ BinaryenExpressionId BinaryenMemoryCopyId(void) { BinaryenExpressionId BinaryenMemoryFillId(void) { return Expression::Id::MemoryFillId; } +BinaryenExpressionId BinaryenRefNullId(void) { + return Expression::Id::RefNullId; +} +BinaryenExpressionId BinaryenRefIsNullId(void) { + return Expression::Id::RefIsNullId; +} +BinaryenExpressionId BinaryenRefFuncId(void) { + return Expression::Id::RefFuncId; +} BinaryenExpressionId BinaryenTryId(void) { return Expression::Id::TryId; } BinaryenExpressionId BinaryenThrowId(void) { return Expression::Id::ThrowId; } BinaryenExpressionId BinaryenRethrowId(void) { @@ -1330,17 +1354,22 @@ BinaryenExpressionRef BinaryenBinary(BinaryenModuleRef module, BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressionRef condition, BinaryenExpressionRef ifTrue, - BinaryenExpressionRef ifFalse) { + BinaryenExpressionRef ifFalse, + BinaryenType type) { auto* ret = ((Module*)module)->allocator.alloc(); + ret->condition = condition; + ret->ifTrue = ifTrue; + ret->ifFalse = ifFalse; + ret->finalize(type); + return ret; + } Return* makeReturn(Expression* value = nullptr) { auto* ret = allocator.alloc(); ret->value = value; @@ -502,6 +527,23 @@ class Builder { ret->finalize(); return ret; } + RefNull* makeRefNull() { + auto* ret = allocator.alloc(); + ret->finalize(); + return ret; + } + RefIsNull* makeRefIsNull(Expression* value) { + auto* ret = allocator.alloc(); + ret->value = value; + ret->finalize(); + return ret; + } + RefFunc* makeRefFunc(Name func) { + auto* ret = allocator.alloc(); + ret->func = func; + ret->finalize(); + return ret; + } Try* makeTry(Expression* body, Expression* catchBody) { auto* ret = allocator.alloc(); ret->body = body; @@ -569,6 +611,21 @@ class Builder { return ret; } + Expression* makeConstExpression(Literal value) { + switch (value.type) { + case Type::nullref: + return makeRefNull(); + case Type::funcref: + if (value.getFunc()[0] != 0) { + return makeRefFunc(value.getFunc()); + } + return makeRefNull(); + default: + assert(value.type.isNumber()); + return makeConst(value); + } + } + // Additional utility functions for building on top of nodes // Convenient to have these on Builder, as it has allocation built in @@ -663,6 +720,13 @@ class Builder { return block; } + Block* makeSequence(Expression* left, Expression* right, Type type) { + auto* block = makeBlock(left); + block->list.push_back(right); + block->finalize(type); + return block; + } + // Grab a slice out of a block, replacing it with nops, and returning // either another block with the contents (if more than 1) or a single // expression @@ -728,16 +792,15 @@ class Builder { value = Literal(bytes.data()); break; } + case funcref: case anyref: - // TODO Implement and return nullref - assert(false && "anyref not implemented yet"); + case nullref: case exnref: - // TODO Implement and return nullref - assert(false && "exnref not implemented yet"); + return ExpressionManipulator::refNull(curr); case none: return ExpressionManipulator::nop(curr); case unreachable: - return ExpressionManipulator::convert(curr); + return ExpressionManipulator::unreachable(curr); } return makeConst(value); } diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 571f0d1a538..f37a6edd68a 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -143,13 +143,13 @@ class ExpressionRunner : public OverriddenVisitor { if (!ret.breaking() && (curr->type.isConcrete() || ret.value.type.isConcrete())) { #if 1 // def WASM_INTERPRETER_DEBUG - if (ret.value.type != curr->type) { + if (!Type::isSubType(ret.value.type, curr->type)) { std::cerr << "expected " << curr->type << ", seeing " << ret.value.type << " from\n" << curr << '\n'; } #endif - assert(ret.value.type == curr->type); + assert(Type::isSubType(ret.value.type, curr->type)); } depth--; return ret; @@ -1095,7 +1095,7 @@ class ExpressionRunner : public OverriddenVisitor { return Literal(uint64_t(val)); } } - Flow visitAtomicFence(AtomicFence*) { + Flow visitAtomicFence(AtomicFence* curr) { // Wasm currently supports only sequentially consistent atomics, in which // case atomic_fence can be lowered to nothing. NOTE_ENTER("AtomicFence"); @@ -1123,6 +1123,26 @@ class ExpressionRunner : public OverriddenVisitor { Flow visitSIMDLoadExtend(SIMDLoad*) { WASM_UNREACHABLE("unimp"); } Flow visitPush(Push*) { WASM_UNREACHABLE("unimp"); } Flow visitPop(Pop*) { WASM_UNREACHABLE("unimp"); } + Flow visitRefNull(RefNull* curr) { + NOTE_ENTER("RefNull"); + return Literal::makeNullref(); + } + Flow visitRefIsNull(RefIsNull* curr) { + NOTE_ENTER("RefIsNull"); + Flow flow = visit(curr->value); + if (flow.breaking()) { + return flow; + } + Literal value = flow.value; + NOTE_EVAL1(value); + return Literal(value.type == nullref); + } + Flow visitRefFunc(RefFunc* curr) { + NOTE_ENTER("RefFunc"); + NOTE_NAME(curr->func); + return Literal::makeFuncref(curr->func); + } + // TODO Implement EH instructions Flow visitTry(Try*) { WASM_UNREACHABLE("unimp"); } Flow visitThrow(Throw*) { WASM_UNREACHABLE("unimp"); } Flow visitRethrow(Rethrow*) { WASM_UNREACHABLE("unimp"); } @@ -1217,8 +1237,10 @@ template class ModuleInstanceBase { return Literal(load64u(addr)).castToF64(); case v128: return Literal(load128(addr).data()); - case anyref: // anyref cannot be loaded from memory - case exnref: // exnref cannot be loaded from memory + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1272,8 +1294,10 @@ template class ModuleInstanceBase { case v128: store128(addr, value.getv128()); break; - case anyref: // anyref cannot be stored from memory - case exnref: // exnref cannot be stored in memory + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1464,7 +1488,7 @@ template class ModuleInstanceBase { for (size_t i = 0; i < function->getNumLocals(); i++) { if (i < arguments.size()) { assert(i < params.size()); - if (params[i] != arguments[i].type) { + if (!Type::isSubType(arguments[i].type, params[i])) { std::cerr << "Function `" << function->name << "` expects type " << params[i] << " for parameter " << i << ", got " << arguments[i].type << "." << std::endl; @@ -1473,7 +1497,7 @@ template class ModuleInstanceBase { locals[i] = arguments[i]; } else { assert(function->isVar(i)); - locals[i].type = function->getLocalType(i); + locals[i] = Literal::makeZero(function->getLocalType(i)); } } } @@ -1580,7 +1604,8 @@ template class ModuleInstanceBase { } NOTE_EVAL1(index); NOTE_EVAL1(flow.value); - assert(curr->isTee() ? flow.value.type == curr->type : true); + assert(curr->isTee() ? Type::isSubType(flow.value.type, curr->type) + : true); scope.locals[index] = flow.value; return curr->isTee() ? flow : Flow(); } @@ -2067,7 +2092,7 @@ template class ModuleInstanceBase { // cannot still be breaking, it means we missed our stop assert(!flow.breaking() || flow.breakTo == RETURN_FLOW); Literal ret = flow.value; - if (function->sig.results != ret.type) { + if (!Type::isSubType(ret.type, function->sig.results)) { std::cerr << "calling " << function->name << " resulted in " << ret << " but the function type is " << function->sig.results << '\n'; diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index d7324d7564e..8cdcb88f416 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -225,6 +225,9 @@ class SExpressionWasmBuilder { Expression* makeBreak(Element& s); Expression* makeBreakTable(Element& s); Expression* makeReturn(Element& s); + Expression* makeRefNull(Element& s); + Expression* makeRefIsNull(Element& s); + Expression* makeRefFunc(Element& s); Expression* makeTry(Element& s); Expression* makeCatch(Element& s, Type type); Expression* makeThrow(Element& s); diff --git a/src/wasm-stack.h b/src/wasm-stack.h index fbd28b0d5d3..91c0c5383b9 100644 --- a/src/wasm-stack.h +++ b/src/wasm-stack.h @@ -128,6 +128,9 @@ class BinaryInstWriter : public OverriddenVisitor { void visitSelect(Select* curr); void visitReturn(Return* curr); void visitHost(Host* curr); + void visitRefNull(RefNull* curr); + void visitRefIsNull(RefIsNull* curr); + void visitRefFunc(RefFunc* curr); void visitTry(Try* curr); void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); @@ -207,6 +210,9 @@ class BinaryenIRWriter : public OverriddenVisitor> { void visitSelect(Select* curr); void visitReturn(Return* curr); void visitHost(Host* curr); + void visitRefNull(RefNull* curr); + void visitRefIsNull(RefIsNull* curr); + void visitRefFunc(RefFunc* curr); void visitTry(Try* curr); void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); @@ -698,6 +704,30 @@ void BinaryenIRWriter::visitHost(Host* curr) { emit(curr); } +template +void BinaryenIRWriter::visitRefNull(RefNull* curr) { + emit(curr); +} + +template +void BinaryenIRWriter::visitRefIsNull(RefIsNull* curr) { + visit(curr->value); + if (curr->type == Type::unreachable) { + emitUnreachable(); + return; + } + emit(curr); +} + +template +void BinaryenIRWriter::visitRefFunc(RefFunc* curr) { + if (curr->type == Type::unreachable) { + emitUnreachable(); + return; + } + emit(curr); +} + template void BinaryenIRWriter::visitTry(Try* curr) { emit(curr); visitPossibleBlockContents(curr->body); diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 9c6e78360dc..c9290cbabe7 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -72,6 +72,9 @@ template struct Visitor { ReturnType visitDrop(Drop* curr) { return ReturnType(); } ReturnType visitReturn(Return* curr) { return ReturnType(); } ReturnType visitHost(Host* curr) { return ReturnType(); } + ReturnType visitRefNull(RefNull* curr) { return ReturnType(); } + ReturnType visitRefIsNull(RefIsNull* curr) { return ReturnType(); } + ReturnType visitRefFunc(RefFunc* curr) { return ReturnType(); } ReturnType visitTry(Try* curr) { return ReturnType(); } ReturnType visitThrow(Throw* curr) { return ReturnType(); } ReturnType visitRethrow(Rethrow* curr) { return ReturnType(); } @@ -167,6 +170,12 @@ template struct Visitor { DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); + case Expression::Id::RefNullId: + DELEGATE(RefNull); + case Expression::Id::RefIsNullId: + DELEGATE(RefIsNull); + case Expression::Id::RefFuncId: + DELEGATE(RefFunc); case Expression::Id::TryId: DELEGATE(Try); case Expression::Id::ThrowId: @@ -241,6 +250,9 @@ struct OverriddenVisitor { UNIMPLEMENTED(Drop); UNIMPLEMENTED(Return); UNIMPLEMENTED(Host); + UNIMPLEMENTED(RefNull); + UNIMPLEMENTED(RefIsNull); + UNIMPLEMENTED(RefFunc); UNIMPLEMENTED(Try); UNIMPLEMENTED(Throw); UNIMPLEMENTED(Rethrow); @@ -337,6 +349,12 @@ struct OverriddenVisitor { DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); + case Expression::Id::RefNullId: + DELEGATE(RefNull); + case Expression::Id::RefIsNullId: + DELEGATE(RefIsNull); + case Expression::Id::RefFuncId: + DELEGATE(RefFunc); case Expression::Id::TryId: DELEGATE(Try); case Expression::Id::ThrowId: @@ -476,6 +494,15 @@ struct UnifiedExpressionVisitor : public Visitor { ReturnType visitHost(Host* curr) { return static_cast(this)->visitExpression(curr); } + ReturnType visitRefNull(RefNull* curr) { + return static_cast(this)->visitExpression(curr); + } + ReturnType visitRefIsNull(RefIsNull* curr) { + return static_cast(this)->visitExpression(curr); + } + ReturnType visitRefFunc(RefFunc* curr) { + return static_cast(this)->visitExpression(curr); + } ReturnType visitTry(Try* curr) { return static_cast(this)->visitExpression(curr); } @@ -778,6 +805,15 @@ struct Walker : public VisitorType { static void doVisitHost(SubType* self, Expression** currp) { self->visitHost((*currp)->cast()); } + static void doVisitRefNull(SubType* self, Expression** currp) { + self->visitRefNull((*currp)->cast()); + } + static void doVisitRefIsNull(SubType* self, Expression** currp) { + self->visitRefIsNull((*currp)->cast()); + } + static void doVisitRefFunc(SubType* self, Expression** currp) { + self->visitRefFunc((*currp)->cast()); + } static void doVisitTry(SubType* self, Expression** currp) { self->visitTry((*currp)->cast()); } @@ -1036,6 +1072,19 @@ struct PostWalker : public Walker { } break; } + case Expression::Id::RefNullId: { + self->pushTask(SubType::doVisitRefNull, currp); + break; + } + case Expression::Id::RefIsNullId: { + self->pushTask(SubType::doVisitRefIsNull, currp); + self->pushTask(SubType::scan, &curr->cast()->value); + break; + } + case Expression::Id::RefFuncId: { + self->pushTask(SubType::doVisitRefFunc, currp); + break; + } case Expression::Id::TryId: { self->pushTask(SubType::doVisitTry, currp); self->pushTask(SubType::scan, &curr->cast()->catchBody); @@ -1099,7 +1148,7 @@ struct ControlFlowWalker : public PostWalker { Expression* findBreakTarget(Name name) { assert(!controlFlowStack.empty()); Index i = controlFlowStack.size() - 1; - while (1) { + while (true) { auto* curr = controlFlowStack[i]; if (Block* block = curr->template dynCast()) { if (name == block->name) { @@ -1111,7 +1160,7 @@ struct ControlFlowWalker : public PostWalker { } } else { // an if, ignorable - assert(curr->template is()); + assert(curr->template is() || curr->template is()); } if (i == 0) { return nullptr; @@ -1169,7 +1218,7 @@ struct ExpressionStackWalker : public PostWalker { Expression* findBreakTarget(Name name) { assert(!expressionStack.empty()); Index i = expressionStack.size() - 1; - while (1) { + while (true) { auto* curr = expressionStack[i]; if (Block* block = curr->template dynCast()) { if (name == block->name) { @@ -1179,8 +1228,6 @@ struct ExpressionStackWalker : public PostWalker { if (name == loop->name) { return curr; } - } else { - WASM_UNREACHABLE("unexpected expression type"); } if (i == 0) { return nullptr; diff --git a/src/wasm-type.h b/src/wasm-type.h index 53ef39ef833..668ac3e4d64 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -36,7 +36,9 @@ class Type { f32, f64, v128, + funcref, anyref, + nullref, exnref, _last_value_type, }; @@ -64,7 +66,8 @@ class Type { bool isInteger() const { return id == i32 || id == i64; } bool isFloat() const { return id == f32 || id == f64; } bool isVector() const { return id == v128; }; - bool isRef() const { return id == anyref || id == exnref; } + bool isNumber() const { return id >= i32 && id <= v128; } + bool isRef() const { return id >= funcref && id <= exnref; } // (In)equality must be defined for both Type and ValueType because it is // otherwise ambiguous whether to convert both this and other to int or @@ -94,6 +97,23 @@ class Type { // type. static Type get(unsigned byteSize, bool float_); + // Returns true if left is a subtype of right. Subtype includes itself. + static bool isSubType(Type left, Type right); + + // Computes the least upper bound from the type lattice. + // If one of the type is unreachable, the other type becomes the result. If + // the common supertype does not exist, returns none, a poison value. + static Type getLeastUpperBound(Type a, Type b); + + // Computes the least upper bound for all types in the given list. + template static Type mergeTypes(const T& types) { + Type type = Type::unreachable; + for (auto other : types) { + type = Type::getLeastUpperBound(type, other); + } + return type; + } + std::string toString() const; }; @@ -134,7 +154,9 @@ constexpr Type i64 = Type::i64; constexpr Type f32 = Type::f32; constexpr Type f64 = Type::f64; constexpr Type v128 = Type::v128; +constexpr Type funcref = Type::funcref; constexpr Type anyref = Type::anyref; +constexpr Type nullref = Type::nullref; constexpr Type exnref = Type::exnref; constexpr Type unreachable = Type::unreachable; diff --git a/src/wasm.h b/src/wasm.h index 48adf103b06..c4dbd2f3f51 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -531,6 +531,9 @@ class Expression { MemoryFillId, PushId, PopId, + RefNullId, + RefIsNullId, + RefFuncId, TryId, ThrowId, RethrowId, @@ -569,6 +572,8 @@ class Expression { const char* getExpressionName(Expression* curr); +Literal getLiteralFromConstExpression(Expression* curr); + typedef ArenaVector ExpressionList; template class SpecificExpression : public Expression { @@ -1008,6 +1013,7 @@ class Select : public SpecificExpression { Expression* condition; void finalize(); + void finalize(Type type_); }; class Drop : public SpecificExpression { @@ -1070,6 +1076,32 @@ class Pop : public SpecificExpression { Pop(MixedArena& allocator) {} }; +class RefNull : public SpecificExpression { +public: + RefNull() = default; + RefNull(MixedArena& allocator) {} + + void finalize(); +}; + +class RefIsNull : public SpecificExpression { +public: + RefIsNull(MixedArena& allocator) {} + + Expression* value; + + void finalize(); +}; + +class RefFunc : public SpecificExpression { +public: + RefFunc(MixedArena& allocator) {} + + Name func; + + void finalize(); +}; + class Try : public SpecificExpression { public: Try(MixedArena& allocator) {} diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index 82a150257c1..4f66b36e352 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -137,8 +137,11 @@ void Literal::getBits(uint8_t (&buf)[16]) const { case Type::v128: memcpy(buf, &v128, sizeof(v128)); break; - case Type::anyref: // anyref type is opaque - case Type::exnref: // exnref type is opaque + case Type::funcref: + case Type::nullref: + break; + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("invalid type"); @@ -146,10 +149,20 @@ void Literal::getBits(uint8_t (&buf)[16]) const { } bool Literal::operator==(const Literal& other) const { + if (type.isRef() && other.type.isRef()) { + if (type == Type::nullref && other.type == Type::nullref) { + return true; + } + if (type == Type::funcref && other.type == Type::funcref && + func == other.func) { + return true; + } + return false; + } if (type != other.type) { return false; } - if (type == none) { + if (type == Type::none) { return true; } uint8_t bits[16], other_bits[16]; @@ -273,8 +286,14 @@ std::ostream& operator<<(std::ostream& o, Literal literal) { o << "i32x4 "; literal.printVec128(o, literal.getv128()); break; - case Type::anyref: // anyref type is opaque - case Type::exnref: // exnref type is opaque + case Type::funcref: + o << "funcref(" << literal.getFunc() << ")"; + break; + case Type::nullref: + o << "nullref"; + break; + case Type::anyref: + case Type::exnref: case Type::unreachable: WASM_UNREACHABLE("invalid type"); } @@ -477,7 +496,9 @@ Literal Literal::eqz() const { case Type::f64: return eq(Literal(double(0))); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -497,7 +518,9 @@ Literal Literal::neg() const { case Type::f64: return Literal(int64_t(i64 ^ 0x8000000000000000ULL)).castToF64(); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -517,7 +540,9 @@ Literal Literal::abs() const { case Type::f64: return Literal(int64_t(i64 & 0x7fffffffffffffffULL)).castToF64(); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -620,7 +645,9 @@ Literal Literal::add(const Literal& other) const { case Type::f64: return Literal(getf64() + other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -640,7 +667,9 @@ Literal Literal::sub(const Literal& other) const { case Type::f64: return Literal(getf64() - other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -731,7 +760,9 @@ Literal Literal::mul(const Literal& other) const { case Type::f64: return Literal(getf64() * other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -967,7 +998,9 @@ Literal Literal::eq(const Literal& other) const { case Type::f64: return Literal(getf64() == other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -987,7 +1020,9 @@ Literal Literal::ne(const Literal& other) const { case Type::f64: return Literal(getf64() != other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 82eb51d7e39..ba5a8d3ddb0 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -262,7 +262,7 @@ void WasmBinaryWriter::writeImports() { BYN_TRACE("write one table\n"); writeImportHeader(&wasm->table); o << U32LEB(int32_t(ExternalKind::Table)); - o << S32LEB(BinaryConsts::EncodedType::AnyFunc); + o << S32LEB(BinaryConsts::EncodedType::funcref); writeResizableLimits(wasm->table.initial, wasm->table.max, wasm->table.hasMax(), @@ -463,7 +463,7 @@ void WasmBinaryWriter::writeFunctionTableDeclaration() { BYN_TRACE("== writeFunctionTableDeclaration\n"); auto start = startSection(BinaryConsts::Section::Table); o << U32LEB(1); // Declare 1 table. - o << S32LEB(BinaryConsts::EncodedType::AnyFunc); + o << S32LEB(BinaryConsts::EncodedType::funcref); writeResizableLimits(wasm->table.initial, wasm->table.max, wasm->table.hasMax(), @@ -1059,8 +1059,12 @@ Type WasmBinaryBuilder::getType() { return f64; case BinaryConsts::EncodedType::v128: return v128; + case BinaryConsts::EncodedType::funcref: + return funcref; case BinaryConsts::EncodedType::anyref: return anyref; + case BinaryConsts::EncodedType::nullref: + return nullref; case BinaryConsts::EncodedType::exnref: return exnref; default: @@ -1258,8 +1262,8 @@ void WasmBinaryBuilder::readImports() { wasm.table.name = Name(std::string("timport$") + std::to_string(i)); auto elementType = getS32LEB(); WASM_UNUSED(elementType); - if (elementType != BinaryConsts::EncodedType::AnyFunc) { - throwError("Imported table type is not AnyFunc"); + if (elementType != BinaryConsts::EncodedType::funcref) { + throwError("Imported table type is not funcref"); } wasm.table.exists = true; bool is_shared; @@ -1802,11 +1806,17 @@ void WasmBinaryBuilder::processFunctions() { wasm.addExport(curr); } - for (auto& iter : functionCalls) { + for (auto& iter : functionRefs) { size_t index = iter.first; - auto& calls = iter.second; - for (auto* call : calls) { - call->target = getFunctionName(index); + auto& refs = iter.second; + for (auto* ref : refs) { + if (auto* call = ref->dynCast()) { + call->target = getFunctionName(index); + } else if (auto* refFunc = ref->dynCast()) { + refFunc->func = getFunctionName(index); + } else { + WASM_UNREACHABLE("Invalid type in function references"); + } } } @@ -1869,8 +1879,8 @@ void WasmBinaryBuilder::readFunctionTableDeclaration() { } wasm.table.exists = true; auto elemType = getS32LEB(); - if (elemType != BinaryConsts::EncodedType::AnyFunc) { - throwError("ElementType must be AnyFunc in MVP"); + if (elemType != BinaryConsts::EncodedType::funcref) { + throwError("ElementType must be funcref in MVP"); } bool is_shared; getResizableLimits( @@ -2117,7 +2127,8 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { visitGlobalSet((curr = allocator.alloc())->cast()); break; case BinaryConsts::Select: - visitSelect((curr = allocator.alloc()); + case BinaryConsts::SelectWithType: + visitSelect((curr = allocator.alloc(), code); break; case BinaryConsts::Return: visitReturn((curr = allocator.alloc())->cast()); @@ -2137,6 +2148,15 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { case BinaryConsts::Catch: curr = nullptr; break; + case BinaryConsts::RefNull: + visitRefNull((curr = allocator.alloc())->cast()); + break; + case BinaryConsts::RefIsNull: + visitRefIsNull((curr = allocator.alloc())->cast()); + break; + case BinaryConsts::RefFunc: + visitRefFunc((curr = allocator.alloc())->cast()); + break; case BinaryConsts::Try: visitTry((curr = allocator.alloc())->cast()); break; @@ -2510,7 +2530,7 @@ void WasmBinaryBuilder::visitCall(Call* curr) { curr->operands[num - i - 1] = popNonVoidExpression(); } curr->type = sig.results; - functionCalls[index].push_back(curr); // we don't know function names yet + functionRefs[index].push_back(curr); // we don't know function names yet curr->finalize(); } @@ -4326,12 +4346,24 @@ bool WasmBinaryBuilder::maybeVisitSIMDLoad(Expression*& out, uint32_t code) { return true; } -void WasmBinaryBuilder::visitSelect(Select* curr) { - BYN_TRACE("zz node: Select\n"); +void WasmBinaryBuilder::visitSelect(Select* curr, uint8_t code) { + BYN_TRACE("zz node: Select, code " << int32_t(code) << std::endl); + if (code == BinaryConsts::SelectWithType) { + size_t numTypes = getU32LEB(); + std::vector types; + for (size_t i = 0; i < numTypes; i++) { + types.push_back(getType()); + } + curr->type = Type(types); + } curr->condition = popNonVoidExpression(); curr->ifFalse = popNonVoidExpression(); curr->ifTrue = popNonVoidExpression(); - curr->finalize(); + if (code == BinaryConsts::SelectWithType) { + curr->finalize(curr->type); + } else { + curr->finalize(); + } } void WasmBinaryBuilder::visitReturn(Return* curr) { @@ -4383,6 +4415,27 @@ void WasmBinaryBuilder::visitDrop(Drop* curr) { curr->finalize(); } +void WasmBinaryBuilder::visitRefNull(RefNull* curr) { + BYN_TRACE("zz node: RefNull\n"); + curr->finalize(); +} + +void WasmBinaryBuilder::visitRefIsNull(RefIsNull* curr) { + BYN_TRACE("zz node: RefIsNull\n"); + curr->value = popNonVoidExpression(); + curr->finalize(); +} + +void WasmBinaryBuilder::visitRefFunc(RefFunc* curr) { + BYN_TRACE("zz node: RefFunc\n"); + Index index = getU32LEB(); + if (index >= functionImports.size() + functionSignatures.size()) { + throwError("ref.func: invalid call index"); + } + functionRefs[index].push_back(curr); // we don't know function names yet + curr->finalize(); +} + void WasmBinaryBuilder::visitTry(Try* curr) { BYN_TRACE("zz node: Try\n"); // For simplicity of implementation, like if scopes, we create a hidden block diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 20aff209124..3b12c4346f4 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -850,16 +850,22 @@ Type SExpressionWasmBuilder::stringToType(const char* str, return v128; } } + if (strncmp(str, "funcref", 7) == 0 && (prefix || str[7] == 0)) { + return funcref; + } if (strncmp(str, "anyref", 6) == 0 && (prefix || str[6] == 0)) { return anyref; } + if (strncmp(str, "nullref", 7) == 0 && (prefix || str[7] == 0)) { + return nullref; + } if (strncmp(str, "exnref", 6) == 0 && (prefix || str[6] == 0)) { return exnref; } if (allowError) { return none; } - throw ParseException("invalid wasm type"); + throw ParseException(std::string("invalid wasm type: ") + str); } Type SExpressionWasmBuilder::stringToLaneType(const char* str) { @@ -936,10 +942,16 @@ Expression* SExpressionWasmBuilder::makeUnary(Element& s, UnaryOp op) { Expression* SExpressionWasmBuilder::makeSelect(Element& s) { auto ret = allocator.alloc