Skip to content

Commit

Permalink
Also rename groups_of -> groups
Browse files Browse the repository at this point in the history
  • Loading branch information
purefunctor committed Nov 29, 2023
1 parent a48119c commit 7aebcac
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions RTNeural/ModelT.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,21 @@ namespace modelt_detail
}
}

template <typename T, int in_size, int out_size, int kernel_size, int dilation_rate, int groups_of, bool dynamic_state>
void loadLayer(Conv1DT<T, in_size, out_size, kernel_size, dilation_rate, groups_of, dynamic_state>& conv, int& json_stream_idx, const nlohmann::json& l,
template <typename T, int in_size, int out_size, int kernel_size, int dilation_rate, int groups, bool dynamic_state>
void loadLayer(Conv1DT<T, in_size, out_size, kernel_size, dilation_rate, groups, dynamic_state>& conv, int& json_stream_idx, const nlohmann::json& l,
const std::string& type, int layerDims, bool debug)
{
using namespace json_parser;

debug_print("Layer: " + type, debug);
debug_print(" Dims: " + std::to_string(layerDims), debug);
const auto& weights = l["weights"];
const auto kernel = l["kernel_size"].back().get<int>();
const auto dilation = l["dilation"].back().get<int>();
const auto groups = l.value("groups", 1);
const auto& l_weights = l["weights"];
const auto l_kernel = l["kernel_size"].back().get<int>();
const auto l_dilation = l["dilation"].back().get<int>();
const auto l_groups = l.value("groups", 1);

if(checkConv1D<T>(conv, type, layerDims, kernel, dilation, groups, debug))
loadConv1D<T>(conv, kernel, dilation, weights);
if(checkConv1D<T>(conv, type, layerDims, l_kernel, l_dilation, l_groups, debug))
loadConv1D<T>(conv, l_kernel, l_dilation, l_weights);

if(!l.contains("activation"))
{
Expand Down

0 comments on commit 7aebcac

Please sign in to comment.