Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix astype
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Feb 20, 2020
1 parent 052f864 commit f35c20e
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 8 deletions.
6 changes: 3 additions & 3 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2606,12 +2606,12 @@ def astype(self, dtype, copy=True):
<type 'numpy.int32'>
"""

if dtype is None:
dtype = _np.float32
if not copy and np.dtype(dtype) == self.dtype:
return self

res = empty(self.shape, ctx=self.ctx, dtype=dtype)
self.copyto(res)
return res
return op.cast(self, dtype=dtype)

def copyto(self, other):
"""Copies the value of this array to another array.
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/ndarray/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def astype(self, dtype, copy=True):
if not copy and np.dtype(dtype) == self.dtype:
return self

# Use copyto for casting, as op.cast(self, dtype=dtype) doesn't support sparse stype
res = zeros(shape=self.shape, ctx=self.context,
dtype=dtype, stype=self.stype)
self.copyto(res)
Expand Down
6 changes: 3 additions & 3 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,12 +1247,12 @@ def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True): #
raise ValueError('casting must be equal to \'unsafe\'')
if not subok:
raise ValueError('subok must be equal to True')
if dtype is None:
dtype = _np.float32
if not copy and _np.dtype(dtype) == self.dtype:
return self

res = empty(self.shape, dtype=dtype, ctx=self.ctx)
self.copyto(res)
return res
return _npi.cast(self, dtype=dtype)

def copyto(self, other):
"""Copies the value of this array to another array.
Expand Down
5 changes: 3 additions & 2 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2177,8 +2177,9 @@ bool CopyToType(const nnvm::NodeAttrs &attrs, std::vector<int> *in_attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
int in_type = in_attrs->at(0);
int out_type = in_type;
TYPE_ASSIGN_CHECK(*out_attrs, 0, out_type);
if (out_attrs->at(0) == -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_type);
}
return out_attrs->at(0) != -1;
}

Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_deferred_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,14 @@ def f(a, *, nd):
_assert_dc(_dc_simple_setup, f)


def test_dc_astype():
def f(a, *, nd):
a = a.astype(np.int32)
b = nd.zeros_like(a)
return [a + b]

_assert_dc(_dc_simple_setup, f)

###############################################################################
# Gluon
###############################################################################
Expand Down

0 comments on commit f35c20e

Please sign in to comment.