Skip to content

Commit

Permalink
All tests pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
thorstenhater committed Feb 21, 2025
1 parent 7dd7706 commit eebf939
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 26 deletions.
11 changes: 7 additions & 4 deletions arbor/cable_cell_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,14 @@ void run_samples(const fvm_probe_multi& p,

scratch.times.clear();
scratch.values.clear();
scratch.values.reserve(n_raw_per_sample*n_sample);

for (sample_size_type j = 0; j < n_sample; ++j) {
auto offset = j*n_raw_per_sample + sc.begin_offset;
scratch.times.push_back(raw_times[offset]);
scratch.values.push_back(raw_samples[offset]);
for (sample_size_type i = 0; i < n_raw_per_sample; ++i) {
scratch.values.push_back(raw_samples[offset + i]);
}
}

do_run_sampler(sc, n_sample, n_raw_per_sample, p, scratch);
Expand All @@ -144,10 +147,10 @@ void run_samples(const fvm_probe_weighted_multi& p,

for (sample_size_type j = 0; j < n_sample; ++j) {
auto offset = j*n_raw_per_sample + sc.begin_offset;
scratch.times.push_back(raw_times[offset]);
for (sample_size_type i = 0; i < n_raw_per_sample; ++i) {
scratch.values.push_back(raw_samples[offset + i]*p.weight[i]);
}
scratch.times.push_back(raw_times[offset]);
}

do_run_sampler(sc, n_sample, n_raw_per_sample, p, scratch);
Expand All @@ -171,12 +174,12 @@ void run_samples(const fvm_probe_interpolated_multi& p,

for (sample_size_type j = 0; j < n_sample; ++j) {
auto offset = j*n_raw_per_sample + sc.begin_offset;
scratch.times.push_back(raw_times[offset]);
const auto* raw_a = raw_samples + offset;
const auto* raw_b = raw_a + n_interp_per_sample;
for (sample_size_type i = 0; i < n_interp_per_sample; ++i) {
scratch.values.push_back(raw_a[i]*p.coef[0][i] + raw_b[i]*p.coef[1][i]);
}
scratch.times.push_back(raw_times[offset]);
}
do_run_sampler(sc, n_sample, n_interp_per_sample, p, scratch);
}
Expand All @@ -202,6 +205,7 @@ void run_samples(const fvm_probe_membrane_currents& p,

for (sample_size_type j = 0; j < n_sample; ++j) {
auto offset = j*n_raw_per_sample + sc.begin_offset;
scratch.times.push_back(raw_times[offset]);
auto base = scratch.values.data() + j*n_cable;

// Each CV voltage contributes to the current sum of its parent's cables
Expand Down Expand Up @@ -235,7 +239,6 @@ void run_samples(const fvm_probe_membrane_currents& p,
base[cable_i] -= cv_stim_I*p.weight[cable_i];
}
}
scratch.times.push_back(raw_times[offset]);
}
do_run_sampler(sc, n_sample, n_cable, p, scratch);
}
Expand Down
2 changes: 1 addition & 1 deletion arbor/include/arbor/cable_cell.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct ARB_SYMBOL_VISIBLE cable_probe_density_state {
// Value of state variable `state` in density mechanism `mechanism` across components of the cell.
struct ARB_SYMBOL_VISIBLE cable_probe_density_state_cell {
using value_type = cable_sample_type;
using meta_type = cable_state_meta_type;
using meta_type = cable_state_cell_meta_type;
std::string mechanism;
std::string state;
};
Expand Down
2 changes: 1 addition & 1 deletion arbor/include/arbor/sampling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct sample_reader {

time_type get_time(std::size_t i) const {
arb_assert(i < n_sample);
return values[i];
return time[i];
}

meta_type get_metadata(std::size_t j) const {
Expand Down
5 changes: 4 additions & 1 deletion arbor/lif_cell_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,10 @@ void lif_cell_group::advance_cell(time_type tfinal,
const auto& fun = samplers_[k].sampler;
for (auto& [id, us]: vs) {
auto meta = get_probe_metadata(id)[0];
fun(meta, sample_records{.n_sample=us.times.size(), .time=us.times.data(), .values=us.values.data()});
fun(meta, sample_records{.n_sample=us.times.size(),
.width=1,
.time=us.times.data(),
.values=const_cast<const double*>(us.values.data())});
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions test/unit/test_diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ testing::AssertionResult run(const linear& rec, const result_t exp) {
result_t sample_values;
auto sampler = [&sample_values](arb::probe_metadata pm, const arb::sample_records& recs) {
sample_values.clear();
auto reader = arb::make_sample_reader<arb::cable_state_meta_type, arb::cable_sample_type>(pm.meta, recs);
auto reader = arb::make_sample_reader<arb::cable_state_cell_meta_type, arb::cable_sample_type>(pm.meta, recs);
for (std::size_t ix = 0; ix < reader.n_sample; ++ix) {
auto loc = reader.get_metadata(ix);
auto time = reader.get_time(ix);
for (std::size_t iy = 0; iy < reader.width; ++iy) {
auto cable = reader.get_metadata(iy);
auto value = reader.get_value(ix, iy);
sample_values.push_back({time, loc.pos, value});
sample_values.push_back({time, cable.prox_pos, value});
}
}
};
Expand Down
31 changes: 17 additions & 14 deletions test/unit/test_probe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,14 +941,15 @@ void run_multi_probe_test(context ctx) {
auto trace = run_simple_sampler(ctx, 0.1*U::ms, {cable_cell{m, d}},
{0, "probe"}, cable_probe_density_state{ ls::terminal(), "param_as_state", "s"},
{0.0*U::ms});
ASSERT_EQ(1u, trace.values.size());
ASSERT_EQ(3u, trace.values.size());
ASSERT_EQ(3u, trace.metadata.size());
for (const auto& val: trace.values) ASSERT_EQ(1u, val.size());

// Expect to have received a sample on each of the terminals of branches 1, 2, and 5.
std::vector<std::pair<mlocation, double>> vals;
for (size_t ix = 0; ix < trace.n_sample; ++ix) {
for (size_t iy = 0; iy < trace.width; ++iy) {
vals.emplace_back(trace.metadata.at(iy), trace.values[ix][iy]);
vals.emplace_back(trace.metadata.at(iy), trace.values[iy][ix]);
}
}

Expand Down Expand Up @@ -989,17 +990,19 @@ void run_v_sampled_probe_test(context ctx) {

auto trace0 = run_simple_sampler(ctx, t_end, cells, {0, "Um-loc"}, probe, when);
EXPECT_EQ(probe_loc, trace0.metadata.at(0));
EXPECT_EQ(2u, trace0.values.size());
EXPECT_EQ(1u, trace0.values.size());
EXPECT_EQ(2u, trace0.values[0].size());

auto trace1 = run_simple_sampler(ctx, t_end, cells, {1, "Um-loc"}, probe, when);
EXPECT_EQ(probe_loc, trace1.metadata.at(0));
EXPECT_EQ(2u, trace1.values.size());
EXPECT_EQ(1u, trace1.values.size());
EXPECT_EQ(2u, trace1.values[0].size());

EXPECT_EQ(trace0.time[0], trace1.time[0]);
EXPECT_EQ(trace0.values[0], trace1.values[0]);
EXPECT_EQ(trace0.values[0][0], trace1.values[0][0]);

EXPECT_EQ(trace0.time[1], trace1.time[1]);
EXPECT_NE(trace0.values[1], trace1.values[1]);
EXPECT_NE(trace0.values[0][1], trace1.values[0][1]);
}


Expand Down Expand Up @@ -1099,7 +1102,7 @@ void run_total_current_probe_test(context ctx) {
double sum_current = 0;

for (auto k: util::make_span(trace.width)) {
double current = trace.values[j][k] + stim_trace.values[j][k];
double current = trace.values[k][j] + stim_trace.values[k][j];
EXPECT_NE(0.0, current);
max_abs_current = std::max(max_abs_current, std::abs(current));
sum_current += current;
Expand All @@ -1112,19 +1115,19 @@ void run_total_current_probe_test(context ctx) {
// TODO Check that we transcribed the width/length correctly
for (auto k: util::make_span(trace.n_sample)) {
const double rtol_large = 1e-3;
EXPECT_FALSE(testing::near_relative(trace.values[0][k], ion_trace.values[0][k], rtol_large));
EXPECT_FALSE(testing::near_relative(trace.values[k][0], ion_trace.values[k][0], rtol_large));
}

for (unsigned k = 0; k<trace.n_sample; ++k) {
const double rtol_small = 1e-6;
EXPECT_TRUE( testing::near_relative(trace.values[1][k], ion_trace.values[1][k], rtol_small));
EXPECT_TRUE( testing::near_relative(trace.values[k][1], ion_trace.values[k][1], rtol_small));
}

}

// Total membrane currents should differ between the two cells at t=τ.
for (unsigned k = 0; k < traces[0].n_sample; ++k) {
EXPECT_NE(traces[0].values[0][k], traces[1].values[0][k]);
for (unsigned k = 0; k < traces[0].width; ++k) {
EXPECT_NE(traces[0].values[k][0], traces[1].values[k][0]);
}
};

Expand Down Expand Up @@ -1350,7 +1353,6 @@ ARB_PP_FOREACH(RUN_GPU, PROBE_TESTS)

// Test simulator `get_probe_metadata` interface.
// (No need to run this on GPU back-end as well.)

TEST(probe, get_probe_metadata) {
// Reuse multiprobe test set-up to confirm simulator::get_probe_metadata returns
// correct vector of metadata.
Expand Down Expand Up @@ -1379,8 +1381,9 @@ TEST(probe, get_probe_metadata) {

EXPECT_EQ(0u, mm[0].index);

auto locs = *any_cast<const mlocation_list*>(mm[0].meta);

// TODO This isn't as nice as we'd like.
auto ptr = any_cast<const mlocation*>(mm[0].meta);
std::vector<mlocation> locs(ptr, ptr + 3);
util::sort(locs);
EXPECT_EQ((mlocation{1, 1.}), locs[0]);
EXPECT_EQ((mlocation{2, 1.}), locs[1]);
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_sde.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ TEST(sde, solver) {
std::size_t n_entities = samples.width;
std::size_t offset = pm.id.gid*nsteps*n_entities;
std::size_t stride = n_entities;
assert(n == nsteps);
assert(samples.n_sample == nsteps);

using probe_t = arb::cable_probe_point_state_cell;
auto reader = arb::make_sample_reader<probe_t::meta_type, probe_t::value_type>(pm.meta, samples);
Expand Down Expand Up @@ -818,7 +818,7 @@ TEST(sde, coupled) {
std::size_t n_entities = samples.width;
std::size_t offset = pm.id.gid*nsteps*n_entities;
std::size_t stride = n_entities;
assert(n == nsteps);
assert(samples.n_sample == nsteps);

using probe_t = arb::cable_probe_point_state_cell;
auto reader = arb::make_sample_reader<probe_t::meta_type, probe_t::value_type>(pm.meta, samples);
Expand Down

0 comments on commit eebf939

Please sign in to comment.