Skip to content

Commit

Permalink
Add {u,s}{add,sub,mul}_overflow instructions (bytecodealliance#5784)
Browse files Browse the repository at this point in the history
* add `{u,s}{add,sub,mul}_overflow` with interpreter

* add `{u,s}{add,sub,mul}_overflow` for x64

* add `{u,s}{add,sub,mul}_overflow` for aarch64

* 128bit filetests for `{u,s}{add,sub,mul}_overflow`

* `{u,s}{add,sub,mul}_overflow` emit tests for x64

* `{u,s}{add,sub,mul}_overflow` emit tests for aarch64

* Initial review changes

* add `with_flags_extended` helper

* add `with_flags_chained` helper
  • Loading branch information
T0b1-iOS authored and eduardomourar committed Apr 16, 2023
1 parent d644e07 commit ad06b86
Show file tree
Hide file tree
Showing 27 changed files with 2,194 additions and 98 deletions.
119 changes: 119 additions & 0 deletions cranelift/codegen/meta/src/shared/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2057,6 +2057,125 @@ pub(crate) fn define(
]),
);

{
let of_out = Operand::new("of", i8).with_doc("Overflow flag");
ig.push(
Inst::new(
"uadd_overflow",
r#"
Add integers unsigned with overflow out.
``of`` is set when the addition overflowed.
```text
a &= x + y \pmod 2^B \\
of &= x+y >= 2^B
```
Polymorphic over all scalar integer types, but does not support vector
types.
"#,
&formats.binary,
)
.operands_in(vec![Operand::new("x", iB), Operand::new("y", iB)])
.operands_out(vec![Operand::new("a", iB), of_out.clone()]),
);

ig.push(
Inst::new(
"sadd_overflow",
r#"
Add integers signed with overflow out.
``of`` is set when the addition over- or underflowed.
Polymorphic over all scalar integer types, but does not support vector
types.
"#,
&formats.binary,
)
.operands_in(vec![Operand::new("x", iB), Operand::new("y", iB)])
.operands_out(vec![Operand::new("a", iB), of_out.clone()]),
);

ig.push(
Inst::new(
"usub_overflow",
r#"
Subtract integers unsigned with overflow out.
``of`` is set when the subtraction underflowed.
```text
a &= x - y \pmod 2^B \\
of &= x - y < 0
```
Polymorphic over all scalar integer types, but does not support vector
types.
"#,
&formats.binary,
)
.operands_in(vec![Operand::new("x", iB), Operand::new("y", iB)])
.operands_out(vec![Operand::new("a", iB), of_out.clone()]),
);

ig.push(
Inst::new(
"ssub_overflow",
r#"
Subtract integers signed with overflow out.
``of`` is set when the subtraction over- or underflowed.
Polymorphic over all scalar integer types, but does not support vector
types.
"#,
&formats.binary,
)
.operands_in(vec![Operand::new("x", iB), Operand::new("y", iB)])
.operands_out(vec![Operand::new("a", iB), of_out.clone()]),
);

{
let NarrowScalar = &TypeVar::new(
"NarrowScalar",
"A scalar integer type up to 64 bits",
TypeSetBuilder::new().ints(8..64).build(),
);

ig.push(
Inst::new(
"umul_overflow",
r#"
Multiply integers unsigned with overflow out.
``of`` is set when the multiplication overflowed.
```text
a &= x * y \pmod 2^B \\
of &= x * y > 2^B
```
Polymorphic over all scalar integer types except i128, but does not support vector
types.
"#,
&formats.binary,
)
.operands_in(vec![
Operand::new("x", NarrowScalar),
Operand::new("y", NarrowScalar),
])
.operands_out(vec![Operand::new("a", NarrowScalar), of_out.clone()]),
);

ig.push(
Inst::new(
"smul_overflow",
r#"
Multiply integers signed with overflow out.
``of`` is set when the multiplication over- or underflowed.
Polymorphic over all scalar integer types except i128, but does not support vector
types.
"#,
&formats.binary,
)
.operands_in(vec![
Operand::new("x", NarrowScalar),
Operand::new("y", NarrowScalar),
])
.operands_out(vec![Operand::new("a", NarrowScalar), of_out.clone()]),
);
}
}

let i32_64 = &TypeVar::new(
"i32_64",
"A 32 or 64-bit scalar integer type",
Expand Down
12 changes: 12 additions & 0 deletions cranelift/codegen/src/data_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,18 @@ impl DataValue {
(DataValue::F32(a), DataValue::F32(b)) => a.bits() == b.bits(),
(DataValue::F64(a), DataValue::F64(b)) => a.bits() == b.bits(),

// when testing for bitwise equality, the sign information does not matter
(DataValue::I8(a), DataValue::U8(b)) => *a as u8 == *b,
(DataValue::U8(a), DataValue::I8(b)) => *a == *b as u8,
(DataValue::I16(a), DataValue::U16(b)) => *a as u16 == *b,
(DataValue::U16(a), DataValue::I16(b)) => *a == *b as u16,
(DataValue::I32(a), DataValue::U32(b)) => *a as u32 == *b,
(DataValue::U32(a), DataValue::I32(b)) => *a == *b as u32,
(DataValue::I64(a), DataValue::U64(b)) => *a as u64 == *b,
(DataValue::U64(a), DataValue::I64(b)) => *a == *b as u64,
(DataValue::I128(a), DataValue::U128(b)) => *a as u128 == *b,
(DataValue::U128(a), DataValue::I128(b)) => *a == *b as u128,

// We don't need to worry about F32x4 / F64x2 Since we compare V128 which is already the
// raw bytes anyway
(a, b) => a == b,
Expand Down
61 changes: 52 additions & 9 deletions cranelift/codegen/src/isa/aarch64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,10 @@
(MAdd)
;; Multiply-sub
(MSub)
;; Unsigned-Multiply-add
(UMAddL)
;; Signed-Multiply-add
(SMAddL)
))

(type MoveWideOp
Expand Down Expand Up @@ -1727,6 +1731,9 @@
(decl pure partial lshl_from_u64 (Type u64) ShiftOpAndAmt)
(extern constructor lshl_from_u64 lshl_from_u64)

(decl pure partial ashr_from_u64 (Type u64) ShiftOpAndAmt)
(extern constructor ashr_from_u64 ashr_from_u64)

(decl integral_ty (Type) Type)
(extern extractor integral_ty integral_ty)

Expand Down Expand Up @@ -1966,6 +1973,15 @@
(MInst.AluRRRShift (ALUOp.SubS) size (writable_zero_reg)
src1 src2 shift)))

;; Helper for emitting `cmp` instructions, setting flags, with an arithmetic right-shifted
;; second operand register.
(decl cmp_rr_shift_asr (OperandSize Reg Reg u64) ProducesFlags)
(rule (cmp_rr_shift_asr size src1 src2 shift_amount)
(if-let shift (ashr_from_u64 $I64 shift_amount))
(ProducesFlags.ProducesFlagsSideEffect
(MInst.AluRRRShift (ALUOp.SubS) size (writable_zero_reg)
src1 src2 shift)))

;; Helper for emitting `MInst.AluRRRExtend` instructions.
(decl alu_rrr_extend (ALUOp Type Reg Reg ExtendOp) Reg)
(rule (alu_rrr_extend op ty src1 src2 extend)
Expand All @@ -1988,6 +2004,22 @@
(_ Unit (emit (MInst.AluRRRR op (operand_size ty) dst src1 src2 src3))))
dst))

;; Helper for emitting paired `MInst.AluRRR` instructions
(decl alu_rrr_with_flags_paired (Type Reg Reg ALUOp) ProducesFlags)
(rule (alu_rrr_with_flags_paired ty src1 src2 alu_op)
(let ((dst WritableReg (temp_writable_reg $I64)))
(ProducesFlags.ProducesFlagsReturnsResultWithConsumer
(MInst.AluRRR alu_op (operand_size ty) dst src1 src2)
dst)))

;; Should only be used for AdcS and SbcS
(decl alu_rrr_with_flags_chained (Type Reg Reg ALUOp) ConsumesAndProducesFlags)
(rule (alu_rrr_with_flags_chained ty src1 src2 alu_op)
(let ((dst WritableReg (temp_writable_reg $I64)))
(ConsumesAndProducesFlags.ReturnsReg
(MInst.AluRRR alu_op (operand_size ty) dst src1 src2)
dst)))

;; Helper for emitting `MInst.BitRR` instructions.
(decl bit_rr (BitOp Type Reg) Reg)
(rule (bit_rr op ty src)
Expand Down Expand Up @@ -2335,7 +2367,7 @@
;; immediately by the `MInst.CCmp` instruction.
(decl ccmp (OperandSize Reg Reg NZCV Cond ProducesFlags) ProducesFlags)
(rule (ccmp size rn rm nzcv cond inst_input)
(produces_flags_append inst_input (MInst.CCmp size rn rm nzcv cond)))
(produces_flags_concat inst_input (ProducesFlags.ProducesFlagsSideEffect (MInst.CCmp size rn rm nzcv cond))))

;; Helper for generating `MInst.CCmpImm` instructions.
(decl ccmp_imm (OperandSize Reg UImm5 NZCV Cond) ConsumesFlags)
Expand Down Expand Up @@ -2411,6 +2443,14 @@
(decl msub (Type Reg Reg Reg) Reg)
(rule (msub ty x y z) (alu_rrrr (ALUOp3.MSub) ty x y z))

;; Helpers for generating `umaddl` instructions
(decl umaddl (Reg Reg Reg) Reg)
(rule (umaddl x y z) (alu_rrrr (ALUOp3.UMAddL) $I32 x y z))

;; Helpers for generating `smaddl` instructions
(decl smaddl (Reg Reg Reg) Reg)
(rule (smaddl x y z) (alu_rrrr (ALUOp3.SMAddL) $I32 x y z))

;; Helper for generating `uqadd` instructions.
(decl uqadd (Reg Reg VectorSize) Reg)
(rule (uqadd x y size) (vec_rrr (VecALUOp.Uqadd) x y size))
Expand Down Expand Up @@ -2620,6 +2660,9 @@
(decl orr_imm (Type Reg ImmLogic) Reg)
(rule (orr_imm ty x y) (alu_rr_imm_logic (ALUOp.Orr) ty x y))

(decl orr_shift (Type Reg Reg ShiftOpAndAmt) Reg)
(rule (orr_shift ty x y shift) (alu_rrr_shift (ALUOp.Orr) ty x y shift))

(decl orr_vec (Reg Reg VectorSize) Reg)
(rule (orr_vec x y size) (vec_rrr (VecALUOp.Orr) x y size))

Expand Down Expand Up @@ -3659,12 +3702,12 @@
(rm Reg (put_in_reg y)))
(vec_cmp rn rm in_ty cond)))

;; Determines the appropriate extend op given the value type and whether it is signed.
(decl lower_extend_op (Type bool) ExtendOp)
(rule (lower_extend_op $I8 $true) (ExtendOp.SXTB))
(rule (lower_extend_op $I16 $true) (ExtendOp.SXTH))
(rule (lower_extend_op $I8 $false) (ExtendOp.UXTB))
(rule (lower_extend_op $I16 $false) (ExtendOp.UXTH))
;; Determines the appropriate extend op given the value type and the given ArgumentExtension.
(decl lower_extend_op (Type ArgumentExtension) ExtendOp)
(rule (lower_extend_op $I8 (ArgumentExtension.Sext)) (ExtendOp.SXTB))
(rule (lower_extend_op $I16 (ArgumentExtension.Sext)) (ExtendOp.SXTH))
(rule (lower_extend_op $I8 (ArgumentExtension.Uext)) (ExtendOp.UXTB))
(rule (lower_extend_op $I16 (ArgumentExtension.Uext)) (ExtendOp.UXTH))

