Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stage2: wasm - Implement overflow arithmetic #11321

Merged
merged 2 commits into from
Mar 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 156 additions & 25 deletions src/arch/wasm/CodeGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1303,10 +1303,16 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
.bool_and => self.airBinOp(inst, .@"and"),
.bool_or => self.airBinOp(inst, .@"or"),
.rem => self.airBinOp(inst, .rem),
.shl, .shl_exact => self.airBinOp(inst, .shl),
.shl => self.airWrapBinOp(inst, .shl),
.shl_exact => self.airBinOp(inst, .shl),
.shr, .shr_exact => self.airBinOp(inst, .shr),
.xor => self.airBinOp(inst, .xor),

.add_with_overflow => self.airBinOpOverflow(inst, .add),
.sub_with_overflow => self.airBinOpOverflow(inst, .sub),
.shl_with_overflow => self.airBinOpOverflow(inst, .shl),
.mul_with_overflow => self.airBinOpOverflow(inst, .mul),

.cmp_eq => self.airCmp(inst, .eq),
.cmp_gte => self.airCmp(inst, .gte),
.cmp_gt => self.airCmp(inst, .gt),
Expand Down Expand Up @@ -1461,13 +1467,6 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
.atomic_rmw,
.tag_name,
.mul_add,

// For these 4, probably best to wait until https://github.com/ziglang/zig/issues/10248
// is implemented in the frontend before implementing them here in the wasm backend.
.add_with_overflow,
.sub_with_overflow,
.mul_with_overflow,
.shl_with_overflow,
=> |tag| return self.fail("TODO: Implement wasm inst: {s}", .{@tagName(tag)}),
};
}
Expand Down Expand Up @@ -1754,24 +1753,28 @@ fn airBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
const lhs = try self.resolveInst(bin_op.lhs);
const rhs = try self.resolveInst(bin_op.rhs);
const operand_ty = self.air.typeOfIndex(inst);
const ty = self.air.typeOf(bin_op.lhs);

if (isByRef(operand_ty, self.target)) {
return self.fail("TODO: Implement binary operation for type: {}", .{operand_ty.fmtDebug()});
}

return self.binOp(lhs, rhs, ty, op);
}

fn binOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WValue {
try self.emitWValue(lhs);
try self.emitWValue(rhs);

const bin_ty = self.air.typeOf(bin_op.lhs);
const opcode: wasm.Opcode = buildOpcode(.{
.op = op,
.valtype1 = typeToValtype(bin_ty, self.target),
.signedness = if (bin_ty.isSignedInt()) .signed else .unsigned,
.valtype1 = typeToValtype(ty, self.target),
.signedness = if (ty.isSignedInt()) .signed else .unsigned,
});
try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));

// save the result in a temporary
const bin_local = try self.allocLocal(bin_ty);
const bin_local = try self.allocLocal(ty);
try self.addLabel(.local_set, bin_local.local);
return bin_local;
}
Expand All @@ -1781,18 +1784,21 @@ fn airWrapBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
const lhs = try self.resolveInst(bin_op.lhs);
const rhs = try self.resolveInst(bin_op.rhs);

return self.wrapBinOp(lhs, rhs, self.air.typeOf(bin_op.lhs), op);
}

fn wrapBinOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WValue {
try self.emitWValue(lhs);
try self.emitWValue(rhs);

const bin_ty = self.air.typeOf(bin_op.lhs);
const opcode: wasm.Opcode = buildOpcode(.{
.op = op,
.valtype1 = typeToValtype(bin_ty, self.target),
.signedness = if (bin_ty.isSignedInt()) .signed else .unsigned,
.valtype1 = typeToValtype(ty, self.target),
.signedness = if (ty.isSignedInt()) .signed else .unsigned,
});
try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));

const int_info = bin_ty.intInfo(self.target);
const int_info = ty.intInfo(self.target);
const bitsize = int_info.bits;
const is_signed = int_info.signedness == .signed;
// if target type bitsize is x < 32 and 32 > x < 64, we perform
Expand Down Expand Up @@ -1820,7 +1826,7 @@ fn airWrapBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
}

// save the result in a temporary
const bin_local = try self.allocLocal(bin_ty);
const bin_local = try self.allocLocal(ty);
try self.addLabel(.local_set, bin_local.local);
return bin_local;
}
Expand Down Expand Up @@ -2202,18 +2208,21 @@ fn airCmp(self: *Self, inst: Air.Inst.Index, op: std.math.CompareOperator) Inner
const lhs = try self.resolveInst(bin_op.lhs);
const rhs = try self.resolveInst(bin_op.rhs);
const operand_ty = self.air.typeOf(bin_op.lhs);
return self.cmp(lhs, rhs, operand_ty, op);
}

