Skip to content

Commit 4a6cd6d

Browse files
committed
fix: bug when support_set is like range(k)
1 parent 0a0dc68 commit 4a6cd6d

11 files changed

+496
-208
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 680,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"n, m, k = 5, 10, 3#500, 200, 15\n",
10+
"p = k * m + k * n\n",
11+
"l = 10#100"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 681,
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"import numpy as np\n",
21+
"import random\n",
22+
"random.seed(3)\n",
23+
"np.random.seed(3)"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": 682,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"def L(A, B):\n",
33+
" k1, n = A.shape\n",
34+
" k2, m = B.shape\n",
35+
" assert n == m\n",
36+
" total = 0\n",
37+
" for i in range(k1):\n",
38+
" min_row_diff = np.inf\n",
39+
" for j in range(k2):\n",
40+
" row_diff = np.sum((A[i, :] - B[j, :]) ** 2)\n",
41+
" if row_diff < min_row_diff:\n",
42+
" min_row_diff = row_diff\n",
43+
" total += min_row_diff\n",
44+
" return total"
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": 683,
50+
"metadata": {},
51+
"outputs": [],
52+
"source": [
53+
"from skscope import layer\n",
54+
"import jax.numpy as jnp\n",
55+
"import numpy as np\n",
56+
"\n",
57+
"layers = [layer.NonNegative(p)]\n",
58+
"for i in range(m): \n",
59+
" coef = np.zeros(p)\n",
60+
" coef[i*k:i*k+k] = 1.0\n",
61+
" layers.append(layer.LinearConstraint(p, jnp.array(coef)))\n",
62+
"\n",
63+
"preselect = list(range(k*m))"
64+
]
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": 684,
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"true_H = np.random.uniform(100,size=(k, n))\n",
73+
"np.put(true_index:=np.zeros((k, n), dtype=int), np.random.choice(n*k, l, replace=False), 1)\n",
74+
"true_H *= true_index\n",
75+
"true_W = np.random.uniform(size=(m, k))\n",
76+
"true_W /= true_W.sum(axis=1, keepdims=True)\n",
77+
"\n",
78+
"X = true_W @ true_H"
79+
]
80+
},
81+
{
82+
"cell_type": "code",
83+
"execution_count": 685,
84+
"metadata": {},
85+
"outputs": [
86+
{
87+
"data": {
88+
"text/plain": [
89+
"array([[45.47100765, 0. , 0. , 0. , 11.59825152],\n",
90+
" [11.2669842 , 87.56705426, 79.48295506, 94.90474687, 56.35982548],\n",
91+
" [ 0. , 54.77351078, 0. , 72.42975902, 33.0507647 ]])"
92+
]
93+
},
94+
"execution_count": 685,
95+
"metadata": {},
96+
"output_type": "execute_result"
97+
}
98+
],
99+
"source": [
100+
"true_H"
101+
]
102+
},
103+
{
104+
"cell_type": "code",
105+
"execution_count": 686,
106+
"metadata": {},
107+
"outputs": [],
108+
"source": [
109+
"def nmf(params):\n",
110+
" W = params[:k*m].reshape(m, k)\n",
111+
" H = params[k*m:].reshape(k, n)\n",
112+
" return jnp.sum((X - W @ H) ** 2)"
113+
]
114+
},
115+
{
116+
"cell_type": "code",
117+
"execution_count": 687,
118+
"metadata": {},
119+
"outputs": [
120+
{
121+
"data": {
122+
"text/plain": [
123+
"(Array([[ 0. , 3.73 , 0. , 4.47 , 2.49 ],\n",
124+
" [ 0. , 197.08 , 0. , 245.43999 , 126.619995 ],\n",
125+
" [ 22.06 , 1.8499999, 36.71 , 0. , 3.59 ]], dtype=float32),\n",
126+
" Array(469.97394, dtype=float32),\n",
127+
" Array(0., dtype=float32))"
128+
]
129+
},
130+
"execution_count": 687,
131+
"metadata": {},
132+
"output_type": "execute_result"
133+
}
134+
],
135+
"source": [
136+
"from skscope import ScopeSolver\n",
137+
"\n",
138+
"solver = ScopeSolver(p, k*m+l, preselect=preselect)\n",
139+
"params = solver.solve(nmf, layers=layers, jit=True)\n",
140+
"round(params[k*m:].reshape(k, n), 2), nmf(params), nmf(np.concatenate([true_W.flatten(), true_H.flatten()]))"
141+
]
142+
},
143+
{
144+
"cell_type": "code",
145+
"execution_count": 688,
146+
"metadata": {},
147+
"outputs": [
148+
{
149+
"data": {
150+
"text/plain": [
151+
"0.6"
152+
]
153+
},
154+
"execution_count": 688,
155+
"metadata": {},
156+
"output_type": "execute_result"
157+
}
158+
],
159+
"source": [
160+
"len(set(true_H.reshape(-1).nonzero()[0]) & set(np.array(params[k*m:].nonzero()[0]))) / l"
161+
]
162+
},
163+
{
164+
"cell_type": "code",
165+
"execution_count": 689,
166+
"metadata": {},
167+
"outputs": [
168+
{
169+
"data": {
170+
"text/plain": [
171+
"(Array([[ 19.56 , 3.05 , 10.98 , 1.15 , 6.92 ],\n",
172+
" [ 5.2799997, 59.039997 , 5.73 , 76.56 , 37.02 ],\n",
173+
" [ 33.77 , 254.87999 , 158.23999 , 295.44998 , 163.7 ]], dtype=float32),\n",
174+
" Array(1.1004886e-09, dtype=float32),\n",
175+
" Array(0., dtype=float32))"
176+
]
177+
},
178+
"execution_count": 689,
179+
"metadata": {},
180+
"output_type": "execute_result"
181+
}
182+
],
183+
"source": [
184+
"from skscope import BaseSolver\n",
185+
"solver = BaseSolver(p, p)\n",
186+
"dense_params = solver.solve(nmf, layers=layers, jit=True)\n",
187+
"round(dense_params[k*m:].reshape(k, n), 2), nmf(dense_params), nmf(np.concatenate([true_W.flatten(), true_H.flatten()]))"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": 690,
193+
"metadata": {},
194+
"outputs": [
195+
{
196+
"data": {
197+
"text/plain": [
198+
"(Array(0.8247665, dtype=float32), Array(0.2097329, dtype=float32))"
199+
]
200+
},
201+
"execution_count": 690,
202+
"metadata": {},
203+
"output_type": "execute_result"
204+
}
205+
],
206+
"source": [
207+
"L(true_H, params[k*m:].reshape(k, n)) / np.sum(true_H ** 2), L(true_H, dense_params[k*m:].reshape(k, n)) / np.sum(true_H ** 2)"
208+
]
209+
}
210+
],
211+
"metadata": {
212+
"kernelspec": {
213+
"display_name": "scope",
214+
"language": "python",
215+
"name": "python3"
216+
},
217+
"language_info": {
218+
"codemirror_mode": {
219+
"name": "ipython",
220+
"version": 3
221+
},
222+
"file_extension": ".py",
223+
"mimetype": "text/x-python",
224+
"name": "python",
225+
"nbconvert_exporter": "python",
226+
"pygments_lexer": "ipython3",
227+
"version": "3.10.13"
228+
}
229+
},
230+
"nbformat": 4,
231+
"nbformat_minor": 2
232+
}

