Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose spike-like events to PyGeNN #469

Merged
merged 5 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 158 additions & 68 deletions pygenn/genn_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ def __init__(self, name, model):
self.neuron = None
self.spikes = None
self.spike_count = None
self.spike_events = None
self.spike_event_count = None
self.spike_que_ptr = [0]
self._max_delay_steps = 0
self.spike_times = None
Expand All @@ -344,78 +346,30 @@ def __init__(self, name, model):
@property
def current_spikes(self):
"""Current spikes from GeNN"""
# Get current spike queue pointer
d = self.spike_que_ptr[0]

# If batch size is one, return single slice of spikes
if self._model.batch_size == 1:
return self.spikes[0, d, 0:self.spike_count[0, d]]
# Otherwise, return list of slices
else:
return [self.spikes[b, d, 0:self.spike_count[b, d]]
for b in range(self._model.batch_size)]
return self._get_current_events(True)

@current_spikes.setter
def current_spikes(self, spikes):
"""Current spikes from GeNN"""
# Get current spike queue pointer
d = self.spike_que_ptr[0]
self._set_current_events(spikes, True)

# If batch size is one, set single spike count and spike data
if self._model.batch_size == 1:
num_spikes = len(spikes)
self.spike_count[0, d] = num_spikes
self.spikes[0, d, 0:num_spikes] = spikes
# Otherwise
else:
# Check that spikes have been passed for each batch
if len(spikes) != self._model.batch_size:
raise Exception("When using a batched model, you must "
"set current spikes using a list of spikes "
"for each batch")
@property
def current_spike_events(self):
"""Current spike events from GeNN"""
return self._get_current_events(False)

# Loop through batches and set spike counts and spike data
for b, batch_spikes in enumerate(spikes):
num_spikes = len(batch_spikes)
self.spike_count[b, d] = num_spikes
self.spikes[b, d, 0:num_spikes] = batch_spikes
@current_spike_events.setter
def current_spike_events(self, spike_events):
"""Current spike events from GeNN"""
self._set_current_events(spike_events, False)

@property
def spike_recording_data(self):
# Get byte view of data
data_bytes = self._spike_recording_data.view(dtype=np.uint8)
return self._get_event_recording_data(True)

# Reshape view into a tensor with time, batches and recording bytes
spike_recording_bytes = self._spike_recording_words * 4
data_bytes = np.reshape(data_bytes, (-1, self._model.batch_size,
spike_recording_bytes))

# Calculate start time of recording
start_time_ms = (self._model.timestep - data_bytes.shape[0]) * self._model.dT
if start_time_ms < 0.0:
raise Exception("spike_recording_data can only be "
"accessed once buffer is full.")

# Unpack data (results in one byte per bit)
# **THINK** is there a way to avoid this step?
data_unpack = np.unpackbits(data_bytes, axis=2,
count=self.size,
bitorder="little")

# Loop through batches
spike_data = []
for b in range(self._model.batch_size):
# Calculate indices where there are spikes
spikes = np.where(data_unpack[:,b,:] == 1)

# Convert spike times to ms
spike_times = start_time_ms + (spikes[0] * self._model.dT)

# Add to list
spike_data.append((spike_times, spikes[1]))

# If batch size is 1, return 1st population's spikes otherwise list
return spike_data[0] if self._model.batch_size == 1 else spike_data
@property
def spike_event_recording_data(self):
return self._get_event_recording_data(False)

@property
def delay_slots(self):
Expand All @@ -433,6 +387,14 @@ def spike_recording_enabled(self):
@spike_recording_enabled.setter
def spike_recording_enabled(self, enabled):
return self.pop.set_spike_recording_enabled(enabled)

@property
def spike_event_recording_enabled(self):
return self.pop.is_spike_event_recording_enabled()

@spike_event_recording_enabled.setter
def spike_event_recording_enabled(self, enabled):
return self.pop.set_spike_event_recording_enabled(enabled)

