Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define nam::DSP::Reset and nam::DSP::ResetAndPrewarm #111

Merged
merged 6 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Define DSP::Reset() and DSP:Warmup()
  • Loading branch information
sdatkinson committed Sep 8, 2024
commit 27303ad4b7a1605b5f55296be94013f4ce70165f
8 changes: 8 additions & 0 deletions NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ double nam::DSP::GetLoudness() const
return mLoudness;
}

void nam::DSP::Reset(const double sampleRate, const int maxBufferSize)
{
// Some subclasses might want to throw an exception if the sample rate is "wrong".
// This could be under a debugging flag potentially.
mExternalSampleRate = sampleRate;
mHaveExternalSampleRate = true;
mMaxBufferSize = maxBufferSize;
}
void nam::DSP::SetLoudness(const double loudness)
{
mLoudness = loudness;
Expand Down
17 changes: 17 additions & 0 deletions NAM/dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,16 @@ class DSP
double GetLoudness() const;
// Get whether the model knows how loud it is.
bool HasLoudness() const { return mHasLoudness; };
// General function for resetting the DSP unit.
// NAMs might do warmup under this.
virtual void Reset(const double sampleRate, const int maxBufferSize);
// Set the loudness, in dB.
// This is usually defined to be the loudness to a standardized input. The trainer has its own, but you can always
// use this to define it a different way if you like yours better.
void SetLoudness(const double loudness);
// Run some zeroes through the DSP unit until it's ready.
// This is helpful for things that have a history dependence (LSTMs, Convolutions, etc)
void Warmup();

protected:
bool mHasLoudness = false;
Expand All @@ -77,6 +83,14 @@ class DSP
double mExpectedSampleRate;
// How many samples should be processed during "pre-warming"
int _prewarm_samples = 0;
// Have we been told what the external sample rate is? If so, what is it?
bool mHaveExternalSampleRate = false;
double mExternalSampleRate = -1.0;
// The largest buffer I expect to be told to process:
int mMaxBufferSize;

// How many samples should be processed for me to be considered "warmed up"?
virtual int WarmupSamples() { return 0; };
};

// Class where an input buffer is kept so that long-time effects can be
Expand All @@ -87,6 +101,8 @@ class Buffer : public DSP
public:
Buffer(const int receptive_field, const double expected_sample_rate = -1.0);
void finalize_(const int num_frames);
// TODO Could set buffer sizes etc
// virtual void Reset(const double sampleRate, const int maxBufferSize) override;

protected:
// Input buffer
Expand Down Expand Up @@ -120,6 +136,7 @@ class Linear : public Buffer

// NN modules =================================================================

// TODO conv could take care of its own ring buffer.
class Conv1D
{
public:
Expand Down
4 changes: 2 additions & 2 deletions NAM/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ void nam::wavenet::_Layer::process_(const Eigen::MatrixXf& input, const Eigen::M
// Mix-in condition
this->_z += this->_input_mixin.process(condition);



if (!this->_gated)
{
Expand All @@ -40,7 +39,8 @@ void nam::wavenet::_Layer::process_(const Eigen::MatrixXf& input, const Eigen::M
{
this->_activation->apply(this->_z.topRows(channels));
activations::Activation::get_activation("Sigmoid")->apply(this->_z.bottomRows(channels));
//activations::Activation::get_activation("Sigmoid")->apply(this->_z.block(channels, 0, channels, this->_z.cols()));
// activations::Activation::get_activation("Sigmoid")->apply(this->_z.block(channels, 0, channels,
// this->_z.cols()));

this->_z.topRows(channels).array() *= this->_z.bottomRows(channels).array();
// this->_z.topRows(channels) = this->_z.topRows(channels).cwiseProduct(
Expand Down