diff --git a/examples/nested_evaluation.rs b/examples/nested_evaluation.rs index f0135eb0..972f1c0c 100644 --- a/examples/nested_evaluation.rs +++ b/examples/nested_evaluation.rs @@ -1,73 +1,75 @@ use std::{process::Command, time::Instant}; -use ahash::HashMap; use symbolica::{ atom::{Atom, AtomView}, - evaluate::{ConstOrExpr, ExpressionEvaluator}, + domains::rational::Rational, + evaluate::{ExpressionEvaluator, FunctionMap}, state::State, }; fn main() { - let e1 = Atom::parse("x + cos(x) + f(g(x+1),h(x*2)) + p(1)").unwrap(); + let e1 = Atom::parse("x + pi + cos(x) + f(g(x+1),h(x*2)) + p(1,x)").unwrap(); let e2 = Atom::parse("x + h(x*2) + cos(x)").unwrap(); let f = Atom::parse("y^2 + z^2*y^2").unwrap(); let g = Atom::parse("i(y+7)+x*i(y+7)*(y-1)").unwrap(); let h = Atom::parse("y*(1+x*(1+x^2)) + y^2*(1+x*(1+x^2))^2 + 3*(1+x^2)").unwrap(); let i = Atom::parse("y - 1").unwrap(); - let k = Atom::parse("3*x^3 + 4*x^2 + 6*x +8").unwrap(); + let p1 = Atom::parse("3*z^3 + 4*z^2 + 6*z +8").unwrap(); - let mut const_map = HashMap::default(); + let mut fn_map = FunctionMap::new(); - let p1 = Atom::parse("p(1)").unwrap(); - let f_s = Atom::new_var(State::get_symbol("f")); - let g_s = Atom::new_var(State::get_symbol("g")); - let h_s = Atom::new_var(State::get_symbol("h")); - let i_s = Atom::new_var(State::get_symbol("i")); - - const_map.insert( - p1.into(), - ConstOrExpr::Expr(State::get_symbol("p1"), vec![], k.as_view()), + fn_map.add_constant( + Atom::new_var(State::get_symbol("pi")).into(), + Rational::from((22, 7)).into(), ); - - const_map.insert( - f_s.into(), - ConstOrExpr::Expr( + fn_map + .add_tagged_function( + State::get_symbol("p"), + vec![Atom::new_num(1).into()], + "p1".to_string(), + vec![State::get_symbol("z")], + p1.as_view(), + ) + .unwrap(); + fn_map + .add_function( State::get_symbol("f"), + "f".to_string(), vec![State::get_symbol("y"), State::get_symbol("z")], f.as_view(), - ), - ); - const_map.insert( - g_s.into(), - ConstOrExpr::Expr( + ) + .unwrap(); + fn_map + .add_function( State::get_symbol("g"), + "g".to_string(), vec![State::get_symbol("y")], g.as_view(), - ), - ); - const_map.insert( - h_s.into(), - ConstOrExpr::Expr( + ) + .unwrap(); + fn_map + .add_function( State::get_symbol("h"), + "h".to_string(), vec![State::get_symbol("y")], h.as_view(), - ), - ); - const_map.insert( - i_s.into(), - ConstOrExpr::Expr( + ) + .unwrap(); + fn_map + .add_function( State::get_symbol("i"), + "i".to_string(), vec![State::get_symbol("y")], i.as_view(), - ), - ); + ) + .unwrap(); let params = vec![Atom::parse("x").unwrap()]; let mut tree = AtomView::to_eval_tree_multiple( &[e1.as_view(), e2.as_view()], |r| r.clone(), - &const_map, + &fn_map, ¶ms, ); @@ -87,7 +89,7 @@ fn main() { std::fs::write("nested_evaluation.cpp", cpp).unwrap(); - Command::new("g++") + let r = Command::new("g++") .arg("-shared") .arg("-fPIC") .arg("-O3") @@ -97,12 +99,12 @@ fn main() { .arg("nested_evaluation.cpp") .output() .unwrap(); + println!("Compilation {}", r.status); unsafe { let lib = libloading::Library::new("./libneval.so").unwrap(); - let func: libloading::Symbol< - unsafe extern "C" fn(params: *const f64, out: *mut f64) -> f64, - > = lib.get(b"eval_double").unwrap(); + let func: libloading::Symbol = + lib.get(b"eval_double").unwrap(); let params = vec![5.]; let mut out = vec![0., 0.]; diff --git a/src/atom.rs b/src/atom.rs index 279c2391..7f10d7ab 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -193,6 +193,7 @@ impl<'a> From> for AtomView<'a> { } /// A copy-on-write structure for `Atom` and `AtomView`. +#[derive(Clone)] pub enum AtomOrView<'a> { Atom(Atom), View(AtomView<'a>), diff --git a/src/evaluate.rs b/src/evaluate.rs index eeb57384..d1f908b9 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -32,9 +32,99 @@ impl EvaluationFn { } } -pub enum ConstOrExpr<'a, T> { +#[derive(PartialEq, Eq, Hash)] +enum AtomOrTaggedFunction<'a> { + Atom(AtomOrView<'a>), + TaggedFunction(Symbol, Vec>), +} + +pub struct FunctionMap<'a, T> { + map: HashMap, ConstOrExpr<'a, T>>, + tag: HashMap, +} + +impl<'a, T> FunctionMap<'a, T> { + pub fn new() -> Self { + FunctionMap { + map: HashMap::default(), + tag: HashMap::default(), + } + } + + pub fn add_constant(&mut self, key: AtomOrView<'a>, value: T) { + self.map + .insert(AtomOrTaggedFunction::Atom(key), ConstOrExpr::Const(value)); + } + + pub fn add_function( + &mut self, + name: Symbol, + rename: String, + args: Vec, + body: AtomView<'a>, + ) -> Result<(), &str> { + if let Some(t) = self.tag.insert(name, 0) { + if t != 0 { + return Err("Cannot add the same function with a different number of parameters"); + } + } + + self.map.insert( + AtomOrTaggedFunction::Atom(Atom::new_var(name).into()), + ConstOrExpr::Expr(rename, 0, args, body), + ); + + Ok(()) + } + + pub fn add_tagged_function( + &mut self, + name: Symbol, + tags: Vec, + rename: String, + args: Vec, + body: AtomView<'a>, + ) -> Result<(), &str> { + if let Some(t) = self.tag.insert(name, tags.len()) { + if t != tags.len() { + return Err("Cannot add the same function with a different number of parameters"); + } + } + + self.map.insert( + AtomOrTaggedFunction::Atom(Atom::new_var(name).into()), + ConstOrExpr::Expr(rename, tags.len(), args, body), + ); + + Ok(()) + } + + fn get_tag_len(&self, symbol: &Symbol) -> usize { + self.tag.get(symbol).cloned().unwrap_or(0) + } + + fn get(&self, a: AtomView<'a>) -> Option<&ConstOrExpr<'a, T>> { + if let Some(c) = self.map.get(&AtomOrTaggedFunction::Atom(a.into())) { + return Some(c); + } + + if let AtomView::Fun(aa) = a { + let s = aa.get_symbol(); + let tag_len = self.get_tag_len(&s); + + if tag_len != 0 && aa.get_nargs() >= tag_len { + let tag = aa.iter().take(tag_len).map(|x| x.into()).collect(); + return self.map.get(&AtomOrTaggedFunction::TaggedFunction(s, tag)); + } + } + + None + } +} + +enum ConstOrExpr<'a, T> { Const(T), - Expr(Symbol, Vec, AtomView<'a>), + Expr(String, usize, Vec, AtomView<'a>), } impl Atom { @@ -55,14 +145,14 @@ impl Atom { } } -pub struct ExpressionWithSubexpressions { +pub struct SplitExpression { pub tree: Vec>, pub subexpressions: Vec>, } pub struct EvalTree { - functions: Vec<(Symbol, Vec, ExpressionWithSubexpressions)>, - expressions: ExpressionWithSubexpressions, + functions: Vec<(String, Vec, SplitExpression)>, + expressions: SplitExpression, } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -241,9 +331,9 @@ enum Instr { BuiltinFun(usize, Symbol, usize), } -impl ExpressionWithSubexpressions { - pub fn map_coeff T2>(&self, f: &F) -> ExpressionWithSubexpressions { - ExpressionWithSubexpressions { +impl SplitExpression { + pub fn map_coeff T2>(&self, f: &F) -> SplitExpression { + SplitExpression { tree: self.tree.iter().map(|x| x.map_coeff(f)).collect(), subexpressions: self.subexpressions.iter().map(|x| x.map_coeff(f)).collect(), } @@ -320,7 +410,7 @@ impl Expression { impl EvalTree { pub fn map_coeff T2>(&self, f: &F) -> EvalTree { EvalTree { - expressions: ExpressionWithSubexpressions { + expressions: SplitExpression { tree: self .expressions .tree @@ -337,7 +427,7 @@ impl EvalTree { functions: self .functions .iter() - .map(|(s, a, e)| (*s, a.clone(), e.map_coeff(f))) + .map(|(s, a, e)| (s.clone(), a.clone(), e.map_coeff(f))) .collect(), } } @@ -772,9 +862,7 @@ impl EvalTree } } -impl - ExpressionWithSubexpressions -{ +impl SplitExpression { pub fn common_subexpression_elimination(&mut self) { let mut h = HashMap::default(); @@ -875,9 +963,7 @@ impl Expressi } } -impl - ExpressionWithSubexpressions -{ +impl SplitExpression { /// Find and extract pairs of variables that appear in more than one instruction. /// This reduces the number of operations. Returns `true` iff an extraction could be performed. /// @@ -1466,10 +1552,10 @@ impl<'a> AtomView<'a> { >( &self, coeff_map: F, - const_map: &HashMap>, + fn_map: &FunctionMap<'a, T>, params: &[Atom], ) -> EvalTree { - Self::to_eval_tree_multiple(std::slice::from_ref(self), coeff_map, const_map, params) + Self::to_eval_tree_multiple(std::slice::from_ref(self), coeff_map, fn_map, params) } /// Convert nested expressions to a tree. @@ -1479,17 +1565,17 @@ impl<'a> AtomView<'a> { >( exprs: &[Self], coeff_map: F, - const_map: &HashMap>, + fn_map: &FunctionMap<'a, T>, params: &[Atom], ) -> EvalTree { let mut funcs = vec![]; let tree = exprs .iter() - .map(|t| t.to_eval_tree_impl(coeff_map, const_map, params, &[], &mut funcs)) + .map(|t| t.to_eval_tree_impl(coeff_map, fn_map, params, &[], &mut funcs)) .collect(); EvalTree { - expressions: ExpressionWithSubexpressions { + expressions: SplitExpression { tree, subexpressions: vec![], }, @@ -1500,20 +1586,20 @@ impl<'a> AtomView<'a> { fn to_eval_tree_impl T + Copy>( &self, coeff_map: F, - const_map: &HashMap>, + fn_map: &FunctionMap<'a, T>, params: &[Atom], args: &[Symbol], - funcs: &mut Vec<(Symbol, Vec, ExpressionWithSubexpressions)>, + funcs: &mut Vec<(String, Vec, SplitExpression)>, ) -> Expression { if let Some(p) = params.iter().position(|a| a.as_view() == *self) { return Expression::Parameter(p); } - if let Some(c) = const_map.get(&self.into()) { + if let Some(c) = fn_map.get(*self) { return match c { ConstOrExpr::Const(c) => Expression::Const(c.clone()), - ConstOrExpr::Expr(name, args, v) => { - if !args.is_empty() { + ConstOrExpr::Expr(name, tag_len, args, v) => { + if args.len() != *tag_len { panic!( "Function {} called with wrong number of arguments: 0 vs {}", self, @@ -1524,11 +1610,11 @@ impl<'a> AtomView<'a> { if let Some(pos) = funcs.iter().position(|f| f.0 == *name) { Expression::Eval(pos, vec![]) } else { - let r = v.to_eval_tree_impl(coeff_map, const_map, params, args, funcs); + let r = v.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs); funcs.push(( - *name, + name.clone(), args.clone(), - ExpressionWithSubexpressions { + SplitExpression { tree: vec![r.clone()], subexpressions: vec![], }, @@ -1577,44 +1663,44 @@ impl<'a> AtomView<'a> { if [State::EXP, State::LOG, State::SIN, State::COS, State::SQRT].contains(&name) { assert!(f.get_nargs() == 1); let arg = f.iter().next().unwrap(); - let arg_eval = arg.to_eval_tree_impl(coeff_map, const_map, params, args, funcs); + let arg_eval = arg.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs); return Expression::BuiltinFun(f.get_symbol(), Box::new(arg_eval)); } let symb = InlineVar::new(f.get_symbol()); - let Some(fun) = const_map.get(&symb.as_view().into()) else { + let Some(fun) = fn_map.get(symb.as_view()) else { panic!("Undefined function {}", State::get_name(f.get_symbol())); }; match fun { ConstOrExpr::Const(t) => Expression::Const(t.clone()), - ConstOrExpr::Expr(name, arg_spec, e) => { - if f.get_nargs() != arg_spec.len() { + ConstOrExpr::Expr(name, tag_len, arg_spec, e) => { + if f.get_nargs() != arg_spec.len() + *tag_len { panic!( "Function {} called with wrong number of arguments: {} vs {}", f.get_symbol(), f.get_nargs(), - arg_spec.len() + arg_spec.len() + *tag_len ); } let eval_args = f .iter() + .skip(*tag_len) .map(|arg| { - arg.to_eval_tree_impl(coeff_map, const_map, params, args, funcs) + arg.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs) }) .collect(); if let Some(pos) = funcs.iter().position(|f| f.0 == *name) { Expression::Eval(pos, eval_args) } else { - let r = - e.to_eval_tree_impl(coeff_map, const_map, params, arg_spec, funcs); + let r = e.to_eval_tree_impl(coeff_map, fn_map, params, arg_spec, funcs); funcs.push(( - *name, + name.clone(), arg_spec.clone(), - ExpressionWithSubexpressions { + SplitExpression { tree: vec![r.clone()], subexpressions: vec![], }, @@ -1626,7 +1712,7 @@ impl<'a> AtomView<'a> { } AtomView::Pow(p) => { let (b, e) = p.get_base_exp(); - let b_eval = b.to_eval_tree_impl(coeff_map, const_map, params, args, funcs); + let b_eval = b.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs); if let AtomView::Num(n) = e { if let CoefficientView::Natural(num, den) = n.get_coeff_view() { @@ -1639,13 +1725,13 @@ impl<'a> AtomView<'a> { } } - let e_eval = e.to_eval_tree_impl(coeff_map, const_map, params, args, funcs); + let e_eval = e.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs); Expression::Powf(Box::new((b_eval, e_eval))) } AtomView::Mul(m) => { let mut muls = vec![]; for arg in m.iter() { - let a = arg.to_eval_tree_impl(coeff_map, const_map, params, args, funcs); + let a = arg.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs); if let Expression::Mul(m) = a { muls.extend(m); } else { @@ -1658,7 +1744,7 @@ impl<'a> AtomView<'a> { AtomView::Add(a) => { let mut adds = vec![]; for arg in a.iter() { - adds.push(arg.to_eval_tree_impl(coeff_map, const_map, params, args, funcs)); + adds.push(arg.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs)); } Expression::Add(adds)