Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add lstm clip
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Oct 31, 2018
1 parent b1cbe8f commit c6210a1
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 27 deletions.
43 changes: 35 additions & 8 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ def __init__(self, hidden_size, num_layers, layout,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
mode, projection_size, h2r_weight_initializer,
lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
**kwargs):
super(_RNNLayer, self).__init__(**kwargs)
assert layout in ('TNC', 'NTC'), \
"Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
self._hidden_size = hidden_size
self._projection_size = projection_size
self._projection_size = projection_size if projection_size else None
self._num_layers = num_layers
self._mode = mode
self._layout = layout
Expand All @@ -53,11 +54,14 @@ def __init__(self, hidden_size, num_layers, layout,
self._i2h_bias_initializer = i2h_bias_initializer
self._h2h_bias_initializer = h2h_bias_initializer
self._h2r_weight_initializer = h2r_weight_initializer
self._lstm_state_clip_min = lstm_state_clip_min
self._lstm_state_clip_max = lstm_state_clip_max
self._lstm_state_clip_nan = lstm_state_clip_nan

self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]

ng, ni, nh = self._gates, input_size, hidden_size
if projection_size is None:
if not projection_size:
for i in range(num_layers):
for j in ['l', 'r'][:self._dir]:
self._register_param('{}{}_i2h_weight'.format(j, i),
Expand Down Expand Up @@ -138,7 +142,9 @@ def state_info(self, batch_size=0):

def _unfuse(self):
"""Unfuses the fused RNN in to a stack of rnn cells."""
assert self._projection_size is None, "_unfuse does not support projection layer yet!"
assert not self._projection_size, "_unfuse does not support projection layer yet!"
assert not self._lstm_state_clip_min and not self._lstm_state_clip_max, \
"_unfuse does not support state clipping yet!"
get_cell = {'rnn_relu': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size,
activation='relu',
**kwargs),
Expand Down Expand Up @@ -253,7 +259,10 @@ def _forward_kernel(self, F, inputs, states, **kwargs):
rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
projection_size=self._projection_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode)
p=self._dropout, state_outputs=True, mode=self._mode,
lstm_state_clip_min=self._lstm_state_clip_min,
lstm_state_clip_max=self._lstm_state_clip_max,
lstm_state_clip_nan=self._lstm_state_clip_nan)

if self._mode == 'lstm':
outputs, states = rnn[0], [rnn[1], rnn[2]]
Expand Down Expand Up @@ -353,7 +362,8 @@ def __init__(self, hidden_size, num_layers=1, activation='relu',
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'rnn_'+activation, None, None, **kwargs)
'rnn_'+activation, None, None, None, None, None,
**kwargs)

def state_info(self, batch_size=0):
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
Expand Down Expand Up @@ -410,6 +420,18 @@ class LSTM(_RNNLayer):
Initializer for the bias vector.
projection_size: int, default None
The number of features after projection.
h2r_weight_initializer : str or Initializer, default None
Initializer for the projected recurrent weights matrix, used for the linear
transformation of the recurrent state to the projected space.
state_clip_min : float or None, default None
Minimum clip value of LSTM states. This option must be used together with
state_clip_max. If None, clipping is not applied.
state_clip_max : float or None, default None
Maximum clip value of LSTM states. This option must be used together with
state_clip_min. If None, clipping is not applied.
state_clip_nan : boolean, default False
Whether to stop NaN from propagating in state by clipping it to min/max.
If the clipping range is not specified, this option is ignored.
input_size: int, default 0
The number of expected features in the input x.
If not specified, it will be inferred from input.
Expand Down Expand Up @@ -453,12 +475,16 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
dropout=0, bidirectional=False, input_size=0,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
projection_size=None, h2r_weight_initializer=None, **kwargs):
projection_size=None, h2r_weight_initializer=None,
state_clip_min=None, state_clip_max=None, state_clip_nan=False,
**kwargs):
super(LSTM, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'lstm', projection_size, h2r_weight_initializer, **kwargs)
'lstm', projection_size, h2r_weight_initializer,
state_clip_min, state_clip_max, state_clip_nan,
**kwargs)

def state_info(self, batch_size=0):
if self._projection_size is None:
Expand Down Expand Up @@ -565,7 +591,8 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'gru', None, None, **kwargs)
'gru', None, None, None, None, None,
**kwargs)

