diff --git a/CHANGELOG.md b/CHANGELOG.md index c5011f6a0bf..d3f9cf4ad4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ full changeset diff at the end of each section. Current Trunk ------------- +- `local.tee`'s C/Binaryen.js API now takes an additional type parameter for its + local type, like `local.get`. This is required to handle subtypes. - Added load_splat SIMD instructions - Binaryen.js instruction API changes: - `notify` -> `atomic.notify` diff --git a/src/asm2wasm.h b/src/asm2wasm.h index fc841e63462..319116a4284 100644 --- a/src/asm2wasm.h +++ b/src/asm2wasm.h @@ -1907,7 +1907,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { auto ret = allocator.alloc(); ret->index = function->getLocalIndex(assign->target()); ret->value = process(assign->value()); - ret->setTee(false); + ret->makeSet(); ret->finalize(); return ret; } @@ -2158,7 +2158,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { auto set = allocator.alloc(); set->index = function->getLocalIndex(I32_TEMP); set->value = value; - set->setTee(false); + set->makeSet(); set->finalize(); auto get = [&]() { auto ret = allocator.alloc(); @@ -2264,7 +2264,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { view.bytes, 0, processUnshifted(ast[2][1], view.bytes), - builder.makeLocalTee(temp, process(ast[2][2])), + builder.makeLocalTee(temp, process(ast[2][2]), type), type), builder.makeLocalGet(temp, type)); } else if (name == Atomics_exchange) { diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index 35f394fd19d..70712d4d46f 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -1197,22 +1197,23 @@ BinaryenExpressionRef BinaryenLocalSet(BinaryenModuleRef module, ret->index = index; ret->value = (Expression*)value; - ret->setTee(false); + ret->makeSet(); ret->finalize(); return static_cast(ret); } BinaryenExpressionRef BinaryenLocalTee(BinaryenModuleRef module, BinaryenIndex index, - BinaryenExpressionRef value) { + BinaryenExpressionRef value, + BinaryenType type) { auto* ret = ((Module*)module)->allocator.alloc(); if (tracing) { - traceExpression(ret, "BinaryenLocalTee", index, value); + traceExpression(ret, "BinaryenLocalTee", index, value, type); } ret->index = index; ret->value = (Expression*)value; - ret->setTee(true); + ret->makeTee(Type(type)); ret->finalize(); return static_cast(ret); } diff --git a/src/binaryen-c.h b/src/binaryen-c.h index ab321064467..9bbca29f85f 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -650,8 +650,10 @@ BINARYEN_API BinaryenExpressionRef BinaryenLocalGet(BinaryenModuleRef module, BinaryenType type); BINARYEN_API BinaryenExpressionRef BinaryenLocalSet( BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value); -BINARYEN_API BinaryenExpressionRef BinaryenLocalTee( - BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value); +BINARYEN_API BinaryenExpressionRef BinaryenLocalTee(BinaryenModuleRef module, + BinaryenIndex index, + BinaryenExpressionRef value, + BinaryenType type); BINARYEN_API BinaryenExpressionRef BinaryenGlobalGet(BinaryenModuleRef module, const char* name, BinaryenType type); diff --git a/src/ir/ExpressionManipulator.cpp b/src/ir/ExpressionManipulator.cpp index fd0e6fd7564..fbee9f9c129 100644 --- a/src/ir/ExpressionManipulator.cpp +++ b/src/ir/ExpressionManipulator.cpp @@ -91,7 +91,7 @@ flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) { } Expression* visitLocalSet(LocalSet* curr) { if (curr->isTee()) { - return builder.makeLocalTee(curr->index, copy(curr->value)); + return builder.makeLocalTee(curr->index, copy(curr->value), curr->type); } else { return builder.makeLocalSet(curr->index, copy(curr->value)); } diff --git a/src/ir/localize.h b/src/ir/localize.h index ff454382dc1..733e7bdecd3 100644 --- a/src/ir/localize.h +++ b/src/ir/localize.h @@ -36,7 +36,7 @@ struct Localizer { index = set->index; } else { index = Builder::addVar(func, expr->type); - expr = Builder(*wasm).makeLocalTee(index, expr); + expr = Builder(*wasm).makeLocalTee(index, expr, expr->type); } } }; diff --git a/src/js/binaryen.js-post.js b/src/js/binaryen.js-post.js index 65a4e15dcdb..918d205cf2e 100644 --- a/src/js/binaryen.js-post.js +++ b/src/js/binaryen.js-post.js @@ -540,8 +540,11 @@ function wrapModule(module, self) { 'set': function(index, value) { return Module['_BinaryenLocalSet'](module, index, value); }, - 'tee': function(index, value) { - return Module['_BinaryenLocalTee'](module, index, value); + 'tee': function(index, value, type) { + if (typeof type === 'undefined') { + throw new Error("local.tee's type should be defined"); + } + return Module['_BinaryenLocalTee'](module, index, value, type); } } diff --git a/src/passes/Flatten.cpp b/src/passes/Flatten.cpp index 6f698367f00..74788afb520 100644 --- a/src/passes/Flatten.cpp +++ b/src/passes/Flatten.cpp @@ -172,9 +172,10 @@ struct Flatten replaceCurrent(set->value); // trivial, no set happens } else { // use a set in a prelude + a get - set->setTee(false); + set->makeSet(); ourPreludes.push_back(set); - replaceCurrent(builder.makeLocalGet(set->index, set->value->type)); + Type localType = getFunction()->getLocalType(set->index); + replaceCurrent(builder.makeLocalGet(set->index, localType)); } } } else if (auto* br = curr->dynCast()) { diff --git a/src/passes/LocalCSE.cpp b/src/passes/LocalCSE.cpp index afd30c040f5..0816bf6ea8e 100644 --- a/src/passes/LocalCSE.cpp +++ b/src/passes/LocalCSE.cpp @@ -184,8 +184,9 @@ struct LocalCSE : public WalkerPass> { if (iter != usables.end()) { // already exists in the table, this is good to reuse auto& info = iter->second; + Type localType = getFunction()->getLocalType(info.index); set->value = - Builder(*getModule()).makeLocalGet(info.index, value->type); + Builder(*getModule()).makeLocalGet(info.index, localType); anotherPass = true; } else { // not in table, add this, maybe we can help others later diff --git a/src/passes/MergeLocals.cpp b/src/passes/MergeLocals.cpp index c20105621c5..0116753f158 100644 --- a/src/passes/MergeLocals.cpp +++ b/src/passes/MergeLocals.cpp @@ -88,7 +88,7 @@ struct MergeLocals if (auto* get = curr->value->dynCast()) { if (get->index != curr->index) { Builder builder(*getModule()); - auto* trivial = builder.makeLocalTee(get->index, get); + auto* trivial = builder.makeLocalTee(get->index, get, get->type); curr->value = trivial; copies.push_back(curr); } diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp index 55f57302d19..e0174934aac 100644 --- a/src/passes/RemoveUnusedBrs.cpp +++ b/src/passes/RemoveUnusedBrs.cpp @@ -287,7 +287,7 @@ struct RemoveUnusedBrs : public WalkerPass> { Expression* z; replaceCurrent( z = builder.makeIf( - builder.makeLocalTee(temp, curr->condition), + builder.makeLocalTee(temp, curr->condition, i32), builder.makeIf(builder.makeBinary(EqInt32, builder.makeLocalGet(temp, i32), builder.makeConst(Literal(int32_t( @@ -1074,7 +1074,7 @@ struct RemoveUnusedBrs : public WalkerPass> { iff->finalize(); Expression* replacement = iff; if (tee) { - set->setTee(false); + set->makeSet(); // We need a block too. replacement = builder.makeSequence(iff, get // reuse the get diff --git a/src/passes/SSAify.cpp b/src/passes/SSAify.cpp index bcafb2784bd..df32fc77bee 100644 --- a/src/passes/SSAify.cpp +++ b/src/passes/SSAify.cpp @@ -154,7 +154,7 @@ struct SSAify : public Pass { if (set) { // a set exists, just add a tee of its value auto* value = set->value; - auto* tee = builder.makeLocalTee(new_, value); + auto* tee = builder.makeLocalTee(new_, value, get->type); set->value = tee; // the value may have been something we tracked the location // of. if so, update that, since we moved it into the tee diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index 6b76faed32e..a3fa4a34d87 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -256,7 +256,7 @@ struct SimplifyLocals } else { this->replaceCurrent(set); assert(!set->isTee()); - set->setTee(true); + set->makeTee(this->getFunction()->getLocalType(set->index)); } // reuse the local.get that is dying *found->second.item = curr; @@ -271,7 +271,7 @@ struct SimplifyLocals auto* set = curr->value->dynCast(); if (set) { assert(set->isTee()); - set->setTee(false); + set->makeSet(); this->replaceCurrent(set); } } @@ -559,7 +559,7 @@ struct SimplifyLocals auto* set = (*breakLocalSetPointer)->template cast(); if (br->condition) { br->value = set; - set->setTee(true); + set->makeTee(this->getFunction()->getLocalType(set->index)); *breakLocalSetPointer = this->getModule()->allocator.template alloc(); // in addition, as this is a conditional br that now has a value, it now @@ -728,7 +728,8 @@ struct SimplifyLocals ifTrueBlock->finalize(); assert(ifTrueBlock->type != none); // Update the ifFalse side. - iff->ifFalse = builder.makeLocalGet(set->index, set->value->type); + iff->ifFalse = builder.makeLocalGet( + set->index, this->getFunction()->getLocalType(set->index)); iff->finalize(); // update type // Update the get count. getCounter.num[set->index]++; diff --git a/src/passes/Untee.cpp b/src/passes/Untee.cpp index 79c76b988fe..6e5fb489b4c 100644 --- a/src/passes/Untee.cpp +++ b/src/passes/Untee.cpp @@ -41,9 +41,10 @@ struct Untee : public WalkerPass> { } else { // a normal tee. replace with set and get Builder builder(*getModule()); - replaceCurrent(builder.makeSequence( - curr, builder.makeLocalGet(curr->index, curr->value->type))); - curr->setTee(false); + LocalGet* get = builder.makeLocalGet( + curr->index, getFunction()->getLocalType(curr->index)); + replaceCurrent(builder.makeSequence(curr, get)); + curr->makeSet(); } } } diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp index 26fd2bd0e2d..48a55ed89b9 100644 --- a/src/passes/Vacuum.cpp +++ b/src/passes/Vacuum.cpp @@ -346,7 +346,7 @@ struct Vacuum : public WalkerPass> { // a drop of a tee is a set if (auto* set = curr->value->dynCast()) { assert(set->isTee()); - set->setTee(false); + set->makeSet(); replaceCurrent(set); return; } diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h index 3dcb5c665b6..ce302fac61d 100644 --- a/src/tools/fuzzing.h +++ b/src/tools/fuzzing.h @@ -1277,7 +1277,7 @@ class TranslateToFuzzReader { } auto* value = make(valueType); if (tee) { - return builder.makeLocalTee(pick(locals), value); + return builder.makeLocalTee(pick(locals), value, valueType); } else { return builder.makeLocalSet(pick(locals), value); } diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 22fa0e6427e..918e6a4ab5e 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -232,14 +232,15 @@ class Builder { auto* ret = allocator.alloc(); ret->index = index; ret->value = value; + ret->makeSet(); ret->finalize(); return ret; } - LocalSet* makeLocalTee(Index index, Expression* value) { + LocalSet* makeLocalTee(Index index, Expression* value, Type type) { auto* ret = allocator.alloc(); ret->index = index; ret->value = value; - ret->setTee(true); + ret->makeTee(type); return ret; } GlobalGet* makeGlobalGet(Name name, Type type) { diff --git a/src/wasm.h b/src/wasm.h index fba962bdb75..cc2070eb252 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -709,8 +709,9 @@ class LocalSet : public SpecificExpression { Index index; Expression* value; - bool isTee(); - void setTee(bool is); + bool isTee() const; + void makeTee(Type type); + void makeSet(); }; class GlobalGet : public SpecificExpression { diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index d86eea8f5c1..2787725a638 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -2459,8 +2459,11 @@ void WasmBinaryBuilder::visitLocalSet(LocalSet* curr, uint8_t code) { throwError("bad local.set index"); } curr->value = popNonVoidExpression(); - curr->type = curr->value->type; - curr->setTee(code == BinaryConsts::LocalTee); + if (code == BinaryConsts::LocalTee) { + curr->makeTee(currFunction->getLocalType(curr->index)); + } else { + curr->makeSet(); + } curr->finalize(); } diff --git a/src/wasm/wasm-emscripten.cpp b/src/wasm/wasm-emscripten.cpp index 9a0b145dae1..be4e51b9e9c 100644 --- a/src/wasm/wasm-emscripten.cpp +++ b/src/wasm/wasm-emscripten.cpp @@ -112,7 +112,7 @@ inline Expression* stackBoundsCheck(Builder& builder, auto check = builder.makeIf(builder.makeBinary( BinaryOp::LtUInt32, - builder.makeLocalTee(newSP, value), + builder.makeLocalTee(newSP, value, stackPointer->type), builder.makeGlobalGet(stackLimit->name, stackLimit->type)), builder.makeCall(handler, {}, none)); // (global.set $__stack_pointer (local.get $newSP)) @@ -172,7 +172,7 @@ void EmscriptenGlueGenerator::generateStackAllocFunction() { const static uint32_t bitMask = bitAlignment - 1; Const* subConst = builder.makeConst(Literal(~bitMask)); Binary* maskedSub = builder.makeBinary(AndInt32, sub, subConst); - LocalSet* teeStackLocal = builder.makeLocalTee(1, maskedSub); + LocalSet* teeStackLocal = builder.makeLocalTee(1, maskedSub, i32); Expression* storeStack = generateStoreStackPointer(function, teeStackLocal); Block* block = builder.makeBlock(); diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 10b6aead765..1319d80fc3e 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -997,7 +997,7 @@ Expression* SExpressionWasmBuilder::makeLocalTee(Element& s) { auto ret = allocator.alloc(); ret->index = getLocalIndex(*s[1]); ret->value = parseExpression(s[2]); - ret->setTee(true); + ret->makeTee(currFunction->getLocalType(ret->index)); ret->finalize(); return ret; } @@ -1006,7 +1006,7 @@ Expression* SExpressionWasmBuilder::makeLocalSet(Element& s) { auto ret = allocator.alloc(); ret->index = getLocalIndex(*s[1]); ret->value = parseExpression(s[2]); - ret->setTee(false); + ret->makeSet(); ret->finalize(); return ret; } diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index c6de444c169..0efc8cf4f96 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -718,15 +718,15 @@ void FunctionValidator::visitLocalSet(LocalSet* curr) { "local.set index must be small enough")) { if (curr->value->type != unreachable) { if (curr->type != none) { // tee is ok anyhow - shouldBeEqualOrFirstIsUnreachable(curr->value->type, - curr->type, - curr, - "local.set type must be correct"); + shouldBeEqual(getFunction()->getLocalType(curr->index), + curr->type, + curr, + "local.set type must be correct"); } - shouldBeEqual(getFunction()->getLocalType(curr->index), - curr->value->type, + shouldBeEqual(curr->value->type, + getFunction()->getLocalType(curr->index), curr, - "local.set type must match function"); + "local.set's value type must be correct"); } } } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 83829418e46..783d51e0fb4 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -442,24 +442,23 @@ void CallIndirect::finalize() { } } -bool LocalSet::isTee() { return type != none; } +bool LocalSet::isTee() const { return type != none; } -void LocalSet::setTee(bool is) { - if (is) { - type = value->type; - } else { - type = none; - } +// Changes to local.tee. The type of the local should be given. +void LocalSet::makeTee(Type type_) { + type = type_; + finalize(); // type may need to be unreachable +} + +// Changes to local.set. +void LocalSet::makeSet() { + type = none; finalize(); // type may need to be unreachable } void LocalSet::finalize() { if (value->type == unreachable) { type = unreachable; - } else if (isTee()) { - type = value->type; - } else { - type = none; } } diff --git a/test/binaryen.js/kitchen-sink.js b/test/binaryen.js/kitchen-sink.js index 83e3d4a44d5..982bcc92791 100644 --- a/test/binaryen.js/kitchen-sink.js +++ b/test/binaryen.js/kitchen-sink.js @@ -469,7 +469,7 @@ function test_core() { ), module.drop(module.local.get(0, Binaryen.i32)), module.local.set(0, makeInt32(101)), - module.drop(module.local.tee(0, makeInt32(102))), + module.drop(module.local.tee(0, makeInt32(102), Binaryen.i32)), module.i32.load(0, 0, makeInt32(1)), module.i64.load16_s(2, 1, makeInt32(8)), module.f32.load(0, 0, makeInt32(2)), diff --git a/test/binaryen.js/kitchen-sink.js.txt b/test/binaryen.js/kitchen-sink.js.txt index a740bd42cfe..74a11d48c50 100644 --- a/test/binaryen.js/kitchen-sink.js.txt +++ b/test/binaryen.js/kitchen-sink.js.txt @@ -1581,7 +1581,7 @@ int main() { expressions[736] = BinaryenConst(the_module, BinaryenLiteralInt32(101)); expressions[737] = BinaryenLocalSet(the_module, 0, expressions[736]); expressions[738] = BinaryenConst(the_module, BinaryenLiteralInt32(102)); - expressions[739] = BinaryenLocalTee(the_module, 0, expressions[738]); + expressions[739] = BinaryenLocalTee(the_module, 0, expressions[738], 2); expressions[740] = BinaryenDrop(the_module, expressions[739]); expressions[741] = BinaryenConst(the_module, BinaryenLiteralInt32(1)); expressions[742] = BinaryenLoad(the_module, 4, 1, 0, 0, 2, expressions[741]); @@ -10107,7 +10107,7 @@ module loaded from binary form: ) ) -[wasm-validator error in function func] i32 != i64: local.set type must match function, on +[wasm-validator error in function func] i64 != i32: local.set's value type must be correct, on [none] (local.set $0 [i64] (i64.const 1234) ) diff --git a/test/example/c-api-kitchen-sink.c b/test/example/c-api-kitchen-sink.c index b3419d1840b..6284cdd3911 100644 --- a/test/example/c-api-kitchen-sink.c +++ b/test/example/c-api-kitchen-sink.c @@ -648,7 +648,9 @@ void test_core() { BinaryenTypeInt32())), BinaryenDrop(module, BinaryenLocalGet(module, 0, BinaryenTypeInt32())), BinaryenLocalSet(module, 0, makeInt32(module, 101)), - BinaryenDrop(module, BinaryenLocalTee(module, 0, makeInt32(module, 102))), + BinaryenDrop( + module, + BinaryenLocalTee(module, 0, makeInt32(module, 102), BinaryenTypeInt32())), BinaryenLoad(module, 4, 0, 0, 0, BinaryenTypeInt32(), makeInt32(module, 1)), BinaryenLoad(module, 2, 1, 2, 1, BinaryenTypeInt64(), makeInt32(module, 8)), BinaryenLoad( diff --git a/test/example/c-api-kitchen-sink.txt b/test/example/c-api-kitchen-sink.txt index ca9061e63f6..a75fd738018 100644 --- a/test/example/c-api-kitchen-sink.txt +++ b/test/example/c-api-kitchen-sink.txt @@ -1604,7 +1604,7 @@ int main() { expressions[747] = BinaryenConst(the_module, BinaryenLiteralInt32(101)); expressions[748] = BinaryenLocalSet(the_module, 0, expressions[747]); expressions[749] = BinaryenConst(the_module, BinaryenLiteralInt32(102)); - expressions[750] = BinaryenLocalTee(the_module, 0, expressions[749]); + expressions[750] = BinaryenLocalTee(the_module, 0, expressions[749], 2); expressions[751] = BinaryenDrop(the_module, expressions[750]); expressions[752] = BinaryenConst(the_module, BinaryenLiteralInt32(1)); expressions[753] = BinaryenLoad(the_module, 4, 0, 0, 0, 2, expressions[752]); diff --git a/test/example/c-api-relooper-unreachable-if.cpp b/test/example/c-api-relooper-unreachable-if.cpp index 3b35221ec23..329b0e3b7b0 100644 --- a/test/example/c-api-relooper-unreachable-if.cpp +++ b/test/example/c-api-relooper-unreachable-if.cpp @@ -242,7 +242,8 @@ int main() { expressions[76] = BinaryenLoad(the_module, 4, 0, 0, 0, BinaryenTypeInt32(), expressions[75]); expressions[77] = BinaryenConst(the_module, BinaryenLiteralInt32(128)); expressions[78] = BinaryenBinary(the_module, 1, expressions[76], expressions[77]); - expressions[79] = BinaryenLocalTee(the_module, 3, expressions[78]); + expressions[79] = + BinaryenLocalTee(the_module, 3, expressions[78], BinaryenTypeInt32()); expressions[80] = BinaryenStore( the_module, 4, 0, 0, expressions[75], expressions[79], BinaryenTypeInt32()); expressions[81] = BinaryenLocalGet(the_module, 3, BinaryenTypeInt32()); @@ -333,7 +334,8 @@ int main() { expressions[123] = BinaryenLoad(the_module, 4, 0, 0, 0, BinaryenTypeInt32(), expressions[122]); expressions[124] = BinaryenConst(the_module, BinaryenLiteralInt32(128)); expressions[125] = BinaryenBinary(the_module, 1, expressions[123], expressions[124]); - expressions[126] = BinaryenLocalTee(the_module, 5, expressions[125]); + expressions[126] = + BinaryenLocalTee(the_module, 5, expressions[125], BinaryenTypeInt32()); expressions[127] = BinaryenStore(the_module, 4, 0, @@ -497,7 +499,8 @@ int main() { expressions[186] = BinaryenLoad(the_module, 4, 0, 0, 0, BinaryenTypeInt32(), expressions[185]); expressions[187] = BinaryenConst(the_module, BinaryenLiteralInt32(128)); expressions[188] = BinaryenBinary(the_module, 1, expressions[186], expressions[187]); - expressions[189] = BinaryenLocalTee(the_module, 6, expressions[188]); + expressions[189] = + BinaryenLocalTee(the_module, 6, expressions[188], BinaryenTypeInt32()); expressions[190] = BinaryenStore(the_module, 4, 0,