From bc423e9a9c834d4e4eafe0dbf627811ca481785e Mon Sep 17 00:00:00 2001 From: Alexander Zai Date: Wed, 15 Aug 2018 12:44:07 -0400 Subject: [PATCH] Module predict API can accept NDArray as input (#12166) * forward and predict can accept nd.array np.array --- python/mxnet/module/base_module.py | 14 ++++++++++++-- tests/python/unittest/test_module.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py index 654e41bf3656..08ab8fa89e49 100644 --- a/python/mxnet/module/base_module.py +++ b/python/mxnet/module/base_module.py @@ -22,6 +22,7 @@ import time import logging import warnings +import numpy as np from .. import metric from .. import ndarray @@ -29,7 +30,7 @@ from ..context import cpu from ..model import BatchEndParam from ..initializer import Uniform -from ..io import DataDesc +from ..io import DataDesc, DataIter, DataBatch from ..base import _as_list @@ -333,7 +334,7 @@ def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True, Parameters ---------- - eval_data : DataIter + eval_data : DataIter or NDArray or numpy array Evaluation data to run prediction on. num_batch : int Defaults to ``None``, indicates running all the batches in the data iterator. @@ -363,6 +364,15 @@ def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True, """ assert self.binded and self.params_initialized + if isinstance(eval_data, (ndarray.NDArray, np.ndarray)): + if isinstance(eval_data, np.ndarray): + eval_data = ndarray.array(eval_data) + self.forward(DataBatch([eval_data])) + return self.get_outputs()[0] + + if not isinstance(eval_data, DataIter): + raise ValueError('eval_data must be of type NDArray or DataIter') + if reset: eval_data.reset() diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index a21527a5a4ad..5e60989489f6 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -772,6 +772,8 @@ def test_forward_reshape(): for_training=False, force_rebind=True) assert mod.predict(pred_dataiter).shape == tuple([10, num_class]) +@with_seed() +def test_forward_types(): #Test forward with other data batch API Batch = namedtuple('Batch', ['data']) data = mx.sym.Variable('data') @@ -786,6 +788,18 @@ def test_forward_reshape(): mod.forward(Batch(data2)) assert mod.get_outputs()[0].shape == (3, 5) + #Test forward with other NDArray and np.ndarray inputs + data = mx.sym.Variable('data') + out = data * 2 + mod = mx.mod.Module(symbol=out, label_names=None) + mod.bind(data_shapes=[('data', (1, 10))]) + mod.init_params() + data1 = mx.nd.ones((1, 10)) + assert mod.predict(data1).shape == (1, 10) + data2 = np.ones((1, 10)) + assert mod.predict(data1).shape == (1, 10) + + if __name__ == '__main__': import nose