diff --git a/python/rebop/__init__.py b/python/rebop/__init__.py index e9bbfd4..0db0e33 100644 --- a/python/rebop/__init__.py +++ b/python/rebop/__init__.py @@ -1,9 +1,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import xarray as xr from .rebop import Gillespie, __version__ # type: ignore[attr-defined] +if TYPE_CHECKING: + from collections.abc import Sequence + __all__ = ("Gillespie", "__version__") og_run = Gillespie.run @@ -17,13 +22,22 @@ def run_xarray( # noqa: PLR0913 too many parameters in function definition seed: int | None = None, *, sparse: bool = False, + var_names: Sequence[str] | None = None, ) -> xr.Dataset: """Run the system until `tmax` with `nb_steps` steps. The initial configuration is specified in the dictionary `init`. Returns an xarray Dataset. """ - times, result = og_run(self, init, tmax, nb_steps, seed, sparse=sparse) + times, result = og_run( + self, + init, + tmax, + nb_steps, + seed, + sparse=sparse, + var_names=var_names, + ) ds = xr.Dataset( data_vars={ name: xr.DataArray(values, dims="time", coords={"time": times}) diff --git a/python/rebop/rebop.pyi b/python/rebop/rebop.pyi index f84f8c0..da33fa3 100644 --- a/python/rebop/rebop.pyi +++ b/python/rebop/rebop.pyi @@ -1,3 +1,5 @@ +from collections.abc import Sequence + import xarray class Gillespie: @@ -32,6 +34,7 @@ class Gillespie: seed: int | None = None, *, sparse: bool = False, + var_names: Sequence[str] | None = None, ) -> xarray.Dataset: """Run the system until `tmax` with `nb_steps` steps. diff --git a/src/lib.rs b/src/lib.rs index 4f9c064..8050315 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -303,7 +303,7 @@ impl Gillespie { /// values at the given time points. One can specify a random `seed` for reproducibility. /// If `nb_steps` is `0`, then returns all reactions, ending with the first that happens at /// or after `tmax`. - #[pyo3(signature = (init, tmax, nb_steps, seed=None, sparse=false))] + #[pyo3(signature = (init, tmax, nb_steps, seed=None, sparse=false, var_names=None))] fn run( &self, init: HashMap, @@ -311,6 +311,7 @@ impl Gillespie { nb_steps: usize, seed: Option, sparse: bool, + var_names: Option>, ) -> PyResult<(Vec, HashMap>)> { let mut x0 = vec![0; self.species.len()]; for (name, &value) in &init { @@ -322,6 +323,13 @@ impl Gillespie { Some(seed) => gillespie::Gillespie::new_with_seed(x0, sparse, seed), None => gillespie::Gillespie::new(x0, sparse), }; + let save_indices: Vec<_> = match &var_names { + Some(x) => x + .iter() + .map(|key| self.species.get(key).unwrap().clone()) + .collect(), + None => (0..self.species.len()).collect(), + }; for (rate, reactants, products) in self.reactions.iter() { let mut vreactants = vec![0; self.species.len()]; @@ -340,34 +348,43 @@ impl Gillespie { } let mut times = Vec::new(); // species.shape = (species, nb_steps) - let mut species = vec![Vec::new(); self.species.len()]; + let mut species = vec![Vec::new(); save_indices.len()]; if nb_steps > 0 { for i in 0..=nb_steps { let t = tmax * i as f64 / nb_steps as f64; times.push(t); g.advance_until(t); - for s in 0..self.species.len() { - species[s].push(g.get_species(s)); + for (i, s) in save_indices.iter().enumerate() { + species[i].push(g.get_species(*s)); } } } else { // nb_steps = 0: we return every step let mut rates = vec![f64::NAN; g.nb_reactions()]; times.push(g.get_time()); - for s in 0..self.species.len() { - species[s].push(g.get_species(s)); + for (i, s) in save_indices.iter().enumerate() { + species[i].push(g.get_species(*s)); } while g.get_time() < tmax { g._advance_one_reaction(&mut rates); times.push(g.get_time()); - for s in 0..self.species.len() { - species[s].push(g.get_species(s)); + for (i, s) in save_indices.iter().enumerate() { + species[i].push(g.get_species(*s)); } } } let mut result = HashMap::new(); - for (name, &id) in &self.species { - result.insert(name.clone(), species[id].clone()); + match var_names { + Some(x) => { + for (id, name) in x.iter().enumerate() { + result.insert(name.clone(), species[id].clone()); + } + } + None => { + for (name, &id) in &self.species { + result.insert(name.clone(), species[id].clone()); + } + } } Ok((times, result)) } diff --git a/tests/test_rebop.py b/tests/test_rebop.py index 2528bcf..f6853df 100644 --- a/tests/test_rebop.py +++ b/tests/test_rebop.py @@ -55,7 +55,33 @@ def test_all_reactions(seed: int) -> None: def test_dense_vs_sparse() -> None: sir = sir_model() init = {"S": 999, "I": 1} - kwargs = {"tmax": 250, "nb_steps": 250, "seed": 42} - ds_dense = sir.run(init, **kwargs, sparse=False) - ds_sparse = sir.run(init, **kwargs, sparse=True) + tmax = 250 + nb_steps = 250 + seed = 42 + ds_dense = sir.run(init, tmax=tmax, nb_steps=nb_steps, seed=seed, sparse=False) + ds_sparse = sir.run(init, tmax=tmax, nb_steps=nb_steps, seed=seed, sparse=True) assert (ds_dense == ds_sparse).all() + + +@pytest.mark.parametrize("nb_steps", [0, 250]) +def test_var_names(nb_steps: int) -> None: + all_variables = {"S", "I", "R"} + subset_to_save = ["S", "I"] + remaining = all_variables.difference(subset_to_save) + + sir = sir_model() + init = {"S": 999, "I": 1} + tmax = 250 + seed = 0 + + ds_all = sir.run(init, tmax=tmax, nb_steps=nb_steps, seed=seed, var_names=None) + ds_subset = sir.run( + init, tmax=tmax, nb_steps=nb_steps, seed=seed, var_names=subset_to_save + ) + + for s in subset_to_save: + assert s in ds_subset + for s in remaining: + assert s not in ds_subset + + assert ds_all[subset_to_save] == ds_subset