Skip to content

Commit

Permalink
Add LSTMLayer and LSTMUnitLayer, with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffdonahue committed Feb 16, 2015
1 parent ee44d95 commit c38f9ac
Show file tree
Hide file tree
Showing 5 changed files with 901 additions and 0 deletions.
133 changes: 133 additions & 0 deletions include/caffe/sequence_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,139 @@ class RecurrentLayer : public Layer<Dtype> {
Blob<Dtype>* cont_input_blob_;
};

/**
* @brief Processes sequential inputs using a "Long Short-Term Memory" (LSTM)
* [1] style recurrent neural network (RNN). Implemented as a network
* unrolled the LSTM computation in time.
*
*
* The specific architecture used in this implementation is as described in
* "Learning to Execute" [2], reproduced below:
* i_t := \sigmoid[ W_{hi} * h_{t-1} + W_{xi} * x_t + b_i ]
* f_t := \sigmoid[ W_{hf} * h_{t-1} + W_{xf} * x_t + b_f ]
* o_t := \sigmoid[ W_{ho} * h_{t-1} + W_{xo} * x_t + b_o ]
* g_t := \tanh[ W_{hg} * h_{t-1} + W_{xg} * x_t + b_g ]
* c_t := (f_t .* c_{t-1}) + (i_t .* g_t)
* h_t := o_t .* \tanh[c_t]
* In the implementation, the i, f, o, and g computations are performed as a
* single inner product.
*
* Notably, this implementation lacks the "diagonal" gates, as used in the
* LSTM architectures described by Alex Graves [3] and others.
*
* [1] Hochreiter, Sepp, and Schmidhuber, Jürgen. "Long short-term memory."
* Neural Computation 9, no. 8 (1997): 1735-1780.
*
* [2] Zaremba, Wojciech, and Sutskever, Ilya. "Learning to execute."
* arXiv preprint arXiv:1410.4615 (2014).
*
* [3] Graves, Alex. "Generating sequences with recurrent neural networks."
* arXiv preprint arXiv:1308.0850 (2013).
*/
template <typename Dtype>
class LSTMLayer : public RecurrentLayer<Dtype> {
public:
explicit LSTMLayer(const LayerParameter& param)
: RecurrentLayer<Dtype>(param) {}

virtual inline const char* type() const { return "LSTM"; }

protected:
virtual void FillUnrolledNet(NetParameter* net_param) const;
virtual void RecurrentInputBlobNames(vector<string>* names) const;
virtual void RecurrentOutputBlobNames(vector<string>* names) const;
virtual void OutputBlobNames(vector<string>* names) const;
};

/**
* @brief A helper for LSTMLayer: computes a single timestep of the
* non-linearity of the LSTM, producing the updated cell and hidden
* states.
*/
template <typename Dtype>
class LSTMUnitLayer : public Layer<Dtype> {
public:
explicit LSTMUnitLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "LSTMUnit"; }
virtual inline int ExactNumBottomBlobs() const { return 3; }
virtual inline int ExactNumTopBlobs() const { return 2; }

virtual inline bool AllowForceBackward(const int bottom_index) const {
// Can't propagate to sequence continuation indicators.
return bottom_index != 2;
}

protected:
/**
* @param bottom input Blob vector (length 3)
* -# @f$ (1 \times N \times D) @f$
* the previous timestep cell state @f$ c_{t-1} @f$
* -# @f$ (1 \times N \times 4D) @f$
* the "gate inputs" @f$ [i_t', f_t', o_t', g_t'] @f$
* -# @f$ (1 \times 1 \times N) @f$
* the sequence continuation indicators @f$ \delta_t @f$
* @param top output Blob vector (length 2)
* -# @f$ (1 \times N \times D) @f$
* the updated cell state @f$ c_t @f$, computed as:
* i_t := \sigmoid[i_t']
* f_t := \sigmoid[f_t']
* o_t := \sigmoid[o_t']
* g_t := \tanh[g_t']
* c_t := cont_t * (f_t .* c_{t-1}) + (i_t .* g_t)
* -# @f$ (1 \times N \times D) @f$
* the updated hidden state @f$ h_t @f$, computed as:
* h_t := o_t .* \tanh[c_t]
*/
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

