Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved pre-warm to DSP and call it in get_dsp() #90

Merged
merged 1 commit into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions NAM/convnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,13 @@ convnet::ConvNet::ConvNet(const int channels, const std::vector<int>& dilations,
this->_head = _Head(channels, it);
if (it != params.end())
throw std::runtime_error("Didn't touch all the params when initializing ConvNet");

_prewarm_samples = 1;
for (size_t i = 0; i < dilations.size(); i++)
_prewarm_samples += dilations[i];
sdatkinson marked this conversation as resolved.
Show resolved Hide resolved
}


void convnet::ConvNet::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames)

{
Expand Down
17 changes: 17 additions & 0 deletions NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ DSP::DSP(const double expected_sample_rate)
{
}

void DSP::prewarm()
{
if (_prewarm_samples == 0)
return;

NAM_SAMPLE sample = 0;
NAM_SAMPLE* sample_ptr = &sample;

// pre-warm the model for a model-specific number of samples
for (long i = 0; i < _prewarm_samples; i++)
{
this->process(sample_ptr, sample_ptr, 1);
this->finalize_(1);
sample = 0;
}
}

void DSP::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames)
{
// Default implementation is the null operation
Expand Down
4 changes: 4 additions & 0 deletions NAM/dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class DSP
// We may choose to have the models figure out for themselves how loud they are in here in the future.
DSP(const double expected_sample_rate);
virtual ~DSP() = default;
// prewarm() does any required intial work required to "settle" model initial conditions
// it can be somewhat expensive, so should not be called during realtime audio processing
virtual void prewarm();
// process() does all of the processing requried to take `input` array and
// fill in the required values on `output`.
// To do this:
Expand Down Expand Up @@ -87,6 +90,7 @@ class DSP
std::unordered_map<std::string, double> _params;
// If the params have changed since the last buffer was processed:
bool _stale_params = true;
int _prewarm_samples = 0;

// Methods

Expand Down
4 changes: 4 additions & 0 deletions NAM/get_dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,9 @@ std::unique_ptr<DSP> get_dsp(dspData& conf)
{
out->SetLoudness(loudness);
}

// "pre-warm" the model to settle initial conditions
out->prewarm();

return out;
}
15 changes: 2 additions & 13 deletions NAM/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,20 +263,9 @@ wavenet::WaveNet::WaveNet(const std::vector<wavenet::LayerArrayParams>& layer_ar
this->_head_output.resize(1, 0); // Mono output!
this->set_params_(params);

long receptive_field = 1;
_prewarm_samples = 1;
for (size_t i = 0; i < this->_layer_arrays.size(); i++)
receptive_field += this->_layer_arrays[i].get_receptive_field();

NAM_SAMPLE sample = 0;
NAM_SAMPLE* sample_ptr = &sample;

// pre-warm the model over the size of the receptive field
for (long i = 0; i < receptive_field; i++)
{
this->process(sample_ptr, sample_ptr, 1);
this->finalize_(1);
sample = 0;
}
_prewarm_samples += this->_layer_arrays[i].get_receptive_field();
}

void wavenet::WaveNet::finalize_(const int num_frames)
Expand Down
1 change: 0 additions & 1 deletion NAM/wavenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ class WaveNet : public DSP
// Get the info from the parametric config
void _init_parametric_(nlohmann::json& parametric);
void _prepare_for_frames_(const long num_frames);
// Reminder: From ._input_post_gain to ._core_dsp_output
void process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames) override;

// Ensure that all buffer arrays are the right size for this num_frames
Expand Down