Skip to content

Commit

Permalink
Add GetOpArgName, TryGetTensorValue, IsConstantTensor in pir parser; …
Browse files Browse the repository at this point in the history
…Adjust the return value of GetOpInputOutputName2Idx; Adjust GetAttr, HasInput, HasOutput, IsConstantInput in mapper.h
  • Loading branch information
0x3878f committed Oct 9, 2024
1 parent 50b4a47 commit c09c2ff
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 34 deletions.
70 changes: 50 additions & 20 deletions paddle2onnx/mapper/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,20 @@ class Mapper {
if (opset_version <= helper_->GetOpsetVersion()) {
v = false;
}
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
std::string output_name = "";
if (op.outputs(0).arguments_size() > 0) {
output_name = op.outputs(0).arguments(0);
std::string op_type = "";
if(in_pir_mode) {
auto &op = pir_parser_-> global_blocks_ops[pir_op_idx_];
output_name = GetOutput(0)[0].name;
op_type = op->name();
}
else {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
if (op.outputs(0).arguments_size() > 0) {
output_name = op.outputs(0).arguments(0);
}
op_type = op.type();
}
std::string op_type = op.type();
std::string prefix = "[Paddle2ONNX] [" + op_type + ": " + output_name + "]";
return P2OLogger(v, prefix);
}
Expand Down Expand Up @@ -183,15 +191,13 @@ class Mapper {

bool HasInput(const std::string &name) const {
if (in_pir_mode) {
int32_t value_idx = pir_parser_->GetOpInputOutputName2Idx(pir_op_idx_, name, true);
return pir_parser_->OpHasInput(pir_op_idx_, value_idx);
return pir_parser_->GetOpInputOutputName2Idx(pir_op_idx_, name, true) != -1;
}
return parser_->OpHasInput(block_idx_, op_idx_, name);
}
bool HasOutput(const std::string &name) const {
if (in_pir_mode) {
int32_t value_idx = pir_parser_->GetOpInputOutputName2Idx(pir_op_idx_, name, false);
return pir_parser_->OpHasOutput(pir_op_idx_, value_idx);
return pir_parser_->GetOpInputOutputName2Idx(pir_op_idx_, name, false) != -1;
}
return parser_->OpHasOutput(block_idx_, op_idx_, name);
}
Expand Down Expand Up @@ -231,10 +237,13 @@ class Mapper {
return parser_->GetOpAttrVar(block_idx_, op_idx_, name);
}

/*
* todo(wangmingkai02): add GetInputAttrVar function.
std::vector<int64_t> GetInputAttrVar(const std::string &input_name, const std::string &attr_name) const {
int32_t value_idx = pir_parser_->GetOpInputOutputName2Idx(pir_op_idx_, input_name, true);
return pir_parser_->GetOpAttrVar(pir_op_idx_, value_idx, attr_name);
}
*/


bool HasAttr(const std::string &name) const {
Expand All @@ -250,7 +259,7 @@ class Mapper {
void GetAttr(const std::string &name, int64_t *val) {
if (in_pir_mode) {
auto &op = pir_parser_->global_blocks_ops[pir_op_idx_];
pir_parser_->GetOpAttr(op, name, val);
pir_parser_->GetOpAttr(op, pir_parser_->GetOpArgName(pir_op_idx_, name), val);
} else {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
Expand All @@ -259,7 +268,7 @@ class Mapper {
void GetAttr(const std::string &name, float *val) {
if (in_pir_mode) {
auto &op = pir_parser_->global_blocks_ops[pir_op_idx_];
pir_parser_->GetOpAttr(op, name, val);
pir_parser_->GetOpAttr(op, pir_parser_->GetOpArgName(pir_op_idx_, name), val);
} else {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
Expand All @@ -268,7 +277,7 @@ class Mapper {
void GetAttr(const std::string &name, double *val) {
if (in_pir_mode) {
auto &op = pir_parser_->global_blocks_ops[pir_op_idx_];
pir_parser_->GetOpAttr(op, name, val);
pir_parser_->GetOpAttr(op, pir_parser_->GetOpArgName(pir_op_idx_, name), val);
} else {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
Expand All @@ -277,7 +286,7 @@ class Mapper {
void GetAttr(const std::string &name, bool *val) {
if (in_pir_mode) {
auto &op = pir_parser_->global_blocks_ops[pir_op_idx_];
pir_parser_->GetOpAttr(op, name, val);
pir_parser_->GetOpAttr(op, pir_parser_->GetOpArgName(pir_op_idx_, name), val);
} else {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
Expand All @@ -286,7 +295,7 @@ class Mapper {
void GetAttr(const std::string &name, std::string *val) {
if (in_pir_mode) {
auto &op = pir_parser_->global_blocks_ops[pir_op_idx_];
pir_parser_->GetOpAttr(op, name, val);
pir_parser_->GetOpAttr(op, pir_parser_->GetOpArgName(pir_op_idx_, name), val);
} else {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
Expand All @@ -295,7 +304,7 @@ class Mapper {
void GetAttr(const std::string &name, std::vector<int64_t> *val) {
if (in_pir_mode) {
auto &op = pir_parser_->global_blocks_ops[pir_op_idx_];
pir_parser_->GetOpAttr(op, name, val);
pir_parser_->GetOpAttr(op, pir_parser_->GetOpArgName(pir_op_idx_, name), val);
} else {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
Expand All @@ -304,7 +313,7 @@ class Mapper {
void GetAttr(const std::string &name, std::vector<float> *val) {
if (in_pir_mode) {
auto &op = pir_parser_->global_blocks_ops[pir_op_idx_];
pir_parser_->GetOpAttr(op, name, val);
pir_parser_->GetOpAttr(op, pir_parser_->GetOpArgName(pir_op_idx_, name), val);
} else {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
Expand All @@ -313,16 +322,22 @@ class Mapper {
void GetAttr(const std::string &name, std::vector<double> *val) {
if (in_pir_mode) {
auto &op = pir_parser_->global_blocks_ops[pir_op_idx_];
pir_parser_->GetOpAttr(op, name, val);
pir_parser_->GetOpAttr(op, pir_parser_->GetOpArgName(pir_op_idx_, name), val);
} else {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
}

bool IsConstantInput(const std::string &input_key) const {
auto input_info = GetInput(input_key);
return parser_->IsConstantTensor(block_idx_, input_info[0].name);
if(in_pir_mode) {
int32_t value_idx = pir_parser_->GetOpInputOutputName2Idx(pir_op_idx_, input_key, true);
return pir_parser_->IsConstantTensor(pir_op_idx_, value_idx);
}
else {
auto input_info = GetInput(input_key);
return parser_->IsConstantTensor(block_idx_, input_info[0].name);
}
}

bool IsConstant(const TensorInfo &info) const {
Expand All @@ -331,8 +346,23 @@ class Mapper {

template <typename T>
bool TryGetInputValue(const std::string &input_key, std::vector<T> *data) {
auto input_info = GetInput(input_key);
return parser_->TryGetTensorValue(block_idx_, input_info[0].name, data);
if(in_pir_mode) {
return pir_parser_->TryGetTensorValue(pir_op_idx_, pir_parser_->GetOpInputOutputName2Idx(pir_op_idx_, input_key, true), data);
}
else {
auto input_info = GetInput(input_key);
return parser_->TryGetTensorValue(block_idx_, input_info[0].name, data);
}
}

template <typename T>
bool TryGetInputValue(const std::string &input_key, T *data) {
if(in_pir_mode) {
return pir_parser_->TryGetTensorValue(pir_op_idx_, pir_parser_->GetOpInputOutputName2Idx(pir_op_idx_, input_key, true), data);
}
else {
Assert(false, "Not support in old IR.");
}
}

template <typename T>
Expand Down
68 changes: 56 additions & 12 deletions paddle2onnx/parser/pir_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ namespace paddle2onnx {
}
}

int32_t PaddlePirParser::GetOpInputOutputName2Idx(int64_t op_id, std::string name, bool is_input) const {
auto& op = global_blocks_ops[op_id];
std::string PaddlePirParser::GetOpArgName(int64_t op_id, std::string name) const {
auto& op = global_blocks_ops[op_id];
pir::IrContext* ctx = pir::IrContext::Instance();
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
Expand All @@ -210,20 +210,40 @@ namespace paddle2onnx {
}
}
}
return name;
}

int32_t PaddlePirParser::GetOpInputOutputName2Idx(int64_t op_id, std::string name, bool is_input) const {
auto& op = global_blocks_ops[op_id];
pir::IrContext* ctx = pir::IrContext::Instance();
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name = op->attributes()
.at("op_name")
.dyn_cast<pir::StrAttribute>()
.AsString();
}
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
paddle::dialect::OpYamlInfoParser yaml_parser(
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>()->get_op_info_(op_name),
// paddle::dialect::IsLegacyOp(op_name));
false);
name = GetOpArgName(op_id, name);
bool exist = is_input ? yaml_parser.InputName2Id().count(name) : yaml_parser.OutputName2Id().count(name);
PADDLE_ENFORCE_EQ(
exist,
true,
common::errors::InvalidArgument(
"Cannot find input/output name '%s' in op yaml info of %s.",
name, op_name));
if (!exist) {
P2OLogger() << "Cannot find input/output name '" << name
<< "' in op yaml info of " << op_name << std::endl;
return -1;
}
// PADDLE_ENFORCE_EQ(
// exist,
// true,
// common::errors::InvalidArgument(
// "Cannot find input/output name '%s' in op yaml info of %s.",
// name, op_name));
return is_input ? yaml_parser.InputName2Id().at(name) : yaml_parser.OutputName2Id().at(name);
}

bool PaddlePirParser::LoadProgram(const std::string& model) {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand Down Expand Up @@ -706,6 +726,11 @@ void PaddlePirParser::GetOpAttr(const pir::Operation* op,

std::vector<TensorInfo> PaddlePirParser::GetOpInput(
int64_t op_id, int64_t input_idx) const {
PADDLE_ENFORCE_GT(
input_idx,
-1,
common::errors::InvalidArgument(
"input_idx should be greater than -1 in GetOpInput."));
pir::Operation* op = global_blocks_ops[op_id];
PADDLE_ENFORCE_LT(input_idx, op->num_operands(),
common::errors::InvalidArgument(
Expand All @@ -716,14 +741,22 @@ std::vector<TensorInfo> PaddlePirParser::GetOpInput(

std::vector<TensorInfo> PaddlePirParser::GetOpOutput(
int64_t op_id, int64_t output_idx) const {
PADDLE_ENFORCE_GT(
output_idx,
-1,
common::errors::InvalidArgument(
"output_idx should be greater than -1 in GetOpOutput."));
pir::Operation* op = global_blocks_ops[op_id];
PADDLE_ENFORCE_LT(output_idx, op->num_results(),
common::errors::InvalidArgument(
"output index %d is out of range, the output size is %d",
output_idx, op->num_results()));
PADDLE_ENFORCE_LT(
output_idx,
op->num_results(),
common::errors::InvalidArgument(
"output index %d is out of range, the output size is %d",
output_idx, op->num_results()));
return GetTensorInfo(op->result(output_idx));
}

/**
std::vector<int64_t> PaddlePirParser::GetOpAttrVar(int64_t op_id, int64_t input_idx, const std::string &name) const {
pir::Operation* op = global_blocks_ops[op_id]->operand(input_idx).source().defining_op();
std::vector<int64_t> result;
Expand All @@ -734,4 +767,15 @@ std::vector<TensorInfo> PaddlePirParser::GetOpOutput(
}
return result;
}
*/

bool PaddlePirParser::IsConstantTensor(int64_t op_id, int64_t input_idx) const {
PADDLE_ENFORCE_GT(
input_idx,
-1,
common::errors::InvalidArgument(
"input_idx should be greater than -1 in IsConstantTensor."));
// todo(wangmingkai02): need to check
return global_blocks_ops[op_id]->operand(input_idx).source().defining_op()->num_operands() == 0;
}
} // namespace paddle2onnx
66 changes: 64 additions & 2 deletions paddle2onnx/parser/pir_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,71 @@ class PaddlePirParser {
bool OpHasAttr(pir::Operation *op, const std::string &name) const;
std::vector<TensorInfo> GetOpInput(int64_t op_id, int64_t input_idx) const;
std::vector<TensorInfo> GetOpOutput(int64_t op_id, int64_t output_idx) const;
std::vector<int64_t> GetOpAttrVar(int64_t op_id, int64_t input_idx, const std::string &name) const;
std::string GetOpArgName(int64_t op_id, std::string name) const;
int32_t GetOpInputOutputName2Idx(int64_t op_id, std::string name, bool is_input) const;

bool IsConstantTensor(int64_t op_id, int64_t input_idx) const;

template<typename T>
bool TryGetTensorValue(int64_t op_id, int64_t input_idx, std::vector<T> *data) const {
PADDLE_ENFORCE_GT(
input_idx,
-1,
common::errors::InvalidArgument(
"input_idx should be greater than -1 in TryGetTensorValue."));
TensorInfo tensor_info = GetTensorInfo(global_blocks_ops[op_id]->operand(input_idx).source())[0];
auto iter = params.find(tensor_info.name);
if (iter != params.end()) {
(iter->second).get(data);
return true;
}
pir::Operation* op = global_blocks_ops[op_id]->operand(input_idx).source().defining_op();
int32_t dtype = tensor_info.dtype;
if (dtype == P2ODataType::INT64 || dtype == P2ODataType::INT32) {
std::vector<int64_t> value;
GetOpAttr(op, "value", &value);
data->assign(value.begin(), value.end());
} else if (dtype == P2ODataType::FP32) {
std::vector<float> value;
GetOpAttr(op, "value", &value);
data->assign(value.begin(), value.end());
} else if (dtype == P2ODataType::FP64) {
std::vector<double> value;
GetOpAttr(op, "value", &value);
data->assign(value.begin(), value.end());
} else {
Assert(false, "Only support int32/int64/float32/float64 data type now.");
}
return true;
}

template<typename T>
bool TryGetTensorValue(int64_t op_id, int64_t input_idx, T *data) const {
PADDLE_ENFORCE_GT(
input_idx,
-1,
common::errors::InvalidArgument(
"input_idx should be greater than -1 in TryGetTensorValue."));
TensorInfo tensor_info = GetTensorInfo(global_blocks_ops[op_id]->operand(input_idx).source())[0];
pir::Operation* op = global_blocks_ops[op_id]->operand(input_idx).source().defining_op();
PADDLE_ENFORCE_EQ(
op->HasAttribute("value"),
true,
common::errors::InvalidArgument(
"Cannot found attribute 'value' in op %s", op->name()));
auto value = op->attribute("value");
if (value.isa<pir::Int32Attribute>()) {
*data = value.dyn_cast<::pir::Int32Attribute>().data();
} else if(value.isa<pir::Int64Attribute>()) {
*data = value.dyn_cast<::pir::Int64Attribute>().data();
} else if(value.isa<pir::FloatAttribute>()) {
*data = value.dyn_cast<::pir::FloatAttribute>().data();
} else if(value.isa<pir::DoubleAttribute>()) {
*data = value.dyn_cast<::pir::DoubleAttribute>().data();
} else {
Assert(false, "Only support int32/int64/float32/float64 data type now.");
}
return true;
}

private:
bool IsAttrVar(const pir::Operation *op, const int64_t &attr_id) const;
Expand Down

0 comments on commit c09c2ff

Please sign in to comment.