def state_info(self, batch_size=0):
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
Expand Down
61 changes: 45 additions & 16 deletions src/operator/cudnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
#ifndef MXNET_OPERATOR_CUDNN_RNN_INL_H_
#define MXNET_OPERATOR_CUDNN_RNN_INL_H_

#define USE_CUDNN_RNN_PROJ MXNET_USE_CUDNN == 1 && ((CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2)\
#define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && ((CUDNN_MAJOR == 7 && CUDNN_MINOR >= 1)\
|| CUDNN_MAJOR > 7)
#define USE_CUDNN_LSTM_CLIP MXNET_USE_CUDNN == 1 && ((CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2)\
|| CUDNN_MAJOR > 7)


Expand Down Expand Up @@ -73,16 +75,31 @@ class CuDNNRNNOp : public Operator {
default:
LOG(FATAL) << "Not implmented";
}
#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
if (param_.projection_size.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm)
<< "Projection is only supported for LSTM.";
CHECK_GE(param_.state_size, param_.projection_size.value())
<< "State size must be larger than state size.";
<< "State size must be larger than projection size.";
}
#else
CHECK(!param_.projection_size.has_value())
<< "Projection is only supported for LSTM with CuDNN version later than 7.1.1.";
#endif
#if USE_CUDNN_LSTM_CLIP
if (param_.lstm_state_clip_min.has_value()
|| param_.lstm_state_clip_max.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm)
<< "State clipping is only supported for LSTM.";
CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value())
<< "lstm_state_clip_min and lstm_state_clip_max must be specified together.";
CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value())
<< "lstm_state_clip_max must be greater or equal to lstm_state_clip_min";
}
#else
CHECK(!param_.lstm_state_clip_min.has_value()
&& !param_.lstm_state_clip_max.has_value())
<< "State clipping is only supported for LSTM with CuDNN version later than 7.2.1.";
#endif
// RNN Direction
direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
Expand All @@ -108,7 +125,7 @@ class CuDNNRNNOp : public Operator {
CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));

#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_));
Expand Down Expand Up @@ -145,7 +162,7 @@ class CuDNNRNNOp : public Operator {
Storage::Get()->Free(dropout_states_);
}
}
#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_));
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_));
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_));
Expand Down Expand Up @@ -197,7 +214,7 @@ class CuDNNRNNOp : public Operator {
Tensor<gpu, 1, DType> temp_space =
ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
mshadow::Shape1(temp_size), s);
#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_);
CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
dtype_,
Expand Down Expand Up @@ -237,8 +254,20 @@ class CuDNNRNNOp : public Operator {
nullptr));
}
#endif

#if USE_CUDNN_LSTM_CLIP
bool clip_state = param_.lstm_state_clip_min.has_value();
bool clip_nan = param_.lstm_state_clip_nan;
CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_,
rnn_desc_,
clip_state ? CUDNN_RNN_CLIP_MINMAX : CUDNN_RNN_CLIP_NONE,
clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN,
clip_state ? param_.lstm_state_clip_min.value() : 0.0,
clip_state ? param_.lstm_state_clip_max.value() : 0.0));
#endif

