Skip to content

Commit

Permalink
Added const-folding for u*_overflowing_{add,sub}.
Browse files Browse the repository at this point in the history
commit-id:8ebf86e7
  • Loading branch information
orizi committed Jul 28, 2024
1 parent d923eb7 commit 3e3df1e
Show file tree
Hide file tree
Showing 5 changed files with 5,908 additions and 6,290 deletions.
103 changes: 96 additions & 7 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
mod test;

use cairo_lang_defs::ids::{ExternFunctionId, ModuleId, ModuleItemId};
use cairo_lang_semantic::corelib::{self};
use cairo_lang_semantic::items::constant::ConstValue;
use cairo_lang_semantic::GenericArgumentId;
use cairo_lang_semantic::{corelib, GenericArgumentId};
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use cairo_lang_utils::{try_extract_matches, Intern};
use id_arena::Arena;
use itertools::{chain, zip_eq};
use num_bigint::BigInt;
use num_traits::Zero;
use num_traits::{Num, One, Zero};
use smol_str::SmolStr;

use crate::db::LoweringGroup;
Expand Down Expand Up @@ -246,13 +245,44 @@ impl<'a> ConstFoldingContext<'a> {
FlatBlockEnd::Goto(arm.block_id, Default::default()),
)
});
} else if let Some(ty_name) = self.uadd_fns.get(&info.function) {
let lhs = self.as_int(info.inputs[0].var_id)?;
let rhs = self.as_int(info.inputs[1].var_id)?;
let value = lhs + rhs;
let (arm_index, value) = match Self::value_in_range(ty_name, value) {
Ok(value) => (0, value),
Err(value) => (1, value),
};
let arm = &info.arms[arm_index];
let actual_output = arm.var_ids[0];
let ty = self.variables[actual_output].ty;
let value = ConstValue::Int(value, ty);
self.var_info.insert(actual_output, VarInfo::Const(value.clone()));
return Some((
Some(Statement::Const(StatementConst { value, output: actual_output })),
FlatBlockEnd::Goto(arm.block_id, Default::default()),
));
} else if let Some(ty_name) = self.usub_fns.get(&info.function) {
let lhs = self.as_int(info.inputs[0].var_id)?;
let rhs = self.as_int(info.inputs[1].var_id)?;
let value = lhs - rhs;
let (arm_index, value) = match Self::value_in_range(ty_name, value) {
Ok(value) => (0, value),
Err(value) => (1, value),
};
let arm = &info.arms[arm_index];
let actual_output = arm.var_ids[0];
let ty = self.variables[actual_output].ty;
let value = ConstValue::Int(value, ty);
self.var_info.insert(actual_output, VarInfo::Const(value.clone()));
return Some((
Some(Statement::Const(StatementConst { value, output: actual_output })),
FlatBlockEnd::Goto(arm.block_id, Default::default()),
));
} else if let Some(extrn) = info.function.get_extern(self.db) {
if extrn == self.downcast {
let input_var = info.inputs[0].var_id;
let Some(VarInfo::Const(ConstValue::Int(value, _))) = self.var_info.get(&input_var)
else {
return None;
};
let value = self.as_int(input_var)?;
let success_output = info.arms[0].var_ids[0];
let ty = self.variables[success_output].ty;
return Some(
Expand All @@ -275,6 +305,50 @@ impl<'a> ConstFoldingContext<'a> {
None
}

/// Returns the Ok(value) if the value is in the range of the type `ty_name`.
/// Otherwise, returns Err(value) with the value fixed into the range.
fn value_in_range(ty_name: &str, value: BigInt) -> Result<BigInt, BigInt> {
match ty_name {
"u8" if value < u8::MIN.into() => Err(value + 0x100),
"u8" if value > u8::MAX.into() => Err(value - 0x100),
"u8" => Ok(value),
"u16" if value < u16::MIN.into() => Err(value + 0x10000),
"u16" if value > u16::MAX.into() => Err(value - 0x10000),
"u16" => Ok(value),
"u32" if value < u32::MIN.into() => Err(value + 0x100000000u64),
"u32" if value > u32::MAX.into() => Err(value - 0x100000000u64),
"u32" => Ok(value),
"u64" if value < u64::MIN.into() => Err(value + 0x10000000000000000u128),
"u64" if value > u64::MAX.into() => Err(value - 0x10000000000000000u128),
"u64" => Ok(value),
"u128" if value < u128::MIN.into() => Err(value + (BigInt::one() << 128)),
"u128" if value > u128::MAX.into() => Err(value - (BigInt::one() << 128)),
"u128" => Ok(value),
"i8" if value < i8::MIN.into() => Err(value + 0x100),
"i8" if value > i8::MAX.into() => Err(value - 0x100),
"i8" => Ok(value),
"i16" if value < i16::MIN.into() => Err(value + 0x10000),
"i16" if value > i16::MAX.into() => Err(value - 0x10000),
"i16" => Ok(value),
"i32" if value < i32::MIN.into() => Err(value + 0x100000000u64),
"i32" if value > i32::MAX.into() => Err(value - 0x100000000u64),
"i32" => Ok(value),
"i64" if value < i64::MIN.into() => Err(value + 0x10000000000000000u128),
"i64" if value > i64::MAX.into() => Err(value - 0x10000000000000000u128),
"i64" => Ok(value),
"i128" if value < i128::MIN.into() => Err(value + (BigInt::one() << 128)),
"i128" if value > i128::MAX.into() => Err(value - (BigInt::one() << 128)),
"i128" => Ok(value),
"felt252" => Ok(value
% BigInt::from_str_radix(
"800000000000011000000000000000000000000000000000000000000000000",
16,
)
.unwrap()),
_ => unreachable!(),
}
}

/// Returns the const value of a variable if it exists.
fn as_const(&self, var_id: VariableId) -> Option<&ConstValue> {
try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const)
Expand Down Expand Up @@ -359,6 +433,10 @@ struct LibfuncInfo<'a> {
storage_base_address_from_felt252: FunctionId,
/// The set of functions that check if a number is zero.
nz_fns: UnorderedHashSet<FunctionId>,
/// The set of functions to add unsigned ints.
uadd_fns: UnorderedHashMap<FunctionId, &'static str>,
/// The set of functions to subtract unsigned ints.
usub_fns: UnorderedHashMap<FunctionId, &'static str>,
/// The storage access module.
storage_access_module: ModuleHelper<'a>,
}
Expand All @@ -380,13 +458,24 @@ impl<'a> LibfuncInfo<'a> {
["u8", "u16", "u32", "u64", "u128", "u256", "i8", "i16", "i32", "i64", "i128"]
.map(|ty| integer_module.function_id(format!("{}_is_zero", ty), vec![]))
));
let utypes = ["u8", "u16", "u32", "u64", "u128"];
let uadd_fns =
UnorderedHashMap::<_, _>::from_iter(utypes.map(|ty| {
(integer_module.function_id(format!("{ty}_overflowing_add"), vec![]), ty)
}));
let usub_fns =
UnorderedHashMap::<_, _>::from_iter(utypes.map(|ty| {
(integer_module.function_id(format!("{ty}_overflowing_sub"), vec![]), ty)
}));
Self {
felt_sub,
into_box,
upcast,
downcast,
storage_base_address_from_felt252,
nz_fns,
uadd_fns,
usub_fns,
storage_access_module,
}
}
Expand Down
Loading

0 comments on commit 3e3df1e

Please sign in to comment.