Skip to content

Commit

Permalink
Fix Segmentation Fault and ZeroDivisionError in Group Lasso (#292)
Browse files Browse the repository at this point in the history
Co-authored-by: mathurinm <[email protected]>
  • Loading branch information
Badr-MOUFAD and mathurinm authored Apr 8, 2024
1 parent 1788a4f commit 1df21cf
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion celer/dropin_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def fit(self, X, y):
% (n_samples, y.shape[0]))

X, y, X_offset, y_offset, X_scale = _preprocess_data(
X, y, self.fit_intercept, copy=False)
X, y, fit_intercept=self.fit_intercept, copy=False)

if not self.warm_start or not hasattr(self, "coef_"):
self.coef_ = None
Expand Down
10 changes: 6 additions & 4 deletions celer/group_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ cpdef floating dnorm_grp(

else: # scaling only with features in C
for g_idx in range(ws_size):
g = C[g_idx]

if weights[g] == INFINITY:
continue

g = C[g_idx]
tmp = 0
for k in range(grp_ptr[g], grp_ptr[g + 1]):
j = grp_indices[k]
Expand Down Expand Up @@ -418,8 +419,11 @@ cpdef celer_grp(
&inc) / lc_groups[g]
norm_wg += w[j] ** 2
norm_wg = sqrt(norm_wg)
bst_scal = max(0.,
if norm_wg != 0.:
bst_scal = max(0.,
1. - alpha * weights[g] / lc_groups[g] * n_samples / norm_wg)
else:
bst_scal = 0.

for k in range(grp_ptr[g + 1] - grp_ptr[g]):
j = grp_indices[grp_ptr[g] + k]
Expand Down Expand Up @@ -448,5 +452,3 @@ cpdef celer_grp(
'Fitting data with very small alpha causes precision issues.',
ConvergenceWarning)
return np.asarray(w), np.asarray(theta), np.asarray(gaps[:t + 1])


2 changes: 1 addition & 1 deletion celer/tests/test_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


@pytest.mark.parametrize("sparse_X, fit_intercept",
itertools.product([0, 1], [0, 1]))
itertools.product([0, 1], [False, True]))
def test_GroupLasso_Lasso_equivalence(sparse_X, fit_intercept):
"""Check that GroupLasso with groups of size 1 gives Lasso."""
n_features = 1000
Expand Down

0 comments on commit 1df21cf

Please sign in to comment.