Skip to content

Commit

Permalink
Refactored range-testing in const folding.
Browse files Browse the repository at this point in the history
commit-id:273e5bb9
  • Loading branch information
orizi committed Jul 28, 2024
1 parent 3e3df1e commit fbb424f
Showing 1 changed file with 73 additions and 80 deletions.
153 changes: 73 additions & 80 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ mod test;

use cairo_lang_defs::ids::{ExternFunctionId, ModuleId, ModuleItemId};
use cairo_lang_semantic::items::constant::ConstValue;
use cairo_lang_semantic::{corelib, GenericArgumentId};
use cairo_lang_semantic::{corelib, GenericArgumentId, TypeId};
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use cairo_lang_utils::{try_extract_matches, Intern};
use id_arena::Arena;
use itertools::{chain, zip_eq};
use num_bigint::BigInt;
use num_traits::{Num, One, Zero};
use num_traits::Zero;
use smol_str::SmolStr;

use crate::db::LoweringGroup;
Expand Down Expand Up @@ -245,30 +245,15 @@ impl<'a> ConstFoldingContext<'a> {
FlatBlockEnd::Goto(arm.block_id, Default::default()),
)
});
} else if let Some(ty_name) = self.uadd_fns.get(&info.function) {
} else if self.uadd_fns.contains(&info.function) || self.usub_fns.contains(&info.function) {
let lhs = self.as_int(info.inputs[0].var_id)?;
let rhs = self.as_int(info.inputs[1].var_id)?;
let value = lhs + rhs;
let (arm_index, value) = match Self::value_in_range(ty_name, value) {
Ok(value) => (0, value),
Err(value) => (1, value),
};
let arm = &info.arms[arm_index];
let actual_output = arm.var_ids[0];
let ty = self.variables[actual_output].ty;
let value = ConstValue::Int(value, ty);
self.var_info.insert(actual_output, VarInfo::Const(value.clone()));
return Some((
Some(Statement::Const(StatementConst { value, output: actual_output })),
FlatBlockEnd::Goto(arm.block_id, Default::default()),
));
} else if let Some(ty_name) = self.usub_fns.get(&info.function) {
let lhs = self.as_int(info.inputs[0].var_id)?;
let rhs = self.as_int(info.inputs[1].var_id)?;
let value = lhs - rhs;
let (arm_index, value) = match Self::value_in_range(ty_name, value) {
Ok(value) => (0, value),
Err(value) => (1, value),
let value = if self.uadd_fns.contains(&info.function) { lhs + rhs } else { lhs - rhs };
let ty = self.variables[info.inputs[0].var_id].ty;
let range = self.type_value_ranges.get(&ty)?;
let (arm_index, value) = match range.normalized(value) {
NormalizedResult::InRange(value) => (0, value),
NormalizedResult::Over(value) | NormalizedResult::Under(value) => (1, value),
};
let arm = &info.arms[arm_index];
let actual_output = arm.var_ids[0];
Expand All @@ -285,9 +270,10 @@ impl<'a> ConstFoldingContext<'a> {
let value = self.as_int(input_var)?;
let success_output = info.arms[0].var_ids[0];
let ty = self.variables[success_output].ty;
let range = self.type_value_ranges.get(&ty)?;
return Some(
if corelib::validate_literal(self.db.upcast(), ty, value.clone()).is_ok() {
let value = ConstValue::Int(value.clone(), ty);
if let NormalizedResult::InRange(value) = range.normalized(value.clone()) {
let value = ConstValue::Int(value, ty);
self.var_info.insert(success_output, VarInfo::Const(value.clone()));
(
Some(Statement::Const(StatementConst {
Expand All @@ -305,50 +291,6 @@ impl<'a> ConstFoldingContext<'a> {
None
}

/// Returns the Ok(value) if the value is in the range of the type `ty_name`.
/// Otherwise, returns Err(value) with the value fixed into the range.
fn value_in_range(ty_name: &str, value: BigInt) -> Result<BigInt, BigInt> {
match ty_name {
"u8" if value < u8::MIN.into() => Err(value + 0x100),
"u8" if value > u8::MAX.into() => Err(value - 0x100),
"u8" => Ok(value),
"u16" if value < u16::MIN.into() => Err(value + 0x10000),
"u16" if value > u16::MAX.into() => Err(value - 0x10000),
"u16" => Ok(value),
"u32" if value < u32::MIN.into() => Err(value + 0x100000000u64),
"u32" if value > u32::MAX.into() => Err(value - 0x100000000u64),
"u32" => Ok(value),
"u64" if value < u64::MIN.into() => Err(value + 0x10000000000000000u128),
"u64" if value > u64::MAX.into() => Err(value - 0x10000000000000000u128),
"u64" => Ok(value),
"u128" if value < u128::MIN.into() => Err(value + (BigInt::one() << 128)),
"u128" if value > u128::MAX.into() => Err(value - (BigInt::one() << 128)),
"u128" => Ok(value),
"i8" if value < i8::MIN.into() => Err(value + 0x100),
"i8" if value > i8::MAX.into() => Err(value - 0x100),
"i8" => Ok(value),
"i16" if value < i16::MIN.into() => Err(value + 0x10000),
"i16" if value > i16::MAX.into() => Err(value - 0x10000),
"i16" => Ok(value),
"i32" if value < i32::MIN.into() => Err(value + 0x100000000u64),
"i32" if value > i32::MAX.into() => Err(value - 0x100000000u64),
"i32" => Ok(value),
"i64" if value < i64::MIN.into() => Err(value + 0x10000000000000000u128),
"i64" if value > i64::MAX.into() => Err(value - 0x10000000000000000u128),
"i64" => Ok(value),
"i128" if value < i128::MIN.into() => Err(value + (BigInt::one() << 128)),
"i128" if value > i128::MAX.into() => Err(value - (BigInt::one() << 128)),
"i128" => Ok(value),
"felt252" => Ok(value
% BigInt::from_str_radix(
"800000000000011000000000000000000000000000000000000000000000000",
16,
)
.unwrap()),
_ => unreachable!(),
}
}

/// Returns the const value of a variable if it exists.
fn as_const(&self, var_id: VariableId) -> Option<&ConstValue> {
try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const)
Expand Down Expand Up @@ -434,11 +376,13 @@ struct LibfuncInfo<'a> {
/// The set of functions that check if a number is zero.
nz_fns: UnorderedHashSet<FunctionId>,
/// The set of functions to add unsigned ints.
uadd_fns: UnorderedHashMap<FunctionId, &'static str>,
uadd_fns: UnorderedHashSet<FunctionId>,
/// The set of functions to subtract unsigned ints.
usub_fns: UnorderedHashMap<FunctionId, &'static str>,
usub_fns: UnorderedHashSet<FunctionId>,
/// The storage access module.
storage_access_module: ModuleHelper<'a>,
/// Type ranges.
type_value_ranges: UnorderedHashMap<TypeId, TypeRange>,
}
impl<'a> LibfuncInfo<'a> {
fn new(db: &'a dyn LoweringGroup) -> Self {
Expand All @@ -459,14 +403,30 @@ impl<'a> LibfuncInfo<'a> {
.map(|ty| integer_module.function_id(format!("{}_is_zero", ty), vec![]))
));
let utypes = ["u8", "u16", "u32", "u64", "u128"];
let uadd_fns =
UnorderedHashMap::<_, _>::from_iter(utypes.map(|ty| {
(integer_module.function_id(format!("{ty}_overflowing_add"), vec![]), ty)
}));
let usub_fns =
UnorderedHashMap::<_, _>::from_iter(utypes.map(|ty| {
(integer_module.function_id(format!("{ty}_overflowing_sub"), vec![]), ty)
}));
let uadd_fns = UnorderedHashSet::<_>::from_iter(
utypes.map(|ty| integer_module.function_id(format!("{ty}_overflowing_add"), vec![])),
);
let usub_fns = UnorderedHashSet::<_>::from_iter(
utypes.map(|ty| integer_module.function_id(format!("{ty}_overflowing_sub"), vec![])),
);
let type_value_ranges = UnorderedHashMap::from_iter(
[
("u8", TypeRange::closed(0, u8::MAX)),
("u16", TypeRange::closed(0, u16::MAX)),
("u32", TypeRange::closed(0, u32::MAX)),
("u64", TypeRange::closed(0, u64::MAX)),
("u128", TypeRange::closed(0, u128::MAX)),
("u256", TypeRange::closed(0, BigInt::from(1) << 256)),
("i8", TypeRange::closed(i8::MIN, i8::MAX)),
("i16", TypeRange::closed(i16::MIN, i16::MAX)),
("i32", TypeRange::closed(i32::MIN, i32::MAX)),
("i64", TypeRange::closed(i64::MIN, i64::MAX)),
("i128", TypeRange::closed(i128::MIN, i128::MAX)),
]
.map(|(ty, range)| {
(corelib::get_core_ty_by_name(db.upcast(), ty.into(), vec![]), range)
}),
);
Self {
felt_sub,
into_box,
Expand All @@ -477,6 +437,7 @@ impl<'a> LibfuncInfo<'a> {
uadd_fns,
usub_fns,
storage_access_module,
type_value_ranges,
}
}
}
Expand All @@ -487,3 +448,35 @@ impl<'a> std::ops::Deref for ConstFoldingContext<'a> {
&self.libfunc_info
}
}

/// The range of a type for normalizations.
struct TypeRange {
min: BigInt,
max: BigInt,
}
impl TypeRange {
fn closed(min: impl Into<BigInt>, max: impl Into<BigInt>) -> Self {
Self { min: min.into(), max: max.into() }
}
/// Normalizes the value to the range.
/// Assumes the value is within size of range of the range.
fn normalized(&self, value: BigInt) -> NormalizedResult {
if value < self.min {
NormalizedResult::Under(value - &self.min + &self.max + 1)
} else if value > self.max {
NormalizedResult::Over(value + &self.min - &self.max - 1)
} else {
NormalizedResult::InRange(value)
}
}
}

/// The result of normalizing a value to a range.
enum NormalizedResult {
/// The original value is in the range, carries the value, or an equivalent value.
InRange(BigInt),
/// The original value is larger than range max, carries the normalized value.
Over(BigInt),
/// The original value is smaller than range min, carries the normalized value.
Under(BigInt),
}

0 comments on commit fbb424f

Please sign in to comment.