Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch reductions #447

Merged
merged 29 commits into from
Aug 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
eaed3ee
new variable access modes
neworderofjamie Jul 20, 2021
8c17846
removed strange usage of VarAccessDuplication & VarAccessDuplication …
neworderofjamie Jul 20, 2021
b36c249
finally tidied up makefile for unit tests
neworderofjamie Jul 20, 2021
590bd79
fixed bug in var access modes
neworderofjamie Jul 20, 2021
ddf88a7
checks that reduction variables aren't added to models other than cus…
neworderofjamie Jul 20, 2021
fe1c9c6
Merge branch 'name_validate' into reductions
neworderofjamie Jul 20, 2021
9f2301d
incorporate reduction variable checks into validation framework
neworderofjamie Jul 20, 2021
4f6242d
initial implementation of reduction operations
neworderofjamie Jul 21, 2021
215928c
fixed compiler warnings
neworderofjamie Jul 22, 2021
8467025
initial implemention of weight reduction operations
neworderofjamie Jul 22, 2021
27d676d
started feature test for reductions
neworderofjamie Jul 22, 2021
7623196
fixed typos
neworderofjamie Jul 22, 2021
dccd48b
complete feature test
neworderofjamie Jul 22, 2021
b0bebdc
fixed warning
neworderofjamie Jul 22, 2021
31a6f7b
fixed issue with var access flags
neworderofjamie Jul 22, 2021
250052b
respect ``isBatched`` in ``CustomUpdateWUGroupMergedBase::getVarIndex…
neworderofjamie Jul 22, 2021
79b43cb
fixed issue in PyGeNN
neworderofjamie Jul 22, 2021
321d49f
* added error message if you try and reduce into a duplicate variable
neworderofjamie Aug 5, 2021
49351c9
Small refactor of ``CustomUpdateBase::isBatched`` and ``CustomUpdate:…
neworderofjamie Aug 5, 2021
1a09636
updated tests to reflect that we can now detect some errors in ``Mode…
neworderofjamie Aug 5, 2021
f1a0d59
moved some generic reduction-handling code down from ``BackendSIMT`` …
neworderofjamie Aug 5, 2021
ca7988a
new test for reductions with batch size 1 (tests fallback for single-…
neworderofjamie Aug 5, 2021
4d6c790
simplification
neworderofjamie Aug 5, 2021
cbca852
actually BackendSIMT is a better place for ``genInitReductionTargets``
neworderofjamie Aug 5, 2021
5fe2b85
implemented single-threaded CPU non-reduction
neworderofjamie Aug 5, 2021
0191f8f
additional ``WUVarReference`` constructors and ``createWUVarRef`` wra…
neworderofjamie Aug 5, 2021
5865000
Merge branch 'master' into reductions
neworderofjamie Aug 5, 2021
b45882f
fixed comments
neworderofjamie Aug 11, 2021
8bd66ac
Merge branch 'master' into reductions
neworderofjamie Aug 11, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/genn/backends/single_threaded_cpu/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

// GeNN includes
#include "backendExport.h"
#include "varAccess.h"

// GeNN code generator includes
#include "code_generator/backendBase.h"
Expand Down Expand Up @@ -198,6 +199,28 @@ class BACKEND_EXPORT Backend : public BackendBase
}
}
}

//! Helper to generate code to copy reduced variables back to variables
/*! Because reduction operations are unnecessary in unbatched single-threaded CPU models so there's no need to actually reduce */
template<typename G>
void genWriteBackReductions(CodeStream &os, const G &cg, const std::string &idx) const
{
const auto *cm = cg.getArchetype().getCustomUpdateModel();
for(const auto &v : cm->getVars()) {
// If variable is a reduction target, copy value from register straight back into global memory
if(v.access & VarAccessModeAttribute::REDUCE) {
os << "group->" << v.name << "[" << idx << "] = l" << v.name << ";" << std::endl;
}
}

// Loop through variable references
for(const auto &v : cm->getVarRefs()) {
// If variable reference is a reduction target, copy value from register straight back into global memory
if(v.access & VarAccessModeAttribute::REDUCE) {
os << "group->" << v.name << "[" << idx<< "] = l" << v.name << ";" << std::endl;
}
}
}
};
} // namespace SingleThreadedCPU
} // namespace CodeGenerator
15 changes: 10 additions & 5 deletions include/genn/genn/code_generator/backendBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ class GENN_EXPORT BackendBase
//! Get the size of the type
size_t getSize(const std::string &type) const;

