Skip to content

Commit

Permalink
Add MathProvider to GRU layer creation and to createActivation.
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeoliphant committed Nov 4, 2024
1 parent 3fe1f9f commit 42a2a2c
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions RTNeural/model_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,10 @@ namespace json_parser
}

/** Creates a GRULayer from a json representation of the layer weights. */
template <typename T>
std::unique_ptr<GRULayer<T>> createGRU(int in_size, int out_size, const nlohmann::json& weights)
template <typename T, typename MathsProvider = DefaultMathsProvider>
std::unique_ptr<GRULayer<T, MathsProvider>> createGRU(int in_size, int out_size, const nlohmann::json& weights)
{
auto gru = std::make_unique<GRULayer<T>>(in_size, out_size);
auto gru = std::make_unique<GRULayer<T, MathsProvider>>(in_size, out_size);
loadGRU<T>(*gru.get(), weights);
return std::move(gru);
}
Expand Down Expand Up @@ -560,24 +560,24 @@ namespace json_parser
}

/** Creates an activation layer of a given type. */
template <typename T>
template <typename T, typename MathsProvider = DefaultMathsProvider>
std::unique_ptr<Activation<T>>
createActivation(const std::string& activationType, int dims)
{
if(activationType == "tanh")
return std::make_unique<TanhActivation<T>>(dims);
return std::make_unique<TanhActivation<T, MathsProvider>>(dims);

if(activationType == "relu")
return std::make_unique<ReLuActivation<T>>(dims);

if(activationType == "sigmoid")
return std::make_unique<SigmoidActivation<T>>(dims);
return std::make_unique<SigmoidActivation<T, MathsProvider>>(dims);

if(activationType == "softmax")
return std::make_unique<SoftmaxActivation<T>>(dims);
return std::make_unique<SoftmaxActivation<T, MathsProvider>>(dims);

if(activationType == "elu")
return std::make_unique<ELuActivation<T>>(dims);
return std::make_unique<ELuActivation<T, MathsProvider>>(dims);

return {};
}
Expand Down Expand Up @@ -639,7 +639,7 @@ namespace json_parser
if(!activationType.empty())
{
debug_print(" activation: " + activationType, debug);
auto activation = createActivation<T>(activationType, layerDims);
auto activation = createActivation<T, MathsProvider>(activationType, layerDims);
_model->addLayer(activation.release());
}
}
Expand Down Expand Up @@ -683,7 +683,7 @@ namespace json_parser
}
else if(type == "gru")
{
auto gru = createGRU<T>(model->getNextInSize(), layerDims, weights);
auto gru = createGRU<T, MathsProvider>(model->getNextInSize(), layerDims, weights);
model->addLayer(gru.release());
}
else if(type == "lstm")
Expand Down

0 comments on commit 42a2a2c

Please sign in to comment.