diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 6ab9c300d641..d95d59b26260 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1421,7 +1421,7 @@ MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle, * \param curr returns the current status. * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNDArrayIsDeferredComputeEnabled(int *curr); +MXNET_DLL int MXNDArrayIsDeferredCompute(int *curr); /*! * \brief set whether to enable deferred compute mode @@ -1429,7 +1429,7 @@ MXNET_DLL int MXNDArrayIsDeferredComputeEnabled(int *curr); * \param prev returns the previous status before this set. * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNDArraySetDeferredComputeEnabled(int deferred_compute_enabled, int *prev); +MXNET_DLL int MXNDArraySetIsDeferredCompute(int deferred_compute_enabled, int *prev); /*! * \brief Convert the graph constructed during deferred computation mode to a Symbol. diff --git a/python/mxnet/_deferred_compute.py b/python/mxnet/_deferred_compute.py index 18fc7605c6f9..c60fee309d79 100644 --- a/python/mxnet/_deferred_compute.py +++ b/python/mxnet/_deferred_compute.py @@ -27,7 +27,7 @@ def is_deferred_compute(): """Get status of deferred compute mode.""" curr = ctypes.c_bool() - check_call(_LIB.MXNDArrayIsDeferredComputeEnabled(ctypes.byref(curr))) + check_call(_LIB.MXNDArrayIsDeferredCompute(ctypes.byref(curr))) return curr.value def set_deferred_compute(is_deferred_compute): @@ -42,7 +42,7 @@ def set_deferred_compute(is_deferred_compute): Previous deferred compute state. """ prev = ctypes.c_int() - check_call(_LIB.MXNDArraySetDeferredComputeEnabled( + check_call(_LIB.MXNDArraySetIsDeferredCompute( ctypes.c_int(is_deferred_compute), ctypes.byref(prev))) return bool(prev.value) diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index c8b417b6119c..87f78ac5b449 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -437,13 +437,13 @@ int MXCachedOpRegisterOpHook(NDArrayHandle handle, API_END(); } -int MXNDArrayIsDeferredComputeEnabled(int *curr) { +int MXNDArrayIsDeferredCompute(int *curr) { API_BEGIN(); *curr = Imperative::Get()->is_deferred_compute(); API_END(); } -int MXNDArraySetDeferredComputeEnabled(int deferred_compute, int *prev) { +int MXNDArraySetIsDeferredCompute(int deferred_compute, int *prev) { API_BEGIN(); *prev = Imperative::Get()->set_is_deferred_compute(static_cast(deferred_compute)); API_END(); diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 857072f2d3ac..81e12bcf197e 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -309,8 +309,7 @@ void Imperative::RecordDeferredCompute(nnvm::NodeAttrs &&attrs, // a[:5] = 0 } DispatchMode dispatch_mode = DispatchMode::kUndefined; - Context default_ctx = Context::CPU(); - Context ctx = imperative::GetContext(attrs, inputs, outputs, default_ctx); + Context ctx = imperative::GetContext(attrs, inputs, outputs, Context::CPU()); imperative::SetShapeType(ctx, attrs, inputs, outputs, &dispatch_mode); nnvm::ObjectPtr node = nnvm::Node::Create(); @@ -368,8 +367,8 @@ nnvm::Symbol Imperative::GetDeferredComputeSymbol( return array == std::get<0>(input); }; - // std::vector>::iterator input_search = - auto input_search = std::find_if(inputs.begin(), inputs.end(), is_equal); + std::vector>::const_iterator input_search = + std::find_if(inputs.begin(), inputs.end(), is_equal); // Create symbol variable if (input_search != inputs.end()) { NDArray *ndinput; diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index a53bca154f99..001759852fd5 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -2126,10 +2126,7 @@ void NDArray::WaitToRead() const { void NDArray::WaitToWrite() const { if (is_none()) return; Imperative::DCInfo::Compute(*this); - /*! - * Push an empty mutable function to flush all preceding reads to the - * variable. - */ + // Push an empty mutable function to flush all preceding reads to the variable. Engine::Get()->PushAsync( [](RunContext, Engine::CallbackOnComplete on_complete) { on_complete(); }, Context{}, {}, {ptr_->var});