Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Gluon LSTM Projection and Clipping #13056

Merged
merged 9 commits into from
Nov 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ unittest_ubuntu_python2_gpu() {
export PYTHONPATH=./python/
export MXNET_MKLDNN_DEBUG=1 # Ignored if not present
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
export CUDNN_VERSION=7.0.3
nosetests-2.7 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
}

Expand Down Expand Up @@ -734,6 +735,7 @@ unittest_ubuntu_python3_gpu() {
export PYTHONPATH=./python/
export MXNET_MKLDNN_DEBUG=1 # Ignored if not present
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
export CUDNN_VERSION=7.0.3
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
}

Expand All @@ -750,6 +752,7 @@ unittest_ubuntu_tensorrt_gpu() {
export PYTHONPATH=./python/
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
export CUDNN_VERSION=7.0.3
python tests/python/tensorrt/lenet5_train.py
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose --nocapture tests/python/tensorrt/
}
Expand All @@ -761,6 +764,7 @@ unittest_ubuntu_python2_quantization_gpu() {
export PYTHONPATH=./python/
export MXNET_MKLDNN_DEBUG=1 # Ignored if not present
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
export CUDNN_VERSION=7.0.3
nosetests-2.7 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_quantization_gpu.xml --verbose tests/python/quantization_gpu
}

Expand All @@ -771,6 +775,7 @@ unittest_ubuntu_python3_quantization_gpu() {
export PYTHONPATH=./python/
export MXNET_MKLDNN_DEBUG=1 # Ignored if not present
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
export CUDNN_VERSION=7.0.3
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_quantization_gpu.xml --verbose tests/python/quantization_gpu
}

Expand Down Expand Up @@ -865,6 +870,7 @@ unittest_centos7_cpu() {
unittest_centos7_gpu() {
set -ex
cd /work/mxnet
export CUDNN_VERSION=7.0.3
python3.6 -m "nose" $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
}

Expand Down
130 changes: 100 additions & 30 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@ def __init__(self, hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
mode, **kwargs):
mode, projection_size, h2r_weight_initializer,
lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
**kwargs):
super(_RNNLayer, self).__init__(**kwargs)
assert layout in ('TNC', 'NTC'), \
"Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
self._hidden_size = hidden_size
self._projection_size = projection_size if projection_size else None
self._num_layers = num_layers
self._mode = mode
self._layout = layout
Expand All @@ -50,25 +53,50 @@ def __init__(self, hidden_size, num_layers, layout,
self._h2h_weight_initializer = h2h_weight_initializer
self._i2h_bias_initializer = i2h_bias_initializer
self._h2h_bias_initializer = h2h_bias_initializer
self._h2r_weight_initializer = h2r_weight_initializer
self._lstm_state_clip_min = lstm_state_clip_min
self._lstm_state_clip_max = lstm_state_clip_max
self._lstm_state_clip_nan = lstm_state_clip_nan

self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]

ng, ni, nh = self._gates, input_size, hidden_size
for i in range(num_layers):
for j in ['l', 'r'][:self._dir]:
self._register_param('{}{}_i2h_weight'.format(j, i),
shape=(ng*nh, ni),
init=i2h_weight_initializer)
self._register_param('{}{}_h2h_weight'.format(j, i),
shape=(ng*nh, nh),
init=h2h_weight_initializer)
self._register_param('{}{}_i2h_bias'.format(j, i),
shape=(ng*nh,),
init=i2h_bias_initializer)
self._register_param('{}{}_h2h_bias'.format(j, i),
shape=(ng*nh,),
init=h2h_bias_initializer)
ni = nh * self._dir
if not projection_size:
for i in range(num_layers):
for j in ['l', 'r'][:self._dir]:
self._register_param('{}{}_i2h_weight'.format(j, i),
shape=(ng*nh, ni),
init=i2h_weight_initializer)
self._register_param('{}{}_h2h_weight'.format(j, i),
shape=(ng*nh, nh),
init=h2h_weight_initializer)
self._register_param('{}{}_i2h_bias'.format(j, i),
shape=(ng*nh,),
init=i2h_bias_initializer)
self._register_param('{}{}_h2h_bias'.format(j, i),
shape=(ng*nh,),
init=h2h_bias_initializer)
ni = nh * self._dir
else:
np = self._projection_size
for i in range(num_layers):
for j in ['l', 'r'][:self._dir]:
self._register_param('{}{}_i2h_weight'.format(j, i),
shape=(ng*nh, ni),
init=i2h_weight_initializer)
self._register_param('{}{}_h2h_weight'.format(j, i),
shape=(ng*nh, np),
init=h2h_weight_initializer)
self._register_param('{}{}_i2h_bias'.format(j, i),
shape=(ng*nh,),
init=i2h_bias_initializer)
self._register_param('{}{}_h2h_bias'.format(j, i),
shape=(ng*nh,),
init=h2h_bias_initializer)
self._register_param('{}{}_h2r_weight'.format(j, i),
shape=(np, nh),
init=h2r_weight_initializer)
ni = np * self._dir