def set_neuron(self, model, param_space, var_space):
"""Set neuron, its parameters and initial variables
Expand Down Expand Up @@ -463,10 +425,18 @@ def add_to(self, num_neurons):
def pull_spikes_from_device(self):
"""Wrapper around GeNNModel.pull_spikes_from_device"""
self._model.pull_spikes_from_device(self.name)

def pull_spike_events_from_device(self):
"""Wrapper around GeNNModel.pull_spike_events_from_device"""
self._model.pull_spike_events_from_device(self.name)

def pull_current_spikes_from_device(self):
"""Wrapper around GeNNModel.pull_current_spikes_from_device"""
self._model.pull_current_spikes_from_device(self.name)

def pull_current_spike_events_from_device(self):
"""Wrapper around GeNNModel.pull_current_spike_events_from_device"""
self._model.pull_current_spike_events_from_device(self.name)

def pull_spike_times_from_device(self):
"""Helper to pull spike times from device"""
Expand All @@ -491,10 +461,18 @@ def pull_prev_spike_event_times_from_device(self):
def push_spikes_to_device(self):
"""Wrapper around GeNNModel.push_spikes_to_device"""
self._model.push_spikes_to_device(self.name)


def push_spike_events_to_device(self):
"""Wrapper around GeNNModel.push_spike_events_to_device"""
self._model.push_spike_events_to_device(self.name)

def push_current_spikes_to_device(self):
"""Wrapper around GeNNModel.push_current_spikes_to_device"""
self._model.push_current_spikes_to_device(self.name)

def push_current_spike_events_to_device(self):
"""Wrapper around GeNNModel.push_current_spike_events_to_device"""
self._model.push_current_spike_events_to_device(self.name)

def push_spike_times_to_device(self):
"""Helper to push spike times to device"""
Expand Down Expand Up @@ -534,6 +512,24 @@ def load(self, num_recording_timesteps):
self.spike_count = np.reshape(self.spike_count, (batch_size,
self.delay_slots))

# If this neuron group produces spike events and
# spike event data is present on the host
if (self.pop.is_spike_event_required() and
(self.pop.get_spike_event_location() & VarLocation_HOST) != 0):
self.spike_events = self._assign_ext_ptr_array(
"glbSpkEvnt", self.size * self.delay_slots * batch_size,
"unsigned int")
self.spike_event_count = self._assign_ext_ptr_array(
"glbSpkCntEvnt", self.delay_slots * batch_size, "unsigned int")

# Reshape to expose delay slots and batches
self.spike_events = np.reshape(self.spike_events, (batch_size,
self.delay_slots,
self.size))
self.spike_event_count = np.reshape(self.spike_event_count, (batch_size,
self.delay_slots))


# If this neuron group generates spike times
# and they are accesible on the host
if (self.pop.is_spike_time_required() and
Expand Down Expand Up @@ -565,13 +561,23 @@ def load(self, num_recording_timesteps):
# If spike recording is enabled
if self.spike_recording_enabled:
# Calculate spike recording words
recording_words = (self._spike_recording_words * num_recording_timesteps
recording_words = (self._event_recording_words * num_recording_timesteps
* batch_size)

# Assign pointer to recording data
self._spike_recording_data = self._assign_ext_ptr_array("recordSpk",
recording_words,
"uint32_t")
self._spike_recording_data = self._assign_ext_ptr_array(
"recordSpk", recording_words, "uint32_t")

# If spike-event recording is enabled
if self.spike_event_recording_enabled:
# Calculate spike recording words
recording_words = (self._event_recording_words * num_recording_timesteps
* batch_size)

# Assign pointer to recording data
self._spike_event_recording_data = self._assign_ext_ptr_array(
"recordSpkEvent", recording_words, "uint32_t")

if self.delay_slots > 1:
self.spike_que_ptr = self._model._slm.assign_external_pointer_single_ui(
"spkQuePtr" + self.name)
Expand All @@ -593,7 +599,7 @@ def reinitialise(self):
self._reinitialise_vars()

@property
def _spike_recording_words(self):
def _event_recording_words(self):
return ((self.size + 31) // 32)

def _get_event_time_view(self, name):
Expand All @@ -607,6 +613,90 @@ def _get_event_time_view(self, name):
view = np.reshape(view, (batch_size, self.delay_slots,
self.size))
return view

def _get_current_events(self, true_spike):
# Get current spike queue pointer
d = self.spike_que_ptr[0]

# Get event data
event_count = self.spike_count if true_spike else self.spike_event_count
events = self.spikes if true_spike else self.spike_events

# If batch size is one, return single slice of spikes
if self._model.batch_size == 1:
return events[0, d, 0:event_count[0, d]]
# Otherwise, return list of slices
else:
return [events[b, d, 0:event_count[b, d]]
for b in range(self._model.batch_size)]

def _set_current_events(self, current_events, true_spike):
"""Current spikes from GeNN"""
# Get current spike queue pointer
d = self.spike_que_ptr[0]

# Get event data
event_count = self.spike_count if true_spike else self.spike_event_count
events = self.spikes if true_spike else self.spike_events
description = "spikes" if true_spike else "spike-events"

# If batch size is one, set single event count and event data
if self._model.batch_size == 1:
num_events = len(current_events)
event_count[0, d] = num_events
events[0, d, 0:num_events] = current_events
# Otherwise
else:
# Check that events have been passed for each batch
if len(current_events) != self._model.batch_size:
raise Exception("When using a batched model, you must "
"set current %s using a list of %s "
"for each batch" % description)

# Loop through batches and set spike counts and spike data
for b, batch_events in enumerate(current_events):
num_events = len(batch_events)
event_count[b, d] = num_events
events[b, d, 0:num_events] = batch_events

def _get_event_recording_data(self, true_spike):
# Get byte view of data
recording_data = (self._spike_recording_data if true_spike
else self._spike_event_recording_data)
data_bytes = recording_data.view(dtype=np.uint8)

# Reshape view into a tensor with time, batches and recording bytes
event_recording_bytes = self._event_recording_words * 4
data_bytes = np.reshape(data_bytes, (-1, self._model.batch_size,
event_recording_bytes))

# Calculate start time of recording
start_time_ms = (self._model.timestep - data_bytes.shape[0]) * self._model.dT
if start_time_ms < 0.0:
raise Exception("spike_recording_data can only be "
"accessed once buffer is full.")

# Unpack data (results in one byte per bit)
# **THINK** is there a way to avoid this step?
data_unpack = np.unpackbits(data_bytes, axis=2,
count=self.size,
bitorder="little")

# Loop through batches
event_data = []
for b in range(self._model.batch_size):
# Calculate indices where there are events
events = np.where(data_unpack[:,b,:] == 1)

# Convert event times to ms
event_times = start_time_ms + (events[0] * self._model.dT)

# Add to list
event_data.append((event_times, events[1]))

# If batch size is 1, return 1st population's events otherwise list
return event_data[0] if self._model.batch_size == 1 else event_data


class SynapseGroup(Group):

Expand Down
28 changes: 28 additions & 0 deletions pygenn/genn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,14 +712,28 @@ def pull_spikes_from_device(self, pop_name):
raise Exception("GeNN model has to be loaded before pulling")

self._slm.pull_spikes_from_device(pop_name)

def pull_spike_events_from_device(self, pop_name):
"""Pull spike events from the device for a given population"""
if not self._loaded:
raise Exception("GeNN model has to be loaded before pulling")

self._slm.pull_spike_events_from_device(pop_name)

def pull_current_spikes_from_device(self, pop_name):
"""Pull spikes from the device for a given population"""
if not self._loaded:
raise Exception("GeNN model has to be loaded before pulling")

self._slm.pull_current_spikes_from_device(pop_name)

def pull_current_spike_events_from_device(self, pop_name):
"""Pull spike events from the device for a given population"""
if not self._loaded:
raise Exception("GeNN model has to be loaded before pulling")

self._slm.pull_current_spike_events_from_device(pop_name)

def pull_connectivity_from_device(self, pop_name):
"""Pull connectivity from the device for a given population"""
if not self._loaded:
Expand Down Expand Up @@ -759,13 +773,27 @@ def push_spikes_to_device(self, pop_name):
raise Exception("GeNN model has to be loaded before pushing")

self._slm.push_spikes_to_device(pop_name)

def push_spike_events_to_device(self, pop_name):
"""Push spike events to the device for a given population"""
if not self._loaded:
raise Exception("GeNN model has to be loaded before pushing")

self._slm.push_spike_events_to_device(pop_name)

def push_current_spikes_to_device(self, pop_name):
"""Push current spikes to the device for a given population"""
if not self._loaded:
raise Exception("GeNN model has to be loaded before pushing")

self._slm.push_current_spikes_to_device(pop_name)

def push_current_spike_events_to_device(self, pop_name):
"""Push current spike events to the device for a given population"""
if not self._loaded:
raise Exception("GeNN model has to be loaded before pushing")

self._slm.push_current_spike_events_to_device(pop_name)

def push_connectivity_to_device(self, pop_name):
"""Push connectivity to the device for a given population"""
Expand Down
Loading