From c38f9ac750a3fde17082fdef122a90c09827181b Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Sun, 15 Feb 2015 14:57:22 -0800 Subject: [PATCH] Add LSTMLayer and LSTMUnitLayer, with tests --- include/caffe/sequence_layers.hpp | 133 ++++++++++++++ src/caffe/layers/lstm_layer.cpp | 221 ++++++++++++++++++++++ src/caffe/layers/lstm_unit_layer.cpp | 128 +++++++++++++ src/caffe/layers/lstm_unit_layer.cu | 154 ++++++++++++++++ src/caffe/test/test_lstm_layer.cpp | 265 +++++++++++++++++++++++++++ 5 files changed, 901 insertions(+) create mode 100644 src/caffe/layers/lstm_layer.cpp create mode 100644 src/caffe/layers/lstm_unit_layer.cpp create mode 100644 src/caffe/layers/lstm_unit_layer.cu create mode 100644 src/caffe/test/test_lstm_layer.cpp diff --git a/include/caffe/sequence_layers.hpp b/include/caffe/sequence_layers.hpp index 708f6380d93..8ac735435a4 100644 --- a/include/caffe/sequence_layers.hpp +++ b/include/caffe/sequence_layers.hpp @@ -149,6 +149,139 @@ class RecurrentLayer : public Layer { Blob* cont_input_blob_; }; +/** + * @brief Processes sequential inputs using a "Long Short-Term Memory" (LSTM) + * [1] style recurrent neural network (RNN). Implemented as a network + * unrolled the LSTM computation in time. + * + * + * The specific architecture used in this implementation is as described in + * "Learning to Execute" [2], reproduced below: + * i_t := \sigmoid[ W_{hi} * h_{t-1} + W_{xi} * x_t + b_i ] + * f_t := \sigmoid[ W_{hf} * h_{t-1} + W_{xf} * x_t + b_f ] + * o_t := \sigmoid[ W_{ho} * h_{t-1} + W_{xo} * x_t + b_o ] + * g_t := \tanh[ W_{hg} * h_{t-1} + W_{xg} * x_t + b_g ] + * c_t := (f_t .* c_{t-1}) + (i_t .* g_t) + * h_t := o_t .* \tanh[c_t] + * In the implementation, the i, f, o, and g computations are performed as a + * single inner product. + * + * Notably, this implementation lacks the "diagonal" gates, as used in the + * LSTM architectures described by Alex Graves [3] and others. + * + * [1] Hochreiter, Sepp, and Schmidhuber, Jürgen. "Long short-term memory." + * Neural Computation 9, no. 8 (1997): 1735-1780. + * + * [2] Zaremba, Wojciech, and Sutskever, Ilya. "Learning to execute." + * arXiv preprint arXiv:1410.4615 (2014). + * + * [3] Graves, Alex. "Generating sequences with recurrent neural networks." + * arXiv preprint arXiv:1308.0850 (2013). + */ +template +class LSTMLayer : public RecurrentLayer { + public: + explicit LSTMLayer(const LayerParameter& param) + : RecurrentLayer(param) {} + + virtual inline const char* type() const { return "LSTM"; } + + protected: + virtual void FillUnrolledNet(NetParameter* net_param) const; + virtual void RecurrentInputBlobNames(vector* names) const; + virtual void RecurrentOutputBlobNames(vector* names) const; + virtual void OutputBlobNames(vector* names) const; +}; + +/** + * @brief A helper for LSTMLayer: computes a single timestep of the + * non-linearity of the LSTM, producing the updated cell and hidden + * states. + */ +template +class LSTMUnitLayer : public Layer { + public: + explicit LSTMUnitLayer(const LayerParameter& param) + : Layer(param) {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "LSTMUnit"; } + virtual inline int ExactNumBottomBlobs() const { return 3; } + virtual inline int ExactNumTopBlobs() const { return 2; } + + virtual inline bool AllowForceBackward(const int bottom_index) const { + // Can't propagate to sequence continuation indicators. + return bottom_index != 2; + } + + protected: + /** + * @param bottom input Blob vector (length 3) + * -# @f$ (1 \times N \times D) @f$ + * the previous timestep cell state @f$ c_{t-1} @f$ + * -# @f$ (1 \times N \times 4D) @f$ + * the "gate inputs" @f$ [i_t', f_t', o_t', g_t'] @f$ + * -# @f$ (1 \times 1 \times N) @f$ + * the sequence continuation indicators @f$ \delta_t @f$ + * @param top output Blob vector (length 2) + * -# @f$ (1 \times N \times D) @f$ + * the updated cell state @f$ c_t @f$, computed as: + * i_t := \sigmoid[i_t'] + * f_t := \sigmoid[f_t'] + * o_t := \sigmoid[o_t'] + * g_t := \tanh[g_t'] + * c_t := cont_t * (f_t .* c_{t-1}) + (i_t .* g_t) + * -# @f$ (1 \times N \times D) @f$ + * the updated hidden state @f$ h_t @f$, computed as: + * h_t := o_t .* \tanh[c_t] + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the error gradient w.r.t. the LSTMUnit inputs. + * + * @param top output Blob vector (length 2), providing the error gradient with + * respect to the outputs + * -# @f$ (1 \times N \times D) @f$: + * containing error gradients @f$ \frac{\partial E}{\partial c_t} @f$ + * with respect to the updated cell state @f$ c_t @f$ + * -# @f$ (1 \times N \times D) @f$: + * containing error gradients @f$ \frac{\partial E}{\partial h_t} @f$ + * with respect to the updated cell state @f$ h_t @f$ + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 3), into which the error gradients + * with respect to the LSTMUnit inputs @f$ c_{t-1} @f$ and the gate + * inputs are computed. Computatation of the error gradients w.r.t. + * the sequence indicators is not implemented. + * -# @f$ (1 \times N \times D) @f$ + * the error gradient w.r.t. the previous timestep cell state + * @f$ c_{t-1} @f$ + * -# @f$ (1 \times N \times 4D) @f$ + * the error gradient w.r.t. the "gate inputs" + * @f$ [ + * \frac{\partial E}{\partial i_t} + * \frac{\partial E}{\partial f_t} + * \frac{\partial E}{\partial o_t} + * \frac{\partial E}{\partial g_t} + * ] @f$ + * -# @f$ (1 \times 1 \times N) @f$ + * the gradient w.r.t. the sequence continuation indicators + * @f$ \delta_t @f$ is currently not computed. + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + /// @brief The hidden and output dimension. + int hidden_dim_; + Blob X_acts_; +}; + /** * @brief Processes time-varying inputs using a simple recurrent neural network * (RNN). Implemented as a network unrolling the RNN computation in time. diff --git a/src/caffe/layers/lstm_layer.cpp b/src/caffe/layers/lstm_layer.cpp new file mode 100644 index 00000000000..91543f73f71 --- /dev/null +++ b/src/caffe/layers/lstm_layer.cpp @@ -0,0 +1,221 @@ +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/sequence_layers.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void LSTMLayer::RecurrentInputBlobNames(vector* names) const { + names->resize(2); + (*names)[0] = "h_0"; + (*names)[1] = "c_0"; +} + +template +void LSTMLayer::RecurrentOutputBlobNames(vector* names) const { + names->resize(2); + (*names)[0] = "h_" + this->int_to_str(this->T_); + (*names)[1] = "c_T"; +} + +template +void LSTMLayer::OutputBlobNames(vector* names) const { + names->resize(1); + (*names)[0] = "h"; +} + +template +void LSTMLayer::FillUnrolledNet(NetParameter* net_param) const { + const int num_output = this->layer_param_.recurrent_param().num_output(); + CHECK_GT(num_output, 0) << "num_output must be positive"; + const FillerParameter& weight_filler = + this->layer_param_.recurrent_param().weight_filler(); + const FillerParameter& bias_filler = + this->layer_param_.recurrent_param().bias_filler(); + + // Add generic LayerParameter's (without bottoms/tops) of layer types we'll + // use to save redundant code. + LayerParameter hidden_param; + hidden_param.set_type("InnerProduct"); + hidden_param.mutable_inner_product_param()->set_num_output(num_output * 4); + hidden_param.mutable_inner_product_param()->set_bias_term(false); + hidden_param.mutable_inner_product_param()->set_axis(2); + hidden_param.mutable_inner_product_param()-> + mutable_weight_filler()->CopyFrom(weight_filler); + + LayerParameter biased_hidden_param(hidden_param); + biased_hidden_param.mutable_inner_product_param()->set_bias_term(true); + biased_hidden_param.mutable_inner_product_param()-> + mutable_bias_filler()->CopyFrom(bias_filler); + + LayerParameter sum_param; + sum_param.set_type("Eltwise"); + sum_param.mutable_eltwise_param()->set_operation( + EltwiseParameter_EltwiseOp_SUM); + + LayerParameter slice_param; + slice_param.set_type("Slice"); + slice_param.mutable_slice_param()->set_axis(0); + + LayerParameter split_param; + split_param.set_type("Split"); + + BlobShape input_shape; + input_shape.add_dim(1); // c_0 and h_0 are a single timestep + input_shape.add_dim(this->N_); + input_shape.add_dim(num_output); + + net_param->add_input("c_0"); + net_param->add_input_shape()->CopyFrom(input_shape); + + net_param->add_input("h_0"); + net_param->add_input_shape()->CopyFrom(input_shape); + + LayerParameter* cont_slice_param = net_param->add_layer(); + cont_slice_param->CopyFrom(slice_param); + cont_slice_param->set_name("cont_slice"); + cont_slice_param->add_bottom("cont"); + cont_slice_param->mutable_slice_param()->set_axis(1); + + // Add layer to transform all timesteps of x to the hidden state dimension. + // W_xc_x = W_xc * x + b_c + { + LayerParameter* x_transform_param = net_param->add_layer(); + x_transform_param->CopyFrom(biased_hidden_param); + x_transform_param->set_name("x_transform"); + x_transform_param->add_param()->set_name("W_xc"); + x_transform_param->add_param()->set_name("b_c"); + x_transform_param->add_bottom("x"); + x_transform_param->add_top("W_xc_x"); + } + + if (this->static_input_) { + // Add layer to transform x_static to the gate dimension. + // W_xc_x_static = W_xc_static * x_static + LayerParameter* x_static_transform_param = net_param->add_layer(); + x_static_transform_param->CopyFrom(hidden_param); + x_static_transform_param->mutable_inner_product_param()->set_axis(1); + x_static_transform_param->set_name("W_xc_x_static"); + x_static_transform_param->add_param()->set_name("W_xc_static"); + x_static_transform_param->add_bottom("x_static"); + x_static_transform_param->add_top("W_xc_x_static"); + + LayerParameter* reshape_param = net_param->add_layer(); + reshape_param->set_type("Reshape"); + BlobShape* new_shape = + reshape_param->mutable_reshape_param()->mutable_shape(); + new_shape->add_dim(1); // One timestep. + new_shape->add_dim(this->N_); + new_shape->add_dim( + x_static_transform_param->inner_product_param().num_output()); + reshape_param->add_bottom("W_xc_x_static"); + reshape_param->add_top("W_xc_x_static"); + } + + LayerParameter* x_slice_param = net_param->add_layer(); + x_slice_param->CopyFrom(slice_param); + x_slice_param->add_bottom("W_xc_x"); + x_slice_param->set_name("W_xc_x_slice"); + + LayerParameter output_concat_layer; + output_concat_layer.set_name("h_concat"); + output_concat_layer.set_type("Concat"); + output_concat_layer.add_top("h"); + output_concat_layer.mutable_concat_param()->set_axis(0); + + for (int t = 1; t <= this->T_; ++t) { + string tm1s = this->int_to_str(t - 1); + string ts = this->int_to_str(t); + + cont_slice_param->add_top("cont_" + ts); + x_slice_param->add_top("W_xc_x_" + ts); + + // Add layers to flush the hidden state when beginning a new + // sequence, as indicated by cont_t. + // h_conted_{t-1} := cont_t * h_{t-1} + // + // Normally, cont_t is binary (i.e., 0 or 1), so: + // h_conted_{t-1} := h_{t-1} if cont_t == 1 + // 0 otherwise + { + LayerParameter* cont_h_param = net_param->add_layer(); + cont_h_param->CopyFrom(sum_param); + cont_h_param->mutable_eltwise_param()->set_coeff_blob(true); + cont_h_param->set_name("h_conted_" + tm1s); + cont_h_param->add_bottom("h_" + tm1s); + cont_h_param->add_bottom("cont_" + ts); + cont_h_param->add_top("h_conted_" + tm1s); + } + + // Add layer to compute + // W_hc_h_{t-1} := W_hc * h_conted_{t-1} + { + LayerParameter* w_param = net_param->add_layer(); + w_param->CopyFrom(hidden_param); + w_param->set_name("transform_" + ts); + w_param->add_param()->set_name("W_hc"); + w_param->add_bottom("h_conted_" + tm1s); + w_param->add_top("W_hc_h_" + tm1s); + w_param->mutable_inner_product_param()->set_axis(2); + } + + // Add the outputs of the linear transformations to compute the gate input. + // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c + // = W_hc_h_{t-1} + W_xc_x_t + b_c + { + LayerParameter* input_sum_layer = net_param->add_layer(); + input_sum_layer->CopyFrom(sum_param); + input_sum_layer->set_name("gate_input_" + ts); + input_sum_layer->add_bottom("W_hc_h_" + tm1s); + input_sum_layer->add_bottom("W_xc_x_" + ts); + if (this->static_input_) { + input_sum_layer->add_bottom("W_xc_x_static"); + } + input_sum_layer->add_top("gate_input_" + ts); + } + + // Add LSTMUnit layer to compute the cell & hidden vectors c_t and h_t. + // Inputs: c_{t-1}, gate_input_t = (i_t, f_t, o_t, g_t), cont_t + // Outputs: c_t, h_t + // [ i_t' ] + // [ f_t' ] := gate_input_t + // [ o_t' ] + // [ g_t' ] + // i_t := \sigmoid[i_t'] + // f_t := \sigmoid[f_t'] + // o_t := \sigmoid[o_t'] + // g_t := \tanh[g_t'] + // c_t := cont_t * (f_t .* c_{t-1}) + (i_t .* g_t) + // h_t := o_t .* \tanh[c_t] + { + LayerParameter* lstm_unit_param = net_param->add_layer(); + lstm_unit_param->set_type("LSTMUnit"); + lstm_unit_param->add_bottom("c_" + tm1s); + lstm_unit_param->add_bottom("gate_input_" + ts); + lstm_unit_param->add_bottom("cont_" + ts); + lstm_unit_param->add_top("c_" + ts); + lstm_unit_param->add_top("h_" + ts); + lstm_unit_param->set_name("unit_" + ts); + } + output_concat_layer.add_bottom("h_" + ts); + } // for (int t = 1; t <= this->T_; ++t) + + { + LayerParameter* c_T_copy_param = net_param->add_layer(); + c_T_copy_param->CopyFrom(split_param); + c_T_copy_param->add_bottom("c_" + this->int_to_str(this->T_)); + c_T_copy_param->add_top("c_T"); + } + net_param->add_layer()->CopyFrom(output_concat_layer); +} + +INSTANTIATE_CLASS(LSTMLayer); +REGISTER_LAYER_CLASS(LSTM); + +} // namespace caffe diff --git a/src/caffe/layers/lstm_unit_layer.cpp b/src/caffe/layers/lstm_unit_layer.cpp new file mode 100644 index 00000000000..74078d264f5 --- /dev/null +++ b/src/caffe/layers/lstm_unit_layer.cpp @@ -0,0 +1,128 @@ +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/sequence_layers.hpp" + +namespace caffe { + +template +inline Dtype sigmoid(Dtype x) { + return 1. / (1. + exp(-x)); +} + +template +inline Dtype tanh(Dtype x) { + return 2. * sigmoid(2. * x) - 1.; +} + +template +void LSTMUnitLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + for (int i = 0; i < bottom.size(); ++i) { + CHECK_EQ(3, bottom[i]->num_axes()); + CHECK_EQ(1, bottom[i]->shape(0)); + } + const int num_instances = bottom[0]->shape(1); + hidden_dim_ = bottom[0]->shape(2); + CHECK_EQ(num_instances, bottom[1]->shape(1)); + CHECK_EQ(4 * hidden_dim_, bottom[1]->shape(2)); + CHECK_EQ(1, bottom[2]->shape(1)); + CHECK_EQ(num_instances, bottom[2]->shape(2)); + top[0]->ReshapeLike(*bottom[0]); + top[1]->ReshapeLike(*bottom[0]); + X_acts_.ReshapeLike(*bottom[1]); +} + +template +void LSTMUnitLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const int num = bottom[0]->shape(1); + const int x_dim = hidden_dim_ * 4; + const Dtype* C_prev = bottom[0]->cpu_data(); + const Dtype* X = bottom[1]->cpu_data(); + const Dtype* flush = bottom[2]->cpu_data(); + Dtype* C = top[0]->mutable_cpu_data(); + Dtype* H = top[1]->mutable_cpu_data(); + for (int n = 0; n < num; ++n) { + for (int d = 0; d < hidden_dim_; ++d) { + const Dtype i = sigmoid(X[d]); + const Dtype f = (*flush == 0) ? 0 : + (*flush * sigmoid(X[1 * hidden_dim_ + d])); + const Dtype o = sigmoid(X[2 * hidden_dim_ + d]); + const Dtype g = tanh(X[3 * hidden_dim_ + d]); + const Dtype c_prev = C_prev[d]; + const Dtype c = f * c_prev + i * g; + C[d] = c; + const Dtype tanh_c = tanh(c); + H[d] = o * tanh_c; + } + C_prev += hidden_dim_; + X += x_dim; + C += hidden_dim_; + H += hidden_dim_; + ++flush; + } +} + +template +void LSTMUnitLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + CHECK(!propagate_down[2]) << "Cannot backpropagate to sequence indicators."; + if (!propagate_down[0] && !propagate_down[1]) { return; } + + const int num = bottom[0]->shape(1); + const int x_dim = hidden_dim_ * 4; + const Dtype* C_prev = bottom[0]->cpu_data(); + const Dtype* X = bottom[1]->cpu_data(); + const Dtype* flush = bottom[2]->cpu_data(); + const Dtype* C = top[0]->cpu_data(); + const Dtype* H = top[1]->cpu_data(); + const Dtype* C_diff = top[0]->cpu_diff(); + const Dtype* H_diff = top[1]->cpu_diff(); + Dtype* C_prev_diff = bottom[0]->mutable_cpu_diff(); + Dtype* X_diff = bottom[1]->mutable_cpu_diff(); + for (int n = 0; n < num; ++n) { + for (int d = 0; d < hidden_dim_; ++d) { + const Dtype i = sigmoid(X[d]); + const Dtype f = (*flush == 0) ? 0 : + (*flush * sigmoid(X[1 * hidden_dim_ + d])); + const Dtype o = sigmoid(X[2 * hidden_dim_ + d]); + const Dtype g = tanh(X[3 * hidden_dim_ + d]); + const Dtype c_prev = C_prev[d]; + const Dtype c = C[d]; + const Dtype tanh_c = tanh(c); + Dtype* c_prev_diff = C_prev_diff + d; + Dtype* i_diff = X_diff + d; + Dtype* f_diff = X_diff + 1 * hidden_dim_ + d; + Dtype* o_diff = X_diff + 2 * hidden_dim_ + d; + Dtype* g_diff = X_diff + 3 * hidden_dim_ + d; + const Dtype c_term_diff = + C_diff[d] + H_diff[d] * o * (1 - tanh_c * tanh_c); + *c_prev_diff = c_term_diff * f; + *i_diff = c_term_diff * g * i * (1 - i); + *f_diff = c_term_diff * c_prev * f * (1 - f); + *o_diff = H_diff[d] * tanh_c * o * (1 - o); + *g_diff = c_term_diff * i * (1 - g * g); + } + C_prev += hidden_dim_; + X += x_dim; + C += hidden_dim_; + H += hidden_dim_; + C_diff += hidden_dim_; + H_diff += hidden_dim_; + X_diff += x_dim; + C_prev_diff += hidden_dim_; + ++flush; + } +} + +#ifdef CPU_ONLY +STUB_GPU(LSTMUnitLayer); +#endif + +INSTANTIATE_CLASS(LSTMUnitLayer); +REGISTER_LAYER_CLASS(LSTMUnit); + +} // namespace caffe diff --git a/src/caffe/layers/lstm_unit_layer.cu b/src/caffe/layers/lstm_unit_layer.cu new file mode 100644 index 00000000000..d6bf85071f5 --- /dev/null +++ b/src/caffe/layers/lstm_unit_layer.cu @@ -0,0 +1,154 @@ +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/sequence_layers.hpp" + +namespace caffe { + +template +__device__ Dtype sigmoid(const Dtype x) { + return Dtype(1) / (Dtype(1) + exp(-x)); +} + +template +__device__ Dtype tanh(const Dtype x) { + return Dtype(2) * sigmoid(Dtype(2) * x) - Dtype(1); +} + +template +__global__ void LSTMActsForward(const int nthreads, const int dim, + const Dtype* X, Dtype* X_acts) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int x_dim = 4 * dim; + const int d = index % x_dim; + if (d < 3 * dim) { + X_acts[index] = sigmoid(X[index]); + } else { + X_acts[index] = tanh(X[index]); + } + } +} + +template +__global__ void LSTMUnitForward(const int nthreads, const int dim, + const Dtype* C_prev, const Dtype* X, const Dtype* flush, + Dtype* C, Dtype* H) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / dim; + const int d = index % dim; + const Dtype* X_offset = X + 4 * dim * n; + const Dtype i = X_offset[d]; + const Dtype f = X_offset[1 * dim + d]; + const Dtype o = X_offset[2 * dim + d]; + const Dtype g = X_offset[3 * dim + d]; + const Dtype c_prev = C_prev[index]; + const Dtype c = flush[n] * f * c_prev + i * g; + C[index] = c; + const Dtype tanh_c = tanh(c); + H[index] = o * tanh_c; + } +} + +template +void LSTMUnitLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const int count = top[1]->count(); + const Dtype* C_prev = bottom[0]->gpu_data(); + const Dtype* X = bottom[1]->gpu_data(); + const Dtype* flush = bottom[2]->gpu_data(); + Dtype* X_acts = X_acts_.mutable_gpu_data(); + Dtype* C = top[0]->mutable_gpu_data(); + Dtype* H = top[1]->mutable_gpu_data(); + const int X_count = bottom[1]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + LSTMActsForward<<>>( + X_count, hidden_dim_, X, X_acts); + CUDA_POST_KERNEL_CHECK; + // NOLINT_NEXT_LINE(whitespace/operators) + LSTMUnitForward<<>>( + count, hidden_dim_, C_prev, X_acts, flush, C, H); + CUDA_POST_KERNEL_CHECK; +} + +template +__global__ void LSTMUnitBackward(const int nthreads, const int dim, + const Dtype* C_prev, const Dtype* X, const Dtype* C, const Dtype* H, + const Dtype* flush, const Dtype* C_diff, const Dtype* H_diff, + Dtype* C_prev_diff, Dtype* X_diff) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / dim; + const int d = index % dim; + const Dtype* X_offset = X + 4 * dim * n; + const Dtype i = X_offset[d]; + const Dtype f = X_offset[1 * dim + d]; + const Dtype o = X_offset[2 * dim + d]; + const Dtype g = X_offset[3 * dim + d]; + const Dtype c_prev = C_prev[index]; + const Dtype c = C[index]; + const Dtype tanh_c = tanh(c); + Dtype* c_prev_diff = C_prev_diff + index; + Dtype* X_diff_offset = X_diff + 4 * dim * n; + Dtype* i_diff = X_diff_offset + d; + Dtype* f_diff = X_diff_offset + 1 * dim + d; + Dtype* o_diff = X_diff_offset + 2 * dim + d; + Dtype* g_diff = X_diff_offset + 3 * dim + d; + const Dtype c_term_diff = + C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c); + const Dtype flush_n = flush[n]; + *c_prev_diff = flush_n * c_term_diff * f; + *i_diff = c_term_diff * g; + *f_diff = flush_n * c_term_diff * c_prev; + *o_diff = H_diff[index] * tanh_c; + *g_diff = c_term_diff * i; + } +} + +template +__global__ void LSTMActsBackward(const int nthreads, const int dim, + const Dtype* X_acts, const Dtype* X_acts_diff, Dtype* X_diff) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int x_dim = 4 * dim; + const int d = index % x_dim; + const Dtype X_act = X_acts[index]; + if (d < 3 * dim) { + X_diff[index] = X_acts_diff[index] * X_act * (Dtype(1) - X_act); + } else { + X_diff[index] = X_acts_diff[index] * (Dtype(1) - X_act * X_act); + } + } +} + +template +void LSTMUnitLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + CHECK(!propagate_down[2]) << "Cannot backpropagate to sequence indicators."; + if (!propagate_down[0] && !propagate_down[1]) { return; } + + const int count = top[1]->count(); + const Dtype* C_prev = bottom[0]->gpu_data(); + const Dtype* X_acts = X_acts_.gpu_data(); + const Dtype* flush = bottom[2]->gpu_data(); + const Dtype* C = top[0]->gpu_data(); + const Dtype* H = top[1]->gpu_data(); + const Dtype* C_diff = top[0]->gpu_diff(); + const Dtype* H_diff = top[1]->gpu_diff(); + Dtype* C_prev_diff = bottom[0]->mutable_gpu_diff(); + Dtype* X_acts_diff = X_acts_.mutable_gpu_diff(); + LSTMUnitBackward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>(count, hidden_dim_, + C_prev, X_acts, C, H, flush, C_diff, H_diff, C_prev_diff, X_acts_diff); + CUDA_POST_KERNEL_CHECK; + const int X_count = bottom[1]->count(); + Dtype* X_diff = bottom[1]->mutable_gpu_diff(); + LSTMActsBackward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + X_count, hidden_dim_, X_acts, X_acts_diff, X_diff); + CUDA_POST_KERNEL_CHECK; +} + +INSTANTIATE_LAYER_GPU_FUNCS(LSTMUnitLayer); + +} // namespace caffe diff --git a/src/caffe/test/test_lstm_layer.cpp b/src/caffe/test/test_lstm_layer.cpp new file mode 100644 index 00000000000..935a2c874f8 --- /dev/null +++ b/src/caffe/test/test_lstm_layer.cpp @@ -0,0 +1,265 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/sequence_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class LSTMLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + LSTMLayerTest() : num_output_(7) { + blob_bottom_vec_.push_back(&blob_bottom_); + blob_bottom_vec_.push_back(&blob_bottom_flush_); + blob_top_vec_.push_back(&blob_top_); + unit_blob_bottom_vec_.push_back(&unit_blob_bottom_c_prev_); + unit_blob_bottom_vec_.push_back(&unit_blob_bottom_x_); + unit_blob_bottom_vec_.push_back(&unit_blob_bottom_flush_); + unit_blob_top_vec_.push_back(&unit_blob_top_c_); + unit_blob_top_vec_.push_back(&unit_blob_top_h_); + + ReshapeBlobs(1, 3); + + layer_param_.mutable_recurrent_param()->set_num_output(num_output_); + FillerParameter* weight_filler = + layer_param_.mutable_recurrent_param()->mutable_weight_filler(); + weight_filler->set_type("gaussian"); + weight_filler->set_std(0.2); + FillerParameter* bias_filler = + layer_param_.mutable_recurrent_param()->mutable_bias_filler(); + bias_filler->set_type("gaussian"); + bias_filler->set_std(0.1); + } + + void ReshapeBlobs(int num_timesteps, int num_instances) { + blob_bottom_.Reshape(num_timesteps, num_instances, 3, 2); + vector shape(2); + shape[0] = num_timesteps; + shape[1] = num_instances; + blob_bottom_flush_.Reshape(shape); + shape.push_back(num_output_); + + shape[0] = 1; shape[1] = num_instances; shape[2] = 4 * num_output_; + unit_blob_bottom_x_.Reshape(shape); + shape[0] = 1; shape[1] = num_instances; shape[2] = num_output_; + unit_blob_bottom_c_prev_.Reshape(shape); + shape[0] = 1; shape[1] = 1; shape[2] = num_instances; + unit_blob_bottom_flush_.Reshape(shape); + + FillerParameter filler_param; + filler_param.set_min(-1); + filler_param.set_max(1); + UniformFiller filler(filler_param); + filler.Fill(&blob_bottom_); + filler.Fill(&unit_blob_bottom_c_prev_); + filler.Fill(&unit_blob_bottom_x_); + } + + int num_output_; + LayerParameter layer_param_; + Blob blob_bottom_; + Blob blob_bottom_flush_; + Blob blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; + + Blob unit_blob_bottom_flush_; + Blob unit_blob_bottom_c_prev_; + Blob unit_blob_bottom_x_; + Blob unit_blob_top_c_; + Blob unit_blob_top_h_; + vector*> unit_blob_bottom_vec_; + vector*> unit_blob_top_vec_; +}; + +TYPED_TEST_CASE(LSTMLayerTest, TestDtypesAndDevices); + +TYPED_TEST(LSTMLayerTest, TestSetUp) { + typedef typename TypeParam::Dtype Dtype; + LSTMLayer layer(this->layer_param_); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + vector expected_top_shape = this->blob_bottom_.shape(); + expected_top_shape.resize(3); + expected_top_shape[2] = this->num_output_; + EXPECT_TRUE(this->blob_top_.shape() == expected_top_shape); +} + +TYPED_TEST(LSTMLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + const int kNumTimesteps = 3; + const int num = this->blob_bottom_.shape(1); + this->ReshapeBlobs(kNumTimesteps, num); + + // Fill the flush blob with <0, 1, 1, ..., 1>, + // indicating a sequence that begins at the first timestep + // then continues for the rest of the sequence. + for (int t = 0; t < kNumTimesteps; ++t) { + for (int n = 0; n < num; ++n) { + this->blob_bottom_flush_.mutable_cpu_data()[t * num + n] = t > 0; + } + } + + // Process the full sequence in a single batch. + FillerParameter filler_param; + filler_param.set_mean(0); + filler_param.set_std(1); + GaussianFiller sequence_filler(filler_param); + sequence_filler.Fill(&this->blob_bottom_); + shared_ptr > layer(new LSTMLayer(this->layer_param_)); + Caffe::set_random_seed(1701); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + LOG(INFO) << "Calling forward for full sequence LSTM"; + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + // Copy the inputs and outputs to reuse/check them later. + Blob bottom_copy(this->blob_bottom_.shape()); + bottom_copy.CopyFrom(this->blob_bottom_); + Blob top_copy(this->blob_top_.shape()); + top_copy.CopyFrom(this->blob_top_); + + // Process the batch one timestep at a time; + // check that we get the same result. + this->ReshapeBlobs(1, num); + layer.reset(new LSTMLayer(this->layer_param_)); + Caffe::set_random_seed(1701); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + const int bottom_count = this->blob_bottom_.count(); + const int top_count = this->blob_top_.count(); + const Dtype kEpsilon = 1e-5; + for (int t = 0; t < kNumTimesteps; ++t) { + caffe_copy(bottom_count, bottom_copy.cpu_data() + t * bottom_count, + this->blob_bottom_.mutable_cpu_data()); + for (int n = 0; n < num; ++n) { + this->blob_bottom_flush_.mutable_cpu_data()[n] = t > 0; + } + LOG(INFO) << "Calling forward for LSTM timestep " << t; + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < top_count; ++i) { + ASSERT_LT(t * top_count + i, top_copy.count()); + EXPECT_NEAR(this->blob_top_.cpu_data()[i], + top_copy.cpu_data()[t * top_count + i], kEpsilon) + << "t = " << t << "; i = " << i; + } + } + + // Process the batch one timestep at a time with all flush blobs set to 0. + // Check that we get a different result, except in the first timestep. + Caffe::set_random_seed(1701); + layer.reset(new LSTMLayer(this->layer_param_)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + for (int t = 0; t < kNumTimesteps; ++t) { + caffe_copy(bottom_count, bottom_copy.cpu_data() + t * bottom_count, + this->blob_bottom_.mutable_cpu_data()); + for (int n = 0; n < num; ++n) { + this->blob_bottom_flush_.mutable_cpu_data()[n] = 0; + } + LOG(INFO) << "Calling forward for LSTM timestep " << t; + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < top_count; ++i) { + if (t == 0) { + EXPECT_NEAR(this->blob_top_.cpu_data()[i], + top_copy.cpu_data()[t * top_count + i], kEpsilon) + << "t = " << t << "; i = " << i; + } else { + EXPECT_NE(this->blob_top_.cpu_data()[i], + top_copy.cpu_data()[t * top_count + i]) + << "t = " << t << "; i = " << i; + } + } + } +} + +TYPED_TEST(LSTMLayerTest, TestLSTMUnitSetUp) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + LSTMUnitLayer layer(layer_param); + layer.SetUp(this->unit_blob_bottom_vec_, this->unit_blob_top_vec_); + const int num_axes = this->unit_blob_bottom_c_prev_.num_axes(); + ASSERT_EQ(num_axes, this->unit_blob_top_c_.num_axes()); + ASSERT_EQ(num_axes, this->unit_blob_top_h_.num_axes()); + for (int i = 0; i < num_axes; ++i) { + EXPECT_EQ(this->unit_blob_bottom_c_prev_.shape(i), + this->unit_blob_top_c_.shape(i)); + EXPECT_EQ(this->unit_blob_bottom_c_prev_.shape(i), + this->unit_blob_top_h_.shape(i)); + } +} + +TYPED_TEST(LSTMLayerTest, TestLSTMUnitGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + LSTMUnitLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + Dtype* flush_data = this->blob_bottom_flush_.mutable_cpu_data(); + flush_data[0] = 0; + flush_data[1] = 0; + flush_data[2] = 0; + checker.CheckGradientExhaustive(&layer, this->unit_blob_bottom_vec_, + this->unit_blob_top_vec_, 0); + checker.CheckGradientExhaustive(&layer, this->unit_blob_bottom_vec_, + this->unit_blob_top_vec_, 1); +} + +TYPED_TEST(LSTMLayerTest, TestLSTMUnitGradientNonZeroFlush) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + LSTMUnitLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + Dtype* flush_data = this->blob_bottom_flush_.mutable_cpu_data(); + flush_data[0] = 1; + flush_data[1] = 0; + flush_data[2] = 1; + checker.CheckGradientExhaustive(&layer, this->unit_blob_bottom_vec_, + this->unit_blob_top_vec_, 0); + checker.CheckGradientExhaustive(&layer, this->unit_blob_bottom_vec_, + this->unit_blob_top_vec_, 1); +} + +TYPED_TEST(LSTMLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LSTMLayer layer(this->layer_param_); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +TYPED_TEST(LSTMLayerTest, TestGradientNonZeroFlush) { + Caffe::set_phase(Caffe::TEST); + typedef typename TypeParam::Dtype Dtype; + LSTMLayer layer(this->layer_param_); + GradientChecker checker(1e-2, 1e-3); + for (int i = 0; i < this->blob_bottom_flush_.count(); ++i) { + this->blob_bottom_flush_.mutable_cpu_data()[i] = i > 2; + } + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +TYPED_TEST(LSTMLayerTest, TestGradientNonZeroFlushBufferSize2) { + Caffe::set_phase(Caffe::TEST); + typedef typename TypeParam::Dtype Dtype; + this->ReshapeBlobs(2, 2); + // fill the values + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(&this->blob_bottom_); + LSTMLayer layer(this->layer_param_); + GradientChecker checker(1e-2, 1e-3); + for (int i = 0; i < this->blob_bottom_flush_.count(); ++i) { + this->blob_bottom_flush_.mutable_cpu_data()[i] = i > 2; + } + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +} // namespace caffe