Skip to content

Commit

Permalink
Merge pull request #443 from genn-team/name_validate
Browse files Browse the repository at this point in the history
Validate population and variable names
  • Loading branch information
neworderofjamie authored Jul 22, 2021
2 parents 4e1d00e + d4d4f0f commit c27d722
Show file tree
Hide file tree
Showing 26 changed files with 301 additions and 1 deletion.
4 changes: 4 additions & 0 deletions include/genn/genn/currentSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// GeNN includes
#include "currentSourceModels.h"
#include "gennExport.h"
#include "gennUtils.h"
#include "variableMode.h"

// Forward declarations
Expand Down Expand Up @@ -68,6 +69,9 @@ class GENN_EXPORT CurrentSource
m_TrgNeuronGroup(trgNeuronGroup), m_VarLocation(varInitialisers.size(), defaultVarLocation),
m_ExtraGlobalParamLocation(currentSourceModel->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation)
{
// Validate names
Utils::validateVarPopName(name, "Current source");
getCurrentSourceModel()->validate();
}

//------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions include/genn/genn/currentSourceModels.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class GENN_EXPORT Base : public Models::Base
//----------------------------------------------------------------------------
//! Update hash from model
boost::uuids::detail::sha1::digest_type getHashDigest() const;

//! Validate names of parameters etc
using Models::Base::validate;
};

//----------------------------------------------------------------------------
Expand Down
5 changes: 5 additions & 0 deletions include/genn/genn/customUpdate.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

// GeNN includes
#include "gennExport.h"
#include "gennUtils.h"
#include "customUpdateModels.h"
#include "variableMode.h"

Expand Down Expand Up @@ -56,6 +57,10 @@ class GENN_EXPORT CustomUpdateBase
m_ExtraGlobalParamLocation(customUpdateModel->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation),
m_Batched(false)
{
// Validate names
Utils::validateVarPopName(name, "Custom update");
Utils::validateVarPopName(updateGroupName, "Custom update group name");
getCustomUpdateModel()->validate();
}

//------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions include/genn/genn/customUpdateModels.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class GENN_EXPORT Base : public Models::Base
//----------------------------------------------------------------------------
//! Update hash from model
boost::uuids::detail::sha1::digest_type getHashDigest() const;

//! Validate names of parameters etc
void validate() const;
};

//----------------------------------------------------------------------------
Expand Down
21 changes: 21 additions & 0 deletions include/genn/genn/gennUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@ GENN_EXPORT bool isTypePointerToPointer(const std::string &type);
//--------------------------------------------------------------------------
GENN_EXPORT std::string getUnderlyingType(const std::string &type);

//--------------------------------------------------------------------------
//! \brief Is the variable/population name valid? GeNN variables and population names must obey C variable naming rules
//--------------------------------------------------------------------------
GENN_EXPORT void validateVarPopName(const std::string &name, const std::string &description);

//--------------------------------------------------------------------------
//! \brief Are all the parameter names in vector valid? GeNN variables and population names must obey C variable naming rules
//--------------------------------------------------------------------------
GENN_EXPORT void validateParamNames(const std::vector<std::string> &paramNames);

//--------------------------------------------------------------------------
//! \brief Are the 'name' fields of all structs in vector valid? GeNN variables and population names must obey C variable naming rules
//--------------------------------------------------------------------------
template<typename T>
void validateVecNames(const std::vector<T> &vec, const std::string &description)
{
for(const auto &v : vec) {
validateVarPopName(v.name, description);
}
}

//--------------------------------------------------------------------------
//! \brief This function writes a floating point value to a stream -setting the precision so no digits are lost
//--------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions include/genn/genn/initSparseConnectivitySnippet.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class GENN_EXPORT Base : public Snippet::Base
//------------------------------------------------------------------------
//! Update hash from snippet
boost::uuids::detail::sha1::digest_type getHashDigest() const;

//! Validate names of parameters etc
void validate() const;
};

//----------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions include/genn/genn/initVarSnippet.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class GENN_EXPORT Base : public Snippet::Base
//! Update hash from snippet
boost::uuids::detail::sha1::digest_type getHashDigest() const;

//! Validate names of parameters etc
using Snippet::Base::validate;

//! Does this var init snippet require kernel-based connectivity
bool requiresKernel() const;
};
Expand Down
3 changes: 3 additions & 0 deletions include/genn/genn/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ class GENN_EXPORT Base : public Snippet::Base
// Protected methods
//------------------------------------------------------------------------
void updateHash(boost::uuids::detail::sha1 &hash) const;

//! Validate names of parameters etc
void validate() const;
};


Expand Down
4 changes: 4 additions & 0 deletions include/genn/genn/neuronGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// GeNN includes
#include "gennExport.h"
#include "gennUtils.h"
#include "neuronModels.h"
#include "variableMode.h"