def _register_param(self, name, shape, init):
p = self.params.get(name, shape=shape, init=init,
Expand Down Expand Up @@ -114,6 +142,9 @@ def state_info(self, batch_size=0):

def _unfuse(self):
"""Unfuses the fused RNN in to a stack of rnn cells."""
assert not self._projection_size, "_unfuse does not support projection layer yet!"
assert not self._lstm_state_clip_min and not self._lstm_state_clip_max, \
"_unfuse does not support state clipping yet!"
get_cell = {'rnn_relu': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size,
activation='relu',
**kwargs),
Expand Down Expand Up @@ -189,7 +220,7 @@ def hybrid_forward(self, F, inputs, states=None, **kwargs):
skip_states = states is None
if skip_states:
if F is ndarray:
states = self.begin_state(batch_size, ctx=inputs.context)
states = self.begin_state(batch_size, ctx=inputs.context, dtype=inputs.dtype)
else:
states = self.begin_state(0, func=symbol.zeros)
if isinstance(states, tensor_types):
Expand All @@ -209,16 +240,29 @@ def _forward_kernel(self, F, inputs, states, **kwargs):
""" forward using CUDNN or CPU kenrel"""
if self._layout == 'NTC':
inputs = F.swapaxes(inputs, dim1=0, dim2=1)
params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
for t in ['weight', 'bias']
for l in range(self._num_layers)
for d in ['l', 'r'][:self._dir]
for g in ['i2h', 'h2h'])
if self._projection_size is None:
params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
for t in ['weight', 'bias']
for l in range(self._num_layers)
for d in ['l', 'r'][:self._dir]
for g in ['i2h', 'h2h'])
else:
params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
for t in ['weight', 'bias']
for l in range(self._num_layers)
for d in ['l', 'r'][:self._dir]
for g in ['i2h', 'h2h', 'h2r']
if g != 'h2r' or t != 'bias')

params = F._internal._rnn_param_concat(*params, dim=0)

rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
projection_size=self._projection_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode)
p=self._dropout, state_outputs=True, mode=self._mode,
lstm_state_clip_min=self._lstm_state_clip_min,
lstm_state_clip_max=self._lstm_state_clip_max,
lstm_state_clip_nan=self._lstm_state_clip_nan)

if self._mode == 'lstm':
outputs, states = rnn[0], [rnn[1], rnn[2]]
Expand Down Expand Up @@ -318,7 +362,8 @@ def __init__(self, hidden_size, num_layers=1, activation='relu',
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'rnn_'+activation, **kwargs)
'rnn_'+activation, None, None, None, None, False,
**kwargs)

def state_info(self, batch_size=0):
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
Expand Down Expand Up @@ -373,6 +418,20 @@ class LSTM(_RNNLayer):
to zero.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
projection_size: int, default None
The number of features after projection.
h2r_weight_initializer : str or Initializer, default None
Initializer for the projected recurrent weights matrix, used for the linear
transformation of the recurrent state to the projected space.
state_clip_min : float or None, default None
Minimum clip value of LSTM states. This option must be used together with
state_clip_max. If None, clipping is not applied.
state_clip_max : float or None, default None
Maximum clip value of LSTM states. This option must be used together with
state_clip_min. If None, clipping is not applied.
state_clip_nan : boolean, default False
Whether to stop NaN from propagating in state by clipping it to min/max.
If the clipping range is not specified, this option is ignored.
input_size: int, default 0
The number of expected features in the input x.
If not specified, it will be inferred from input.
Expand Down Expand Up @@ -416,18 +475,28 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
dropout=0, bidirectional=False, input_size=0,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
projection_size=None, h2r_weight_initializer=None,
state_clip_min=None, state_clip_max=None, state_clip_nan=False,
**kwargs):
super(LSTM, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'lstm', **kwargs)
'lstm', projection_size, h2r_weight_initializer,
state_clip_min, state_clip_max, state_clip_nan,
**kwargs)

def state_info(self, batch_size=0):
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'},
{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]
if self._projection_size is None:
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'},
{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]
else:
return [{'shape': (self._num_layers * self._dir, batch_size, self._projection_size),
'__layout__': 'LNC'},
{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]


class GRU(_RNNLayer):
Expand Down Expand Up @@ -522,7 +591,8 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'gru', **kwargs)
'gru', None, None, None, None, False,
**kwargs)

def state_info(self, batch_size=0):
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non
assert(False), "unknown storage type"
return False

def rand_ndarray(shape, stype, density=None, dtype=None,
def rand_ndarray(shape, stype='default', density=None, dtype=None,
modifier_func=None, shuffle_csr_indices=False, distribution=None):
if stype == 'default':
arr = mx.nd.array(random_arrays(shape), dtype=dtype)
Expand Down
Loading