Skip to content

Commit

Permalink
Added WidePow2 trait.
Browse files Browse the repository at this point in the history
Allowing optimized implementaion of WideMul squaring.

commit-id:77870a43
  • Loading branch information
orizi committed Aug 12, 2024
1 parent 9f88e35 commit 4d24ee3
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 3 deletions.
2 changes: 1 addition & 1 deletion corelib/src/integer.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ pub fn u256_wide_mul(a: u256, b: u256) -> u512 nopanic {

/// Helper function for implementation of `u256_wide_mul`.
/// Used for adding two u128s and receiving a BoundedInt for the carry result.
fn u128_add_with_bounded_int_carry(
pub(crate) fn u128_add_with_bounded_int_carry(
a: u128, b: u128
) -> (u128, core::internal::bounded_int::BoundedInt<0, 1>) nopanic {
match u128_overflowing_add(a, b) {
Expand Down
1 change: 1 addition & 0 deletions corelib/src/num/traits.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ pub use ops::wrapping::{WrappingAdd, WrappingSub, WrappingMul};
pub use ops::checked::{CheckedAdd, CheckedSub, CheckedMul};
pub use ops::saturating::{SaturatingAdd, SaturatingSub, SaturatingMul};
pub use ops::widemul::WideMul;
pub use ops::widepow2::WidePow2;
pub use ops::sqrt::Sqrt;
1 change: 1 addition & 0 deletions corelib/src/num/traits/ops.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pub mod checked;
pub mod saturating;
pub(crate) mod sqrt;
pub(crate) mod widemul;
pub(crate) mod widepow2;
61 changes: 61 additions & 0 deletions corelib/src/num/traits/ops/widepow2.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use core::num::traits::WideMul;

/// A trait for a type that can be squared to produce a wider type.
pub trait WidePow2<T> {
/// The type of the result of the power of 2.
type Target;
/// Calculates the power of 2, producing a wider type.
fn wide_pow2(self: T) -> Self::Target;
}

mod wide_mul_based {
pub impl TWidePow2<T, impl TWideMul: super::WideMul<T, T>, +Copy<T>> of super::WidePow2<T> {
type Target = TWideMul::Target;
fn wide_pow2(self: T) -> Self::Target {
TWideMul::wide_mul(self, self)
}
}
}

impl WidePow2I8 = wide_mul_based::TWidePow2<i8>;
impl WidePow2I16 = wide_mul_based::TWidePow2<i16>;
impl WidePow2I32 = wide_mul_based::TWidePow2<i32>;
impl WidePow2I64 = wide_mul_based::TWidePow2<i64>;
impl WidePow2U8 = wide_mul_based::TWidePow2<u8>;
impl WidePow2U16 = wide_mul_based::TWidePow2<u16>;
impl WidePow2U32 = wide_mul_based::TWidePow2<u32>;
impl WidePow2U64 = wide_mul_based::TWidePow2<u64>;
impl WidePow2U128 = wide_mul_based::TWidePow2<u128>;
impl WidePow2U256 of WidePow2<u256> {
type Target = core::integer::u512;
fn wide_pow2(self: u256) -> Self::Target {
inner::u256_wide_pow2(self)
}
}

mod inner {
use core::integer::{u512, u128_add_with_bounded_int_carry, upcast};
use core::internal::bounded_int;
use core::num::traits::{WidePow2, WideMul, WrappingAdd};

pub fn u256_wide_pow2(value: u256) -> u512 {
let u256 { high: limb1, low: limb0 } = value.low.wide_pow2();
let u256 { high: limb2, low: limb1_part } = value.low.wide_mul(value.high);
let (limb1, limb1_overflow0) = u128_add_with_bounded_int_carry(limb1, limb1_part);
let (limb1, limb1_overflow1) = u128_add_with_bounded_int_carry(limb1, limb1_part);
let (limb2, limb2_overflow0) = u128_add_with_bounded_int_carry(limb2, limb2);
let u256 { high: limb3, low: limb2_part } = value.high.wide_pow2();
let (limb2, limb2_overflow1) = u128_add_with_bounded_int_carry(limb2, limb2_part);
// Packing together the overflow bits, making a cheaper addition into limb2.
let limb1_overflow = bounded_int::add(limb1_overflow0, limb1_overflow1);
let (limb2, limb2_overflow2) = u128_add_with_bounded_int_carry(
limb2, upcast(limb1_overflow)
);
// Packing together the overflow bits, making a cheaper addition into limb3.
let limb2_overflow = bounded_int::add(limb2_overflow0, limb2_overflow1);
let limb2_overflow = bounded_int::add(limb2_overflow, limb2_overflow2);
// No overflow since no limb4.
let limb3 = limb3.wrapping_add(upcast(limb2_overflow));
u512 { limb0, limb1, limb2, limb3 }
}
}
27 changes: 25 additions & 2 deletions corelib/src/test/integer_test.cairo
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[feature("deprecated-bounded-int-trait")]
use core::{integer, integer::{u512_safe_div_rem_by_u256, u512}};
use core::test::test_utils::{assert_eq, assert_ne, assert_le, assert_lt, assert_gt, assert_ge};
use core::num::traits::{Bounded, Sqrt, WideMul, WrappingSub};
use core::num::traits::{Bounded, Sqrt, WideMul, WidePow2, WrappingSub};

#[test]
fn test_u8_operators() {
Expand Down Expand Up @@ -706,6 +706,29 @@ fn test_u256_wide_mul() {
);
}

#[test]
fn test_u256_wide_pow2() {
assert!(0_u256.wide_pow2() == u512 { limb0: 0, limb1: 0, limb2: 0, limb3: 0 });
assert!(
0x1001001001001001001001001001001001001001001001001001_u256
.wide_pow2() == u512 {
limb0: 0x0b00a009008007006005004003002001,
limb1: 0xe00f01001101201101000f00e00d00c0,
limb2: 0x00400500600700800900a00b00c00d00,
limb3: 0x1002003
}
);
assert!(
0x1000100010001000100010001000100010001000100010001000100010001_u256
.wide_pow2() == u512 {
limb0: 0x00080007000600050004000300020001,
limb1: 0x0010000f000e000d000c000b000a0009,
limb2: 0x00080009000a000b000c000d000e000f,
limb3: 0x1000200030004000500060007
}
);
}

#[test]
fn test_u512_safe_div_rem_by_u256() {
let zero = u512 { limb0: 0, limb1: 0, limb2: 0, limb3: 0 };
Expand Down Expand Up @@ -846,7 +869,7 @@ fn test_u256_sqrt() {
assert!(1_u256.sqrt() == 1);
assert!(0_u256.sqrt() == 0);
assert!(Bounded::<u256>::MAX.sqrt() == Bounded::<u128>::MAX);
assert!(Bounded::<u128>::MAX.wide_mul(Bounded::<u128>::MAX).sqrt() == Bounded::<u128>::MAX);
assert!(Bounded::<u128>::MAX.wide_pow2().sqrt() == Bounded::<u128>::MAX);
}

#[test]
Expand Down

0 comments on commit 4d24ee3

Please sign in to comment.