//! Get the lowest value of a type
std::string getLowestValue(const std::string &type) const;

//! Get the prefix for accessing the address of 'scalar' variables
std::string getScalarAddressPrefix() const
{
Expand All @@ -500,9 +503,10 @@ class GENN_EXPORT BackendBase
//--------------------------------------------------------------------------
// Protected API
//--------------------------------------------------------------------------
void addType(const std::string &type, size_t size)
void addType(const std::string &type, size_t size, const std::string &lowestValue = "")
{
m_TypeBytes.emplace(type, size);
m_Types.emplace(std::piecewise_construct, std::forward_as_tuple(type),
std::forward_as_tuple(size, lowestValue));
}

void setPointerBytes(size_t pointerBytes)
Expand All @@ -514,7 +518,7 @@ class GENN_EXPORT BackendBase

void genSynapseIndexCalculation(CodeStream &os, const SynapseGroupMergedBase &sg, unsigned int batchSize) const;

void genCustomUpdateIndexCalculation(CodeStream &os, const CustomUpdateGroupMerged &cu) const;
void genCustomUpdateIndexCalculation(CodeStream &os, const CustomUpdateGroupMerged &cu, unsigned int batchSize) const;

private:
//--------------------------------------------------------------------------
Expand All @@ -523,8 +527,9 @@ class GENN_EXPORT BackendBase
//! How large is a device pointer? E.g. on some AMD devices this != sizeof(char*)
size_t m_PointerBytes;

//! Size of supported types in bytes - used for estimating memory usage
std::unordered_map<std::string, size_t> m_TypeBytes;
//! Size of supported types in bytes and string containing their lowest value
//! used for estimating memory usage and for reduction operations
std::unordered_map<std::string, std::pair<size_t, std::string>> m_Types;

//! Preferences
const PreferencesBase &m_Preferences;
Expand Down
46 changes: 45 additions & 1 deletion include/genn/genn/code_generator/backendSIMT.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

// GeNN includes
#include "gennExport.h"
#include "varAccess.h"

// GeNN code generator includes
#include "code_generator/backendBase.h"
Expand Down Expand Up @@ -212,7 +213,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase
size_t numInitializeThreads, size_t &idStart) const;

//! Adds a type - both to backend base's list of sized types but also to device types set
void addDeviceType(const std::string &type, size_t size);
void addDeviceType(const std::string &type, size_t size, const std::string &maxValue = "");

//! Is type a a device only type?
bool isDeviceType(const std::string &type) const;
Expand All @@ -224,6 +225,23 @@ class GENN_EXPORT BackendSIMT : public BackendBase
const KernelBlockSize &getKernelBlockSize() const { return m_KernelBlockSizes; }

private:
//--------------------------------------------------------------------------
// ReductionTarget
//--------------------------------------------------------------------------
//! Simple struct to hold reduction targets
struct ReductionTarget
{
ReductionTarget(const std::string &n, const std::string &t, VarAccessMode a)
: name(n), type(t), access(a)
{
}

const std::string name;
const std::string type;
const VarAccessMode access;
};


//--------------------------------------------------------------------------
// Type definitions
//--------------------------------------------------------------------------
Expand Down Expand Up @@ -311,6 +329,32 @@ class GENN_EXPORT BackendSIMT : public BackendBase
}
}


