Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Address comments 2
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Feb 15, 2020
1 parent 593248e commit 96c6213
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 14 deletions.
4 changes: 2 additions & 2 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1421,15 +1421,15 @@ MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle,
* \param curr returns the current status.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayIsDeferredComputeEnabled(int *curr);
MXNET_DLL int MXNDArrayIsDeferredCompute(int *curr);

/*!
* \brief set whether to enable deferred compute mode
* \param deferred_compute_enabled 1 to enable, 0 to disable.
* \param prev returns the previous status before this set.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArraySetDeferredComputeEnabled(int deferred_compute_enabled, int *prev);
MXNET_DLL int MXNDArraySetIsDeferredCompute(int deferred_compute_enabled, int *prev);

/*!
* \brief Convert the graph constructed during deferred computation mode to a Symbol.
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/_deferred_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def is_deferred_compute():
"""Get status of deferred compute mode."""
curr = ctypes.c_bool()
check_call(_LIB.MXNDArrayIsDeferredComputeEnabled(ctypes.byref(curr)))
check_call(_LIB.MXNDArrayIsDeferredCompute(ctypes.byref(curr)))
return curr.value

def set_deferred_compute(is_deferred_compute):
Expand All @@ -42,7 +42,7 @@ def set_deferred_compute(is_deferred_compute):
Previous deferred compute state.
"""
prev = ctypes.c_int()
check_call(_LIB.MXNDArraySetDeferredComputeEnabled(
check_call(_LIB.MXNDArraySetIsDeferredCompute(
ctypes.c_int(is_deferred_compute), ctypes.byref(prev)))
return bool(prev.value)

Expand Down
4 changes: 2 additions & 2 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,13 @@ int MXCachedOpRegisterOpHook(NDArrayHandle handle,
API_END();
}

int MXNDArrayIsDeferredComputeEnabled(int *curr) {
int MXNDArrayIsDeferredCompute(int *curr) {
API_BEGIN();
*curr = Imperative::Get()->is_deferred_compute();
API_END();
}

int MXNDArraySetDeferredComputeEnabled(int deferred_compute, int *prev) {
int MXNDArraySetIsDeferredCompute(int deferred_compute, int *prev) {
API_BEGIN();
*prev = Imperative::Get()->set_is_deferred_compute(static_cast<bool>(deferred_compute));
API_END();
Expand Down
7 changes: 3 additions & 4 deletions src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,7 @@ void Imperative::RecordDeferredCompute(nnvm::NodeAttrs &&attrs,
// a[:5] = 0
}
DispatchMode dispatch_mode = DispatchMode::kUndefined;
Context default_ctx = Context::CPU();
Context ctx = imperative::GetContext(attrs, inputs, outputs, default_ctx);
Context ctx = imperative::GetContext(attrs, inputs, outputs, Context::CPU());
imperative::SetShapeType(ctx, attrs, inputs, outputs, &dispatch_mode);

nnvm::ObjectPtr node = nnvm::Node::Create();
Expand Down Expand Up @@ -368,8 +367,8 @@ nnvm::Symbol Imperative::GetDeferredComputeSymbol(
return array == std::get<0>(input);
};

// std::vector<std::pair<NDArray *, std::string>>::iterator input_search =
auto input_search = std::find_if(inputs.begin(), inputs.end(), is_equal);
std::vector<std::pair<NDArray *, std::string>>::const_iterator input_search =
std::find_if(inputs.begin(), inputs.end(), is_equal);
// Create symbol variable
if (input_search != inputs.end()) {
NDArray *ndinput;
Expand Down
5 changes: 1 addition & 4 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2126,10 +2126,7 @@ void NDArray::WaitToRead() const {
void NDArray::WaitToWrite() const {
if (is_none()) return;
Imperative::DCInfo::Compute(*this);
/*!
* Push an empty mutable function to flush all preceding reads to the
* variable.
*/
// Push an empty mutable function to flush all preceding reads to the variable.
Engine::Get()->PushAsync(
[](RunContext, Engine::CallbackOnComplete on_complete) { on_complete(); },
Context{}, {}, {ptr_->var});
Expand Down

0 comments on commit 96c6213

Please sign in to comment.