Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor getMaxBits() out of OptimizeInstructions and add beginnings of unit testing for it #3019

Merged
merged 2 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 232 additions & 69 deletions src/ir/bits.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,92 +20,255 @@
#include "ir/literal-utils.h"
#include "support/bits.h"
#include "wasm-builder.h"
#include <ir/load-utils.h>

namespace wasm {

struct Bits {
// get a mask to keep only the low # of bits
static int32_t lowBitMask(int32_t bits) {
uint32_t ret = -1;
if (bits >= 32) {
return ret;
}
return ret >> (32 - bits);
namespace Bits {

// get a mask to keep only the low # of bits
inline int32_t lowBitMask(int32_t bits) {
uint32_t ret = -1;
if (bits >= 32) {
return ret;
}
return ret >> (32 - bits);
}

// checks if the input is a mask of lower bits, i.e., all 1s up to some high
// bit, and all zeros from there. returns the number of masked bits, or 0 if
// this is not such a mask
static uint32_t getMaskedBits(uint32_t mask) {
if (mask == uint32_t(-1)) {
return 32; // all the bits
}
if (mask == 0) {
return 0; // trivially not a mask
}
// otherwise, see if x & (x + 1) turns this into non-zero value
// 00011111 & (00011111 + 1) => 0
if (mask & (mask + 1)) {
return 0;
}
// this is indeed a mask
return 32 - CountLeadingZeroes(mask);
// checks if the input is a mask of lower bits, i.e., all 1s up to some high
// bit, and all zeros from there. returns the number of masked bits, or 0 if
// this is not such a mask
inline uint32_t getMaskedBits(uint32_t mask) {
if (mask == uint32_t(-1)) {
return 32; // all the bits
}
if (mask == 0) {
return 0; // trivially not a mask
}
// otherwise, see if x & (x + 1) turns this into non-zero value
// 00011111 & (00011111 + 1) => 0
if (mask & (mask + 1)) {
return 0;
}
// this is indeed a mask
return 32 - CountLeadingZeroes(mask);
}

// gets the number of effective shifts a shift operation does. In
// wasm, only 5 bits matter for 32-bit shifts, and 6 for 64.
inline Index getEffectiveShifts(Index amount, Type type) {
if (type == Type::i32) {
return amount & 31;
} else if (type == Type::i64) {
return amount & 63;
}
WASM_UNREACHABLE("unexpected type");
}

// gets the number of effective shifts a shift operation does. In
// wasm, only 5 bits matter for 32-bit shifts, and 6 for 64.
static Index getEffectiveShifts(Index amount, Type type) {
if (type == Type::i32) {
return amount & 31;
} else if (type == Type::i64) {
return amount & 63;
inline Index getEffectiveShifts(Expression* expr) {
auto* amount = expr->cast<Const>();
if (amount->type == Type::i32) {
return getEffectiveShifts(amount->value.geti32(), Type::i32);
} else if (amount->type == Type::i64) {
return getEffectiveShifts(amount->value.geti64(), Type::i64);
}
WASM_UNREACHABLE("unexpected type");
}

inline Expression* makeSignExt(Expression* value, Index bytes, Module& wasm) {
if (value->type == Type::i32) {
if (bytes == 1 || bytes == 2) {
auto shifts = bytes == 1 ? 24 : 16;
Builder builder(wasm);
return builder.makeBinary(
ShrSInt32,
builder.makeBinary(
ShlInt32,
value,
LiteralUtils::makeFromInt32(shifts, Type::i32, wasm)),
LiteralUtils::makeFromInt32(shifts, Type::i32, wasm));
}
assert(bytes == 4);
return value; // nothing to do
} else {
assert(value->type == Type::i64);
if (bytes == 1 || bytes == 2 || bytes == 4) {
auto shifts = bytes == 1 ? 56 : (bytes == 2 ? 48 : 32);
Builder builder(wasm);
return builder.makeBinary(
ShrSInt64,
builder.makeBinary(
ShlInt64,
value,
LiteralUtils::makeFromInt32(shifts, Type::i64, wasm)),
LiteralUtils::makeFromInt32(shifts, Type::i64, wasm));
}
WASM_UNREACHABLE("unexpected type");
assert(bytes == 8);
return value; // nothing to do
}
}

static Index getEffectiveShifts(Expression* expr) {
auto* amount = expr->cast<Const>();
if (amount->type == Type::i32) {
return getEffectiveShifts(amount->value.geti32(), Type::i32);
} else if (amount->type == Type::i64) {
return getEffectiveShifts(amount->value.geti64(), Type::i64);
// getMaxBits() helper that has pessimistic results for the bits used in locals.
struct DummyLocalInfoProvider {
Index getMaxBitsForLocal(LocalGet* get) {
if (get->type == Type::i32) {
return 32;
}
WASM_UNREACHABLE("unexpected type");
if (get->type == Type::i32) {
return 64;
}
WASM_UNREACHABLE("type has no integer bit size");
}
};

static Expression* makeSignExt(Expression* value, Index bytes, Module& wasm) {
if (value->type == Type::i32) {
if (bytes == 1 || bytes == 2) {
auto shifts = bytes == 1 ? 24 : 16;
Builder builder(wasm);
return builder.makeBinary(
ShrSInt32,
builder.makeBinary(
ShlInt32,
value,
LiteralUtils::makeFromInt32(shifts, Type::i32, wasm)),
LiteralUtils::makeFromInt32(shifts, Type::i32, wasm));
// Returns the maximum amount of bits used in an integer expression
// not extremely precise (doesn't look into add operands, etc.)
// LocalInfoProvider is an optional class that can provide answers about
// local.get.
template<typename LocalInfoProvider = DummyLocalInfoProvider>
Index getMaxBits(Expression* curr,
LocalInfoProvider* localInfoProvider = nullptr) {
if (auto* const_ = curr->dynCast<Const>()) {
switch (curr->type.getSingle()) {
case Type::i32:
return 32 - const_->value.countLeadingZeroes().geti32();
case Type::i64:
return 64 - const_->value.countLeadingZeroes().geti64();
default:
WASM_UNREACHABLE("invalid type");
}
} else if (auto* binary = curr->dynCast<Binary>()) {
switch (binary->op) {
// 32-bit
case AddInt32:
case SubInt32:
case MulInt32:
case DivSInt32:
case DivUInt32:
case RemSInt32:
case RemUInt32:
case RotLInt32:
case RotRInt32:
return 32;
case AndInt32:
return std::min(getMaxBits(binary->left, localInfoProvider),
getMaxBits(binary->right, localInfoProvider));
case OrInt32:
case XorInt32:
return std::max(getMaxBits(binary->left, localInfoProvider),
getMaxBits(binary->right, localInfoProvider));
case ShlInt32: {
if (auto* shifts = binary->right->dynCast<Const>()) {
return std::min(Index(32),
getMaxBits(binary->left, localInfoProvider) +
Bits::getEffectiveShifts(shifts));
}
return 32;
}
assert(bytes == 4);
return value; // nothing to do
} else {
assert(value->type == Type::i64);
if (bytes == 1 || bytes == 2 || bytes == 4) {
auto shifts = bytes == 1 ? 56 : (bytes == 2 ? 48 : 32);
Builder builder(wasm);
return builder.makeBinary(
ShrSInt64,
builder.makeBinary(
ShlInt64,
value,
LiteralUtils::makeFromInt32(shifts, Type::i64, wasm)),
LiteralUtils::makeFromInt32(shifts, Type::i64, wasm));
case ShrUInt32: {
if (auto* shift = binary->right->dynCast<Const>()) {
auto maxBits = getMaxBits(binary->left, localInfoProvider);
auto shifts =
std::min(Index(Bits::getEffectiveShifts(shift)),
maxBits); // can ignore more shifts than zero us out
return std::max(Index(0), maxBits - shifts);
}
return 32;
}
case ShrSInt32: {
if (auto* shift = binary->right->dynCast<Const>()) {
auto maxBits = getMaxBits(binary->left, localInfoProvider);
if (maxBits == 32) {
return 32;
}
auto shifts =
std::min(Index(Bits::getEffectiveShifts(shift)),
maxBits); // can ignore more shifts than zero us out
return std::max(Index(0), maxBits - shifts);
}
return 32;
}
// 64-bit TODO
// comparisons
case EqInt32:
case NeInt32:
case LtSInt32:
case LtUInt32:
case LeSInt32:
case LeUInt32:
case GtSInt32:
case GtUInt32:
case GeSInt32:
case GeUInt32:
case EqInt64:
case NeInt64:
case LtSInt64:
case LtUInt64:
case LeSInt64:
case LeUInt64:
case GtSInt64:
case GtUInt64:
case GeSInt64:
case GeUInt64:
case EqFloat32:
case NeFloat32:
case LtFloat32:
case LeFloat32:
case GtFloat32:
case GeFloat32:
case EqFloat64:
case NeFloat64:
case LtFloat64:
case LeFloat64:
case GtFloat64:
case GeFloat64:
return 1;
default: {
}
assert(bytes == 8);
return value; // nothing to do
}
} else if (auto* unary = curr->dynCast<Unary>()) {
switch (unary->op) {
case ClzInt32:
case CtzInt32:
case PopcntInt32:
return 6;
case ClzInt64:
case CtzInt64:
case PopcntInt64:
return 7;
case EqZInt32:
case EqZInt64:
return 1;
case WrapInt64:
return std::min(Index(32), getMaxBits(unary->value, localInfoProvider));
default: {
}
}
} else if (auto* set = curr->dynCast<LocalSet>()) {
// a tee passes through the value
return getMaxBits(set->value, localInfoProvider);
} else if (auto* get = curr->dynCast<LocalGet>()) {
return localInfoProvider->getMaxBitsForLocal(get);
} else if (auto* load = curr->dynCast<Load>()) {
// if signed, then the sign-extension might fill all the bits
// if unsigned, then we have a limit
if (LoadUtils::isSignRelevant(load) && !load->signed_) {
return 8 * load->bytes;
}
}
};
switch (curr->type.getSingle()) {
case Type::i32:
return 32;
case Type::i64:
return 64;
case Type::unreachable:
return 64; // not interesting, but don't crash
default:
WASM_UNREACHABLE("invalid type");
}
}

} // namespace Bits

} // namespace wasm

Expand Down
Loading