diff --git a/src/nupic/algorithms/SDRClassifier.cpp b/src/nupic/algorithms/SDRClassifier.cpp index 56aadf528f..cd7d96d77b 100644 --- a/src/nupic/algorithms/SDRClassifier.cpp +++ b/src/nupic/algorithms/SDRClassifier.cpp @@ -203,15 +203,7 @@ void SDRClassifier::infer_(const vector &patternNZ, add(likelihoods->begin(), likelihoods->end(), weights.begin(bit), weights.begin(bit + 1)); } - Real64 maxLikelihoods = *max_element(likelihoods->begin(), likelihoods->end()); - for (auto likelihood : *likelihoods) { - likelihood -= maxLikelihoods; - } - range_exp(1.0, *likelihoods); - Real64 sumLikelihoods = accumulate(likelihoods->begin(), likelihoods->end(), 0); - for (auto likelihood : *likelihoods) { - likelihood /= sumLikelihoods; - } + softmax_(likelihoods->begin(), likelihoods->end()); } } @@ -226,15 +218,7 @@ vector SDRClassifier::calculateError_(const vector &bucketIdxList, add(likelihoods.begin(), likelihoods.end(), weights.begin(bit), weights.begin(bit + 1)); } - Real64 maxLikelihoods = *max_element(likelihoods.begin(), likelihoods.end()); - for (auto likelihood : likelihoods) { - likelihood -= maxLikelihoods; - } - range_exp(1.0, likelihoods); - Real64 sumLikelihoods = accumulate(likelihoods.begin(), likelihoods.end(), 0); - for (auto likelihood : likelihoods) { - likelihood /= sumLikelihoods; - } + softmax_(likelihoods.begin(), likelihoods.end()); // compute target likelihoods vector targetDistribution(maxBucketIdx_ + 1, 0.0); @@ -246,6 +230,19 @@ vector SDRClassifier::calculateError_(const vector &bucketIdxList, return likelihoods; } +template +void SDRClassifier::softmax_(Iterator begin, Iterator end) { + Iterator maxItr= max_element(begin, end); + for (auto itr = begin; itr != end; ++itr) { + *itr -= *maxItr; + } + range_exp(1.0, begin, end); + typename std::iterator_traits::value_type sum = accumulate(begin, end, 0); + for (auto itr = begin; itr != end; ++itr) { + *itr /= sum; + } +} + UInt SDRClassifier::version() const { return version_; } UInt SDRClassifier::getVerbosity() const { return verbosity_; } diff --git a/src/nupic/algorithms/SDRClassifier.hpp b/src/nupic/algorithms/SDRClassifier.hpp index a731b55af1..e9ee0a5f53 100644 --- a/src/nupic/algorithms/SDRClassifier.hpp +++ b/src/nupic/algorithms/SDRClassifier.hpp @@ -163,6 +163,10 @@ class SDRClassifier : public Serializable { vector calculateError_(const vector &bucketIdxList, const vector patternNZ, UInt step); + // softmax function + template + void softmax_(Iterator begin, Iterator end); + // The list of prediction steps to learn and infer. vector steps_;