diff --git a/src/ir/bits.h b/src/ir/bits.h index 8b28bb0319b..e0bca8d879b 100644 --- a/src/ir/bits.h +++ b/src/ir/bits.h @@ -20,92 +20,255 @@ #include "ir/literal-utils.h" #include "support/bits.h" #include "wasm-builder.h" +#include namespace wasm { -struct Bits { - // get a mask to keep only the low # of bits - static int32_t lowBitMask(int32_t bits) { - uint32_t ret = -1; - if (bits >= 32) { - return ret; - } - return ret >> (32 - bits); +namespace Bits { + +// get a mask to keep only the low # of bits +inline int32_t lowBitMask(int32_t bits) { + uint32_t ret = -1; + if (bits >= 32) { + return ret; } + return ret >> (32 - bits); +} - // checks if the input is a mask of lower bits, i.e., all 1s up to some high - // bit, and all zeros from there. returns the number of masked bits, or 0 if - // this is not such a mask - static uint32_t getMaskedBits(uint32_t mask) { - if (mask == uint32_t(-1)) { - return 32; // all the bits - } - if (mask == 0) { - return 0; // trivially not a mask - } - // otherwise, see if x & (x + 1) turns this into non-zero value - // 00011111 & (00011111 + 1) => 0 - if (mask & (mask + 1)) { - return 0; - } - // this is indeed a mask - return 32 - CountLeadingZeroes(mask); +// checks if the input is a mask of lower bits, i.e., all 1s up to some high +// bit, and all zeros from there. returns the number of masked bits, or 0 if +// this is not such a mask +inline uint32_t getMaskedBits(uint32_t mask) { + if (mask == uint32_t(-1)) { + return 32; // all the bits + } + if (mask == 0) { + return 0; // trivially not a mask } + // otherwise, see if x & (x + 1) turns this into non-zero value + // 00011111 & (00011111 + 1) => 0 + if (mask & (mask + 1)) { + return 0; + } + // this is indeed a mask + return 32 - CountLeadingZeroes(mask); +} + +// gets the number of effective shifts a shift operation does. In +// wasm, only 5 bits matter for 32-bit shifts, and 6 for 64. +inline Index getEffectiveShifts(Index amount, Type type) { + if (type == Type::i32) { + return amount & 31; + } else if (type == Type::i64) { + return amount & 63; + } + WASM_UNREACHABLE("unexpected type"); +} - // gets the number of effective shifts a shift operation does. In - // wasm, only 5 bits matter for 32-bit shifts, and 6 for 64. - static Index getEffectiveShifts(Index amount, Type type) { - if (type == Type::i32) { - return amount & 31; - } else if (type == Type::i64) { - return amount & 63; +inline Index getEffectiveShifts(Expression* expr) { + auto* amount = expr->cast(); + if (amount->type == Type::i32) { + return getEffectiveShifts(amount->value.geti32(), Type::i32); + } else if (amount->type == Type::i64) { + return getEffectiveShifts(amount->value.geti64(), Type::i64); + } + WASM_UNREACHABLE("unexpected type"); +} + +inline Expression* makeSignExt(Expression* value, Index bytes, Module& wasm) { + if (value->type == Type::i32) { + if (bytes == 1 || bytes == 2) { + auto shifts = bytes == 1 ? 24 : 16; + Builder builder(wasm); + return builder.makeBinary( + ShrSInt32, + builder.makeBinary( + ShlInt32, + value, + LiteralUtils::makeFromInt32(shifts, Type::i32, wasm)), + LiteralUtils::makeFromInt32(shifts, Type::i32, wasm)); + } + assert(bytes == 4); + return value; // nothing to do + } else { + assert(value->type == Type::i64); + if (bytes == 1 || bytes == 2 || bytes == 4) { + auto shifts = bytes == 1 ? 56 : (bytes == 2 ? 48 : 32); + Builder builder(wasm); + return builder.makeBinary( + ShrSInt64, + builder.makeBinary( + ShlInt64, + value, + LiteralUtils::makeFromInt32(shifts, Type::i64, wasm)), + LiteralUtils::makeFromInt32(shifts, Type::i64, wasm)); } - WASM_UNREACHABLE("unexpected type"); + assert(bytes == 8); + return value; // nothing to do } +} - static Index getEffectiveShifts(Expression* expr) { - auto* amount = expr->cast(); - if (amount->type == Type::i32) { - return getEffectiveShifts(amount->value.geti32(), Type::i32); - } else if (amount->type == Type::i64) { - return getEffectiveShifts(amount->value.geti64(), Type::i64); +// getMaxBits() helper that has pessimistic results for the bits used in locals. +struct DummyLocalInfoProvider { + Index getMaxBitsForLocal(LocalGet* get) { + if (get->type == Type::i32) { + return 32; } - WASM_UNREACHABLE("unexpected type"); + if (get->type == Type::i32) { + return 64; + } + WASM_UNREACHABLE("type has no integer bit size"); } +}; - static Expression* makeSignExt(Expression* value, Index bytes, Module& wasm) { - if (value->type == Type::i32) { - if (bytes == 1 || bytes == 2) { - auto shifts = bytes == 1 ? 24 : 16; - Builder builder(wasm); - return builder.makeBinary( - ShrSInt32, - builder.makeBinary( - ShlInt32, - value, - LiteralUtils::makeFromInt32(shifts, Type::i32, wasm)), - LiteralUtils::makeFromInt32(shifts, Type::i32, wasm)); +// Returns the maximum amount of bits used in an integer expression +// not extremely precise (doesn't look into add operands, etc.) +// LocalInfoProvider is an optional class that can provide answers about +// local.get. +template +Index getMaxBits(Expression* curr, + LocalInfoProvider* localInfoProvider = nullptr) { + if (auto* const_ = curr->dynCast()) { + switch (curr->type.getSingle()) { + case Type::i32: + return 32 - const_->value.countLeadingZeroes().geti32(); + case Type::i64: + return 64 - const_->value.countLeadingZeroes().geti64(); + default: + WASM_UNREACHABLE("invalid type"); + } + } else if (auto* binary = curr->dynCast()) { + switch (binary->op) { + // 32-bit + case AddInt32: + case SubInt32: + case MulInt32: + case DivSInt32: + case DivUInt32: + case RemSInt32: + case RemUInt32: + case RotLInt32: + case RotRInt32: + return 32; + case AndInt32: + return std::min(getMaxBits(binary->left, localInfoProvider), + getMaxBits(binary->right, localInfoProvider)); + case OrInt32: + case XorInt32: + return std::max(getMaxBits(binary->left, localInfoProvider), + getMaxBits(binary->right, localInfoProvider)); + case ShlInt32: { + if (auto* shifts = binary->right->dynCast()) { + return std::min(Index(32), + getMaxBits(binary->left, localInfoProvider) + + Bits::getEffectiveShifts(shifts)); + } + return 32; } - assert(bytes == 4); - return value; // nothing to do - } else { - assert(value->type == Type::i64); - if (bytes == 1 || bytes == 2 || bytes == 4) { - auto shifts = bytes == 1 ? 56 : (bytes == 2 ? 48 : 32); - Builder builder(wasm); - return builder.makeBinary( - ShrSInt64, - builder.makeBinary( - ShlInt64, - value, - LiteralUtils::makeFromInt32(shifts, Type::i64, wasm)), - LiteralUtils::makeFromInt32(shifts, Type::i64, wasm)); + case ShrUInt32: { + if (auto* shift = binary->right->dynCast()) { + auto maxBits = getMaxBits(binary->left, localInfoProvider); + auto shifts = + std::min(Index(Bits::getEffectiveShifts(shift)), + maxBits); // can ignore more shifts than zero us out + return std::max(Index(0), maxBits - shifts); + } + return 32; + } + case ShrSInt32: { + if (auto* shift = binary->right->dynCast()) { + auto maxBits = getMaxBits(binary->left, localInfoProvider); + if (maxBits == 32) { + return 32; + } + auto shifts = + std::min(Index(Bits::getEffectiveShifts(shift)), + maxBits); // can ignore more shifts than zero us out + return std::max(Index(0), maxBits - shifts); + } + return 32; + } + // 64-bit TODO + // comparisons + case EqInt32: + case NeInt32: + case LtSInt32: + case LtUInt32: + case LeSInt32: + case LeUInt32: + case GtSInt32: + case GtUInt32: + case GeSInt32: + case GeUInt32: + case EqInt64: + case NeInt64: + case LtSInt64: + case LtUInt64: + case LeSInt64: + case LeUInt64: + case GtSInt64: + case GtUInt64: + case GeSInt64: + case GeUInt64: + case EqFloat32: + case NeFloat32: + case LtFloat32: + case LeFloat32: + case GtFloat32: + case GeFloat32: + case EqFloat64: + case NeFloat64: + case LtFloat64: + case LeFloat64: + case GtFloat64: + case GeFloat64: + return 1; + default: { } - assert(bytes == 8); - return value; // nothing to do + } + } else if (auto* unary = curr->dynCast()) { + switch (unary->op) { + case ClzInt32: + case CtzInt32: + case PopcntInt32: + return 6; + case ClzInt64: + case CtzInt64: + case PopcntInt64: + return 7; + case EqZInt32: + case EqZInt64: + return 1; + case WrapInt64: + return std::min(Index(32), getMaxBits(unary->value, localInfoProvider)); + default: { + } + } + } else if (auto* set = curr->dynCast()) { + // a tee passes through the value + return getMaxBits(set->value, localInfoProvider); + } else if (auto* get = curr->dynCast()) { + return localInfoProvider->getMaxBitsForLocal(get); + } else if (auto* load = curr->dynCast()) { + // if signed, then the sign-extension might fill all the bits + // if unsigned, then we have a limit + if (LoadUtils::isSignRelevant(load) && !load->signed_) { + return 8 * load->bytes; } } -}; + switch (curr->type.getSingle()) { + case Type::i32: + return 32; + case Type::i64: + return 64; + case Type::unreachable: + return 64; // not interesting, but don't crash + default: + WASM_UNREACHABLE("invalid type"); + } +} + +} // namespace Bits } // namespace wasm diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 860ea3ac748..7f66ca7eed9 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -45,152 +46,6 @@ Name F32_EXPR = "f32.expr"; Name F64_EXPR = "f64.expr"; Name ANY_EXPR = "any.expr"; -// Utilities - -// returns the maximum amount of bits used in an integer expression -// not extremely precise (doesn't look into add operands, etc.) -// LocalInfoProvider is an optional class that can provide answers about -// local.get. -template -Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) { - if (auto* const_ = curr->dynCast()) { - switch (curr->type.getSingle()) { - case Type::i32: - return 32 - const_->value.countLeadingZeroes().geti32(); - case Type::i64: - return 64 - const_->value.countLeadingZeroes().geti64(); - default: - WASM_UNREACHABLE("invalid type"); - } - } else if (auto* binary = curr->dynCast()) { - switch (binary->op) { - // 32-bit - case AddInt32: - case SubInt32: - case MulInt32: - case DivSInt32: - case DivUInt32: - case RemSInt32: - case RemUInt32: - case RotLInt32: - case RotRInt32: - return 32; - case AndInt32: - return std::min(getMaxBits(binary->left, localInfoProvider), - getMaxBits(binary->right, localInfoProvider)); - case OrInt32: - case XorInt32: - return std::max(getMaxBits(binary->left, localInfoProvider), - getMaxBits(binary->right, localInfoProvider)); - case ShlInt32: { - if (auto* shifts = binary->right->dynCast()) { - return std::min(Index(32), - getMaxBits(binary->left, localInfoProvider) + - Bits::getEffectiveShifts(shifts)); - } - return 32; - } - case ShrUInt32: { - if (auto* shift = binary->right->dynCast()) { - auto maxBits = getMaxBits(binary->left, localInfoProvider); - auto shifts = - std::min(Index(Bits::getEffectiveShifts(shift)), - maxBits); // can ignore more shifts than zero us out - return std::max(Index(0), maxBits - shifts); - } - return 32; - } - case ShrSInt32: { - if (auto* shift = binary->right->dynCast()) { - auto maxBits = getMaxBits(binary->left, localInfoProvider); - if (maxBits == 32) { - return 32; - } - auto shifts = - std::min(Index(Bits::getEffectiveShifts(shift)), - maxBits); // can ignore more shifts than zero us out - return std::max(Index(0), maxBits - shifts); - } - return 32; - } - // 64-bit TODO - // comparisons - case EqInt32: - case NeInt32: - case LtSInt32: - case LtUInt32: - case LeSInt32: - case LeUInt32: - case GtSInt32: - case GtUInt32: - case GeSInt32: - case GeUInt32: - case EqInt64: - case NeInt64: - case LtSInt64: - case LtUInt64: - case LeSInt64: - case LeUInt64: - case GtSInt64: - case GtUInt64: - case GeSInt64: - case GeUInt64: - case EqFloat32: - case NeFloat32: - case LtFloat32: - case LeFloat32: - case GtFloat32: - case GeFloat32: - case EqFloat64: - case NeFloat64: - case LtFloat64: - case LeFloat64: - case GtFloat64: - case GeFloat64: - return 1; - default: {} - } - } else if (auto* unary = curr->dynCast()) { - switch (unary->op) { - case ClzInt32: - case CtzInt32: - case PopcntInt32: - return 6; - case ClzInt64: - case CtzInt64: - case PopcntInt64: - return 7; - case EqZInt32: - case EqZInt64: - return 1; - case WrapInt64: - return std::min(Index(32), getMaxBits(unary->value, localInfoProvider)); - default: {} - } - } else if (auto* set = curr->dynCast()) { - // a tee passes through the value - return getMaxBits(set->value, localInfoProvider); - } else if (auto* get = curr->dynCast()) { - return localInfoProvider->getMaxBitsForLocal(get); - } else if (auto* load = curr->dynCast()) { - // if signed, then the sign-extension might fill all the bits - // if unsigned, then we have a limit - if (LoadUtils::isSignRelevant(load) && !load->signed_) { - return 8 * load->bytes; - } - } - switch (curr->type.getSingle()) { - case Type::i32: - return 32; - case Type::i64: - return 64; - case Type::unreachable: - return 64; // not interesting, but don't crash - default: - WASM_UNREACHABLE("invalid type"); - } -} - // Useful information about locals struct LocalInfo { static const Index kUnknown = Index(-1); @@ -243,7 +98,7 @@ struct LocalScanner : PostWalker { auto* value = Properties::getFallthrough( curr->value, passOptions, getModule()->features); auto& info = localInfo[curr->index]; - info.maxBits = std::max(info.maxBits, getMaxBits(value, this)); + info.maxBits = std::max(info.maxBits, Bits::getMaxBits(value, this)); auto signExtBits = LocalInfo::kUnknown; if (Properties::getSignExtValue(value)) { signExtBits = Properties::getSignExtBits(value); @@ -373,7 +228,7 @@ struct OptimizeInstructions // if the sign-extend input cannot have a sign bit, we don't need it // we also don't need it if it already has an identical-sized sign // extend - if (getMaxBits(ext, this) + extraShifts < bits || + if (Bits::getMaxBits(ext, this) + extraShifts < bits || isSignExted(ext, bits)) { return removeAlmostSignExt(binary); } @@ -538,7 +393,7 @@ struct OptimizeInstructions return binary->left; } } else if (auto maskedBits = Bits::getMaskedBits(mask)) { - if (getMaxBits(binary->left, this) <= maskedBits) { + if (Bits::getMaxBits(binary->left, this) <= maskedBits) { // a mask of lower bits is not needed if we are already smaller return binary->left; } diff --git a/test/example/cpp-unit.cpp b/test/example/cpp-unit.cpp index e6189d9d113..2ba4388d197 100644 --- a/test/example/cpp-unit.cpp +++ b/test/example/cpp-unit.cpp @@ -1,17 +1,41 @@ // test multiple uses of the threadPool -#include +#include -#include +#include #include +#include using namespace wasm; -int main() -{ +void compare(size_t x, size_t y) { + if (x != y) { + std::cout << "comparison error!\n" << x << '\n' << y << '\n'; + abort(); + } +} + +void test_bits() { + Const c; + c.type = Type::i32; + c.value = Literal(int32_t(1)); + compare(Bits::getMaxBits(&c), 1); + c.value = Literal(int32_t(2)); + compare(Bits::getMaxBits(&c), 2); + c.value = Literal(int32_t(3)); + compare(Bits::getMaxBits(&c), 2); +} + +void test_cost() { // Some optimizations assume that the cost of a get is zero, e.g. local-cse. LocalGet get; - assert(CostAnalyzer(&get).cost == 0); + compare(CostAnalyzer(&get).cost, 0); +} + +int main() { + test_bits(); + + test_cost(); std::cout << "Success.\n"; return 0;