Skip to content

Commit

Permalink
Residual unroll (apache#6397)
Browse files Browse the repository at this point in the history
* residual unroll

* unroll for residual cell

* merge_outputs fix
  • Loading branch information
szha authored and Olivier committed May 30, 2017
1 parent 266bb33 commit 62637cd
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 3 deletions.
31 changes: 30 additions & 1 deletion python/mxnet/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,26 @@ def __call__(self, inputs, states):
output = symbol.elemwise_add(output, inputs, name="%s_plus_residual" % output.name)
return output, states

def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
self.reset()

self.base_cell._modified = False
outputs, states = self.base_cell.unroll(length, inputs=inputs, begin_state=begin_state,
layout=layout, merge_outputs=merge_outputs)
self.base_cell._modified = True

merge_outputs = isinstance(outputs, symbol.Symbol) if merge_outputs is None else \
merge_outputs
inputs, _ = _normalize_sequence(length, inputs, layout, merge_outputs)
if merge_outputs:
outputs = symbol.elemwise_add(outputs, inputs, name="%s_plus_residual" % outputs.name)
else:
outputs = [symbol.elemwise_add(output_sym, input_sym,
name="%s_plus_residual" % output_sym.name)
for output_sym, input_sym in zip(outputs, inputs)]

return outputs, states


class BidirectionalCell(BaseRNNCell):
"""Bidirectional RNN cell.
Expand All @@ -928,9 +948,18 @@ class BidirectionalCell(BaseRNNCell):
"""
def __init__(self, l_cell, r_cell, params=None, output_prefix='bi_'):
super(BidirectionalCell, self).__init__('', params=params)
self._output_prefix = output_prefix
self._override_cell_params = params is not None

if self._override_cell_params:
assert l_cell._own_params and r_cell._own_params, \
"Either specify params for BidirectionalCell " \
"or child cells, not both."
l_cell.params._params.update(self.params._params)
r_cell.params._params.update(self.params._params)
self.params._params.update(l_cell.params._params)
self.params._params.update(r_cell.params._params)
self._cells = [l_cell, r_cell]
self._output_prefix = output_prefix

def unpack_weights(self, args):
return _cells_unpack_weights(self._cells, args)
Expand Down
20 changes: 20 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,25 @@ def test_unfuse():
check_rnn_consistency(fused, stack)
check_rnn_consistency(stack, fused)

def test_residual_fused():
cell = mx.rnn.ResidualCell(
mx.rnn.FusedRNNCell(50, num_layers=3, mode='lstm',
prefix='rnn_', dropout=0.5))

inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(2)]
outputs, _ = cell.unroll(2, inputs, merge_outputs=None)
assert sorted(cell.params._params.keys()) == \
['rnn_parameters']

args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10, 50), rnn_t1_data=(10, 50))
assert outs == [(10, 2, 50)]
outputs = outputs.eval(ctx=mx.gpu(0),
rnn_t0_data=mx.nd.ones((10, 50), ctx=mx.gpu(0))+5,
rnn_t1_data=mx.nd.ones((10, 50), ctx=mx.gpu(0))+5,
rnn_parameters=mx.nd.zeros((61200,), ctx=mx.gpu(0)))
expected_outputs = np.ones((10, 2, 50))+5
assert np.array_equal(outputs[0].asnumpy(), expected_outputs)

if __name__ == '__main__':
test_countsketch()
test_ifft()
Expand All @@ -1103,6 +1122,7 @@ def test_unfuse():
test_gru()
test_rnn()
test_unfuse()
test_residual_fused()
test_convolution_options()
test_convolution_versions()
test_convolution_with_type()
Expand Down
34 changes: 32 additions & 2 deletions tests/python/unittest/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def test_residual():

args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10, 50), rnn_t1_data=(10, 50))
assert outs == [(10, 50), (10, 50)]
print(args)
print(outputs.list_arguments())
outputs = outputs.eval(rnn_t0_data=mx.nd.ones((10, 50)),
rnn_t1_data=mx.nd.ones((10, 50)),
rnn_i2h_weight=mx.nd.zeros((150, 50)),
Expand All @@ -85,6 +83,38 @@ def test_residual():
assert np.array_equal(outputs[1].asnumpy(), expected_outputs)


def test_residual_bidirectional():
cell = mx.rnn.ResidualCell(
mx.rnn.BidirectionalCell(
mx.rnn.GRUCell(25, prefix='rnn_l_'),
mx.rnn.GRUCell(25, prefix='rnn_r_')))

inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(2)]
outputs, _ = cell.unroll(2, inputs, merge_outputs=False)
outputs = mx.sym.Group(outputs)
assert sorted(cell.params._params.keys()) == \
['rnn_l_h2h_bias', 'rnn_l_h2h_weight', 'rnn_l_i2h_bias', 'rnn_l_i2h_weight',
'rnn_r_h2h_bias', 'rnn_r_h2h_weight', 'rnn_r_i2h_bias', 'rnn_r_i2h_weight']
assert outputs.list_outputs() == \
['bi_t0_plus_residual_output', 'bi_t1_plus_residual_output']

args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10, 50), rnn_t1_data=(10, 50))
assert outs == [(10, 50), (10, 50)]
outputs = outputs.eval(rnn_t0_data=mx.nd.ones((10, 50))+5,
rnn_t1_data=mx.nd.ones((10, 50))+5,
rnn_l_i2h_weight=mx.nd.zeros((75, 50)),
rnn_l_i2h_bias=mx.nd.zeros((75,)),
rnn_l_h2h_weight=mx.nd.zeros((75, 25)),
rnn_l_h2h_bias=mx.nd.zeros((75,)),
rnn_r_i2h_weight=mx.nd.zeros((75, 50)),
rnn_r_i2h_bias=mx.nd.zeros((75,)),
rnn_r_h2h_weight=mx.nd.zeros((75, 25)),
rnn_r_h2h_bias=mx.nd.zeros((75,)))
expected_outputs = np.ones((10, 50))+5
assert np.array_equal(outputs[0].asnumpy(), expected_outputs)
assert np.array_equal(outputs[1].asnumpy(), expected_outputs)


def test_stack():
cell = mx.rnn.SequentialRNNCell()
for i in range(5):
Expand Down

0 comments on commit 62637cd

Please sign in to comment.