Skip to content

Commit

Permalink
feat(numeric_solver): add a new numeric algo 'BFGS' (#77)
Browse files Browse the repository at this point in the history
* feat(numeric_solver): add a new numeric algo 'BFGS'

use scipy.optimize.minimize with method 'BFGS'
  • Loading branch information
bbayukari authored Feb 18, 2024
1 parent 804c798 commit 20ef8ae
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 15 deletions.
19 changes: 15 additions & 4 deletions pytest/test_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,26 @@
solvers_ids = ("scope", "Base") # , "GraHTP", "GraSP", "IHT")


@pytest.mark.parametrize("model", models, ids=models_ids)
@pytest.mark.parametrize("solver_creator", solvers, ids=solvers_ids)
def test_numeric_solver(model, solver_creator):
from skscope.numeric_solver import convex_solver_BFGS

solver = solver_creator(
model["n_features"], model["n_informative"], numeric_solver=convex_solver_BFGS
)
solver.solve(model["loss"], jit=True)

assert set(model["support_set"]) == set(solver.get_support())


@pytest.mark.parametrize("model", models, ids=models_ids)
@pytest.mark.parametrize("solver_creator", solvers, ids=solvers_ids)
def test_init_support_set(model, solver_creator):
solver = solver_creator(model["n_features"], model["n_informative"])
solver.solve(model["loss"], init_support_set=[0, 1, 2], jit=True)
solver.get_result()
solver.get_estimated_params()
solver.get_support()
assert set(model["support_set"]) == set(solver.support_set)

assert set(model["support_set"]) == set(solver.get_support())


@pytest.mark.parametrize("model", models, ids=models_ids)
Expand Down
4 changes: 2 additions & 2 deletions pytest/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def test_config(model, solver_creator):
solver = solver_creator(model["n_features"], model["n_informative"])
solver.set_config(**solver.get_config())
solver.solve(model["loss"], jit=True)

assert set(model["support_set"]) == set(solver.support_set)
res = solver.get_result()
assert set(model["support_set"]) == set(res["support_set"])


@pytest.mark.parametrize("model", models, ids=models_ids)
Expand Down
22 changes: 13 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"win-arm64": "ARM64",
}


def get_info():
# get information from `__init__.py`
labels = ["__version__", "__author__"]
Expand All @@ -39,6 +40,7 @@ def __init__(self, name, sourcedir=""):
self.sourcedir = os.path.abspath(sourcedir)
self.parallel = 4


class CMakeBuild(build_ext):
def build_extension(self, ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
Expand All @@ -59,7 +61,7 @@ def build_extension(self, ext):
cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
f"-DPYTHON_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={cfg}"
f"-DCMAKE_BUILD_TYPE={cfg}",
]
build_args = []
# Adding CMake arguments set as environment variable
Expand Down Expand Up @@ -128,27 +130,29 @@ def build_extension(self, ext):
subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp)

with open(os.path.join(CURRENT_DIR, 'README.md'), encoding='utf-8') as f:

with open(os.path.join(CURRENT_DIR, "README.md"), encoding="utf-8") as f:
long_description = f.read()

package_info = get_info()

setup(
name='skscope',
version=package_info['__version__'],
author=package_info['__author__'],
name="skscope",
version=package_info["__version__"],
author=package_info["__author__"],
author_email="[email protected]",
maintainer="Zezhi Wang",
maintainer_email="[email protected]",
packages=find_packages(),
description="Sparsity-Constraint OPtimization via itErative-algorithm",
description="Sparsity-Constraint OPtimization via itErative-algorithm",
long_description=long_description,
long_description_content_type='text/markdown',
long_description_content_type="text/markdown",
install_requires=[
"numpy",
"scikit-learn>=1.2.2",
"jax[cpu]",
"nlopt",
"scipy",
],
license="MIT",
url="https://skscope.readthedocs.io",
Expand Down Expand Up @@ -176,7 +180,7 @@ def build_extension(self, ext):
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
],
python_requires='>=3.8',
python_requires=">=3.8",
ext_modules=[CMakeExtension("skscope._scope")],
cmdclass={"build_ext": CMakeBuild}
cmdclass={"build_ext": CMakeBuild},
)
22 changes: 22 additions & 0 deletions skscope/numeric_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import math
import nlopt
from scipy.optimize import minimize


def convex_solver_nlopt(
Expand Down Expand Up @@ -69,3 +70,24 @@ def cache_opt_fn(x, grad):
except RuntimeError:
init_params[optim_variable_set] = best_params
return best_loss, init_params


def convex_solver_BFGS(
objective_func,
value_and_grad,
init_params,
optim_variable_set,
data,
):
def fun(x):
init_params[optim_variable_set] = x
return objective_func(init_params, data)

def jac(x):
init_params[optim_variable_set] = x
_, grad = value_and_grad(init_params, data)
return grad[optim_variable_set]

res = minimize(fun, init_params[optim_variable_set], method="BFGS", jac=jac)
init_params[optim_variable_set] = res.x
return res.fun, init_params

0 comments on commit 20ef8ae

Please sign in to comment.