Skip to content

Commit

Permalink
Finishes simulations regarding inhibitory plasticity
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloabur committed Jan 29, 2024
1 parent f58c1a9 commit cb68f33
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 1 deletion.
23 changes: 23 additions & 0 deletions core/equations/synapses/sfp8iSTDP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from brian2.units import *
from core.equations.synapses.sfp8STDP import sfp8STDP, ParamDict


class sfp8iSTDP(sfp8STDP):
def __init__(self):
""" Implementation of inhibitory STDP with minifloat.
"""
super().__init__()
self.modify_model('namespace', 184, key='w_factor')
self.model += 'target_rate : integer\n'

self.modify_model('on_pre',
'''delta_w = int(Ca_pre<128 and Ca_post>128)*Ca_pre + int(Ca_pre>128 and Ca_post<128)*fp8_add_stochastic(Ca_post, target_rate)
delta_w = fp8_multiply_stochastic(delta_w, eta)
w_plast = fp8_add_stochastic(w_plast, delta_w)
w_plast = int(w_plast<128)*w_plast
''',
key='stdp_fanout')

self.parameters = ParamDict({**self.parameters,
**{'target_rate': 236}} # -20 in decimal
)
128 changes: 128 additions & 0 deletions plots/istdp.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
suppressPackageStartupMessages(library(ggplot2))
suppressPackageStartupMessages(library(dplyr))
suppressPackageStartupMessages(library(jsonlite))
suppressPackageStartupMessages(library(purrr))
suppressPackageStartupMessages(library(stringr))
suppressPackageStartupMessages(library(wesanderson))
suppressPackageStartupMessages(library(patchwork))
suppressPackageStartupMessages(library(latex2exp))

library(argparser)
include('plots/parse_inputs.R')
include('plots/minifloat_utils.R')

ca_trace <- function(df_data, color_map){
fig <- df_data %>%
mutate(value=map_dbl(value, minifloat2decimal)) %>%
group_by(time_ms) %>%
mutate(avg_value=mean(value)) %>%
ggplot(aes(x=time_ms, y=avg_value)) + geom_line(color=color_map[4]) +
theme_bw() + labs(x='time (ms)', y=TeX(r'(average $x$ (a.u.))'))

return(fig)
}

stats <- function(df_data, color_map, tsim){
df_stats_init <- df_data %>%
group_by(id) %>%
filter(time_ms<1000) %>%
summarise(rate=n()/1,
mean_isi=mean(diff(time_ms), na.rm=T),
sd_isi=sd(diff(time_ms), na.rm=T),
cv=sd_isi/mean_isi)
df_stats_final <- df_data %>%
group_by(id) %>%
filter(time_ms>(1000*tsim-1000)) %>%
summarise(rate=n()/1,
mean_isi=mean(diff(time_ms), na.rm=T),
sd_isi=sd(diff(time_ms), na.rm=T),
cv=sd_isi/mean_isi)

rate_init <- df_stats_init %>%
ggplot(aes(x=rate)) + geom_histogram(fill=color_map[4], binwidth=1) +
theme_bw() + labs(x='firing rate (Hz)', y='count')
cv_init <- df_stats_init %>%
ggplot(aes(x=cv)) + geom_histogram(fill=color_map[4], binwidth=0.1) +
theme_bw() + labs(x='ISI CV', y='count')
rate_final <- df_stats_final %>%
ggplot(aes(x=rate)) + geom_histogram(fill=color_map[4], binwidth=1) +
theme_bw() + labs(x='firing rate (Hz)', y='count')
cv_final <- df_stats_final %>%
ggplot(aes(x=cv)) + geom_histogram(fill=color_map[4], binwidth=0.1) +
theme_bw() + labs(x='ISI CV', y='count')

return(list(rate_init, cv_init, rate_final, cv_final, df_stats_init, df_stats_final))
}

color_map <- wes_palette('Rushmore1')

wd = getwd()
data_path <- Sys.glob(file.path(wd, argv$source, "*"))
metadata <- map(file.path(data_path, "metadata.json"), fromJSON)
protocol <- map_int(metadata, \(x) x$protocol)
tsim <- map_dbl(metadata, \(x) as.double(str_sub(x$duration, 1, -2)))

