From 1d7bea6423212061ac1530a4db5851f7533ab099 Mon Sep 17 00:00:00 2001 From: chachaleo Date: Mon, 8 Jul 2024 11:55:41 +0200 Subject: [PATCH] addition of f32x32 type4 --- .../src/{f16x16 => }/core_trait.cairo | 50 ++++ packages/orion-numbers/src/f16x16.cairo | 1 - packages/orion-numbers/src/f16x16/core.cairo | 8 +- packages/orion-numbers/src/f16x16/erf.cairo | 3 +- .../orion-numbers/src/f16x16/helpers.cairo | 9 +- packages/orion-numbers/src/f16x16/lut.cairo | 2 +- packages/orion-numbers/src/f16x16/math.cairo | 224 +++++++++--------- packages/orion-numbers/src/f16x16/trig.cairo | 10 +- packages/orion-numbers/src/f32x32.cairo | 4 + packages/orion-numbers/src/f32x32/core.cairo | 201 ++++++++++++++++ .../orion-numbers/src/f32x32/helpers.cairo | 40 ++++ packages/orion-numbers/src/f32x32/math.cairo | 82 +++++++ packages/orion-numbers/src/lib.cairo | 57 ++++- 13 files changed, 557 insertions(+), 134 deletions(-) rename packages/orion-numbers/src/{f16x16 => }/core_trait.cairo (64%) create mode 100644 packages/orion-numbers/src/f32x32.cairo create mode 100644 packages/orion-numbers/src/f32x32/core.cairo create mode 100644 packages/orion-numbers/src/f32x32/helpers.cairo create mode 100644 packages/orion-numbers/src/f32x32/math.cairo diff --git a/packages/orion-numbers/src/f16x16/core_trait.cairo b/packages/orion-numbers/src/core_trait.cairo similarity index 64% rename from packages/orion-numbers/src/f16x16/core_trait.cairo rename to packages/orion-numbers/src/core_trait.cairo index 06a47fee3..fde940c8f 100644 --- a/packages/orion-numbers/src/f16x16/core_trait.cairo +++ b/packages/orion-numbers/src/core_trait.cairo @@ -1,3 +1,5 @@ +// Basic Arithmetic Trait on integer 32, 64 and 128, should be included in Cairo core soon. + pub impl I32Div of Div { fn div(lhs: i32, rhs: i32) -> i32 { assert(rhs != 0, 'divisor cannot be 0'); @@ -73,6 +75,54 @@ pub impl I64Rem of Rem { } } +pub impl I128Div of Div { + fn div(lhs: i128, rhs: i128) -> i128 { + assert(rhs != 0, 'divisor cannot be 0'); + + let mut lhs_positive = lhs; + let mut rhs_positive = rhs; + + if lhs < 0 { + lhs_positive = lhs * -1; + } + if rhs < 0 { + rhs_positive = rhs * -1; + } + + let lhs_u128: u128 = lhs_positive.try_into().unwrap(); + let rhs_u128: u128 = rhs_positive.try_into().unwrap(); + + let mut result = lhs_u128 / rhs_u128; + let felt_result: felt252 = result.into(); + let signed_int_result: i128 = felt_result.try_into().unwrap(); + + // avoids mul overflow for f16x16 + if sign_i128(lhs) * rhs < 0 { + signed_int_result * -1 + } else { + signed_int_result + } + } +} + +pub impl I128Rem of Rem { + fn rem(lhs: i128, rhs: i128) -> i128 { + let div = Div::div(lhs, rhs); + lhs - rhs * div + } +} + + +pub fn sign_i128(a: i128) -> i128 { + if a == 0 { + 0 + } else if a > 0 { + 1 + } else { + -1 + } +} + pub fn sign_i32(a: i32) -> i32 { if a == 0 { 0 diff --git a/packages/orion-numbers/src/f16x16.cairo b/packages/orion-numbers/src/f16x16.cairo index 84b9046f0..474264063 100644 --- a/packages/orion-numbers/src/f16x16.cairo +++ b/packages/orion-numbers/src/f16x16.cairo @@ -4,4 +4,3 @@ pub mod trig; pub mod erf; pub mod helpers; pub mod lut; -pub mod core_trait; \ No newline at end of file diff --git a/packages/orion-numbers/src/f16x16/core.cairo b/packages/orion-numbers/src/f16x16/core.cairo index 884322559..d3ee15e58 100644 --- a/packages/orion-numbers/src/f16x16/core.cairo +++ b/packages/orion-numbers/src/f16x16/core.cairo @@ -1,7 +1,5 @@ use orion_numbers::f16x16::{math, trig, erf}; -use core::traits::TryInto; -use core::option::OptionTrait; -use core::traits::Mul; +use orion_numbers::FixedTrait; pub type f16x16 = i32; @@ -12,8 +10,8 @@ pub const HALF: f16x16 = 32768; // 2 ** 15 pub const MAX: f16x16 = 2147483647; // 2 ** 31 -1 pub const MIN: f16x16 = -2147483648; // 2 ** 31 -#[generate_trait] -pub impl f16x16Impl of FixedTrait { + +pub impl F16x16Impl of FixedTrait { fn ZERO() -> f16x16 { 0 } diff --git a/packages/orion-numbers/src/f16x16/erf.cairo b/packages/orion-numbers/src/f16x16/erf.cairo index 611e34e8e..07c43d32e 100644 --- a/packages/orion-numbers/src/f16x16/erf.cairo +++ b/packages/orion-numbers/src/f16x16/erf.cairo @@ -1,4 +1,5 @@ -use orion_numbers::f16x16::{core::{FixedTrait, f16x16, ONE}, lut}; +use orion_numbers::f16x16::{core::{f16x16, ONE}, lut}; +use orion_numbers::FixedTrait; const ERF_COMPUTATIONAL_ACCURACY: i32 = 100; const ROUND_CHECK_NUMBER: i32 = 10; diff --git a/packages/orion-numbers/src/f16x16/helpers.cairo b/packages/orion-numbers/src/f16x16/helpers.cairo index f2e26fd5e..ada4e617b 100644 --- a/packages/orion-numbers/src/f16x16/helpers.cairo +++ b/packages/orion-numbers/src/f16x16/helpers.cairo @@ -1,6 +1,5 @@ -use orion_numbers::f16x16::core::{FixedTrait, f16x16, ONE, HALF}; - -use orion_numbers::f16x16::core_trait::I32Div; +use orion_numbers::f16x16::core::{F16x16Impl, f16x16, ONE, HALF}; +use orion_numbers::core_trait::I32Div; const DEFAULT_PRECISION: i32 = 7; // 1e-4 @@ -14,7 +13,7 @@ pub fn assert_precise( Option::None => DEFAULT_PRECISION, }; - let diff = (result - FixedTrait::from_felt(expected)); + let diff = (result - F16x16Impl::from_felt(expected)); if (diff > precision) { //println!("{}", result); @@ -30,7 +29,7 @@ pub fn assert_relative( Option::None => DEFAULT_PRECISION, }; - let diff = result - FixedTrait::from_felt(expected); + let diff = result - F16x16Impl::from_felt(expected); let rel_diff = diff / result; if (rel_diff > precision) { diff --git a/packages/orion-numbers/src/f16x16/lut.cairo b/packages/orion-numbers/src/f16x16/lut.cairo index a834d4b5a..696ca91a7 100644 --- a/packages/orion-numbers/src/f16x16/lut.cairo +++ b/packages/orion-numbers/src/f16x16/lut.cairo @@ -1,6 +1,6 @@ use orion_numbers::f16x16::core::ONE; -use orion_numbers::f16x16::core_trait::I32Div; +use orion_numbers::core_trait::I32Div; // Calculates the most significant bit pub fn msb(whole: i32) -> (i32, i32) { diff --git a/packages/orion-numbers/src/f16x16/math.cairo b/packages/orion-numbers/src/f16x16/math.cairo index bc885f298..2d2273b2c 100644 --- a/packages/orion-numbers/src/f16x16/math.cairo +++ b/packages/orion-numbers/src/f16x16/math.cairo @@ -1,11 +1,7 @@ -use core::traits::TryInto; -use core::option::OptionTrait; -use core::traits::{DivRem, Into}; use core::integer; +use orion_numbers::f16x16::{core::{F16x16Impl, f16x16, ONE, HALF}, lut}; -use orion_numbers::f16x16::{core::{FixedTrait, f16x16, ONE, HALF}, lut}; - -use orion_numbers::f16x16::core_trait::{I32Rem, I32Div, I64Div}; //I32TryIntoNonZero, I32DivRem +use orion_numbers::core_trait::{I32Rem, I32Div, I64Div}; //I32TryIntoNonZero, I32DivRem pub fn abs(a: f16x16) -> f16x16 { @@ -30,9 +26,9 @@ pub fn ceil(a: f16x16) -> f16x16 { let rem = Rem::rem(a, ONE); if rem == 0 { - FixedTrait::new_unscaled(div) + F16x16Impl::new_unscaled(div) } else { - FixedTrait::new_unscaled(div) + ONE + F16x16Impl::new_unscaled(div) + ONE } } @@ -41,48 +37,48 @@ pub fn div(a: f16x16, b: f16x16) -> f16x16 { let res_i64 = a_i64 / b.into(); // Re-apply sign - FixedTrait::new(res_i64.try_into().unwrap()) + F16x16Impl::new(res_i64.try_into().unwrap()) } // Calculates the natural exponent of x: e^x pub fn exp(a: f16x16) -> f16x16 { - exp2(FixedTrait::mul(FixedTrait::new(94548), a)) // log2(e) * 2^23 ≈ 12102203 + exp2(F16x16Impl::mul(F16x16Impl::new(94548), a)) // log2(e) * 2^23 ≈ 12102203 } // Calculates the binary exponent of x: 2^x pub fn exp2(a: f16x16) -> f16x16 { if (a == 0) { - return FixedTrait::ONE(); + return F16x16Impl::ONE(); } //let (int_part, frac_part) = DivRem::div_rem(a.abs(), ONE.try_into().unwrap()); let int_part = Div::div(a.abs(), ONE); let frac_part = Rem::rem(a.abs(), ONE); - let int_res = FixedTrait::new_unscaled(lut::exp2(int_part)); + let int_res = F16x16Impl::new_unscaled(lut::exp2(int_part)); let mut res_u = int_res; if frac_part != 0 { - let frac = FixedTrait::new(frac_part); - let r7 = FixedTrait::mul(FixedTrait::new(1), frac); - let r6 = FixedTrait::mul((r7 + FixedTrait::new(10)), frac); - let r5 = FixedTrait::mul((r6 + FixedTrait::new(87)), frac); - let r4 = FixedTrait::mul((r5 + FixedTrait::new(630)), frac); - let r3 = FixedTrait::mul((r4 + FixedTrait::new(3638)), frac); - let r2 = FixedTrait::mul((r3 + FixedTrait::new(15743)), frac); - let r1 = FixedTrait::mul((r2 + FixedTrait::new(45426)), frac); - res_u = FixedTrait::mul(res_u, (r1 + FixedTrait::ONE())); + let frac = F16x16Impl::new(frac_part); + let r7 = F16x16Impl::mul(F16x16Impl::new(1), frac); + let r6 = F16x16Impl::mul((r7 + F16x16Impl::new(10)), frac); + let r5 = F16x16Impl::mul((r6 + F16x16Impl::new(87)), frac); + let r4 = F16x16Impl::mul((r5 + F16x16Impl::new(630)), frac); + let r3 = F16x16Impl::mul((r4 + F16x16Impl::new(3638)), frac); + let r2 = F16x16Impl::mul((r3 + F16x16Impl::new(15743)), frac); + let r1 = F16x16Impl::mul((r2 + F16x16Impl::new(45426)), frac); + res_u = F16x16Impl::mul(res_u, (r1 + F16x16Impl::ONE())); } if a < 0 { - FixedTrait::div(FixedTrait::ONE(), res_u) + F16x16Impl::div(F16x16Impl::ONE(), res_u) } else { res_u } } fn exp2_int(exp: i32) -> f16x16 { - FixedTrait::new_unscaled(lut::exp2(exp)) + F16x16Impl::new_unscaled(lut::exp2(exp)) } pub fn floor(a: f16x16) -> f16x16 { @@ -93,16 +89,16 @@ pub fn floor(a: f16x16) -> f16x16 { if rem == 0 { a } else if a >= 0 { - FixedTrait::new_unscaled(div) + F16x16Impl::new_unscaled(div) } else { - FixedTrait::new_unscaled(div - 1) + F16x16Impl::new_unscaled(div - 1) } } // Calculates the natural logarithm of x: ln(x) // self must be greater than zero pub fn ln(a: f16x16) -> f16x16 { - FixedTrait::mul(FixedTrait::new(45426), log2(a)) // ln(2) = 0.693... + F16x16Impl::mul(F16x16Impl::new(45426), log2(a)) // ln(2) = 0.693... } // Calculates the binary logarithm of x: log2(x) @@ -111,10 +107,10 @@ pub fn log2(a: f16x16) -> f16x16 { assert(a >= 0, 'must be positive'); if (a == ONE) { - return FixedTrait::ZERO(); + return F16x16Impl::ZERO(); } else if (a < ONE) { // Compute true inverse binary log if 0 < x < 1 - let div = FixedTrait::div(FixedTrait::ONE(), a); + let div = F16x16Impl::div(F16x16Impl::ONE(), a); return -log2(div); } @@ -122,33 +118,33 @@ pub fn log2(a: f16x16) -> f16x16 { let (msb, div) = lut::msb(whole); if a == div * ONE { - FixedTrait::new_unscaled(msb) + F16x16Impl::new_unscaled(msb) } else { - let norm = FixedTrait::div(a, FixedTrait::new_unscaled(div)); - let r8 = FixedTrait::mul(FixedTrait::new(-596), norm); - let r7 = FixedTrait::mul((r8 + FixedTrait::new(8116)), norm); - let r6 = FixedTrait::mul((r7 + FixedTrait::new(-49044)), norm); - let r5 = FixedTrait::mul((r6 + FixedTrait::new(172935)), norm); - let r4 = FixedTrait::mul((r5 + FixedTrait::new(-394096)), norm); - let r3 = FixedTrait::mul((r4 + FixedTrait::new(608566)), norm); - let r2 = FixedTrait::mul((r3 + FixedTrait::new(-655828)), norm); - let r1 = FixedTrait::mul((r2 + FixedTrait::new(534433)), norm); - - r1 + FixedTrait::new(-224487) + FixedTrait::new_unscaled(msb) + let norm = F16x16Impl::div(a, F16x16Impl::new_unscaled(div)); + let r8 = F16x16Impl::mul(F16x16Impl::new(-596), norm); + let r7 = F16x16Impl::mul((r8 + F16x16Impl::new(8116)), norm); + let r6 = F16x16Impl::mul((r7 + F16x16Impl::new(-49044)), norm); + let r5 = F16x16Impl::mul((r6 + F16x16Impl::new(172935)), norm); + let r4 = F16x16Impl::mul((r5 + F16x16Impl::new(-394096)), norm); + let r3 = F16x16Impl::mul((r4 + F16x16Impl::new(608566)), norm); + let r2 = F16x16Impl::mul((r3 + F16x16Impl::new(-655828)), norm); + let r1 = F16x16Impl::mul((r2 + F16x16Impl::new(534433)), norm); + + r1 + F16x16Impl::new(-224487) + F16x16Impl::new_unscaled(msb) } } // Calculates the base 10 log of x: log10(x) // self must be greater than zero pub fn log10(a: f16x16) -> f16x16 { - FixedTrait::mul(FixedTrait::new(19728), log2(a)) // log10(2) = 0.301... + F16x16Impl::mul(F16x16Impl::new(19728), log2(a)) // log10(2) = 0.301... } pub fn mul(a: f16x16, b: f16x16) -> f16x16 { let prod_i64 = integer::i32_wide_mul(a, b); // Re-apply sign - FixedTrait::new((prod_i64 / ONE.into()).try_into().unwrap()) + F16x16Impl::new((prod_i64 / ONE.into()).try_into().unwrap()) } // Calclates the value of x^y and checks for overflow before returning @@ -163,7 +159,7 @@ pub fn pow(a: f16x16, b: f16x16) -> f16x16 { } // x^y = exp(y*ln(x)) for x > 0 will error for x < 0 - exp(FixedTrait::mul(b, ln(a))) + exp(F16x16Impl::mul(b, ln(a))) } // Calclates the value of a^b and checks for overflow before returning @@ -172,7 +168,7 @@ fn pow_int(a: f16x16, b: i32) -> f16x16 { let mut n = b.abs(); if b < 0 { - x = FixedTrait::div(ONE, x); + x = F16x16Impl::div(ONE, x); } if n == 0 { @@ -188,14 +184,14 @@ fn pow_int(a: f16x16, b: i32) -> f16x16 { let rem = Rem::rem(n, two); if rem == 1 { - y = FixedTrait::mul(x, y); + y = F16x16Impl::mul(x, y); } - x = FixedTrait::mul(x, x); + x = F16x16Impl::mul(x, x); n = div; }; - FixedTrait::mul(x, y) + F16x16Impl::mul(x, y) } pub fn round(a: f16x16) -> f16x16 { @@ -204,15 +200,15 @@ pub fn round(a: f16x16) -> f16x16 { let rem = Rem::rem(a, ONE); if (HALF <= rem) { - FixedTrait::new_unscaled(div + 1) + F16x16Impl::new_unscaled(div + 1) } else { - FixedTrait::new_unscaled(div) + F16x16Impl::new_unscaled(div) } } pub fn sign(a: f16x16) -> f16x16 { if a == 0 { - FixedTrait::new(0) + F16x16Impl::new(0) } else if a > 0 { ONE } else { @@ -228,7 +224,7 @@ pub fn sqrt(a: f16x16) -> f16x16 { let root = integer::u64_sqrt(a.try_into().unwrap() * ONE.try_into().unwrap()); - FixedTrait::new(root.try_into().unwrap()) + F16x16Impl::new(root.try_into().unwrap()) } @@ -242,35 +238,35 @@ mod tests { use orion_numbers::f16x16::helpers::{assert_precise, assert_relative}; use super::{ - FixedTrait, ONE, HALF, f16x16, integer, lut, ceil, add, sqrt, floor, exp, exp2, exp2_int, + F16x16Impl, ONE, HALF, f16x16, integer, lut, ceil, add, sqrt, floor, exp, exp2, exp2_int, ln, log2, log10, pow, round, sign }; - use orion_numbers::f16x16::core_trait::{I32Rem, I32Div}; + use orion_numbers::core_trait::{I32Rem, I32Div}; #[test] fn test_into() { - let a = FixedTrait::new_unscaled(5); + let a = F16x16Impl::new_unscaled(5); assert(a == 5 * ONE, 'invalid result'); } #[test] fn test_ceil() { - let a = FixedTrait::new(190054); // 2.9 + let a = F16x16Impl::new(190054); // 2.9 assert(ceil(a) == 3 * ONE, 'invalid pos decimal'); } #[test] #[available_gas(10000000)] fn test_exp() { - let a = FixedTrait::new_unscaled(2); + let a = F16x16Impl::new_unscaled(2); assert_relative(exp(a), 484249, 'invalid exp of 2', Option::None(())); // 7.389056098793725 } #[test] #[available_gas(400000)] fn test_exp2() { - let a = FixedTrait::new_unscaled(5); + let a = F16x16Impl::new_unscaled(5); assert(exp2(a) == 2097152, 'invalid exp2 of 2'); } @@ -282,34 +278,34 @@ mod tests { #[test] fn test_floor() { - let a = FixedTrait::new(190054); // 2.9 + let a = F16x16Impl::new(190054); // 2.9 assert(floor(a) == 2 * ONE, 'invalid pos decimal'); } #[test] #[available_gas(1000000)] fn test_ln() { - let mut a = FixedTrait::new_unscaled(1); + let mut a = F16x16Impl::new_unscaled(1); assert(ln(a) == 0, 'invalid ln of 1'); - a = FixedTrait::new(178145); + a = F16x16Impl::new(178145); assert_relative(ln(a), ONE.into(), 'invalid ln of 2.7...', Option::None(())); } #[test] #[available_gas(1000000)] fn test_log2() { - let mut a = FixedTrait::new_unscaled(32); - assert(log2(a) == FixedTrait::new_unscaled(5), 'invalid log2 32'); + let mut a = F16x16Impl::new_unscaled(32); + assert(log2(a) == F16x16Impl::new_unscaled(5), 'invalid log2 32'); - a = FixedTrait::new_unscaled(10); + a = F16x16Impl::new_unscaled(10); assert_relative(log2(a), 217706, 'invalid log2 10', Option::None(())); // 3.321928094887362 } #[test] #[available_gas(1000000)] fn test_log10() { - let a = FixedTrait::new_unscaled(100); + let a = F16x16Impl::new_unscaled(100); assert_relative(log10(a), 2 * ONE.into(), 'invalid log10', Option::None(())); } @@ -317,16 +313,16 @@ mod tests { #[test] #[available_gas(600000)] fn test_pow() { - let a = FixedTrait::new_unscaled(3); - let b = FixedTrait::new_unscaled(4); + let a = F16x16Impl::new_unscaled(3); + let b = F16x16Impl::new_unscaled(4); assert(pow(a, b) == 81 * ONE, 'invalid pos base power'); } #[test] #[available_gas(900000)] fn test_pow_frac() { - let a = FixedTrait::new_unscaled(3); - let b = FixedTrait::new(32768); // 0.5 + let a = F16x16Impl::new_unscaled(3); + let b = F16x16Impl::new(32768); // 0.5 assert_relative( pow(a, b), 113512, 'invalid pos base power', Option::None(()) ); // 1.7320508075688772 @@ -334,50 +330,50 @@ mod tests { #[test] fn test_round() { - let a = FixedTrait::new(190054); // 2.9 + let a = F16x16Impl::new(190054); // 2.9 assert(round(a) == 3 * ONE, 'invalid pos decimal'); } #[test] fn test_sqrt() { - let mut a = FixedTrait::new_unscaled(0); + let mut a = F16x16Impl::new_unscaled(0); assert(sqrt(a) == 0, 'invalid zero root'); - a = FixedTrait::new_unscaled(25); + a = F16x16Impl::new_unscaled(25); assert(sqrt(a) == 5 * ONE, 'invalid pos root'); } #[test] #[should_panic] fn test_sqrt_fail() { - let a = FixedTrait::new_unscaled(-25); + let a = F16x16Impl::new_unscaled(-25); sqrt(a); } #[test] #[available_gas(2000000)] fn test_sign() { - let a = FixedTrait::new(0); + let a = F16x16Impl::new(0); assert(a.sign() == 0, 'invalid sign (0)'); - let a = FixedTrait::new(-HALF); + let a = F16x16Impl::new(-HALF); assert(a.sign() == -ONE, 'invalid sign (-HALF)'); - let a = FixedTrait::new(HALF); + let a = F16x16Impl::new(HALF); assert(a.sign() == ONE, 'invalid sign (HALF)'); - let a = FixedTrait::new(-ONE); + let a = F16x16Impl::new(-ONE); assert(a.sign() == -ONE, 'invalid sign (-ONE)'); - let a = FixedTrait::new(ONE); + let a = F16x16Impl::new(ONE); assert(a.sign() == ONE, 'invalid sign (ONE)'); } #[test] #[available_gas(100000)] fn test_msb() { - let a = FixedTrait::new_unscaled(100); + let a = F16x16Impl::new_unscaled(100); let (msb, div) = lut::msb(a / ONE); assert(msb == 6, 'invalid msb'); assert(div == 64, 'invalid msb ceil'); @@ -385,49 +381,49 @@ mod tests { #[test] fn test_eq() { - let a = FixedTrait::new_unscaled(42); - let b = FixedTrait::new_unscaled(42); + let a = F16x16Impl::new_unscaled(42); + let b = F16x16Impl::new_unscaled(42); let c = a == b; assert(c, 'invalid result'); } #[test] fn test_ne() { - let a = FixedTrait::new_unscaled(42); - let b = FixedTrait::new_unscaled(42); + let a = F16x16Impl::new_unscaled(42); + let b = F16x16Impl::new_unscaled(42); let c = a != b; assert(!c, 'invalid result'); } #[test] fn test_add() { - let a = FixedTrait::new_unscaled(1); - let b = FixedTrait::new_unscaled(2); - assert(add(a, b) == FixedTrait::new_unscaled(3), 'invalid result'); + let a = F16x16Impl::new_unscaled(1); + let b = F16x16Impl::new_unscaled(2); + assert(add(a, b) == F16x16Impl::new_unscaled(3), 'invalid result'); } #[test] fn test_add_eq() { - let mut a = FixedTrait::new_unscaled(1); - let b = FixedTrait::new_unscaled(2); + let mut a = F16x16Impl::new_unscaled(1); + let b = F16x16Impl::new_unscaled(2); a += b; - assert(a == FixedTrait::new_unscaled(3), 'invalid result'); + assert(a == F16x16Impl::new_unscaled(3), 'invalid result'); } #[test] fn test_sub() { - let a = FixedTrait::new_unscaled(5); - let b = FixedTrait::new_unscaled(2); + let a = F16x16Impl::new_unscaled(5); + let b = F16x16Impl::new_unscaled(2); let c = a - b; - assert(c == FixedTrait::new_unscaled(3), 'false result invalid'); + assert(c == F16x16Impl::new_unscaled(3), 'false result invalid'); } #[test] fn test_sub_eq() { - let mut a = FixedTrait::new_unscaled(5); - let b = FixedTrait::new_unscaled(2); + let mut a = F16x16Impl::new_unscaled(5); + let b = F16x16Impl::new_unscaled(2); a -= b; - assert(a == FixedTrait::new_unscaled(3), 'invalid result'); + assert(a == F16x16Impl::new_unscaled(3), 'invalid result'); } #[test] @@ -435,32 +431,32 @@ mod tests { fn test_mul_pos() { let a = 190054; let b = 190054; - let c = FixedTrait::mul(a, b); + let c = F16x16Impl::mul(a, b); assert(c == 551155, 'invalid result'); } #[test] fn test_mul_neg() { - let a = FixedTrait::new_unscaled(5); - let b = FixedTrait::new_unscaled(-2); - let c = FixedTrait::mul(a, b); - assert(c == FixedTrait::new_unscaled(-10), 'invalid result'); + let a = F16x16Impl::new_unscaled(5); + let b = F16x16Impl::new_unscaled(-2); + let c = F16x16Impl::mul(a, b); + assert(c == F16x16Impl::new_unscaled(-10), 'invalid result'); } #[test] fn test_div() { - let a = FixedTrait::new_unscaled(10); - let b = FixedTrait::new(190054); // 2.9 - let c = FixedTrait::div(a, b); + let a = F16x16Impl::new_unscaled(10); + let b = F16x16Impl::new(190054); // 2.9 + let c = F16x16Impl::div(a, b); assert(c == 225986, 'invalid pos decimal'); // 3.4482758620689653 } #[test] fn test_le() { - let a = FixedTrait::new_unscaled(1); - let b = FixedTrait::new_unscaled(0); - let c = FixedTrait::new_unscaled(-1); + let a = F16x16Impl::new_unscaled(1); + let b = F16x16Impl::new_unscaled(0); + let c = F16x16Impl::new_unscaled(-1); assert(a <= a, 'a <= a'); assert(!(a <= b), 'a <= b'); @@ -477,9 +473,9 @@ mod tests { #[test] fn test_lt() { - let a = FixedTrait::new_unscaled(1); - let b = FixedTrait::new_unscaled(0); - let c = FixedTrait::new_unscaled(-1); + let a = F16x16Impl::new_unscaled(1); + let b = F16x16Impl::new_unscaled(0); + let c = F16x16Impl::new_unscaled(-1); assert(!(a < a), 'a < a'); assert(!(a < b), 'a < b'); @@ -496,9 +492,9 @@ mod tests { #[test] fn test_ge() { - let a = FixedTrait::new_unscaled(1); - let b = FixedTrait::new_unscaled(0); - let c = FixedTrait::new_unscaled(-1); + let a = F16x16Impl::new_unscaled(1); + let b = F16x16Impl::new_unscaled(0); + let c = F16x16Impl::new_unscaled(-1); assert(a >= a, 'a >= a'); assert(a >= b, 'a >= b'); @@ -515,9 +511,9 @@ mod tests { #[test] fn test_gt() { - let a = FixedTrait::new_unscaled(1); - let b = FixedTrait::new_unscaled(0); - let c = FixedTrait::new_unscaled(-1); + let a = F16x16Impl::new_unscaled(1); + let b = F16x16Impl::new_unscaled(0); + let c = F16x16Impl::new_unscaled(-1); assert(!(a > a), 'a > a'); assert(a > b, 'a > b'); diff --git a/packages/orion-numbers/src/f16x16/trig.cairo b/packages/orion-numbers/src/f16x16/trig.cairo index 551218ff7..97f551e6d 100644 --- a/packages/orion-numbers/src/f16x16/trig.cairo +++ b/packages/orion-numbers/src/f16x16/trig.cairo @@ -1,10 +1,8 @@ -use core::option::OptionTrait; -use core::traits::TryInto; use core::integer; -use orion_numbers::f16x16::core::{FixedTrait, f16x16, ONE, HALF, TWO}; -use orion_numbers::f16x16::lut; +use orion_numbers::f16x16::{core::{f16x16, ONE, HALF, TWO}, lut}; +use orion_numbers::FixedTrait; -use orion_numbers::f16x16::core_trait::{I32Div, I32Rem}; +use orion_numbers::core_trait::{I32Div, I32Rem}; // CONSTANTS const TWO_PI: i32 = 411775; @@ -180,7 +178,7 @@ mod tests { sin_fast, tan_fast, acosh, asinh, atanh, cosh, sinh, tanh }; - use orion_numbers::f16x16::core_trait::I32Div; + use orion_numbers::core_trait::I32Div; #[test] #[available_gas(8000000)] diff --git a/packages/orion-numbers/src/f32x32.cairo b/packages/orion-numbers/src/f32x32.cairo new file mode 100644 index 000000000..59787e7f4 --- /dev/null +++ b/packages/orion-numbers/src/f32x32.cairo @@ -0,0 +1,4 @@ +pub mod core; +pub mod math; +pub mod helpers; + diff --git a/packages/orion-numbers/src/f32x32/core.cairo b/packages/orion-numbers/src/f32x32/core.cairo new file mode 100644 index 000000000..70b2e63d6 --- /dev/null +++ b/packages/orion-numbers/src/f32x32/core.cairo @@ -0,0 +1,201 @@ +use orion_numbers::f32x32::math; +use orion_numbers::FixedTrait; + +pub type f32x32 = i64; + +// CONSTANTS +pub const TWO: f32x32 = 8589934592; // 2 ** 33 +pub const ONE: f32x32 = 4294967296; // 2 ** 32 +pub const HALF: f32x32 = 2147483648; // 2 ** 31 +pub const MAX: f32x32 = 9223372036854775807; // 2 ** 63 -1 +pub const MIN: f32x32 = -9223372036854775808; // -2 ** 63 + + +pub impl F32x32Impl of FixedTrait { + fn ZERO() -> f32x32 { + 0 + } + + fn HALF() -> f32x32 { + HALF + } + + fn ONE() -> f32x32 { + ONE + } + + fn MAX() -> f32x32 { + MAX + } + + fn MIN() -> f32x32 { + MIN + } + + fn new_unscaled(x: i64) -> f32x32 { + x * ONE + } + + fn new(x: i64) -> f32x32 { + x + } + + fn from_felt(x: felt252) -> f32x32 { + x.try_into().unwrap() + } + + fn from_unscaled_felt(x: felt252) -> f32x32 { + return FixedTrait::from_felt(x * ONE.into()); + } + + fn abs(self: f32x32) -> f32x32 { + math::abs(self) + } + + fn div(self: f32x32, rhs: f32x32) -> f32x32 { + math::div(self, rhs) + } + + fn mul(self: f32x32, rhs: f32x32) -> f32x32 { + math::mul(self, rhs) + } + + fn sign(self: f32x32) -> f32x32 { + math::sign(self) + } + + fn acos(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn acosh(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn asin(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn asinh(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn atan(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn atanh(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn add(lhs: f32x32, rhs: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn ceil(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn cos(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn cosh(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn exp(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn exp2(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn floor(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn ln(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn log2(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn log10(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn pow(self: f32x32, b: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn round(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn sin(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn sinh(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn sqrt(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn tan(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn tanh(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + + fn sub(lhs: f32x32, rhs: f32x32) -> f32x32 { + panic!("not implem yet") + } + + fn NaN() -> f32x32 { + -0 + } + + fn is_nan(self: f32x32) -> bool { + self == -0 + } + + fn INF() -> f32x32 { + MAX + } + + fn POS_INF() -> f32x32 { + MAX + } + + fn NEG_INF() -> f32x32 { + MIN + } + + fn is_inf(self: f32x32) -> bool { + self == MAX + } + + fn is_pos_inf(self: f32x32) -> bool { + self == MAX + } + + fn is_neg_inf(self: f32x32) -> bool { + self == MIN + } + + fn erf(self: f32x32) -> f32x32 { + panic!("not implem yet") + } + + +} diff --git a/packages/orion-numbers/src/f32x32/helpers.cairo b/packages/orion-numbers/src/f32x32/helpers.cairo new file mode 100644 index 000000000..606af6efe --- /dev/null +++ b/packages/orion-numbers/src/f32x32/helpers.cairo @@ -0,0 +1,40 @@ +use orion_numbers::f32x32::core::{F32x32Impl, f32x32, ONE, HALF}; +use orion_numbers::core_trait::I64Div; + +const DEFAULT_PRECISION: i64 = 429497; // 1e-4 + +// To use `DEFAULT_PRECISION`, final arg is: `Option::None(())`. +// To use `custom_precision` of 430_i64: `Option::Some(430_i64)`. +pub fn assert_precise( + result: f32x32, expected: felt252, msg: felt252, custom_precision: Option +) { + let precision = match custom_precision { + Option::Some(val) => val, + Option::None => DEFAULT_PRECISION, + }; + + let diff = (result - F32x32Impl::from_felt(expected)); + + if (diff > precision) { + //println!("{}", result); + assert(diff <= precision, msg); + } +} + +pub fn assert_relative( + result: f32x32, expected: felt252, msg: felt252, custom_precision: Option +) { + let precision = match custom_precision { + Option::Some(val) => val, + Option::None => DEFAULT_PRECISION, + }; + + let diff = result - F32x32Impl::from_felt(expected); + let rel_diff = diff / result; + + if (rel_diff > precision) { + //println!("{}", result); + assert(rel_diff <= precision, msg); + } +} + diff --git a/packages/orion-numbers/src/f32x32/math.cairo b/packages/orion-numbers/src/f32x32/math.cairo new file mode 100644 index 000000000..208736f60 --- /dev/null +++ b/packages/orion-numbers/src/f32x32/math.cairo @@ -0,0 +1,82 @@ +use core::integer; +use orion_numbers::f32x32::core::{F32x32Impl, f32x32, ONE, HALF}; + +use orion_numbers::core_trait::{I64Div, I64Rem, I128Div}; + + +pub fn abs(a: f32x32) -> f32x32 { + if a >= 0 { + a + } else { + a * -1_i64 + } +} + +pub fn div(a: f32x32, b: f32x32) -> f32x32 { + let a_i128 = integer::i64_wide_mul(a, ONE); + let res_i128 = a_i128 / b.into(); + + // Re-apply sign + F32x32Impl::new(res_i128.try_into().unwrap()) +} + +pub fn mul(a: f32x32, b: f32x32) -> f32x32 { + let prod_i128 = integer::i64_wide_mul(a, b); + + // Re-apply sign + F32x32Impl::new((prod_i128 / ONE.into()).try_into().unwrap()) +} + +pub fn round(a: f32x32) -> f32x32 { + //let (div, rem) = DivRem::div_rem(a, ONE.try_into().unwrap()); + let div = Div::div(a, ONE); + let rem = Rem::rem(a, ONE); + + if (HALF <= rem) { + F32x32Impl::new_unscaled(div + 1) + } else { + F32x32Impl::new_unscaled(div) + } +} + +pub fn sign(a: f32x32) -> f32x32 { + if a == 0 { + F32x32Impl::new(0) + } else if a > 0 { + ONE + } else { + -ONE + } +} + +// Tests +// +// +// -------------------------------------------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use orion_numbers::f32x32::helpers::{assert_precise, assert_relative}; + + use super::{F32x32Impl, ONE, HALF, f32x32, integer, round, sign}; + + + #[test] + #[available_gas(2000000)] + fn test_sign() { + let a = F32x32Impl::new(0); + assert(a.sign() == 0, 'invalid sign (0)'); + + let a = F32x32Impl::new(-HALF); + assert(a.sign() == -ONE, 'invalid sign (-HALF)'); + + let a = F32x32Impl::new(HALF); + assert(a.sign() == ONE, 'invalid sign (HALF)'); + + let a = F32x32Impl::new(-ONE); + assert(a.sign() == -ONE, 'invalid sign (-ONE)'); + + let a = F32x32Impl::new(ONE); + assert(a.sign() == ONE, 'invalid sign (ONE)'); + } +} diff --git a/packages/orion-numbers/src/lib.cairo b/packages/orion-numbers/src/lib.cairo index 317833eaf..42678f7b7 100644 --- a/packages/orion-numbers/src/lib.cairo +++ b/packages/orion-numbers/src/lib.cairo @@ -1 +1,56 @@ -pub mod f16x16; \ No newline at end of file +pub mod f16x16; +pub mod f32x32; +pub mod core_trait; + +use orion_numbers::f16x16::core::F16x16Impl; +use orion_numbers::f32x32::core::F32x32Impl; + + +trait FixedTrait { + fn ZERO() -> T; + fn HALF() -> T; + fn ONE() -> T; + fn MAX() -> T; + fn MIN() -> T; + fn new_unscaled(x: T) -> T; + fn new(x: T) -> T; + fn from_felt(x: felt252) -> T; + fn from_unscaled_felt(x: felt252) -> T; + fn abs(self: T) -> T; + fn acos(self: T) -> T; + fn acosh(self: T) -> T; + fn asin(self: T) -> T; + fn asinh(self: T) -> T; + fn atan(self: T) -> T; + fn atanh(self: T) -> T; + fn add(lhs: T, rhs: T) -> T; + fn ceil(self: T) -> T; + fn cos(self: T) -> T; + fn cosh(self: T) -> T; + fn div(self: T, rhs: T) -> T; + fn exp(self: T) -> T; + fn exp2(self: T) -> T; + fn floor(self: T) -> T; + fn ln(self: T) -> T; + fn log2(self: T) -> T; + fn log10(self: T) -> T; + fn mul(self: T, rhs: T) -> T; + fn pow(self: T, b: T) -> T; + fn round(self: T) -> T; + fn sin(self: T) -> T; + fn sinh(self: T) -> T; + fn sqrt(self: T) -> T; + fn tan(self: T) -> T; + fn tanh(self: T) -> T; + fn sign(self: T) -> T; + fn sub(lhs: T, rhs: T) -> T; + fn NaN() -> T; + fn is_nan(self: T) -> bool; + fn INF() -> T; + fn POS_INF() -> T; + fn NEG_INF() -> T; + fn is_inf(self: T) -> bool; + fn is_pos_inf(self: T) -> bool; + fn is_neg_inf(self: T) -> bool; + fn erf(self: T) -> T; +}