/**
* @brief Computes the error gradient w.r.t. the LSTMUnit inputs.
*
* @param top output Blob vector (length 2), providing the error gradient with
* respect to the outputs
* -# @f$ (1 \times N \times D) @f$:
* containing error gradients @f$ \frac{\partial E}{\partial c_t} @f$
* with respect to the updated cell state @f$ c_t @f$
* -# @f$ (1 \times N \times D) @f$:
* containing error gradients @f$ \frac{\partial E}{\partial h_t} @f$
* with respect to the updated cell state @f$ h_t @f$
* @param propagate_down see Layer::Backward.
* @param bottom input Blob vector (length 3), into which the error gradients
* with respect to the LSTMUnit inputs @f$ c_{t-1} @f$ and the gate
* inputs are computed. Computatation of the error gradients w.r.t.
* the sequence indicators is not implemented.
* -# @f$ (1 \times N \times D) @f$
* the error gradient w.r.t. the previous timestep cell state
* @f$ c_{t-1} @f$
* -# @f$ (1 \times N \times 4D) @f$
* the error gradient w.r.t. the "gate inputs"
* @f$ [
* \frac{\partial E}{\partial i_t}
* \frac{\partial E}{\partial f_t}
* \frac{\partial E}{\partial o_t}
* \frac{\partial E}{\partial g_t}
* ] @f$
* -# @f$ (1 \times 1 \times N) @f$
* the gradient w.r.t. the sequence continuation indicators
* @f$ \delta_t @f$ is currently not computed.
*/
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

/// @brief The hidden and output dimension.
int hidden_dim_;
Blob<Dtype> X_acts_;
};

