Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make LIF cells probeable. #2021

Merged
merged 8 commits into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion arbor/include/arbor/lif_cell.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@ struct ARB_SYMBOL_VISIBLE lif_cell {
double C_m = 20; // Membrane capacitance [pF].
double E_L = 0; // Resting potential [mV].
double V_m = E_L; // Initial value of the Membrane potential [mV].
double V_reset = E_L; // Reset potential [mV].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this member there in the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a very good question. It was never used, but set in all the examples and mentioned in the docs...
I am contemplating on whether we should be adding it back in, but this time it should be actually doing
something ;)

I'll ask some of your LIF-savvy users.

double t_ref = 2; // Refractory period [ms].

lif_cell() = delete;
lif_cell(cell_tag_type source, cell_tag_type target): source(std::move(source)), target(std::move(target)) {}
};

// LIF probe metadata, to be passed to sampler callbacks. Intentionally left blank.
struct ARB_SYMBOL_VISIBLE lif_probe_metadata {};

// Voltage estimate [mV].
// Sample value type: `double`
struct ARB_SYMBOL_VISIBLE lif_probe_voltage {};

} // namespace arb
218 changes: 168 additions & 50 deletions arbor/lif_cell_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "profile/profiler_macro.hpp"
#include "util/rangeutil.hpp"
#include "util/span.hpp"
#include "util/filter.hpp"
#include "util/maputil.hpp"

using namespace arb;

