forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
LRN coding style changes (apache#21)
* LRN coding style change * Add const for local variables * Add req for LRN forward * rebase code * align API interface * revert modification in test_executor.
- Loading branch information
PatricZhao
authored and
Olivier
committed
Feb 6, 2018
1 parent
e9fd871
commit 0753b19
Showing
2 changed files
with
75 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,20 +21,21 @@ | |
* Copyright (c) 2015 by Contributors | ||
* \file lrn.cc | ||
* \brief | ||
* \author Bing Xu | ||
* \author Bing Xu, Patric Zhao ([email protected]) | ||
*/ | ||
|
||
#include "./lrn-inl.h" | ||
#include "../operator_common.h" | ||
#if MXNET_USE_MKLDNN == 1 | ||
#include "./mkldnn/mkldnn_lrn-inl.h" | ||
#endif | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
static bool LRNShape(const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape> *in_shape, | ||
std::vector<TShape> *out_shape) { | ||
bool LRNShape(const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape> *in_shape, | ||
std::vector<TShape> *out_shape) { | ||
using namespace mshadow; | ||
CHECK_EQ(in_shape->size(), 1U) << "Input:[data]"; | ||
const TShape &dshape = in_shape->at(0); | ||
|
@@ -45,13 +46,13 @@ static bool LRNShape(const nnvm::NodeAttrs& attrs, | |
return true; | ||
} | ||
|
||
static inline std::vector<std::string> ListArguments() { | ||
inline std::vector<std::string> ListArguments() { | ||
return {"data"}; | ||
} | ||
|
||
static bool LRNType(const nnvm::NodeAttrs& attrs, | ||
std::vector<int> *in_type, | ||
std::vector<int> *out_type) { | ||
bool LRNType(const nnvm::NodeAttrs& attrs, | ||
std::vector<int> *in_type, | ||
std::vector<int> *out_type) { | ||
CHECK_GE(in_type->size(), 1U); | ||
int dtype = (*in_type)[0]; | ||
CHECK_NE(dtype, -1) << "First input must have specified type"; | ||
|
@@ -80,37 +81,39 @@ struct LRNGrad { | |
} | ||
}; | ||
|
||
inline static bool LRNForwardInferStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
*dispatch_mode = DispatchMode::kFCompute; | ||
#if MXNET_USE_MKLDNN == 1 | ||
bool LRNForwardInferStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
CHECK(!in_attrs->empty()); | ||
#if MXNET_USE_MKLDNN == 1 | ||
if (dev_mask == mshadow::cpu::kDevMask) { | ||
*dispatch_mode = DispatchMode::kFComputeEx; | ||
storage_type_assign(out_attrs, mxnet::kDefaultStorage, | ||
dispatch_mode, DispatchMode::kFComputeEx); | ||
return true; | ||
} | ||
#endif | ||
for (size_t i = 0; i < out_attrs->size(); i++) | ||
(*out_attrs)[i] = kDefaultStorage; | ||
storage_type_assign(out_attrs, mxnet::kDefaultStorage, | ||
dispatch_mode, DispatchMode::kFCompute); | ||
return true; | ||
} | ||
|
||
inline static bool LRNBackwardInferStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
*dispatch_mode = DispatchMode::kFCompute; | ||
#if MXNET_USE_MKLDNN == 1 | ||
bool LRNBackwardInferStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
CHECK(!in_attrs->empty()); | ||
#if MXNET_USE_MKLDNN == 1 | ||
if (dev_mask == mshadow::cpu::kDevMask) { | ||
*dispatch_mode = DispatchMode::kFComputeEx; | ||
storage_type_assign(out_attrs, mxnet::kDefaultStorage, | ||
dispatch_mode, DispatchMode::kFComputeEx); | ||
return true; | ||
} | ||
#endif | ||
for (size_t i = 0; i < out_attrs->size(); i++) | ||
(*out_attrs)[i] = kDefaultStorage; | ||
storage_type_assign(out_attrs, mxnet::kDefaultStorage, | ||
dispatch_mode, DispatchMode::kFCompute); | ||
return true; | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters