From f7531885b5b64653cf61b005c45f4c3054661b9f Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Thu, 23 Jan 2025 15:26:27 +0000 Subject: [PATCH] Change how we store and sign/zero extend integers. Previously we stored raw `u64`s and expected the user to remember that they needed to zero/sign extend the underlying integer whenever they wanted to do anything with that. We did not always do this, and we did it incorrectly in a couple of places! This commit introduces a new struct `ArbBitInt` which is basically a pair `(bit_width: u32, value: u64)` which hides the underlying value. To get a raw Rust-level integer you have to call a method, and those methods have names with `sign_extend` and `zero_extend` in them. While one can, of course, call the wrong one, it is now impossible not to sign/zero extend. This struct is currently rather simple, but the API is flexible enough to extend to beyond-64-bit ints transparently. The (fairly extensive) test suite is a bit overkill right now, but is intended to help give us confidence if/when we support more than 64 bit ints in the future. This commit also necessarily requires a full audit of everything to do with ints-in-traces. That means a lot of code churn, but it's absolutely necessary, and (a) makes much clearer where we should sign/zero extend (b) catches some places where we didn't do this but should. This commit isn't perfect. In particular, I'm not very happy that `Const::Int` has both a `TyIdx` that contains a bit width *and* an `ArbBitInt` that separately records a bit width. That feels icky, but doing something neater will require at least some ickiness elsewhere. I'll worry about that another day. --- ykrt/Cargo.toml | 1 + ykrt/src/compile/jitc_yk/arbbitint.rs | 840 ++++++++++++++++++ .../compile/jitc_yk/codegen/x64/lsregalloc.rs | 25 +- ykrt/src/compile/jitc_yk/codegen/x64/mod.rs | 193 ++-- ykrt/src/compile/jitc_yk/jit_ir/mod.rs | 100 +-- ykrt/src/compile/jitc_yk/jit_ir/parser.rs | 17 +- .../src/compile/jitc_yk/jit_ir/well_formed.rs | 22 +- ykrt/src/compile/jitc_yk/mod.rs | 1 + ykrt/src/compile/jitc_yk/opt/analyse.rs | 12 +- ykrt/src/compile/jitc_yk/opt/heapvalues.rs | 29 +- ykrt/src/compile/jitc_yk/opt/mod.rs | 297 ++++--- ykrt/src/compile/jitc_yk/trace_builder.rs | 30 +- ykrt/src/lib.rs | 1 + ykrt/src/log/stats.rs | 1 + 14 files changed, 1196 insertions(+), 373 deletions(-) create mode 100644 ykrt/src/compile/jitc_yk/arbbitint.rs diff --git a/ykrt/Cargo.toml b/ykrt/Cargo.toml index c38008436..6410949b6 100644 --- a/ykrt/Cargo.toml +++ b/ykrt/Cargo.toml @@ -58,3 +58,4 @@ lazy_static = "1.5.0" lrlex = "0.13" lrpar = "0.13" regex = { version = "1.9", features = ["std"] } +proptest = "1.6.0" diff --git a/ykrt/src/compile/jitc_yk/arbbitint.rs b/ykrt/src/compile/jitc_yk/arbbitint.rs new file mode 100644 index 000000000..61d5932bd --- /dev/null +++ b/ykrt/src/compile/jitc_yk/arbbitint.rs @@ -0,0 +1,840 @@ +//! An integer of an arbitrary, dynamic, bit width. +//! +//! This hides away the underlying representation, and forces the user to consider whether they +//! want to zero or sign extend the underlying the integer whenever they want access to a +//! Rust-level integer. +//! +//! Currently only up to 64 bits are supported, though the API is flexible enough to transparently +//! support greater bit widths in the future. + +use super::int_signs::{SignExtend, Truncate}; +use std::{ + hash::Hash, + ops::{BitAnd, BitOr, BitXor}, +}; + +/// An integer of an arbitrary, dynamic, bit width. +/// +/// Currently can only represent a max of 64 bits: this could be extended in the future. +#[derive(Clone, Debug)] +pub(crate) struct ArbBitInt { + bitw: u32, + /// The underlying value. Any bits above `self.bitw` have an undefined value: they may be set + /// or unset. + /// + /// Currently we can only store ints that can fit in 64 bits: in the future we could use + /// another scheme to e.g `Box` bigger integers. + val: u64, +} + +impl ArbBitInt { + /// Create a new `ArbBitInt` that is `width` bits wide and has a value `val`. Any bits above + /// `width` bits are ignored (i.e. it is safe for those bits to be set or unset when calling + /// this function). + pub(crate) fn from_u64(bitw: u32, val: u64) -> Self { + debug_assert!(bitw <= 64); + Self { bitw, val } + } + + /// Create a new `ArbBitInt` that is `width` bits wide and has a value `val`. Any bits above + /// `width` bits are ignored (i.e. it is safe for those bits to be set or unset when calling + /// this function). + #[cfg(test)] + pub(crate) fn from_i64(bitw: u32, val: i64) -> Self { + debug_assert!(bitw <= 64); + Self { + bitw, + val: val as u64, + } + } + + /// Create a new `ArbBitInt` with all `bitw` bits set. This can be seen as equivalent to + /// creating a value of `ubitw::MAX` (when `ubitw` is also a valid Rust type). + pub(crate) fn all_bits_set(bitw: u32) -> Self { + Self { + bitw, + val: u64::MAX, + } + } + + /// How many bits wide is this `ArbBitInt`? + pub(crate) fn bitw(&self) -> u32 { + self.bitw + } + + /// Sign extend this `ArbBitInt` to `to_bitw` bits. + /// + /// # Panics + /// + /// If `to_bitw` is smaller than `self.bitw()`. + pub(crate) fn sign_extend(&self, to_bitw: u32) -> Self { + debug_assert!(to_bitw >= self.bitw && to_bitw <= 64); + Self { + bitw: to_bitw, + val: self.val.sign_extend(self.bitw, to_bitw), + } + } + + /// Zero extend this `ArbBitInt` to `to_bitw` bits. + /// + /// # Panics + /// + /// If `to_bitw` is smaller than `self.bitw()`. + pub(crate) fn zero_extend(&self, to_bitw: u32) -> Self { + debug_assert!(to_bitw >= self.bitw && to_bitw <= 64); + Self { + bitw: to_bitw, + val: self.val.truncate(self.bitw), + } + } + + /// Sign extend the underlying value and, if it is representable as an `i8`, return it. + pub(crate) fn to_sign_ext_i8(&self) -> Option { + i8::try_from(self.val.sign_extend(self.bitw, 64) as i64).ok() + } + + /// Sign extend the underlying value and, if it is representable as an `i16`, return it. + #[allow(unused)] + pub(crate) fn to_sign_ext_i16(&self) -> Option { + i16::try_from(self.val.sign_extend(self.bitw, 64) as i64).ok() + } + + /// Sign extend the underlying value and, if it is representable as an `i32`, return it. + pub(crate) fn to_sign_ext_i32(&self) -> Option { + i32::try_from(self.val.sign_extend(self.bitw, 64) as i64).ok() + } + + /// Sign extend the underlying value and, if it is representable as an `i64`, return it. + pub(crate) fn to_sign_ext_i64(&self) -> Option { + Some(self.val.sign_extend(self.bitw, 64) as i64) + } + + /// zero extend the underlying value and, if it is representable as an `u8`, return it. + pub(crate) fn to_zero_ext_u8(&self) -> Option { + u8::try_from(self.val.truncate(self.bitw)).ok() + } + + /// zero extend the underlying value and, if it is representable as an `u16`, return it. + pub(crate) fn to_zero_ext_u16(&self) -> Option { + u16::try_from(self.val.truncate(self.bitw)).ok() + } + + /// zero extend the underlying value and, if it is representable as an `u32`, return it. + pub(crate) fn to_zero_ext_u32(&self) -> Option { + u32::try_from(self.val.truncate(self.bitw)).ok() + } + + /// zero extend the underlying value and, if it is representable as an `u64`, return it. + pub(crate) fn to_zero_ext_u64(&self) -> Option { + Some(self.val.truncate(self.bitw)) + } + + /// Return a new [ArbBitInt] that performs two's complement wrapping addition on `self` and + /// `other`. + /// + /// # Panics + /// + /// If `self` and `other` are not the same bit width. + pub(crate) fn wrapping_add(&self, other: &Self) -> Self { + debug_assert_eq!(self.bitw, other.bitw); + Self { + bitw: self.bitw, + val: self.val.wrapping_add(other.val), + } + } + + /// Return a new [ArbBitInt] that performs two's complement wrapping multiplication on `self` and + /// `other`. + /// + /// # Panics + /// + /// If `self` and `other` are not the same bit width. + pub(crate) fn wrapping_mul(&self, other: &Self) -> Self { + debug_assert_eq!(self.bitw, other.bitw); + Self { + bitw: self.bitw, + val: self.val.wrapping_mul(other.val), + } + } + + /// Return a new [ArbBitInt] that performs two's complement wrapping subtraction on `self` and + /// `other`. + /// + /// # Panics + /// + /// If `self` and `other` are not the same bit width. + pub(crate) fn wrapping_sub(&self, other: &Self) -> Self { + debug_assert_eq!(self.bitw, other.bitw); + Self { + bitw: self.bitw, + val: self.val.wrapping_sub(other.val), + } + } + + /// Return a new [ArbBitInt] that left shifts `self` by `bits`s or `None` if `bits >= + /// self.bitw()`. + pub(crate) fn checked_shl(&self, bits: u32) -> Option { + if bits < self.bitw { + Some(Self { + bitw: self.bitw, + val: self.val.checked_shl(bits).unwrap(), // unwrap cannot fail + }) + } else { + None + } + } + + /// Return a new [ArbBitInt] that right shifts `self` by `bits` or `None` if `bits >= + /// self.bitw()`. + pub(crate) fn checked_shr(&self, bits: u32) -> Option { + if bits < self.bitw { + Some(Self { + bitw: self.bitw, + val: self.val.checked_shr(bits).unwrap(), // unwrap cannot fail + }) + } else { + None + } + } + + /// Return a new [ArbBitInt] that performs bitwise `AND` on `self` and `other`. + /// + /// # Panics + /// + /// If `self` and `other` are not the same bit width. + pub(crate) fn bitand(&self, other: &Self) -> Self { + debug_assert_eq!(self.bitw, other.bitw); + Self { + bitw: self.bitw, + val: self.val.bitand(other.val), + } + } + + /// Return a new [ArbBitInt] that performs bitwise `OR` on `self` and `other`. + /// + /// # Panics + /// + /// If `self` and `other` are not the same bit width. + pub(crate) fn bitor(&self, other: &Self) -> Self { + debug_assert_eq!(self.bitw, other.bitw); + Self { + bitw: self.bitw, + val: self.val.bitor(other.val), + } + } + + /// Return a new [ArbBitInt] that performs bitwise `XOR` on `self` and `other`. + /// + /// # Panics + /// + /// If `self` and `other` are not the same bit width. + pub(crate) fn bitxor(&self, other: &Self) -> Self { + debug_assert_eq!(self.bitw, other.bitw); + Self { + bitw: self.bitw, + val: self.val.bitxor(other.val), + } + } +} + +impl Hash for ArbBitInt { + fn hash(&self, state: &mut H) { + self.bitw.hash(state); + self.val.truncate(self.bitw).hash(state); + } +} + +impl PartialEq for ArbBitInt { + fn eq(&self, other: &Self) -> bool { + self.bitw == other.bitw && self.val.truncate(self.bitw) == other.val.truncate(self.bitw) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + proptest! { + #[test] + fn arbbitint_8bit(x in any::(), y in any::()) { + assert_eq!(ArbBitInt::from_i64(8, x as i64).to_sign_ext_i8(), Some(x)); + + // wrapping_add + // i8 + assert_eq!( + ArbBitInt::from_i64(8, x as i64) + .wrapping_add(&ArbBitInt::from_i64(8, y as i64)).to_sign_ext_i8(), + i8::try_from(x.wrapping_add(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(8, x as i64) + .wrapping_add(&ArbBitInt::from_i64(8, y as i64)).to_sign_ext_i16(), + i16::try_from(x.wrapping_add(y)).ok() + ); + + // wrapping_sub + // i8 + assert_eq!( + ArbBitInt::from_i64(8, x as i64) + .wrapping_sub(&ArbBitInt::from_i64(8, y as i64)).to_sign_ext_i8(), + i8::try_from(x.wrapping_sub(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(8, x as i64) + .wrapping_sub(&ArbBitInt::from_i64(8, y as i64)).to_sign_ext_i16(), + i16::try_from(x.wrapping_sub(y)).ok() + ); + + // wrapping_mul + // i8 + assert_eq!( + ArbBitInt::from_i64(8, x as i64) + .wrapping_mul(&ArbBitInt::from_i64(8, y as i64)).to_sign_ext_i8(), + i8::try_from(x.wrapping_mul(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(8, x as i64) + .wrapping_mul(&ArbBitInt::from_i64(8, y as i64)).to_sign_ext_i16(), + i16::try_from(x.wrapping_mul(y)).ok() + ); + + // bitadd + assert_eq!( + ArbBitInt::from_i64(8, x as i64) + .bitand(&ArbBitInt::from_i64(8, y as i64)).to_sign_ext_i8(), + i8::try_from(x.bitand(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(8, x as i64) + .bitand(&ArbBitInt::from_i64(8, y as i64)).to_sign_ext_i16(), + i16::try_from(x.bitand(y)).ok() + ); + + // bitor + // i16 + assert_eq!( + ArbBitInt::from_i64(8, x as i64) + .bitor(&ArbBitInt::from_i64(8, y as i64)).to_sign_ext_i8(), + i8::try_from(x.bitor(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(8, x as i64) + .bitor(&ArbBitInt::from_i64(8, y as i64)).to_sign_ext_i16(), + i16::try_from(x.bitor(y)).ok() + ); + + // bitxor + // i8 + assert_eq!( + ArbBitInt::from_i64(8, x as i64) + .bitxor(&ArbBitInt::from_i64(8, y as i64)).to_sign_ext_i8(), + i8::try_from(x.bitxor(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(8, x as i64) + .bitxor(&ArbBitInt::from_i64(8, y as i64)).to_sign_ext_i16(), + i16::try_from(x.bitxor(y)).ok() + ); + } + + #[test] + fn arbbitint_16bit(x in any::(), y in any::()) { + match (i8::try_from(x), ArbBitInt::from_i64(16, x as i64).to_sign_ext_i8()) { + (Ok(a), Some(b)) if a == b => (), + (Err(_), None) => (), + a => panic!("{a:?}") + } + assert_eq!(ArbBitInt::from_i64(16, x as i64).to_sign_ext_i16(), Some(x)); + + // wrapping_add + // i8 + assert_eq!( + ArbBitInt::from_i64(16, x as i64) + .wrapping_add(&ArbBitInt::from_i64(16, y as i64)).to_sign_ext_i8(), + i8::try_from(x.wrapping_add(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(16, x as i64) + .wrapping_add(&ArbBitInt::from_i64(16, y as i64)).to_sign_ext_i16(), + i16::try_from(x.wrapping_add(y)).ok() + ); + + // wrapping_sub + // i8 + assert_eq!( + ArbBitInt::from_i64(16, x as i64) + .wrapping_sub(&ArbBitInt::from_i64(16, y as i64)).to_sign_ext_i8(), + i8::try_from(x.wrapping_sub(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(16, x as i64) + .wrapping_sub(&ArbBitInt::from_i64(16, y as i64)).to_sign_ext_i16(), + i16::try_from(x.wrapping_sub(y)).ok() + ); + + // wrapping_mul + // i8 + assert_eq!( + ArbBitInt::from_i64(16, x as i64) + .wrapping_mul(&ArbBitInt::from_i64(16, y as i64)).to_sign_ext_i8(), + i8::try_from(x.wrapping_mul(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(16, x as i64) + .wrapping_mul(&ArbBitInt::from_i64(16, y as i64)).to_sign_ext_i16(), + i16::try_from(x.wrapping_mul(y)).ok() + ); + + // bitand + // i8 + assert_eq!( + ArbBitInt::from_i64(16, x as i64) + .bitand(&ArbBitInt::from_i64(16, y as i64)).to_sign_ext_i8(), + i8::try_from(x.bitand(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(16, x as i64) + .bitand(&ArbBitInt::from_i64(16, y as i64)).to_sign_ext_i16(), + i16::try_from(x.bitand(y)).ok() + ); + + // bitor + // i8 + assert_eq!( + ArbBitInt::from_i64(16, x as i64) + .bitor(&ArbBitInt::from_i64(16, y as i64)).to_sign_ext_i8(), + i8::try_from(x.bitor(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(16, x as i64) + .bitor(&ArbBitInt::from_i64(16, y as i64)).to_sign_ext_i16(), + i16::try_from(x.bitor(y)).ok() + ); + + // bitxor + // i8 + assert_eq!( + ArbBitInt::from_i64(16, x as i64) + .bitxor(&ArbBitInt::from_i64(16, y as i64)).to_sign_ext_i8(), + i8::try_from(x.bitxor(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(16, x as i64) + .bitxor(&ArbBitInt::from_i64(16, y as i64)).to_sign_ext_i16(), + i16::try_from(x.bitxor(y)).ok() + ); + } + + #[test] + fn arbbitint_32bit(x in any::(), y in any::()) { + match (i8::try_from(x), ArbBitInt::from_i64(32, x as i64).to_sign_ext_i8()) { + (Ok(a), Some(b)) if a == b => (), + (Err(_), None) => (), + a => panic!("{a:?}") + } + match (i16::try_from(x), ArbBitInt::from_i64(32, x as i64).to_sign_ext_i16()) { + (Ok(a), Some(b)) if a == b => (), + (Err(_), None) => (), + a => panic!("{a:?}") + } + assert_eq!(ArbBitInt::from_i64(32, x as i64).to_sign_ext_i32(), Some(x)); + + // wrapping_add + // i8 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .wrapping_add(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i8(), + i8::try_from(x.wrapping_add(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .wrapping_add(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i16(), + i16::try_from(x.wrapping_add(y)).ok() + ); + // i32 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .wrapping_add(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i32(), + i32::try_from(x.wrapping_add(y)).ok() + ); + + // wrapping_sub + // i8 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .wrapping_sub(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i8(), + i8::try_from(x.wrapping_sub(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .wrapping_sub(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i16(), + i16::try_from(x.wrapping_sub(y)).ok() + ); + // i32 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .wrapping_sub(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i32(), + i32::try_from(x.wrapping_sub(y)).ok() + ); + + // wrapping_mul + // i8 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .wrapping_mul(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i8(), + i8::try_from(x.wrapping_mul(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .wrapping_mul(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i16(), + i16::try_from(x.wrapping_mul(y)).ok() + ); + // i32 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .wrapping_mul(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i32(), + i32::try_from(x.wrapping_mul(y)).ok() + ); + + // bitand + // i8 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .bitand(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i8(), + i8::try_from(x.bitand(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .bitand(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i16(), + i16::try_from(x.bitand(y)).ok() + ); + // i32 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .bitand(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i32(), + i32::try_from(x.bitand(y)).ok() + ); + + // bitor + // i8 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .bitor(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i8(), + i8::try_from(x.bitor(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .bitor(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i16(), + i16::try_from(x.bitor(y)).ok() + ); + // i32 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .bitor(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i32(), + i32::try_from(x.bitor(y)).ok() + ); + + // bitxor + // i8 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .bitxor(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i8(), + i8::try_from(x.bitxor(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .bitxor(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i16(), + i16::try_from(x.bitxor(y)).ok() + ); + // i32 + assert_eq!( + ArbBitInt::from_i64(32, x as i64) + .bitxor(&ArbBitInt::from_i64(32, y as i64)).to_sign_ext_i32(), + i32::try_from(x.bitxor(y)).ok() + ); + } + + #[test] + fn arbbitint_64bit(x in any::(), y in any::()) { + match (i8::try_from(x), ArbBitInt::from_i64(64, x).to_sign_ext_i8()) { + (Ok(a), Some(b)) if a == b => (), + (Err(_), None) => (), + a => panic!("{a:?}") + } + match (i16::try_from(x), ArbBitInt::from_i64(64, x).to_sign_ext_i16()) { + (Ok(a), Some(b)) if a == b => (), + (Err(_), None) => (), + a => panic!("{a:?}") + } + match (i32::try_from(x), ArbBitInt::from_i64(64, x).to_sign_ext_i32()) { + (Ok(a), Some(b)) if a == b => (), + (Err(_), None) => (), + a => panic!("{a:?}") + } + assert_eq!(ArbBitInt::from_i64(64, x).to_sign_ext_i64(), Some(x)); + + // wrapping_add + // i8 + assert_eq!( + ArbBitInt::from_i64(64, x) + .wrapping_add(&ArbBitInt::from_i64(64, y)).to_sign_ext_i8(), + i8::try_from(x.wrapping_add(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(64, x) + .wrapping_add(&ArbBitInt::from_i64(64, y)).to_sign_ext_i16(), + i16::try_from(x.wrapping_add(y)).ok() + ); + // i32 + assert_eq!( + ArbBitInt::from_i64(64, x) + .wrapping_add(&ArbBitInt::from_i64(64, y)).to_sign_ext_i32(), + i32::try_from(x.wrapping_add(y)).ok() + ); + // i64 + assert_eq!( + ArbBitInt::from_i64(64, x) + .wrapping_add(&ArbBitInt::from_i64(64, y)).to_sign_ext_i64(), + i64::try_from(x.wrapping_add(y)).ok() + ); + + // wrapping_sub + // i8 + assert_eq!( + ArbBitInt::from_i64(64, x) + .wrapping_sub(&ArbBitInt::from_i64(64, y)).to_sign_ext_i8(), + i8::try_from(x.wrapping_sub(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(64, x) + .wrapping_sub(&ArbBitInt::from_i64(64, y)).to_sign_ext_i16(), + i16::try_from(x.wrapping_sub(y)).ok() + ); + // i32 + assert_eq!( + ArbBitInt::from_i64(64, x) + .wrapping_sub(&ArbBitInt::from_i64(64, y)).to_sign_ext_i32(), + i32::try_from(x.wrapping_sub(y)).ok() + ); + // i64 + assert_eq!( + ArbBitInt::from_i64(64, x) + .wrapping_sub(&ArbBitInt::from_i64(64, y)).to_sign_ext_i64(), + i64::try_from(x.wrapping_sub(y)).ok() + ); + + // wrapping_mul + // i8 + assert_eq!( + ArbBitInt::from_i64(64, x) + .wrapping_mul(&ArbBitInt::from_i64(64, y)).to_sign_ext_i8(), + i8::try_from(x.wrapping_mul(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(64, x) + .wrapping_mul(&ArbBitInt::from_i64(64, y)).to_sign_ext_i16(), + i16::try_from(x.wrapping_mul(y)).ok() + ); + // i32 + assert_eq!( + ArbBitInt::from_i64(64, x) + .wrapping_mul(&ArbBitInt::from_i64(64, y)).to_sign_ext_i32(), + i32::try_from(x.wrapping_mul(y)).ok() + ); + // i64 + assert_eq!( + ArbBitInt::from_i64(64, x) + .wrapping_mul(&ArbBitInt::from_i64(64, y)).to_sign_ext_i64(), + i64::try_from(x.wrapping_mul(y)).ok() + ); + + // bitand + // i8 + assert_eq!( + ArbBitInt::from_i64(64, x) + .bitand(&ArbBitInt::from_i64(64, y)).to_sign_ext_i8(), + i8::try_from(x.bitand(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(64, x) + .bitand(&ArbBitInt::from_i64(64, y)).to_sign_ext_i16(), + i16::try_from(x.bitand(y)).ok() + ); + // i32 + assert_eq!( + ArbBitInt::from_i64(64, x) + .bitand(&ArbBitInt::from_i64(64, y)).to_sign_ext_i32(), + i32::try_from(x.bitand(y)).ok() + ); + // i64 + assert_eq!( + ArbBitInt::from_i64(64, x) + .bitand(&ArbBitInt::from_i64(64, y)).to_sign_ext_i64(), + i64::try_from(x.bitand(y)).ok() + ); + + // bitor + // i8 + assert_eq!( + ArbBitInt::from_i64(64, x) + .bitor(&ArbBitInt::from_i64(64, y)).to_sign_ext_i8(), + i8::try_from(x.bitor(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(64, x) + .bitor(&ArbBitInt::from_i64(64, y)).to_sign_ext_i16(), + i16::try_from(x.bitor(y)).ok() + ); + // i32 + assert_eq!( + ArbBitInt::from_i64(64, x) + .bitor(&ArbBitInt::from_i64(64, y)).to_sign_ext_i32(), + i32::try_from(x.bitor(y)).ok() + ); + // i64 + assert_eq!( + ArbBitInt::from_i64(64, x) + .bitor(&ArbBitInt::from_i64(64, y)).to_sign_ext_i64(), + i64::try_from(x.bitor(y)).ok() + ); + + // bitxor + // i8 + assert_eq!( + ArbBitInt::from_i64(64, x) + .bitxor(&ArbBitInt::from_i64(64, y)).to_sign_ext_i8(), + i8::try_from(x.bitxor(y)).ok() + ); + // i16 + assert_eq!( + ArbBitInt::from_i64(64, x) + .bitxor(&ArbBitInt::from_i64(64, y)).to_sign_ext_i16(), + i16::try_from(x.bitxor(y)).ok() + ); + // i32 + assert_eq!( + ArbBitInt::from_i64(64, x) + .bitxor(&ArbBitInt::from_i64(64, y)).to_sign_ext_i32(), + i32::try_from(x.bitxor(y)).ok() + ); + // i64 + assert_eq!( + ArbBitInt::from_i64(64, x) + .bitxor(&ArbBitInt::from_i64(64, y)).to_sign_ext_i64(), + i64::try_from(x.bitxor(y)).ok() + ); + } + + #[test] + fn arbbitint_8bit_shl(x in any::(), y in 0u32..=8) { + // Notice that we deliberately allow y to extend beyond to the shiftable range, to make + // sure that we test the "failure" case, while not biasing our testing too-far to + // "failure cases only". + assert_eq!( + ArbBitInt::from_u64(8, x as u64).checked_shl(y).map(|x| x.to_zero_ext_u8()), + x.checked_shl(y).map(|x| u8::try_from(x).ok()) + ); + } + + #[test] + fn arbbitint_16bit_shl(x in any::(), y in 0u32..=16) { + // Notice that we deliberately allow y to extend beyond to the shiftable range, to make + // sure that we test the "failure" case, while not biasing our testing too-far to + // "failure cases only". + assert_eq!( + ArbBitInt::from_u64(16, x as u64).checked_shl(y).map(|x| x.to_zero_ext_u16()), + x.checked_shl(y).map(|x| u16::try_from(x).ok()) + ); + } + + #[test] + fn arbbitint_32bit_shl(x in any::(), y in 0u32..=32) { + // Notice that we deliberately allow y to extend beyond to the shiftable range, to make + // sure that we test the "failure" case, while not biasing our testing too-far to + // "failure cases only". + assert_eq!( + ArbBitInt::from_u64(32, x as u64).checked_shl(y).map(|x| x.to_zero_ext_u32()), + x.checked_shl(y).map(|x| u32::try_from(x).ok()) + ); + } + + #[test] + fn arbbitint_64bit_shl(x in any::(), y in 0u32..=63) { + // Notice that we deliberately allow y to extend beyond to the shiftable range, to make + // sure that we test the "failure" case, while not biasing our testing too-far to + // "failure cases only". + assert_eq!( + ArbBitInt::from_u64(64, x).checked_shl(y).map(|x| x.to_zero_ext_u64()), + x.checked_shl(y).map(|x| u64::try_from(x).ok()) + ); + } + + #[test] + fn arbbitint_8bit_shr(x in any::(), y in 0u32..=8) { + // Notice that we deliberately allow y to extend beyond to the shiftable range, to make + // sure that we test the "failure" case, while not biasing our testing too-far to + // "failure cases only". + assert_eq!( + ArbBitInt::from_u64(8, x as u64).checked_shr(y).map(|x| x.to_zero_ext_u8()), + x.checked_shr(y).map(|x| u8::try_from(x).ok()) + ); + } + + #[test] + fn arbbitint_16bit_shr(x in any::(), y in 0u32..=16) { + // Notice that we deliberately allow y to extend beyond to the shiftable range, to make + // sure that we test the "failure" case, while not biasing our testing too-far to + // "failure cases only". + assert_eq!( + ArbBitInt::from_u64(16, x as u64).checked_shr(y).map(|x| x.to_zero_ext_u16()), + x.checked_shr(y).map(|x| u16::try_from(x).ok()) + ); + } + + #[test] + fn arbbitint_32bit_shr(x in any::(), y in 0u32..=32) { + // Notice that we deliberately allow y to extend beyond to the shiftable range, to make + // sure that we test the "failure" case, while not biasing our testing too-far to + // "failure cases only". + assert_eq!( + ArbBitInt::from_u64(32, x as u64).checked_shr(y).map(|x| x.to_zero_ext_u32()), + x.checked_shr(y).map(|x| u32::try_from(x).ok()) + ); + } + + #[test] + fn arbbitint_64bit_shr(x in any::(), y in 0u32..=63) { + // Notice that we deliberately allow y to extend beyond to the shiftable range, to make + // sure that we test the "failure" case, while not biasing our testing too-far to + // "failure cases only". + assert_eq!( + ArbBitInt::from_u64(64, x).checked_shr(y).map(|x| x.to_zero_ext_u64()), + x.checked_shr(y).map(|x| u64::try_from(x).ok()) + ); + } + } +} diff --git a/ykrt/src/compile/jitc_yk/codegen/x64/lsregalloc.rs b/ykrt/src/compile/jitc_yk/codegen/x64/lsregalloc.rs index 9bca0fe36..ebd8a21fc 100644 --- a/ykrt/src/compile/jitc_yk/codegen/x64/lsregalloc.rs +++ b/ykrt/src/compile/jitc_yk/codegen/x64/lsregalloc.rs @@ -29,7 +29,7 @@ use super::{rev_analyse::RevAnalyse, Register, VarLocation}; use crate::compile::jitc_yk::{ - codegen::{abs_stack::AbstractStack, x64::REG64_BYTESIZE}, + codegen::abs_stack::AbstractStack, jit_ir::{Const, ConstIdx, FloatTy, Inst, InstIdx, Module, Operand, PtrAddInst, Ty}, }; use dynasmrt::{ @@ -831,7 +831,7 @@ impl LSRegAlloc<'_> { let inst = self.m.inst(iidx); let bitw = inst.def_bitw(self.m); let bytew = inst.def_byte_size(self.m); - debug_assert!(bitw >= bytew); + debug_assert!(usize::try_from(bitw).unwrap() >= bytew); self.stack.align(bytew); let frame_off = self.stack.grow(bytew); let off = i32::try_from(frame_off).unwrap(); @@ -921,13 +921,8 @@ impl LSRegAlloc<'_> { fn load_const_into_gp_reg(&mut self, asm: &mut Assembler, cidx: ConstIdx, reg: Rq) { match self.m.const_(cidx) { Const::Float(_tyidx, _x) => todo!(), - Const::Int(tyidx, x) => { - // `unwrap` cannot fail, integers are sized. - if self.m.type_(*tyidx).byte_size().unwrap() <= REG64_BYTESIZE { - dynasm!(asm; mov Rq(reg.code()), QWORD *x as i64); - } else { - todo!(); - } + Const::Int(_, x) => { + dynasm!(asm; mov Rq(reg.code()), QWORD x.to_zero_ext_u64().unwrap() as i64); } Const::Ptr(x) => { dynasm!(asm; mov Rq(reg.code()), QWORD *x as i64) @@ -964,12 +959,10 @@ impl LSRegAlloc<'_> { Inst::Copy(_) => panic!(), Inst::Const(cidx) => match self.m.const_(cidx) { Const::Float(_, v) => VarLocation::ConstFloat(*v), - Const::Int(tyidx, v) => { - let Ty::Integer(bits) = self.m.type_(*tyidx) else { - panic!() - }; - VarLocation::ConstInt { bits: *bits, v: *v } - } + Const::Int(_, x) => VarLocation::ConstInt { + bits: x.bitw(), + v: x.to_zero_ext_u64().unwrap(), + }, Const::Ptr(p) => VarLocation::ConstInt { bits: 64, v: u64::try_from(*p).unwrap(), @@ -1388,7 +1381,7 @@ impl LSRegAlloc<'_> { let inst = self.m.inst(iidx); let bitw = inst.def_bitw(self.m); let bytew = inst.def_byte_size(self.m); - debug_assert!(bitw >= bytew); + debug_assert!(usize::try_from(bitw).unwrap() >= bytew); self.stack.align(bytew); let frame_off = self.stack.grow(bytew); let off = i32::try_from(frame_off).unwrap(); diff --git a/ykrt/src/compile/jitc_yk/codegen/x64/mod.rs b/ykrt/src/compile/jitc_yk/codegen/x64/mod.rs index 3b234535e..80dc18d77 100644 --- a/ykrt/src/compile/jitc_yk/codegen/x64/mod.rs +++ b/ykrt/src/compile/jitc_yk/codegen/x64/mod.rs @@ -16,7 +16,6 @@ use super::{ super::{ - int_signs::{SignExtend, Truncate}, jit_ir::{self, BinOp, FloatTy, Inst, InstIdx, Module, Operand, TraceKind, Ty}, CompilationError, }, @@ -127,7 +126,7 @@ static JITFUNC_LIVEVARS_ARGIDX: usize = 0; /// The size of a 64-bit register in bytes. pub(crate) static REG64_BYTESIZE: usize = 8; -static REG64_BITSIZE: usize = REG64_BYTESIZE * 8; +static REG64_BITSIZE: u32 = 64; static RBP_DWARF_NUM: u16 = 6; /// The x64 SysV ABI requires a 16-byte aligned stack prior to any call. @@ -1399,25 +1398,34 @@ impl<'a> Assemble<'a> { match self.m.type_(val.tyidx(self.m)) { Ty::Integer(_) | Ty::Ptr => { let bitw = val.bitw(self.m); - if let Some(imm) = self.op_to_zero_ext_immediate(&val) { + if bitw == 8 + && let Some(v) = self.op_to_zero_ext_i8(&val) + { let [tgt_reg] = self.ra .assign_gp_regs(&mut self.asm, iidx, [RegConstraint::Input(tgt_op)]); - match (imm, bitw) { - (Immediate::I8(v), 8) => { - dynasm!(self.asm ; mov BYTE [Rq(tgt_reg.code()) + off], v) - } - (Immediate::I16(v), 16) => { - dynasm!(self.asm ; mov WORD [Rq(tgt_reg.code()) + off], v) - } - (Immediate::I32(v), 32) => { - dynasm!(self.asm ; mov DWORD [Rq(tgt_reg.code()) + off], v) - } - (Immediate::I32(v), 64) => { - dynasm!(self.asm ; mov QWORD [Rq(tgt_reg.code()) + off], v) - } - _ => todo!(), - } + dynasm!(self.asm ; mov BYTE [Rq(tgt_reg.code()) + off], v); + } else if bitw == 16 + && let Some(v) = self.op_to_zero_ext_i16(&val) + { + let [tgt_reg] = + self.ra + .assign_gp_regs(&mut self.asm, iidx, [RegConstraint::Input(tgt_op)]); + dynasm!(self.asm ; mov WORD [Rq(tgt_reg.code()) + off], v); + } else if bitw == 32 + && let Some(v) = self.op_to_zero_ext_i32(&val) + { + let [tgt_reg] = + self.ra + .assign_gp_regs(&mut self.asm, iidx, [RegConstraint::Input(tgt_op)]); + dynasm!(self.asm ; mov DWORD [Rq(tgt_reg.code()) + off], v); + } else if bitw == 64 + && let Some(v) = self.op_to_zero_ext_i32(&val) + { + let [tgt_reg] = + self.ra + .assign_gp_regs(&mut self.asm, iidx, [RegConstraint::Input(tgt_op)]); + dynasm!(self.asm ; mov QWORD [Rq(tgt_reg.code()) + off], v); } else { let [tgt_reg, val_reg] = self.ra.assign_gp_regs( &mut self.asm, @@ -1679,15 +1687,10 @@ impl<'a> Assemble<'a> { Operand::Var(iidx) => self.ra.var_location(iidx), Operand::Const(cidx) => match self.m.const_(cidx) { Const::Float(_, v) => VarLocation::ConstFloat(*v), - Const::Int(tyidx, v) => { - let Ty::Integer(bit_size) = self.m.type_(*tyidx) else { - panic!() - }; - VarLocation::ConstInt { - bits: *bit_size, - v: *v, - } - } + Const::Int(_, x) => VarLocation::ConstInt { + bits: x.bitw(), + v: x.to_zero_ext_u64().unwrap(), + }, Const::Ptr(v) => VarLocation::ConstPtr(*v), }, } @@ -1697,71 +1700,41 @@ impl<'a> Assemble<'a> { /// it, otherwise return `None`. fn op_to_sign_ext_i8(&self, op: &Operand) -> Option { if let Operand::Const(cidx) = op { - if let Const::Int(tyidx, v) = self.m.const_(*cidx) { - let Ty::Integer(bit_size) = self.m.type_(*tyidx) else { - panic!() - }; - if *bit_size <= 8 { - return Some(v.sign_extend(*bit_size, 8) as i8); - } else if v.truncate(8).sign_extend(8, 64) == *v { - return Some(v.truncate(8) as i8); - } + if let Const::Int(_, x) = self.m.const_(*cidx) { + return x.to_sign_ext_i8(); } } None } - /// If an `Operand` refers to a constant integer that can be represented as an `i8`, return - /// it zero-extended to 8 bits, otherwise return `None`. - fn op_to_zero_ext_i8(&self, op: &Operand) -> Option { + /// If an `Operand` refers to a constant integer that can be represented as an `i32`, return it + /// sign-extended to 32 bits, otherwise return `None`. + fn op_to_sign_ext_i32(&self, op: &Operand) -> Option { if let Operand::Const(cidx) = op { - if let Const::Int(tyidx, v) = self.m.const_(*cidx) { - let Ty::Integer(bit_size) = self.m.type_(*tyidx) else { - panic!() - }; - if *bit_size <= 8 { - debug_assert_eq!(v.truncate(*bit_size), *v); - return Some(*v as i8); - } else if v.truncate(8) == *v { - return Some(v.truncate(8) as i8); - } + if let Const::Int(_, x) = self.m.const_(*cidx) { + return x.to_sign_ext_i32(); } } None } - /// If an `Operand` refers to a constant integer that can be represented as an `i16`, return - /// it zero extended to 16 bits, otherwise return `None`. - fn op_to_zero_ext_i16(&self, op: &Operand) -> Option { + /// If an `Operand` refers to a constant integer that can be represented as an `i8`, return + /// it zero-extended to 8 bits, otherwise return `None`. + fn op_to_zero_ext_i8(&self, op: &Operand) -> Option { if let Operand::Const(cidx) = op { - if let Const::Int(tyidx, v) = self.m.const_(*cidx) { - let Ty::Integer(bit_size) = self.m.type_(*tyidx) else { - panic!() - }; - if *bit_size <= 16 { - debug_assert_eq!(v.truncate(*bit_size), *v); - return Some(*v as i16); - } else if v.truncate(16) == *v { - return Some(v.truncate(16) as i16); - } + if let Const::Int(_, x) = self.m.const_(*cidx) { + return x.to_zero_ext_u8().map(|x| x as i8); } } None } - /// If an `Operand` refers to a constant integer that can be represented as an `i32`, return it - /// sign-extended to 32 bits, otherwise return `None`. - fn op_to_sign_ext_i32(&self, op: &Operand) -> Option { + /// If an `Operand` refers to a constant integer that can be represented as an `i8`, return + /// it zero-extended to 8 bits, otherwise return `None`. + fn op_to_zero_ext_i16(&self, op: &Operand) -> Option { if let Operand::Const(cidx) = op { - if let Const::Int(tyidx, v) = self.m.const_(*cidx) { - let Ty::Integer(bit_size) = self.m.type_(*tyidx) else { - panic!() - }; - if *bit_size <= 32 { - return Some(v.sign_extend(*bit_size, 32) as i32); - } else if v.truncate(32).sign_extend(32, 64) == *v { - return Some(v.truncate(32) as i32); - } + if let Const::Int(_, x) = self.m.const_(*cidx) { + return x.to_zero_ext_u16().map(|x| x as i16); } } None @@ -1771,46 +1744,14 @@ impl<'a> Assemble<'a> { /// zero-extended to 32 bits, otherwise return `None`. fn op_to_zero_ext_i32(&self, op: &Operand) -> Option { if let Operand::Const(cidx) = op { - if let Const::Int(tyidx, v) = self.m.const_(*cidx) { - let Ty::Integer(bit_size) = self.m.type_(*tyidx) else { - panic!() - }; - if *bit_size <= 32 { - debug_assert_eq!(v.truncate(*bit_size), *v); - return Some(*v as i32); - } else if v.truncate(32) == *v { - return Some(v.truncate(32) as i32); - } + if let Const::Int(_, x) = self.m.const_(*cidx) { + return x.to_zero_ext_u32().map(|x| x as i32); } } None } - /// Return a zero-extended [Immediate] if `op` is a constant and is representable as an x64 - /// immediate. Note this embeds the follow assumptions: - /// 1. 1 byte constants map to Immediate::I8. - /// 2. 2 byte constants map to Immediate::I16. - /// 3. 3 byte constants map to Immediate::I32. - /// 4. 4 byte constants map to Immediate::I32. - /// - /// Note that number (4) breaks the pattern of the (1-3)! - fn op_to_zero_ext_immediate(&self, op: &Operand) -> Option { - match op { - Operand::Const(cidx) => match self.m.const_(*cidx) { - Const::Float(_, _) => todo!(), - Const::Int(_, _) => match op.byte_size(self.m) { - 1 => self.op_to_zero_ext_i8(op).map(Immediate::I8), - 2 => self.op_to_zero_ext_i16(op).map(Immediate::I16), - 4 | 8 => self.op_to_zero_ext_i32(op).map(Immediate::I32), - _ => todo!(), - }, - Const::Ptr(_) => self.op_to_zero_ext_i32(op).map(Immediate::I32), - }, - Operand::Var(_) => None, - } - } - - fn cg_cmp_const(&mut self, bit_size: usize, pred: jit_ir::Predicate, lhs_reg: Rq, rhs: i32) { + fn cg_cmp_const(&mut self, bit_size: u32, pred: jit_ir::Predicate, lhs_reg: Rq, rhs: i32) { match bit_size { 0 => unreachable!(), 32 => dynasm!(self.asm; cmp Rd(lhs_reg.code()), rhs), @@ -1826,7 +1767,7 @@ impl<'a> Assemble<'a> { } } - fn cg_cmp_regs(&mut self, bit_size: usize, pred: jit_ir::Predicate, lhs_reg: Rq, rhs_reg: Rq) { + fn cg_cmp_regs(&mut self, bit_size: u32, pred: jit_ir::Predicate, lhs_reg: Rq, rhs_reg: Rq) { match bit_size { 0 => unreachable!(), 32 => dynasm!(self.asm; cmp Rd(lhs_reg.code()), Rd(rhs_reg.code())), @@ -2110,6 +2051,13 @@ impl<'a> Assemble<'a> { mov BYTE [rbp - i32::try_from(off_dst).unwrap()], v as i8), x => todo!("{x}"), }, + VarLocation::ConstPtr(v) => { + dynasm!(self.asm + ; push rax + ; mov rax, QWORD v as i64 + ; mov QWORD [rbp - i32::try_from(off_dst).unwrap()], rax + ; pop rax); + } VarLocation::Stack { frame_off: off_src, size: size_src, @@ -2326,7 +2274,7 @@ impl<'a> Assemble<'a> { unreachable!(); // must be an integer }; - if *dest_bitsize <= u32::try_from(REG64_BITSIZE).unwrap() { + if *dest_bitsize <= REG64_BITSIZE { if *src_bitsize == 64 { // The 64 bit registers are implicitly sign extended. self.ra @@ -2529,12 +2477,13 @@ impl<'a> Assemble<'a> { // it doesn't have an allocation. We can just push the actual value instead // which will be written as is during deoptimisation. match self.m.const_(x) { - Const::Int(tyidx, c) => { - let Ty::Integer(bits) = self.m.type_(*tyidx) else { - panic!() - }; - lives.push((iid.clone(), VarLocation::ConstInt { bits: *bits, v: *c })) - } + Const::Int(_, y) => lives.push(( + iid.clone(), + VarLocation::ConstInt { + bits: y.bitw(), + v: y.to_zero_ext_u64().unwrap(), + }, + )), Const::Ptr(p) => lives.push(( iid.clone(), VarLocation::ConstInt { @@ -2781,16 +2730,6 @@ impl<'a> AsmPrinter<'a> { } } -/// A representation of an x64 immediate, suitable for use in x64 instructions. -/// -/// Note that the integer values inside may be zero or sign-extended depending on the construction -/// of an instance of this enum. -enum Immediate { - I8(i8), - I16(i16), - I32(i32), -} - /// x64 tests. These use an unusual form of pattern matching. Instead of using concrete register /// names, one can refer to a class of registers e.g. `r.8` is all 8-bit registers. To match a /// register name but ignore its value uses `r.8._`: to match a register name use `r.8.x`. diff --git a/ykrt/src/compile/jitc_yk/jit_ir/mod.rs b/ykrt/src/compile/jitc_yk/jit_ir/mod.rs index 86936323e..cebb48dba 100644 --- a/ykrt/src/compile/jitc_yk/jit_ir/mod.rs +++ b/ykrt/src/compile/jitc_yk/jit_ir/mod.rs @@ -97,9 +97,7 @@ mod parser; mod well_formed; use super::aot_ir; -#[cfg(debug_assertions)] -use super::int_signs::Truncate; -use crate::compile::{CompilationError, SideTraceInfo}; +use crate::compile::{jitc_yk::arbbitint::ArbBitInt, CompilationError, SideTraceInfo}; use indexmap::IndexSet; use std::{ ffi::{c_void, CString}, @@ -266,13 +264,19 @@ impl Module { let mut consts = IndexSet::new(); let true_constidx = ConstIdx::try_from( consts - .insert_full(ConstIndexSetWrapper(Const::Int(int1_tyidx, 1))) + .insert_full(ConstIndexSetWrapper(Const::Int( + int1_tyidx, + ArbBitInt::from_u64(1, 1), + ))) .0, ) .unwrap(); let false_constidx = ConstIdx::try_from( consts - .insert_full(ConstIndexSetWrapper(Const::Int(int1_tyidx, 0))) + .insert_full(ConstIndexSetWrapper(Const::Int( + int1_tyidx, + ArbBitInt::from_u64(1, 0), + ))) .0, ) .unwrap(); @@ -525,23 +529,6 @@ impl Module { ConstIdx::try_from(i) } - /// Convenience method for adding a `Const::Int` to the constant pool and, in debug mode, - /// checking that its bit size is not execeeded. See [Self::insert_const] for the return value. - pub(crate) fn insert_const_int( - &mut self, - tyidx: TyIdx, - v: u64, - ) -> Result { - #[cfg(debug_assertions)] - { - let Ty::Integer(bits) = self.type_(tyidx) else { - panic!() - }; - assert_eq!(v.truncate(*bits), v); - } - self.insert_const(Const::Int(tyidx, v)) - } - /// Return the const for the specified index. /// /// # Panics @@ -1023,19 +1010,19 @@ impl Ty { } /// Returns the size of the type in bits, or `None` if asking the size makes no sense. - pub(crate) fn bitw(&self) -> Option { + pub(crate) fn bitw(&self) -> Option { match self { - Self::Void => Some(0), - Self::Integer(bits) => Some(usize::try_from(*bits).unwrap()), + Self::Void => None, + Self::Integer(bitw) => Some(*bitw), Self::Ptr => { // We make the same assumptions about pointer size as in Self::byte_size(). - Some(mem::size_of::<*const c_void>() * 8) + u32::try_from(mem::size_of::<*const c_void>() * 8).ok() } Self::Func(_) => None, - Self::Float(ft) => Some(match ft { - FloatTy::Float => mem::size_of::() * 8, - FloatTy::Double => mem::size_of::() * 8, - }), + Self::Float(ft) => match ft { + FloatTy::Float => u32::try_from(mem::size_of::() * 8).ok(), + FloatTy::Double => u32::try_from(mem::size_of::() * 8).ok(), + }, Self::Unimplemented(_) => None, } } @@ -1188,7 +1175,7 @@ impl Operand { /// # Panics /// /// Panics if asking for the size make no sense for this operand. - pub(crate) fn bitw(&self, m: &Module) -> usize { + pub(crate) fn bitw(&self, m: &Module) -> u32 { match self { Self::Var(l) => m.inst_raw(*l).def_bitw(m), Self::Const(cidx) => m.type_(m.const_(*cidx).tyidx(m)).bitw().unwrap(), @@ -1245,13 +1232,11 @@ impl fmt::Display for DisplayableOperand<'_> { /// Note that this struct deliberately does not implement `PartialEq` (or `Eq`): two instances of /// `Const` may represent the same underlying constant, but (because of floats), you as the user /// need to determine what notion of equality you wish to use on a given const. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub(crate) enum Const { Float(TyIdx, f64), - /// A constant integer at most 64 bits wide. This can be treated a signed or unsigned integer - /// depending on the operations that use this constant (the [Ty::Integer] type itself has no - /// concept of signedness). - Int(TyIdx, u64), + /// A constant integer. Note that, as in LLVM IR, this has no inherent signedness. + Int(TyIdx, ArbBitInt), Ptr(usize), } @@ -1282,11 +1267,8 @@ impl fmt::Display for DisplayableConst<'_> { Ty::Float(FloatTy::Double) => write!(f, "{}double", v), _ => unreachable!(), }, - Const::Int(tyidx, x) => { - let Ty::Integer(width) = self.m.type_(*tyidx) else { - panic!() - }; - write!(f, "{x}i{width}") + Const::Int(_, x) => { + write!(f, "{}i{}", x.to_zero_ext_u64().unwrap(), x.bitw()) } Const::Ptr(x) => write!(f, "{:#x}", *x), } @@ -1308,8 +1290,8 @@ impl PartialEq for ConstIndexSetWrapper { // acceptable. lhs_tyidx == rhs_tyidx && lhs_v.to_bits() == rhs_v.to_bits() } - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { - lhs_tyidx == rhs_tyidx && lhs_v == rhs_v + (Const::Int(lhs_tyidx, lhs), Const::Int(rhs_tyidx, rhs)) => { + lhs_tyidx == rhs_tyidx && lhs == rhs } (Const::Ptr(lhs_v), Const::Ptr(rhs_v)) => lhs_v == rhs_v, (_, _) => false, @@ -1321,16 +1303,16 @@ impl Eq for ConstIndexSetWrapper {} impl Hash for ConstIndexSetWrapper { fn hash(&self, state: &mut H) { - match self.0 { + match &self.0 { Const::Float(tyidx, v) => { tyidx.hash(state); // We treat floats as bit patterns: because we can accept duplicates, this is // acceptable. v.to_bits().hash(state); } - Const::Int(tyidx, v) => { + Const::Int(tyidx, x) => { tyidx.hash(state); - v.hash(state); + x.hash(state); } Const::Ptr(v) => v.hash(state), } @@ -1902,7 +1884,7 @@ impl Inst { /// Panics if: /// - The instruction defines no local variable. /// - The instruction defines an unsized local variable. - pub(crate) fn def_bitw(&self, m: &Module) -> usize { + pub(crate) fn def_bitw(&self, m: &Module) -> u32 { if let Some(ty) = self.def_type(m) { if let Some(size) = ty.bitw() { size @@ -3294,13 +3276,27 @@ mod tests { #[test] fn stringify_int_consts() { let mut m = Module::new_testing(); - let i8_tyidx = m.insert_ty(Ty::Integer(8)).unwrap(); - assert_eq!(Const::Int(i8_tyidx, 0).display(&m).to_string(), "0i8"); - assert_eq!(Const::Int(i8_tyidx, 255).display(&m).to_string(), "255i8"); let i64_tyidx = m.insert_ty(Ty::Integer(64)).unwrap(); - assert_eq!(Const::Int(i64_tyidx, 0).display(&m).to_string(), "0i64"); assert_eq!( - Const::Int(i64_tyidx, 9223372036854775808) + Const::Int(m.int8_tyidx(), ArbBitInt::from_u64(8, 0)) + .display(&m) + .to_string(), + "0i8" + ); + assert_eq!( + Const::Int(m.int8_tyidx(), ArbBitInt::from_u64(8, 255)) + .display(&m) + .to_string(), + "255i8" + ); + assert_eq!( + Const::Int(i64_tyidx, ArbBitInt::from_u64(64, 0)) + .display(&m) + .to_string(), + "0i64" + ); + assert_eq!( + Const::Int(i64_tyidx, ArbBitInt::from_u64(64, 9223372036854775808)) .display(&m) .to_string(), "9223372036854775808i64" diff --git a/ykrt/src/compile/jitc_yk/jit_ir/parser.rs b/ykrt/src/compile/jitc_yk/jit_ir/parser.rs index ae78fec27..a5eeca1f3 100644 --- a/ykrt/src/compile/jitc_yk/jit_ir/parser.rs +++ b/ykrt/src/compile/jitc_yk/jit_ir/parser.rs @@ -8,6 +8,7 @@ use crate::compile::jitc_yk::aot_ir; use super::super::{ aot_ir::{BinOp, FloatPredicate, InstID, Predicate}, + arbbitint::ArbBitInt, jit_ir::{ BinOpInst, BitCastInst, BlackBoxInst, Const, DirectCallInst, DynPtrAddInst, FCmpInst, FNegInst, FPExtInst, FPToSIInst, FloatTy, FuncDecl, FuncTy, GuardInfo, GuardInst, ICmpInst, @@ -569,34 +570,34 @@ impl<'lexer, 'input: 'lexer> JITIRParser<'lexer, 'input, '_> { ASTOperand::ConstInt(span) => { let s = self.lexer.span_str(span); let [val, width] = <[&str; 2]>::try_from(s.split('i').collect::>()).unwrap(); - let width = width + let bitw = width .parse::() .map_err(|e| self.error_at_span(span, &e.to_string()))?; let val = if val.starts_with("-") { let val = val .parse::() .map_err(|e| self.error_at_span(span, &e.to_string()))?; - if width < 64 - && (val < -((1 << width) - 1) / 2 - 1 || val >= ((1 << width) - 1) / 2) + if bitw < 64 + && (val < -((1 << bitw) - 1) / 2 - 1 || val >= ((1 << bitw) - 1) / 2) { return Err(self.error_at_span(span, - &format!("Signed constant {val} exceeds the bit width {width} of the integer type"))); + &format!("Signed constant {val} exceeds the bit width {bitw} of the integer type"))); } val as u64 } else { let val = val .parse::() .map_err(|e| self.error_at_span(span, &e.to_string()))?; - if width < 64 && val > (1 << width) - 1 { + if bitw < 64 && val > (1 << bitw) - 1 { return Err(self.error_at_span(span, - &format!("Unsigned constant {val} exceeds the bit width {width} of the integer type"))); + &format!("Unsigned constant {val} exceeds the bit width {bitw} of the integer type"))); } val }; - let tyidx = self.m.insert_ty(Ty::Integer(width)).unwrap(); + let tyidx = self.m.insert_ty(Ty::Integer(bitw)).unwrap(); Ok(Operand::Const( self.m - .insert_const(Const::Int(tyidx, val)) + .insert_const(Const::Int(tyidx, ArbBitInt::from_u64(bitw, val))) .map_err(|e| self.error_at_span(span, &e.to_string()))?, )) } diff --git a/ykrt/src/compile/jitc_yk/jit_ir/well_formed.rs b/ykrt/src/compile/jitc_yk/jit_ir/well_formed.rs index d8ff89c37..3edff277a 100644 --- a/ykrt/src/compile/jitc_yk/jit_ir/well_formed.rs +++ b/ykrt/src/compile/jitc_yk/jit_ir/well_formed.rs @@ -134,11 +134,14 @@ impl Module { self.inst(iidx).display(self, iidx) ) }; - if let Operand::Const(x) = cond { - let Const::Int(_, v) = self.const_(x) else { + if let Operand::Const(cidx) = cond { + let Const::Int(_, x) = self.const_(cidx) else { unreachable!() }; - if (expect && *v == 0) || (!expect && *v == 1) { + debug_assert_eq!(x.bitw(), 1); + if (expect && x.to_zero_ext_u8().unwrap() == 0) + || (!expect && x.to_zero_ext_u8().unwrap() == 1) + { panic!( "Guard at position {iidx} references a constant that is at odds with the guard itself\n {}", inst.display(self, iidx) @@ -316,7 +319,10 @@ impl Module { #[cfg(test)] mod tests { - use super::{super::TraceKind, BinOp, BinOpInst, Const, Inst, Module, Operand}; + use super::{ + super::{ArbBitInt, TraceKind}, + BinOp, BinOpInst, Const, Inst, Module, Operand, + }; #[should_panic(expected = "Instruction at position 0 passing too few arguments")] #[test] @@ -388,8 +394,12 @@ mod tests { // The parser will reject a binop with a result type different from either operand, so to // get the test we want, we can't use the parser. let mut m = Module::new(TraceKind::HeaderOnly, 0, 0).unwrap(); - let c1 = m.insert_const(Const::Int(m.int1_tyidx(), 0)).unwrap(); - let c2 = m.insert_const(Const::Int(m.int8_tyidx(), 0)).unwrap(); + let c1 = m + .insert_const(Const::Int(m.int1_tyidx(), ArbBitInt::from_u64(1, 0))) + .unwrap(); + let c2 = m + .insert_const(Const::Int(m.int8_tyidx(), ArbBitInt::from_u64(8, 0))) + .unwrap(); m.push(Inst::BinOp(BinOpInst::new( Operand::Const(c1), BinOp::Add, diff --git a/ykrt/src/compile/jitc_yk/mod.rs b/ykrt/src/compile/jitc_yk/mod.rs index a1e98fc72..33f6ab07a 100644 --- a/ykrt/src/compile/jitc_yk/mod.rs +++ b/ykrt/src/compile/jitc_yk/mod.rs @@ -21,6 +21,7 @@ use ykaddr::addr::symbol_to_ptr; use yksmp::Location; pub mod aot_ir; +mod arbbitint; mod codegen; #[cfg(any(debug_assertions, test))] mod gdb; diff --git a/ykrt/src/compile/jitc_yk/opt/analyse.rs b/ykrt/src/compile/jitc_yk/opt/analyse.rs index 5d78dbe44..19ab2d5bc 100644 --- a/ykrt/src/compile/jitc_yk/opt/analyse.rs +++ b/ykrt/src/compile/jitc_yk/opt/analyse.rs @@ -56,16 +56,12 @@ impl Analyse { // that allows us to turn it into a constant. if let Inst::ICmp(inst) = m.inst(iidx) { let lhs = self.op_map(m, inst.lhs(m)); - let pred = inst.predicate(); let rhs = self.op_map(m, inst.rhs(m)); - if let (&Operand::Const(lhs_cidx), &Operand::Const(rhs_cidx)) = (&lhs, &rhs) + if let (&Operand::Const(_lhs_cidx), &Operand::Const(_rhs_cidx)) = + (&lhs, &rhs) { - if pred == Predicate::Equal && m.const_(lhs_cidx) == m.const_(rhs_cidx) - { - self.values.borrow_mut()[usize::from(iidx)] = - Value::Const(m.true_constidx()); - return Operand::Const(m.true_constidx()); - } + // Can we still hit this case? + todo!(); } } op diff --git a/ykrt/src/compile/jitc_yk/opt/heapvalues.rs b/ykrt/src/compile/jitc_yk/opt/heapvalues.rs index f663ef9c2..c453052ba 100644 --- a/ykrt/src/compile/jitc_yk/opt/heapvalues.rs +++ b/ykrt/src/compile/jitc_yk/opt/heapvalues.rs @@ -117,7 +117,10 @@ impl HeapValues { #[cfg(test)] mod test { use super::*; - use crate::compile::jitc_yk::jit_ir::{Const, Ty}; + use crate::compile::jitc_yk::{ + arbbitint::ArbBitInt, + jit_ir::{Const, Ty}, + }; #[test] fn basic() { @@ -137,14 +140,18 @@ mod test { // Add a single load let addr0 = Address::from_operand(&m, Operand::Var(InstIdx::unchecked_from(0))); assert!(hv.get(&m, addr0.clone(), 1).is_none()); - let cidx0 = m.insert_const(Const::Int(m.int8_tyidx(), 0)).unwrap(); + let cidx0 = m + .insert_const(Const::Int(m.int8_tyidx(), ArbBitInt::from_u64(8, 0))) + .unwrap(); hv.load(&m, addr0.clone(), Operand::Const(cidx0)); assert_eq!(hv.hv.len(), 1); assert_eq!(hv.get(&m, addr0.clone(), 1), Some(Operand::Const(cidx0))); // Add a non-overlapping load let addr1 = Address::from_operand(&m, Operand::Var(InstIdx::unchecked_from(1))); - let cidx1 = m.insert_const(Const::Int(m.int8_tyidx(), 1)).unwrap(); + let cidx1 = m + .insert_const(Const::Int(m.int8_tyidx(), ArbBitInt::from_u64(8, 1))) + .unwrap(); hv.load(&m, addr1.clone(), Operand::Const(cidx1)); assert_eq!(hv.hv.len(), 2); assert_eq!(hv.get(&m, addr0.clone(), 1), Some(Operand::Const(cidx0))); @@ -156,7 +163,9 @@ mod test { assert!(hv.get(&m, addr2.clone(), 2).is_none()); // Add a store that replaces our knowledge of the second load but preserves the first. - let cidx2 = m.insert_const(Const::Int(m.int8_tyidx(), 2)).unwrap(); + let cidx2 = m + .insert_const(Const::Int(m.int8_tyidx(), ArbBitInt::from_u64(8, 2))) + .unwrap(); hv.store(&m, addr2.clone(), Operand::Const(cidx2)); assert_eq!(hv.hv.len(), 2); assert_eq!(hv.get(&m, addr0.clone(), 1), Some(Operand::Const(cidx2))); @@ -164,7 +173,9 @@ mod test { // Add an overlapping i64 store which should remove information about both preceding loads. let int64_tyidx = m.insert_ty(Ty::Integer(64)).unwrap(); - let cidx3 = m.insert_const(Const::Int(int64_tyidx, 3)).unwrap(); + let cidx3 = m + .insert_const(Const::Int(int64_tyidx, ArbBitInt::from_u64(64, 3))) + .unwrap(); hv.store(&m, addr2.clone(), Operand::Const(cidx3)); assert_eq!(hv.hv.len(), 1); assert_eq!(hv.get(&m, addr0.clone(), 8), Some(Operand::Const(cidx3))); @@ -172,7 +183,9 @@ mod test { assert!(hv.get(&m, addr1.clone(), 1).is_none()); // Add an overlapping i8 store which should remove information about the i64 load. - let cidx4 = m.insert_const(Const::Int(m.int8_tyidx(), 4)).unwrap(); + let cidx4 = m + .insert_const(Const::Int(m.int8_tyidx(), ArbBitInt::from_u64(8, 4))) + .unwrap(); hv.store(&m, addr1.clone(), Operand::Const(cidx4)); assert_eq!(hv.hv.len(), 1); assert_eq!(hv.get(&m, addr1.clone(), 1), Some(Operand::Const(cidx4))); @@ -181,7 +194,9 @@ mod test { // Add a store which we can't prove doesn't alias. let addr4 = Address::from_operand(&m, Operand::Var(InstIdx::unchecked_from(4))); - let cidx5 = m.insert_const(Const::Int(m.int8_tyidx(), 5)).unwrap(); + let cidx5 = m + .insert_const(Const::Int(m.int8_tyidx(), ArbBitInt::from_u64(8, 5))) + .unwrap(); hv.store(&m, addr4.clone(), Operand::Const(cidx5)); assert_eq!(hv.hv.len(), 1); assert_eq!(hv.get(&m, addr4.clone(), 1), Some(Operand::Const(cidx5))); diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index 989ae2be9..efccc12f7 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -7,7 +7,7 @@ // in this module; the refinement of values in the [Analyse] module. use super::{ - int_signs::{SignExtend, Truncate}, + arbbitint::ArbBitInt, jit_ir::{ BinOp, BinOpInst, Const, ConstIdx, ICmpInst, Inst, InstIdx, Module, Operand, Predicate, PtrAddInst, TraceKind, Ty, @@ -177,7 +177,7 @@ impl Opt { (Operand::Const(op_cidx), Operand::Var(op_iidx)) | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { match self.m.const_(op_cidx) { - Const::Int(_, 0) => { + Const::Int(_, x) if x.to_zero_ext_u64().unwrap() == 0 => { // Replace `x + 0` with `x`. self.m.replace(iidx, Inst::Copy(op_iidx)); } @@ -197,15 +197,13 @@ impl Opt { } (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { - debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( + (Const::Int(lhs_tyidx, lhs), Const::Int(_rhs_tyidx, rhs)) => { + debug_assert_eq!(lhs.bitw(), rhs.bitw()); + debug_assert_eq!(lhs_tyidx, _rhs_tyidx); + let cidx = self.m.insert_const(Const::Int( *lhs_tyidx, - (lhs_v.wrapping_add(*rhs_v)).truncate(*bits), - )?; + lhs.wrapping_add(rhs), + ))?; self.m.replace(iidx, Inst::Const(cidx)); } _ => todo!(), @@ -220,7 +218,7 @@ impl Opt { (Operand::Const(op_cidx), Operand::Var(op_iidx)) | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { match self.m.const_(op_cidx) { - Const::Int(_, 0) => { + Const::Int(_, x) if x.to_zero_ext_u64().unwrap() == 0 => { // Replace `x & 0` with `0`. self.m.replace(iidx, Inst::Const(op_cidx)); } @@ -240,15 +238,11 @@ impl Opt { } (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + (Const::Int(lhs_tyidx, lhs), Const::Int(rhs_tyidx, rhs)) => { debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v & rhs_v).truncate(*bits), - )?; + let cidx = self + .m + .insert_const(Const::Int(*lhs_tyidx, lhs.bitand(rhs)))?; self.m.replace(iidx, Inst::Const(cidx)); } _ => panic!(), @@ -261,29 +255,36 @@ impl Opt { self.an.op_map(&self.m, x.rhs(&self.m)), ) { (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { - if let Const::Int(_, 0) = self.m.const_(op_cidx) { - // Replace `x >> 0` with `x`. - self.m.replace(iidx, Inst::Copy(op_iidx)); + if let Const::Int(_, y) = self.m.const_(op_cidx) { + if y.to_zero_ext_u64().unwrap() == 0 { + // Replace `x >> 0` with `x`. + self.m.replace(iidx, Inst::Copy(op_iidx)); + } } } (Operand::Const(op_cidx), Operand::Var(_)) => { - if let Const::Int(tyidx, 0) = self.m.const_(op_cidx) { - // Replace `0 >> x` with `0`. - let new_cidx = self.m.insert_const_int(*tyidx, 0)?; - self.m.replace(iidx, Inst::Const(new_cidx)); + if let Const::Int(tyidx, y) = self.m.const_(op_cidx) { + if y.to_zero_ext_u64().unwrap() == 0 { + // Replace `0 >> x` with `0`. + let new_cidx = self.m.insert_const(Const::Int( + *tyidx, + ArbBitInt::from_u64(y.bitw(), 0), + ))?; + self.m.replace(iidx, Inst::Const(new_cidx)); + } } } (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { - debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v >> rhs_v).truncate(*bits), - )?; + (Const::Int(lhs_tyidx, lhs), Const::Int(_rhs_tyidx, rhs)) => { + debug_assert_eq!(lhs.bitw(), rhs.bitw()); + debug_assert_eq!(lhs_tyidx, _rhs_tyidx); + // If checked_shr fails, we've encountered LLVM poison and can + // choose any value. + let shr = lhs + .checked_shr(rhs.to_zero_ext_u32().unwrap()) + .unwrap_or_else(|| ArbBitInt::all_bits_set(lhs.bitw())); + let cidx = self.m.insert_const(Const::Int(*lhs_tyidx, shr))?; self.m.replace(iidx, Inst::Const(cidx)); } _ => panic!(), @@ -298,20 +299,22 @@ impl Opt { (Operand::Const(op_cidx), Operand::Var(op_iidx)) | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { match self.m.const_(op_cidx) { - Const::Int(_, 0) => { + Const::Int(_, y) if y.to_zero_ext_u64().unwrap() == 0 => { // Replace `x * 0` with `0`. self.m.replace(iidx, Inst::Const(op_cidx)); } - Const::Int(_, 1) => { + Const::Int(_, y) if y.to_zero_ext_u64().unwrap() == 1 => { // Replace `x * 1` with `x`. self.m.replace(iidx, Inst::Copy(op_iidx)); } - Const::Int(ty_idx, x) if x.is_power_of_two() => { + Const::Int(tyidx, y) + if y.to_zero_ext_u64().unwrap().is_power_of_two() => + { // Replace `x * y` with `x << ...`. - let shl = u64::from(x.ilog2()); - let shl_op = Operand::Const( - self.m.insert_const(Const::Int(*ty_idx, shl))?, - ); + let shl = u64::from(y.to_zero_ext_u64().unwrap().ilog2()); + let shl_op = Operand::Const(self.m.insert_const( + Const::Int(*tyidx, ArbBitInt::from_u64(y.bitw(), shl)), + )?); let new_inst = BinOpInst::new(Operand::Var(op_iidx), BinOp::Shl, shl_op) .into(); @@ -333,15 +336,13 @@ impl Opt { } (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { - debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( + (Const::Int(lhs_tyidx, lhs), Const::Int(_rhs_tyidx, rhs)) => { + debug_assert_eq!(lhs_tyidx, _rhs_tyidx); + debug_assert_eq!(lhs.bitw(), rhs.bitw()); + let cidx = self.m.insert_const(Const::Int( *lhs_tyidx, - (lhs_v.wrapping_mul(*rhs_v)).truncate(*bits), - )?; + lhs.wrapping_mul(rhs), + ))?; self.m.replace(iidx, Inst::Const(cidx)); } _ => todo!(), @@ -356,7 +357,7 @@ impl Opt { (Operand::Const(op_cidx), Operand::Var(op_iidx)) | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { match self.m.const_(op_cidx) { - Const::Int(_, 0) => { + Const::Int(_, y) if y.to_zero_ext_u64().unwrap() == 0 => { // Replace `x | 0` with `x`. self.m.replace(iidx, Inst::Copy(op_iidx)); } @@ -376,15 +377,12 @@ impl Opt { } (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + (Const::Int(lhs_tyidx, lhs), Const::Int(rhs_tyidx, rhs)) => { debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v | rhs_v).truncate(*bits), - )?; + debug_assert_eq!(lhs.bitw(), rhs.bitw()); + let cidx = self + .m + .insert_const(Const::Int(*lhs_tyidx, lhs.bitor(rhs)))?; self.m.replace(iidx, Inst::Const(cidx)); } _ => panic!(), @@ -397,38 +395,36 @@ impl Opt { self.an.op_map(&self.m, x.rhs(&self.m)), ) { (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { - if let Const::Int(_, 0) = self.m.const_(op_cidx) { - // Replace `x << 0` with `x`. - self.m.replace(iidx, Inst::Copy(op_iidx)); + if let Const::Int(_, y) = self.m.const_(op_cidx) { + if y.to_zero_ext_u64().unwrap() == 0 { + // Replace `x << 0` with `x`. + self.m.replace(iidx, Inst::Copy(op_iidx)); + } } } (Operand::Const(op_cidx), Operand::Var(_)) => { - if let Const::Int(tyidx, 0) = self.m.const_(op_cidx) { - // Replace `0 << x` with `0`. - let new_cidx = self.m.insert_const_int(*tyidx, 0)?; - self.m.replace(iidx, Inst::Const(new_cidx)); + if let Const::Int(tyidx, y) = self.m.const_(op_cidx) { + if y.to_zero_ext_u64().unwrap() == 0 { + // Replace `0 << x` with `0`. + let new_cidx = self.m.insert_const(Const::Int( + *tyidx, + ArbBitInt::from_u64(y.bitw(), 0), + ))?; + self.m.replace(iidx, Inst::Const(new_cidx)); + } } } (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + (Const::Int(lhs_tyidx, lhs), Const::Int(rhs_tyidx, rhs)) => { debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - // If checked_shl fails, we've encountered LLVM poison: we can - // now choose any value (in this case 0) and know that we're - // respecting LLVM's semantics. In case the user's program then - // has UB and uses the poison value, we make it `int::MAX` - // because there is a small chance that will make the UB more - // obvious to them. - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v - .checked_shl(u32::try_from(*rhs_v).unwrap()) - .unwrap_or(u64::MAX)) - .truncate(*bits), - )?; + debug_assert_eq!(lhs.bitw(), rhs.bitw()); + // If checked_shl fails, we've encountered LLVM poison and can + // choose any value. + let shl = lhs + .checked_shl(rhs.to_zero_ext_u32().unwrap()) + .unwrap_or_else(|| ArbBitInt::all_bits_set(lhs.bitw())); + let cidx = self.m.insert_const(Const::Int(*lhs_tyidx, shl))?; self.m.replace(iidx, Inst::Const(cidx)); } _ => panic!(), @@ -441,22 +437,22 @@ impl Opt { self.an.op_map(&self.m, x.rhs(&self.m)), ) { (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { - if let Const::Int(_, 0) = self.m.const_(op_cidx) { - // Replace `x - 0` with `x`. - self.m.replace(iidx, Inst::Copy(op_iidx)); + if let Const::Int(_, y) = self.m.const_(op_cidx) { + if y.to_zero_ext_u64().unwrap() == 0 { + // Replace `x - 0` with `x`. + self.m.replace(iidx, Inst::Copy(op_iidx)); + } } } (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + (Const::Int(lhs_tyidx, lhs), Const::Int(rhs_tyidx, rhs)) => { debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( + debug_assert_eq!(lhs.bitw(), rhs.bitw()); + let cidx = self.m.insert_const(Const::Int( *lhs_tyidx, - (lhs_v.wrapping_sub(*rhs_v)).truncate(*bits), - )?; + lhs.wrapping_sub(rhs), + ))?; self.m.replace(iidx, Inst::Const(cidx)); } _ => todo!(), @@ -472,7 +468,7 @@ impl Opt { (Operand::Const(op_cidx), Operand::Var(op_iidx)) | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { match self.m.const_(op_cidx) { - Const::Int(_, 0) => { + Const::Int(_, y) if y.to_zero_ext_u64().unwrap() == 0 => { // Replace `x ^ 0` with `x`. self.m.replace(iidx, Inst::Copy(op_iidx)); } @@ -492,15 +488,12 @@ impl Opt { } (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + (Const::Int(lhs_tyidx, lhs), Const::Int(rhs_tyidx, rhs)) => { debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v ^ rhs_v).truncate(*bits), - )?; + debug_assert_eq!(lhs.bitw(), rhs.bitw()); + let cidx = self + .m + .insert_const(Const::Int(*lhs_tyidx, lhs.bitxor(rhs)))?; self.m.replace(iidx, Inst::Const(cidx)); } _ => panic!(), @@ -525,7 +518,7 @@ impl Opt { }; // DynPtrAdd indices are signed, so we have to be careful to interpret the // constant as such. - let v = *v as i64; + let v = v.to_sign_ext_i64().unwrap(); // LLVM IR allows `off` to be an `i64` but our IR currently allows only an // `i32`. On that basis, we can hit our limits before the program has // itself hit UB, at which point we can't go any further. @@ -576,7 +569,8 @@ impl Opt { let Const::Int(_, v) = self.m.const_(cidx) else { panic!() }; - let op = match v { + debug_assert_eq!(v.bitw(), 1); + let op = match v.to_zero_ext_u8().unwrap() { 0 => sinst.falseval(&self.m), 1 => sinst.trueval(&self.m), _ => panic!(), @@ -591,16 +585,14 @@ impl Opt { } Inst::SExt(x) => { if let Operand::Const(cidx) = self.an.op_map(&self.m, x.val(&self.m)) { - let Const::Int(src_ty, src_val) = self.m.const_(cidx) else { + let Const::Int(_, src_val) = self.m.const_(cidx) else { unreachable!() }; - let src_ty = self.m.type_(*src_ty); let dst_ty = self.m.type_(x.dest_tyidx()); - let (Ty::Integer(src_bits), Ty::Integer(dst_bits)) = (src_ty, dst_ty) else { + let Ty::Integer(dst_bits) = dst_ty else { unreachable!() }; - let dst_val = - Const::Int(x.dest_tyidx(), src_val.sign_extend(*src_bits, *dst_bits)); + let dst_val = Const::Int(x.dest_tyidx(), src_val.sign_extend(*dst_bits)); let dst_cidx = self.m.insert_const(dst_val)?; self.m.replace(iidx, Inst::Const(dst_cidx)); } @@ -609,7 +601,15 @@ impl Opt { // FIXME: This feels like it should be handled by trace_builder, but we can't // do so yet because of https://github.com/ykjit/yk/issues/1435. if let yksmp::Location::Constant(v) = self.m.param(x.paramidx()) { - let cidx = self.m.insert_const(Const::Int(x.tyidx(), u64::from(*v)))?; + let Ty::Integer(bitw) = self.m.type_(x.tyidx()) else { + unreachable!() + }; + // `Location::Constant` is a u32 + assert!(*bitw <= 32); + let cidx = self.m.insert_const(Const::Int( + x.tyidx(), + ArbBitInt::from_u64(*bitw, u64::from(*v)), + ))?; self.an.set_value(&self.m, iidx, Value::Const(cidx)); } } @@ -664,7 +664,16 @@ impl Opt { Some(Operand::Var(hv_iidx)) => Operand::Var(hv_iidx) == x.val(&self.m), Some(Operand::Const(hv_cidx)) => match val { Operand::Var(_) => false, - Operand::Const(cidx) => self.m.const_(cidx) == self.m.const_(hv_cidx), + Operand::Const(cidx) => { + match (self.m.const_(cidx), self.m.const_(hv_cidx)) { + (Const::Int(lhs_tyidx, lhs), Const::Int(rhs_tyidx, rhs)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + debug_assert_eq!(lhs.bitw(), rhs.bitw()); + lhs == rhs + } + x => todo!("{x:?}"), + } + } }, }; if is_dead { @@ -680,18 +689,14 @@ impl Opt { let Const::Int(_src_ty, src_val) = self.m.const_(cidx) else { unreachable!() }; - #[cfg(debug_assertions)] - { - let src_ty = self.m.type_(*_src_ty); - let dst_ty = self.m.type_(x.dest_tyidx()); - let (Ty::Integer(src_bits), Ty::Integer(dst_bits)) = (src_ty, dst_ty) - else { - unreachable!() - }; - debug_assert!(src_bits <= dst_bits); - debug_assert!(*dst_bits <= 64); - } - let dst_cidx = self.m.insert_const(Const::Int(x.dest_tyidx(), *src_val))?; + let dst_ty = self.m.type_(x.dest_tyidx()); + let Ty::Integer(dst_bits) = dst_ty else { + unreachable!() + }; + debug_assert!(*dst_bits <= 64); + let dst_cidx = self + .m + .insert_const(Const::Int(x.dest_tyidx(), src_val.zero_extend(*dst_bits)))?; self.m.replace(iidx, Inst::Const(dst_cidx)); } } @@ -794,22 +799,40 @@ impl Opt { let rhs_c = self.m.const_(rhs); match (lhs_c, rhs_c) { (Const::Float(..), Const::Float(..)) => (), - (Const::Int(lhs_tyidx, x), Const::Int(rhs_tyidx, y)) => { - debug_assert_eq!(lhs_tyidx, rhs_tyidx); + (Const::Int(_, x), Const::Int(_, y)) => { + debug_assert_eq!(x.bitw(), y.bitw()); // Constant fold comparisons of simple integers. - let x = *x; - let y = *y; let r = match pred { - Predicate::Equal => x == y, - Predicate::NotEqual => x != y, - Predicate::UnsignedGreater => x > y, - Predicate::UnsignedGreaterEqual => x >= y, - Predicate::UnsignedLess => x < y, - Predicate::UnsignedLessEqual => x <= y, - Predicate::SignedGreater => (x as i64) > (y as i64), - Predicate::SignedGreaterEqual => (x as i64) >= (y as i64), - Predicate::SignedLess => (x as i64) < (y as i64), - Predicate::SignedLessEqual => (x as i64) <= (y as i64), + Predicate::Equal => { + x.to_zero_ext_u64().unwrap() == y.to_zero_ext_u64().unwrap() + } + Predicate::NotEqual => { + x.to_zero_ext_u64().unwrap() != y.to_zero_ext_u64().unwrap() + } + Predicate::UnsignedGreater => { + x.to_zero_ext_u64().unwrap() > y.to_zero_ext_u64().unwrap() + } + Predicate::UnsignedGreaterEqual => { + x.to_zero_ext_u64().unwrap() >= y.to_zero_ext_u64().unwrap() + } + Predicate::UnsignedLess => { + x.to_zero_ext_u64().unwrap() < y.to_zero_ext_u64().unwrap() + } + Predicate::UnsignedLessEqual => { + x.to_zero_ext_u64().unwrap() <= y.to_zero_ext_u64().unwrap() + } + Predicate::SignedGreater => { + (x.to_sign_ext_i64().unwrap()) > (y.to_sign_ext_i64().unwrap()) + } + Predicate::SignedGreaterEqual => { + (x.to_sign_ext_i64().unwrap()) >= (y.to_sign_ext_i64().unwrap()) + } + Predicate::SignedLess => { + (x.to_sign_ext_i64().unwrap()) < (y.to_sign_ext_i64().unwrap()) + } + Predicate::SignedLessEqual => { + (x.to_sign_ext_i64().unwrap()) <= (y.to_sign_ext_i64().unwrap()) + } }; self.m.replace( diff --git a/ykrt/src/compile/jitc_yk/trace_builder.rs b/ykrt/src/compile/jitc_yk/trace_builder.rs index c3be83c76..693b74965 100644 --- a/ykrt/src/compile/jitc_yk/trace_builder.rs +++ b/ykrt/src/compile/jitc_yk/trace_builder.rs @@ -5,6 +5,7 @@ use super::aot_ir::{self, BBlockId, BinOp, Module}; use super::YkSideTraceInfo; use super::{ + arbbitint::ArbBitInt, jit_ir::{self, Const, Operand, PackedOperand, ParamIdx, TraceKind}, AOT_MOD, }; @@ -396,7 +397,10 @@ impl TraceBuilder { _ => todo!("{}", x.bitw()), }; let jit_tyidx = self.jit_mod.insert_ty(jit_ir::Ty::Integer(x.bitw()))?; - Ok(jit_ir::Const::Int(jit_tyidx, v)) + Ok(jit_ir::Const::Int( + jit_tyidx, + ArbBitInt::from_u64(x.bitw(), v), + )) } aot_ir::Ty::Float(fty) => { let jit_tyidx = self.jit_mod.insert_ty(jit_ir::Ty::Float(fty.clone()))?; @@ -917,6 +921,7 @@ impl TraceBuilder { } let jit_tyidx = self.handle_type(test_val.type_(self.aot_mod))?; + let bitw = self.jit_mod.type_(jit_tyidx).bitw().unwrap(); // Find out which case we traced. let guard = match case_dests.iter().position(|&cd| cd == next_bb.bbidx()) { @@ -926,7 +931,7 @@ impl TraceBuilder { let bb = case_dests[cidx]; // Build the constant value to guard. - let jit_const = jit_ir::Const::Int(jit_tyidx, val); + let jit_const = jit_ir::Const::Int(jit_tyidx, ArbBitInt::from_u64(bitw, val)); let jit_const_opnd = jit_ir::Operand::Const(self.jit_mod.insert_const(jit_const)?); // Perform the comparison. @@ -957,7 +962,7 @@ impl TraceBuilder { let mut cmps_opnds = Vec::new(); for cv in case_values { // Build a constant of the case value. - let jit_const = jit_ir::Const::Int(jit_tyidx, *cv); + let jit_const = jit_ir::Const::Int(jit_tyidx, ArbBitInt::from_u64(bitw, *cv)); let jit_const_opnd = jit_ir::Operand::Const(self.jit_mod.insert_const(jit_const)?); @@ -1083,27 +1088,28 @@ impl TraceBuilder { // Insert a guard to ensure the trace only runs if the value we encounter is the // same each time. - let ty = self.handle_type(self.aot_mod.type_(*tyidx))?; + let tyidx = self.handle_type(self.aot_mod.type_(*tyidx))?; // Create the constant from the runtime value. - let c = match self.jit_mod.type_(ty) { + let ty = self.jit_mod.type_(tyidx); + let c = match ty { jit_ir::Ty::Void => unreachable!(), - jit_ir::Ty::Integer(width_bits) => { - let width_bytes = usize::try_from(*width_bits).unwrap() / 8; - let v = match width_bits { + jit_ir::Ty::Integer(bitw) => { + let bytew = ty.byte_size().unwrap(); + let v = match *bitw { 64 => u64::from_ne_bytes( - self.promotions[self.promote_idx..self.promote_idx + width_bytes] + self.promotions[self.promote_idx..self.promote_idx + bytew] .try_into() .unwrap(), ), 32 => u64::from(u32::from_ne_bytes( - self.promotions[self.promote_idx..self.promote_idx + width_bytes] + self.promotions[self.promote_idx..self.promote_idx + bytew] .try_into() .unwrap(), )), x => todo!("{x}"), }; - self.promote_idx += width_bytes; - Const::Int(ty, v) + self.promote_idx += bytew; + Const::Int(tyidx, ArbBitInt::from_u64(*bitw, v)) } jit_ir::Ty::Ptr => todo!(), jit_ir::Ty::Func(_) => todo!(), diff --git a/ykrt/src/lib.rs b/ykrt/src/lib.rs index be6a7705c..8e5c4d18a 100644 --- a/ykrt/src/lib.rs +++ b/ykrt/src/lib.rs @@ -3,6 +3,7 @@ #![cfg_attr(test, feature(test))] #![feature(assert_matches)] #![feature(int_roundings)] +#![feature(let_chains)] #![feature(naked_functions)] #![feature(ptr_sub_ptr)] #![allow(clippy::type_complexity)] diff --git a/ykrt/src/log/stats.rs b/ykrt/src/log/stats.rs index 3350b1ae1..9c249a138 100644 --- a/ykrt/src/log/stats.rs +++ b/ykrt/src/log/stats.rs @@ -71,6 +71,7 @@ impl Stats { pub fn new() -> Self { Self { inner: Some(Mutex::new(StatsInner::new("-".to_string()))), + #[cfg(feature = "yk_testing")] wait_until_condvar: None, } }