diff --git a/arbor/include/arbor/lif_cell.hpp b/arbor/include/arbor/lif_cell.hpp index cff3eeb551..64557616f2 100644 --- a/arbor/include/arbor/lif_cell.hpp +++ b/arbor/include/arbor/lif_cell.hpp @@ -16,11 +16,17 @@ struct ARB_SYMBOL_VISIBLE lif_cell { double C_m = 20; // Membrane capacitance [pF]. double E_L = 0; // Resting potential [mV]. double V_m = E_L; // Initial value of the Membrane potential [mV]. - double V_reset = E_L; // Reset potential [mV]. double t_ref = 2; // Refractory period [ms]. lif_cell() = delete; lif_cell(cell_tag_type source, cell_tag_type target): source(std::move(source)), target(std::move(target)) {} }; +// LIF probe metadata, to be passed to sampler callbacks. Intentionally left blank. +struct ARB_SYMBOL_VISIBLE lif_probe_metadata {}; + +// Voltage estimate [mV]. +// Sample value type: `double` +struct ARB_SYMBOL_VISIBLE lif_probe_voltage {}; + } // namespace arb diff --git a/arbor/lif_cell_group.cpp b/arbor/lif_cell_group.cpp index 0fa71789ec..1dffa20d3c 100644 --- a/arbor/lif_cell_group.cpp +++ b/arbor/lif_cell_group.cpp @@ -5,6 +5,8 @@ #include "profile/profiler_macro.hpp" #include "util/rangeutil.hpp" #include "util/span.hpp" +#include "util/filter.hpp" +#include "util/maputil.hpp" using namespace arb; @@ -13,8 +15,16 @@ lif_cell_group::lif_cell_group(const std::vector& gids, const rec gids_(gids) { for (auto gid: gids_) { - if (!rec.get_probes(gid).empty()) { - throw bad_cell_probe(cell_kind::lif, gid); + auto probes = rec.get_probes(gid); + for (const auto lid: util::count_along(probes)) { + const auto& probe = probes[lid]; + if (probe.address.type() == typeid(lif_probe_voltage)) { + cell_member_type id{gid, static_cast(lid)}; + probes_[id] = {probe.tag, lif_probe_kind::voltage, {}}; + } + else { + throw bad_cell_probe{cell_kind::lif, gid}; + } } } // Default to no binning of events @@ -22,6 +32,7 @@ lif_cell_group::lif_cell_group(const std::vector& gids, const rec cells_.reserve(gids_.size()); last_time_updated_.resize(gids_.size()); + next_time_updatable_.resize(gids_.size()); for (auto lid: util::make_span(gids_.size())) { cells_.push_back(util::any_cast(rec.get_cell_description(gids_[lid]))); @@ -41,11 +52,9 @@ cell_kind lif_cell_group::get_cell_kind() const { void lif_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) { PE(advance:lif); - if (event_lanes.size() > 0) { - for (auto lid: util::make_span(gids_.size())) { - // Advance each cell independently. - advance_cell(ep.t1, dt, lid, event_lanes[lid]); - } + for (auto lid: util::make_span(gids_.size())) { + // Advance each cell independently. + advance_cell(ep.t1, dt, lid, event_lanes); } PL(); } @@ -59,10 +68,30 @@ void lif_cell_group::clear_spikes() { } // TODO: implement sampler -void lif_cell_group::add_sampler(sampler_association_handle h, cell_member_predicate probeset_ids, - schedule sched, sampler_function fn, sampling_policy policy) {} -void lif_cell_group::remove_sampler(sampler_association_handle h) {} -void lif_cell_group::remove_all_samplers() {} +void lif_cell_group::add_sampler(sampler_association_handle h, + cell_member_predicate probeset_ids, + schedule sched, + sampler_function fn, + sampling_policy policy) { + std::lock_guard guard(sampler_mex_); + std::vector probeset = + util::assign_from(util::filter(util::keys(probes_), probeset_ids)); + auto assoc = arb::sampler_association{std::move(sched), + std::move(fn), + std::move(probeset), + policy}; + auto result = samplers_.insert({h, std::move(assoc)}); + arb_assert(result.second); +} + +void lif_cell_group::remove_sampler(sampler_association_handle h) { + std::lock_guard guard(sampler_mex_); + samplers_.erase(h); +} +void lif_cell_group::remove_all_samplers() { + std::lock_guard guard(sampler_mex_); + samplers_.clear(); +} // TODO: implement binner_ void lif_cell_group::set_binning_policy(binning_kind policy, time_type bin_interval) { @@ -71,52 +100,143 @@ void lif_cell_group::set_binning_policy(binning_kind policy, time_type bin_inter void lif_cell_group::reset() { spikes_.clear(); util::fill(last_time_updated_, 0.); + util::fill(next_time_updatable_, 0.); } // Advances a single cell (lid) with the exact solution (jumps can be arbitrary). // Parameter dt is ignored, since we make jumps between two consecutive spikes. -void lif_cell_group::advance_cell(time_type tfinal, time_type dt, cell_gid_type lid, pse_vector& event_lane) { - // Current time of last update. - auto t = last_time_updated_[lid]; +void lif_cell_group::advance_cell(time_type tfinal, time_type dt, cell_gid_type lid, const event_lane_subrange& event_lanes) { + const auto gid = gids_[lid]; auto& cell = cells_[lid]; - const auto n_events = event_lane.size(); - - // Integrate until tfinal using the exact solution of membrane voltage differential equation. - for (unsigned i=0; i= tfinal) break; // end of integration interval - - // if there are events that happened at the same time as this event, process them as well - while (i + 1 < n_events && event_lane[i+1].time <= time) { - weight += event_lane[i+1].weight; - i++; + // time of last update. + auto t = last_time_updated_[lid]; + // spikes to process + const auto n_events = static_cast(event_lanes.size() ? event_lanes[lid].size() : 0); + int event_idx = 0; + // collected sampling data + std::unordered_map>> sampled; + // samples to process + std::size_t n_values = 0; + std::vector> samples; + { + std::lock_guard guard(sampler_mex_); + for (auto& [hdl, assoc]: samplers_) { + // Construct sampling times + const auto& times = util::make_range(assoc.sched.events(t, tfinal)); + const auto n_times = times.size(); + // Count up the samplers touching _our_ gid + int delta = 0; + for (const auto& pid: assoc.probeset_ids) { + if (pid.gid != gid) continue; + arb_assert (0 == sampled[hdl].count(pid)); + sampled[hdl][pid].reserve(n_times); + delta += n_times; + } + if (delta == 0) continue; + n_values += delta; + // only exact sampling: ignore lax and never look at policy + for (auto t: times) samples.emplace_back(t, hdl); + } + } + std::sort(samples.begin(), samples.end()); + int n_samples = samples.size(); + int sample_idx = 0; + // Now allocate some scratch space for the probed values, if we don't, + // re-alloc might move our data + std::vector sampled_voltages; + sampled_voltages.reserve(n_values); + // integrate until tfinal using the exact solution of membrane voltage differential equation. + for (;;) { + const auto event_time = event_idx < n_events ? event_lanes[lid][event_idx].time : tfinal; + const auto sample_time = sample_idx < n_samples ? samples[sample_idx].first : tfinal; + const auto time = std::min(event_time, sample_time); + // bail at end of integration interval + if (time >= tfinal) break; + // Check what to do, we might need to process events **and/or** perform + // sampling. + // NB. we put events before samples, if they collide we'll see + // the update in sampling. + + bool do_event = time == event_time; + bool do_sample = time == sample_time; + + if (do_event) { + const auto& event_lane = event_lanes[lid]; + // process all events at time t + auto weight = 0.0; + for (; event_idx < n_events && event_lane[event_idx].time <= time; ++event_idx) { + weight += event_lane[event_idx].weight; + } + // skip event if neuron is in refactory period + if (time >= t) { + // Let the membrane potential decay. + cell.V_m *= exp((t - time) / cell.tau_m); + // Add jump due to spike(s). + cell.V_m += weight / cell.C_m; + // Update current time + t = time; + // If crossing threshold occurred + if (cell.V_m >= cell.V_th) { + // save spike + spikes_.push_back({{gid, 0}, time}); + // Advance to account for the refractory period. + // This means decay will also start at t + t_ref + t += cell.t_ref; + // Reset the voltage to resting potential. + cell.V_m = cell.E_L; + } + } } - // Let the membrane potential decay. - auto decay = exp(-(time - t) / cell.tau_m); - cell.V_m *= decay; - auto update = weight / cell.C_m; - // Add jump due to spike. - cell.V_m += update; - t = time; - // If crossing threshold occurred - if (cell.V_m >= cell.V_th) { - cell_member_type spike_neuron_gid = {gids_[lid], 0}; - spike s = {spike_neuron_gid, t}; - spikes_.push_back(s); - - // Advance the last_time_updated to account for the refractory period. - t += cell.t_ref; - - // Reset the voltage to resting potential. - cell.V_m = cell.E_L; + if (do_sample) { + // Consume all sample events at this time + for (; sample_idx < n_samples && samples[sample_idx].first <= time; ++sample_idx) { + const auto& [s_time, hdl] = samples[sample_idx]; + for (const auto& key: samplers_[hdl].probeset_ids) { + const auto& kind = probes_.at(key).kind; + // This is the only thing we know how to do: Probing U(t) + switch (kind) { + case lif_probe_kind::voltage: { + // Compute, but do not _set_ V_m + auto U = cell.V_m; + if (time >= t) U *= exp((t - time) / cell.tau_m); + // Store U for later use. + sampled_voltages.push_back(U); + // Set up reference to sampled value + sampled[hdl][key].push_back(sample_record{time, {&sampled_voltages.back()}}); + break; + } + default: + throw arbor_internal_error{"Invalid LIF probe kind"}; + } + } + } + } + if (!(do_sample || do_event)) { + throw arbor_internal_error{"LIF cell group: Must select either sample or spike event; got neither."}; } + last_time_updated_[lid] = t; } + arb_assert (sampled_voltages.size() == n_values); + // Now we need to call all sampler callbacks with the data we have collected + { + std::lock_guard guard(sampler_mex_); + for (const auto& [k, vs]: sampled) { + const auto& fun = samplers_[k].sampler; + for (const auto& [id, us]: vs) { + auto meta = get_probe_metadata(id)[0]; + fun(meta, us.size(), us.data()); + } + } + } +} - // This is the last time a cell was updated. - last_time_updated_[lid] = t; +std::vector lif_cell_group::get_probe_metadata(cell_member_type key) const { + if (probes_.count(key)) { + return {probe_metadata{key, {}, 0, {&probes_.at(key).metadata}}}; + } else { + return {}; + } } diff --git a/arbor/lif_cell_group.hpp b/arbor/lif_cell_group.hpp index b89e442f86..dbebbfa777 100644 --- a/arbor/lif_cell_group.hpp +++ b/arbor/lif_cell_group.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -9,6 +10,7 @@ #include #include +#include "sampler_map.hpp" #include "cell_group.hpp" #include "label_resolution.hpp" @@ -37,10 +39,21 @@ class ARB_ARBOR_API lif_cell_group: public cell_group { virtual void remove_sampler(sampler_association_handle) override; virtual void remove_all_samplers() override; + virtual std::vector get_probe_metadata(cell_member_type) const override; + private: + enum class lif_probe_kind { voltage }; + + struct lif_probe_info { + probe_tag tag; + lif_probe_kind kind; + lif_probe_metadata metadata; + }; + + // Advances a single cell (lid) with the exact solution (jumps can be arbitrary). // Parameter dt is ignored, since we make jumps between two consecutive spikes. - void advance_cell(time_type tfinal, time_type dt, cell_gid_type lid, pse_vector& event_lane); + void advance_cell(time_type tfinal, time_type dt, cell_gid_type lid, const event_lane_subrange& event_lane); // List of the gids of the cells in the group. std::vector gids_; @@ -53,6 +66,16 @@ class ARB_ARBOR_API lif_cell_group: public cell_group { // Time when the cell was last updated. std::vector last_time_updated_; + // Time when the cell can _next_ be updated; + std::vector next_time_updatable_; + + // SAFETY: We need to access samplers_ through a mutex since + // simulation::add_sampler might be called concurrently. + std::mutex sampler_mex_; + sampler_association_map samplers_; + + // LIF probe metadata, precalculated to pass to callbacks + std::unordered_map probes_; }; } // namespace arb diff --git a/doc/concepts/cable_cell.rst b/doc/concepts/cable_cell.rst index a619051264..3364f20fb4 100644 --- a/doc/concepts/cable_cell.rst +++ b/doc/concepts/cable_cell.rst @@ -46,7 +46,6 @@ Once constructed, the cable cell can be queried for specific information about t labels mechanisms decor - probe_sample API --- diff --git a/doc/concepts/index.rst b/doc/concepts/index.rst index f1cf92b071..cd0f20c36b 100644 --- a/doc/concepts/index.rst +++ b/doc/concepts/index.rst @@ -42,3 +42,5 @@ of the model over the locally available computational resources. In order to visualize the result of detected spikes a spike recorder can be used, and to analyse Arbor's performance a meter manager is available. + +:ref:`probesample` shows how to extract data from simulations. diff --git a/doc/concepts/lif_cell.rst b/doc/concepts/lif_cell.rst index c965c967b0..18fdcd44d2 100644 --- a/doc/concepts/lif_cell.rst +++ b/doc/concepts/lif_cell.rst @@ -5,19 +5,35 @@ LIF cells The description of a LIF cell is used to control the leaky integrate-and-fire dynamics: -* Resting potential. -* Reset potential. -* Initial value of membrane potential. -* Membrane potential decaying constant. -* Membrane capacitance. -* Firing threshold. -* Refractory period. - -The morphology of a LIF cell is automatically modelled as a single :term:`compartment `; -each cell has one built-in **source** and one built-in **target** which need to be given labels when the -cell is created. The labels are used to form connections to and from the cell. -LIF cells do not support adding additional **sources** or **targets** to the description. They do not support -**gap junctions**. They do not support adding density or point mechanisms. +* Resting potential :math:`E_\mathrm{L}` +* Membrane potential decaying constant :math:`\tau_\mathrm{m}` +* Membrane capacitance :math:`C_\mathrm{m}` +* Firing threshold :math:`U_\mathrm{threshold}` +* Refractory period :math:`t_\mathrm{ref}` + +The morphology of a LIF cell is automatically modelled as a single +:term:`compartment `; each cell has one built-in **source** and +one built-in **target** which need to be given labels when the cell is created. +The labels are used to form connections to and from the cell. LIF cells do not +support adding additional **sources** or **targets** to the description. They do +not support **gap junctions**. They do not support adding density or point +mechanisms. + +The LIF cell's time dynamics are this: + +0. :math:`U_\mathrm{m}(0) = E_\mathrm{L}` +1. If the cell is in its refractory state :math:`U_\mathrm{m}(t) = E_\mathrm{L}` +2. Otherwise :math:`U'_\mathrm{m}(t) = \sum\limits_\mathrm{spike} w_\mathrm{spike} \cdot\delta(t - t_\mathrm{spike}) -\frac{1}{\tau_\mathrm{m}}U_\mathrm{m}(t)` +3. If :math:`U_\mathrm{m}(t_0) \geq U_\mathrm{threshold}`: emit spike and enter refractory period until :math:`t = t_0 + t_\mathrm{ref}` + +LIF cells can be probed to obtain their current membrane potential, see :ref:`probesample`. + +.. figure:: ../images/lif.svg + :width: 400 + :align: center + + Plot of the potential over time for a LIF cell. + API --- diff --git a/doc/concepts/probe_sample.rst b/doc/concepts/probe_sample.rst index 0e3b0fcf94..9b553eb683 100644 --- a/doc/concepts/probe_sample.rst +++ b/doc/concepts/probe_sample.rst @@ -1,7 +1,12 @@ .. _probesample: -Cable cell probing and sampling -=============================== +Probing and Sampling +==================== + +Both cable cells and LIF cells can be probed, see here for more details on cells +:ref:`modelcells`. The LIF cell, however, has a much smaller set of observable +quantities and only offers scalar probes. Thus, the following discussion is +tailored to the cable cell. Definitions *********** @@ -64,7 +69,6 @@ Spiking Threshold detectors have a dual use: they can be used to record spike times, but are also used in propagating signals between cells. See also :term:`threshold detector` and :ref:`cablecell-threshold-detectors`. - API --- diff --git a/doc/cpp/probe_sample.rst b/doc/cpp/probe_sample.rst index cb7b08cfbb..c9eb0c231e 100644 --- a/doc/cpp/probe_sample.rst +++ b/doc/cpp/probe_sample.rst @@ -660,3 +660,19 @@ call the *sampler* callback once for probe in *probe set*, with *n* sample value In addition to the ``lax`` sampling policy, ``mc_cell_group`` supports the ``exact`` policy. Integration steps will be shortened such that any sample times associated with an ``exact`` policy can be satisfied precisely. + +LIF cell probing and sampling +=============================== + +Membrane voltage +---------------- + +.. code:: + + struct lif_probe_voltage {}; + +Queries cell membrane potential. + +* Sample value: ``double``. Membrane potential (mV). + +* Metadata: none diff --git a/doc/images/lif.svg b/doc/images/lif.svg new file mode 100644 index 0000000000..b4c31ddb8f --- /dev/null +++ b/doc/images/lif.svg @@ -0,0 +1,1297 @@ + + + + + + + + + 2022-10-27T16:15:02.520502 + image/svg+xml + + + Matplotlib v3.3.4, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/doc/index.rst b/doc/index.rst index af35bfc62f..a13eafa3e1 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -129,6 +129,7 @@ A full list of our software attributions can be found `here maybe_method(pybind11::object method) { std::string lif_str(const arb::lif_cell& c){ return util::pprintf( - "", - c.tau_m, c.V_th, c.C_m, c.E_L, c.V_m, c.t_ref, c.V_reset); + "", + c.tau_m, c.V_th, c.C_m, c.E_L, c.V_m, c.t_ref); } @@ -230,8 +230,6 @@ void register_cells(pybind11::module& m) { "Initial value of the Membrane potential [mV].") .def_readwrite("t_ref", &arb::lif_cell::t_ref, "Refractory period [ms].") - .def_readwrite("V_reset", &arb::lif_cell::V_reset, - "Reset potential [mV].") .def_readwrite("source", &arb::lif_cell::source, "Label of the single build-in source on the cell.") .def_readwrite("target", &arb::lif_cell::target, diff --git a/python/example/brunel.py b/python/example/brunel.py index a147e73784..ba8258e2c8 100755 --- a/python/example/brunel.py +++ b/python/example/brunel.py @@ -105,7 +105,6 @@ def cell_description(self, gid): cell.C_m = 20 cell.E_L = 0 cell.V_m = 0 - cell.V_reset = 0 cell.t_ref = 2 return cell diff --git a/python/cable_probes.cpp b/python/probes.cpp similarity index 88% rename from python/cable_probes.cpp rename to python/probes.cpp index a309578ad7..a589166bdc 100644 --- a/python/cable_probes.cpp +++ b/python/probes.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -26,7 +27,7 @@ namespace pyarb { // to cable_cell scalar- and vector-valued probes. template -struct recorder_cable_base: sample_recorder { +struct recorder_base: sample_recorder { // Return stride-column array: first column is time, remainder correspond to sample. py::object samples() const override { @@ -49,14 +50,14 @@ struct recorder_cable_base: sample_recorder { std::vector sample_raw_; std::ptrdiff_t stride_; - recorder_cable_base(const Meta* meta_ptr, std::ptrdiff_t width): + recorder_base(const Meta* meta_ptr, std::ptrdiff_t width): meta_(*meta_ptr), stride_(1+width) {} }; template -struct recorder_cable_scalar: recorder_cable_base { - using recorder_cable_base::sample_raw_; +struct recorder_cable_scalar: recorder_base { + using recorder_base::sample_raw_; void record(any_ptr, std::size_t n_sample, const arb::sample_record* records) override { for (std::size_t i = 0; i { } protected: - recorder_cable_scalar(const Meta* meta_ptr): recorder_cable_base(meta_ptr, 1) {} + recorder_cable_scalar(const Meta* meta_ptr): recorder_base(meta_ptr, 1) {} }; +struct recorder_lif: recorder_base { + using recorder_base::sample_raw_; + + void record(any_ptr, std::size_t n_sample, const arb::sample_record* records) override { + for (std::size_t i = 0; i(records[i].data)) { + sample_raw_.push_back(records[i].time); + sample_raw_.push_back(*v_ptr); + } + else { + std::string ty = records[i].data.type().name(); + throw arb::arbor_internal_error("LIF recorder: unexpected sample type " + ty); + } + } + } + + recorder_lif(const arb::lif_probe_metadata* meta_ptr): recorder_base(meta_ptr, 1) {} +}; + + template -struct recorder_cable_vector: recorder_cable_base { - using recorder_cable_base::sample_raw_; +struct recorder_cable_vector: recorder_base { + using recorder_base::sample_raw_; void record(any_ptr, std::size_t n_sample, const arb::sample_record* records) override { for (std::size_t i = 0; i { protected: recorder_cable_vector(const Meta* meta_ptr, std::ptrdiff_t width): - recorder_cable_base(meta_ptr, width) {} + recorder_base(meta_ptr, width) {} }; // Specific recorder classes: @@ -132,6 +153,8 @@ void register_probe_meta_maps(pyarb_global_ptr g) { }); } + + // Wrapper functions around cable_cell probe types that return arb::probe_info values: // (Probe tag value is implicitly left at zero.) @@ -211,6 +234,12 @@ arb::probe_info cable_probe_ion_ext_concentration_cell(const char* ion) { return arb::cable_probe_ion_ext_concentration_cell{ion}; } +// LIF cell probes +arb::probe_info lif_probe_voltage() { + return arb::lif_probe_voltage{}; +} + + // Add wrappers to module, recorder factories to global data. void register_cable_probes(pybind11::module& m, pyarb_global_ptr global_ptr) { @@ -219,6 +248,9 @@ void register_cable_probes(pybind11::module& m, pyarb_global_ptr global_ptr) { // Probe metadata wrappers: + py::class_ lif_probe_metadata(m, "lif_probe_metadata", + "Probe metadata associated with a LIF cell probe."); + py::class_ cable_probe_point_info(m, "cable_probe_point_info", "Probe metadata associated with a cable cell probe for point process state."); @@ -236,6 +268,9 @@ void register_cable_probes(pybind11::module& m, pyarb_global_ptr global_ptr) { // Probe address constructors: + m.def("lif_probe_voltage", &lif_probe_voltage, + "Probe specification for LIF cell membrane voltage."); + m.def("cable_probe_membrane_voltage", &cable_probe_membrane_voltage, "Probe specification for cable cell membrane voltage interpolated at points in a location set.", "where"_a); @@ -314,6 +349,7 @@ void register_cable_probes(pybind11::module& m, pyarb_global_ptr global_ptr) { register_probe_meta_maps(global_ptr); register_probe_meta_maps(global_ptr); register_probe_meta_maps, recorder_cable_vector_point_info>(global_ptr); + register_probe_meta_maps(global_ptr); } } // namespace pyarb diff --git a/python/pyarb.hpp b/python/pyarb.hpp index 954195a439..1b4241c113 100644 --- a/python/pyarb.hpp +++ b/python/pyarb.hpp @@ -57,7 +57,8 @@ struct recorder_factory_map { return map_.at(meta.type())(meta); } catch (std::out_of_range&) { - throw arb::arbor_internal_error("unrecognized probe metadata type"); + std::string ty = meta.type().name(); + throw arb::arbor_internal_error("unrecognized probe metadata type " + ty); } } }; diff --git a/python/test/unit/test_cable_probes.py b/python/test/unit/test_probes.py similarity index 78% rename from python/test/unit/test_cable_probes.py rename to python/test/unit/test_probes.py index daa0ffd29c..ff045a64cb 100644 --- a/python/test/unit/test_cable_probes.py +++ b/python/test/unit/test_probes.py @@ -2,6 +2,7 @@ import unittest import arbor as A +import numpy as np """ tests for cable probe wrappers @@ -166,3 +167,63 @@ def test_probe_addr_metadata(self): m = sim.probe_metadata((0, 16)) self.assertEqual(1, len(m)) self.assertEqual(all_cv_cables, m[0]) + + +class lif_recipe(A.recipe): + def __init__(self): + A.recipe.__init__(self) + + def num_cells(self): + return 1 + + def cell_kind(self, gid): + return A.cell_kind.lif + + def global_properties(self, kind): + return None + + def probes(self, gid): + return [ + # probe id (0, 0) + A.lif_probe_voltage(), + ] + + def cell_description(self, gid): + cell = A.lif_cell("src", "tgt") + cell.E_L = -42 + cell.V_m = -23 + cell.t_ref = 0.2 + return cell + + +class TestLifProbes(unittest.TestCase): + def test_probe_addr_metadata(self): + rec = lif_recipe() + sim = A.simulation(rec) + + m = sim.probe_metadata((0, 0)) + self.assertEqual(1, len(m)) + self.assertTrue(all(isinstance(i, A.lif_probe_metadata) for i in m)) + + def test_probe_result(self): + rec = lif_recipe() + sim = A.simulation(rec) + hdl = sim.sample((0, 0), A.regular_schedule(0.1)) + sim.run(1.0, 0.05) + smp = sim.samples(hdl) + exp = np.array( + [ + [0.0, -23.0], + [0.1, -22.77114618], + [0.2, -22.54456949], + [0.3, -22.32024727], + [0.4, -22.0981571], + [0.5, -21.87827676], + [0.6, -21.66058427], + [0.7, -21.44505786], + [0.8, -21.23167597], + [0.9, -21.02041726], + ] + ) + for d, _ in smp: + np.testing.assert_allclose(d, exp) diff --git a/test/unit/test_lif_cell_group.cpp b/test/unit/test_lif_cell_group.cpp index 6916f01a99..60d8d8c9ee 100644 --- a/test/unit/test_lif_cell_group.cpp +++ b/test/unit/test_lif_cell_group.cpp @@ -1,5 +1,7 @@ #include +#include "common.hpp" + #include #include #include @@ -114,7 +116,7 @@ class probe_recipe: public arb::recipe { probe_recipe() {} cell_size_type num_cells() const override { - return 1; + return 2; } cell_kind get_cell_kind(cell_gid_type gid) const override { return cell_kind::lif; @@ -123,17 +125,29 @@ class probe_recipe: public arb::recipe { return {}; } util::unique_any get_cell_description(cell_gid_type gid) const override { - return lif_cell("src", "tgt"); + auto cell = lif_cell("src", "tgt"); + if (gid == 0) { + cell.E_L = -42; + cell.V_m = -23; + cell.t_ref = 0.2; + } + return cell; } - std::vector get_probes(cell_gid_type gid) const override{ - return {arb::cable_probe_membrane_voltage{mlocation{0, 0}}}; + std::vector get_probes(cell_gid_type gid) const override { + if (gid == 0) { + return {arb::lif_probe_voltage{}, arb::lif_probe_voltage{}}; + } else { + return {arb::lif_probe_voltage{}}; + } } + std::vector event_generators(cell_gid_type) const override { return {regular_generator({"tgt"}, 100.0, 0.25, 0.05)}; } }; + TEST(lif_cell_group, throw) { probe_recipe rec; auto context = make_context(); auto decomp = partition_load_balance(rec, context); - EXPECT_THROW(simulation(rec, context, decomp), bad_cell_probe); + EXPECT_NO_THROW(simulation(rec, context, decomp)); } TEST(lif_cell_group, recipe) @@ -188,12 +202,9 @@ TEST(lif_cell_group, ring) // Total simulation time. time_type simulation_time = 100; - auto context = make_context(); auto recipe = ring_recipe(num_lif_cells, weight, delay); - auto decomp = partition_load_balance(recipe, context); - // Creates a simulation with a ring recipe of lif neurons - simulation sim(recipe, context, decomp); + simulation sim(recipe); std::vector spike_buffer; @@ -221,3 +232,115 @@ TEST(lif_cell_group, ring) } } +struct Um_type { + constexpr static double delta = 1e-6; + + double t; + double u; + + friend std::ostream& operator<<(std::ostream& os, const Um_type& um) { + os << "{ " << um.t << ", " << um.u << " }"; + return os; + } + + friend bool operator==(const Um_type& lhs, const Um_type& rhs) { + return (std::abs(lhs.t - rhs.t) <= delta) + && (std::abs(lhs.u - rhs.u) <= delta); + } +}; + +TEST(lif_cell_group, probe) { + auto ums = std::unordered_map>{}; + auto fun = [&ums](probe_metadata pm, + std::size_t n, + const sample_record* samples) { + for (int ix = 0; ix < n; ++ix) { + const auto& [t, v] = samples[ix]; + double u = *util::any_cast(v); + ums[pm.id].push_back({t, u}); + } + }; + auto rec = probe_recipe{}; + auto sim = simulation(rec); + + sim.add_sampler(all_probes, regular_schedule(0.025), fun); + + std::vector spikes; + + sim.set_global_spike_callback( + [&spikes](const std::vector& spk) { for (const auto& s: spk) spikes.push_back(s.time); } + ); + + sim.run(1.5, 0.005); + std::vector exp = {{0, -23}, + {0.025, -22.9425718}, + {0.05, -22.885287}, + {0.075, -22.8281453}, + {0.1, -22.7711462}, + {0.125, -22.7142894}, + {0.15, -22.6575746}, + {0.175, -22.6010014}, + {0.2, -22.5445695}, + {0.225, -22.4882785}, + {0.25, -17.432128}, + {0.275, -17.3886021}, + {0.3, -12.3451849}, + {0.325, -12.3143605}, + {0.35, -7.28361301}, + {0.375, -7.26542672}, + {0.4, -2.24728584}, + {0.425, -2.24167464}, + {0.45, 2.76392255}, + {0.475, 2.75702137}, + {0.5, 7.75013743}, + {0.525, 7.73078628}, + {0.55, -42}, + {0.575, -42}, + {0.6, -42}, + {0.625, -42}, + {0.65, -42}, + {0.675, -42}, + {0.7, -42}, + {0.725, -42}, + {0.75, -37}, + {0.775, -36.9076155}, + {0.8, -31.8154617}, + {0.825, -31.7360224}, + {0.85, -26.6567815}, + {0.875, -26.5902227}, + {0.9, -21.5238302}, + {0.925, -21.4700878}, + {0.95, -16.4164796}, + {0.975, -16.3754897}, + {1, -11.3346021}, + {1.025, -11.306301}, + {1.05, -6.27807055}, + {1.075, -6.26239498}, + {1.1, -1.24675854}, + {1.125, -1.24364554}, + {1.15, 3.75945969}, + {1.175, 3.75007278}, + {1.2, 8.74070931}, + {1.225, 8.71888483}, + {1.25, -42}, + {1.275, -42}, + {1.3, -42}, + {1.325, -42}, + {1.35, -42}, + {1.375, -42}, + {1.4, -42}, + {1.425, -42}, + {1.45, -37}, + {1.475, -36.9076155},}; + + ASSERT_TRUE(testing::seq_eq(ums[{0, 0}], exp)); + ASSERT_TRUE(testing::seq_eq(ums[{0, 1}], exp)); + // gid == 1 is different, but of same size + EXPECT_EQ((ums[{1, 0}].size()), exp.size()); + ASSERT_FALSE(testing::seq_eq(ums[{1, 0}], exp)); + // now check the spikes + std::sort(spikes.begin(), spikes.end()); + EXPECT_EQ(spikes.size(), 3); + std::vector sexp{0.35, 0.55, 1.25}; + ASSERT_TRUE(testing::seq_almost_eq(spikes, sexp)); +}