Skip to content

Commit 9a48a40

Browse files
khotilovtqchen
authored andcommitted
Fixes for multiple and default metric (#1239)
* fix multiple evaluation metrics * create DefaultEvalMetric only when really necessary * py test for #1239 * make travis happy
1 parent 9ef8607 commit 9a48a40

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

src/c_api/c_api.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ class Booster {
3333

3434
inline void SetParam(const std::string& name, const std::string& val) {
3535
auto it = std::find_if(cfg_.begin(), cfg_.end(),
36-
[&name](decltype(*cfg_.begin()) &x) {
36+
[&name, &val](decltype(*cfg_.begin()) &x) {
37+
if (name == "eval_metric") {
38+
return x.first == name && x.second == val;
39+
}
3740
return x.first == name;
3841
});
3942
if (it == cfg_.end()) {

src/learner.cc

+3-6
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,6 @@ class LearnerImpl : public Learner {
256256
attributes_ = std::map<std::string, std::string>(
257257
attr.begin(), attr.end());
258258
}
259-
if (metrics_.size() == 0) {
260-
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
261-
}
262259
this->base_score_ = mparam.base_score;
263260
gbm_->ResetPredBuffer(pred_buffer_size_);
264261
cfg_["num_class"] = common::ToString(mparam.num_class);
@@ -307,6 +304,9 @@ class LearnerImpl : public Learner {
307304
std::ostringstream os;
308305
os << '[' << iter << ']'
309306
<< std::setiosflags(std::ios::fixed);
307+
if (metrics_.size() == 0) {
308+
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
309+
}
310310
for (size_t i = 0; i < data_sets.size(); ++i) {
311311
this->PredictRaw(data_sets[i], &preds_);
312312
obj_->EvalTransform(&preds_);
@@ -445,9 +445,6 @@ class LearnerImpl : public Learner {
445445

446446
// reset the base score
447447
mparam.base_score = obj_->ProbToMargin(mparam.base_score);
448-
if (metrics_.size() == 0) {
449-
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
450-
}
451448

452449
this->base_score_ = mparam.base_score;
453450
gbm_->ResetPredBuffer(pred_buffer_size_);

tests/python/test_basic_models.py

+10
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@ def neg_evalerror(preds, dtrain):
105105
if int(preds2[i] > 0.5) != labels[i]) / float(len(preds2))
106106
assert err == err2
107107

108+
def test_multi_eval_metric(self):
109+
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
110+
param = {'max_depth': 2, 'eta': 0.2, 'silent': 1, 'objective': 'binary:logistic'}
111+
param['eval_metric'] = ["auc", "logloss", 'error']
112+
evals_result = {}
113+
bst = xgb.train(param, dtrain, 4, watchlist, evals_result=evals_result)
114+
assert isinstance(bst, xgb.core.Booster)
115+
assert len(evals_result['eval']) == 3
116+
assert set(evals_result['eval'].keys()) == {'auc', 'error', 'logloss'}
117+
108118
def test_fpreproc(self):
109119
param = {'max_depth': 2, 'eta': 1, 'silent': 1,
110120
'objective': 'binary:logistic'}

0 commit comments

Comments
 (0)