####### Processing trial with high learning rate
sel_dir <- match(1, protocol)
state_vars <- read.csv(file.path(data_path[sel_dir], "state_vars.csv"))
exc_spikes <- read.csv(file.path(data_path[sel_dir], "spikes_exc.csv"))

ca_trace_high <- ca_trace(state_vars, color_map)

# get small one for comparison
state_vars <- read.csv(file.path(data_path[match(2, protocol)], "state_vars.csv"))
ca_trace_small <- ca_trace(state_vars, color_map)

exc_spikes <- read.csv(file.path(data_path[sel_dir], "spikes_exc.csv"))
stats_figs <- stats(exc_spikes, color_map, tsim[sel_dir])

####### Processing trial with small learning rate
sel_dir <- match(2, protocol)
exc_spikes <- read.csv(file.path(data_path[sel_dir], "spikes_exc.csv"))
inh_weights <- read.csv(file.path(data_path[sel_dir], "weights.csv"))

w_traces <- inh_weights %>%
mutate(value=map_dbl(value, minifloat2decimal)) %>%
mutate(time_ms=time_ms/1000) %>%
ggplot(aes(x=value, fill=time_ms, group=time_ms)) +
geom_histogram(alpha=0.7) +
scale_fill_gradientn(colors=color_map) +
theme_bw() + labs(x='weight (a.u.)', y='count', fill='time (s)') +
theme(legend.position = c(0.80, 0.8))

raster <- function(df_data, color_map, time_slice){
fig <- df_data %>%
filter(time_ms>time_slice[1] & time_ms<time_slice[2]) %>%
ggplot(aes(x=time_ms, y=id)) +
geom_point(shape=20, size=0.05, alpha=0.2, color=color_map[4]) +
theme_bw() +
theme(panel.grid.minor=element_blank(),
panel.grid.major=element_blank()) +
scale_color_manual(values=color_map[4]) +
labs(x='time (ms)', y='neuron id')

return(fig)
}

raster_init <- raster(exc_spikes, color_map, c(0, 1000))
raster_final <- raster(exc_spikes, color_map, c(1000*tsim[sel_dir]-1000, 1000*tsim[sel_dir]))

stats_figs <- stats(exc_spikes, color_map, tsim[sel_dir])

RS <- stats_figs[[5]]$cv
AI <- stats_figs[[6]]$cv
print(t.test(RS, AI))

fig_high_protocol <- wrap_elements(ca_trace_high + ca_trace_small +
plot_annotation(title='A')) /
(wrap_elements(stats_figs[[1]] + plot_annotation(title='B')) |
wrap_elements(stats_figs[[2]] + plot_annotation(title='C')))

fig_small_protocol <- (wrap_elements(raster_init + stats_figs[[1]] + stats_figs[[2]] +
plot_annotation(title='A')) /
wrap_elements(raster_final + stats_figs[[3]] + stats_figs[[4]] +
plot_annotation(title='B'))) |
wrap_elements(w_traces + plot_annotation(title='C'))

ggsave(str_replace(argv$dest, '.png', '_high_eta.png'), fig_high_protocol)
ggsave(str_replace(argv$dest, '.png', '_small_eta.png'), fig_small_protocol)
15 changes: 15 additions & 0 deletions run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from simulations.balanced_network import balanced_network
from simulations.balanced_network_stdp import balanced_network_stdp
from simulations.minifloat import minifloat_operations
from simulations.istdp import istdp

import os
from datetime import datetime
Expand Down Expand Up @@ -218,6 +219,20 @@
f' All operations are stochastic.')
subparser_minifloat.set_defaults(func=minifloat_operations)

