Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: save subset of variables #36

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion python/rebop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import xarray as xr

from .rebop import Gillespie, __version__ # type: ignore[attr-defined]

if TYPE_CHECKING:
from collections.abc import Sequence

__all__ = ("Gillespie", "__version__")

og_run = Gillespie.run
Expand All @@ -17,13 +22,22 @@ def run_xarray( # noqa: PLR0913 too many parameters in function definition
seed: int | None = None,
*,
sparse: bool = False,
var_names: Sequence[str] | None = None,
) -> xr.Dataset:
"""Run the system until `tmax` with `nb_steps` steps.

The initial configuration is specified in the dictionary `init`.
Returns an xarray Dataset.
"""
times, result = og_run(self, init, tmax, nb_steps, seed, sparse=sparse)
times, result = og_run(
self,
init,
tmax,
nb_steps,
seed,
sparse=sparse,
var_names=var_names,
)
ds = xr.Dataset(
data_vars={
name: xr.DataArray(values, dims="time", coords={"time": times})
Expand Down
3 changes: 3 additions & 0 deletions python/rebop/rebop.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Sequence

import xarray

class Gillespie:
Expand Down Expand Up @@ -32,6 +34,7 @@ class Gillespie:
seed: int | None = None,
*,
sparse: bool = False,
var_names: Sequence[str] | None = None,
) -> xarray.Dataset:
"""Run the system until `tmax` with `nb_steps` steps.

Expand Down
37 changes: 27 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,15 @@ impl Gillespie {
/// values at the given time points. One can specify a random `seed` for reproducibility.
/// If `nb_steps` is `0`, then returns all reactions, ending with the first that happens at
/// or after `tmax`.
#[pyo3(signature = (init, tmax, nb_steps, seed=None, sparse=false))]
#[pyo3(signature = (init, tmax, nb_steps, seed=None, sparse=false, var_names=None))]
fn run(
&self,
init: HashMap<String, usize>,
tmax: f64,
nb_steps: usize,
seed: Option<u64>,
sparse: bool,
var_names: Option<Vec<String>>,
) -> PyResult<(Vec<f64>, HashMap<String, Vec<isize>>)> {
let mut x0 = vec![0; self.species.len()];
for (name, &value) in &init {
Expand All @@ -322,6 +323,13 @@ impl Gillespie {
Some(seed) => gillespie::Gillespie::new_with_seed(x0, sparse, seed),
None => gillespie::Gillespie::new(x0, sparse),
};
let save_indices: Vec<_> = match &var_names {
Some(x) => x
.iter()
.map(|key| self.species.get(key).unwrap().clone())
.collect(),
None => (0..self.species.len()).collect(),
};

for (rate, reactants, products) in self.reactions.iter() {
let mut vreactants = vec![0; self.species.len()];
Expand All @@ -340,34 +348,43 @@ impl Gillespie {
}
let mut times = Vec::new();
// species.shape = (species, nb_steps)
let mut species = vec![Vec::new(); self.species.len()];
let mut species = vec![Vec::new(); save_indices.len()];
if nb_steps > 0 {
for i in 0..=nb_steps {
let t = tmax * i as f64 / nb_steps as f64;
times.push(t);
g.advance_until(t);
for s in 0..self.species.len() {
species[s].push(g.get_species(s));
for (i, s) in save_indices.iter().enumerate() {
species[i].push(g.get_species(*s));
}
}
} else {
// nb_steps = 0: we return every step
let mut rates = vec![f64::NAN; g.nb_reactions()];
times.push(g.get_time());
for s in 0..self.species.len() {
species[s].push(g.get_species(s));
for (i, s) in save_indices.iter().enumerate() {
species[i].push(g.get_species(*s));
}
while g.get_time() < tmax {
g._advance_one_reaction(&mut rates);
times.push(g.get_time());
for s in 0..self.species.len() {
species[s].push(g.get_species(s));
for (i, s) in save_indices.iter().enumerate() {
species[i].push(g.get_species(*s));
}
}
}
let mut result = HashMap::new();
for (name, &id) in &self.species {
result.insert(name.clone(), species[id].clone());
match var_names {
Some(x) => {
for (id, name) in x.iter().enumerate() {
result.insert(name.clone(), species[id].clone());
}
}
None => {
for (name, &id) in &self.species {
result.insert(name.clone(), species[id].clone());
}
}
}
Ok((times, result))
}
Expand Down
32 changes: 29 additions & 3 deletions tests/test_rebop.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,33 @@ def test_all_reactions(seed: int) -> None:
def test_dense_vs_sparse() -> None:
sir = sir_model()
init = {"S": 999, "I": 1}
kwargs = {"tmax": 250, "nb_steps": 250, "seed": 42}
ds_dense = sir.run(init, **kwargs, sparse=False)
ds_sparse = sir.run(init, **kwargs, sparse=True)
tmax = 250
nb_steps = 250
seed = 42
ds_dense = sir.run(init, tmax=tmax, nb_steps=nb_steps, seed=seed, sparse=False)
ds_sparse = sir.run(init, tmax=tmax, nb_steps=nb_steps, seed=seed, sparse=True)
assert (ds_dense == ds_sparse).all()


@pytest.mark.parametrize("nb_steps", [0, 250])
def test_var_names(nb_steps: int) -> None:
all_variables = {"S", "I", "R"}
subset_to_save = ["S", "I"]
remaining = all_variables.difference(subset_to_save)

sir = sir_model()
init = {"S": 999, "I": 1}
tmax = 250
seed = 0

ds_all = sir.run(init, tmax=tmax, nb_steps=nb_steps, seed=seed, var_names=None)
ds_subset = sir.run(
init, tmax=tmax, nb_steps=nb_steps, seed=seed, var_names=subset_to_save
)

for s in subset_to_save:
assert s in ds_subset
for s in remaining:
assert s not in ds_subset

assert ds_all[subset_to_save] == ds_subset
Loading