;; Integers <= 64-bits.
(rule -2 (lower_icmp_into_reg cond rn rm in_ty out_ty)
Expand All @@ -3675,13 +3718,13 @@
(rule 1 (lower_icmp cond rn rm (fits_in_16 ty))
(if (signed_cond_code cond))
(let ((rn Reg (put_in_reg_sext32 rn)))
(flags_and_cc (cmp_extend (operand_size ty) rn rm (lower_extend_op ty $true)) cond)))
(flags_and_cc (cmp_extend (operand_size ty) rn rm (lower_extend_op ty (ArgumentExtension.Sext))) cond)))
(rule -1 (lower_icmp cond rn (imm12_from_value rm) (fits_in_16 ty))
(let ((rn Reg (put_in_reg_zext32 rn)))
(flags_and_cc (cmp_imm (operand_size ty) rn rm) cond)))
(rule -2 (lower_icmp cond rn rm (fits_in_16 ty))
(let ((rn Reg (put_in_reg_zext32 rn)))
(flags_and_cc (cmp_extend (operand_size ty) rn rm (lower_extend_op ty $false)) cond)))
(flags_and_cc (cmp_extend (operand_size ty) rn rm (lower_extend_op ty (ArgumentExtension.Uext))) cond)))
(rule -3 (lower_icmp cond rn (u64_from_iconst c) ty)
(if (ty_int_ref_scalar_64 ty))
(lower_icmp_const cond rn c ty))
Expand Down
8 changes: 8 additions & 0 deletions cranelift/codegen/src/isa/aarch64/inst/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,14 @@ impl MachInstEmit for Inst {
let (top11, bit15) = match alu_op {
ALUOp3::MAdd => (0b0_00_11011_000, 0),
ALUOp3::MSub => (0b0_00_11011_000, 1),
ALUOp3::UMAddL => {
debug_assert!(size == OperandSize::Size32);
(0b1_00_11011_1_01, 0)
}
ALUOp3::SMAddL => {
debug_assert!(size == OperandSize::Size32);
(0b1_00_11011_0_01, 0)
}
};
let top11 = top11 | size.sf_bit() << 10;
sink.put4(enc_arith_rrrr(top11, rm, bit15, ra, rn, rd));
Expand Down
24 changes: 24 additions & 0 deletions cranelift/codegen/src/isa/aarch64/inst/emit_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,30 @@ fn test_aarch64_binemit() {
"4190039B",
"msub x1, x2, x3, x4",
));
insns.push((
Inst::AluRRRR {
alu_op: ALUOp3::UMAddL,
size: OperandSize::Size32,
rd: writable_xreg(1),
rn: xreg(2),
rm: xreg(3),
ra: xreg(4),
},
"4110A39B",
"umaddl x1, w2, w3, x4",
));
insns.push((
Inst::AluRRRR {
alu_op: ALUOp3::SMAddL,
size: OperandSize::Size32,
rd: writable_xreg(1),
rn: xreg(2),
rm: xreg(3),
ra: xreg(4),
},
"4110239B",
"smaddl x1, w2, w3, x4",
));
insns.push((
Inst::AluRRR {
alu_op: ALUOp::SMulH,
Expand Down
12 changes: 7 additions & 5 deletions cranelift/codegen/src/isa/aarch64/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1191,14 +1191,16 @@ impl Inst {
rm,
ra,
} => {
let op = match alu_op {
ALUOp3::MAdd => "madd",
ALUOp3::MSub => "msub",
let (op, da_size) = match alu_op {
ALUOp3::MAdd => ("madd", size),
ALUOp3::MSub => ("msub", size),
ALUOp3::UMAddL => ("umaddl", OperandSize::Size64),
ALUOp3::SMAddL => ("smaddl", OperandSize::Size64),
};
let rd = pretty_print_ireg(rd.to_reg(), size, allocs);
let rd = pretty_print_ireg(rd.to_reg(), da_size, allocs);
let rn = pretty_print_ireg(rn, size, allocs);
let rm = pretty_print_ireg(rm, size, allocs);
let ra = pretty_print_ireg(ra, size, allocs);
let ra = pretty_print_ireg(ra, da_size, allocs);

format!("{} {}, {}, {}, {}", op, rd, rn, rm, ra)
}
Expand Down
Loading

0 comments on commit ad06b86

Please sign in to comment.