Skip to content

Commit

Permalink
feat: add RustitudeError and remove unwraps, error handling should wo…
Browse files Browse the repository at this point in the history
…rk in python as well
  • Loading branch information
denehoffman committed May 24, 2024
1 parent 9bcdb46 commit 6955773
Show file tree
Hide file tree
Showing 19 changed files with 364 additions and 225 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions crates/rustitude-core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub struct OmegaDalitz {
}

impl Node for OmegaDalitz {
fn precalculate(&mut self, dataset: &Dataset) -> Result<(), NodeError> {
fn precalculate(&mut self, dataset: &Dataset) -> Result<(), RustitudeError> {
(self.dalitz_z, (self.dalitz_sin3theta, self.lambda)) = dataset
.events
.read()
Expand Down Expand Up @@ -105,7 +105,7 @@ impl Node for OmegaDalitz {
Ok(())
}

fn calculate(&self, parameters: &[f64], event: &Event) -> Result<Complex64, NodeError> {
fn calculate(&self, parameters: &[f64], event: &Event) -> Result<Complex64, RustitudeError> {
let dalitz_z = self.dalitz_z[event.index];
let dalitz_sin3theta = self.dalitz_sin3theta[event.index];
let lambda = self.lambda[event.index];
Expand Down Expand Up @@ -137,9 +137,9 @@ impl Node for OmegaDalitz {

Let's walk through this code. First, we need to define a `struct` which has all of the general information about the amplitude, and in this case some kind of `Vec` for storing precalculated data. We consider this precalculated data to correspond to a single dataset, and each dataset gets its own copy of the amplitude `struct`. Because this particular amplitude doesn't have any input parameters, we can `#[derive(Default)]` on it to make a default constructor, which allows the amplitude to be initialized with something like `let amp = OmegaDalitz::default();`. If we wanted a parameterized constructor, we have to define our own, and while Rust has no default name for constructors, `pub fn new(...) -> rustitude_core::AmpOp` is preferred.

Next, we implement the `Node` trait for the `struct`. Traits in Rust are kind of like abstract classes or interfaces in object-oriented languages, they provide some set of methods which a `struct` must implement. The first of these methods is `fn precalculate(&mut self, dataset: &Dataset) -> Result<(), NodeError>`. As the signature suggests, it takes a `Dataset` and mutates the `struct` in some way. It should raise a `NodeError` if anything goes wrong in the evaluation. The intended usage of this function is to precalculate some terms in the amplitude's mathematical expression, things which don't change when you update the free parameter inputs to the amplitude. In this case, the four input parameters, $`\alpha`$, $`\beta`$, $`\gamma`$, and $`\delta`$, are independent from `dalitz_z`, `dalitz_sin3theta`, and `lambda`, so we can safely calculate those ahead of time and just pull from their respective `Vec`s when needed later. I won't go too far into Rust's syntax here, but typical precalculation functions will start by iterating over the dataset's events in parallel (the line `use rayon::prelude::*;` is needed to use `par_iter` here) and collecting or unzipping that iterator into a `Vec` or group of `Vec`s.
Next, we implement the `Node` trait for the `struct`. Traits in Rust are kind of like abstract classes or interfaces in object-oriented languages, they provide some set of methods which a `struct` must implement. The first of these methods is `fn precalculate(&mut self, dataset: &Dataset) -> Result<(), RustitudeError>`. As the signature suggests, it takes a `Dataset` and mutates the `struct` in some way. It should raise a `RustitudeError` if anything goes wrong in the evaluation. The intended usage of this function is to precalculate some terms in the amplitude's mathematical expression, things which don't change when you update the free parameter inputs to the amplitude. In this case, the four input parameters, $`\alpha`$, $`\beta`$, $`\gamma`$, and $`\delta`$, are independent from `dalitz_z`, `dalitz_sin3theta`, and `lambda`, so we can safely calculate those ahead of time and just pull from their respective `Vec`s when needed later. I won't go too far into Rust's syntax here, but typical precalculation functions will start by iterating over the dataset's events in parallel (the line `use rayon::prelude::*;` is needed to use `par_iter` here) and collecting or unzipping that iterator into a `Vec` or group of `Vec`s.

The calculate step has the signature `fn calculate(&self, parameters: &[f64], event: &Event) -> Result<Complex64, NodeError>`. This means we need to take a list of parameters and a single event and turn them into a complex value. The `Event` struct contains an `index` field which can be used to access the precalculated storage arrays made in the previous step.
The calculate step has the signature `fn calculate(&self, parameters: &[f64], event: &Event) -> Result<Complex64, RustitudeError>`. This means we need to take a list of parameters and a single event and turn them into a complex value. The `Event` struct contains an `index` field which can be used to access the precalculated storage arrays made in the previous step.

Finally, the `parameters` function just returns a list of the parameter names in the order they are expected to be input into `calculate`. In the event that an amplitude doesn't have any free parameters (like [my implementation of the `Ylm` and `Zlm` amplitudes](https://github.com/denehoffman/rustitude/blob/main/crates/rustitude-gluex/src/harmonics.rs)), we can omit this function entirely, as the default implementation returns `vec![]`.

Expand Down
120 changes: 72 additions & 48 deletions crates/rustitude-core/src/amplitude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ use std::{
ops::{Add, Mul},
sync::Arc,
};
use thiserror::Error;

use crate::dataset::{Dataset, Event};
use crate::{
dataset::{Dataset, Event},
errors::RustitudeError,
};

#[derive(Clone)]
pub struct Parameter {
Expand Down Expand Up @@ -80,14 +82,6 @@ macro_rules! amplitude {
}};
}

#[derive(Debug, Clone, Error)]
pub enum NodeError {
#[error("invalid parameter value")]
InvalidParameterValue(String),
#[error("evaluation error")]
EvaluationError(String),
}

/// A trait which contains all the required methods for a functioning [`Amplitude`].
///
/// The [`Node`] trait represents any mathematical structure which takes in some parameters and some
Expand Down Expand Up @@ -151,7 +145,7 @@ pub enum NodeError {
/// }
/// impl Node for Ylm {
/// fn parameters(&self) -> Vec<String> { vec![] }
/// fn precalculate(&mut self, dataset: &Dataset) -> Result<(), NodeError> {
/// fn precalculate(&mut self, dataset: &Dataset) -> Result<(), RustitudeError> {
/// self.1 = dataset.events.read()
/// .par_iter()
/// .map(|event| {
Expand All @@ -173,7 +167,7 @@ pub enum NodeError {
/// Ok(())
/// }
///
/// fn calculate(&self, _parameters: &[f64], event: &Event) -> Result<Complex64, NodeError> {
/// fn calculate(&self, _parameters: &[f64], event: &Event) -> Result<Complex64, RustitudeError> {
/// Ok(self.1[event.index])
/// }
/// }
Expand All @@ -185,7 +179,7 @@ pub enum NodeError {
/// use rustitude_core::prelude::*;
/// struct ComplexScalar;
/// impl Node for ComplexScalar {
/// fn calculate(&self, parameters: &[f64], _event: &Event) -> Result<Complex64, NodeError> {
/// fn calculate(&self, parameters: &[f64], _event: &Event) -> Result<Complex64, RustitudeError> {
/// Ok(Complex64::new(parameters[0], parameters[1]))
/// }
///
Expand All @@ -201,7 +195,7 @@ pub trait Node: Sync + Send {
/// parameters. For instance, to calculate a spherical harmonic, we don't actually need any
/// other information than what is contained in the [`Event`], so we can calculate a spherical
/// harmonic for every event once and then retrieve the data in the [`Node::calculate`] method.
fn precalculate(&mut self, _dataset: &Dataset) -> Result<(), NodeError> {
fn precalculate(&mut self, _dataset: &Dataset) -> Result<(), RustitudeError> {
Ok(())
}

Expand All @@ -213,7 +207,7 @@ pub trait Node: Sync + Send {
/// a slice of [`f64`]s. This slice is guaranteed to have the same length and order as
/// specified in the [`Node::parameters`] method, or it will be empty if that method returns
/// [`None`].
fn calculate(&self, parameters: &[f64], event: &Event) -> Result<Complex64, NodeError>;
fn calculate(&self, parameters: &[f64], event: &Event) -> Result<Complex64, RustitudeError>;

/// A method which specifies the number and order of parameters used by the [`Node`].
///
Expand Down Expand Up @@ -558,17 +552,17 @@ impl Amplitude {
cache_position: usize,
parameter_index_start: usize,
dataset: &Dataset,
) -> Result<(), NodeError> {
) -> Result<(), RustitudeError> {
self.cache_position = cache_position;
self.parameter_index_start = parameter_index_start;
self.precalculate(dataset)
}
}
impl Node for Amplitude {
fn precalculate(&mut self, dataset: &Dataset) -> Result<(), NodeError> {
fn precalculate(&mut self, dataset: &Dataset) -> Result<(), RustitudeError> {
self.node.write().precalculate(dataset)
}
fn calculate(&self, parameters: &[f64], event: &Event) -> Result<Complex64, NodeError> {
fn calculate(&self, parameters: &[f64], event: &Event) -> Result<Complex64, RustitudeError> {
self.node.read().calculate(
&parameters
[self.parameter_index_start..self.parameter_index_start + self.parameters().len()],
Expand All @@ -588,10 +582,23 @@ pub struct Model {
}

impl Model {
pub fn get_parameter(&self, amplitude_name: &str, parameter_name: &str) -> Option<Parameter> {
pub fn get_amplitude(&self, amplitude_name: &str) -> Result<Amplitude, RustitudeError> {
self.amplitudes
.iter()
.find(|a: &&Amplitude| a.name == amplitude_name)
.ok_or_else(|| RustitudeError::AmplitudeNotFoundError(amplitude_name.to_string()))
.cloned()
}
pub fn get_parameter(
&self,
amplitude_name: &str,
parameter_name: &str,
) -> Result<Parameter, RustitudeError> {
self.get_amplitude(amplitude_name)?;
self.parameters
.iter()
.find(|p: &&Parameter| p.amplitude == amplitude_name && p.name == parameter_name)
.ok_or_else(|| RustitudeError::ParameterNotFoundError(parameter_name.to_string()))
.cloned()
}
pub fn print_parameters(&self) {
Expand Down Expand Up @@ -619,9 +626,9 @@ impl Model {
parameter_1: &str,
amplitude_2: &str,
parameter_2: &str,
) {
let p1 = self.get_parameter(amplitude_1, parameter_1).unwrap();
let p2 = self.get_parameter(amplitude_2, parameter_2).unwrap();
) -> Result<(), RustitudeError> {
let p1 = self.get_parameter(amplitude_1, parameter_1)?;
let p2 = self.get_parameter(amplitude_2, parameter_2)?;
for par in self.parameters.iter_mut() {
// None < Some(0)
match p1.index.cmp(&p2.index) {
Expand All @@ -645,10 +652,16 @@ impl Model {
}
}
self.reindex_parameters();
Ok(())
}

pub fn fix(&mut self, amplitude: &str, parameter: &str, value: f64) {
let search_par = self.get_parameter(amplitude, parameter).unwrap();
pub fn fix(
&mut self,
amplitude: &str,
parameter: &str,
value: f64,
) -> Result<(), RustitudeError> {
let search_par = self.get_parameter(amplitude, parameter)?;
let fixed_index = self.get_min_fixed_index();
for par in self.parameters.iter_mut() {
if par.index == search_par.index {
Expand All @@ -658,9 +671,10 @@ impl Model {
}
}
self.reindex_parameters();
Ok(())
}
pub fn free(&mut self, amplitude: &str, parameter: &str) {
let search_par = self.get_parameter(amplitude, parameter).unwrap();
pub fn free(&mut self, amplitude: &str, parameter: &str) -> Result<(), RustitudeError> {
let search_par = self.get_parameter(amplitude, parameter)?;
let index = self.get_min_free_index();
for par in self.parameters.iter_mut() {
if par.fixed_index == search_par.fixed_index {
Expand All @@ -669,9 +683,15 @@ impl Model {
}
}
self.reindex_parameters();
Ok(())
}
pub fn set_bounds(&mut self, amplitude: &str, parameter: &str, bounds: (f64, f64)) {
let search_par = self.get_parameter(amplitude, parameter).unwrap();
pub fn set_bounds(
&mut self,
amplitude: &str,
parameter: &str,
bounds: (f64, f64),
) -> Result<(), RustitudeError> {
let search_par = self.get_parameter(amplitude, parameter)?;
if search_par.index.is_some() {
for par in self.parameters.iter_mut() {
if par.index == search_par.index {
Expand All @@ -685,9 +705,15 @@ impl Model {
}
}
}
Ok(())
}
pub fn set_initial(&mut self, amplitude: &str, parameter: &str, initial: f64) {
let search_par = self.get_parameter(amplitude, parameter).unwrap();
pub fn set_initial(
&mut self,
amplitude: &str,
parameter: &str,
initial: f64,
) -> Result<(), RustitudeError> {
let search_par = self.get_parameter(amplitude, parameter)?;
if search_par.index.is_some() {
for par in self.parameters.iter_mut() {
if par.index == search_par.index {
Expand All @@ -701,6 +727,7 @@ impl Model {
}
}
}
Ok(())
}
pub fn get_bounds(&self) -> Vec<(f64, f64)> {
let any_fixed = if self.any_fixed() { 1 } else { 0 };
Expand Down Expand Up @@ -762,7 +789,7 @@ impl Model {
parameters,
}
}
pub fn compute(&self, parameters: &[f64], event: &Event) -> f64 {
pub fn compute(&self, parameters: &[f64], event: &Event) -> Result<f64, RustitudeError> {
let pars: Vec<f64> = self
.parameters
.iter()
Expand All @@ -774,23 +801,19 @@ impl Model {
.iter()
.map(|amp| {
if amp.active {
let res = amp.calculate(&pars, event).unwrap(); // unwrap panics if any
// errors occur in calculation
Some(res)
amp.calculate(&pars, event).map(Some)
} else {
None
Ok(None)
}
})
.collect();
let res = self.root.compute(&cache).unwrap(); // unwrap panics if all the
res.re
.collect::<Result<Vec<Option<Complex64>>, RustitudeError>>()?;
Ok(self.root.compute(&cache).unwrap_or_default().re)
}
pub fn load(&mut self, dataset: &Dataset) {
pub fn load(&mut self, dataset: &Dataset) -> Result<(), RustitudeError> {
let mut next_cache_pos = 0;
let mut parameter_index = 0;
self.amplitudes.iter_mut().for_each(|amp| {
amp.register(next_cache_pos, parameter_index, dataset)
.unwrap(); // unwrap panics if precalculate fails
self.amplitudes.iter_mut().try_for_each(|amp| {
amp.register(next_cache_pos, parameter_index, dataset)?;
self.root.walk_mut().iter_mut().for_each(|r_amp| {
if r_amp.name == amp.name {
r_amp.cache_position = next_cache_pos;
Expand All @@ -799,7 +822,8 @@ impl Model {
});
next_cache_pos += 1;
parameter_index += amp.parameters().len();
});
Ok(())
})
}
fn group_by_index(&self) -> Vec<Vec<&Parameter>> {
self.parameters
Expand Down Expand Up @@ -858,7 +882,7 @@ impl Node for Scalar {
fn parameters(&self) -> Vec<String> {
vec!["value".to_string()]
}
fn calculate(&self, parameters: &[f64], _event: &Event) -> Result<Complex64, NodeError> {
fn calculate(&self, parameters: &[f64], _event: &Event) -> Result<Complex64, RustitudeError> {
Ok(Complex64::new(parameters[0], 0.0))
}
}
Expand Down Expand Up @@ -893,7 +917,7 @@ pub fn scalar(name: &str) -> AmpOp {
/// - `imag`: The imaginary part of the complex scalar.
pub struct ComplexScalar;
impl Node for ComplexScalar {
fn calculate(&self, parameters: &[f64], _event: &Event) -> Result<Complex64, NodeError> {
fn calculate(&self, parameters: &[f64], _event: &Event) -> Result<Complex64, RustitudeError> {
Ok(Complex64::new(parameters[0], parameters[1]))
}

Expand Down Expand Up @@ -933,7 +957,7 @@ pub fn cscalar(name: &str) -> AmpOp {
/// - `phi`: The phase of the complex scalar.
pub struct PolarComplexScalar;
impl Node for PolarComplexScalar {
fn calculate(&self, parameters: &[f64], _event: &Event) -> Result<Complex64, NodeError> {
fn calculate(&self, parameters: &[f64], _event: &Event) -> Result<Complex64, RustitudeError> {
Ok(parameters[0] * Complex64::cis(parameters[1]))
}

Expand Down Expand Up @@ -997,7 +1021,7 @@ impl<F> Node for Piecewise<F>
where
F: Fn(&Event) -> f64 + Send + Sync + Copy,
{
fn precalculate(&mut self, dataset: &Dataset) -> Result<(), NodeError> {
fn precalculate(&mut self, dataset: &Dataset) -> Result<(), RustitudeError> {
self.calculated_variable = dataset
.events
.read()
Expand All @@ -1007,7 +1031,7 @@ where
Ok(())
}

fn calculate(&self, parameters: &[f64], event: &Event) -> Result<Complex64, NodeError> {
fn calculate(&self, parameters: &[f64], event: &Event) -> Result<Complex64, RustitudeError> {
let val = self.calculated_variable[event.index];
let opt_i_bin = self.edges.iter().position(|&(l, r)| val >= l && val <= r);
opt_i_bin.map_or_else(
Expand Down
Loading

0 comments on commit 6955773

Please sign in to comment.