/**
* @brief Processes time-varying inputs using a simple recurrent neural network
* (RNN). Implemented as a network unrolling the RNN computation in time.
Expand Down
221 changes: 221 additions & 0 deletions src/caffe/layers/lstm_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
#include <string>
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/sequence_layers.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void LSTMLayer<Dtype>::RecurrentInputBlobNames(vector<string>* names) const {
names->resize(2);
(*names)[0] = "h_0";
(*names)[1] = "c_0";
}

template <typename Dtype>
void LSTMLayer<Dtype>::RecurrentOutputBlobNames(vector<string>* names) const {
names->resize(2);
(*names)[0] = "h_" + this->int_to_str(this->T_);
(*names)[1] = "c_T";
}

template <typename Dtype>
void LSTMLayer<Dtype>::OutputBlobNames(vector<string>* names) const {
names->resize(1);
(*names)[0] = "h";
}

template <typename Dtype>
void LSTMLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
const int num_output = this->layer_param_.recurrent_param().num_output();
CHECK_GT(num_output, 0) << "num_output must be positive";
const FillerParameter& weight_filler =
this->layer_param_.recurrent_param().weight_filler();
const FillerParameter& bias_filler =
this->layer_param_.recurrent_param().bias_filler();

// Add generic LayerParameter's (without bottoms/tops) of layer types we'll
// use to save redundant code.
LayerParameter hidden_param;
hidden_param.set_type("InnerProduct");
hidden_param.mutable_inner_product_param()->set_num_output(num_output * 4);
hidden_param.mutable_inner_product_param()->set_bias_term(false);
hidden_param.mutable_inner_product_param()->set_axis(2);
hidden_param.mutable_inner_product_param()->
mutable_weight_filler()->CopyFrom(weight_filler);

LayerParameter biased_hidden_param(hidden_param);
biased_hidden_param.mutable_inner_product_param()->set_bias_term(true);
biased_hidden_param.mutable_inner_product_param()->
mutable_bias_filler()->CopyFrom(bias_filler);

LayerParameter sum_param;
sum_param.set_type("Eltwise");
sum_param.mutable_eltwise_param()->set_operation(
EltwiseParameter_EltwiseOp_SUM);

LayerParameter slice_param;
slice_param.set_type("Slice");
slice_param.mutable_slice_param()->set_axis(0);

LayerParameter split_param;
split_param.set_type("Split");

BlobShape input_shape;
input_shape.add_dim(1); // c_0 and h_0 are a single timestep
input_shape.add_dim(this->N_);
input_shape.add_dim(num_output);

net_param->add_input("c_0");
net_param->add_input_shape()->CopyFrom(input_shape);

net_param->add_input("h_0");
net_param->add_input_shape()->CopyFrom(input_shape);

LayerParameter* cont_slice_param = net_param->add_layer();
cont_slice_param->CopyFrom(slice_param);
cont_slice_param->set_name("cont_slice");
cont_slice_param->add_bottom("cont");
cont_slice_param->mutable_slice_param()->set_axis(1);

// Add layer to transform all timesteps of x to the hidden state dimension.
// W_xc_x = W_xc * x + b_c
{
LayerParameter* x_transform_param = net_param->add_layer();
x_transform_param->CopyFrom(biased_hidden_param);
x_transform_param->set_name("x_transform");
x_transform_param->add_param()->set_name("W_xc");
x_transform_param->add_param()->set_name("b_c");
x_transform_param->add_bottom("x");
x_transform_param->add_top("W_xc_x");
}

if (this->static_input_) {
// Add layer to transform x_static to the gate dimension.
// W_xc_x_static = W_xc_static * x_static
LayerParameter* x_static_transform_param = net_param->add_layer();
x_static_transform_param->CopyFrom(hidden_param);
x_static_transform_param->mutable_inner_product_param()->set_axis(1);
x_static_transform_param->set_name("W_xc_x_static");
x_static_transform_param->add_param()->set_name("W_xc_static");
x_static_transform_param->add_bottom("x_static");
x_static_transform_param->add_top("W_xc_x_static");

LayerParameter* reshape_param = net_param->add_layer();
reshape_param->set_type("Reshape");
BlobShape* new_shape =
reshape_param->mutable_reshape_param()->mutable_shape();
new_shape->add_dim(1); // One timestep.
new_shape->add_dim(this->N_);
new_shape->add_dim(
x_static_transform_param->inner_product_param().num_output());
reshape_param->add_bottom("W_xc_x_static");
reshape_param->add_top("W_xc_x_static");
}

LayerParameter* x_slice_param = net_param->add_layer();
x_slice_param->CopyFrom(slice_param);
x_slice_param->add_bottom("W_xc_x");
x_slice_param->set_name("W_xc_x_slice");

LayerParameter output_concat_layer;
output_concat_layer.set_name("h_concat");
output_concat_layer.set_type("Concat");
output_concat_layer.add_top("h");
output_concat_layer.mutable_concat_param()->set_axis(0);

for (int t = 1; t <= this->T_; ++t) {
string tm1s = this->int_to_str(t - 1);
string ts = this->int_to_str(t);

cont_slice_param->add_top("cont_" + ts);
x_slice_param->add_top("W_xc_x_" + ts);

// Add layers to flush the hidden state when beginning a new
// sequence, as indicated by cont_t.
// h_conted_{t-1} := cont_t * h_{t-1}
//
// Normally, cont_t is binary (i.e., 0 or 1), so:
// h_conted_{t-1} := h_{t-1} if cont_t == 1
// 0 otherwise
{
LayerParameter* cont_h_param = net_param->add_layer();
cont_h_param->CopyFrom(sum_param);
cont_h_param->mutable_eltwise_param()->set_coeff_blob(true);
cont_h_param->set_name("h_conted_" + tm1s);
cont_h_param->add_bottom("h_" + tm1s);
cont_h_param->add_bottom("cont_" + ts);
cont_h_param->add_top("h_conted_" + tm1s);
}

// Add layer to compute
// W_hc_h_{t-1} := W_hc * h_conted_{t-1}
{
LayerParameter* w_param = net_param->add_layer();
w_param->CopyFrom(hidden_param);
w_param->set_name("transform_" + ts);
w_param->add_param()->set_name("W_hc");
w_param->add_bottom("h_conted_" + tm1s);
w_param->add_top("W_hc_h_" + tm1s);
w_param->mutable_inner_product_param()->set_axis(2);
}

// Add the outputs of the linear transformations to compute the gate input.
// gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
// = W_hc_h_{t-1} + W_xc_x_t + b_c
{
LayerParameter* input_sum_layer = net_param->add_layer();
input_sum_layer->CopyFrom(sum_param);
input_sum_layer->set_name("gate_input_" + ts);
input_sum_layer->add_bottom("W_hc_h_" + tm1s);
input_sum_layer->add_bottom("W_xc_x_" + ts);
if (this->static_input_) {
input_sum_layer->add_bottom("W_xc_x_static");
}
input_sum_layer->add_top("gate_input_" + ts);
}

// Add LSTMUnit layer to compute the cell & hidden vectors c_t and h_t.
// Inputs: c_{t-1}, gate_input_t = (i_t, f_t, o_t, g_t), cont_t
// Outputs: c_t, h_t
// [ i_t' ]
// [ f_t' ] := gate_input_t
// [ o_t' ]
// [ g_t' ]
// i_t := \sigmoid[i_t']
// f_t := \sigmoid[f_t']
// o_t := \sigmoid[o_t']
// g_t := \tanh[g_t']
// c_t := cont_t * (f_t .* c_{t-1}) + (i_t .* g_t)
// h_t := o_t .* \tanh[c_t]
{
LayerParameter* lstm_unit_param = net_param->add_layer();
lstm_unit_param->set_type("LSTMUnit");
lstm_unit_param->add_bottom("c_" + tm1s);
lstm_unit_param->add_bottom("gate_input_" + ts);
lstm_unit_param->add_bottom("cont_" + ts);
lstm_unit_param->add_top("c_" + ts);
lstm_unit_param->add_top("h_" + ts);
lstm_unit_param->set_name("unit_" + ts);
}
output_concat_layer.add_bottom("h_" + ts);
} // for (int t = 1; t <= this->T_; ++t)

{
LayerParameter* c_T_copy_param = net_param->add_layer();
c_T_copy_param->CopyFrom(split_param);
c_T_copy_param->add_bottom("c_" + this->int_to_str(this->T_));
c_T_copy_param->add_top("c_T");
}
net_param->add_layer()->CopyFrom(output_concat_layer);
}

INSTANTIATE_CLASS(LSTMLayer);
REGISTER_LAYER_CLASS(LSTM);

} // namespace caffe
Loading

0 comments on commit c38f9ac

Please sign in to comment.