Expand All @@ -13,15 +15,24 @@ lif_cell_group::lif_cell_group(const std::vector<cell_gid_type>& gids, const rec
gids_(gids)
{
for (auto gid: gids_) {
if (!rec.get_probes(gid).empty()) {
throw bad_cell_probe(cell_kind::lif, gid);
auto probes = rec.get_probes(gid);
for (const auto lid: util::count_along(probes)) {
const auto& probe = probes[lid];
cell_member_type id = {gid, static_cast<cell_lid_type>(lid)};
if (probe.address.type() == typeid(lif_probe_voltage)) {
probes_[id] = {probe.tag, lif_probe_kind::voltage, {}};
}
else {
throw bad_cell_probe{cell_kind::lif, gid};
}
}
}
// Default to no binning of events
lif_cell_group::set_binning_policy(binning_kind::none, 0);

cells_.reserve(gids_.size());
last_time_updated_.resize(gids_.size());
next_time_updatable_.resize(gids_.size());

for (auto lid: util::make_span(gids_.size())) {
cells_.push_back(util::any_cast<lif_cell>(rec.get_cell_description(gids_[lid])));
Expand All @@ -41,11 +52,9 @@ cell_kind lif_cell_group::get_cell_kind() const {

void lif_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) {
PE(advance:lif);
if (event_lanes.size() > 0) {
for (auto lid: util::make_span(gids_.size())) {
// Advance each cell independently.
advance_cell(ep.t1, dt, lid, event_lanes[lid]);
}
for (auto lid: util::make_span(gids_.size())) {
// Advance each cell independently.
advance_cell(ep.t1, dt, lid, event_lanes);
}
PL();
}
Expand All @@ -59,10 +68,30 @@ void lif_cell_group::clear_spikes() {
}

// TODO: implement sampler
void lif_cell_group::add_sampler(sampler_association_handle h, cell_member_predicate probeset_ids,
schedule sched, sampler_function fn, sampling_policy policy) {}
void lif_cell_group::remove_sampler(sampler_association_handle h) {}
void lif_cell_group::remove_all_samplers() {}
void lif_cell_group::add_sampler(sampler_association_handle h,
cell_member_predicate probeset_ids,
schedule sched,
sampler_function fn,
sampling_policy policy) {
std::lock_guard<std::mutex> guard(sampler_mex_);
std::vector<cell_member_type> probeset =
util::assign_from(util::filter(util::keys(probes_), probeset_ids));
auto assoc = arb::sampler_association{std::move(sched),
std::move(fn),
std::move(probeset),
policy};
auto result = samplers_.insert({h, std::move(assoc)});
arb_assert(result.second);
}

void lif_cell_group::remove_sampler(sampler_association_handle h) {
std::lock_guard<std::mutex> guard(sampler_mex_);
samplers_.erase(h);
}
void lif_cell_group::remove_all_samplers() {
std::lock_guard<std::mutex> guard(sampler_mex_);
samplers_.clear();
}

// TODO: implement binner_
void lif_cell_group::set_binning_policy(binning_kind policy, time_type bin_interval) {
Expand All @@ -71,52 +100,141 @@ void lif_cell_group::set_binning_policy(binning_kind policy, time_type bin_inter
void lif_cell_group::reset() {
spikes_.clear();
util::fill(last_time_updated_, 0.);
util::fill(next_time_updatable_, 0.);
}

// Advances a single cell (lid) with the exact solution (jumps can be arbitrary).
// Parameter dt is ignored, since we make jumps between two consecutive spikes.
void lif_cell_group::advance_cell(time_type tfinal, time_type dt, cell_gid_type lid, pse_vector& event_lane) {
// Current time of last update.
auto t = last_time_updated_[lid];
void lif_cell_group::advance_cell(time_type tfinal, time_type dt, cell_gid_type lid, const event_lane_subrange& event_lanes) {
// our gid
const auto gid = gids_[lid];
auto& cell = cells_[lid];
const auto n_events = event_lane.size();

// Integrate until tfinal using the exact solution of membrane voltage differential equation.
for (unsigned i=0; i<n_events; ++i ) {
auto& ev = event_lane[i];
const auto time = ev.time;
auto weight = ev.weight;

if (time < t) continue; // skip event if a neuron is in refactory period
if (time >= tfinal) break; // end of integration interval

// if there are events that happened at the same time as this event, process them as well
while (i + 1 < n_events && event_lane[i+1].time <= time) {
weight += event_lane[i+1].weight;
i++;
// time of last update.
auto t = last_time_updated_[lid];
// integrate until tfinal using the exact solution of membrane voltage differential equation.
// spikes to process
const auto n_events = static_cast<int>(event_lanes.size() ? event_lanes[lid].size() : 0);
int e_idx = 0;
// collected sampling data
std::unordered_map<sampler_association_handle,
std::unordered_map<cell_member_type,
std::vector<sample_record>>> sampled;
// samples to process
std::vector<std::pair<time_type, sampler_association_handle>> samples;
std::size_t count = 0;
{
std::lock_guard<std::mutex> guard(sampler_mex_);
for (auto& [hdl, assoc]: samplers_) {
// Construct sampling times
const auto& times = util::make_range(assoc.sched.events(t, tfinal));
const auto size = times.size();
// Count up the samplers touching _our_ gid
std::size_t delta = 0;
for (const auto& pid: assoc.probeset_ids) {
if (pid.gid != gid) continue;
arb_assert (0 == sampled[hdl].count(pid));
sampled[hdl][pid].reserve(size);
delta += size;
}
if (delta == 0) continue;
count += delta;
// We only ever use exact sampling, so we over-provision for lax and
// never look at the policy
for (auto t: times) samples.emplace_back(t, hdl);
}

// Let the membrane potential decay.
auto decay = exp(-(time - t) / cell.tau_m);
cell.V_m *= decay;
auto update = weight / cell.C_m;
// Add jump due to spike.
cell.V_m += update;
t = time;
// If crossing threshold occurred
if (cell.V_m >= cell.V_th) {
cell_member_type spike_neuron_gid = {gids_[lid], 0};
spike s = {spike_neuron_gid, t};
spikes_.push_back(s);

// Advance the last_time_updated to account for the refractory period.
t += cell.t_ref;

// Reset the voltage to resting potential.
cell.V_m = cell.E_L;
}
std::sort(samples.begin(), samples.end());
const auto n_samples = static_cast<int>(samples.size());
int s_idx = 0;
// Now allocate some scratch space for the probed values, if we don't,
// re-alloc might move our data
std::vector<value_type> sampled_voltages;
sampled_voltages.reserve(count);
for (;;) {
const auto e_time = e_idx < n_events ? event_lanes[lid][e_idx].time : tfinal;
const auto s_time = s_idx < n_samples ? samples[s_idx].first : tfinal;
const auto time = std::min(e_time, s_time);
// bail at end of integration interval
if (time >= tfinal) break;
// Check what to do, we put events before samples, if they collide we'll
// see the update in sampling.
// We need to incorporate an event
if (time == e_time) {
const auto& event_lane = event_lanes[lid];
// process all events at time t
auto weight = 0.0;
for (; e_idx < n_events && event_lane[e_idx].time <= time; ++e_idx) {
weight += event_lane[e_idx].weight;
}
// skip event if neuron is in refactory period
if (time >= t) {
// Let the membrane potential decay.
cell.V_m *= exp((t - time) / cell.tau_m);
// Add jump due to spike(s).
cell.V_m += weight / cell.C_m;
// Update current time
t = time;
// If crossing threshold occurred
if (cell.V_m >= cell.V_th) {
// save spike
spikes_.push_back({{gid, 0}, time});
// Advance to account for the refractory period.
// This means decay will also start at t + t_ref
t += cell.t_ref;
// Reset the voltage to resting potential.
cell.V_m = cell.E_L;
}
}
}
// We need to probe, so figure out what to do.
if (time == s_time) {
// Consume all sample events at this time
for (; s_idx < n_samples && samples[s_idx].first <= time; ++s_idx) {
const auto& [s_time, hdl] = samples[s_idx];
for (const auto& key: samplers_[hdl].probeset_ids) {
const auto& kind = probes_.at(key).kind;
// This is the only thing we know how to do: Probing U(t)
switch (kind) {
case lif_probe_kind::voltage: {
// Compute, but do not _set_ V_m
auto U = cell.V_m;
if (time >= t) U *= exp((t - time) / cell.tau_m);
// Store U for later use.
sampled_voltages.push_back(U);
// Set up reference to sampled value
auto data_ptr = sampled_voltages.data() + sampled_voltages.size() - 1;
sampled[hdl][key].push_back(sample_record{time, {data_ptr}});
break;
}
default:
throw arbor_internal_error{"Invalid LIF probe kind"};
}
}
}
}
if ((time != s_time) && (time != e_time)) {
throw arbor_internal_error{"LIF cell group: Must select either sample or spike event; got neither."};
}
last_time_updated_[lid] = t;
}
arb_assert (sampled_voltages.size() == count);
// Now we need to call all sampler callbacks with the data we have collected
{
std::lock_guard<std::mutex> guard(sampler_mex_);
for (const auto& [k, vs]: sampled) {
const auto& fun = samplers_[k].sampler;
for (const auto& [id, us]: vs) {
auto meta = get_probe_metadata(id)[0];
fun(meta, us.size(), us.data());
}
}
}
}

// This is the last time a cell was updated.
last_time_updated_[lid] = t;
std::vector<probe_metadata> lif_cell_group::get_probe_metadata(cell_member_type key) const {
if (probes_.count(key)) {
return {probe_metadata{key, {}, 0, {&probes_.at(key).metadata}}};
} else {
return {};
}
}
21 changes: 20 additions & 1 deletion arbor/lif_cell_group.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <vector>
#include <mutex>

#include <arbor/export.hpp>
#include <arbor/common_types.hpp>
Expand All @@ -9,6 +10,7 @@
#include <arbor/sampling.hpp>
#include <arbor/spike.hpp>

#include "sampler_map.hpp"
#include "cell_group.hpp"
#include "label_resolution.hpp"

Expand Down Expand Up @@ -37,10 +39,21 @@ class ARB_ARBOR_API lif_cell_group: public cell_group {
virtual void remove_sampler(sampler_association_handle) override;
virtual void remove_all_samplers() override;

virtual std::vector<probe_metadata> get_probe_metadata(cell_member_type) const override;

private:
enum class lif_probe_kind { voltage };

struct lif_probe_info {
probe_tag tag;
lif_probe_kind kind;
lif_probe_metadata metadata;
};


// Advances a single cell (lid) with the exact solution (jumps can be arbitrary).
// Parameter dt is ignored, since we make jumps between two consecutive spikes.
void advance_cell(time_type tfinal, time_type dt, cell_gid_type lid, pse_vector& event_lane);
void advance_cell(time_type tfinal, time_type dt, cell_gid_type lid, const event_lane_subrange& event_lane);

// List of the gids of the cells in the group.
std::vector<cell_gid_type> gids_;
Expand All @@ -53,6 +66,12 @@ class ARB_ARBOR_API lif_cell_group: public cell_group {

// Time when the cell was last updated.
std::vector<time_type> last_time_updated_;
// Time when the cell can _next_ be updated;
std::vector<time_type> next_time_updatable_;

std::mutex sampler_mex_;
sampler_association_map samplers_;
std::unordered_map<cell_member_type, lif_probe_info> probes_;
};

} // namespace arb
1 change: 0 additions & 1 deletion doc/concepts/cable_cell.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ Once constructed, the cable cell can be queried for specific information about t
labels
mechanisms
decor
probe_sample

API
---
Expand Down
2 changes: 2 additions & 0 deletions doc/concepts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ of the model over the locally available computational resources.

In order to visualize the result of detected spikes a spike recorder can be used, and to analyse Arbor's performance a
meter manager is available.

:ref:`probesample` shows how to extract data from simulations.
Loading