Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Step 5 - Dynamic and typed parameters #607

Merged
merged 55 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
d23467e
started trying to switch storage of parameters from ``double`` to ``T…
neworderofjamie Nov 17, 2023
022913e
fixed compilation errors
neworderofjamie Nov 17, 2023
f0f6e56
fixed up unit tests
neworderofjamie Nov 17, 2023
4acae1d
param values also switch from double to NumericValue
neworderofjamie Nov 17, 2023
e9425ce
starting adding types to parameters
neworderofjamie Nov 17, 2023
6dfa613
fixed some compilation issues and realised you don't need Utils::Over…
neworderofjamie Nov 17, 2023
ff0c8aa
unit test fixup
neworderofjamie Nov 17, 2023
57eff8e
GCC cannot infer auto in derived parameter lambda functions
neworderofjamie Nov 17, 2023
c9ed9a6
added types to integer snippet parameters
neworderofjamie Nov 18, 2023
baea8a1
first pass at getting rid of index-based tracking of whether neuron v…
neworderofjamie Nov 20, 2023
4dfbaa4
* restored error checking to setXXXVarLocation methods
neworderofjamie Nov 20, 2023
e13a2c5
updated hashing/fusing logic to include dynamic parameters
neworderofjamie Nov 20, 2023
075b0d1
implemented unit tests for checking dynamicness changes hash
neworderofjamie Nov 20, 2023
3c34452
throwing std::bad_optional_access to the user is pretty gross - catch…
neworderofjamie Nov 20, 2023
7d3f656
fixed GCC warning
neworderofjamie Nov 20, 2023
7041bab
fixed couple of typos
neworderofjamie Nov 20, 2023
29c5df4
fixed up failing unit tests which previously used non-array EGPs
neworderofjamie Nov 20, 2023
89463bf
mark dynamicness of parameter fields
neworderofjamie Nov 20, 2023
e638d16
push pointers and non-pointers
neworderofjamie Nov 20, 2023
50da59f
correct check for optional
neworderofjamie Nov 20, 2023
51fd776
WIP runtime
neworderofjamie Nov 20, 2023
7f8b94c
Still a bit rough but finally a hopefully working dynamic parameter r…
neworderofjamie Nov 21, 2023
1ffb91b
add public API for setting dynamic parameters
neworderofjamie Nov 21, 2023
487155f
use macro to implement tedious dynamic parameter boilerplate
neworderofjamie Nov 22, 2023
4421347
separately create dynamic parameters
neworderofjamie Nov 22, 2023
776f0ea
added types to derived parameters
neworderofjamie Nov 22, 2023
0bcb94a
fixed GCC warnings
neworderofjamie Nov 22, 2023
25930ed
basic exposure of typed parameters and derived parameters to PyGeNN
neworderofjamie Nov 22, 2023
75847ee
expose setParamDynamic methods in groups
neworderofjamie Nov 22, 2023
87cf7f1
wrapped setting of parameter values at runtime
neworderofjamie Nov 22, 2023
5a4f44c
default dynamicness to true in setParamDynamic methods
neworderofjamie Nov 22, 2023
a810078
handle and deprecate ``param_names`` kwarg and replace with ``params``
neworderofjamie Nov 22, 2023
2a2f530
feature test for dynamic parameters in neuron code
neworderofjamie Nov 22, 2023
a4b0f14
missing default somehow
neworderofjamie Nov 22, 2023
d75a043
fixed typo
neworderofjamie Nov 22, 2023
257ea8a
removed more deprecated tests
neworderofjamie Nov 22, 2023
94744f9
fixed substition of t into custom connectivity update host code
neworderofjamie Nov 22, 2023
e9c0338
extended test to custom update, current source, custom connectivity u…
neworderofjamie Nov 22, 2023
0d0d9e1
completed test to cover PSM - needs debugging
neworderofjamie Nov 23, 2023
8ea4a97
dumb typo in feature test
neworderofjamie Nov 23, 2023
fb566fb
convert additional input variable values to NumericValue
neworderofjamie Nov 23, 2023
1def65f
wrap casts of NumericValue to int64_t and double
neworderofjamie Nov 23, 2023
701e004
wrap derived parameter lambda's return value in NumericValue
neworderofjamie Nov 23, 2023
ef200bb
tidy some warnings
neworderofjamie Nov 24, 2023
46a2b92
replace group merged Field tuple with struct
neworderofjamie Nov 24, 2023
ebc3925
switch field datastructure from vector to set and handle duplications
neworderofjamie Nov 24, 2023
f2746ce
stripped out special case environment for neuron variables - should r…
neworderofjamie Nov 24, 2023
1f7e8a6
fixed GCC compiler error
neworderofjamie Nov 24, 2023
f757553
allow duplication of neuron and weight update model variable fields
neworderofjamie Nov 24, 2023
1b064d1
correctly wrap derived parameter lambdas
neworderofjamie Nov 24, 2023
700d751
pybind11 handling of std::variant + another layer of wrapping solves …
neworderofjamie Nov 24, 2023
f9610d5
a more useful error message
neworderofjamie Nov 25, 2023
443b476
Small fixes
neworderofjamie Nov 26, 2023
ea38a76
at least outwardly less broken implementation of ``genCopyDelayedVars``
neworderofjamie Nov 27, 2023
2f8a491
fixed a couple of typos
neworderofjamie Nov 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/genn/backends/cuda/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT
os << "Merged" << T::name << "Group" << g.getIndex() << " group = {";
const auto sortedFields = g.getSortedFields(*this);
for(const auto &f : sortedFields) {
os << std::get<1>(f) << ", ";
os << f.name << ", ";
}
os << "};" << std::endl;

