Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v2.0] RNN: use rnn_params #20384

Merged
merged 22 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 27 additions & 97 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ class _RNNLayer(HybridBlock):
"""Implementation of recurrent layers."""
def __init__(self, hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
mode, projection_size, h2r_weight_initializer,
param_initializer, mode, projection_size,
lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
dtype, use_sequence_length=False, **kwargs):
super(_RNNLayer, self).__init__(**kwargs)
Expand All @@ -50,11 +48,6 @@ def __init__(self, hidden_size, num_layers, layout,
self._dropout = dropout
self._dir = 2 if bidirectional else 1
self._input_size = input_size
self._i2h_weight_initializer = i2h_weight_initializer
self._h2h_weight_initializer = h2h_weight_initializer
self._i2h_bias_initializer = i2h_bias_initializer
self._h2h_bias_initializer = h2h_bias_initializer
self._h2r_weight_initializer = h2r_weight_initializer
self._lstm_state_clip_min = lstm_state_clip_min
self._lstm_state_clip_max = lstm_state_clip_max
self._lstm_state_clip_nan = lstm_state_clip_nan
Expand All @@ -64,48 +57,8 @@ def __init__(self, hidden_size, num_layers, layout,

self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]

ng, ni, nh = self._gates, input_size, hidden_size
if not projection_size:
for i in range(num_layers):
for j in ['l', 'r'][:self._dir]:
self._register_param('{}{}_i2h_weight'.format(j, i),
shape=(ng*nh, ni),
init=i2h_weight_initializer, dtype=dtype)
self._register_param('{}{}_h2h_weight'.format(j, i),
shape=(ng*nh, nh),
init=h2h_weight_initializer, dtype=dtype)
self._register_param('{}{}_i2h_bias'.format(j, i),
shape=(ng*nh,),
init=i2h_bias_initializer, dtype=dtype)
self._register_param('{}{}_h2h_bias'.format(j, i),
shape=(ng*nh,),
init=h2h_bias_initializer, dtype=dtype)
ni = nh * self._dir
else:
ps = self._projection_size
for i in range(num_layers):
for j in ['l', 'r'][:self._dir]:
self._register_param('{}{}_i2h_weight'.format(j, i),
shape=(ng*nh, ni),
init=i2h_weight_initializer, dtype=dtype)
self._register_param('{}{}_h2h_weight'.format(j, i),
shape=(ng*nh, ps),
init=h2h_weight_initializer, dtype=dtype)
self._register_param('{}{}_i2h_bias'.format(j, i),
shape=(ng*nh,),
init=i2h_bias_initializer, dtype=dtype)
self._register_param('{}{}_h2h_bias'.format(j, i),
shape=(ng*nh,),
init=h2h_bias_initializer, dtype=dtype)
self._register_param('{}{}_h2r_weight'.format(j, i),
shape=(ps, nh),
init=h2r_weight_initializer, dtype=dtype)
ni = ps * self._dir

def _register_param(self, name, shape, init, dtype):
p = Parameter(name, shape=shape, init=init, allow_deferred_init=True, dtype=dtype)
setattr(self, name, p)
return p
self.rnn_param = Parameter('rnn_param', shape=(-1,), init=param_initializer,
allow_deferred_init=True, dtype=dtype)

def __repr__(self):
s = '{name}({mapping}, {_layout}'
Expand All @@ -116,8 +69,7 @@ def __repr__(self):
if self._dir == 2:
s += ', bidirectional'
s += ')'
shape = self.l0_i2h_weight.shape
mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates)
mapping = '{0} -> {1}'.format(self._input_size if self._input_size else None, self._hidden_size)
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)
Expand Down Expand Up @@ -196,37 +148,26 @@ def forward(self, inputs, states, sequence_length=None):
def infer_shape(self, inputs, *args):
assert inputs.ndim == 3, \
"Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"
if not self._projection_size:
step = self._hidden_size
else:
step = self._projection_size
ni = inputs.shape[2]
for i in range(self._num_layers):
for j in ['l', 'r'][:self._dir]:
name = '{}{}_i2h_weight'.format(j, i)
getattr(self, name).shape = (self._gates*self._hidden_size, ni)
ni = step * self._dir
self._input_size = inputs.shape[2]
ng, ni, nh = self._gates, inputs.shape[2], self._hidden_size

size = nh * self._dir * ng
size1 = (ni + nh + 2) * size # first layer size
size2 = (nh * self._dir + nh + 2) * size # second layer size
if self._projection_size:
size1 = (ni + self._projection_size + 2) * size # first layer size
size2 = (self._projection_size * self._dir + \
self._projection_size + 2) * size # second layer size
param_size = size1 + (self._num_layers - 1) * size2
if self._projection_size:
param_size += self._projection_size * nh * self._num_layers * self._dir
self.rnn_param.shape = (param_size, )

def _forward_kernel(self, inputs, states, sequence_length):
""" forward using CUDNN or CPU kenrel"""
ctx = inputs.ctx
if self._layout == 'NTC':
inputs = np.swapaxes(inputs, 0, 1)
if self._projection_size is None:
params = (getattr(self, '{}{}_{}_{}'.format(d, l, g, t)).data(ctx).reshape(-1)
for t in ['weight', 'bias']
for l in range(self._num_layers)
for d in ['l', 'r'][:self._dir]
for g in ['i2h', 'h2h'])
else:
params = (getattr(self, '{}{}_{}_{}'.format(d, l, g, t)).data(ctx).reshape(-1)
for t in ['weight', 'bias']
for l in range(self._num_layers)
for d in ['l', 'r'][:self._dir]
for g in ['i2h', 'h2h', 'h2r']
if g != 'h2r' or t != 'bias')

params = np.concatenate(params, axis=0)

if self._use_sequence_length:
rnn_args = states + [sequence_length]
Expand All @@ -238,7 +179,8 @@ def _forward_kernel(self, inputs, states, sequence_length):
new_args = args.as_in_ctx(ctx)
rnn_args_ctx.append(new_args)

rnn = npx.rnn(inputs, params, *rnn_args_ctx, use_sequence_length=self._use_sequence_length,
rnn = npx.rnn(inputs, self.rnn_param.data().as_in_ctx(ctx), *rnn_args_ctx,
use_sequence_length=self._use_sequence_length,
state_size=self._hidden_size, projection_size=self._projection_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode,
Expand Down Expand Up @@ -334,15 +276,11 @@ class RNN(_RNNLayer):
>>> output, hn = layer(input, h0)
"""
def __init__(self, hidden_size, num_layers=1, activation='relu',
layout='TNC', dropout=0, bidirectional=False,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
layout='TNC', dropout=0, bidirectional=False, param_initializer=None,
input_size=0, dtype='float32', **kwargs):
super(RNN, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'rnn_'+activation, None, None, None, None, False,
dropout, bidirectional, input_size, param_initializer,
'rnn_'+activation, None, None, None, False,
dtype, **kwargs)

def state_info(self, batch_size=0):
Expand Down Expand Up @@ -451,16 +389,12 @@ class LSTM(_RNNLayer):
"""
def __init__(self, hidden_size, num_layers=1, layout='TNC',
dropout=0, bidirectional=False, input_size=0,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
projection_size=None, h2r_weight_initializer=None,
param_initializer=None, projection_size=None,
state_clip_min=None, state_clip_max=None, state_clip_nan=False,
dtype='float32', **kwargs):
super(LSTM, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'lstm', projection_size, h2r_weight_initializer,
param_initializer, 'lstm', projection_size,
state_clip_min, state_clip_max, state_clip_nan,
dtype, **kwargs)

Expand Down Expand Up @@ -560,14 +494,10 @@ class GRU(_RNNLayer):
"""
def __init__(self, hidden_size, num_layers=1, layout='TNC',
dropout=0, bidirectional=False, input_size=0,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will need an initializer for the fused parameter and use it as default. With this default initializer for RNN layers, the bias terms should be initialized as 0s.

dtype='float32', **kwargs):
param_initializer=None, dtype='float32', **kwargs):
super(GRU, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'gru', None, None, None, None, False,
param_initializer, 'gru', None, None, None, False,
dtype, **kwargs)

def state_info(self, batch_size=0):
Expand Down
51 changes: 51 additions & 0 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,54 @@ def _check_block_input_np_ndarrays(inputs):
for i in inputs:
_check_block_input_np_ndarrays(i)
# pylint: enable=no-else-raise


# pylint: disable=too-many-nested-blocks
def split_rnn_params(param, mode, num_layers, input_size, hidden_size, bidirectional, projection_size=None):
"""Split rnn layer parameter into weight and bias in different layer."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring

gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
dir = 2 if bidirectional else 1
param_dict = {}
begin = 0
if not projection_size:
for p in ['weight', 'bias']:
for l in range(num_layers):
for d in ['l', 'r'][:dir]:
for g in ['i2h', 'h2h']:
ni = input_size
if l != 0:
ni = hidden_size * dir
if g == 'h2h':
ni = hidden_size
shape0 = gates * hidden_size
if p == 'weight':
cur_len = shape0 * ni
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
param[begin:begin+cur_len].reshape(shape0, ni)
else:
cur_len = shape0
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
param[begin:begin+cur_len].reshape(shape0,)
begin += cur_len
else:
for p in ['weight', 'bias']:
for l in range(num_layers):
for d in ['l', 'r'][:dir]:
for g in ['i2h', 'h2h', 'h2r']:
if g != 'h2r' or p != 'bias':
ni = input_size
if l != 0:
ni = projection_size * dir
if g == 'h2h':
ni = projection_size
shape0 = gates * hidden_size
if p == 'weight':
cur_len = shape0 * ni
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
param[begin:begin+cur_len].reshape(shape0, ni)
else:
cur_len = shape0
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
param[begin:begin+cur_len].reshape(shape0,)
begin += cur_len
return param_dict
2 changes: 1 addition & 1 deletion tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,7 @@ def forward(self, x):
x = self.encoders[i](x)
return x
net = Network()
net.initialize(mx.init.Xavier(), ctx=mx.cpu())
net.initialize(mx.init.Uniform(), ctx=mx.cpu())
net.hybridize()
x = onp.random.rand(32, 10, 10)
x = mx.np.array(x).as_in_context(mx.cpu())
Expand Down
Loading