-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Finishes simulations regarding inhibitory plasticity
- Loading branch information
Showing
5 changed files
with
297 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from brian2.units import * | ||
from core.equations.synapses.sfp8STDP import sfp8STDP, ParamDict | ||
|
||
|
||
class sfp8iSTDP(sfp8STDP): | ||
def __init__(self): | ||
""" Implementation of inhibitory STDP with minifloat. | ||
""" | ||
super().__init__() | ||
self.modify_model('namespace', 184, key='w_factor') | ||
self.model += 'target_rate : integer\n' | ||
|
||
self.modify_model('on_pre', | ||
'''delta_w = int(Ca_pre<128 and Ca_post>128)*Ca_pre + int(Ca_pre>128 and Ca_post<128)*fp8_add_stochastic(Ca_post, target_rate) | ||
delta_w = fp8_multiply_stochastic(delta_w, eta) | ||
w_plast = fp8_add_stochastic(w_plast, delta_w) | ||
w_plast = int(w_plast<128)*w_plast | ||
''', | ||
key='stdp_fanout') | ||
|
||
self.parameters = ParamDict({**self.parameters, | ||
**{'target_rate': 236}} # -20 in decimal | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
suppressPackageStartupMessages(library(ggplot2)) | ||
suppressPackageStartupMessages(library(dplyr)) | ||
suppressPackageStartupMessages(library(jsonlite)) | ||
suppressPackageStartupMessages(library(purrr)) | ||
suppressPackageStartupMessages(library(stringr)) | ||
suppressPackageStartupMessages(library(wesanderson)) | ||
suppressPackageStartupMessages(library(patchwork)) | ||
suppressPackageStartupMessages(library(latex2exp)) | ||
|
||
library(argparser) | ||
include('plots/parse_inputs.R') | ||
include('plots/minifloat_utils.R') | ||
|
||
ca_trace <- function(df_data, color_map){ | ||
fig <- df_data %>% | ||
mutate(value=map_dbl(value, minifloat2decimal)) %>% | ||
group_by(time_ms) %>% | ||
mutate(avg_value=mean(value)) %>% | ||
ggplot(aes(x=time_ms, y=avg_value)) + geom_line(color=color_map[4]) + | ||
theme_bw() + labs(x='time (ms)', y=TeX(r'(average $x$ (a.u.))')) | ||
|
||
return(fig) | ||
} | ||
|
||
stats <- function(df_data, color_map, tsim){ | ||
df_stats_init <- df_data %>% | ||
group_by(id) %>% | ||
filter(time_ms<1000) %>% | ||
summarise(rate=n()/1, | ||
mean_isi=mean(diff(time_ms), na.rm=T), | ||
sd_isi=sd(diff(time_ms), na.rm=T), | ||
cv=sd_isi/mean_isi) | ||
df_stats_final <- df_data %>% | ||
group_by(id) %>% | ||
filter(time_ms>(1000*tsim-1000)) %>% | ||
summarise(rate=n()/1, | ||
mean_isi=mean(diff(time_ms), na.rm=T), | ||
sd_isi=sd(diff(time_ms), na.rm=T), | ||
cv=sd_isi/mean_isi) | ||
|
||
rate_init <- df_stats_init %>% | ||
ggplot(aes(x=rate)) + geom_histogram(fill=color_map[4], binwidth=1) + | ||
theme_bw() + labs(x='firing rate (Hz)', y='count') | ||
cv_init <- df_stats_init %>% | ||
ggplot(aes(x=cv)) + geom_histogram(fill=color_map[4], binwidth=0.1) + | ||
theme_bw() + labs(x='ISI CV', y='count') | ||
rate_final <- df_stats_final %>% | ||
ggplot(aes(x=rate)) + geom_histogram(fill=color_map[4], binwidth=1) + | ||
theme_bw() + labs(x='firing rate (Hz)', y='count') | ||
cv_final <- df_stats_final %>% | ||
ggplot(aes(x=cv)) + geom_histogram(fill=color_map[4], binwidth=0.1) + | ||
theme_bw() + labs(x='ISI CV', y='count') | ||
|
||
return(list(rate_init, cv_init, rate_final, cv_final, df_stats_init, df_stats_final)) | ||
} | ||
|
||
color_map <- wes_palette('Rushmore1') | ||
|
||
wd = getwd() | ||
data_path <- Sys.glob(file.path(wd, argv$source, "*")) | ||
metadata <- map(file.path(data_path, "metadata.json"), fromJSON) | ||
protocol <- map_int(metadata, \(x) x$protocol) | ||
tsim <- map_dbl(metadata, \(x) as.double(str_sub(x$duration, 1, -2))) | ||
|
||
####### Processing trial with high learning rate | ||
sel_dir <- match(1, protocol) | ||
state_vars <- read.csv(file.path(data_path[sel_dir], "state_vars.csv")) | ||
exc_spikes <- read.csv(file.path(data_path[sel_dir], "spikes_exc.csv")) | ||
|
||
ca_trace_high <- ca_trace(state_vars, color_map) | ||
|
||
# get small one for comparison | ||
state_vars <- read.csv(file.path(data_path[match(2, protocol)], "state_vars.csv")) | ||
ca_trace_small <- ca_trace(state_vars, color_map) | ||
|
||
exc_spikes <- read.csv(file.path(data_path[sel_dir], "spikes_exc.csv")) | ||
stats_figs <- stats(exc_spikes, color_map, tsim[sel_dir]) | ||
|
||
####### Processing trial with small learning rate | ||
sel_dir <- match(2, protocol) | ||
exc_spikes <- read.csv(file.path(data_path[sel_dir], "spikes_exc.csv")) | ||
inh_weights <- read.csv(file.path(data_path[sel_dir], "weights.csv")) | ||
|
||
w_traces <- inh_weights %>% | ||
mutate(value=map_dbl(value, minifloat2decimal)) %>% | ||
mutate(time_ms=time_ms/1000) %>% | ||
ggplot(aes(x=value, fill=time_ms, group=time_ms)) + | ||
geom_histogram(alpha=0.7) + | ||
scale_fill_gradientn(colors=color_map) + | ||
theme_bw() + labs(x='weight (a.u.)', y='count', fill='time (s)') + | ||
theme(legend.position = c(0.80, 0.8)) | ||
|
||
raster <- function(df_data, color_map, time_slice){ | ||
fig <- df_data %>% | ||
filter(time_ms>time_slice[1] & time_ms<time_slice[2]) %>% | ||
ggplot(aes(x=time_ms, y=id)) + | ||
geom_point(shape=20, size=0.05, alpha=0.2, color=color_map[4]) + | ||
theme_bw() + | ||
theme(panel.grid.minor=element_blank(), | ||
panel.grid.major=element_blank()) + | ||
scale_color_manual(values=color_map[4]) + | ||
labs(x='time (ms)', y='neuron id') | ||
|
||
return(fig) | ||
} | ||
|
||
raster_init <- raster(exc_spikes, color_map, c(0, 1000)) | ||
raster_final <- raster(exc_spikes, color_map, c(1000*tsim[sel_dir]-1000, 1000*tsim[sel_dir])) | ||
|
||
stats_figs <- stats(exc_spikes, color_map, tsim[sel_dir]) | ||
|
||
RS <- stats_figs[[5]]$cv | ||
AI <- stats_figs[[6]]$cv | ||
print(t.test(RS, AI)) | ||
|
||
fig_high_protocol <- wrap_elements(ca_trace_high + ca_trace_small + | ||
plot_annotation(title='A')) / | ||
(wrap_elements(stats_figs[[1]] + plot_annotation(title='B')) | | ||
wrap_elements(stats_figs[[2]] + plot_annotation(title='C'))) | ||
|
||
fig_small_protocol <- (wrap_elements(raster_init + stats_figs[[1]] + stats_figs[[2]] + | ||
plot_annotation(title='A')) / | ||
wrap_elements(raster_final + stats_figs[[3]] + stats_figs[[4]] + | ||
plot_annotation(title='B'))) | | ||
wrap_elements(w_traces + plot_annotation(title='C')) | ||
|
||
ggsave(str_replace(argv$dest, '.png', '_high_eta.png'), fig_high_protocol) | ||
ggsave(str_replace(argv$dest, '.png', '_small_eta.png'), fig_small_protocol) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from brian2 import defaultclock, ms, Hz, PoissonGroup, SpikeMonitor, StateMonitor,\ | ||
device, prefs, run | ||
import pandas as pd | ||
import numpy as np | ||
import json | ||
|
||
from core.builder.groups_builder import create_synapses, create_neurons | ||
from core.equations.neurons.sfp8LIF import sfp8LIF | ||
from core.equations.synapses.sfp8CUBA import sfp8CUBA | ||
from core.equations.synapses.sfp8iSTDP import sfp8iSTDP | ||
from core.utils.misc import decimal2minifloat | ||
from core.utils.prepare_models import set_hardwarelike_scheme,\ | ||
generate_connection_indices | ||
from core.utils.process_responses import statemonitors2dataframe | ||
|
||
|
||
def istdp(args): | ||
defaultclock.dt = args.timestep * ms | ||
rng = np.random.default_rng() | ||
run_namespace = {} | ||
|
||
""" ================ models ================ """ | ||
neuron_model = sfp8LIF() | ||
synapse_model = sfp8CUBA() | ||
istdp_model = sfp8iSTDP() | ||
|
||
if args.protocol == 1: | ||
tsim = 10001*ms | ||
elif args.protocol == 2: | ||
tsim = 60001*ms | ||
neuron_model.modify_model('namespace', | ||
decimal2minifloat(1), | ||
key='Ca_inc') | ||
istdp_model.modify_model('parameters', | ||
decimal2minifloat(-0.203125), | ||
key='target_rate') | ||
|
||
neuron_model.modify_model('events', args.event_condition, key='active_Ca',) | ||
# this results in around 5ms of refractory period | ||
neuron_model.modify_model('parameters', 20, key='alpha_refrac') | ||
|
||
num_exc = 8000 | ||
num_inh = 2000 | ||
num_input = 1000 | ||
input_rate = 30*Hz | ||
poisson_pop = PoissonGroup(num_input, input_rate) | ||
neurons = create_neurons(num_exc+num_inh, neuron_model) | ||
exc_neurons = neurons[:num_exc] | ||
inh_neurons = neurons[num_exc:] | ||
|
||
""" ================ Wiring ================ """ | ||
synapse_model.modify_model('connection', 0.03, key='p') | ||
input_synapse = create_synapses(poisson_pop, | ||
neurons, | ||
synapse_model) | ||
|
||
synapse_model.modify_model('connection', 0.02, key='p') | ||
# this is close to '2*1/N_incoming*mV' | ||
synapse_model.modify_model('parameters', | ||
decimal2minifloat(0.013671875), | ||
key='weight') | ||
exc_conn = create_synapses(exc_neurons, neurons, synapse_model) | ||
|
||
# this is close to '10*2*1/N_incoming*mV' | ||
synapse_model.modify_model('parameters', | ||
decimal2minifloat(0.125), | ||
key='weight') | ||
synapse_model.modify_model('namespace', 184, key='w_factor') | ||
inh_conn_static = create_synapses(inh_neurons, inh_neurons, synapse_model) | ||
|
||
sources, targets = generate_connection_indices(num_inh, | ||
num_exc, | ||
0.02) | ||
istdp_model.modify_model('connection', sources, key='i') | ||
istdp_model.modify_model('connection', targets, key='j') | ||
istdp_model.modify_model('parameters', | ||
decimal2minifloat(0.125), | ||
key='w_plast') | ||
istdp_synapse = create_synapses(inh_neurons, | ||
exc_neurons, | ||
istdp_model) | ||
|
||
set_hardwarelike_scheme(prefs, [neurons], defaultclock.dt, 'fp8') | ||
|
||
""" ================ Setting up monitors ================ """ | ||
spikemon_exc_neurons = SpikeMonitor(exc_neurons, | ||
name='spikemon_exc_neurons') | ||
spikemon_inh_neurons = SpikeMonitor(inh_neurons, | ||
name='spikemon_inh_neurons') | ||
statemon_neurons = StateMonitor(exc_neurons, | ||
variables=['Ca'], | ||
record=True, | ||
dt=(tsim-1*ms)/200, | ||
name='statemon_neurons') | ||
statemon_synapses = StateMonitor(istdp_synapse, | ||
variables=['w_plast'], | ||
record=[x for x in range(len(sources))], | ||
dt=(tsim-1*ms)/2, | ||
name='statemon_synapses') | ||
|
||
metadata = {'event_condition': args.event_condition, | ||
'protocol': args.protocol, | ||
'duration': str(tsim) | ||
} | ||
with open(f'{args.save_path}/metadata.json', 'w') as f: | ||
json.dump(metadata, f) | ||
|
||
run(tsim, report='stdout', namespace=run_namespace) | ||
|
||
if args.backend == 'cpp_standalone' or args.backend == 'cuda_standalone': | ||
device.build(args.code_path) | ||
|
||
""" =================== Saving data =================== """ | ||
output_spikes = pd.DataFrame( | ||
{'time_ms': np.array(spikemon_exc_neurons.t/defaultclock.dt), | ||
'id': np.array(spikemon_exc_neurons.i)} | ||
) | ||
output_spikes.to_csv(f'{args.save_path}/spikes_exc.csv', index=False) | ||
|
||
output_spikes = pd.DataFrame( | ||
{'time_ms': np.array(spikemon_inh_neurons.t/defaultclock.dt), | ||
'id': np.array(spikemon_inh_neurons.i)} | ||
) | ||
output_spikes.to_csv(f'{args.save_path}/spikes_inh.csv', index=False) | ||
|
||
state_vars = statemonitors2dataframe([statemon_synapses]) | ||
state_vars.to_csv(f'{args.save_path}/weights.csv', index=False) | ||
state_vars = statemonitors2dataframe([statemon_neurons]) | ||
state_vars.to_csv(f'{args.save_path}/state_vars.csv', index=False) |