From 20ef8ae45e892b889dc2f80759c6d17b57020c71 Mon Sep 17 00:00:00 2001 From: bbayukari <40588568+bbayukari@users.noreply.github.com> Date: Sun, 18 Feb 2024 19:50:49 +0800 Subject: [PATCH] feat(numeric_solver): add a new numeric algo 'BFGS' (#77) * feat(numeric_solver): add a new numeric algo 'BFGS' use scipy.optimize.minimize with method 'BFGS' --- pytest/test_args.py | 19 +++++++++++++++---- pytest/test_workflow.py | 4 ++-- setup.py | 22 +++++++++++++--------- skscope/numeric_solver.py | 22 ++++++++++++++++++++++ 4 files changed, 52 insertions(+), 15 deletions(-) diff --git a/pytest/test_args.py b/pytest/test_args.py index 4211f95..45cf03f 100644 --- a/pytest/test_args.py +++ b/pytest/test_args.py @@ -20,15 +20,26 @@ solvers_ids = ("scope", "Base") # , "GraHTP", "GraSP", "IHT") +@pytest.mark.parametrize("model", models, ids=models_ids) +@pytest.mark.parametrize("solver_creator", solvers, ids=solvers_ids) +def test_numeric_solver(model, solver_creator): + from skscope.numeric_solver import convex_solver_BFGS + + solver = solver_creator( + model["n_features"], model["n_informative"], numeric_solver=convex_solver_BFGS + ) + solver.solve(model["loss"], jit=True) + + assert set(model["support_set"]) == set(solver.get_support()) + + @pytest.mark.parametrize("model", models, ids=models_ids) @pytest.mark.parametrize("solver_creator", solvers, ids=solvers_ids) def test_init_support_set(model, solver_creator): solver = solver_creator(model["n_features"], model["n_informative"]) solver.solve(model["loss"], init_support_set=[0, 1, 2], jit=True) - solver.get_result() - solver.get_estimated_params() - solver.get_support() - assert set(model["support_set"]) == set(solver.support_set) + + assert set(model["support_set"]) == set(solver.get_support()) @pytest.mark.parametrize("model", models, ids=models_ids) diff --git a/pytest/test_workflow.py b/pytest/test_workflow.py index 82f4468..97177ff 100644 --- a/pytest/test_workflow.py +++ b/pytest/test_workflow.py @@ -55,8 +55,8 @@ def test_config(model, solver_creator): solver = solver_creator(model["n_features"], model["n_informative"]) solver.set_config(**solver.get_config()) solver.solve(model["loss"], jit=True) - - assert set(model["support_set"]) == set(solver.support_set) + res = solver.get_result() + assert set(model["support_set"]) == set(res["support_set"]) @pytest.mark.parametrize("model", models, ids=models_ids) diff --git a/setup.py b/setup.py index ca620a3..29fb6e6 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ "win-arm64": "ARM64", } + def get_info(): # get information from `__init__.py` labels = ["__version__", "__author__"] @@ -39,6 +40,7 @@ def __init__(self, name, sourcedir=""): self.sourcedir = os.path.abspath(sourcedir) self.parallel = 4 + class CMakeBuild(build_ext): def build_extension(self, ext): extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) @@ -59,7 +61,7 @@ def build_extension(self, ext): cmake_args = [ f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", f"-DPYTHON_EXECUTABLE={sys.executable}", - f"-DCMAKE_BUILD_TYPE={cfg}" + f"-DCMAKE_BUILD_TYPE={cfg}", ] build_args = [] # Adding CMake arguments set as environment variable @@ -128,27 +130,29 @@ def build_extension(self, ext): subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp) -with open(os.path.join(CURRENT_DIR, 'README.md'), encoding='utf-8') as f: + +with open(os.path.join(CURRENT_DIR, "README.md"), encoding="utf-8") as f: long_description = f.read() package_info = get_info() setup( - name='skscope', - version=package_info['__version__'], - author=package_info['__author__'], + name="skscope", + version=package_info["__version__"], + author=package_info["__author__"], author_email="homura@mail.ustc.edu.cn", maintainer="Zezhi Wang", maintainer_email="homura@mail.ustc.edu.cn", packages=find_packages(), - description="Sparsity-Constraint OPtimization via itErative-algorithm", + description="Sparsity-Constraint OPtimization via itErative-algorithm", long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", install_requires=[ "numpy", "scikit-learn>=1.2.2", "jax[cpu]", "nlopt", + "scipy", ], license="MIT", url="https://skscope.readthedocs.io", @@ -176,7 +180,7 @@ def build_extension(self, ext): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ], - python_requires='>=3.8', + python_requires=">=3.8", ext_modules=[CMakeExtension("skscope._scope")], - cmdclass={"build_ext": CMakeBuild} + cmdclass={"build_ext": CMakeBuild}, ) diff --git a/skscope/numeric_solver.py b/skscope/numeric_solver.py index 6c1ca8c..d8c224f 100644 --- a/skscope/numeric_solver.py +++ b/skscope/numeric_solver.py @@ -7,6 +7,7 @@ import numpy as np import math import nlopt +from scipy.optimize import minimize def convex_solver_nlopt( @@ -69,3 +70,24 @@ def cache_opt_fn(x, grad): except RuntimeError: init_params[optim_variable_set] = best_params return best_loss, init_params + + +def convex_solver_BFGS( + objective_func, + value_and_grad, + init_params, + optim_variable_set, + data, +): + def fun(x): + init_params[optim_variable_set] = x + return objective_func(init_params, data) + + def jac(x): + init_params[optim_variable_set] = x + _, grad = value_and_grad(init_params, data) + return grad[optim_variable_set] + + res = minimize(fun, init_params[optim_variable_set], method="BFGS", jac=jac) + init_params[optim_variable_set] = res.x + return res.fun, init_params