Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[1.x] Backport #19103 #19117

Merged
merged 3 commits into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ class MyStatefulGemm : public CustomStatefulOp {
};

MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
Expand Down
6 changes: 6 additions & 0 deletions example/extensions/lib_custom_op/relu_lib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,18 @@ class MyStatefulReluGPU : public CustomStatefulOp {
};

MXReturnValue createOpStateCPU(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
*op_inst = new MyStatefulReluCPU(attrs);
return MX_SUCCESS;
}

MXReturnValue createOpStateGPU(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
*op_inst = new MyStatefulReluGPU(attrs);
return MX_SUCCESS;
Expand Down
3 changes: 3 additions & 0 deletions example/extensions/lib_custom_op/transposecsr_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ class MyStatefulTransposeCSR : public CustomStatefulOp {
};

MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
Expand Down
3 changes: 3 additions & 0 deletions example/extensions/lib_custom_op/transposerowsp_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ class MyStatefulTransposeRowSP : public CustomStatefulOp {
};

MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
Expand Down
3 changes: 3 additions & 0 deletions example/extensions/lib_subgraph/subgraph_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ class MyStatefulOp : public CustomStatefulOp {
};

MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
std::string serialized_subgraph = "[empty]";
// MXNet subgraph is stored as Symbol in operator node attrs subgraphs field
Expand Down
15 changes: 10 additions & 5 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
#endif

/* Make sure to update the version number everytime you make changes */
#define MX_LIBRARY_VERSION 9
#define MX_LIBRARY_VERSION 10

/*!
* \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple
Expand Down Expand Up @@ -732,6 +732,9 @@ typedef MXReturnValue (*mutateInputs_t)(const std::unordered_map<std::string,
std::vector<int>* input_indices);
typedef MXReturnValue (*createOpState_t)(const std::unordered_map<std::string,
std::string>& attributes,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
Comment on lines +735 to +737
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an API change from 1.x

Copy link
Contributor Author

@samskalicky samskalicky Sep 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @szha, youre right. The Extensions API is not backwards compatible. Every change increments the version number:
https://github.com/apache/incubator-mxnet/blob/f1acda73e0b2ca5f1b6ff89f25c17a838819be82/include/mxnet/lib_api.h#L56
And the version number is required to match:
https://github.com/apache/incubator-mxnet/blob/f1acda73e0b2ca5f1b6ff89f25c17a838819be82/src/c_api/c_api.cc#L1500-L1503
Each release of MXNet will require libraries to have the same version of lib_api.h. The versioning of the Extensions APIs is different than MXNet's C API.

CustomStatefulOp**);

/*!
Expand Down Expand Up @@ -1000,8 +1003,9 @@ typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* ke

#define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState"
typedef int (*opCallCreateOpState_t)(createOpState_t create_op, const char* const* keys,
const char* const* vals, int num,
void** state_op);
const char* const* vals, int num, const char* dev_type,
int dev_id, unsigned int** inshapes, int* indims,
int num_in, const int* intypes, void** state_op);

#define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op,
Expand Down Expand Up @@ -1190,8 +1194,9 @@ extern "C" {

/*! \brief returns status of calling createStatefulOp function for operator from library */
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys,
const char* const* vals, int num,
void** state_op);
const char* const* vals, int num, const char* dev_type,
int dev_id, unsigned int** inshapes, int* indims,
int num_in, const int* intypes, void** state_op);

/*! \brief returns status of calling Stateful Forward/Backward for operator from library */
MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes,
Expand Down
30 changes: 28 additions & 2 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,28 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
attr_vals.push_back(kv.second.c_str());
}

// string repr of supported context for custom library, currently only "cpu" and "gpu"
const char* ctx_str = ctx.dev_mask() == Context::kCPU ? "cpu" : "gpu";

std::vector<uint32_t*> inshapes(in_shapes.size());
std::vector<int> indims(in_shapes.size());

// determine amount of memory needed to store all the input shapes
size_t buff_size = 0;
for (size_t i = 0; i < in_shapes.size(); ++i)
buff_size += in_shapes[i].ndim();

// copy input shapes to raw memory layout
std::vector<uint32_t> inbuff(buff_size);
uint32_t *ptr = inbuff.data();
for (size_t i = 0; i < in_shapes.size(); ++i) {
inshapes[i] = ptr;
indims[i] = in_shapes[i].ndim();
for (int j = 0; j < in_shapes[i].ndim(); ++j, ++ptr) {
*ptr = static_cast<uint32_t>(in_shapes[i][j]);
}
}

// convert subgraph symbol from node attributes to char*
std::string subgraph_json;
if (!attrs.subgraphs.empty()) {
Expand All @@ -1110,15 +1132,19 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
CHECK(createop_map.count("cpu") > 0)
<< "CPU CreateOpState not implemented for '" << name_str << "'";
int retval = callCreateOpState(createop_map.at("cpu"), attr_keys.data(), attr_vals.data(),
attr_keys.size(), &state_op_inst);
attr_keys.size(), ctx_str, ctx.real_dev_id(),
inshapes.data(), indims.data(),
in_shapes.size(), in_types.data(), &state_op_inst);
std::string msgs = getExtensionMsgs(msgSize, msgGet);
CHECK(retval) << "Error calling CreateOpState CPU for custom operator '" << name_str << "'"
<< msgs;
} else if (ctx.dev_mask() == Context::kGPU) {
CHECK(createop_map.count("gpu") > 0)
<< "GPU CreateOpState not implemented for '" << name_str << "'";
int retval = callCreateOpState(createop_map.at("gpu"), attr_keys.data(), attr_vals.data(),
attr_keys.size(), &state_op_inst);
attr_keys.size(), ctx_str, ctx.real_dev_id(),
inshapes.data(), indims.data(),
in_shapes.size(), in_types.data(), &state_op_inst);
std::string msgs = getExtensionMsgs(msgSize, msgGet);
CHECK(retval) << "Error calling CreateOpState GPU for custom operator '" << name_str << "'"
<< msgs;
Expand Down
23 changes: 20 additions & 3 deletions src/lib_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1210,19 +1210,36 @@ MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char* co

/*! \brief returns status of calling createStatefulOp function for operator from library */
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys,
const char* const* vals, int num,
void** state_op) {
const char* const* vals, int num, const char* dev_type,
int dev_id, unsigned int** inshapes, int* indims,
int num_in, const int* intypes, void** state_op) {
// create map of attributes from list
std::unordered_map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
attrs[std::string(keys[i])] = std::string(vals[i]);
}

mxnet::ext::MXContext ctx(dev_type, dev_id);

// create a vector of shapes for inputs
std::vector<std::vector<unsigned int> > in_shapes(num_in);
for (int i = 0; i < num_in; i++) {
for (int j = 0; j < indims[i]; j++) {
in_shapes[i].push_back(inshapes[i][j]);
}
}

// create a vector of types for inputs
std::vector<int> in_types(num_in);
for (int i = 0; i < num_in; i++) {
in_types[i] = intypes[i];
}

// void pointer to hold custom state op instance created in custom library
// eventually state_op pointer is populated by instance from custom library
mxnet::ext::CustomStatefulOp** op_ptr =
reinterpret_cast<mxnet::ext::CustomStatefulOp**>(state_op);
return create_op(attrs, op_ptr);
return create_op(attrs, ctx, in_shapes, in_types, op_ptr);
}

/*! \brief returns status of calling Stateful Forward/Backward for operator from library */
Expand Down