Skip to content

Commit

Permalink
add ConvergenceWarning in do_line_search
Browse files Browse the repository at this point in the history
  • Loading branch information
Badr-MOUFAD committed Apr 24, 2022
1 parent 83c4628 commit 81f2493
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions celer/PN_logreg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,12 @@ cpdef void do_line_search(
floating[::1, :] X, floating[:] X_data,
int[:] X_indices, int[:] X_indptr, int MAX_BACKTRACK_ITR,
floating[:] y, floating[:] exp_Xw, floating[:] low_exp_Xw,
floating[:] aux, int[:] is_positive_label) nogil:
floating[:] aux, int[:] is_positive_label):

cdef int i, ind, backtrack_itr
cdef floating deriv
cdef floating step_size = 1.
cdef floating atol = 1e-7

cdef int n_samples = y.shape[0]
fcopy(&n_samples, &exp_Xw[0], &inc, &low_exp_Xw[0], &inc)
Expand All @@ -417,15 +418,18 @@ cpdef void do_line_search(
deriv = compute_derivative(
w, WS, delta_w, X_delta_w, alpha, aux, step_size, y)

if deriv < 1e-7:
if deriv < atol:
break
else:
step_size = step_size / 2.
for i in range(n_samples):
exp_Xw[i] = sqrt(exp_Xw[i] * low_exp_Xw[i])
else:
pass
# TODO what do we do in this case?
warnings.warn(
'Line search failed to converge '
f'deriv {deriv:.2e}, atol {atol:.2e}',
ConvergenceWarning
)

# a suitable step size is found, perform step:
for ind in range(WS.shape[0]):
Expand Down

0 comments on commit 81f2493

Please sign in to comment.