Expand Down
2 changes: 1 addition & 1 deletion include/genn/genn/code_generator/backendBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "synapseMatrixType.h"
#include "type.h"
#include "varAccess.h"
#include "variableMode.h"
#include "varLocation.h"

// GeNN code generator includes
#include "code_generator/codeStream.h"
Expand Down
2 changes: 1 addition & 1 deletion include/genn/genn/code_generator/codeGenUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "gennUtils.h"
#include "neuronGroupInternal.h"
#include "type.h"
#include "variableMode.h"
#include "varLocation.h"

// GeNN code generator includes
#include "backendBase.h"
Expand Down
389 changes: 143 additions & 246 deletions include/genn/genn/code_generator/environment.h

Large diffs are not rendered by default.

94 changes: 64 additions & 30 deletions include/genn/genn/code_generator/groupMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Standard includes
#include <algorithm>
#include <functional>
#include <set>
#include <type_traits>
#include <vector>

Expand All @@ -24,6 +25,7 @@ class CodeStream;
namespace GeNN::Runtime
{
class ArrayBase;
class MergedDynamicFieldDestinations;
class Runtime;
}

Expand Down Expand Up @@ -64,9 +66,27 @@ class ChildGroupMerged
// Typedefines
//------------------------------------------------------------------------
typedef G GroupInternal;
typedef std::variant<Type::NumericValue, const Runtime::ArrayBase*> FieldValue;
typedef std::function<FieldValue(const Runtime::Runtime &, const G &, size_t)> GetFieldValueFunc;
typedef std::tuple<Type::ResolvedType, std::string, GetFieldValueFunc, GroupMergedFieldType> Field;
typedef std::variant<Type::NumericValue, const Runtime::ArrayBase*,
std::pair<Type::NumericValue, Runtime::MergedDynamicFieldDestinations&>> FieldValue;
typedef std::function<FieldValue(Runtime::Runtime &, const G &, size_t)> GetFieldValueFunc;

//------------------------------------------------------------------------
// Field
//------------------------------------------------------------------------
struct Field
{
std::string name;
Type::ResolvedType type;
GroupMergedFieldType fieldType;
GetFieldValueFunc getValue;

//! Less than operator (used for std::set::insert),
//! compares using only name
bool operator < (const Field &other) const
{
return (name < other.name);
}
};

ChildGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector<std::reference_wrapper<const GroupInternal>> groups)
: m_Index(index), m_TypeContext(typeContext), m_Groups(std::move(groups))
Expand Down Expand Up @@ -101,7 +121,7 @@ class ChildGroupMerged
bool isParamValueHeterogeneous(const std::string &name, P getParamValuesFn) const
{
// Get value of parameter in archetype group
const double archetypeValue = getParamValuesFn(getArchetype()).at(name);
const auto archetypeValue = getParamValuesFn(getArchetype()).at(name);

// Return true if any parameter values differ from the archetype value
return std::any_of(getGroups().cbegin(), getGroups().cend(),
Expand Down Expand Up @@ -129,7 +149,7 @@ class ChildGroupMerged
// Loop through groups
for(const auto &g : getGroups()) {
// Update hash with parameter value
Utils::updateHash(getValueFn(g.get()).at(p.first), hash);
Type::updateHash(getValueFn(g.get()).at(p.first), hash);
}
}
}
Expand All @@ -147,7 +167,7 @@ class ChildGroupMerged
const auto &values = A(g.get()).getInitialisers().at(varInit.first).getParams();

// Update hash with parameter value
Utils::updateHash(values.at(p.first), hash);
Type::updateHash(values.at(p.first), hash);
}
}
}
Expand All @@ -166,7 +186,7 @@ class ChildGroupMerged
const auto &values = A(g.get()).getInitialisers().at(varInit.first).getDerivedParams();

// Update hash with parameter value
Utils::updateHash(values.at(d.first), hash);
Type::updateHash(values.at(d.first), hash);
}
}
}
Expand Down Expand Up @@ -203,19 +223,19 @@ class GroupMerged : public ChildGroupMerged<G>
const std::string &getMemorySpace() const { return m_MemorySpace; }

//! Get group fields
const std::vector<typename ChildGroupMerged<G>::Field> &getFields() const{ return m_Fields; }
const std::set<typename ChildGroupMerged<G>::Field> &getFields() const{ return m_Fields; }

//! Get group fields, sorted into order they will appear in struct
std::vector<typename ChildGroupMerged<G>::Field> getSortedFields(const BackendBase &backend) const
{
// Make a copy of fields and sort so largest come first. This should mean that due
// Copy fields into vectorand sort so largest come first. This should mean that due
// to structure packing rules, significant memory is saved and estimate is more precise
auto sortedFields = m_Fields;
std::vector<typename ChildGroupMerged<G>::Field> sortedFields(m_Fields.cbegin(), m_Fields.cend());
const size_t pointerBytes = backend.getPointerBytes();
std::sort(sortedFields.begin(), sortedFields.end(),
[pointerBytes](const auto &a, const auto &b)
{
return (std::get<0>(a).getSize(pointerBytes) > std::get<0>(b).getSize(pointerBytes));
return (a.type.getSize(pointerBytes) > b.type.getSize(pointerBytes));
});
return sortedFields;

Expand All @@ -232,22 +252,21 @@ class GroupMerged : public ChildGroupMerged<G>
for(const auto &f : sortedFields) {
// If field is a pointer and not marked as being a host field
// (in which case the backend should leave its type alone!)
const auto &type = std::get<0>(f);
if(type.isPointer() && !(std::get<3>(f) & GroupMergedFieldType::HOST)) {
if(f.type.isPointer() && !(f.fieldType & GroupMergedFieldType::HOST)) {
// If we are generating a host structure, allow the backend to override the type
if(host) {
os << backend.getMergedGroupFieldHostTypeName(type);
os << backend.getMergedGroupFieldHostTypeName(f.type);
}
// Otherwise, allow the backend to add a prefix
else {
os << backend.getPointerPrefix() << type.getName();
os << backend.getPointerPrefix() << f.type.getName();
}
}
// Otherwise, leave the type alone
else {
os << type.getName();
os << f.type.getName();
}
os << " " << std::get<1>(f) << ";" << std::endl;
os << " " << f.name << ";" << std::endl;
}
os << std::endl;
}
Expand All @@ -261,7 +280,7 @@ class GroupMerged : public ChildGroupMerged<G>
const auto sortedFields = getSortedFields(backend);
for(size_t fieldIndex = 0; fieldIndex < sortedFields.size(); fieldIndex++) {
const auto &f = sortedFields[fieldIndex];
os << backend.getMergedGroupFieldHostTypeName(std::get<0>(f)) << " " << std::get<1>(f);
os << backend.getMergedGroupFieldHostTypeName(f.type) << " " << f.name;
if(fieldIndex != (sortedFields.size() - 1)) {
os << ", ";
}
Expand All @@ -283,7 +302,7 @@ class GroupMerged : public ChildGroupMerged<G>
// Loop through sorted fields and set array entry
const auto sortedFields = getSortedFields(backend);
for(const auto &f : sortedFields) {
os << "merged" << name << "Group" << this->getIndex() << "[idx]." << std::get<1>(f) << " = " << std::get<1>(f) << ";" << std::endl;
os << "merged" << name << "Group" << this->getIndex() << "[idx]." << f.name << " = " << f.name << ";" << std::endl;
}
}
}
Expand All @@ -297,7 +316,7 @@ class GroupMerged : public ChildGroupMerged<G>
const auto sortedFields = getSortedFields(backend);
for(const auto &f : sortedFields) {
// Add size of field to total
const size_t fieldSize = std::get<0>(f).getSize(backend.getPointerBytes());
const size_t fieldSize = f.type.getSize(backend.getPointerBytes());
structSize += fieldSize;

// Update largest field size
Expand Down Expand Up @@ -343,11 +362,26 @@ class GroupMerged : public ChildGroupMerged<G>
}

