Skip to content

Commit 0bb66d7

Browse files
committed
Handle tol rtol renaming in scipy solvers
1 parent f8fe345 commit 0bb66d7

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

smt/utils/linear_solvers.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -264,24 +264,44 @@ def _setup(self, mtx, printer, mg_matrices=[]):
264264
self.callback_func = self.callback._print_sol
265265
self.solver_kwargs = {
266266
"atol": 0.0,
267-
"tol": self.options["atol"],
267+
"rtol": self.options["atol"],
268268
"maxiter": self.options["ilimit"],
269269
}
270270
elif self.options["solver"] == "bicgstab":
271271
self.solver = scipy.sparse.linalg.bicgstab
272272
self.callback_func = self.callback._print_sol
273273
self.solver_kwargs = {
274-
"tol": self.options["atol"],
274+
"rtol": self.options["rtol"],
275275
"maxiter": self.options["ilimit"],
276276
}
277277
elif self.options["solver"] == "gmres":
278278
self.solver = scipy.sparse.linalg.gmres
279279
self.callback_func = self.callback._print_res
280280
self.solver_kwargs = {
281-
"tol": self.options["atol"],
281+
"rtol": self.options["rtol"],
282282
"maxiter": self.options["ilimit"],
283283
"restart": min(self.options["ilimit"], mtx.shape[0]),
284284
}
285+
self._patch_when_scipy_lessthan_v111()
286+
287+
def _patch_when_scipy_lessthan_v111(self):
288+
"""
289+
From scipy 1.11.0 release notes
290+
The tol argument of scipy.sparse.linalg.{bcg,bicstab,cg,cgs,gcrotmk,gmres,lgmres,minres,qmr,tfqmr}
291+
is now deprecated in favour of rtol and will be removed in SciPy 1.14.
292+
Furthermore, the default value of atol for these functions is due to change to 0.0 in SciPy 1.14.
293+
"""
294+
import scipy
295+
296+
scipy_version = scipy.__version__
297+
version_tuple = tuple(map(int, scipy_version.split(".")))
298+
is_greater_than_1_11 = version_tuple[0] > 1 or (
299+
version_tuple[0] == 1 and version_tuple[1] >= 11
300+
)
301+
302+
if not is_greater_than_1_11:
303+
self.solver_kwargs["tol"] = self.solver_kwargs["rtol"]
304+
del self.solver_kwargs["rtol"]
285305

286306
def _solve(self, rhs, sol=None, ind_y=0):
287307
with self._active(self.options["print_solve"]) as printer:

0 commit comments

Comments
 (0)