Skip to content

Commit

Permalink
Reduce overhead of conversion from Python complex
Browse files Browse the repository at this point in the history
- Add a check for the root node when counting graph bridges
  • Loading branch information
benruijl committed Feb 11, 2025
1 parent b1edfcc commit 8a4a909
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
33 changes: 14 additions & 19 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2324,10 +2324,10 @@ impl<'a> FromPyObject<'a> for PythonMultiPrecisionFloat {

impl<'a> FromPyObject<'a> for Complex<f64> {
fn extract_bound(ob: &Bound<'a, pyo3::PyAny>) -> PyResult<Self> {
if let Ok(a) = ob.extract::<f64>() {
Ok(Complex::new(a, 0.))
} else if let Ok(a) = ob.downcast::<PyComplex>() {
if let Ok(a) = ob.downcast::<PyComplex>() {
Ok(Complex::new(a.real(), a.imag()))
} else if let Ok(a) = ob.extract::<f64>() {
Ok(Complex::new(a, 0.))
} else {
Err(exceptions::PyValueError::new_err(
"Not a valid complex number",
Expand Down Expand Up @@ -10357,19 +10357,17 @@ impl PythonCompiledExpressionEvaluator {
inputs: Vec<Complex<f64>>,
) -> Vec<Bound<'py, PyComplex>> {
let n_inputs = inputs.len() / self.input_len;
let mut res = vec![PyComplex::from_doubles(py, 0., 0.); self.output_len * n_inputs];
let mut tmp = vec![Complex::new_zero(); self.output_len];
let mut res = vec![Complex::new(0., 0.); self.output_len * n_inputs];
for (r, s) in res
.chunks_mut(self.output_len)
.zip(inputs.chunks(self.input_len))
{
self.eval.evaluate(s, &mut tmp);
for (rr, t) in r.iter_mut().zip(&tmp) {
*rr = PyComplex::from_doubles(py, t.re, t.im);
}
self.eval.evaluate(s, r);
}

res
res.into_iter()
.map(|x| PyComplex::from_doubles(py, x.re, x.im))
.collect()
}

/// Evaluate the expression for multiple inputs and return the results.
Expand Down Expand Up @@ -10429,20 +10427,17 @@ impl PythonExpressionEvaluator {
) -> Vec<Bound<'py, PyComplex>> {
let mut eval = self.eval.clone().map_coeff(&|x| Complex::new(*x, 0.));
let n_inputs = inputs.len() / self.eval.get_input_len();
let mut res =
vec![PyComplex::from_doubles(py, 0., 0.); self.eval.get_output_len() * n_inputs];
let mut tmp = vec![Complex::new_zero(); self.eval.get_output_len()];
let mut res = vec![Complex::new(0., 0.); self.eval.get_output_len() * n_inputs];
for (r, s) in res
.chunks_mut(self.eval.get_output_len())
.zip(inputs.chunks(self.eval.get_input_len()))
.zip(inputs.chunks(self.eval.get_output_len()))
{
eval.evaluate(s, &mut tmp);
for (rr, t) in r.iter_mut().zip(&tmp) {
*rr = PyComplex::from_doubles(py, t.re, t.im);
}
eval.evaluate(s, r);
}

res
res.into_iter()
.map(|x| PyComplex::from_doubles(py, x.re, x.im))
.collect()
}

/// Evaluate the expression for multiple inputs and return the results.
Expand Down
1 change: 1 addition & 0 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ impl SpanningTree {
x.chain_id.is_none()
&& !self.nodes[x.parent].external
&& !x.external
&& x.parent != *n // exclude the root
&& !self.nodes[x.parent].back_edges.iter().any(|end| n == end)
})
.count()
Expand Down

0 comments on commit 8a4a909

Please sign in to comment.