From b9b3e878137a7b4d975aec5b492766e9b058eb5c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 22 Oct 2021 11:02:58 +0100 Subject: [PATCH 1/5] Add functions to SharedLibraryModel to handle spike-like events --- userproject/include/sharedLibraryModel.h | 48 ++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/userproject/include/sharedLibraryModel.h b/userproject/include/sharedLibraryModel.h index 07bc53ce0d..dd08ca9530 100644 --- a/userproject/include/sharedLibraryModel.h +++ b/userproject/include/sharedLibraryModel.h @@ -180,6 +180,18 @@ class SharedLibraryModel // Call pull pushPull.second(); } + + void pullSpikeEventsFromDevice(const std::string &popName) + { + // Get push and pull spike events functions and check pull exists + const auto pushPull = getPopPushPullFunction(popName + "SpikeEvents"); + if(pushPull.second == nullptr) { + throw std::runtime_error("You cannot pull spike events from population '" + popName + "'"); + } + + // Call pull + pushPull.second(); + } void pullCurrentSpikesFromDevice(const std::string &popName) { @@ -192,7 +204,19 @@ class SharedLibraryModel // Call pull pushPull.second(); } + + void pullCurrentSpikesEventsFromDevice(const std::string &popName) + { + // Get push and pull spike events functions and check pull exists + const auto pushPull = getPopPushPullFunction(popName + "CurrentSpikeEvents"); + if(pushPull.second == nullptr) { + throw std::runtime_error("You cannot pull current spike events from population '" + popName + "'"); + } + // Call pull + pushPull.second(); + } + void pullConnectivityFromDevice(const std::string &popName) { // Get push and pull connectivity functions and check pull exists @@ -264,6 +288,18 @@ class SharedLibraryModel // Call push pushPull.first(uninitialisedOnly); } + + void pushSpikeEventsToDevice(const std::string &popName, bool uninitialisedOnly = false) + { + // Get push and pull spike events functions and check pull exists + const auto pushPull = getPopPushPullFunction(popName + "SpikeEvents"); + if(pushPull.first == nullptr) { + throw std::runtime_error("You cannot push spike events to population '" + popName + "'"); + } + + // Call push + pushPull.first(uninitialisedOnly); + } void pushCurrentSpikesToDevice(const std::string &popName, bool uninitialisedOnly = false) { @@ -276,6 +312,18 @@ class SharedLibraryModel // Call push pushPull.first(uninitialisedOnly); } + + void pushCurrentSpikeEventsToDevice(const std::string &popName, bool uninitialisedOnly = false) + { + // Get push and pull spike events functions and check pull exists + const auto pushPull = getPopPushPullFunction(popName + "CurrentSpikeEvents"); + if(pushPull.first == nullptr) { + throw std::runtime_error("You cannot push current spike events to population '" + popName + "'"); + } + + // Call push + pushPull.first(uninitialisedOnly); + } void pushConnectivityToDevice(const std::string &popName, bool uninitialisedOnly = false) { From d8899dcf254d87e427e7de4d13702f3cd79f1b69 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 22 Oct 2021 11:33:24 +0100 Subject: [PATCH 2/5] Python wrapper around spike-like events --- pygenn/genn_groups.py | 208 ++++++++++++++++++++++++++++-------------- 1 file changed, 141 insertions(+), 67 deletions(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index 99f1659f65..d6ce351acf 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -334,6 +334,8 @@ def __init__(self, name, model): self.neuron = None self.spikes = None self.spike_count = None + self.spike_events = None + self.spike_event_count = None self.spike_que_ptr = [0] self._max_delay_steps = 0 self.spike_times = None @@ -344,78 +346,30 @@ def __init__(self, name, model): @property def current_spikes(self): """Current spikes from GeNN""" - # Get current spike queue pointer - d = self.spike_que_ptr[0] - - # If batch size is one, return single slice of spikes - if self._model.batch_size == 1: - return self.spikes[0, d, 0:self.spike_count[0, d]] - # Otherwise, return list of slices - else: - return [self.spikes[b, d, 0:self.spike_count[b, d]] - for b in range(self._model.batch_size)] + return self._get_current_events(True) @current_spikes.setter def current_spikes(self, spikes): """Current spikes from GeNN""" - # Get current spike queue pointer - d = self.spike_que_ptr[0] + self._set_current_events(spikes, True) - # If batch size is one, set single spike count and spike data - if self._model.batch_size == 1: - num_spikes = len(spikes) - self.spike_count[0, d] = num_spikes - self.spikes[0, d, 0:num_spikes] = spikes - # Otherwise - else: - # Check that spikes have been passed for each batch - if len(spikes) != self._model.batch_size: - raise Exception("When using a batched model, you must " - "set current spikes using a list of spikes " - "for each batch") + @property + def current_spike_events(self): + """Current spike events from GeNN""" + return self._get_current_events(False) - # Loop through batches and set spike counts and spike data - for b, batch_spikes in enumerate(spikes): - num_spikes = len(batch_spikes) - self.spike_count[b, d] = num_spikes - self.spikes[b, d, 0:num_spikes] = batch_spikes + @current_spike_events.setter + def current_spike_events(self, spike_events): + """Current spike events from GeNN""" + self._set_current_events(spike_events, False) @property def spike_recording_data(self): - # Get byte view of data - data_bytes = self._spike_recording_data.view(dtype=np.uint8) - - # Reshape view into a tensor with time, batches and recording bytes - spike_recording_bytes = self._spike_recording_words * 4 - data_bytes = np.reshape(data_bytes, (-1, self._model.batch_size, - spike_recording_bytes)) - - # Calculate start time of recording - start_time_ms = (self._model.timestep - data_bytes.shape[0]) * self._model.dT - if start_time_ms < 0.0: - raise Exception("spike_recording_data can only be " - "accessed once buffer is full.") - - # Unpack data (results in one byte per bit) - # **THINK** is there a way to avoid this step? - data_unpack = np.unpackbits(data_bytes, axis=2, - count=self.size, - bitorder="little") - - # Loop through batches - spike_data = [] - for b in range(self._model.batch_size): - # Calculate indices where there are spikes - spikes = np.where(data_unpack[:,b,:] == 1) - - # Convert spike times to ms - spike_times = start_time_ms + (spikes[0] * self._model.dT) - - # Add to list - spike_data.append((spike_times, spikes[1])) + return self._get_event_recording_data(True) - # If batch size is 1, return 1st population's spikes otherwise list - return spike_data[0] if self._model.batch_size == 1 else spike_data + @property + def spike_event_recording_data(self): + return self._get_event_recording_data(False) @property def delay_slots(self): @@ -433,6 +387,14 @@ def spike_recording_enabled(self): @spike_recording_enabled.setter def spike_recording_enabled(self, enabled): return self.pop.set_spike_recording_enabled(enabled) + + @property + def spike_event_recording_enabled(self): + return self.pop.is_spike_event_recording_enabled() + + @spike_event_recording_enabled.setter + def spike_event_recording_enabled(self, enabled): + return self.pop.set_spike_event_recording_enabled(enabled) def set_neuron(self, model, param_space, var_space): """Set neuron, its parameters and initial variables @@ -534,6 +496,24 @@ def load(self, num_recording_timesteps): self.spike_count = np.reshape(self.spike_count, (batch_size, self.delay_slots)) + # If this neuron group produces spike events and + # spike event data is present on the host + if (self.pop.is_spike_event_required() and + (self.pop.get_event_spike_location() & VarLocation_HOST) != 0): + self.spike_events = self._assign_ext_ptr_array( + "glbSpkEvnt", self.size * self.delay_slots * batch_size, + "unsigned int") + self.spike_event_count = self._assign_ext_ptr_array( + "glbSpkCntEvnt", self.delay_slots * batch_size, "unsigned int") + + # Reshape to expose delay slots and batches + self.spike_events = np.reshape(self.spike_events, (batch_size, + self.delay_slots, + self.size)) + self.spike_event_count = np.reshape(self.spike_event_count, (batch_size, + self.delay_slots)) + + # If this neuron group generates spike times # and they are accesible on the host if (self.pop.is_spike_time_required() and @@ -565,13 +545,23 @@ def load(self, num_recording_timesteps): # If spike recording is enabled if self.spike_recording_enabled: # Calculate spike recording words - recording_words = (self._spike_recording_words * num_recording_timesteps + recording_words = (self._event_recording_words * num_recording_timesteps + * batch_size) + + # Assign pointer to recording data + self._spike_recording_data = self._assign_ext_ptr_array( + "recordSpk", recording_words, "uint32_t") + + # If spike-event recording is enabled + if self.spike_event_recording_enabled: + # Calculate spike recording words + recording_words = (self._event_recording_words * num_recording_timesteps * batch_size) # Assign pointer to recording data - self._spike_recording_data = self._assign_ext_ptr_array("recordSpk", - recording_words, - "uint32_t") + self._spike_event_recording_data = self._assign_ext_ptr_array( + "recordSpkEvent", recording_words, "uint32_t") + if self.delay_slots > 1: self.spike_que_ptr = self._model._slm.assign_external_pointer_single_ui( "spkQuePtr" + self.name) @@ -593,7 +583,7 @@ def reinitialise(self): self._reinitialise_vars() @property - def _spike_recording_words(self): + def _event_recording_words(self): return ((self.size + 31) // 32) def _get_event_time_view(self, name): @@ -607,6 +597,90 @@ def _get_event_time_view(self, name): view = np.reshape(view, (batch_size, self.delay_slots, self.size)) return view + + def _get_current_events(self, true_spike): + # Get current spike queue pointer + d = self.spike_que_ptr[0] + + # Get event data + event_count = self.spike_count if true_spike else self.spike_event_count + events = self.spikes if true_spike else self.spike_events + + # If batch size is one, return single slice of spikes + if self._model.batch_size == 1: + return events[0, d, 0:event_count[0, d]] + # Otherwise, return list of slices + else: + return [events[b, d, 0:event_count[b, d]] + for b in range(self._model.batch_size)] + + def _set_current_events(self, current_events, true_spike): + """Current spikes from GeNN""" + # Get current spike queue pointer + d = self.spike_que_ptr[0] + + # Get event data + event_count = self.spike_count if true_spike else self.spike_event_count + events = self.spikes if true_spike else self.spike_events + description = "spikes" if true_spike else "spike-events" + + # If batch size is one, set single event count and event data + if self._model.batch_size == 1: + num_events = len(current_events) + event_count[0, d] = num_events + events[0, d, 0:num_events] = current_events + # Otherwise + else: + # Check that events have been passed for each batch + if len(current_events) != self._model.batch_size: + raise Exception("When using a batched model, you must " + "set current %s using a list of %s " + "for each batch" % description) + + # Loop through batches and set spike counts and spike data + for b, batch_events in enumerate(current_events): + num_events = len(batch_spikes) + event_count[b, d] = num_events + self.spikes[b, d, 0:num_events] = batch_events + + def _get_event_recording_data(self, true_spike): + # Get byte view of data + recording_data = (self._spike_recording_data if true_spike + else self._spike_event_recording_data) + data_bytes = recording_data.view(dtype=np.uint8) + + # Reshape view into a tensor with time, batches and recording bytes + event_recording_bytes = self._event_recording_words * 4 + data_bytes = np.reshape(data_bytes, (-1, self._model.batch_size, + event_recording_bytes)) + + # Calculate start time of recording + start_time_ms = (self._model.timestep - data_bytes.shape[0]) * self._model.dT + if start_time_ms < 0.0: + raise Exception("spike_recording_data can only be " + "accessed once buffer is full.") + + # Unpack data (results in one byte per bit) + # **THINK** is there a way to avoid this step? + data_unpack = np.unpackbits(data_bytes, axis=2, + count=self.size, + bitorder="little") + + # Loop through batches + event_data = [] + for b in range(self._model.batch_size): + # Calculate indices where there are events + events = np.where(data_unpack[:,b,:] == 1) + + # Convert event times to ms + event_times = start_time_ms + (events[0] * self._model.dT) + + # Add to list + event_data.append((event_times, events[1])) + + # If batch size is 1, return 1st population's events otherwise list + return event_data[0] if self._model.batch_size == 1 else event_data + class SynapseGroup(Group): From f675997cf4c873a87f37efe96e838972476a7afe Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 22 Oct 2021 11:39:36 +0100 Subject: [PATCH 3/5] forgot push and pull --- pygenn/genn_groups.py | 18 +++++++++++++++++- pygenn/genn_model.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index d6ce351acf..ea9decc135 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -425,10 +425,18 @@ def add_to(self, num_neurons): def pull_spikes_from_device(self): """Wrapper around GeNNModel.pull_spikes_from_device""" self._model.pull_spikes_from_device(self.name) + + def pull_spike_events_from_device(self): + """Wrapper around GeNNModel.pull_spike_events_from_device""" + self._model.pull_spike_events_from_device(self.name) def pull_current_spikes_from_device(self): """Wrapper around GeNNModel.pull_current_spikes_from_device""" self._model.pull_current_spikes_from_device(self.name) + + def pull_current_spike_events_from_device(self): + """Wrapper around GeNNModel.pull_current_spike_events_from_device""" + self._model.pull_current_spike_events_from_device(self.name) def pull_spike_times_from_device(self): """Helper to pull spike times from device""" @@ -453,10 +461,18 @@ def pull_prev_spike_event_times_from_device(self): def push_spikes_to_device(self): """Wrapper around GeNNModel.push_spikes_to_device""" self._model.push_spikes_to_device(self.name) - + + def push_spike_events_to_device(self): + """Wrapper around GeNNModel.push_spike_events_to_device""" + self._model.push_spike_events_to_device(self.name) + def push_current_spikes_to_device(self): """Wrapper around GeNNModel.push_current_spikes_to_device""" self._model.push_current_spikes_to_device(self.name) + + def push_current_spike_events_to_device(self): + """Wrapper around GeNNModel.push_current_spike_events_to_device""" + self._model.push_current_spike_events_to_device(self.name) def push_spike_times_to_device(self): """Helper to push spike times to device""" diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index fab3c70ff0..b49cbc7a4b 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -712,6 +712,13 @@ def pull_spikes_from_device(self, pop_name): raise Exception("GeNN model has to be loaded before pulling") self._slm.pull_spikes_from_device(pop_name) + + def pull_spike_events_from_device(self, pop_name): + """Pull spike events from the device for a given population""" + if not self._loaded: + raise Exception("GeNN model has to be loaded before pulling") + + self._slm.pull_spike_events_from_device(pop_name) def pull_current_spikes_from_device(self, pop_name): """Pull spikes from the device for a given population""" @@ -719,7 +726,14 @@ def pull_current_spikes_from_device(self, pop_name): raise Exception("GeNN model has to be loaded before pulling") self._slm.pull_current_spikes_from_device(pop_name) + + def pull_current_spike_events_from_device(self, pop_name): + """Pull spike events from the device for a given population""" + if not self._loaded: + raise Exception("GeNN model has to be loaded before pulling") + self._slm.pull_current_spike_events_from_device(pop_name) + def pull_connectivity_from_device(self, pop_name): """Pull connectivity from the device for a given population""" if not self._loaded: @@ -759,6 +773,13 @@ def push_spikes_to_device(self, pop_name): raise Exception("GeNN model has to be loaded before pushing") self._slm.push_spikes_to_device(pop_name) + + def push_spike_events_to_device(self, pop_name): + """Push spike events to the device for a given population""" + if not self._loaded: + raise Exception("GeNN model has to be loaded before pushing") + + self._slm.push_spike_events_to_device(pop_name) def push_current_spikes_to_device(self, pop_name): """Push current spikes to the device for a given population""" @@ -766,6 +787,13 @@ def push_current_spikes_to_device(self, pop_name): raise Exception("GeNN model has to be loaded before pushing") self._slm.push_current_spikes_to_device(pop_name) + + def push_current_spike_events_to_device(self, pop_name): + """Push current spike events to the device for a given population""" + if not self._loaded: + raise Exception("GeNN model has to be loaded before pushing") + + self._slm.push_current_spike_events_to_device(pop_name) def push_connectivity_to_device(self, pop_name): """Push connectivity to the device for a given population""" From 7a240775fb8b093e84a6c4cb1426cb5b542a4e87 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 22 Oct 2021 13:50:01 +0100 Subject: [PATCH 4/5] fixed small typo --- pygenn/genn_groups.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index ea9decc135..9211d560d4 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -655,9 +655,9 @@ def _set_current_events(self, current_events, true_spike): # Loop through batches and set spike counts and spike data for b, batch_events in enumerate(current_events): - num_events = len(batch_spikes) + num_events = len(batch_events) event_count[b, d] = num_events - self.spikes[b, d, 0:num_events] = batch_events + events[b, d, 0:num_events] = batch_events def _get_event_recording_data(self, true_spike): # Get byte view of data From 7205e8fdaa6a6f9ec1ab7a743a6732471994fa54 Mon Sep 17 00:00:00 2001 From: chanokin Date: Fri, 22 Oct 2021 22:52:25 +0100 Subject: [PATCH 5/5] typo in line 518 self.pop.get_event_spike_location() should be self.pop.get_spike_event_location() --- pygenn/genn_groups.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index 9211d560d4..9413253b95 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -515,7 +515,7 @@ def load(self, num_recording_timesteps): # If this neuron group produces spike events and # spike event data is present on the host if (self.pop.is_spike_event_required() and - (self.pop.get_event_spike_location() & VarLocation_HOST) != 0): + (self.pop.get_spike_event_location() & VarLocation_HOST) != 0): self.spike_events = self._assign_ext_ptr_array( "glbSpkEvnt", self.size * self.delay_slots * batch_size, "unsigned int")