Expand Down Expand Up @@ -205,6 +206,9 @@ class GENN_EXPORT NeuronGroup
m_VarLocation(varInitialisers.size(), defaultVarLocation), m_ExtraGlobalParamLocation(neuronModel->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation),
m_SpikeRecordingEnabled(false), m_SpikeEventRecordingEnabled(false)
{
// Validate names
Utils::validateVarPopName(name, "Neuron group");
getNeuronModel()->validate();
}

//------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions include/genn/genn/neuronModels.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class GENN_EXPORT Base : public Models::Base
//----------------------------------------------------------------------------
//! Update hash from model
boost::uuids::detail::sha1::digest_type getHashDigest() const;

//! Validate names of parameters etc
void validate() const;
};

//----------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions include/genn/genn/postsynapticModels.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class GENN_EXPORT Base : public Models::Base
//----------------------------------------------------------------------------
//! Update hash from model
boost::uuids::detail::sha1::digest_type getHashDigest() const;

//! Validate names of parameters etc
using Models::Base::validate;
};

//----------------------------------------------------------------------------
Expand Down
10 changes: 10 additions & 0 deletions include/genn/genn/snippet.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,14 @@ class GENN_EXPORT Base
Utils::updateHash(getExtraGlobalParams(), hash);
}

//! Validate names of parameters etc
void validate() const
{
Utils::validateParamNames(getParamNames());
Utils::validateVecNames(getDerivedParams(), "Derived parameter");
Utils::validateVecNames(getExtraGlobalParams(), "Derived parameter");
}

//------------------------------------------------------------------------
// Protected static helpers
//------------------------------------------------------------------------
Expand Down Expand Up @@ -245,6 +253,8 @@ class Init
Init(const SnippetBase *snippet, const std::vector<double> &params)
: m_Snippet(snippet), m_Params(params)
{
// Validate names
getSnippet()->validate();
}

//----------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions include/genn/genn/weightUpdateModels.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ class GENN_EXPORT Base : public Models::Base

//! Update hash from model
boost::uuids::detail::sha1::digest_type getHashDigest() const;

//! Validate names of parameters etc
void validate() const;
};

//----------------------------------------------------------------------------
Expand Down
8 changes: 8 additions & 0 deletions src/genn/genn/customUpdateModels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,11 @@ boost::uuids::detail::sha1::digest_type CustomUpdateModels::Base::getHashDigest(
Utils::updateHash(getVarRefs(), hash);
return hash.get_digest();
}
//----------------------------------------------------------------------------
void CustomUpdateModels::Base::validate() const
{
// Superclass
Models::Base::validate();

Utils::validateVecNames(getVarRefs(), "Variable reference");
}
27 changes: 27 additions & 0 deletions src/genn/genn/gennUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,31 @@ std::string getUnderlyingType(const std::string &type)
return type.substr(0, type.length() - 1);
}
}
//--------------------------------------------------------------------------
void validateVarPopName(const std::string &name, const std::string &description)
{
// Empty names aren't valid
if(name.empty()) {
throw std::runtime_error(description + " name invalid: cannot be empty");
}

// If first character's a number, name isn't valid
if(std::isdigit(name.front())) {
throw std::runtime_error(description + " name invalid: '" + name + "' starts with a digit");
}

// If any characters aren't underscores or alphanumeric, name isn't valud
if(std::any_of(name.cbegin(), name.cend(),
[](char c) { return (c != '_') && !std::isalnum(c); }))
{
throw std::runtime_error(description + " name invalid: '" + name + "' contains an illegal character");
}
}
//--------------------------------------------------------------------------
void validateParamNames(const std::vector<std::string> &paramNames)
{
for(const std::string &p : paramNames) {
validateVarPopName(p, "Parameter");
}
}
} // namespace utils
8 changes: 8 additions & 0 deletions src/genn/genn/initSparseConnectivitySnippet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,11 @@ boost::uuids::detail::sha1::digest_type InitSparseConnectivitySnippet::Base::get
Utils::updateHash(getHostInitCode(), hash);
return hash.get_digest();
}
//----------------------------------------------------------------------------
void InitSparseConnectivitySnippet::Base::validate() const
{
// Superclass
Snippet::Base::validate();
Utils::validateVecNames(getRowBuildStateVars(), "Row building state variable");
Utils::validateVecNames(getColBuildStateVars(), "Column building state variable");
}
11 changes: 10 additions & 1 deletion src/genn/genn/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ void Base::updateHash(boost::uuids::detail::sha1 &hash) const

Utils::updateHash(getVars(), hash);
}
//----------------------------------------------------------------------------
void Base::validate() const
{
// Superclass
Snippet::Base::validate();

Utils::validateVecNames(getVars(), "Variable");
}

