Skip to content

Commit

Permalink
fix base_score for binary classification (#566)
Browse files Browse the repository at this point in the history
Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre authored Jun 23, 2022
1 parent facefb2 commit 24d1d99
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
3 changes: 3 additions & 0 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def convert(scope, operator, container):
attr_pairs['class_ids'] = [0 for v in attr_pairs['class_treeids']]
if js_trees[0].get('leaf', None) == 0:
attr_pairs['base_values'] = [0.5]
elif base_score != 0.5:
cst = - np.log(1 / np.float32(base_score) - 1.)
attr_pairs['base_values'] = [cst]
else:
# See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L35.
attr_pairs['post_transform'] = "SOFTMAX"
Expand Down
21 changes: 21 additions & 0 deletions tests/xgboost/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,27 @@ def test_xgb_best_tree_limit(self):
assert_almost_equal(bst_loaded.predict(dtest, output_margin=True), res[1], decimal=5)
assert_almost_equal(bst_loaded.predict(dtest), res[0])

def test_onnxrt_python_xgbclassifier(self):
x = np.random.randn(100, 10).astype(np.float32)
y = ((x.sum(axis=1) + np.random.randn(x.shape[0]) / 50 + 0.5) >= 0).astype(np.int64)
x_train, x_test, y_train, y_test = train_test_split(x, y)
bmy = np.mean(y_train)

for bm, n_est in [(None, 1), (None, 3), (bmy, 1), (bmy, 3)]:
model_skl = XGBClassifier(n_estimators=n_est,
learning_rate=0.01,
subsample=0.5, objective="binary:logistic",
base_score=bm, max_depth=2)
model_skl.fit(x_train, y_train, eval_set=[(x_test, y_test)], verbose=0)

model_onnx_skl = convert_xgboost(
model_skl, initial_types=[('X', FloatTensorType([None, x.shape[1]]))],
target_opset=TARGET_OPSET)
with self.subTest(base_score=bm, n_estimators=n_est):
oinf = InferenceSession(model_onnx_skl.SerializeToString())
res2 = oinf.run(None, {'X': x_test})
assert_almost_equal(model_skl.predict_proba(x_test), res2[1])


if __name__ == "__main__":
unittest.main()

0 comments on commit 24d1d99

Please sign in to comment.