Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nam namespace #93

Merged
merged 1 commit into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions NAM/activations.cpp
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
#include "activations.h"

activations::ActivationTanh _TANH = activations::ActivationTanh();
activations::ActivationFastTanh _FAST_TANH = activations::ActivationFastTanh();
activations::ActivationHardTanh _HARD_TANH = activations::ActivationHardTanh();
activations::ActivationReLU _RELU = activations::ActivationReLU();
activations::ActivationSigmoid _SIGMOID = activations::ActivationSigmoid();
nam::activations::ActivationTanh _TANH = nam::activations::ActivationTanh();
nam::activations::ActivationFastTanh _FAST_TANH = nam::activations::ActivationFastTanh();
nam::activations::ActivationHardTanh _HARD_TANH = nam::activations::ActivationHardTanh();
nam::activations::ActivationReLU _RELU = nam::activations::ActivationReLU();
nam::activations::ActivationSigmoid _SIGMOID = nam::activations::ActivationSigmoid();

bool activations::Activation::using_fast_tanh = false;
bool nam::activations::Activation::using_fast_tanh = false;

std::unordered_map<std::string, activations::Activation*> activations::Activation::_activations =
std::unordered_map<std::string, nam::activations::Activation*> nam::activations::Activation::_activations =
{{"Tanh", &_TANH}, {"Hardtanh", &_HARD_TANH}, {"Fasttanh", &_FAST_TANH}, {"ReLU", &_RELU}, {"Sigmoid", &_SIGMOID}};

activations::Activation* tanh_bak = nullptr;
nam::activations::Activation* tanh_bak = nullptr;

activations::Activation* activations::Activation::get_activation(const std::string name)
nam::activations::Activation* nam::activations::Activation::get_activation(const std::string name)
{
if (_activations.find(name) == _activations.end())
return nullptr;

return _activations[name];
}

void activations::Activation::enable_fast_tanh()
void nam::activations::Activation::enable_fast_tanh()
{
activations::Activation::using_fast_tanh = true;
nam::activations::Activation::using_fast_tanh = true;

if (_activations["Tanh"] != _activations["Fasttanh"])
{
Expand All @@ -32,9 +32,9 @@ void activations::Activation::enable_fast_tanh()
}
}

void activations::Activation::disable_fast_tanh()
void nam::activations::Activation::disable_fast_tanh()
{
activations::Activation::using_fast_tanh = false;
nam::activations::Activation::using_fast_tanh = false;

if (_activations["Tanh"] == _activations["Fasttanh"])
{
Expand Down
6 changes: 4 additions & 2 deletions NAM/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <unordered_map>
#include <Eigen/Dense>

namespace nam
{
namespace activations
{
inline float relu(float x)
Expand Down Expand Up @@ -119,5 +121,5 @@ class ActivationSigmoid : public Activation
}
}
};

}; // namespace activations
}; // namespace activations
}; // namespace nam
37 changes: 19 additions & 18 deletions NAM/convnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "util.h"
#include "convnet.h"