//----------------------------------------------------------------------------
// VarReference
//----------------------------------------------------------------------------
Expand Down Expand Up @@ -171,4 +180,4 @@ void Models::updateHash(const WUVarReference &v, boost::uuids::detail::sha1 &has
Utils::updateHash(v.getTransposeTargetName(), hash);
Utils::updateHash(v.getTransposeVarIndex(), hash);
}
}
}
8 changes: 8 additions & 0 deletions src/genn/genn/neuronModels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,11 @@ boost::uuids::detail::sha1::digest_type NeuronModels::Base::getHashDigest() cons
Utils::updateHash(getAdditionalInputVars(), hash);
return hash.get_digest();
}
//----------------------------------------------------------------------------
void NeuronModels::Base::validate() const
{
// Superclass
Models::Base::validate();

Utils::validateVecNames(getAdditionalInputVars(), "Additional input variable");
}
5 changes: 5 additions & 0 deletions src/genn/genn/synapseGroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,11 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType
m_ConnectivityInitialiser(connectivityInitialiser), m_SparseConnectivityLocation(defaultSparseConnectivityLocation),
m_ConnectivityExtraGlobalParamLocation(connectivityInitialiser.getSnippet()->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), m_PSModelTargetName(name)
{
// Validate names
Utils::validateVarPopName(name, "Synapse group");
getWUModel()->validate();
getPSModel()->validate();

// If connectivity is procedural
if(m_MatrixType & SynapseMatrixConnectivity::PROCEDURAL) {
// If there's no row build code, give an error
Expand Down
9 changes: 9 additions & 0 deletions src/genn/genn/weightUpdateModels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,12 @@ boost::uuids::detail::sha1::digest_type WeightUpdateModels::Base::getHashDigest(
// Return digest
return hash.get_digest();
}
//----------------------------------------------------------------------------
void WeightUpdateModels::Base::validate() const
{
// Superclass
Models::Base::validate();

Utils::validateVecNames(getPreVars(), "Presynaptic variable");
Utils::validateVecNames(getPostVars(), "Presynaptic variable");
}
16 changes: 16 additions & 0 deletions tests/unit/currentSource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,19 @@ TEST(CurrentSource, CompareSameParameters)
CurrentSourceInternal *cs1Internal = static_cast<CurrentSourceInternal*>(cs1);
ASSERT_EQ(cs0Internal->getHashDigest(), cs1Internal->getHashDigest());
}

TEST(CurrentSource, InvalidName)
{
NeuronModels::Izhikevich::ParamValues paramVals(0.02, 0.2, -65.0, 8.0);
NeuronModels::Izhikevich::VarValues varVals(0.0, 0.0);

ModelSpec model;
auto *pop = model.addNeuronPopulation<NeuronModels::Izhikevich>("Pop", 10, paramVals, varVals);

try {
model.addCurrentSource<CurrentSourceModels::DC>("6CS", "Pop", {1.0}, {});
FAIL();
}
catch(const std::runtime_error &) {
}
}
42 changes: 42 additions & 0 deletions tests/unit/customUpdate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,3 +582,45 @@ TEST(CustomUpdates, CompareDifferentWUConnectivity)
ASSERT_TRUE(modelSpecMerged.getMergedCustomWUUpdateDenseInitGroups().size() == 1);
ASSERT_TRUE(modelSpecMerged.getMergedCustomWUUpdateSparseInitGroups().size() == 1);
}
//--------------------------------------------------------------------------
TEST(CustomUpdates, InvalidName)
{
ModelSpec model;

// Add neuron group to model
NeuronModels::Izhikevich::ParamValues paramVals(0.02, 0.2, -65.0, 8.0);
NeuronModels::Izhikevich::VarValues varVals(0.0, 0.0);
auto *ng1 = model.addNeuronPopulation<NeuronModels::Izhikevich>("Neuron1", 10, paramVals, varVals);

Sum::VarValues sumVarValues(0.0);
Sum::VarReferences sumVarReferences1(createVarRef(ng1, "V"), createVarRef(ng1, "U"));

try {
model.addCustomUpdate<Sum>("1Sum", "CustomUpdate",
{}, sumVarValues, sumVarReferences1);
FAIL();
}
catch(const std::runtime_error &) {
}
}
//--------------------------------------------------------------------------
TEST(CustomUpdates, InvalidUpdateGroupName)
{
ModelSpec model;

// Add neuron group to model
NeuronModels::Izhikevich::ParamValues paramVals(0.02, 0.2, -65.0, 8.0);
NeuronModels::Izhikevich::VarValues varVals(0.0, 0.0);
auto *ng1 = model.addNeuronPopulation<NeuronModels::Izhikevich>("Neuron1", 10, paramVals, varVals);

Sum::VarValues sumVarValues(0.0);
Sum::VarReferences sumVarReferences1(createVarRef(ng1, "V"), createVarRef(ng1, "U"));

try {
model.addCustomUpdate<Sum>("Sum", "1CustomUpdate",
{}, sumVarValues, sumVarReferences1);
FAIL();
}
catch(const std::runtime_error &) {
}
}
Loading

0 comments on commit c27d722

Please sign in to comment.