Skip to content

Commit

Permalink
raise an exception when TreeEnsemble request a feature out of boundar…
Browse files Browse the repository at this point in the history
…ies (#12859)

* Catch a potential error when the number of featues is low than the features referenced in TreeEnsemble

* add unit test

* remove extra spaces
  • Loading branch information
xadupre authored Sep 7, 2022
1 parent f856be1 commit 400195a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class TreeEnsembleCommonAttributes {
AGGREGATE_FUNCTION aggregate_function_;
int64_t n_nodes_;
int64_t max_tree_depth_;
int64_t max_feature_id_;
int64_t n_trees_;
bool same_mode_;
bool has_missing_tracks_;
Expand Down Expand Up @@ -196,12 +197,16 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(int parall
nodes_.resize(n_nodes_);
roots_.clear();
std::unordered_map<TreeNodeElementId, TreeNodeElement<ThresholdType>*, TreeNodeElementId::hash_fn> idi;
max_feature_id_ = 0;

for (i = 0, limit = nodes_treeids.size(); i < limit; ++i) {
TreeNodeElement<ThresholdType>& node = nodes_[i];
node.id.tree_id = static_cast<int>(nodes_treeids[i]);
node.id.node_id = static_cast<int>(nodes_nodeids[i]);
node.feature_id = static_cast<int>(nodes_featureids[i]);
if (node.feature_id > max_feature_id_) {
max_feature_id_ = node.feature_id;
}
if (nodes_values_as_tensor.empty()) {
node.value = static_cast<ThresholdType>(nodes_values[i]);
} else {
Expand Down Expand Up @@ -344,8 +349,15 @@ template <typename AGG>
void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concurrency::ThreadPool* ttp,
const Tensor* X, Tensor* Z,
Tensor* label, const AGG& agg) const {
if (X->Shape().NumDimensions() > 2) {
ORT_THROW("TreeEnsemble only works on 1D, 2D tensors.");
}
int64_t stride = X->Shape().NumDimensions() == 1 ? X->Shape()[0] : X->Shape()[1];
int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0];
int64_t C = X->Shape().NumDimensions() == 2 ? X->Shape()[1] : 1;
if (max_feature_id_ >= C) {
ORT_THROW("One path in the graph requests feature ", max_feature_id_, " but input tensor has ", C, " features.");
}
OutputType* z_data = Z->MutableData<OutputType>();

const InputType* x_data = X->Data<InputType>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,51 @@ TEST(MLOpTest, TreeEnsembleClassifier_N1) {
test.Run();
}

TEST(MLOpTest, TreeEnsembleClassifierFailShape) {
OpTester test("TreeEnsembleClassifier", 1, onnxruntime::kMLDomain);

std::vector<int64_t> lefts = {1, -1, 3, -1, -1, 1, -1, 3, 4, -1, -1, -1, 1, 2, -1, 4, -1, -1, -1};
std::vector<int64_t> rights = {2, -1, 4, -1, -1, 2, -1, 6, 5, -1, -1, -1, 6, 3, -1, 5, -1, -1, -1};
std::vector<int64_t> treeids = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2};
std::vector<int64_t> nodeids = {0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> featureids = {2, -2, 0, -2, -2, 0, -2, 2, 1, -2, -2, -2, 0, 2, -2, 1, -2, -2, -2};
std::vector<float> thresholds = {-172.f, -2.f, 2.5f, -2.f, -2.f, 1.5f, -2.f, -62.5f, 213.09999084f,
-2.f, -2.f, -2.f, 27.5f, -172.f, -2.f, 8.10000038f, -2.f, -2.f, -2.f};
std::vector<std::string> modes = {"BRANCH_LEQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF", "BRANCH_LEQ",
"LEAF", "BRANCH_LEQ", "BRANCH_LEQ", "LEAF", "LEAF", "LEAF",
"BRANCH_LEQ", "BRANCH_LEQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF", "LEAF"};
std::vector<int64_t> class_treeids = {0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2};
std::vector<int64_t> class_nodeids = {1, 3, 4, 1, 4, 5, 6, 2, 4, 5, 6};
std::vector<int64_t> class_classids = {2, 0, 1, 0, 2, 3, 1, 2, 0, 1, 3};
std::vector<float> class_weights = {1.f, 4.f, 1.f, 2.f, 1.f, 1.f, 2.f, 1.f, 1.f, 1.f, 3.f};
std::vector<int64_t> classes = {0, 1, 2, 3};
std::vector<float> X = {1.f, 0.0f};
std::vector<int64_t> results = {0};
std::vector<float> scores{7, 0, 0, 0};
std::vector<float> probs = {};
std::vector<float> log_probs = {};

constexpr int N = 1;
test.AddAttribute("nodes_truenodeids", lefts);
test.AddAttribute("nodes_falsenodeids", rights);
test.AddAttribute("nodes_treeids", treeids);
test.AddAttribute("nodes_nodeids", nodeids);
test.AddAttribute("nodes_featureids", featureids);
test.AddAttribute("nodes_values", thresholds);
test.AddAttribute("nodes_modes", modes);
test.AddAttribute("class_treeids", class_treeids);
test.AddAttribute("class_nodeids", class_nodeids);
test.AddAttribute("class_ids", class_classids);
test.AddAttribute("class_weights", class_weights);
test.AddAttribute("classlabels_int64s", classes);

test.AddInput<float>("X", {N, 2}, X);
test.AddOutput<int64_t>("Y", {N}, results);
test.AddOutput<float>("Z", {N, static_cast<int64_t>(classes.size())}, scores);
test.Run(OpTester::ExpectResult::kExpectFailure,
"One path in the graph requests feature 2 but input tensor has 2 features.");
}

TEST(MLOpTest, TreeEnsembleClassifierLabels) {
OpTester test("TreeEnsembleClassifier", 1, onnxruntime::kMLDomain);

Expand Down

0 comments on commit 400195a

Please sign in to comment.