template<typename G>
std::vector<ReductionTarget> genInitReductionTargets(CodeStream &os, const G &cg) const
{
// Loop through variables
std::vector<ReductionTarget> reductionTargets;
const auto *cm = cg.getArchetype().getCustomUpdateModel();
for(const auto &v : cm->getVars()) {
// If variable is a reduction target, define variable initialised to correct initial value for reduction
if(v.access & VarAccessModeAttribute::REDUCE) {
os << v.type << " lr" << v.name << " = " << getReductionInitialValue(*this, getVarAccessMode(v.access), v.type) << ";" << std::endl;
reductionTargets.emplace_back(v.name, v.type, getVarAccessMode(v.access));
}
}

// Loop through variable references
for(const auto &v : cm->getVarRefs()) {
// If variable reference is a reduction target, define variable initialised to correct initial value for reduction
if(v.access & VarAccessModeAttribute::REDUCE) {
os << v.type << " lr" << v.name << " = " << getReductionInitialValue(*this, v.access, v.type) << ";" << std::endl;
reductionTargets.emplace_back(v.name, v.type, v.access);
}
}
return reductionTargets;
}

template<typename T, typename S>
void genParallelGroup(CodeStream &os, const Substitutions &kernelSubs, const std::vector<T> &groups, size_t &idStart,
S getPaddedSizeFunc, GroupHandler<T> handler) const
Expand Down
9 changes: 9 additions & 0 deletions include/genn/genn/code_generator/codeGenUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ GENN_EXPORT void genTypeRange(CodeStream &os, const std::string &precision, cons
//--------------------------------------------------------------------------
GENN_EXPORT std::string ensureFtype(const std::string &oldcode, const std::string &type);

//--------------------------------------------------------------------------
//! \brief Get the initial value to start reduction operations from
//--------------------------------------------------------------------------
GENN_EXPORT std::string getReductionInitialValue(const BackendBase &backend, VarAccessMode access, const std::string &type);

//--------------------------------------------------------------------------
//! \brief Generate a reduction operation to reduce value into reduction
//--------------------------------------------------------------------------
GENN_EXPORT std::string getReductionOperation(const std::string &reduction, const std::string &value, VarAccessMode access, const std::string &type);

//--------------------------------------------------------------------------
/*! \brief This function checks for unknown variable definitions and returns a gennError if any are found
Expand Down
7 changes: 2 additions & 5 deletions include/genn/genn/code_generator/groupMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -1544,11 +1544,8 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged<CustomUpdat

boost::uuids::detail::sha1::digest_type getHashDigest() const;

//----------------------------------------------------------------------------
// Static API
//----------------------------------------------------------------------------
static std::string getVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index);
static std::string getVarRefIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index);
std::string getVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const;
std::string getVarRefIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const;

protected:
CustomUpdateWUGroupMergedBase(size_t index, const std::string &precision, const std::string &, const BackendBase &backend,
Expand Down
13 changes: 2 additions & 11 deletions include/genn/genn/currentSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
// GeNN includes
#include "currentSourceModels.h"
#include "gennExport.h"
#include "gennUtils.h"
#include "variableMode.h"

// Forward declarations
Expand Down Expand Up @@ -63,16 +62,8 @@ class GENN_EXPORT CurrentSource
protected:
CurrentSource(const std::string &name, const CurrentSourceModels::Base *currentSourceModel,
const std::vector<double> &params, const std::vector<Models::VarInit> &varInitialisers,
const NeuronGroupInternal *trgNeuronGroup, VarLocation defaultVarLocation,
VarLocation defaultExtraGlobalParamLocation)
: m_Name(name), m_CurrentSourceModel(currentSourceModel), m_Params(params), m_VarInitialisers(varInitialisers),
m_TrgNeuronGroup(trgNeuronGroup), m_VarLocation(varInitialisers.size(), defaultVarLocation),
m_ExtraGlobalParamLocation(currentSourceModel->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation)
{
// Validate names
Utils::validatePopName(name, "Current source");
getCurrentSourceModel()->validate();
}
const NeuronGroupInternal *trgNeuronGroup, VarLocation defaultVarLocation,
VarLocation defaultExtraGlobalParamLocation);

//------------------------------------------------------------------------
// Protected methods
Expand Down
2 changes: 1 addition & 1 deletion include/genn/genn/currentSourceModels.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class GENN_EXPORT Base : public Models::Base
boost::uuids::detail::sha1::digest_type getHashDigest() const;

//! Validate names of parameters etc
using Models::Base::validate;
void validate() const;
};

//----------------------------------------------------------------------------
Expand Down
70 changes: 31 additions & 39 deletions include/genn/genn/customUpdate.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class GENN_EXPORT CustomUpdateBase
//! Is this custom update batched i.e. run in parallel across model batches
bool isBatched() const { return m_Batched; }

//! Does this custom update perform a reduction i.e. reduce some variables from DUPLICATE to SHARED
bool isReduction() const { return getCustomUpdateModel()->isReduction(); }

//! Updates hash with custom update
/*! NOTE: this can only be called after model is finalized */
void updateHash(boost::uuids::detail::sha1 &hash) const;
Expand All @@ -91,47 +94,40 @@ class GENN_EXPORT CustomUpdateBase

boost::uuids::detail::sha1::digest_type getVarLocationHashDigest() const;

//! Helper function to determine whether a custom update should be batched
template<typename R>
void finalizeBatched(unsigned int batchSize, const std::vector<R> &varRefs)
{
// If model has batching at all, custom update should be batched
// if targets of any variable references are duplicated
if(batchSize > 1) {
m_Batched = std::any_of(varRefs.cbegin(), varRefs.cend(),
[](const R &v) { return (v.getVar().access & VarAccessDuplication::DUPLICATE); });

// If custom update is batched, check that any variable references to shared variables are read-only
if(m_Batched) {
const auto modelVarRefs = getCustomUpdateModel()->getVarRefs();
for(size_t i = 0; i < modelVarRefs.size(); i++) {
if((varRefs.at(i).getVar().access & VarAccessDuplication::SHARED)
&& (modelVarRefs.at(i).access != VarAccessMode::READ_ONLY))
{
throw std::runtime_error("Variable references to SHARED variables in batched models must be read-only.");
}
}
}
}
// Otherwise, update should not be batched
else {
m_Batched = false;
}
}

//! Helper function to check if variable reference types match those specified in model
template<typename V>
void checkVarReferenceTypes(const std::vector<V> &varReferences) const
void checkVarReferences(const std::vector<V> &varRefs)
{
const auto modelVarRefs = getCustomUpdateModel()->getVarRefs();

// If target of any variable references is duplicated, custom update should be batched
m_Batched = std::any_of(varRefs.cbegin(), varRefs.cend(),
[](const V &v) { return (v.getVar().access & VarAccessDuplication::DUPLICATE); });

// Loop through all variable references
const auto varRefs = getCustomUpdateModel()->getVarRefs();
for(size_t i = 0; i < varReferences.size(); i++) {
const auto varRef = varReferences.at(i);
for(size_t i = 0; i < varRefs.size(); i++) {
const auto varRef = varRefs.at(i);
const auto modelVarRef = modelVarRefs.at(i);

// Check types of variable references against those specified in model
// **THINK** due to GeNN's current string-based type system this is rather conservative
if(varRef.getVar().type != varRefs.at(i).type) {
throw std::runtime_error("Incompatible type for variable reference '" + getCustomUpdateModel()->getVarRefs().at(i).name + "'");
if(varRef.getVar().type != modelVarRef.type) {
throw std::runtime_error("Incompatible type for variable reference '" + modelVarRef.name + "'");
}

// Check that no reduction targets reference duplicated variables
if((varRef.getVar().access & VarAccessDuplication::DUPLICATE)
&& (modelVarRef.access & VarAccessModeAttribute::REDUCE))
{
throw std::runtime_error("Reduction target variable reference must be to SHARED variables.");
}

// If custom update is batched, check that any variable references to shared variables are read-only
// **NOTE** if custom update isn't batched, it's totally fine to write to shared variables
if(m_Batched && (varRef.getVar().access & VarAccessDuplication::SHARED)
&& (modelVarRef.access != VarAccessMode::READ_ONLY))
{
throw std::runtime_error("Variable references to SHARED variables in batched custom updates must be read-only.");
}
}
}
Expand Down Expand Up @@ -179,7 +175,7 @@ class GENN_EXPORT CustomUpdate : public CustomUpdateBase
//------------------------------------------------------------------------
// Protected methods
//------------------------------------------------------------------------
void finalize(unsigned int batchSize);
void finalize();

//------------------------------------------------------------------------
// Protected const methods
Expand Down Expand Up @@ -220,10 +216,6 @@ class GENN_EXPORT CustomUpdateWU : public CustomUpdateBase
const std::vector<Models::VarInit> &varInitialisers, const std::vector<Models::WUVarReference> &varReferences,
VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation);

//------------------------------------------------------------------------
// Protected methods
//------------------------------------------------------------------------
void finalize(unsigned int batchSize);

//------------------------------------------------------------------------
// Protected const methods
Expand Down
3 changes: 2 additions & 1 deletion include/genn/genn/customUpdateInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class CustomUpdateInternal : public CustomUpdate
using CustomUpdateBase::isInitRNGRequired;
using CustomUpdateBase::isZeroCopyEnabled;
using CustomUpdateBase::isBatched;
using CustomUpdateBase::isReduction;
using CustomUpdateBase::getVarLocationHashDigest;

using CustomUpdate::finalize;
Expand Down Expand Up @@ -51,9 +52,9 @@ class CustomUpdateWUInternal : public CustomUpdateWU
using CustomUpdateBase::isInitRNGRequired;
using CustomUpdateBase::isZeroCopyEnabled;
using CustomUpdateBase::isBatched;
using CustomUpdateBase::isReduction;
using CustomUpdateBase::getVarLocationHashDigest;

using CustomUpdateWU::finalize;
using CustomUpdateWU::getHashDigest;
using CustomUpdateWU::getInitHashDigest;
using CustomUpdateWU::getSynapseGroup;
Expand Down
3 changes: 3 additions & 0 deletions include/genn/genn/customUpdateModels.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class GENN_EXPORT Base : public Models::Base
//! Update hash from model
boost::uuids::detail::sha1::digest_type getHashDigest() const;

//! Is this custom update a reduction operation?
bool isReduction() const;

//! Validate names of parameters etc
void validate() const;
};
Expand Down
5 changes: 5 additions & 0 deletions include/genn/genn/gennUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ GENN_EXPORT bool isTypePointer(const std::string &type);
//--------------------------------------------------------------------------
GENN_EXPORT bool isTypePointerToPointer(const std::string &type);

//--------------------------------------------------------------------------
//! \brief Function to determine whether a string containing a type is floating point
//--------------------------------------------------------------------------
GENN_EXPORT bool isTypeFloatingPoint(const std::string &type);

//--------------------------------------------------------------------------
//! \brief Assuming type is a string containing a pointer type, function to return the underlying type
//--------------------------------------------------------------------------
Expand Down
13 changes: 1 addition & 12 deletions include/genn/genn/neuronGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

// GeNN includes
#include "gennExport.h"
#include "gennUtils.h"
#include "neuronModels.h"
#include "variableMode.h"

Expand Down Expand Up @@ -199,17 +198,7 @@ class GENN_EXPORT NeuronGroup
protected:
NeuronGroup(const std::string &name, int numNeurons, const NeuronModels::Base *neuronModel,
const std::vector<double> &params, const std::vector<Models::VarInit> &varInitialisers,
VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) :
m_Name(name), m_NumNeurons(numNeurons), m_NeuronModel(neuronModel), m_Params(params), m_VarInitialisers(varInitialisers),
m_NumDelaySlots(1), m_VarQueueRequired(varInitialisers.size(), false), m_SpikeLocation(defaultVarLocation), m_SpikeEventLocation(defaultVarLocation),
m_SpikeTimeLocation(defaultVarLocation), m_PrevSpikeTimeLocation(defaultVarLocation), m_SpikeEventTimeLocation(defaultVarLocation), m_PrevSpikeEventTimeLocation(defaultVarLocation),
m_VarLocation(varInitialisers.size(), defaultVarLocation), m_ExtraGlobalParamLocation(neuronModel->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation),
m_SpikeRecordingEnabled(false), m_SpikeEventRecordingEnabled(false)
{
// Validate names
Utils::validatePopName(name, "Neuron group");
getNeuronModel()->validate();
}
VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation);

//------------------------------------------------------------------------
// Protected methods
Expand Down
Loading