diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index e93f8dac2808..7493f1aa1acc 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2606,12 +2606,12 @@ def astype(self, dtype, copy=True): """ + if dtype is None: + dtype = _np.float32 if not copy and np.dtype(dtype) == self.dtype: return self - res = empty(self.shape, ctx=self.ctx, dtype=dtype) - self.copyto(res) - return res + return op.cast(self, dtype=dtype) def copyto(self, other): """Copies the value of this array to another array. diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py index b0238e369abc..eddf8406fa0d 100644 --- a/python/mxnet/ndarray/sparse.py +++ b/python/mxnet/ndarray/sparse.py @@ -230,6 +230,7 @@ def astype(self, dtype, copy=True): if not copy and np.dtype(dtype) == self.dtype: return self + # Use copyto for casting, as op.cast(self, dtype=dtype) doesn't support sparse stype res = zeros(shape=self.shape, ctx=self.context, dtype=dtype, stype=self.stype) self.copyto(res) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 382dbc0ea472..0409805b25ef 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -1247,12 +1247,12 @@ def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True): # raise ValueError('casting must be equal to \'unsafe\'') if not subok: raise ValueError('subok must be equal to True') + if dtype is None: + dtype = _np.float32 if not copy and _np.dtype(dtype) == self.dtype: return self - res = empty(self.shape, dtype=dtype, ctx=self.ctx) - self.copyto(res) - return res + return _npi.cast(self, dtype=dtype) def copyto(self, other): """Copies the value of this array to another array. diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 5efa73294aed..707fcf58b297 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -2177,8 +2177,9 @@ bool CopyToType(const nnvm::NodeAttrs &attrs, std::vector *in_attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); int in_type = in_attrs->at(0); - int out_type = in_type; - TYPE_ASSIGN_CHECK(*out_attrs, 0, out_type); + if (out_attrs->at(0) == -1) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_type); + } return out_attrs->at(0) != -1; } diff --git a/tests/python/unittest/test_deferred_compute.py b/tests/python/unittest/test_deferred_compute.py index 97af9c46ed8f..390b57960e18 100644 --- a/tests/python/unittest/test_deferred_compute.py +++ b/tests/python/unittest/test_deferred_compute.py @@ -260,6 +260,14 @@ def f(a, *, nd): _assert_dc(_dc_simple_setup, f) +def test_dc_astype(): + def f(a, *, nd): + a = a.astype(np.int32) + b = nd.zeros_like(a) + return [a + b] + + _assert_dc(_dc_simple_setup, f) + ############################################################################### # Gluon ###############################################################################