Skip to content

Commit

Permalink
Fix conversion of Booster for xgboost>=1.6.1 (onnx#567)
Browse files Browse the repository at this point in the history
* fix base_score for binary classification

Signed-off-by: xadupre <[email protected]>

* Update for xgboost 1.6.1

Signed-off-by: xadupre <[email protected]>

* update ci

Signed-off-by: xadupre <[email protected]>

* fix for catboost

Signed-off-by: xadupre <[email protected]>

* fix for svm

Signed-off-by: xadupre <[email protected]>

* lint

Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre authored Jun 30, 2022
1 parent a699341 commit d7db0ff
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 44 deletions.
13 changes: 12 additions & 1 deletion .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@ jobs:
strategy:
matrix:

Python39-1120-RT1110:
Python39-1120-RT1110-xgb161:
python.version: '3.9'
ONNX_PATH: 'onnx==1.12.0' #'-i https://test.pypi.org/simple/ onnx==1.12.0rc4'
ONNXRT_PATH: onnxruntime==1.11.0 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003'
COREML_PATH: git+https://github.com/apple/[email protected]
xgboost.version: '>=1.6.1'

Python39-1120-RT1110-xgb142:
python.version: '3.9'
ONNX_PATH: 'onnx==1.12.0' #'-i https://test.pypi.org/simple/ onnx==1.12.0rc4'
ONNXRT_PATH: onnxruntime==1.11.0 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003'
COREML_PATH: git+https://github.com/apple/[email protected]
xgboost.version: '==1.4.2'

Python39-1110-RT1110:
python.version: '3.9'
Expand Down Expand Up @@ -126,7 +134,10 @@ jobs:
export PYTHONPATH=.
python -c "import onnxruntime;print('onnx:',onnx.__version__)"
python -c "import onnxconverter_common;print('cc:',onnxconverter_common.__version__)"
python -c "import onnx;print('onnx:',onnx.__version__)"
python -c "import onnxruntime;print('ort:',onnxruntime.__version__)"
python -c "import xgboost;print('xgboost:',xgboost.__version__)"
python -c "import lightgbm;print('lightgbm:',lightgbm.__version__)"
displayName: 'version'
- script: |
Expand Down
3 changes: 3 additions & 0 deletions .azure-pipelines/win32-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ jobs:
python -m pip install -e .
export PYTHONPATH=.
python -c "import onnxconverter_common;print(onnxconverter_common.__version__)"
python -c "import onnx;print(onnx.__version__)"
python -c "import onnxruntime;print(onnxruntime.__version__)"
python -c "import xgboost;print(xgboost.__version__)"
python -c "import lightgbm;print(lightgbm.__version__)"
displayName: 'version'
- script: |
Expand Down
66 changes: 45 additions & 21 deletions onnxmltools/convert/xgboost/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import json
import re
import pprint
from packaging.version import Version
import numpy as np
from xgboost import XGBRegressor, XGBClassifier
from xgboost import XGBRegressor, XGBClassifier, __version__
from onnxconverter_common.data_types import FloatTensorType
from ..common._container import XGBoostModelContainer
from ..common._topology import Topology
Expand All @@ -27,23 +29,36 @@ def _append_covers(node):


def _get_attributes(booster):
# num_class
state = booster.__getstate__()
bstate = bytes(state['handle'])
reg = re.compile(b'("tree_info":\\[[0-9,]*\\])')
objs = list(set(reg.findall(bstate)))
assert len(objs) == 1, 'Missing required property "tree_info".'
tree_info = json.loads("{{{}}}".format(objs[0].decode('ascii')))['tree_info']
num_class = len(set(tree_info))

atts = booster.attributes()
dp = booster.get_dump(dump_format='json', with_stats=True)
res = [json.loads(d) for d in dp]
trees = len(res)
try:

# num_class
if Version(__version__) < Version('1.5'):
state = booster.__getstate__()
bstate = bytes(state['handle'])
reg = re.compile(b'("tree_info":\\[[0-9,]*\\])')
objs = list(set(reg.findall(bstate)))
if len(objs) != 1:
raise RuntimeError(
"Unable to retrieve the tree coefficients from\n%s"
"" % bstate.decode("ascii", errors="ignore"))
tree_info = json.loads("{{{}}}".format(objs[0].decode('ascii')))['tree_info']
num_class = len(set(tree_info))
trees = len(res)
try:
ntrees = booster.best_ntree_limit
except AttributeError:
ntrees = trees // num_class if num_class > 0 else trees
else:
trees = len(res)
ntrees = booster.best_ntree_limit
except AttributeError:
ntrees = trees // num_class if num_class > 0 else trees
num_class = trees // ntrees
if num_class == 0:
raise RuntimeError(
"Unable to retrieve the number of classes, trees=%d, ntrees=%d." % (
trees, ntrees))

kwargs = atts.copy()
kwargs['feature_names'] = booster.feature_names
kwargs['n_estimators'] = ntrees
Expand All @@ -62,14 +77,23 @@ def _get_attributes(booster):
# classification
kwargs['num_class'] = num_class
if num_class != 1:
reg = re.compile(b'(multi:[a-z]{1,15})')
objs = list(set(reg.findall(bstate)))
if len(objs) == 1:
kwargs["objective"] = objs[0].decode('ascii')
if Version(__version__) < Version('1.5'):
reg = re.compile(b'(multi:[a-z]{1,15})')
objs = list(set(reg.findall(bstate)))
if len(objs) == 1:
kwargs["objective"] = objs[0].decode('ascii')
else:
raise RuntimeError(
"Unable to guess objective in %r (trees=%r, ntrees=%r, num_class=%r)"
"." % (objs, trees, ntrees, kwargs['num_class']))
else:
raise RuntimeError(
"Unable to guess objective in %r (trees=%r, ntrees=%r, num_class=%r)"
"." % (objs, trees, ntrees, kwargs['num_class']))
att = json.loads(booster.save_config())
kwargs["objective"] = att['learner']['objective']['name']
nc = int(att['learner']['learner_model_param']['num_class'])
if nc != num_class:
raise RuntimeError(
"Mismatched value %r != %r from\n%s" % (
nc, num_class, pprint.pformat(att)))
else:
kwargs["objective"] = "binary:logistic"

Expand Down
13 changes: 12 additions & 1 deletion onnxmltools/utils/tests_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,18 @@ def dump_data_and_model(data, model, onnx=None, basename="model", folder=None,
prediction = [model.predict(datax)]
elif hasattr(model, "predict_proba"):
# Classifier
prediction = [model.predict(data), model.predict_proba(data)]
if hasattr(model, 'get_params'):
params = model.get_params()
if 'objective' in params:
objective = params['objective']
if objective == "multi:softmax":
prediction = [model.predict(data)]
else:
prediction = [model.predict(data), model.predict_proba(data)]
else:
prediction = [model.predict(data), model.predict_proba(data)]
else:
prediction = [model.predict(data), model.predict_proba(data)]
elif hasattr(model, "predict_with_probabilities"):
# Classifier that returns all in one go
prediction = model.predict_with_probabilities(data)
Expand Down
7 changes: 1 addition & 6 deletions onnxmltools/utils/utils_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
Helpers to test runtimes.
"""
import os
import sys
import glob
import pickle
import packaging.version as pv
Expand Down Expand Up @@ -75,13 +74,9 @@ def compare_backend(backend, test, decimal=5, options=None, verbose=False, conte
if the comparison failed.
"""
if backend == "onnxruntime":
if sys.version_info[0] == 2:
# onnxruntime is not available on Python 2.
return
from .utils_backend_onnxruntime import compare_runtime
return compare_runtime(test, decimal, options, verbose)
else:
raise ValueError("Does not support backend '{0}'.".format(backend))
raise ValueError("Does not support backend '{0}'.".format(backend))


def search_converted_models(root=None):
Expand Down
37 changes: 22 additions & 15 deletions tests/xgboost/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,24 +109,10 @@ def test_xgb_classifier_reglog(self):
x_test, xgb, conv_model,
basename="SklearnXGBClassifierRegLog")

def test_xgb_classifier_multi_str_labels(self):
xgb, x_test = _fit_classification_model(
XGBClassifier(n_estimators=4), 5, is_str=True)
conv_model = convert_xgboost(
xgb, initial_types=[('input', FloatTensorType(shape=[None, None]))],
target_opset=TARGET_OPSET)
self.assertTrue(conv_model is not None)
dump_data_and_model(
x_test, xgb, conv_model,
basename="SklearnXGBClassifierMultiStrLabels")

def test_xgb_classifier_multi_discrete_int_labels(self):
iris = load_iris()
x = iris.data[:, :2]
y = iris.target
y[y == 0] = 10
y[y == 1] = 20
y[y == 2] = -30
x_train, x_test, y_train, _ = train_test_split(x,
y,
test_size=0.5,
Expand Down Expand Up @@ -241,7 +227,7 @@ def test_xgboost_10(self):
X_test.astype(np.float32), regressor, model_onnx,
basename="XGBBoosterRegBug")

def test_xgboost_classifier_i5450(self):
def test_xgboost_classifier_i5450_softmax(self):
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
Expand All @@ -255,6 +241,26 @@ def test_xgboost_classifier_i5450(self):
predict_list = [1., 20., 466., 0.]
predict_array = np.array(predict_list).reshape((1,-1)).astype(np.float32)
pred_onx = sess.run([label_name], {input_name: predict_array})[0]
bst = clr.get_booster()
bst.dump_model('dump.raw.txt')
dump_data_and_model(
X_test.astype(np.float32) + 1e-5, clr, onx,
basename="XGBClassifierIris-Out0")

def test_xgboost_classifier_i5450(self):
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
clr = XGBClassifier(objective="multi:softprob", max_depth=1, n_estimators=2)
clr.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=40)
initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_xgboost(clr, initial_types=initial_type, target_opset=TARGET_OPSET)
sess = InferenceSession(onx.SerializeToString())
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[1].name
predict_list = [1., 20., 466., 0.]
predict_array = np.array(predict_list).reshape((1,-1)).astype(np.float32)
pred_onx = sess.run([label_name], {input_name: predict_array})[0]
pred_xgboost = sessresults=clr.predict_proba(predict_array)
bst = clr.get_booster()
bst.dump_model('dump.raw.txt')
Expand Down Expand Up @@ -364,4 +370,5 @@ def test_onnxrt_python_xgbclassifier(self):


if __name__ == "__main__":
# TestXGBoostModels().test_xgboost_booster_classifier_multiclass_softprob()
unittest.main()

0 comments on commit d7db0ff

Please sign in to comment.