void addField(const Type::ResolvedType &type, const std::string &name, typename ChildGroupMerged<G>::GetFieldValueFunc getFieldValue,
GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD)
GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD, bool allowDuplicate = false)
{
// Add field to data structurChildGroupMergede
m_Fields.emplace_back(type, name, getFieldValue, fieldType);
}
// Add field to data structure
auto r = m_Fields.insert({name, type, fieldType, getFieldValue});

// If field wasn't successfully inserted
if(!r.second) {
// If duplicate fields are allowed
if(allowDuplicate) {
// If other properties of the field don't match
if(r.first->type != type || r.first->fieldType != fieldType) {
throw std::runtime_error("Unable to add duplicate field '" + name + "' with different properties to merged group");
}
}
// Otherwise, give error
else {
throw std::runtime_error("Unable to add duplicate field '" + name + "' to merged group");
}
}
}

protected:
//------------------------------------------------------------------------
Expand All @@ -362,14 +396,14 @@ class GroupMerged : public ChildGroupMerged<G>

// Loop through fields again to generate any dynamic field pushing functions that are required
for(const auto &f : m_Fields) {
if((std::get<3>(f) & GroupMergedFieldType::DYNAMIC)) {
definitions << "EXPORT_FUNC void pushMerged" << name << this->getIndex() << std::get<1>(f) << "ToDevice(unsigned int idx, ";
definitions << backend.getMergedGroupFieldHostTypeName(std::get<0>(f)) << " value);" << std::endl;
if((f.fieldType & GroupMergedFieldType::DYNAMIC)) {
definitions << "EXPORT_FUNC void pushMerged" << name << this->getIndex() << f.name << "ToDevice(unsigned int idx, ";
definitions << backend.getMergedGroupFieldHostTypeName(f.type) << " value);" << std::endl;
}

// If field is a pointer, assert that this is a host structure if field is a host or host object field
if(std::get<0>(f).isPointer()) {
assert((!(std::get<3>(f) & GroupMergedFieldType::HOST) && !(std::get<3>(f) & GroupMergedFieldType::HOST_OBJECT)) || host);
if(f.type.isPointer()) {
assert((!(f.fieldType & GroupMergedFieldType::HOST) && !(f.fieldType & GroupMergedFieldType::HOST_OBJECT)) || host);
}
}
}
Expand All @@ -379,7 +413,7 @@ class GroupMerged : public ChildGroupMerged<G>
// Members
//------------------------------------------------------------------------
std::string m_MemorySpace;
std::vector<typename ChildGroupMerged<G>::Field> m_Fields;
std::set<typename ChildGroupMerged<G>::Field> m_Fields;
};

