Skip to content

Commit a6e9c7f

Browse files
committed
feat(numeric_solver): add a new numeric algo 'BFGS'
use scipy.optimize.minimize with method 'BFGS'
1 parent 9ad8d30 commit a6e9c7f

File tree

3 files changed

+50
-13
lines changed

3 files changed

+50
-13
lines changed

pytest/test_args.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,26 @@
2020
solvers_ids = ("scope", "Base") # , "GraHTP", "GraSP", "IHT")
2121

2222

23+
@pytest.mark.parametrize("model", models, ids=models_ids)
24+
@pytest.mark.parametrize("solver_creator", solvers, ids=solvers_ids)
25+
def test_numeric_solver(model, solver_creator):
26+
from skscope.numeric_solver import convex_solver_BFGS
27+
28+
solver = solver_creator(
29+
model["n_features"], model["n_informative"], numeric_solver=convex_solver_BFGS
30+
)
31+
solver.solve(model["loss"], jit=True)
32+
33+
assert set(model["support_set"]) == set(solver.get_support())
34+
35+
2336
@pytest.mark.parametrize("model", models, ids=models_ids)
2437
@pytest.mark.parametrize("solver_creator", solvers, ids=solvers_ids)
2538
def test_init_support_set(model, solver_creator):
2639
solver = solver_creator(model["n_features"], model["n_informative"])
2740
solver.solve(model["loss"], init_support_set=[0, 1, 2], jit=True)
28-
solver.get_result()
29-
solver.get_estimated_params()
30-
solver.get_support()
31-
assert set(model["support_set"]) == set(solver.support_set)
41+
42+
assert set(model["support_set"]) == set(solver.get_support())
3243

3344

3445
@pytest.mark.parametrize("model", models, ids=models_ids)

setup.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"win-arm64": "ARM64",
1919
}
2020

21+
2122
def get_info():
2223
# get information from `__init__.py`
2324
labels = ["__version__", "__author__"]
@@ -39,6 +40,7 @@ def __init__(self, name, sourcedir=""):
3940
self.sourcedir = os.path.abspath(sourcedir)
4041
self.parallel = 4
4142

43+
4244
class CMakeBuild(build_ext):
4345
def build_extension(self, ext):
4446
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
@@ -59,7 +61,7 @@ def build_extension(self, ext):
5961
cmake_args = [
6062
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
6163
f"-DPYTHON_EXECUTABLE={sys.executable}",
62-
f"-DCMAKE_BUILD_TYPE={cfg}"
64+
f"-DCMAKE_BUILD_TYPE={cfg}",
6365
]
6466
build_args = []
6567
# Adding CMake arguments set as environment variable
@@ -128,27 +130,29 @@ def build_extension(self, ext):
128130
subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp)
129131
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp)
130132

131-
with open(os.path.join(CURRENT_DIR, 'README.md'), encoding='utf-8') as f:
133+
134+
with open(os.path.join(CURRENT_DIR, "README.md"), encoding="utf-8") as f:
132135
long_description = f.read()
133136

134137
package_info = get_info()
135138

136139
setup(
137-
name='skscope',
138-
version=package_info['__version__'],
139-
author=package_info['__author__'],
140+
name="skscope",
141+
version=package_info["__version__"],
142+
author=package_info["__author__"],
140143
author_email="[email protected]",
141144
maintainer="Zezhi Wang",
142145
maintainer_email="[email protected]",
143146
packages=find_packages(),
144-
description="Sparsity-Constraint OPtimization via itErative-algorithm",
147+
description="Sparsity-Constraint OPtimization via itErative-algorithm",
145148
long_description=long_description,
146-
long_description_content_type='text/markdown',
149+
long_description_content_type="text/markdown",
147150
install_requires=[
148151
"numpy",
149152
"scikit-learn>=1.2.2",
150153
"jax[cpu]",
151154
"nlopt",
155+
"scipy",
152156
],
153157
license="MIT",
154158
url="https://skscope.readthedocs.io",
@@ -176,7 +180,7 @@ def build_extension(self, ext):
176180
"Programming Language :: Python :: 3.10",
177181
"Programming Language :: Python :: 3.11",
178182
],
179-
python_requires='>=3.8',
183+
python_requires=">=3.8",
180184
ext_modules=[CMakeExtension("skscope._scope")],
181-
cmdclass={"build_ext": CMakeBuild}
185+
cmdclass={"build_ext": CMakeBuild},
182186
)

skscope/numeric_solver.py

+22
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import math
99
import nlopt
10+
from scipy.optimize import minimize
1011

1112

1213
def convex_solver_nlopt(
@@ -69,3 +70,24 @@ def cache_opt_fn(x, grad):
6970
except RuntimeError:
7071
init_params[optim_variable_set] = best_params
7172
return best_loss, init_params
73+
74+
75+
def convex_solver_BFGS(
76+
objective_func,
77+
value_and_grad,
78+
init_params,
79+
optim_variable_set,
80+
data,
81+
):
82+
def fun(x):
83+
init_params[optim_variable_set] = x
84+
return objective_func(init_params, data)
85+
86+
def jac(x):
87+
init_params[optim_variable_set] = x
88+
_, grad = value_and_grad(init_params, data)
89+
return grad[optim_variable_set]
90+
91+
res = minimize(fun, init_params[optim_variable_set], method="BFGS", jac=jac)
92+
init_params[optim_variable_set] = res.x
93+
return res.fun, init_params

0 commit comments

Comments
 (0)