Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into feature/fixed-cab…
Browse files Browse the repository at this point in the history
…le-dt
  • Loading branch information
boeschf committed Apr 4, 2023
2 parents d3e05bf + ad0b304 commit 0c799aa
Show file tree
Hide file tree
Showing 42 changed files with 1,370 additions and 207 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ dist
# generated by YouCompleteMe Vim plugin with clangd engine support.
.clangd

# Generated by various tools
.ccls-cache
.direnv
_skbuild

# generated image files by Python examples
python/example/*.svg
python/example/*.csv
Expand Down
8 changes: 7 additions & 1 deletion arbor/arbexcept.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,16 @@ invalid_mechanism_kind::invalid_mechanism_kind(arb_mechanism_kind kind):
{}

bad_connection_source_gid::bad_connection_source_gid(cell_gid_type gid, cell_gid_type src_gid, cell_size_type num_cells):
arbor_exception(pprintf("Model building error on cell {}: connection source gid {} is out of range: there are only {} cells in the model, in the range [{}:{}].", gid, src_gid, num_cells, 0, num_cells-1)),
arbor_exception(pprintf("Model building error on cell {}: connection source gid {} is out of range: there are {} cells in the model, in the range [{}:{}].", gid, src_gid, num_cells, 0, num_cells-1)),
gid(gid), src_gid(src_gid), num_cells(num_cells)
{}

source_gid_exceeds_limit::source_gid_exceeds_limit(cell_gid_type gid, cell_gid_type src_gid):
arbor_exception(pprintf("Model building error on cell {}: connection source gid {} is out of range: gids may not exceed {}.",
gid, src_gid, std::numeric_limits<cell_gid_type>::max()/2)),
gid(gid), src_gid(src_gid)
{}

bad_connection_label::bad_connection_label(cell_gid_type gid, const cell_tag_type& label, const std::string& msg):
arbor_exception(pprintf("Model building error on cell {}: connection endpoint label \"{}\": {}.", gid, label, msg)),
gid(gid), label(label)
Expand Down
188 changes: 143 additions & 45 deletions arbor/communication/communicator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <numeric>
#include <utility>
#include <vector>
#include <limits>

#include <arbor/assert.hpp>
#include <arbor/common_types.hpp>
Expand Down Expand Up @@ -32,6 +33,28 @@ communicator::communicator(const recipe& rec,
distributed_{ctx.distributed},
thread_pool_{ctx.thread_pool} {}

constexpr inline
bool is_external(cell_gid_type c) {
// index of the MSB of cell_gid_type in bits
constexpr auto msb = 1 << (std::numeric_limits<cell_gid_type>::digits - 1);
// If set, we are external
return bool(c & msb);
}

constexpr inline
cell_member_type global_cell_of(const cell_remote_label_type& c) {
constexpr auto msb = 1 << (std::numeric_limits<cell_gid_type>::digits - 1);
// set the MSB
return {c.rid | msb, c.index};
}

constexpr inline
cell_member_type global_cell_of(const cell_member_type& c) {
constexpr auto msb = 1 << (std::numeric_limits<cell_gid_type>::digits - 1);
// set the MSB
return {c.gid | msb, c.index};
}

void communicator::update_connections(const connectivity& rec,
const domain_decomposition& dom_dec,
const label_resolution_map& source_resolution_map,
Expand Down Expand Up @@ -62,11 +85,17 @@ void communicator::update_connections(const connectivity& rec,
// Build the connection information for local cells.
PE(init:communicator:update:gid_connections);
std::vector<cell_connection> gid_connections;
std::vector<size_t> part_connections; part_connections.reserve(num_local_cells_);
std::vector<ext_cell_connection> gid_ext_connections;
std::vector<size_t> part_connections;
part_connections.reserve(num_local_cells_);
part_connections.push_back(0);
std::vector<size_t> part_ext_connections;
part_ext_connections.reserve(num_local_cells_);
part_ext_connections.push_back(0);
std::vector<unsigned> src_domains;
std::vector<cell_size_type> src_counts(num_domains_);
for (const auto gid: gids) {
// Local
const auto& conns = rec.connections_on(gid);
for (const auto& conn: conns) {
const auto sgid = conn.source.gid;
Expand All @@ -77,29 +106,50 @@ void communicator::update_connections(const connectivity& rec,
gid_connections.emplace_back(conn);
}
part_connections.push_back(gid_connections.size());
// Remote
const auto& ext_conns = rec.external_connections_on(gid);
for (const auto& conn: ext_conns) {
gid_ext_connections.emplace_back(conn);
}
part_ext_connections.push_back(gid_ext_connections.size());
}

util::make_partition(connection_part_, src_counts);
std::size_t n_cons = gid_connections.size();
auto n_cons = gid_connections.size();
auto n_ext_cons = gid_ext_connections.size();
PL();

// Construct the connections. The loop above gave us the information needed
// to do this in place.
// NOTE: The connections are partitioned by the domain of their source gid.
PE(init:communicator:update:connections);
connections_.resize(n_cons);
ext_connections_.resize(n_ext_cons);
auto offsets = connection_part_; // Copy, as we use this as the list of current target indices to write into
std::size_t ext = 0;
auto src_domain = src_domains.begin();
auto target_resolver = resolver(&target_resolution_map);
for (const auto index: util::make_span(num_local_cells_)) {
const auto gid = gids[index];
const auto tgt_gid = gids[index];
auto source_resolver = resolver(&source_resolution_map);
for (const auto cidx: util::make_span(part_connections[index], part_connections[index+1])) {
const auto& conn = gid_connections[cidx];
auto src_gid = conn.source.gid;
if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid);
auto src_lid = source_resolver.resolve(conn.source);
auto tgt_lid = target_resolver.resolve(gid, conn.dest);
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
auto offset = offsets[*src_domain]++;
++src_domain;
connections_[offset] = {{conn.source.gid, src_lid}, tgt_lid, conn.weight, conn.delay, index};
connections_[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index};
}
for (const auto cidx: util::make_span(part_ext_connections[index], part_ext_connections[index+1])) {
const auto& conn = gid_ext_connections[cidx];
auto src = global_cell_of(conn.source);
auto src_gid = conn.source.rid;
if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid);
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
ext_connections_[ext] = {src, tgt_lid, conn.weight, conn.delay, index};
++ext;
}
}
PL();
Expand All @@ -120,6 +170,7 @@ void communicator::update_connections(const connectivity& rec,
[&](cell_size_type i) {
util::sort(util::subrange_view(connections_, cp[i], cp[i+1]));
});
std::sort(ext_connections_.begin(), ext_connections_.end());
PL();
}

Expand All @@ -129,13 +180,19 @@ std::pair<cell_size_type, cell_size_type> communicator::group_queue_range(cell_s
}

time_type communicator::min_delay() {
auto local_min = std::accumulate(connections_.begin(), connections_.end(),
std::numeric_limits<time_type>::max(),
[](auto&& acc, auto&& el) { return std::min(acc, time_type(el.delay)); });
return distributed_->min(local_min);
time_type res = std::numeric_limits<time_type>::max();
res = std::accumulate(connections_.begin(), connections_.end(),
res,
[](auto&& acc, auto&& el) { return std::min(acc, time_type(el.delay)); });
res = std::accumulate(ext_connections_.begin(), ext_connections_.end(),
res,
[](auto&& acc, auto&& el) { return std::min(acc, time_type(el.delay)); });
res = distributed_->min(res);
return res;
}

gathered_vector<spike> communicator::exchange(std::vector<spike> local_spikes) {
communicator::spikes
communicator::exchange(std::vector<spike> local_spikes) {
PE(communication:exchange:sort);
// sort the spikes in ascending order of source gid
util::sort_by(local_spikes, [](spike s){return s.source;});
Expand All @@ -147,53 +204,94 @@ gathered_vector<spike> communicator::exchange(std::vector<spike> local_spikes) {
num_spikes_ += global_spikes.size();
PL();

return global_spikes;
// Get remote spikes
PE(communication:exchange:gather:remote);
if (remote_spike_filter_) {
local_spikes.erase(std::remove_if(local_spikes.begin(),
local_spikes.end(),
[this] (const auto& s) { return !remote_spike_filter_(s); }));
}
auto remote_spikes = distributed_->remote_gather_spikes(local_spikes);
PL();

PE(communication:exchange:gather:remote:post_process);
// set the remote bit on all incoming spikes
std::for_each(remote_spikes.begin(), remote_spikes.end(),
[](spike& s) { s.source = global_cell_of(s.source); });
// sort, since we cannot trust our peers
std::sort(remote_spikes.begin(), remote_spikes.end());
PL();
return {global_spikes, remote_spikes};
}

void communicator::make_event_queues(const gathered_vector<spike>& global_spikes,
std::vector<pse_vector>& queues) {
arb_assert(queues.size()==num_local_cells_);
void communicator::set_remote_spike_filter(const spike_predicate& p) { remote_spike_filter_ = p; }
void communicator::remote_ctrl_send_continue(const epoch& e) { distributed_->remote_ctrl_send_continue(e); }
void communicator::remote_ctrl_send_done() { distributed_->remote_ctrl_send_done(); }

// Internal helper to append to the event queues
template<typename S, typename C>
void append_events_from_domain(C cons,
S spks,
std::vector<pse_vector>& queues) {
// Predicate for partitioning
struct spike_pred {
bool operator()(const spike& spk, const cell_member_type& src) { return spk.source < src; }
bool operator()(const cell_member_type& src, const spike& spk) { return src < spk.source; }
};
const auto& sp = global_spikes.partition();
const auto& cp = connection_part_;
for (auto dom: util::make_span(num_domains_)) {
auto cons = util::subrange_view(connections_, cp[dom], cp[dom+1]);
auto spks = util::subrange_view(global_spikes.values(), sp[dom], sp[dom+1]);
auto sp = spks.begin(), se = spks.end();
auto cn = cons.begin(), ce = cons.end();
// We have a choice of whether to walk spikes or connections:
// i.e., we can iterate over the spikes, and for each spike search
// the for connections that have the same source; or alternatively
// for each connection, we can search the list of spikes for spikes
// with the same source.
//
// We iterate over whichever set is the smallest, which has
// complexity of order max(S log(C), C log(S)), where S is the
// number of spikes, and C is the number of connections.
if (cons.size()<spks.size()) {
while (cn!=ce && sp!=se) {
auto& queue = queues[cn->index_on_domain];
auto src = cn->source;
auto sources = std::equal_range(sp, se, src, spike_pred());
for (auto s: util::make_range(sources)) queue.push_back(cn->make_event(s));
sp = sources.first;
++cn;

auto sp = spks.begin(), se = spks.end();
auto cn = cons.begin(), ce = cons.end();
// We have a choice of whether to walk spikes or connections:
// i.e., we can iterate over the spikes, and for each spike search
// the for connections that have the same source; or alternatively
// for each connection, we can search the list of spikes for spikes
// with the same source.
//
// We iterate over whichever set is the smallest, which has
// complexity of order max(S log(C), C log(S)), where S is the
// number of spikes, and C is the number of connections.
if (cons.size() < spks.size()) {
while (cn != ce && sp != se) {
auto sources = std::equal_range(sp, se, cn->source, spike_pred());
for (auto s: util::make_range(sources)) {
queues[cn->index_on_domain].push_back(make_event(*cn, s));
}
sp = sources.first;
++cn;
}
else {
while (cn!=ce && sp!=se) {
auto targets = std::equal_range(cn, ce, sp->source);
for (auto c: util::make_range(targets)) queues[c.index_on_domain].push_back(c.make_event(*sp));
cn = targets.first;
++sp;
}
else {
while (cn != ce && sp != se) {
auto targets = std::equal_range(cn, ce, sp->source);
for (auto c: util::make_range(targets)) {
queues[c.index_on_domain].push_back(make_event(c, *sp));
}
cn = targets.first;
++sp;
}
}
num_local_events_ = util::sum_by( queues, [](const auto& q) {return q.size();}, num_local_events_);
}

void communicator::make_event_queues(
const gathered_vector<spike>& global_spikes,
std::vector<pse_vector>& queues,
const std::vector<spike>& external_spikes) {
arb_assert(queues.size()==num_local_cells_);
const auto& sp = global_spikes.partition();
const auto& cp = connection_part_;
for (auto dom: util::make_span(num_domains_)) {
append_events_from_domain(util::subrange_view(connections_, cp[dom], cp[dom+1]),
util::subrange_view(global_spikes.values(), sp[dom], sp[dom+1]),
queues);
}
num_local_events_ = util::sum_by(queues, [](const auto& q) {return q.size();}, num_local_events_);
// Now that all local spikes have been processed; consume the remote events coming in.
// - turn all gids into externals
auto spikes = external_spikes;
std::for_each(spikes.begin(),
spikes.end(),
[](auto& s) { s.source = global_cell_of(s.source); });
append_events_from_domain(ext_connections_, spikes, queues);
}

std::uint64_t communicator::num_spikes() const {
Expand Down
30 changes: 27 additions & 3 deletions arbor/communication/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "communication/gathered_vector.hpp"
#include "connection.hpp"
#include "epoch.hpp"
#include "execution_context.hpp"
#include "util/partition.hpp"

Expand All @@ -28,6 +29,12 @@ namespace arb {

class ARB_ARBOR_API communicator {
public:

struct spikes {
gathered_vector<spike> from_local;
std::vector<spike> from_remote;
};

communicator() = default;

explicit communicator(const recipe& rec,
Expand All @@ -43,8 +50,10 @@ class ARB_ARBOR_API communicator {
/// Perform exchange of spikes.
///
/// Takes as input the list of local_spikes that were generated on the calling domain.
/// Returns the full global set of vectors, along with meta data about their partition
gathered_vector<spike> exchange(std::vector<spike> local_spikes);
/// Returns
/// * full global set of vectors, along with meta data about their partition
/// * a list of spikes received from remote simulations
spikes exchange(std::vector<spike> local_spikes);

/// Check each global spike in turn to see it generates local events.
/// If so, make the events and insert them into the appropriate event list.
Expand All @@ -56,7 +65,8 @@ class ARB_ARBOR_API communicator {
/// in the list.
void make_event_queues(
const gathered_vector<spike>& global_spikes,
std::vector<pse_vector>& queues);
std::vector<pse_vector>& queues,
const std::vector<spike>& external_spikes={});

/// Returns the total number of global spikes over the duration of the simulation
std::uint64_t num_spikes() const;
Expand All @@ -68,21 +78,35 @@ class ARB_ARBOR_API communicator {

void reset();

// used for commmunicate to coupled simulations
void remote_ctrl_send_continue(const epoch&);
void remote_ctrl_send_done();

void update_connections(const connectivity& rec,
const domain_decomposition& dom_dec,
const label_resolution_map& source_resolution_map,
const label_resolution_map& target_resolution_map);

void set_remote_spike_filter(const spike_predicate&);

private:
cell_size_type num_total_cells_ = 0;
cell_size_type num_local_cells_ = 0;
cell_size_type num_local_groups_ = 0;
cell_size_type num_domains_ = 0;
// Arbor internal connections
std::vector<connection> connections_;
// partition of connections over the domains of the sources' ids.
std::vector<cell_size_type> connection_part_;
std::vector<cell_size_type> index_divisions_;
util::partition_view_type<std::vector<cell_size_type>> index_part_;

spike_predicate remote_spike_filter_;

// Connections from external simulators into Arbor.
// Currently we have no partitions/indices/acceleration structures
std::vector<connection> ext_connections_;

distributed_context_handle distributed_;
task_system_handle thread_pool_;
std::uint64_t num_spikes_ = 0u;
Expand Down
Loading

0 comments on commit 0c799aa

Please sign in to comment.