//----------------------------------------------------------------------------
Expand Down
9 changes: 4 additions & 5 deletions include/genn/genn/code_generator/modelSpecMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,19 +276,18 @@ class GENN_EXPORT ModelSpecMerged
const auto &mergedGroup = groups[g];
for(const auto &f : mergedGroup.getFields()) {
// If field is dynamic, add record to merged EGPS
if((std::get<3>(f) & GroupMergedFieldType::DYNAMIC)) {
if((f.fieldType & GroupMergedFieldType::DYNAMIC)) {
// Add reference to this group's variable to data structure
// **NOTE** this works fine with EGP references because the function to
// get their value will just return the name of the referenced EGP
assert(std::get<0>(f).isPointer());
os << "void pushMerged" << T::name << g << std::get<1>(f) << "ToDevice(unsigned int idx, " << backend.getMergedGroupFieldHostTypeName(std::get<0>(f)) << " value)";
os << "void pushMerged" << T::name << g << f.name << "ToDevice(unsigned int idx, " << backend.getMergedGroupFieldHostTypeName(f.type) << " value)";
{
CodeStream::Scope b(os);
if(host) {
os << "merged" << T::name << "Group" << g << "[idx]." << std::get<1>(f) << " = value;" << std::endl;
os << "merged" << T::name << "Group" << g << "[idx]." << f.name << " = value;" << std::endl;
}
else {
backend.genMergedDynamicVariablePush(os, T::name, g, "idx", std::get<1>(f), "value");
backend.genMergedDynamicVariablePush(os, T::name, g, "idx", f.name, "value");
}
}
}
Expand Down
32 changes: 2 additions & 30 deletions include/genn/genn/code_generator/neuronUpdateGroupMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase

//! Should the current source derived parameter be implemented heterogeneously?
bool isDerivedParamHeterogeneous(const std::string &paramName) const;

private:
//----------------------------------------------------------------------------
// Private API
//----------------------------------------------------------------------------
//! Is the parameter referenced? **YUCK** only used for hashing
bool isParamReferenced(const std::string &paramName) const;
};

//----------------------------------------------------------------------------
Expand All @@ -66,13 +59,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase

//! Should the current source derived parameter be implemented heterogeneously?
bool isDerivedParamHeterogeneous(const std::string &paramName) const;

private:
//----------------------------------------------------------------------------
// Private API
//----------------------------------------------------------------------------
//! Is the parameter referenced? **YUCK** only used for hashing
bool isParamReferenced(const std::string &paramName) const;
};

//----------------------------------------------------------------------------
Expand Down Expand Up @@ -106,7 +92,7 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase
void generate(EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng,
unsigned int batchSize, bool dynamicsNotSpike);

void genCopyDelayedVars(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng,
void genCopyDelayedVars(EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng,
unsigned int batchSize);

//! Update hash with child groups
Expand All @@ -117,13 +103,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase

//! Should the current source derived parameter be implemented heterogeneously?
bool isDerivedParamHeterogeneous(const std::string &paramName) const;

private:
//----------------------------------------------------------------------------
// Private API
//----------------------------------------------------------------------------
//! Is the parameter referenced? **YUCK** only used for hashing
bool isParamReferenced(const std::string &paramName) const;
};

//----------------------------------------------------------------------------
Expand All @@ -141,7 +120,7 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase
void generate(EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng,
unsigned int batchSize, bool dynamicsNotSpike);

void genCopyDelayedVars(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng,
void genCopyDelayedVars(EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng,
unsigned int batchSize);

//! Update hash with child groups
Expand All @@ -152,13 +131,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase

//! Should the current source derived parameter be implemented heterogeneously?
bool isDerivedParamHeterogeneous(const std::string &paramName) const;

private:
//----------------------------------------------------------------------------
// Private API
//----------------------------------------------------------------------------
//! Is the parameter referenced? **YUCK** only used for hashing
bool isParamReferenced(const std::string &paramName) const;
};

NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext,
Expand Down
Loading