diff --git a/src/pycsou/opt/solver/pds.py b/src/pycsou/opt/solver/pds.py index 4f356f0ee..d55d4e9b0 100644 --- a/src/pycsou/opt/solver/pds.py +++ b/src/pycsou/opt/solver/pds.py @@ -1202,7 +1202,9 @@ class ADMM(_PDS): **Remark 1:** The algorithm is still valid if one of the terms :math:`\mathcal{F}`, :math:`\mathcal{H}` or :math:`\mathbf{K}` is - zero or identity. + zero or identity, respectivley. Note that when :math:`\mathbf{K}` is identity, ADMM is equivalent to the :py:func:`~pycsou.opt.solver.pds.DouglasRachford` + method (up to a change of variable, see [PSA]_ for a derivation). + **Remark 2:** This is an implementation of the algorithm described in Section 5.4 of [PSA]_, which handles the non-smooth @@ -1254,7 +1256,8 @@ class ADMM(_PDS): **Remark 3:** - Note that this algorithm **does not require** the diff-lipschitz constant of :math:`\mathcal{F}` to be known! + Note that, unlike traditional implementations of ADMM, this implementation supports the + possibility of relaxation (i.e. :math:`\rho\neq 1), . **Initialization parameters of the class:** @@ -1264,32 +1267,30 @@ class ADMM(_PDS): Proximable functional, instance of :py:class:`~pycsou.abc.operator.ProxFunc`. K: LinOp | None Linear operator, instance of :py:class:`~pycsou.abc.operator.LinOp`. - solver: Callable[ndarray, float] | None - Callable function that solves the x-minimization step :math:numref:`eq:x_minimization`. + solver: Callable[[NDArray, float], NDArray] | None + Optional callable function (with `Numpy signature `_ + ``(n), (1) -> (n)`` ) that solves the x-minimization step :math:numref:`eq:x_minimization` with a custom solver. + solver_kwargs: dict | None + Optional keyword arguments to be passed to the ``__init__`` method of the sub-iterative :py:class:`~pycsou.opt.solver.cg.CG` or + :py:class:`~pycsou.opt.solver.cg.NLCG` solvers (see Remark 2). Ignored if custom ``solver`` is provided. + **Parameterization** of the ``fit()`` method: x0: NDArray (..., N) initial point(s) for the primal variable. - z0: NDArray + z0: NDArray | None (..., N) initial point(s) for the dual variable. If ``None`` (default), then use ``K(x0)`` as the initial point for the dual variable. tau: Real | None Primal step size. rho: Real | None - Momentum parameter. + Momentum parameter for relaxation. tuning_strategy: [1, 2, 3] Strategy to be employed when setting the hyperparameters (default to 1). See base class for more details. - cg_kwargs: dict | None - Initialization parameters of the inner-loop :py:class:`~pycsou.opt.solver.cg.CG` algorithm (see Remark 2). - cg_fit_kwargs: dict | None - Parameters of the ``fit()`` method of the inner-loop :py:class:`~pycsou.opt.solver.cg.CG` algorithm (see Remark - 2). - nlcg_kwargs: dict | None - Initialization parameters of the inner-loop :py:class:`~pycsou.opt.solver.nlcg.NLCG` algorithm (see Remark 2). - nlcg_fit_kwargs: dict | None - Parameters of the ``fit()`` method of the inner-loop :py:class:`~pycsou.opt.solver.nlcg.NLCG` algorithm (see - Remark 2). + solver_kwargs: dict | None + Optional keyword arguments to be passed to the ``fit()`` method of the sub-iterative :py:class:`~pycsou.opt.solver.cg.CG` or + :py:class:`~pycsou.opt.solver.cg.NLCG` solvers (see Remark 2). Ignored if custom ``solver`` is provided. See Also -------- @@ -1303,31 +1304,30 @@ def __init__( h: typ.Optional[pyca.ProxFunc] = None, K: typ.Optional[pyca.DiffMap] = None, solver: typ.Callable[[pyct.NDArray, float], pyct.NDArray] = None, + solver_kwargs: typ.Optional[dict] = None, **kwargs, ): kwargs.update(log_var=kwargs.get("log_var", ("x", "u", "z"))) - x_update_method = "solver" # Method for the x-minimization step + x_update_solver = "custom" # Method for the x-minimization step g = None - beta = 1 # The value of beta is irrelevant in the cg and nlcg scenarios if solver is None: if f.has(pyco.Property.PROXIMABLE) and K is None: - x_update_method = "prox" + x_update_solver = "prox" g = f # In this case, f corresponds to g in the _PDS terminology f = None - beta = 0 # Beta does not apply to the prox x_update_method since f is None elif isinstance(f, pycf.QuadraticFunc): - x_update_method = "cg" + x_update_solver = "cg" warnings.warn( - "An inner-loop conjugate gradient algorithm will be applied for the x-minimization step " - "of ADMM. This might lead to slow convergence.", + "A sub-iterative conjugate gradient algorithm is used for the x-minimization step " + "of ADMM. This might be computationally expensive.", UserWarning, ) elif f.has(pyco.Property.DIFFERENTIABLE_FUNCTION): - x_update_method = "nlcg" + x_update_solver = "nlcg" warnings.warn( - "An inner-loop non-linear conjugate gradient algorithm will be applied for the " - "x-minimization step of ADMM. This might lead to slow convergence.", + "An sub-iterative non-linear conjugate gradient algorithm is used for the " + "x-minimization step of ADMM. This might be computationally expensive.", UserWarning, ) else: @@ -1336,15 +1336,15 @@ def __init__( "QuadraticFunc, or a DiffMap. If neither of these scenarios hold, a solver must be provided for the" "x-minimization step of ADMM." ) - self.solver = solver - self.x_update_method = x_update_method + self._solver = solver + self._x_update_solver = x_update_solver + self._init_kwargs = solver_kwargs if solver_kwargs is not None else dict(show_progress=False) super().__init__( f=f, g=g, h=h, K=K, - beta=beta, **kwargs, ) @@ -1356,21 +1356,19 @@ def m_init( tau: typ.Optional[pyct.Real] = None, rho: typ.Optional[pyct.Real] = None, tuning_strategy: typ.Literal[1, 2, 3] = 1, + solver_kwargs: typ.Optional[dict] = None, **kwargs, ): super().m_init(x0=x0, z0=z0, tau=tau, sigma=None, rho=rho, tuning_strategy=tuning_strategy) mst = self._mstate # shorthand mst["u"] = self._K(x0) if x0.ndim > 1 else self._K(x0).reshape(1, -1) - # Conjugate gradient parameters - mst["cg_kwargs"] = kwargs.get("cg_kwargs", dict(show_progress=False)) - mst["cg_fit_kwargs"] = kwargs.get("cg_fit_kwargs", dict()) - # Nonlinear conjugate gradient parameters - mst["nlcg_kwargs"] = kwargs.get("nlcg_kwargs", dict(show_progress=False)) - mst["nlcg_fit_kwargs"] = kwargs.get("nlcg_fit_kwargs", dict()) + + # Fit kwargs of sub-iterative solver + self._fit_kwargs = solver_kwargs def m_step( self, - ): # Algorithm (130) in [PSA]. Paper -> code correspondence: L -> K, K -> -Id, c -> 0, y -> u, v -> z, g -> h + ): # Algorithm (145) in [PSA]. Paper -> code correspondence: L -> K, K -> -Id, c -> 0, y -> u, v -> z, g -> h mst = self._mstate mst["x"] = self._x_update(mst["u"] - mst["z"], tau=mst["tau"]) z_temp = mst["z"] + self._K(mst["x"]) - mst["u"] @@ -1379,24 +1377,24 @@ def m_step( mst["z"] = z_temp + (mst["rho"] - 1) * (self._K(mst["x"]) - mst["u"]) def _x_update(self, arr: pyct.NDArray, tau: float) -> pyct.NDArray: - if self.x_update_method == "solver": - return self.solver(arr, tau) - elif self.x_update_method == "prox": + if self._x_update_solver == "custom": + return self._solver(arr, tau) + elif self._x_update_solver == "prox": return self._g.prox(arr, tau=tau) - elif self.x_update_method == "cg": + elif self._x_update_solver == "cg": from pycsou.opt.solver import CG b = (1 / tau) * self._K.adjoint(arr) - self._f._c.grad(arr) A = self._f._Q + (1 / tau) * self._K.gram() - slvr = CG(A=A, **self._mstate["cg_kwargs"]) - slvr.fit(b=b, **self._mstate["cg_fit_kwargs"]) + slvr = CG(A=A, **self._init_kwargs) + slvr.fit(b=b, x0=self._mstate["x"], **self._fit_kwargs) # Initialize CG with previous iterate return slvr.solution() - elif self.x_update_method == "nlcg": + elif self._x_update_solver == "nlcg": from pycsou.opt.solver import NLCG - quad_func = pycf.QuadraticFunc(self._K.gram(), pyca.LinFunc.from_array(-self._K.adjoint(arr))) - slvr = NLCG(f=self._f + (1 / tau) * quad_func, **self._mstate["nlcg_kwargs"]) - slvr.fit(x0=self._mstate["x"], **self._mstate["nlcg_fit_kwargs"]) # Initialize NLCG with previous iterate + quad_func = pycf.QuadraticFunc(2 * self._K.gram(), pyca.LinFunc.from_array(-self._K.adjoint(arr))) + slvr = NLCG(f=self._f + (1 / tau) * quad_func, **self._init_kwargs) + slvr.fit(x0=self._mstate["x"], **self._fit_kwargs) # Initialize NLCG with previous iterate return slvr.solution() def solution(self, which: typ.Literal["primal", "primal_h", "dual"] = "primal") -> pyct.NDArray: