Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ignore unknown parameters when loading from model file #6126

Merged
merged 5 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,20 @@ class GBDT : public GBDTBase {
const auto pair = Common::Split(line.c_str(), ":");
if (pair[1] == " ]")
continue;
const auto param = pair[0].substr(1);
const auto value_str = pair[1].substr(1, pair[1].size() - 2);
auto iter = param_types.find(param);
if (iter == param_types.end()) {
Log::Warning("Type for param: '%s' not found. This doesn't affect inference.", param.c_str());
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
continue;
}
std::string param_type = iter->second;
if (first) {
first = false;
str_buf << "\"";
} else {
str_buf << ",\"";
}
const auto param = pair[0].substr(1);
const auto value_str = pair[1].substr(1, pair[1].size() - 2);
const auto param_type = param_types.at(param);
str_buf << param << "\": ";
if (param_type == "string") {
str_buf << "\"" << value_str << "\"";
Expand Down
19 changes: 17 additions & 2 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ def test_feature_name_with_non_ascii():
assert feature_names == gbm2.feature_name()


def test_parameters_are_loaded_from_model_file(tmp_path):
def test_parameters_are_loaded_from_model_file(tmp_path, capsys):
X = np.hstack([np.random.rand(100, 1), np.random.randint(0, 5, (100, 2))])
y = np.random.rand(100)
ds = lgb.Dataset(X, y)
Expand All @@ -1487,8 +1487,18 @@ def test_parameters_are_loaded_from_model_file(tmp_path):
'num_threads': 1,
}
model_file = tmp_path / 'model.txt'
lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2]).save_model(model_file)
orig_bst = lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2])
orig_bst.save_model(model_file)
with model_file.open('rt') as f:
model_contents = f.readlines()
params_start = model_contents.index('parameters:\n')
model_contents.insert(params_start + 1, '[max_conflict_rate: 0]\n')
with model_file.open('wt') as f:
f.writelines(model_contents)
bst = lgb.Booster(model_file=model_file)
expected_msg = "[LightGBM] [Warning] Type for param: 'max_conflict_rate' not found. This doesn't affect inference."
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
stdout = capsys.readouterr().out
assert expected_msg in stdout
set_params = {k: bst.params[k] for k in params.keys()}
assert set_params == params
assert bst.params['categorical_feature'] == [1, 2]
Expand All @@ -1498,6 +1508,11 @@ def test_parameters_are_loaded_from_model_file(tmp_path):
bst2 = lgb.Booster(params={'num_leaves': 7}, model_file=model_file)
assert bst.params == bst2.params

# check inference isn't affected by unknown parameter
orig_preds = orig_bst.predict(X)
preds = bst.predict(X)
np.testing.assert_allclose(preds, orig_preds)


def test_save_load_copy_pickle():
def train_and_predict(init_model=None, return_model=False):
Expand Down