Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Jun 30, 2015
1 parent 59a7cee commit 480604f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
15 changes: 9 additions & 6 deletions api/python/mxnet/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from .base import lib
from .base import c_array
from .base import mx_uint, mx_float, NArrayHandle, FunctionHandle
from .base import check_call
from .base import check_call, MXNetError
from .narray import NArray, _new_empty_handle

class _Function:
# constants for type masks
NARRAY_ARG_BEFORE_SCALAR = 1
SCALAR_ARG_BEFORE_NARRAY = 2
ACCEPT_EMPTY_MUTATE_TARGET = 3
SCALAR_ARG_BEFORE_NARRAY = 1 << 1
ACCEPT_EMPTY_MUTATE_TARGET = 1 << 2

def __init__(self, handle, name):
"""Initialize the function with handle
Expand All @@ -27,6 +27,7 @@ def __init__(self, handle, name):
the name of the function
"""
self.handle = handle
self.name = name
n_used_vars = mx_uint()
n_scalars = mx_uint()
n_mutate_vars = mx_uint()
Expand Down Expand Up @@ -70,14 +71,16 @@ def __call__(self, *args, **kwargs):
"""
if 'mutate_vars' in kwargs:
mutate_vars = kwargs['mutate_vars']
if isinstance(mutate_vars, NArray):
mutate_vars = (mutate_vars,)
if len(mutate_vars) != self.n_mutate_vars:
raise MXNetError('expect %d mutate_vars in function %s', self.n_mutate_vars, self.name)
raise MXNetError('expect %d mutate_vars in op.%s', self.n_mutate_vars, self.name)
else:
if self.accept_empty_mutate:
mutate_vars = tuple(
NArray(_new_empty_handle()) for i in range(self.n_mutate_vars))
else:
raise MXNetError('mutate_vars argument is required to call this function')
else:
raise MXNetError('mutate_vars argument is required to call op.%s' % self.name)

self.invoke_with_handle_([args[i].handle for i in self.use_vars_range],
[args[i] for i in self.scalar_range],
Expand Down
2 changes: 2 additions & 0 deletions api/python/mxnet/narray.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .base import MXNetError
from .context import Context

# op is implicitly imported from .function
# as a singleton of _FunctionRegistry
global op

def _new_empty_handle():
Expand Down
1 change: 0 additions & 1 deletion api/python/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

print(c.context)
print(cc.numpy)

d = c.copyto(mx.Context('cpu', 0))

print(d.numpy)
Expand Down

0 comments on commit 480604f

Please sign in to comment.