if (operand_ty.zigTypeTag() == .Optional and !operand_ty.isPtrLikeOptional()) {
fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOperator) InnerError!WValue {
if (ty.zigTypeTag() == .Optional and !ty.isPtrLikeOptional()) {
var buf: Type.Payload.ElemType = undefined;
const payload_ty = operand_ty.optionalChild(&buf);
const payload_ty = ty.optionalChild(&buf);
if (payload_ty.hasRuntimeBitsIgnoreComptime()) {
// When we hit this case, we must check the value of optionals
// that are not pointers. This means first checking against non-null for
// both lhs and rhs, as well as checking the payload are matching of lhs and rhs
return self.cmpOptionals(lhs, rhs, operand_ty, op);
return self.cmpOptionals(lhs, rhs, ty, op);
}
} else if (isByRef(operand_ty, self.target)) {
return self.cmpBigInt(lhs, rhs, operand_ty, op);
} else if (isByRef(ty, self.target)) {
return self.cmpBigInt(lhs, rhs, ty, op);
}

// ensure that when we compare pointers, we emit
Expand All @@ -2229,13 +2238,13 @@ fn airCmp(self: *Self, inst: Air.Inst.Index, op: std.math.CompareOperator) Inner

const signedness: std.builtin.Signedness = blk: {
// by default we tell the operand type is unsigned (i.e. bools and enum values)
if (operand_ty.zigTypeTag() != .Int) break :blk .unsigned;
if (ty.zigTypeTag() != .Int) break :blk .unsigned;

// incase of an actual integer, we emit the correct signedness
break :blk operand_ty.intInfo(self.target).signedness;
break :blk ty.intInfo(self.target).signedness;
};
const opcode: wasm.Opcode = buildOpcode(.{
.valtype1 = typeToValtype(operand_ty, self.target),
.valtype1 = typeToValtype(ty, self.target),
.op = switch (op) {
.lt => .lt,
.lte => .le,
Expand Down Expand Up @@ -3730,3 +3739,125 @@ fn airPtrSliceFieldPtr(self: *Self, inst: Air.Inst.Index, offset: u32) InnerErro
const slice_ptr = try self.resolveInst(ty_op.operand);
return self.buildPointerOffset(slice_ptr, offset, .new);
}

fn airBinOpOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
if (self.liveness.isUnused(inst)) return WValue{ .none = {} };

const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
const lhs = try self.resolveInst(extra.lhs);
const rhs = try self.resolveInst(extra.rhs);
const lhs_ty = self.air.typeOf(extra.lhs);

// We store the bit if it's overflowed or not in this. As it's zero-initialized
// we only need to update it if an overflow (or underflow) occured.
const overflow_bit = try self.allocLocal(Type.initTag(.u1));
const int_info = lhs_ty.intInfo(self.target);
const wasm_bits = toWasmBits(int_info.bits) orelse {
return self.fail("TODO: Implement overflow arithmetic for integer bitsize: {d}", .{int_info.bits});
};

const zero = switch (wasm_bits) {
32 => WValue{ .imm32 = 0 },
64 => WValue{ .imm64 = 0 },
else => unreachable,
};
const int_max = (@as(u65, 1) << @intCast(u7, int_info.bits - @boolToInt(int_info.signedness == .signed))) - 1;
const int_max_wvalue = switch (wasm_bits) {
32 => WValue{ .imm32 = @intCast(u32, int_max) },
64 => WValue{ .imm64 = @intCast(u64, int_max) },
else => unreachable,
};
const int_min = if (int_info.signedness == .unsigned)
@as(i64, 0)
else
-@as(i64, 1) << @intCast(u6, int_info.bits - 1);
const int_min_wvalue = switch (wasm_bits) {
32 => WValue{ .imm32 = @bitCast(u32, @intCast(i32, int_min)) },
64 => WValue{ .imm64 = @bitCast(u64, int_min) },
else => unreachable,
};

if (int_info.signedness == .unsigned and op == .add) {
const diff = try self.binOp(int_max_wvalue, lhs, lhs_ty, .sub);
const cmp_res = try self.cmp(rhs, diff, lhs_ty, .gt);
try self.emitWValue(cmp_res);
try self.addLabel(.local_set, overflow_bit.local);
} else if (int_info.signedness == .unsigned and op == .sub) {
const cmp_res = try self.cmp(lhs, rhs, lhs_ty, .lt);
try self.emitWValue(cmp_res);
try self.addLabel(.local_set, overflow_bit.local);
} else if (int_info.signedness == .signed and op != .shl) {
// for overflow, we first check if lhs is > 0 (or lhs < 0 in case of subtraction). If not, we will not overflow.
// We first create an outer block, where we handle overflow.
// Then we create an inner block, where underflow is handled.
try self.startBlock(.block, wasm.block_empty);
try self.startBlock(.block, wasm.block_empty);
{
try self.emitWValue(lhs);
const cmp_result = try self.cmp(lhs, zero, lhs_ty, .lt);
try self.emitWValue(cmp_result);
}
try self.addLabel(.br_if, 0); // break to outer block, and handle underflow

// handle overflow
{
const diff = try self.binOp(int_max_wvalue, lhs, lhs_ty, .sub);
const cmp_res = try self.cmp(rhs, diff, lhs_ty, if (op == .add) .gt else .lt);
try self.emitWValue(cmp_res);
try self.addLabel(.local_set, overflow_bit.local);
}
try self.addLabel(.br, 1); // break from blocks, and continue regular flow.
try self.endBlock();

// handle underflow
{
const diff = try self.binOp(int_min_wvalue, lhs, lhs_ty, .sub);
const cmp_res = try self.cmp(rhs, diff, lhs_ty, if (op == .add) .lt else .gt);
try self.emitWValue(cmp_res);
try self.addLabel(.local_set, overflow_bit.local);
}
try self.endBlock();
}

const bin_op = if (op == .shl) blk: {
const tmp_val = try self.binOp(lhs, rhs, lhs_ty, op);
const cmp_res = try self.cmp(tmp_val, int_max_wvalue, lhs_ty, .gt);
try self.emitWValue(cmp_res);
try self.addLabel(.local_set, overflow_bit.local);

try self.emitWValue(tmp_val);
try self.emitWValue(int_max_wvalue);
switch (wasm_bits) {
32 => try self.addTag(.i32_and),
64 => try self.addTag(.i64_and),
else => unreachable,
}
try self.addLabel(.local_set, tmp_val.local);
break :blk tmp_val;
} else if (op == .mul) blk: {
const bin_op = try self.wrapBinOp(lhs, rhs, lhs_ty, op);
try self.startBlock(.block, wasm.block_empty);
// check if 0. true => Break out of block as cannot over -or underflow.
try self.emitWValue(lhs);
switch (wasm_bits) {
32 => try self.addTag(.i32_eqz),
64 => try self.addTag(.i64_eqz),
else => unreachable,
}
try self.addLabel(.br_if, 0);
const div = try self.binOp(bin_op, lhs, lhs_ty, .div);
const cmp_res = try self.cmp(div, rhs, lhs_ty, .neq);
try self.emitWValue(cmp_res);
try self.addLabel(.local_set, overflow_bit.local);
try self.endBlock();
break :blk bin_op;
} else try self.wrapBinOp(lhs, rhs, lhs_ty, op);

const result_ptr = try self.allocStack(self.air.typeOfIndex(inst));
try self.store(result_ptr, bin_op, lhs_ty, 0);
const offset = @intCast(u32, lhs_ty.abiSize(self.target));
try self.store(result_ptr, overflow_bit, Type.initTag(.u1), offset);

return result_ptr;
}
7 changes: 0 additions & 7 deletions test/behavior/math.zig
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,6 @@ test "128-bit multiplication" {

test "@addWithOverflow" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
Expand All @@ -661,7 +660,6 @@ test "@addWithOverflow" {

test "small int addition" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
Expand All @@ -686,7 +684,6 @@ test "small int addition" {

test "@mulWithOverflow" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
Expand All @@ -708,7 +705,6 @@ test "@mulWithOverflow" {

test "@subWithOverflow" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
Expand All @@ -730,7 +726,6 @@ test "@subWithOverflow" {

test "@shlWithOverflow" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
Expand All @@ -752,7 +747,6 @@ test "@shlWithOverflow" {

test "overflow arithmetic with u0 values" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO

var result: u0 = undefined;
Expand Down Expand Up @@ -879,7 +873,6 @@ test "quad hex float literal parsing accurate" {
}

test "truncating shift left" {
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO

try testShlTrunc(maxInt(u16));
Expand Down