if (ctx.is_train) {
#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
rnn_desc_,
x_data_desc_,
Expand Down Expand Up @@ -291,7 +320,7 @@ class CuDNNRNNOp : public Operator {
reserve_space_byte_));
#endif
} else {
#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
rnn_desc_,
x_data_desc_,
Expand Down Expand Up @@ -410,7 +439,7 @@ class CuDNNRNNOp : public Operator {
Tensor<gpu, 1, DType> temp_space =
ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
mshadow::Shape1(temp_size), s);
#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
rnn_desc_,
y_data_desc_,
Expand Down Expand Up @@ -585,7 +614,7 @@ class CuDNNRNNOp : public Operator {
strideA[0] = dimA[2] * dimA[1];
strideA[1] = dimA[2];
strideA[2] = 1;
#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
int dimB[3];
int strideB[3];
dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
Expand All @@ -597,7 +626,7 @@ class CuDNNRNNOp : public Operator {
strideB[2] = 1;
#endif

#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
dtype_,
3,
Expand All @@ -615,7 +644,7 @@ class CuDNNRNNOp : public Operator {
3,
dimA,
strideA));
#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
dtype_,
3,
Expand All @@ -633,7 +662,7 @@ class CuDNNRNNOp : public Operator {
3,
dimA,
strideA));
#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
dtype_,
3,
Expand All @@ -651,7 +680,7 @@ class CuDNNRNNOp : public Operator {
3,
dimA,
strideA));
#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
dtype_,
3,
Expand Down Expand Up @@ -718,7 +747,7 @@ class CuDNNRNNOp : public Operator {
#endif
CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type));
#endif
#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
if (param_.projection_size.has_value()) {
CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_,
rnn_desc_,
Expand Down Expand Up @@ -814,7 +843,7 @@ class CuDNNRNNOp : public Operator {
size_t workspace_byte_, reserve_space_byte_, dropout_byte_;
int workspace_size_, dropout_size_;
std::vector<cudnnTensorDescriptor_t> x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_;
#if USE_CUDNN_RNN_PROJ
#if USE_CUDNN_LSTM_PROJ
cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_;
#endif
cudnnTensorDescriptor_t hx_desc_, cx_desc_;
Expand Down
21 changes: 21 additions & 0 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
int seq_length_, batch_size_, input_size_;
bool lstm_q_; // whether type is lstm
dmlc::optional<int> projection_size;
dmlc::optional<double> lstm_state_clip_min, lstm_state_clip_max;
bool lstm_state_clip_nan;

DMLC_DECLARE_PARAMETER(RNNParam) {
DMLC_DECLARE_FIELD(state_size)
Expand Down Expand Up @@ -192,6 +194,21 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
DMLC_DECLARE_FIELD(projection_size)
.set_default(dmlc::optional<int>())
.describe("size of project size");

DMLC_DECLARE_FIELD(lstm_state_clip_min)
.set_default(dmlc::optional<double>())
.describe("Minimum clip value of LSTM states. This option must be used together with "
"lstm_state_clip_max.");

DMLC_DECLARE_FIELD(lstm_state_clip_max)
.set_default(dmlc::optional<double>())
.describe("Maximum clip value of LSTM states. This option must be used together with "
"lstm_state_clip_min.");

DMLC_DECLARE_FIELD(lstm_state_clip_nan)
.set_default(false)
.describe("Whether to stop NaN from propagating in state by clipping it to min/max. "
"If clipping range is not specified, this option is ignored.");
}
};

Expand Down Expand Up @@ -367,6 +384,10 @@ class RNNOp : public Operator{
if (param_.projection_size.has_value()) {
LOG(FATAL) << "hidden layer projection is only supported for GPU with CuDNN later than 7.1.1";
}
if (param_.lstm_state_clip_min.has_value()
|| param_.lstm_state_clip_max.has_value()) {
LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN later than 7.2.1";
}
}

~RNNOp() {
Expand Down
24 changes: 24 additions & 0 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,30 @@ def test_lstmp():
[mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True)


@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='7.2.1')
def test_lstm_clip():
hidden_size, projection_size = 4096, 2048
batch_size, seq_len = 32, 80
input_size = 50
clip_min, clip_max, clip_nan = -5, 5, True
lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0))
lstm_states = [mx.nd.uniform(shape=(2, batch_size, projection_size), ctx=mx.gpu(0)),
mx.nd.uniform(shape=(2, batch_size, hidden_size), ctx=mx.gpu(0))]
lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size,
input_size=input_size, prefix='lstm0_',
bidirectional=True,
state_clip_min=clip_min,
state_clip_max=clip_max,
state_clip_nan=clip_nan)
lstm_layer.initialize(ctx=mx.gpu(0))
with autograd.record():
_, layer_output_states = lstm_layer(lstm_input, lstm_states)
cell_states = layer_output_states[0].asnumpy()
assert (cell_states >= clip_min).all() and (cell_states <= clip_max).all()
assert not np.isnan(cell_states).any()


@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_rnn_layer():
Expand Down
5 changes: 2 additions & 3 deletions tests/python/unittest/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,11 @@ def test_helper(orig_test):
def test_new(*args, **kwargs):
cudnn_off = os.getenv('CUDNN_OFF_TEST_ONLY') == 'true'
cudnn_env_version = os.getenv('CUDNN_VERSION', None if cudnn_off else '7.3.1')
cudnn_test_disabled = not cudnn_env_version or cudnn_env_version < min_version
cudnn_test_disabled = cudnn_off or cudnn_env_version < min_version
if not cudnn_test_disabled or mx.context.current_context().device_type == 'cpu':
orig_test(*args, **kwargs)
else:
errors = (MXNetError, RuntimeError)
assert_raises(errors, orig_test, *args, **kwargs)
assert_raises((MXNetError, RuntimeError), orig_test, *args, **kwargs)
return test_new
return test_helper

Expand Down

0 comments on commit c6210a1

Please sign in to comment.