convnet::BatchNorm::BatchNorm(const int dim, std::vector<float>::iterator& params)
nam::convnet::BatchNorm::BatchNorm(const int dim, std::vector<float>::iterator& params)
{
// Extract from param buffer
Eigen::VectorXf running_mean(dim);
Expand All @@ -37,7 +37,7 @@ convnet::BatchNorm::BatchNorm(const int dim, std::vector<float>::iterator& param
this->loc = _bias - this->scale.cwiseProduct(running_mean);
}

void convnet::BatchNorm::process_(Eigen::MatrixXf& x, const long i_start, const long i_end) const
void nam::convnet::BatchNorm::process_(Eigen::MatrixXf& x, const long i_start, const long i_end) const
{
// todo using colwise?
// #speed but conv probably dominates
Expand All @@ -48,9 +48,9 @@ void convnet::BatchNorm::process_(Eigen::MatrixXf& x, const long i_start, const
}
}

void convnet::ConvNetBlock::set_params_(const int in_channels, const int out_channels, const int _dilation,
const bool batchnorm, const std::string activation,
std::vector<float>::iterator& params)
void nam::convnet::ConvNetBlock::set_params_(const int in_channels, const int out_channels, const int _dilation,
const bool batchnorm, const std::string activation,
std::vector<float>::iterator& params)
{
this->_batchnorm = batchnorm;
// HACK 2 kernel
Expand All @@ -60,8 +60,8 @@ void convnet::ConvNetBlock::set_params_(const int in_channels, const int out_cha
this->activation = activations::Activation::get_activation(activation);
}

void convnet::ConvNetBlock::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start,
const long i_end) const
void nam::convnet::ConvNetBlock::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start,
const long i_end) const
{
const long ncols = i_end - i_start;
this->conv.process_(input, output, i_start, ncols, i_start);
Expand All @@ -71,30 +71,31 @@ void convnet::ConvNetBlock::process_(const Eigen::MatrixXf& input, Eigen::Matrix
this->activation->apply(output.middleCols(i_start, ncols));
}

long convnet::ConvNetBlock::get_out_channels() const
long nam::convnet::ConvNetBlock::get_out_channels() const
{
return this->conv.get_out_channels();
}

convnet::_Head::_Head(const int channels, std::vector<float>::iterator& params)
nam::convnet::_Head::_Head(const int channels, std::vector<float>::iterator& params)
{
this->_weight.resize(channels);
for (int i = 0; i < channels; i++)
this->_weight[i] = *(params++);
this->_bias = *(params++);
}

void convnet::_Head::process_(const Eigen::MatrixXf& input, Eigen::VectorXf& output, const long i_start,
const long i_end) const
void nam::convnet::_Head::process_(const Eigen::MatrixXf& input, Eigen::VectorXf& output, const long i_start,
const long i_end) const
{
const long length = i_end - i_start;
output.resize(length);
for (long i = 0, j = i_start; i < length; i++, j++)
output(i) = this->_bias + input.col(j).dot(this->_weight);
}

convnet::ConvNet::ConvNet(const int channels, const std::vector<int>& dilations, const bool batchnorm,
const std::string activation, std::vector<float>& params, const double expected_sample_rate)
nam::convnet::ConvNet::ConvNet(const int channels, const std::vector<int>& dilations, const bool batchnorm,
const std::string activation, std::vector<float>& params,
const double expected_sample_rate)
: Buffer(*std::max_element(dilations.begin(), dilations.end()), expected_sample_rate)
{
this->_verify_params(channels, dilations, batchnorm, params.size());
Expand All @@ -116,7 +117,7 @@ convnet::ConvNet::ConvNet(const int channels, const std::vector<int>& dilations,
}


void convnet::ConvNet::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames)
void nam::convnet::ConvNet::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames)

{
this->_update_buffers_(input, num_frames);
Expand All @@ -135,13 +136,13 @@ void convnet::ConvNet::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int
output[s] = this->_head_output(s);
}

void convnet::ConvNet::_verify_params(const int channels, const std::vector<int>& dilations, const bool batchnorm,
const size_t actual_params)
void nam::convnet::ConvNet::_verify_params(const int channels, const std::vector<int>& dilations, const bool batchnorm,
const size_t actual_params)
{
// TODO
}

void convnet::ConvNet::_update_buffers_(NAM_SAMPLE* input, const int num_frames)
void nam::convnet::ConvNet::_update_buffers_(NAM_SAMPLE* input, const int num_frames)
{
this->Buffer::_update_buffers_(input, num_frames);

Expand All @@ -163,7 +164,7 @@ void convnet::ConvNet::_update_buffers_(NAM_SAMPLE* input, const int num_frames)
}
}

void convnet::ConvNet::_rewind_buffers_()
void nam::convnet::ConvNet::_rewind_buffers_()
{
// Need to rewind the block vals first because Buffer::rewind_buffers()
// resets the offset index
Expand Down
3 changes: 3 additions & 0 deletions NAM/convnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

#include <Eigen/Dense>

namespace nam
{
namespace convnet
{
// Custom Conv that avoids re-computing on pieces of the input and trusts
Expand Down Expand Up @@ -82,3 +84,4 @@ class ConvNet : public Buffer
void process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames) override;
};
}; // namespace convnet
}; // namespace nam
Loading
Loading