Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre-warm WaveNet on creation over the size of the receptive field #71

Merged
merged 1 commit into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 17 additions & 35 deletions NAM/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,23 @@ wavenet::WaveNet::WaveNet(const double loudness, const std::vector<wavenet::Laye
}
this->_head_output.resize(1, 0); // Mono output!
this->set_params_(params);
this->_reset_anti_pop_();

long receptive_field = 1;
for (size_t i = 0; i < this->_layer_arrays.size(); i++)
receptive_field += this->_layer_arrays[i].get_receptive_field();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is a receptive field getter not defined for WaveNet? That'd be handy/surely this exists somewhere else in the code?

...looks like no! Huh.


NAM_SAMPLE sample = 0;
NAM_SAMPLE* sample_ptr = &sample;

std::unordered_map<std::string, double> param_dict = {};

// pre-warm the model over the size of the receptive field
for (long i = 0; i < receptive_field; i++)
{
this->process(&sample_ptr, &sample_ptr, 1, 1, 1.0, 1.0, param_dict);
this->finalize_(1);
sample = 0;
}
}

void wavenet::WaveNet::finalize_(const int num_frames)
Expand Down Expand Up @@ -315,11 +331,6 @@ void wavenet::WaveNet::_process_core_()
this->_set_num_frames_(num_frames);
this->_prepare_for_frames_(num_frames);

// NOTE: During warm-up, weird things can happen that NaN out the layers.
// We could solve this by anti-popping the *input*. But, it's easier to check
// the outputs for NaNs and zero them out.
// They'll flush out eventually because the model doesn't use any feedback.

// Fill into condition array:
// Clumsy...
for (int j = 0; j < num_frames; j++)
Expand Down Expand Up @@ -351,13 +362,8 @@ void wavenet::WaveNet::_process_core_()
for (int s = 0; s < num_frames; s++)
{
float out = this->_head_scale * this->_head_arrays[final_head_array](0, s);
// This is the NaN check that we could fix with anti-popping the input
if (isnan(out))
out = 0.0;
this->_core_dsp_output[s] = out;
}
// Apply anti-pop
this->_anti_pop_();
}

void wavenet::WaveNet::_set_num_frames_(const long num_frames)
Expand All @@ -377,27 +383,3 @@ void wavenet::WaveNet::_set_num_frames_(const long num_frames)
// this->_head.set_num_frames_(num_frames);
this->_num_frames = num_frames;
}

void wavenet::WaveNet::_anti_pop_()
{
if (this->_anti_pop_countdown >= this->_anti_pop_ramp)
return;
const float slope = 1.0f / float(this->_anti_pop_ramp);
for (size_t i = 0; i < this->_core_dsp_output.size(); i++)
{
if (this->_anti_pop_countdown >= this->_anti_pop_ramp)
break;
const float gain = std::max(slope * float(this->_anti_pop_countdown), 0.0f);
this->_core_dsp_output[i] *= gain;
this->_anti_pop_countdown++;
}
}

void wavenet::WaveNet::_reset_anti_pop_()
{
// You need the "real" receptive field, not the buffers.
long receptive_field = 1;
for (size_t i = 0; i < this->_layer_arrays.size(); i++)
receptive_field += this->_layer_arrays[i].get_receptive_field();
this->_anti_pop_countdown = -receptive_field;
}
9 changes: 0 additions & 9 deletions NAM/wavenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,5 @@ class WaveNet : public DSP

// Ensure that all buffer arrays are the right size for this num_frames
void _set_num_frames_(const long num_frames);

// The net starts with random parameters inside; we need to wait for a full
// receptive field to pass through before we can count on the output being
// ok. This implements a gentle "ramp-up" so that there's no "pop" at the
// start.
long _anti_pop_countdown;
const long _anti_pop_ramp = 4000;
void _anti_pop_();
void _reset_anti_pop_();
};
}; // namespace wavenet