subparser_istdp = subparsers.add_parser(
'iSTDP',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
subparser_istdp.add_argument('--event_condition',
type=str,
default= 'Ca > 0',
help=f'Condition uppon a plasticity event is '
f'triggered.')
subparser_istdp.add_argument('--protocol',
type=int,
help=f'Type of simulation. 1 and 2 are for high '
f'and low learning rates, respectively.')
subparser_istdp.set_defaults(func=istdp)

args = parser.parse_args()

os.makedirs(args.save_path, exist_ok=True)
Expand Down
3 changes: 2 additions & 1 deletion scripts/python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
#docker run --entrypoint /bin/bash -it --rm -v $(pwd)/run_simulation.py:/app/run_simulation.py:ro -v $(pwd)/simulations:/app/simulations:ro -v $(pwd)/sim_data:/app/sim_data -v $(pwd)/core/equations/:/app/core/equations/:ro -v $(pwd)/core/utils/:/app/core/utils/:ro pabloabur/app -c "echo 'b simulations/stdp.py:167\nc' > ~/.pdbrc && micromamba run -n base -r /opt/micromamba python -m pdb run_simulation.py STDP --event_condition 'Ca > 0' --protocol 1 --precision 'fp8'"
#docker run --rm -v $(pwd)/run_simulation.py:/app/run_simulation.py:ro -v $(pwd)/simulations:/app/simulations:ro -v $(pwd)/sim_data:/app/sim_data -v $(pwd)/core/equations/:/app/core/equations/:ro -v $(pwd)/core/utils/:/app/core/utils/:ro pabloabur/app --backend cpp_standalone --save_path sim_data/ch4/simple_stdp/$(date +"%d-%m_%Hh%Mm%Ss") --quiet STDP --event_condition "Ca > 0" --protocol 1 --precision 'fp8' --w_init 0.02539062
#docker run --rm -v $(pwd)/run_simulation.py:/app/run_simulation.py:ro -v $(pwd)/simulations:/app/simulations:ro -v $(pwd)/sim_data:/app/sim_data -v $(pwd)/core/equations/:/app/core/equations/:ro -v $(pwd)/core/utils/:/app/core/utils/:ro pabloabur/app --backend cpp_standalone --save_path sim_data/stdp_test/$(date +"%d-%m_%Hh%Mm%Ss") minifloat --protocol 1
docker run --rm -v $(pwd)/run_simulation.py:/app/run_simulation.py:ro -v $(pwd)/simulations:/app/simulations:ro -v $(pwd)/sim_data:/app/sim_data -v $(pwd)/core/equations/:/app/core/equations/:ro -v $(pwd)/core/utils/:/app/core/utils/:ro pabloabur/app --backend cpp_standalone --save_path sim_data/ch4/kernel/$(date +"%d-%m_%Hh%Mm%Ss") STDP --event_condition "Ca > 0" --protocol 2 --precision 'fp8' --w_init 50
#docker run --rm -v $(pwd)/run_simulation.py:/app/run_simulation.py:ro -v $(pwd)/simulations:/app/simulations:ro -v $(pwd)/sim_data:/app/sim_data -v $(pwd)/core/equations/:/app/core/equations/:ro -v $(pwd)/core/utils/:/app/core/utils/:ro pabloabur/app --backend cpp_standalone --save_path sim_data/ch4/kernel/$(date +"%d-%m_%Hh%Mm%Ss") STDP --event_condition "Ca > 0" --protocol 2 --precision 'fp8' --w_init 50
docker run --rm -v $(pwd)/run_simulation.py:/app/run_simulation.py:ro -v $(pwd)/simulations:/app/simulations:ro -v $(pwd)/sim_data:/app/sim_data -v $(pwd)/core/equations/:/app/core/equations/:ro -v $(pwd)/core/utils/:/app/core/utils/:ro pabloabur/app --backend cpp_standalone --save_path sim_data/ch5/$(date +"%d-%m_%Hh%Mm%Ss") iSTDP --protocol 1
129 changes: 129 additions & 0 deletions simulations/istdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from brian2 import defaultclock, ms, Hz, PoissonGroup, SpikeMonitor, StateMonitor,\
device, prefs, run
import pandas as pd
import numpy as np
import json

from core.builder.groups_builder import create_synapses, create_neurons
from core.equations.neurons.sfp8LIF import sfp8LIF
from core.equations.synapses.sfp8CUBA import sfp8CUBA
from core.equations.synapses.sfp8iSTDP import sfp8iSTDP
from core.utils.misc import decimal2minifloat
from core.utils.prepare_models import set_hardwarelike_scheme,\
generate_connection_indices
from core.utils.process_responses import statemonitors2dataframe


def istdp(args):
defaultclock.dt = args.timestep * ms
rng = np.random.default_rng()
run_namespace = {}

""" ================ models ================ """
neuron_model = sfp8LIF()
synapse_model = sfp8CUBA()
istdp_model = sfp8iSTDP()

if args.protocol == 1:
tsim = 10001*ms
elif args.protocol == 2:
tsim = 60001*ms
neuron_model.modify_model('namespace',
decimal2minifloat(1),
key='Ca_inc')
istdp_model.modify_model('parameters',
decimal2minifloat(-0.203125),
key='target_rate')

neuron_model.modify_model('events', args.event_condition, key='active_Ca',)
# this results in around 5ms of refractory period
neuron_model.modify_model('parameters', 20, key='alpha_refrac')

num_exc = 8000
num_inh = 2000
num_input = 1000
input_rate = 30*Hz
poisson_pop = PoissonGroup(num_input, input_rate)
neurons = create_neurons(num_exc+num_inh, neuron_model)
exc_neurons = neurons[:num_exc]
inh_neurons = neurons[num_exc:]

""" ================ Wiring ================ """
synapse_model.modify_model('connection', 0.03, key='p')
input_synapse = create_synapses(poisson_pop,
neurons,
synapse_model)

synapse_model.modify_model('connection', 0.02, key='p')
# this is close to '2*1/N_incoming*mV'
synapse_model.modify_model('parameters',
decimal2minifloat(0.013671875),
key='weight')
exc_conn = create_synapses(exc_neurons, neurons, synapse_model)

# this is close to '10*2*1/N_incoming*mV'
synapse_model.modify_model('parameters',
decimal2minifloat(0.125),
key='weight')
synapse_model.modify_model('namespace', 184, key='w_factor')
inh_conn_static = create_synapses(inh_neurons, inh_neurons, synapse_model)

sources, targets = generate_connection_indices(num_inh,
num_exc,
0.02)
istdp_model.modify_model('connection', sources, key='i')
istdp_model.modify_model('connection', targets, key='j')
istdp_model.modify_model('parameters',
decimal2minifloat(0.125),
key='w_plast')
istdp_synapse = create_synapses(inh_neurons,
exc_neurons,
istdp_model)

set_hardwarelike_scheme(prefs, [neurons], defaultclock.dt, 'fp8')

""" ================ Setting up monitors ================ """
spikemon_exc_neurons = SpikeMonitor(exc_neurons,
name='spikemon_exc_neurons')
spikemon_inh_neurons = SpikeMonitor(inh_neurons,
name='spikemon_inh_neurons')
statemon_neurons = StateMonitor(exc_neurons,
variables=['Ca'],
record=True,
dt=(tsim-1*ms)/200,
name='statemon_neurons')
statemon_synapses = StateMonitor(istdp_synapse,
variables=['w_plast'],
record=[x for x in range(len(sources))],
dt=(tsim-1*ms)/2,
name='statemon_synapses')

metadata = {'event_condition': args.event_condition,
'protocol': args.protocol,
'duration': str(tsim)
}
with open(f'{args.save_path}/metadata.json', 'w') as f:
json.dump(metadata, f)

run(tsim, report='stdout', namespace=run_namespace)

if args.backend == 'cpp_standalone' or args.backend == 'cuda_standalone':
device.build(args.code_path)

""" =================== Saving data =================== """
output_spikes = pd.DataFrame(
{'time_ms': np.array(spikemon_exc_neurons.t/defaultclock.dt),
'id': np.array(spikemon_exc_neurons.i)}
)
output_spikes.to_csv(f'{args.save_path}/spikes_exc.csv', index=False)

output_spikes = pd.DataFrame(
{'time_ms': np.array(spikemon_inh_neurons.t/defaultclock.dt),
'id': np.array(spikemon_inh_neurons.i)}
)
output_spikes.to_csv(f'{args.save_path}/spikes_inh.csv', index=False)

state_vars = statemonitors2dataframe([statemon_synapses])
state_vars.to_csv(f'{args.save_path}/weights.csv', index=False)
state_vars = statemonitors2dataframe([statemon_neurons])
state_vars.to_csv(f'{args.save_path}/state_vars.csv', index=False)

0 comments on commit cb68f33

Please sign in to comment.