From 3b339b26465681f81d6007f4f68a977e8b350557 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Wed, 31 Oct 2018 18:37:08 -0700 Subject: [PATCH] Revert "add lstm clip" This reverts commit c4d7a0a60196c2ec53284fe2ce7c3c4abee2ced7. --- python/mxnet/gluon/rnn/rnn_layer.py | 43 ++++---------------- src/operator/cudnn_rnn-inl.h | 62 ++++++++--------------------- src/operator/rnn-inl.h | 21 ---------- tests/python/gpu/test_gluon_gpu.py | 24 ----------- tests/python/unittest/common.py | 5 ++- 5 files changed, 27 insertions(+), 128 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index ce28729a019e..a0eac4f7ff62 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -36,13 +36,12 @@ def __init__(self, hidden_size, num_layers, layout, i2h_weight_initializer, h2h_weight_initializer, i2h_bias_initializer, h2h_bias_initializer, mode, projection_size, h2r_weight_initializer, - lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan, **kwargs): super(_RNNLayer, self).__init__(**kwargs) assert layout in ('TNC', 'NTC'), \ "Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout self._hidden_size = hidden_size - self._projection_size = projection_size if projection_size else None + self._projection_size = projection_size self._num_layers = num_layers self._mode = mode self._layout = layout @@ -54,14 +53,11 @@ def __init__(self, hidden_size, num_layers, layout, self._i2h_bias_initializer = i2h_bias_initializer self._h2h_bias_initializer = h2h_bias_initializer self._h2r_weight_initializer = h2r_weight_initializer - self._lstm_state_clip_min = lstm_state_clip_min - self._lstm_state_clip_max = lstm_state_clip_max - self._lstm_state_clip_nan = lstm_state_clip_nan self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] ng, ni, nh = self._gates, input_size, hidden_size - if not projection_size: + if projection_size is None: for i in range(num_layers): for j in ['l', 'r'][:self._dir]: self._register_param('{}{}_i2h_weight'.format(j, i), @@ -142,9 +138,7 @@ def state_info(self, batch_size=0): def _unfuse(self): """Unfuses the fused RNN in to a stack of rnn cells.""" - assert not self._projection_size, "_unfuse does not support projection layer yet!" - assert not self._lstm_state_clip_min and not self._lstm_state_clip_max, \ - "_unfuse does not support state clipping yet!" + assert self._projection_size is None, "_unfuse does not support projection layer yet!" get_cell = {'rnn_relu': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size, activation='relu', **kwargs), @@ -259,10 +253,7 @@ def _forward_kernel(self, F, inputs, states, **kwargs): rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size, projection_size=self._projection_size, num_layers=self._num_layers, bidirectional=self._dir == 2, - p=self._dropout, state_outputs=True, mode=self._mode, - lstm_state_clip_min=self._lstm_state_clip_min, - lstm_state_clip_max=self._lstm_state_clip_max, - lstm_state_clip_nan=self._lstm_state_clip_nan) + p=self._dropout, state_outputs=True, mode=self._mode) if self._mode == 'lstm': outputs, states = rnn[0], [rnn[1], rnn[2]] @@ -362,8 +353,7 @@ def __init__(self, hidden_size, num_layers=1, activation='relu', dropout, bidirectional, input_size, i2h_weight_initializer, h2h_weight_initializer, i2h_bias_initializer, h2h_bias_initializer, - 'rnn_'+activation, None, None, None, None, None, - **kwargs) + 'rnn_'+activation, None, None, **kwargs) def state_info(self, batch_size=0): return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size), @@ -420,18 +410,6 @@ class LSTM(_RNNLayer): Initializer for the bias vector. projection_size: int, default None The number of features after projection. - h2r_weight_initializer : str or Initializer, default None - Initializer for the projected recurrent weights matrix, used for the linear - transformation of the recurrent state to the projected space. - state_clip_min : float or None, default None - Minimum clip value of LSTM states. This option must be used together with - state_clip_max. If None, clipping is not applied. - state_clip_max : float or None, default None - Maximum clip value of LSTM states. This option must be used together with - state_clip_min. If None, clipping is not applied. - state_clip_nan : boolean, default False - Whether to stop NaN from propagating in state by clipping it to min/max. - If the clipping range is not specified, this option is ignored. input_size: int, default 0 The number of expected features in the input x. If not specified, it will be inferred from input. @@ -475,16 +453,12 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC', dropout=0, bidirectional=False, input_size=0, i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - projection_size=None, h2r_weight_initializer=None, - state_clip_min=None, state_clip_max=None, state_clip_nan=False, - **kwargs): + projection_size=None, h2r_weight_initializer=None, **kwargs): super(LSTM, self).__init__(hidden_size, num_layers, layout, dropout, bidirectional, input_size, i2h_weight_initializer, h2h_weight_initializer, i2h_bias_initializer, h2h_bias_initializer, - 'lstm', projection_size, h2r_weight_initializer, - state_clip_min, state_clip_max, state_clip_nan, - **kwargs) + 'lstm', projection_size, h2r_weight_initializer, **kwargs) def state_info(self, batch_size=0): if self._projection_size is None: @@ -591,8 +565,7 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC', dropout, bidirectional, input_size, i2h_weight_initializer, h2h_weight_initializer, i2h_bias_initializer, h2h_bias_initializer, - 'gru', None, None, None, None, None, - **kwargs) + 'gru', None, None, **kwargs) def state_info(self, batch_size=0): return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size), diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h index 54c956de3d2e..e498952f8045 100644 --- a/src/operator/cudnn_rnn-inl.h +++ b/src/operator/cudnn_rnn-inl.h @@ -26,10 +26,7 @@ #ifndef MXNET_OPERATOR_CUDNN_RNN_INL_H_ #define MXNET_OPERATOR_CUDNN_RNN_INL_H_ -#define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && ((CUDNN_MAJOR == 7 && CUDNN_MINOR >= 1)\ - || CUDNN_MAJOR > 7) - -#define USE_CUDNN_LSTM_CLIP MXNET_USE_CUDNN == 1 && ((CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2)\ +#define USE_CUDNN_RNN_PROJ MXNET_USE_CUDNN == 1 && ((CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2)\ || CUDNN_MAJOR > 7) @@ -76,31 +73,16 @@ class CuDNNRNNOp : public Operator { default: LOG(FATAL) << "Not implmented"; } -#if USE_CUDNN_LSTM_PROJ +#if USE_CUDNN_RNN_PROJ if (param_.projection_size.has_value()) { CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Projection is only supported for LSTM."; CHECK_GE(param_.state_size, param_.projection_size.value()) - << "State size must be larger than projection size."; + << "State size must be larger than state size."; } #else CHECK(!param_.projection_size.has_value()) << "Projection is only supported for LSTM with CuDNN version later than 7.1.1."; -#endif -#if USE_CUDNN_LSTM_CLIP - if (param_.lstm_state_clip_min.has_value() - || param_.lstm_state_clip_max.has_value()) { - CHECK_EQ(param_.mode, rnn_enum::kLstm) - << "State clipping is only supported for LSTM."; - CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value()) - << "lstm_state_clip_min and lstm_state_clip_max must be specified together."; - CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value()) - << "lstm_state_clip_max must be greater or equal to lstm_state_clip_min"; - } -#else - CHECK(!param_.lstm_state_clip_min.has_value() - && !param_.lstm_state_clip_max.has_value()) - << "State clipping is only supported for LSTM with CuDNN version later than 7.2.1."; #endif // RNN Direction direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; @@ -126,7 +108,7 @@ class CuDNNRNNOp : public Operator { CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_)); CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_)); CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_)); CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_)); @@ -163,7 +145,7 @@ class CuDNNRNNOp : public Operator { Storage::Get()->Free(dropout_states_); } } - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_)); CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_)); CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_)); @@ -215,7 +197,7 @@ class CuDNNRNNOp : public Operator { Tensor temp_space = ctx.requested[rnn_enum::kTempSpace].get_space_typed( mshadow::Shape1(temp_size), s); - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ std::vector seqLengthArray(param_.batch_size_, param_.seq_length_); CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_, dtype_, @@ -255,20 +237,8 @@ class CuDNNRNNOp : public Operator { nullptr)); } #endif - - #if USE_CUDNN_LSTM_CLIP - bool clip_state = param_.lstm_state_clip_min.has_value(); - bool clip_nan = param_.lstm_state_clip_nan; - CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_, - rnn_desc_, - clip_state ? CUDNN_RNN_CLIP_MINMAX : CUDNN_RNN_CLIP_NONE, - clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN, - clip_state ? param_.lstm_state_clip_min.value() : 0.0, - clip_state ? param_.lstm_state_clip_max.value() : 0.0)); - #endif - if (ctx.is_train) { - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_, rnn_desc_, x_data_desc_, @@ -321,7 +291,7 @@ class CuDNNRNNOp : public Operator { reserve_space_byte_)); #endif } else { - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_, rnn_desc_, x_data_desc_, @@ -440,7 +410,7 @@ class CuDNNRNNOp : public Operator { Tensor temp_space = ctx.requested[rnn_enum::kTempSpace].get_space_typed( mshadow::Shape1(temp_size), s); - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_, rnn_desc_, y_data_desc_, @@ -615,7 +585,7 @@ class CuDNNRNNOp : public Operator { strideA[0] = dimA[2] * dimA[1]; strideA[1] = dimA[2]; strideA[2] = 1; - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ int dimB[3]; int strideB[3]; dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1); @@ -627,7 +597,7 @@ class CuDNNRNNOp : public Operator { strideB[2] = 1; #endif - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, dtype_, 3, @@ -645,7 +615,7 @@ class CuDNNRNNOp : public Operator { 3, dimA, strideA)); - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, dtype_, 3, @@ -663,7 +633,7 @@ class CuDNNRNNOp : public Operator { 3, dimA, strideA)); - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, dtype_, 3, @@ -681,7 +651,7 @@ class CuDNNRNNOp : public Operator { 3, dimA, strideA)); - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, dtype_, 3, @@ -748,7 +718,7 @@ class CuDNNRNNOp : public Operator { #endif CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type)); #endif - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ if (param_.projection_size.has_value()) { CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_, rnn_desc_, @@ -844,7 +814,7 @@ class CuDNNRNNOp : public Operator { size_t workspace_byte_, reserve_space_byte_, dropout_byte_; int workspace_size_, dropout_size_; std::vector x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_; - #if USE_CUDNN_LSTM_PROJ + #if USE_CUDNN_RNN_PROJ cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_; #endif cudnnTensorDescriptor_t hx_desc_, cx_desc_; diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 545e31bd8ff8..6d260afc0605 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -164,8 +164,6 @@ struct RNNParam : public dmlc::Parameter { int seq_length_, batch_size_, input_size_; bool lstm_q_; // whether type is lstm dmlc::optional projection_size; - dmlc::optional lstm_state_clip_min, lstm_state_clip_max; - bool lstm_state_clip_nan; DMLC_DECLARE_PARAMETER(RNNParam) { DMLC_DECLARE_FIELD(state_size) @@ -194,21 +192,6 @@ struct RNNParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(projection_size) .set_default(dmlc::optional()) .describe("size of project size"); - - DMLC_DECLARE_FIELD(lstm_state_clip_min) - .set_default(dmlc::optional()) - .describe("Minimum clip value of LSTM states. This option must be used together with " - "lstm_state_clip_max."); - - DMLC_DECLARE_FIELD(lstm_state_clip_max) - .set_default(dmlc::optional()) - .describe("Maximum clip value of LSTM states. This option must be used together with " - "lstm_state_clip_min."); - - DMLC_DECLARE_FIELD(lstm_state_clip_nan) - .set_default(false) - .describe("Whether to stop NaN from propagating in state by clipping it to min/max. " - "If clipping range is not specified, this option is ignored."); } }; @@ -384,10 +367,6 @@ class RNNOp : public Operator{ if (param_.projection_size.has_value()) { LOG(FATAL) << "hidden layer projection is only supported for GPU with CuDNN later than 7.1.1"; } - if (param_.lstm_state_clip_min.has_value() - || param_.lstm_state_clip_max.has_value()) { - LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN later than 7.2.1"; - } } ~RNNOp() { diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 780dbad07fc3..1514bafbc0ed 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -130,30 +130,6 @@ def test_lstmp(): [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True) -@with_seed() -@assert_raises_cudnn_not_satisfied(min_version='7.2.1') -def test_lstm_clip(): - hidden_size, projection_size = 4096, 2048 - batch_size, seq_len = 32, 80 - input_size = 50 - clip_min, clip_max, clip_nan = -5, 5, True - lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0)) - lstm_states = [mx.nd.uniform(shape=(2, batch_size, projection_size), ctx=mx.gpu(0)), - mx.nd.uniform(shape=(2, batch_size, hidden_size), ctx=mx.gpu(0))] - lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, - input_size=input_size, prefix='lstm0_', - bidirectional=True, - state_clip_min=clip_min, - state_clip_max=clip_max, - state_clip_nan=clip_nan) - lstm_layer.initialize(ctx=mx.gpu(0)) - with autograd.record(): - _, layer_output_states = lstm_layer(lstm_input, lstm_states) - cell_states = layer_output_states[0].asnumpy() - assert (cell_states >= clip_min).all() and (cell_states <= clip_max).all() - assert not np.isnan(cell_states).any() - - @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnn_layer(): diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index abfba73ab727..dce034fda64b 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -101,11 +101,12 @@ def test_helper(orig_test): def test_new(*args, **kwargs): cudnn_off = os.getenv('CUDNN_OFF_TEST_ONLY') == 'true' cudnn_env_version = os.getenv('CUDNN_VERSION', None if cudnn_off else '7.3.1') - cudnn_test_disabled = cudnn_off or cudnn_env_version < min_version + cudnn_test_disabled = not cudnn_env_version or cudnn_env_version < min_version if not cudnn_test_disabled or mx.context.current_context().device_type == 'cpu': orig_test(*args, **kwargs) else: - assert_raises((MXNetError, RuntimeError), orig_test, *args, **kwargs) + errors = (MXNetError, RuntimeError) + assert_raises(errors, orig_test, *args, **kwargs) return test_new return test_helper