src/Algorithm.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ void Algorithm::get_A(UniversalData &X, MatrixXd &y, VectorXi &A, VectorXi &I, i
241241
}
242242

243243
// If A_U not change, U will not change and we can stop.
244-
if (A_U.size() == 0 || A_U.maxCoeff() == T0 - 1)
244+
if (this->U_size < N && (A_U.size() == 0 || A_U.maxCoeff() == T0 - 1))
245245
break;
246246

247247
// Update & Restore beta, A from U
@@ -366,6 +366,8 @@ VectorXi Algorithm::inital_screening(UniversalData &X, MatrixXd &y, VectorXd &be
366366
{
367367
if (bd.size() == 0)
368368
{
369+
SPDLOG_DEBUG("init active set is ", A.transpose());
370+
SPDLOG_DEBUG("init params is ", beta.transpose());
369371
// variable initialization
370372
int beta_size = X.cols();
371373
bd = VectorXd::Zero(N);

src/Data.h

+7-8
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,16 @@
66

77
#pragma once
88

9-
109
#include <Eigen/Eigen>
1110
#include <vector>
1211

1312
#include "utilities.h"
1413
using namespace std;
1514
using namespace Eigen;
1615

17-
18-
class Data {
19-
public:
16+
class Data
17+
{
18+
public:
2019
UniversalData x;
2120
Eigen::MatrixXd y;
2221
Eigen::VectorXd weight;
@@ -34,7 +33,8 @@ class Data {
3433
Data() = default;
3534

3635
Data(UniversalData &x, Eigen::MatrixXd &y, int normalize_type, Eigen::VectorXd &weight, Eigen::VectorXi &g_index, bool sparse_matrix,
37-
int beta_size) {
36+
int beta_size)
37+
{
3838
this->x = x;
3939
this->y = y;
4040
this->normalize_type = normalize_type;
@@ -49,10 +49,9 @@ class Data {
4949
this->g_index = g_index;
5050
this->g_num = g_index.size();
5151
Eigen::VectorXi temp = Eigen::VectorXi::Zero(this->g_num);
52-
for (int i = 0; i < g_num - 1; i++) temp(i) = g_index(i + 1);
52+
for (int i = 0; i < g_num - 1; i++)
53+
temp(i) = g_index(i + 1);
5354
temp(g_num - 1) = beta_size;
5455
this->g_size = temp - g_index;
5556
};
56-
5757
};
58-

src/OpenMP.h

-2
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,3 @@ inline int omp_get_num_procs() { return 1; }
1616
inline void omp_set_num_threads(int nthread) {}
1717
inline void omp_set_dynamic(int flag) {}
1818
#endif
19-
20-

src/UniversalData.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ void UniversalData::gradient_and_hessian(const VectorXd &effective_para, VectorX
122122

123123
double UniversalData::optimize(VectorXd &effective_para)
124124
{
125-
if (effective_para.size() == 0){
125+
if (effective_para.size() == 0)
126+
{
126127
return model->loss(VectorXd::Zero(this->model_size), *this->data);
127128
}
128129
auto value_and_grad = [this](const VectorXd &complete_para, pybind11::object data) -> pair<double, VectorXd>

0 commit comments

Comments
 (0)