This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7a1832c
commit 910738d
Showing
10 changed files
with
256 additions
and
282 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# coding: utf-8 | ||
""" code for executor. """ | ||
from __future__ import absolute_import | ||
|
||
import ctypes | ||
from .base import _LIB | ||
from .base import c_array, c_str, mx_uint, NArrayHandle, ExecutorHandle | ||
from .base import check_call | ||
from .narray import NArray | ||
|
||
class Executor(object): | ||
""" Executor is the actual executing object of MXNet.""" | ||
def __init__(self, handle): | ||
"""Init an executor from handle | ||
Parameters | ||
---------- | ||
handle: ExecutorHandle | ||
ExecutorHandle generated by calling Bind | ||
""" | ||
if not isinstance(handle, ExecutorHandle): | ||
raise TypeError("Handle type error") | ||
self.handle = handle | ||
|
||
def forward(self): | ||
"""Do forward.""" | ||
check_call(_LIB.MXExecutorForward(self.handle)) | ||
|
||
def backward(self, grads): | ||
"""Do backward on heads' gradient. | ||
Parameters | ||
---------- | ||
grads: Array of NArray | ||
heads' gradient | ||
""" | ||
for obj in grads: | ||
if not isinstance(obj, NArray): | ||
raise TypeError("inputs must be NArray") | ||
narray = c_array(NArrayHandle, [item.handle for item in grads]) | ||
check_call(_LIB.MXExecutorBackward(self.handle, len(grads), narray)) | ||
|
||
def heads(self): | ||
"""list all heads' output narray | ||
Returns | ||
------- | ||
A list of narray binded to the heads of executor. | ||
""" | ||
# TODO: think of access, make heads read only. | ||
# (consider support read only NArray(NArrayView)) | ||
# Otherwise some of the internal might depends on out_data | ||
# if user set the content of the head, the backward behavior can be incorrect. | ||
out_size = mx_uint() | ||
handles = ctypes.POINTER(NArrayHandle)() | ||
check_call(_LIB.MXExecutorHeads(self.handle, ctypes.byref(out_size), ctypes.byref(handles))) | ||
return [NArray(NArrayHandle(handles[i])) for i in range(out_size.value)] |
Oops, something went wrong.