Skip to content

Commit

Permalink
Added Hardtanh activation function (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeoliphant authored Mar 31, 2023
1 parent 070811a commit 2e5e1b2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
26 changes: 26 additions & 0 deletions dsp/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ inline float fast_tanh_(const float x) {
(2.44506634652299f + x2) * fabs(x + 0.814642734961073f * x * ax)));
}

inline float hard_tanh_(const float x) {
const float t = x < -1 ? -1 : x;
return t > 1 ? 1 : t;
}

void tanh_(Eigen::MatrixXf &x, const long i_start, const long i_end,
const long j_start, const long j_end) {
for (long j = j_start; j < j_end; j++)
Expand All @@ -227,6 +232,27 @@ void tanh_(Eigen::MatrixXf &x) {
}
}

void hard_tanh_(Eigen::MatrixXf& x, const long i_start, const long i_end,
const long j_start, const long j_end) {
for (long j = j_start; j < j_end; j++)
for (long i = i_start; i < i_end; i++)
x(i, j) = hard_tanh_(x(i, j));
}

void hard_tanh_(Eigen::MatrixXf& x, const long j_start, const long j_end) {
hard_tanh_(x, 0, x.rows(), j_start, j_end);
}

void hard_tanh_(Eigen::MatrixXf& x) {
float* ptr = x.data();

long size = x.rows() * x.cols();

for (long pos = 0; pos < size; pos++) {
ptr[pos] = hard_tanh_(ptr[pos]);
}
}

void Conv1D::set_params_(std::vector<float>::iterator &params) {
if (this->_weight.size() > 0) {
const long out_channels = this->_weight[0].rows();
Expand Down
8 changes: 8 additions & 0 deletions dsp/dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ void tanh_(Eigen::MatrixXf &x, const long i_start, const long i_end);

void tanh_(Eigen::MatrixXf &x);

// In-place Hardtanh on (N,M) array
void hard_tanh_(Eigen::MatrixXf& x, const long i_start, const long i_end,
const long j_start, const long j_end);
// Subset of the columns
void hard_tanh_(Eigen::MatrixXf& x, const long i_start, const long i_end);

void hard_tanh_(Eigen::MatrixXf& x);

class Conv1D {
public:
Conv1D() { this->_dilation = 1; };
Expand Down
4 changes: 3 additions & 1 deletion dsp/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ void wavenet::_Layer::process_(const Eigen::MatrixXf &input,
this->_conv.process_(input, this->_z, i_start, ncols, 0);
// Mix-in condition
this->_z += this->_input_mixin.process(condition);
if (this->_activation == "Tanh")
if (this->_activation == "Hardtanh")
hard_tanh_(this->_z);
else if (this->_activation == "Tanh")
tanh_(this->_z);
else if (this->_activation == "ReLU")
relu_(this->_z, 0, channels, 0, this->_z.cols());
Expand Down

0 comments on commit 2e5e1b2

Please sign in to comment.