diff --git a/example/rnn/cudnn_lstm_bucketing.py b/example/rnn/cudnn_lstm_bucketing.py new file mode 100644 index 000000000000..b80c1e0d03ec --- /dev/null +++ b/example/rnn/cudnn_lstm_bucketing.py @@ -0,0 +1,173 @@ +import numpy as np +import mxnet as mx +import argparse + +parser = argparse.ArgumentParser(description="Train RNN on Penn Tree Bank", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--test', default=False, action='store_true', + help='whether to do testing instead of training') +parser.add_argument('--model-prefix', type=str, default=None, + help='path to save/load model') +parser.add_argument('--load-epoch', type=int, default=0, + help='load from epoch') +parser.add_argument('--num-layers', type=int, default=2, + help='number of stacked RNN layers') +parser.add_argument('--num-hidden', type=int, default=200, + help='hidden layer size') +parser.add_argument('--num-embed', type=int, default=200, + help='embedding layer size') +parser.add_argument('--gpus', type=str, + help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. ' \ + 'Increase batch size when using multiple gpus for best performance.') +parser.add_argument('--kv-store', type=str, default='device', + help='key-value store type') +parser.add_argument('--num-epochs', type=int, default=25, + help='max num of epochs') +parser.add_argument('--lr', type=float, default=0.01, + help='initial learning rate') +parser.add_argument('--optimizer', type=str, default='sgd', + help='the optimizer type') +parser.add_argument('--mom', type=float, default=0.0, + help='momentum for sgd') +parser.add_argument('--wd', type=float, default=0.00001, + help='weight decay for sgd') +parser.add_argument('--batch-size', type=int, default=32, + help='the batch size.') +parser.add_argument('--disp-batches', type=int, default=50, + help='show progress for every n batches') + + +#buckets = [32] +buckets = [10, 20, 30, 40, 50, 60] + +start_label = 1 +invalid_label = 0 + +def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0): + lines = open(fname).readlines() + lines = [filter(None, i.split(' ')) for i in lines] + sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label) + return sentences, vocab + +def get_data(layout): + train_sent, vocab = tokenize_text("./data/ptb.train.txt", start_label=start_label, + invalid_label=invalid_label) + val_sent, _ = tokenize_text("./data/ptb.test.txt", vocab=vocab, start_label=start_label, + invalid_label=invalid_label) + + data_train = mx.rnn.BucketSentenceIter(train_sent, args.batch_size, buckets=buckets, + invalid_label=invalid_label, layout=layout) + data_val = mx.rnn.BucketSentenceIter(val_sent, args.batch_size, buckets=buckets, + invalid_label=invalid_label, layout=layout) + return data_train, data_val, vocab + + +def train(args): + data_train, data_val, vocab = get_data('TN') + + cell = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, mode='lstm') + + def sym_gen(seq_len): + data = mx.sym.Variable('data') + label = mx.sym.Variable('softmax_label') + embed = mx.sym.Embedding(data=data, input_dim=len(vocab), output_dim=args.num_embed,name='embed') + + output, _ = cell.unroll(seq_len, inputs=embed, merge_outputs=True, layout='TNC') + + pred = mx.sym.Reshape(output, shape=(-1, args.num_hidden)) + pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred') + + label = mx.sym.Reshape(label, shape=(-1,)) + pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') + + return pred, ('data',), ('softmax_label',) + + if args.gpus: + contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')] + else: + contexts = mx.cpu(0) + + model = mx.mod.BucketingModule( + sym_gen = sym_gen, + default_bucket_key = data_train.default_bucket_key, + context = contexts) + + if args.load_epoch: + _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint( + cell, args.model_prefix, args.load_epoch) + else: + arg_params = None + aux_params = None + + model.fit( + train_data = data_train, + eval_data = data_val, + eval_metric = mx.metric.Perplexity(invalid_label), + kvstore = args.kv_store, + optimizer = args.optimizer, + optimizer_params = { 'learning_rate': args.lr, + 'momentum': args.mom, + 'wd': args.wd }, + initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), + arg_params = arg_params, + aux_params = aux_params, + begin_epoch = args.load_epoch, + num_epoch = args.num_epochs, + batch_end_callback = mx.callback.Speedometer(args.batch_size, args.disp_batches), + epoch_end_callback = mx.rnn.do_rnn_checkpoint(cell, args.model_prefix, 1) + if args.model_prefix else None) + +def test(args): + assert args.model_prefix, "Must specifiy path to load from" + _, data_val, vocab = get_data('NT') + + stack = mx.rnn.SequentialRNNCell() + for i in range(args.num_layers): + stack.add(mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_l%d_'%i)) + + def sym_gen(seq_len): + data = mx.sym.Variable('data') + label = mx.sym.Variable('softmax_label') + embed = mx.sym.Embedding(data=data, input_dim=len(vocab), + output_dim=args.num_embed, name='embed') + + outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True) + + pred = mx.sym.Reshape(outputs, shape=(-1, args.num_hidden)) + pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred') + + label = mx.sym.Reshape(label, shape=(-1,)) + pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') + + return pred, ('data',), ('softmax_label',) + + if args.gpus: + contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')] + else: + contexts = mx.cpu(0) + + model = mx.mod.BucketingModule( + sym_gen = sym_gen, + default_bucket_key = data_val.default_bucket_key, + context = contexts) + model.bind(data_val.provide_data, data_val.provide_label, for_training=False) + + # note here we load using SequentialRNNCell instead of FusedRNNCell. + _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(stack, args.model_prefix, args.load_epoch) + model.set_params(arg_params, aux_params) + + model.score(data_val, mx.metric.Perplexity(invalid_label), + batch_end_callback=mx.callback.Speedometer(args.batch_size, 5)) + +if __name__ == '__main__': + import logging + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=logging.DEBUG, format=head) + + args = parser.parse_args() + if args.test: + # Demonstrates how to load a model trained with CuDNN RNN and predict + # with non-fused MXNet symbol + test(args) + else: + train(args) diff --git a/example/rnn/lstm_bucketing.py b/example/rnn/lstm_bucketing.py index d3a9322058ee..b764639ebdea 100644 --- a/example/rnn/lstm_bucketing.py +++ b/example/rnn/lstm_bucketing.py @@ -64,16 +64,15 @@ def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0): def sym_gen(seq_len): data = mx.sym.Variable('data') label = mx.sym.Variable('softmax_label') - embed = mx.sym.Embedding(data=data, input_dim=len(vocab), output_dim=args.num_embed,name='embed') + embed = mx.sym.Embedding(data=data, input_dim=len(vocab), + output_dim=args.num_embed, name='embed') stack = mx.rnn.SequentialRNNCell() for i in range(args.num_layers): stack.add(mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_l%d_'%i)) - outputs, states = mx.rnn.rnn_unroll(stack, seq_len, inputs=embed) + outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True) - outputs = [mx.sym.expand_dims(x, axis=1) for x in outputs] - pred = mx.sym.Concat(*outputs, dim=1) - pred = mx.sym.Reshape(pred, shape=(-1, args.num_hidden)) + pred = mx.sym.Reshape(outputs, shape=(-1, args.num_hidden)) pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred') label = mx.sym.Reshape(label, shape=(-1,)) diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index b370b20c260d..26aa108df395 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -1,5 +1,5 @@ # coding: utf-8 -# pylint: disable=too-many-branches +# pylint: disable=too-many-branches, too-many-arguments """Initialization helper for mxnet""" from __future__ import absolute_import, print_function @@ -413,7 +413,7 @@ def __init__(self, factor_type="avg", slope=0.25): @register class Bilinear(Initializer): - """docstring for Bilinear""" + """Initialize weight for upsampling layer""" def __init__(self): super(Bilinear, self).__init__() @@ -428,3 +428,45 @@ def _init_weight(self, _, arr): y = (i / shape[3]) % shape[2] weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) arr[:] = weight.reshape(shape) + + +@register +class FusedRNN(Initializer): + """Initialze parameters for fused rnn layer + + Parameters + ---------- + init : Initializer + intializer applied to unpacked weights. + num_hidden : int + should be the same with arguments passed to FusedRNNCell. + num_layers : int + should be the same with arguments passed to FusedRNNCell. + mode : str + should be the same with arguments passed to FusedRNNCell. + bidirectional : bool + should be the same with arguments passed to FusedRNNCell. + """ + def __init__(self, init, num_hidden, num_layers, mode, bidirectional=False): + if not isinstance(init, Initializer): + klass, kwargs = json.loads(init) + init = _INITIALIZER_REGISTRY[klass.lower()](**kwargs) + super(FusedRNN, self).__init__(init=init.dumps(), num_hidden=num_hidden, + num_layers=num_layers, mode=mode, + bidirectional=bidirectional) + self._num_hidden = num_hidden + self._num_layers = num_layers + self._bidirectional = bidirectional + self._mode = mode + self._init = init + + def _init_weight(self, _, arr): + from .rnn import rnn_cell + cell = rnn_cell.FusedRNNCell(self._num_hidden, self._num_layers, + self._mode, self._bidirectional, prefix='') + args = cell.unpack_weights({'parameters': arr}) + for name in args: + desc = InitDesc(name) + self._init(desc, args[name]) + arr[:] = cell.pack_weights(args)['parameters'] + diff --git a/python/mxnet/rnn/io.py b/python/mxnet/rnn/io.py index 29e091f1a473..cd55ad5012df 100644 --- a/python/mxnet/rnn/io.py +++ b/python/mxnet/rnn/io.py @@ -79,9 +79,13 @@ class BucketSentenceIter(DataIter): name of data label_name : str, default 'softmax_label' name of label + layout : str + format of data and label. 'NT' means (batch_size, length) + and 'TN' means (length, batch_size). """ - def __init__(self, sentences, batch_size, invalid_label=-1, dtype='float32', - buckets=None, data_name='data', label_name='softmax_label'): + def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1, + data_name='data', label_name='softmax_label', dtype='float32', + layout='NTC'): super(BucketSentenceIter, self).__init__() if not buckets: buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences])) @@ -90,7 +94,7 @@ def __init__(self, sentences, batch_size, invalid_label=-1, dtype='float32', ndiscard = 0 self.data = [[] for _ in buckets] - for i in xrange(len(sentences)): + for i in range(len(sentences)): buck = bisect.bisect_left(buckets, len(sentences[i])) if buck == len(buckets): ndiscard += 1 @@ -103,43 +107,62 @@ def __init__(self, sentences, batch_size, invalid_label=-1, dtype='float32', print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard) - self.default_bucket_key = max(buckets) - - self.provide_data = [(data_name, (batch_size, self.default_bucket_key))] - self.provide_label = [(label_name, (batch_size, self.default_bucket_key))] - self.batch_size = batch_size self.buckets = buckets self.data_name = data_name self.label_name = label_name self.dtype = dtype self.invalid_label = invalid_label + self.nddata = [] + self.ndlabel = [] + self.major_axis = layout.find('N') + self.default_bucket_key = max(buckets) + + if self.major_axis == 0: + self.provide_data = [(data_name, (batch_size, self.default_bucket_key))] + self.provide_label = [(label_name, (batch_size, self.default_bucket_key))] + elif self.major_axis == 1: + self.provide_data = [(data_name, (self.default_bucket_key, batch_size))] + self.provide_label = [(label_name, (self.default_bucket_key, batch_size))] + else: + raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)") self.idx = [] for i, buck in enumerate(self.data): self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)]) self.curr_idx = 0 + self.reset() + def reset(self): self.curr_idx = 0 random.shuffle(self.idx) for buck in self.data: np.random.shuffle(buck) + self.nddata = [] + self.ndlabel = [] + for buck in self.data: + label = np.empty_like(buck) + label[:, :-1] = buck[:, 1:] + label[:, -1] = self.invalid_label + self.nddata.append(ndarray.array(buck, dtype=self.dtype)) + self.ndlabel.append(ndarray.array(label, dtype=self.dtype)) + def next(self): if self.curr_idx == len(self.idx): raise StopIteration i, j = self.idx[self.curr_idx] self.curr_idx += 1 - data = self.data[i][j:j+self.batch_size] - label = np.empty_like(data) - label[:, :-1] = data[:, 1:] - label[:, -1] = self.invalid_label - + if self.major_axis == 1: + data = self.nddata[i][j:j+self.batch_size].T + label = self.ndlabel[i][j:j+self.batch_size].T + else: + data = self.nddata[i][j:j+self.batch_size] + label = self.ndlabel[i][j:j+self.batch_size] - return DataBatch([ndarray.array(data, dtype=self.dtype)], - [ndarray.array(label, dtype=self.dtype)], + return DataBatch([data], [label], bucket_key=self.buckets[i], provide_data=[(self.data_name, data.shape)], provide_label=[(self.label_name, label.shape)]) diff --git a/python/mxnet/rnn/rnn.py b/python/mxnet/rnn/rnn.py index 624381c7b5b7..6a1213b272b0 100644 --- a/python/mxnet/rnn/rnn.py +++ b/python/mxnet/rnn/rnn.py @@ -1,57 +1,104 @@ # coding: utf-8 # pylint: disable=too-many-arguments, no-member """Functions for constructing recurrent neural networks.""" -from .. import symbol +import warnings +from ..model import save_checkpoint, load_checkpoint +from .rnn_cell import BaseRNNCell def rnn_unroll(cell, length, inputs=None, begin_state=None, input_prefix='', layout='NTC'): - """Unroll an RNN cell across time steps. + """Deprecated. Please use cell.unroll instead""" + warnings.warn('rnn_unroll is deprecated. Please call cell.unroll directly.') + return cell.unroll(length=length, inputs=inputs, begin_state=begin_state, + input_prefix=input_prefix, layout=layout) + +def save_rnn_checkpoint(cells, prefix, epoch, symbol, arg_params, aux_params): + """Save checkpoint for model using RNN cells. + Unpacks weight before saving. + + Parameters + ---------- + cells : RNNCells or list of RNNCells + The RNN cells used by this symbol. + prefix : str + Prefix of model name. + epoch : int + The epoch number of the model. + symbol : Symbol + The input symbol + arg_params : dict of str to NDArray + Model parameter, dict of name to NDArray of net's weights. + aux_params : dict of str to NDArray + Model parameter, dict of name to NDArray of net's auxiliary states. + + Notes + ----- + - ``prefix-symbol.json`` will be saved for symbol. + - ``prefix-epoch.params`` will be saved for parameters. + """ + if isinstance(cells, BaseRNNCell): + cells = [cells] + for cell in cells: + arg_params = cell.unpack_weights(arg_params) + save_checkpoint(prefix, epoch, symbol, arg_params, aux_params) + +def load_rnn_checkpoint(cells, prefix, epoch): + """Load model checkpoint from file. + Pack weights after loading. + + Parameters + ---------- + cells : RNNCells or list of RNNCells + The RNN cells used by this symbol. + prefix : str + Prefix of model name. + epoch : int + Epoch number of model we would like to load. + + Returns + ------- + symbol : Symbol + The symbol configuration of computation network. + arg_params : dict of str to NDArray + Model parameter, dict of name to NDArray of net's weights. + aux_params : dict of str to NDArray + Model parameter, dict of name to NDArray of net's auxiliary states. + + Notes + ----- + - symbol will be loaded from ``prefix-symbol.json``. + - parameters will be loaded from ``prefix-epoch.params``. + """ + sym, arg, aux = load_checkpoint(prefix, epoch) + if isinstance(cells, BaseRNNCell): + cells = [cells] + for cell in cells: + arg = cell.pack_weights(arg) + + return sym, arg, aux + +def do_rnn_checkpoint(cells, prefix, period=1): + """Make a callback to checkpoint Module to prefix every epoch. + unpacks weights used by cells before saving. Parameters ---------- - cell : children of BaseRNNCell - the cell to be unrolled. - length : int - number of steps to unroll - inputs : Symbol, list of Symbol, or None - if inputs is a single Symbol (usually the output - of Embedding symbol), it should have shape - (batch_size, length, ...) if layout == 'NTC', - or (length, batch_size, ...) if layout == 'TNC'. - - If inputs is a grouped symbol or a list of - symbols (usually output of SliceChannel or previous - unroll), they should all have shape (batch_size, ...). - - if inputs is None, Placeholder ariables are - automatically created. - begin_state : nested list of Symbol - input states. Created by cell.begin_state() - or output state of another cell. Created - from cell.begin_state() if None. - input_prefix : str - prefix for automatically created input - placehodlers. - layout : str - layout of input symbol. Only used if inputs - is a single Symbol. + cells : subclass of BaseRNNCell + RNN cells used by this module. + prefix : str + The file prefix to checkpoint to + period : int + How many epochs to wait before checkpointing. Default is 1. + + Returns + ------- + callback : function + The callback function that can be passed as iter_end_callback to fit. """ - if inputs is None: - inputs = [symbol.Variable('%st%d_data'%(input_prefix, i)) for i in range(length)] - elif isinstance(inputs, symbol.Symbol): - if len(inputs.list_outputs()) != length: - assert len(inputs.list_outputs()) == 1 - axis = layout.find('T') - inputs = symbol.SliceChannel(inputs, axis=axis, num_outputs=length, squeeze_axis=1) - else: - assert len(inputs) == length - if begin_state is None: - begin_state = cell.begin_state() - - states = begin_state - outputs = [] - for i in range(length): - output, states = cell(inputs[i], states) - outputs.append(output) - - return outputs, states + period = int(max(1, period)) + # pylint: disable=unused-argument + def _callback(iter_no, sym=None, arg=None, aux=None): + """The checkpoint function.""" + if (iter_no + 1) % period == 0: + save_rnn_checkpoint(cells, prefix, iter_no+1, sym, arg, aux) + return _callback diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py index 288ca1eae093..1277d14cc1aa 100644 --- a/python/mxnet/rnn/rnn_cell.py +++ b/python/mxnet/rnn/rnn_cell.py @@ -1,9 +1,12 @@ # coding: utf-8 -# pylint: disable=no-member, invalid-name, protected-access +# pylint: disable=no-member, invalid-name, protected-access, no-self-use +# pylint: disable=too-many-branches, too-many-arguments, no-self-use """Definition of various recurrent neural network cells.""" from __future__ import print_function -from .. import symbol +import warnings + +from .. import symbol, init, ndarray from ..base import numeric_types, string_types class RNNParams(object): @@ -90,11 +93,6 @@ def state_shape(self): """shape(s) of states""" raise NotImplementedError() - @property - def output_shape(self): - """shape(s) of output""" - raise NotImplementedError() - def begin_state(self, init_sym=symbol.zeros, **kwargs): """Initial state for this cell. @@ -129,6 +127,113 @@ def recursive(shape): return recursive(state_shape) + def unpack_weights(self, args): + """Unpack fused weight matrices into separate + weight matrices + + Parameters + ---------- + args : dict of str -> NDArray + dictionary containing packed weights. + usually from Module.get_output() + + Returns + ------- + args : dict of str -> NDArray + dictionary with weights associated to + this cell unpacked. + """ + #pylint: disable=R0201 + return args.copy() + + def pack_weights(self, args): + """Pack separate weight matrices into fused + weight. + + Parameters + ---------- + args : dict of str -> NDArray + dictionary containing unpacked weights. + + Returns + ------- + args : dict of str -> NDArray + dictionary with weights associated to + this cell packed. + """ + #pylint: disable=R0201 + return args.copy() + + def unroll(self, length, inputs=None, begin_state=None, + input_prefix='', layout='NTC', merge_outputs=False): + """Unroll an RNN cell across time steps. + + Parameters + ---------- + length : int + number of steps to unroll + inputs : Symbol, list of Symbol, or None + if inputs is a single Symbol (usually the output + of Embedding symbol), it should have shape + (batch_size, length, ...) if layout == 'NTC', + or (length, batch_size, ...) if layout == 'TNC'. + + If inputs is a list of symbols (usually output of + previous unroll), they should all have shape + (batch_size, ...). + + If inputs is None, Placeholder variables are + automatically created. + begin_state : nested list of Symbol + input states. Created by begin_state() + or output state of another cell. Created + from begin_state() if None. + input_prefix : str + prefix for automatically created input + placehodlers. + layout : str + layout of input symbol. Only used if inputs + is a single Symbol. + merge_outputs : bool + if False, return outputs as a list of Symbols. + If True, concatenate output across time steps + and return a single symbol with shape + (batch_size, length, ...) if layout == 'NTC', + or (length, batch_size, ...) if layout == 'TNC'. + + Returns + ------- + outputs : list of Symbol + output symbols. + states : Symbol or nested list of Symbol + has the same structure as begin_state() + """ + axis = layout.find('T') + if inputs is None: + inputs = [symbol.Variable('%st%d_data'%(input_prefix, i)) + for i in range(length)] + elif isinstance(inputs, symbol.Symbol): + assert len(inputs.list_outputs()) == 1, \ + "unroll doesn't allow grouped symbol as input. Please " \ + "convert to list first or let unroll handle slicing" + inputs = symbol.SliceChannel(inputs, axis=axis, num_outputs=length, + squeeze_axis=1) + else: + assert len(inputs) == length + if begin_state is None: + begin_state = self.begin_state() + + states = begin_state + outputs = [] + for i in range(length): + output, states = self(inputs[i], states) + outputs.append(output) + + if merge_outputs: + outputs = [symbol.expand_dims(i, axis=axis) for i in outputs] + outputs = symbol.Concat(*outputs, dim=axis) + return outputs, states + #pylint: disable=no-self-use def _get_activation(self, inputs, activation, **kwargs): """Get activation function. Convert if is string""" @@ -168,11 +273,6 @@ def state_shape(self): """shape(s) of states""" return (0, self._num_hidden) - @property - def output_shape(self): - """shape(s) of output""" - return (0, self._num_hidden) - def __call__(self, inputs, states): """Construct symbol for one step of RNN. @@ -231,10 +331,63 @@ def state_shape(self): """shape(s) of states""" return [(0, self._num_hidden), (0, self._num_hidden)] - @property - def output_shape(self): - """shape(s) of output""" - return (0, self._num_hidden) + def unpack_weights(self, args): + """Unpack fused weight matrices into separate + weight matrices + + Parameters + ---------- + args : dict of str -> NDArray + dictionary containing packed weights. + usually from Module.get_output() + + Returns + ------- + args : dict of str -> NDArray + dictionary with weights associated to + this cell unpacked. + """ + args = args.copy() + outs = ['_i', '_f', '_c', '_o'] + h = self._num_hidden + for i in ['i2h', 'h2h']: + weight = args.pop('%s%s_weight'%(self._prefix, i)) + bias = args.pop('%s%s_bias'%(self._prefix, i)) + for j, name in enumerate(outs): + wname = '%s%s%s_weight'%(self._prefix, i, name) + args[wname] = weight[j*h:(j+1)*h].copy() + bname = '%s%s%s_bias'%(self._prefix, i, name) + args[bname] = bias[j*h:(j+1)*h].copy() + return args + + def pack_weights(self, args): + """Pack separate weight matrices into fused + weight. + + Parameters + ---------- + args : dict of str -> NDArray + dictionary containing unpacked weights. + + Returns + ------- + args : dict of str -> NDArray + dictionary with weights associated to + this cell packed. + """ + args = args.copy() + outs = ['_i', '_f', '_c', '_o'] + for i in ['i2h', 'h2h']: + weight = [] + bias = [] + for name in outs: + wname = '%s%s%s_weight'%(self._prefix, i, name) + weight.append(args.pop(wname)) + bname = '%s%s%s_bias'%(self._prefix, i, name) + bias.append(args.pop(bname)) + args['%s%s_weight'%(self._prefix, i)] = ndarray.concatenate(weight) + args['%s%s_bias'%(self._prefix, i)] = ndarray.concatenate(bias) + return args def __call__(self, inputs, states): """Construct symbol for one step of RNN. @@ -266,10 +419,10 @@ def __call__(self, inputs, states): name="%sslice"%name) in_gate = symbol.Activation(slice_gates[0], act_type="sigmoid", name='%si'%name) - in_transform = symbol.Activation(slice_gates[1], act_type="tanh", - name='%sc'%name) - forget_gate = symbol.Activation(slice_gates[2], act_type="sigmoid", + forget_gate = symbol.Activation(slice_gates[1], act_type="sigmoid", name='%sf'%name) + in_transform = symbol.Activation(slice_gates[2], act_type="tanh", + name='%sc'%name) out_gate = symbol.Activation(slice_gates[3], act_type="sigmoid", name='%so'%name) next_c = symbol._internal._plus(forget_gate * states[1], in_gate * in_transform, @@ -280,6 +433,240 @@ def __call__(self, inputs, states): return next_h, [next_h, next_c] +class FusedRNNCell(BaseRNNCell): + """Fusing RNN layers across time step into one kernel. + Improves speed but is less flexible. Currently only + supported if using cuDNN on GPU. + + Parameters + ---------- + """ + def __init__(self, num_hidden, num_layers=1, mode='lstm', bidirectional=False, + dropout=0., get_next_state=False, initializer=None, + prefix=None, params=None): + if prefix is None: + prefix = '%s_'%mode + super(FusedRNNCell, self).__init__(prefix=prefix, params=params) + self._num_hidden = num_hidden + self._num_layers = num_layers + self._mode = mode + self._bidirectional = bidirectional + self._dropout = dropout + self._get_next_state = get_next_state + if initializer is None: + initializer = init.Xavier(factor_type='in', magnitude=2.34) + if not isinstance(initializer, init.FusedRNN): + initializer = init.FusedRNN(initializer, num_hidden, num_layers, + mode, bidirectional) + self._parameter = self.params.get('parameters', init=initializer) + + self._directions = self._bidirectional + 1 + self._weight_names = {'rnn_relu': [''], + 'rnn_tanh': [''], + 'lstm': ['_i', '_f', '_c', '_o'], + 'gru': ['_r', '_z', '_o']}[self._mode] + self._num_weights = len(self._weight_names) + + @property + def state_shape(self): + """shape(s) of states""" + b = self._bidirectional + 1 + if self._mode == 'lstm': + return [(b*self._num_layers, 0, self._num_hidden), + (b*self._num_layers, 0, self._num_hidden)] + else: + return (b*self._num_layers, 0, self._num_hidden) + + def _slice_weights(self, arr, li, lh): + """slice fused rnn weights""" + args = {} + b = self._directions + m = self._num_weights + c = self._weight_names + d = ['l', 'r'] + + p = 0 + for i in range(self._num_layers): + for j in range(b): + for k in range(m): + name = '%s%s%d_i2h%s_weight'%(self._prefix, d[j], i, c[k]) + if i > 0: + size = b*lh*lh + args[name] = arr[p:p+size].reshape((lh, b*lh)) + else: + size = li*lh + args[name] = arr[p:p+size].reshape((lh, li)) + p += size + for k in range(m): + name = '%s%s%d_h2h%s_weight'%(self._prefix, d[j], i, c[k]) + size = lh**2 + args[name] = arr[p:p+size].reshape((lh, lh)) + p += size + + for i in range(self._num_layers): + for j in range(b): + for k in range(m): + name = '%s%s%d_i2h%s_bias'%(self._prefix, d[j], i, c[k]) + args[name] = arr[p:p+lh] + p += lh + for k in range(m): + name = '%s%s%d_h2h%s_bias'%(self._prefix, d[j], i, c[k]) + args[name] = arr[p:p+lh] + p += lh + + assert p == arr.size, "Invalid parameters size for FusedRNNCell" + return args + + def unpack_weights(self, args): + """Unpack fused weight matrices into separate + weight matrices + + Parameters + ---------- + args : dict of str -> NDArray + dictionary containing packed weights. + usually from Module.get_output() + + Returns + ------- + args : dict of str -> NDArray + dictionary with weights associated to + this cell unpacked. + """ + args = args.copy() + arr = args.pop(self._parameter.name) + b = self._directions + m = self._num_weights + h = self._num_hidden + num_input = arr.size/b/h/m - (self._num_layers - 1)*(h+b*h+2) - h - 2 + + nargs = self._slice_weights(arr, num_input, self._num_hidden) + args.update({name: nd.copy() for name, nd in nargs.items()}) + return args + + def pack_weights(self, args): + """Pack separate weight matrices into fused + weight. + + Parameters + ---------- + args : dict of str -> NDArray + dictionary containing unpacked weights. + + Returns + ------- + args : dict of str -> NDArray + dictionary with weights associated to + this cell packed. + """ + args = args.copy() + b = self._bidirectional + 1 + m = self._num_weights + c = self._weight_names + h = self._num_hidden + w0 = args['%sl0_i2h%s_weight'%(self._prefix, c[0])] + num_input = w0.shape[1] + total = (num_input+h+2)*h*m*b + (self._num_layers-1)*m*h*(h+b*h+2)*b + + arr = ndarray.zeros((total,), ctx=w0.context, dtype=w0.dtype) + for name, nd in self._slice_weights(arr, num_input, h).items(): + nd[:] = args.pop(name) + args[self._parameter.name] = arr + return args + + def __call__(self, inputs, states): + raise NotImplementedError("FusedRNNCell cannot be stepped. Please use unroll") + + def unroll(self, length, inputs=None, begin_state=None, + input_prefix='', layout='NTC', merge_outputs=False): + """Unroll an RNN cell across time steps. + + Parameters + ---------- + length : int + number of steps to unroll + inputs : Symbol, list of Symbol, or None + if inputs is a single Symbol (usually the output + of Embedding symbol), it should have shape + (batch_size, length, ...) if layout == 'NTC', + or (length, batch_size, ...) if layout == 'TNC'. + using 'TNC' is more efficient for FusedRNNCell. + + If inputs is a list of symbols (usually output of + previous unroll), they should all have shape + (batch_size, ...). using single symbol is + more efficient for FusedRNNCell. + + If inputs is None, a single placeholder variable is + automatically created. + begin_state : nested list of Symbol + input states. Created by begin_state() + or output state of another cell. Created + from begin_state() if None. + input_prefix : str + prefix for automatically created input + placehodlers. + layout : str + layout of input/output symbol. + + Returns + ------- + outputs : list of Symbol + output symbols. + states : Symbol or nested list of Symbol + has the same structure as begin_state() + """ + axis = layout.find('T') + if inputs is None: + inputs = symbol.Variable('%sdata'%input_prefix) + if isinstance(inputs, symbol.Symbol): + assert len(inputs.list_outputs()) == 1, \ + "unroll doesn't allow grouped symbol as input. Please " \ + "convert to list first or let unroll handle slicing" + if axis == 1: + warnings.warn("NTC layout detected. Consider using " + "TNC for FusedRNNCell for faster speed") + inputs = symbol.SwapAxis(inputs, dim1=0, dim2=1) + else: + assert axis == 0, "Unsupported layout %s"%layout + else: + assert len(inputs) == length + inputs = [symbol.expand_dims(i, axis=0) for i in inputs] + inputs = symbol.Concat(inputs, dim=0) + if begin_state is None: + begin_state = self.begin_state() + + states = begin_state + if self._mode == 'lstm': + states = {'state': states[0], 'state_cell': states[1]} + else: + states = {'state': states} + + rnn = symbol.RNN(data=inputs, parameters=self._parameter, + state_size=self._num_hidden, num_layers=self._num_layers, + bidirectional=self._bidirectional, p=self._dropout, + state_outputs=self._get_next_state, + mode=self._mode, name=self._prefix+'rnn', + **states) + + if not self._get_next_state: + outputs, states = rnn, [] + elif self._mode == 'lstm': + outputs, states = rnn[0], [rnn[1], rnn[2]] + else: + outputs, states = rnn[0], rnn[1] + + if not merge_outputs: + warnings.warn("Call FusedRNNCell.unroll with merge_outputs=True " + "for faster speed") + outputs = list(symbol.SliceChannel(outputs, aixs=axis, num_outputs=length, + squeeze_axis=1)) + elif axis == 1: + outputs = symbol.SwapAxis(outputs, dim1=0, dim2=1) + + return outputs, states + + class SequentialRNNCell(BaseRNNCell): """Sequantially stacking multiple RNN cells @@ -314,11 +701,6 @@ def state_shape(self): """shape(s) of states""" return [c.state_shape for c in self._cells] - @property - def output_shape(self): - """shape(s) of output""" - return self._cells[-1].output_shape - def begin_state(self, **kwargs): """Initial state for this cell. @@ -341,6 +723,16 @@ def begin_state(self, **kwargs): "cell cannot be called directly. Call the modifier cell instead." return [c.begin_state(**kwargs) for c in self._cells] + def unpack_weights(self, args): + for cell in self._cells: + args = cell.unpack_weights(args) + return args + + def pack_weights(self, args): + for cell in self._cells: + args = cell.pack_weights(args) + return args + def __call__(self, inputs, states): """Construct symbol for one step of RNN. @@ -390,11 +782,6 @@ def state_shape(self): """shape(s) of states""" return self.base_cell.state_shape - @property - def output_shape(self): - """shape(s) of output""" - return self.base_cell.output_shape - def begin_state(self, init_sym=symbol.zeros, **kwargs): """Initial state for this cell. @@ -420,6 +807,41 @@ def begin_state(self, init_sym=symbol.zeros, **kwargs): self.base_cell._modified = True return begin + def unpack_weights(self, args): + """Unpack fused weight matrices into separate + weight matrices + + Parameters + ---------- + args : dict of str -> NDArray + dictionary containing packed weights. + usually from Module.get_output() + + Returns + ------- + args : dict of str -> NDArray + dictionary with weights associated to + this cell unpacked. + """ + return self.base_cell.unpack_weights(args) + + def pack_weights(self, args): + """Pack separate weight matrices into fused + weight. + + Parameters + ---------- + args : dict of str -> NDArray + dictionary containing unpacked weights. + + Returns + ------- + args : dict of str -> NDArray + dictionary with weights associated to + this cell packed. + """ + return self.base_cell.pack_weights(args) + def __call__(self, inputs, states): """Construct symbol for one step of RNN. @@ -499,5 +921,3 @@ def __call__(self, inputs, states): """ raise NotImplementedError - - diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 66829d2445ad..9ec45c5d834f 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -294,4 +294,9 @@ class LibInfo { // CustomOp @native def mxCustomOpRegister(regName: String, opProp: CustomOpProp): Int + + // Profiler + @native def mxSetProfilerConfig(mode: Int, fileName: String): Int + @native def mxSetProfilerState(state: Int): Int + @native def mxDumpProfile(): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Profiler.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Profiler.scala new file mode 100644 index 000000000000..7a54b6acca13 --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Profiler.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet + +import ml.dmlc.mxnet.Base._ + +/** + * @author Depeng Liang + */ +object Profiler { + + val mode2Int = Map("symbolic" -> 0, "all" -> 1) + val state2Int = Map("stop" -> 0, "run" -> 1) + + /** + * Set up the configure of profiler. + * @param mode, optional + * Indicting whether to enable the profiler, can + * be "symbolic" or "all". Default is "symbolic". + * @param fileName, optional + * The name of output trace file. Default is "profile.json". + */ + def profilerSetConfig(mode: String = "symbolic", fileName: String = "profile.json"): Unit = { + require(mode2Int.contains(mode)) + checkCall(_LIB.mxSetProfilerConfig(mode2Int(mode), fileName)) + } + + /** + * Set up the profiler state to record operator. + * @param state, optional + * Indicting whether to run the profiler, can + * be "stop" or "run". Default is "stop". + */ + def profilerSetState(state: String = "stop"): Unit = { + require(state2Int.contains(state)) + checkCall(_LIB.mxSetProfilerState(state2Int(state))) + } + + /** + * Dump profile and stop profiler. Use this to save profile + * in advance in case your program cannot exit normally. + */ + def dumpProfile(): Unit = { + checkCall(_LIB.mxDumpProfile()) + } +} diff --git a/scala-package/examples/scripts/profiler/run_profiler_matmul.sh b/scala-package/examples/scripts/profiler/run_profiler_matmul.sh new file mode 100644 index 000000000000..3c1b0f35d530 --- /dev/null +++ b/scala-package/examples/scripts/profiler/run_profiler_matmul.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +MXNET_ROOT=$(cd "$(dirname $0)/../../../.."; pwd) +CLASS_PATH=$MXNET_ROOT/scala-package/assembly/linux-x86_64-gpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/* + +# which gpu card to use, -1 means cpu +GPU=0 + +MODE="symbolic" +OUTPUT_PATH="." +# Just load the trace file at chrome://tracing in your Chrome browser +FILE_NAME="profile_matmul_20iter.json" + +java -Xmx4G -cp $CLASS_PATH \ + ml.dmlc.mxnet.examples.profiler.ProfilerMatMul \ + --gpu $GPU \ + --profiler-mode $MODE \ + --output-path $OUTPUT_PATH \ + --profile-filename $FILE_NAME + diff --git a/scala-package/examples/scripts/profiler/run_profiler_ndarray.sh b/scala-package/examples/scripts/profiler/run_profiler_ndarray.sh new file mode 100644 index 000000000000..f6d2ea96ab89 --- /dev/null +++ b/scala-package/examples/scripts/profiler/run_profiler_ndarray.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +MXNET_ROOT=$(cd "$(dirname $0)/../../../.."; pwd) +CLASS_PATH=$MXNET_ROOT/scala-package/assembly/linux-x86_64-gpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/* + + +MODE="all" +OUTPUT_PATH="." +# Just load the trace file at chrome://tracing in your Chrome browser +FILE_NAME="profile_ndarray.json" + +java -Xmx4G -cp $CLASS_PATH \ + ml.dmlc.mxnet.examples.profiler.ProfilerNDArray \ + --profiler-mode $MODE \ + --output-path $OUTPUT_PATH \ + --profile-filename $FILE_NAME + diff --git a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/profiler/ProfilerMatMul.scala b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/profiler/ProfilerMatMul.scala new file mode 100644 index 000000000000..7721ef44718c --- /dev/null +++ b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/profiler/ProfilerMatMul.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet.examples.profiler + +import org.kohsuke.args4j.{CmdLineParser, Option} +import org.slf4j.LoggerFactory +import scala.collection.JavaConverters._ +import ml.dmlc.mxnet.Context +import ml.dmlc.mxnet.Profiler +import java.io.File +import ml.dmlc.mxnet.Symbol +import ml.dmlc.mxnet.Shape +import ml.dmlc.mxnet.Random + +/** + * @author Depeng Liang + */ +object ProfilerMatMul { + private val logger = LoggerFactory.getLogger(classOf[ProfilerMatMul]) + + def main(args: Array[String]): Unit = { + val erul = new ProfilerMatMul + val parser: CmdLineParser = new CmdLineParser(erul) + try { + parser.parseArgument(args.toList.asJava) + val ctx = if (erul.gpu >= 0) Context.gpu(erul.gpu) else Context.cpu() + + val path = s"${erul.outputPath}${File.separator}${erul.profilerName}" + Profiler.profilerSetConfig(mode = erul.profilerMode, fileName = path) + logger.info(s"profile file save to $path") + + val A = Symbol.Variable("A") + val B = Symbol.Variable("B") + val C = Symbol.dot()(A, B)() + + val executor = C.simpleBind(ctx, "write", + Map("A" -> Shape(4096, 4096), "B" -> Shape(4096, 4096))) + + val a = Random.uniform(-1.0f, 1.0f, shape = Shape(4096, 4096)) + val b = Random.uniform(-1.0f, 1.0f, shape = Shape(4096, 4096)) + + a.copyTo(executor.argDict("A")) + b.copyTo(executor.argDict("B")) + + val flag = false + logger.info(s"execution begin") + var t0 = 0L + var t1 = 0L + for (i <- 0 until erul.iterNum) { + if (i == erul.beginProfilingIter) { + t0 = System.currentTimeMillis() + Profiler.profilerSetState("run") + } + if (i == erul.endProfilingIter) { + t1 = System.currentTimeMillis() + Profiler.profilerSetState("stop") + } + executor.forward() + executor.outputs(0).waitToRead() + } + logger.info(s"execution end") + val duration = t1 - t0 + logger.info(s"duration: ${duration / 1000f}s") + logger.info(s"${duration.toFloat / erul.iterNum}ms/operator") + } catch { + case ex: Exception => { + logger.error(ex.getMessage, ex) + parser.printUsage(System.err) + sys.exit(1) + } + } + } +} + +class ProfilerMatMul { + @Option(name = "--profiler-mode", usage = "the profiler mode, can be \"symbolic\" or \"all\".") + private val profilerMode: String = "symbolic" + @Option(name = "--output-path", usage = "the profile file output directory.") + private val outputPath: String = "." + @Option(name = "--profile-filename", usage = "the profile file name.") + private val profilerName: String = "profile_matmul_20iter.json" + @Option(name = "--iter-num", usage = "iterate number.") + private val iterNum: Int = 100 + @Option(name = "--begin-profiling-iter'", usage = "specific iterate to start the profiler.") + private val beginProfilingIter: Int = 50 + @Option(name = "--end-profiling-iter'", usage = "specific iterate to stop the profiler.") + private val endProfilingIter: Int = 70 + @Option(name = "--gpu", usage = "which gpu card to use, default is -1, means using cpu") + private val gpu: Int = -1 +} diff --git a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/profiler/ProfilerNDArray.scala b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/profiler/ProfilerNDArray.scala new file mode 100644 index 000000000000..05cc14dc406d --- /dev/null +++ b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/profiler/ProfilerNDArray.scala @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet.examples.profiler + +import org.kohsuke.args4j.{CmdLineParser, Option} +import org.slf4j.LoggerFactory +import scala.collection.JavaConverters._ +import java.io.File +import ml.dmlc.mxnet.Profiler +import ml.dmlc.mxnet.Random +import ml.dmlc.mxnet.Shape +import ml.dmlc.mxnet.NDArray +import ml.dmlc.mxnet.Context + +/** + * @author Depeng Liang + */ +object ProfilerNDArray { + private val logger = LoggerFactory.getLogger(classOf[ProfilerNDArray]) + + def testBroadcast(): Unit = { + val sampleNum = 1000 + def testBroadcastTo(): Unit = { + for (i <- 0 until sampleNum) { + val nDim = scala.util.Random.nextInt(2) + 1 + val targetShape = Shape((0 until nDim).map(i => scala.util.Random.nextInt(10) + 1)) + val shape = targetShape.toArray.map { s => + if (scala.util.Random.nextInt(2) == 1) 1 + else s + } + val dat = NDArray.empty(shape: _*) + val randomRet = (0 until shape.product) + .map(r => scala.util.Random.nextFloat() - 0.5f).toArray + dat.set(randomRet) + val ndArrayRet = NDArray.broadcast_to(Map("shape" -> targetShape))(dat).get + require(ndArrayRet.shape == targetShape) + val err = { + // implementation of broadcast + val ret = { + (randomRet /: shape.zipWithIndex.reverse){ (acc, elem) => elem match { case (s, i) => + if (s != targetShape(i)) { + acc.grouped(shape.takeRight(shape.length - i).product).map {g => + (0 until targetShape(i)).map(x => g).flatten + }.flatten.toArray + } else acc + }} + } + val tmp = ndArrayRet.toArray.zip(ret).map{ case (l, r) => Math.pow(l - r, 2) } + tmp.sum / tmp.length + } + require(err < 1E-8) + ndArrayRet.dispose() + dat.dispose() + } + } + testBroadcastTo() + } + + def randomNDArray(dim: Int): NDArray = { + val tmp = Math.pow(1000, 1.0 / dim).toInt + val shape = Shape((0 until dim).map(d => scala.util.Random.nextInt(tmp) + 1)) + Random.uniform(-10f, 10f, shape) + } + + def testNDArraySaveload(): Unit = { + val maxDim = 5 + val nRepeat = 10 + val fileName = s"${System.getProperty("java.io.tmpdir")}/tmpList.bin" + for (repeat <- 0 until nRepeat) { + try { + val data = (0 until 10).map(i => randomNDArray(scala.util.Random.nextInt(4) + 1)) + NDArray.save(fileName, data) + val data2 = NDArray.load2Array(fileName) + require(data.length == data2.length) + for ((x, y) <- data.zip(data2)) { + val tmp = x - y + require(tmp.toArray.sum == 0) + tmp.dispose() + } + val dMap = data.zipWithIndex.map { case (arr, i) => + s"NDArray xx $i" -> arr + }.toMap + NDArray.save(fileName, dMap) + val dMap2 = NDArray.load2Map(fileName) + require(dMap.size == dMap2.size) + for ((k, x) <- dMap) { + val y = dMap2(k) + val tmp = x - y + require(tmp.toArray.sum == 0) + tmp.dispose() + } + data.foreach(_.dispose()) + } finally { + val file = new File(fileName) + file.delete() + } + } + } + + def testNDArrayCopy(): Unit = { + val c = Random.uniform(-10f, 10f, Shape(10, 10)) + val d = c.copyTo(Context.cpu(0)) + val tmp = c - d + require(tmp.toArray.map(Math.abs).sum == 0) + c.dispose() + d.dispose() + } + + def reldiff(a: NDArray, b: NDArray): Float = { + val diff = NDArray.sum(NDArray.abs(a - b)).toScalar + val norm = NDArray.sum(NDArray.abs(a)).toScalar + diff / norm + } + + def reldiff(a: Array[Float], b: Array[Float]): Float = { + val diff = + (a zip b).map { case (aElem, bElem) => Math.abs(aElem - bElem) }.sum + val norm: Float = a.reduce(Math.abs(_) + Math.abs(_)) + diff / norm + } + + def testNDArrayNegate(): Unit = { + val rand = Random.uniform(-10f, 10f, Shape(2, 3, 4)) + val npy = rand.toArray + val arr = NDArray.empty(Shape(2, 3, 4)) + arr.set(npy) + require(reldiff(npy, arr.toArray) < 1e-6f) + val negativeArr = -arr + require(reldiff(npy.map(_ * -1f), negativeArr.toArray) < 1e-6f) + // a final check to make sure the negation (-) is not implemented + // as inplace operation, so the contents of arr does not change after + // we compute (-arr) + require(reldiff(npy, arr.toArray) < 1e-6f) + rand.dispose() + arr.dispose() + negativeArr.dispose() + } + + def testNDArrayScalar(): Unit = { + val c = NDArray.empty(10, 10) + val d = NDArray.empty(10, 10) + c.set(0.5f) + d.set(1.0f) + d -= c * 2f / 3f * 6f + c += 0.5f + require(c.toArray.sum - 100f < 1e-5f) + require(d.toArray.sum + 100f < 1e-5f) + c.set(2f) + require(c.toArray.sum - 200f < 1e-5f) + d.set(-c + 2f) + require(d.toArray.sum < 1e-5f) + c.dispose() + d.dispose() + } + + def testClip(): Unit = { + val shape = Shape(10) + val A = Random.uniform(-10f, 10f, shape) + val B = NDArray.clip(A, -2f, 2f) + val B1 = B.toArray + require(B1.forall { x => x >= -2f && x <= 2f }) + } + + def testDot(): Unit = { + val a = Random.uniform(-3f, 3f, Shape(3, 4)) + val b = Random.uniform(-3f, 3f, Shape(4, 5)) + val c = NDArray.dot(a, b) + val A = a.toArray.grouped(4).toArray + val B = b.toArray.grouped(5).toArray + val C = (Array[Array[Float]]() /: A)((acc, row) => acc :+ row.zip(B).map(z => + z._2.map(_ * z._1)).reduceLeft(_.zip(_).map(x => x._1 + x._2))).flatten + require(reldiff(c.toArray, C) < 1e-5f) + a.dispose() + b.dispose() + c.dispose() + } + + def testNDArrayOnehot(): Unit = { + val shape = Shape(100, 20) + var npy = (0 until shape.product).toArray.map(_.toFloat) + val arr = NDArray.empty(shape) + arr.set(npy) + val nRepeat = 3 + for (repeat <- 0 until nRepeat) { + val indices = (0 until shape(0)).map(i => scala.util.Random.nextInt(shape(1))) + npy = npy.map(i => 0f) + for (i <- 0 until indices.length) npy(i * shape(1) + indices(i)) = 1f + val ind = NDArray.empty(shape(0)) + ind.set(indices.toArray.map(_.toFloat)) + NDArray.onehotEncode(ind, arr) + require(arr.toArray.zip(npy).map(x => x._1 - x._2).sum == 0f) + ind.dispose() + } + arr.dispose() + } + + def main(args: Array[String]): Unit = { + val eray = new ProfilerNDArray + val parser: CmdLineParser = new CmdLineParser(eray) + try { + parser.parseArgument(args.toList.asJava) + + val path = s"${eray.outputPath}${File.separator}${eray.profilerName}" + Profiler.profilerSetConfig(mode = eray.profilerMode, fileName = path) + logger.info(s"profile file save to $path") + + Profiler.profilerSetState("run") + testBroadcast() + testNDArraySaveload() + testNDArrayCopy() + testNDArrayNegate() + testNDArrayScalar() + testClip() + testDot() + testNDArrayOnehot() + Profiler.profilerSetState("stop") + + } catch { + case ex: Exception => { + logger.error(ex.getMessage, ex) + parser.printUsage(System.err) + sys.exit(1) + } + } + } +} + +class ProfilerNDArray { + @Option(name = "--profiler-mode", usage = "the profiler mode, can be \"symbolic\" or \"all\".") + private val profilerMode: String = "all" + @Option(name = "--output-path", usage = "the profile file output directory.") + private val outputPath: String = "." + @Option(name = "--profile-filename", usage = "the profile file name.") + private val profilerName: String = "profile_ndarray.json" +} diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 473062a6d7e3..ea86664df120 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -2279,3 +2279,21 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister creatorLambda); return MXCustomOpRegister(regName, creator); } + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSetProfilerConfig + (JNIEnv *env, jobject obj, jint jmode, jstring jfilename) { + const char *fileName = env->GetStringUTFChars(jfilename, 0); + int ret = MXSetProfilerConfig(jmode, fileName); + env->ReleaseStringUTFChars(jfilename, fileName); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSetProfilerState + (JNIEnv *env, jobject obj, jint jstate) { + return MXSetProfilerState(jstate); +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDumpProfile + (JNIEnv *env, jobject obj) { + return MXDumpProfile(); +} diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h index a7430278ef7b..da8e1aea3562 100644 --- a/src/operator/cudnn_rnn-inl.h +++ b/src/operator/cudnn_rnn-inl.h @@ -479,6 +479,42 @@ class CuDNNRNNOp : public Operator { format_, 3, dim_w), CUDNN_STATUS_SUCCESS); + + // Query weight layout + // cudnnFilterDescriptor_t m_desc; + // CHECK_EQ(cudnnCreateFilterDescriptor(&m_desc), CUDNN_STATUS_SUCCESS); + // DType *p; + // int n = 2; + // int64_t last = 0; + // if (param_.mode == rnn_enum::kLstm) n = 8; + // else if (param_.mode == rnn_enum::kGru) n = 6; + + // for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); ++i) { + // for (int j = 0; j < n; ++j) { + // CHECK_EQ(cudnnGetRNNLinLayerMatrixParams(s->dnn_handle_, rnn_desc_, + // i, x_desc_vec_[0], w_desc_, 0, j, m_desc, (void**)&p), CUDNN_STATUS_SUCCESS); + // LOG(INFO) << ((int64_t)(p - NULL))/sizeof(DType) - last; + // last = ((int64_t)(p - NULL))/sizeof(DType); + // cudnnDataType_t t; + // cudnnTensorFormat_t f; + // int ndim = 5; + // int dims[5] = {0, 0, 0, 0, 0}; + // CHECK_EQ(cudnnGetFilterNdDescriptor(m_desc, ndim, &t, &f, &ndim, &dims[0]), + // CUDNN_STATUS_SUCCESS); + // LOG(INFO) << "w: " << i << " " << j << " " << ((int64_t)(p - NULL))/sizeof(DType); + // for (int i = 0; i < ndim; ++i) LOG(INFO) << dims[i]; + // } + // } + + // for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); ++i) { + // for (int j = 0; j < n; ++j) { + // CHECK_EQ(cudnnGetRNNLinLayerBiasParams(s->dnn_handle_, rnn_desc_, i, x_desc_vec_[0], + // w_desc_, 0, j, m_desc, (void**)&p), CUDNN_STATUS_SUCCESS); + // LOG(INFO) << ((int64_t)(p - NULL))/sizeof(DType) - last; + // last = ((int64_t)(p - NULL))/sizeof(DType); + // LOG(INFO) << "b: " << i << " " << j << " " << ((int64_t)(p - NULL))/sizeof(DType); + // } + // } } } diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 294d92fed2a8..bb7bc8e31553 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -292,7 +292,56 @@ def test_take_with_type(): 'take_a': 'write'}, arg_params=arg_params) +def check_rnn_consistency(cell1, cell2): + dshape = (32, 5, 200) + data = mx.sym.Variable('data') + + sym1, _ = cell1.unroll(5, data, merge_outputs=True) + mod1 = mx.mod.Module(sym1, label_names=None, context=mx.gpu(0)) + mod1.bind(data_shapes=[('data', dshape)], label_shapes=None) + + sym2, _ = cell2.unroll(5, data, merge_outputs=True) + mod2 = mx.mod.Module(sym2, label_names=None, context=mx.gpu(0)) + mod2.bind(data_shapes=[('data', dshape)], label_shapes=None) + + mod1.init_params() + args, auxs = mod1.get_params() + args = cell1.unpack_weights(args) + args = cell2.pack_weights(args) + mod2.set_params(args, auxs) + + batch=mx.io.DataBatch(data=[mx.random.uniform(shape=dshape)], label=[]) + mod1.forward(batch) + mod2.forward(batch) + + assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=1e-2) + + +def test_rnn(): + fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='rnn_relu', prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.RNNCell(100, activation='relu', prefix='l0_')) + stack.add(mx.rnn.RNNCell(100, activation='relu', prefix='l1_')) + + check_rnn_consistency(fused, stack) + check_rnn_consistency(stack, fused) + + +def test_lstm(): + fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='lstm', prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.LSTMCell(100, prefix='l0_')) + stack.add(mx.rnn.LSTMCell(100, prefix='l1_')) + + check_rnn_consistency(fused, stack) + check_rnn_consistency(stack, fused) + + if __name__ == '__main__': + test_lstm() + test_rnn() test_convolution_options() test_convolution_with_type() test_pooling_with_type() @@ -300,7 +349,6 @@ def test_take_with_type(): test_batchnorm_with_type() test_deconvolution_with_type() test_upsampling_with_type() - test_upsampling_bilinear_with_type() test_concat_with_type() test_elementwisesum_with_type() test_reshape_with_type() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1da4244ec433..9b45fd4f967e 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2264,7 +2264,7 @@ def test_grid_generator(): exe.arg_dict['affine'][:] = np.array([[1.0, 0, 0, 0, 1.0, 0]]) exe.forward() exe.backward(mx.nd.array(out_grad)) - assert_almost_equal(exe.grad_dict['affine'].asnumpy(), grad_est + grid_grad_npy, rtol=1e-3) + assert_almost_equal(exe.grad_dict['affine'].asnumpy(), grad_est + grid_grad_npy, rtol=1e-2) # transform_type = warp test_case = [(12,21),(4,3),(6,12)] diff --git a/tests/python/unittest/test_rnn.py b/tests/python/unittest/test_rnn.py index d357d2ebf498..9e92300cabd9 100644 --- a/tests/python/unittest/test_rnn.py +++ b/tests/python/unittest/test_rnn.py @@ -3,7 +3,7 @@ def test_rnn(): cell = mx.rnn.RNNCell(100, prefix='rnn_') - outputs, _ = mx.rnn.rnn_unroll(cell, 3, input_prefix='rnn_') + outputs, _ = cell.unroll(3, input_prefix='rnn_') outputs = mx.sym.Group(outputs) assert sorted(cell.params._params.keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output'] @@ -13,7 +13,7 @@ def test_rnn(): def test_lstm(): cell = mx.rnn.LSTMCell(100, prefix='rnn_') - outputs, _ = mx.rnn.rnn_unroll(cell, 3, input_prefix='rnn_') + outputs, _ = cell.unroll(3, input_prefix='rnn_') outputs = mx.sym.Group(outputs) assert sorted(cell.params._params.keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output'] @@ -25,7 +25,7 @@ def test_stack(): cell = mx.rnn.SequentialRNNCell() for i in range(5): cell.add(mx.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i)) - outputs, _ = mx.rnn.rnn_unroll(cell, 3, input_prefix='rnn_') + outputs, _ = cell.unroll(3, input_prefix='rnn_') outputs = mx.sym.Group(outputs) keys = sorted(cell.params._params.keys()) for i in range(5): @@ -38,7 +38,6 @@ def test_stack(): args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] - if __name__ == '__main__': test_rnn() test_lstm()