Skip to content

Commit

Permalink
Module predict API can accept NDArray as input (apache#12166)
Browse files Browse the repository at this point in the history
* forward and predict can accept nd.array np.array
  • Loading branch information
azai91 authored and nswamy committed Aug 15, 2018
1 parent 80d0ce5 commit bc423e9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
14 changes: 12 additions & 2 deletions python/mxnet/module/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
import time
import logging
import warnings
import numpy as np

from .. import metric
from .. import ndarray

from ..context import cpu
from ..model import BatchEndParam
from ..initializer import Uniform
from ..io import DataDesc
from ..io import DataDesc, DataIter, DataBatch
from ..base import _as_list


Expand Down Expand Up @@ -333,7 +334,7 @@ def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True,
Parameters
----------
eval_data : DataIter
eval_data : DataIter or NDArray or numpy array
Evaluation data to run prediction on.
num_batch : int
Defaults to ``None``, indicates running all the batches in the data iterator.
Expand Down Expand Up @@ -363,6 +364,15 @@ def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True,
"""
assert self.binded and self.params_initialized

if isinstance(eval_data, (ndarray.NDArray, np.ndarray)):
if isinstance(eval_data, np.ndarray):
eval_data = ndarray.array(eval_data)
self.forward(DataBatch([eval_data]))
return self.get_outputs()[0]

if not isinstance(eval_data, DataIter):
raise ValueError('eval_data must be of type NDArray or DataIter')

if reset:
eval_data.reset()

Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,8 @@ def test_forward_reshape():
for_training=False, force_rebind=True)
assert mod.predict(pred_dataiter).shape == tuple([10, num_class])

@with_seed()
def test_forward_types():
#Test forward with other data batch API
Batch = namedtuple('Batch', ['data'])
data = mx.sym.Variable('data')
Expand All @@ -786,6 +788,18 @@ def test_forward_reshape():
mod.forward(Batch(data2))
assert mod.get_outputs()[0].shape == (3, 5)

#Test forward with other NDArray and np.ndarray inputs
data = mx.sym.Variable('data')
out = data * 2
mod = mx.mod.Module(symbol=out, label_names=None)
mod.bind(data_shapes=[('data', (1, 10))])
mod.init_params()
data1 = mx.nd.ones((1, 10))
assert mod.predict(data1).shape == (1, 10)
data2 = np.ones((1, 10))
assert mod.predict(data1).shape == (1, 10)



if __name__ == '__main__':
import nose
Expand Down

0 comments on commit bc423e9

Please sign in to comment.