Skip to content

Commit

Permalink
[ADMM] Improved documentation + code restructuration.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthieumeo committed Jan 27, 2023
1 parent d15c949 commit 206fca2
Showing 1 changed file with 44 additions and 46 deletions.
90 changes: 44 additions & 46 deletions src/pycsou/opt/solver/pds.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,9 @@ class ADMM(_PDS):
**Remark 1:**
The algorithm is still valid if one of the terms :math:`\mathcal{F}`, :math:`\mathcal{H}` or :math:`\mathbf{K}` is
zero or identity.
zero or identity, respectivley. Note that when :math:`\mathbf{K}` is identity, ADMM is equivalent to the :py:func:`~pycsou.opt.solver.pds.DouglasRachford`
method (up to a change of variable, see [PSA]_ for a derivation).
**Remark 2:**
This is an implementation of the algorithm described in Section 5.4 of [PSA]_, which handles the non-smooth
Expand Down Expand Up @@ -1254,7 +1256,8 @@ class ADMM(_PDS):
**Remark 3:**
Note that this algorithm **does not require** the diff-lipschitz constant of :math:`\mathcal{F}` to be known!
Note that, unlike traditional implementations of ADMM, this implementation supports the
possibility of relaxation (i.e. :math:`\rho\neq 1), .
**Initialization parameters of the class:**
Expand All @@ -1264,32 +1267,30 @@ class ADMM(_PDS):
Proximable functional, instance of :py:class:`~pycsou.abc.operator.ProxFunc`.
K: LinOp | None
Linear operator, instance of :py:class:`~pycsou.abc.operator.LinOp`.
solver: Callable[ndarray, float] | None
Callable function that solves the x-minimization step :math:numref:`eq:x_minimization`.
solver: Callable[[NDArray, float], NDArray] | None
Optional callable function (with `Numpy signature <https://numpy.org/neps/nep-0020-gufunc-signature-enhancement.html>`_
``(n), (1) -> (n)`` ) that solves the x-minimization step :math:numref:`eq:x_minimization` with a custom solver.
solver_kwargs: dict | None
Optional keyword arguments to be passed to the ``__init__`` method of the sub-iterative :py:class:`~pycsou.opt.solver.cg.CG` or
:py:class:`~pycsou.opt.solver.cg.NLCG` solvers (see Remark 2). Ignored if custom ``solver`` is provided.
**Parameterization** of the ``fit()`` method:
x0: NDArray
(..., N) initial point(s) for the primal variable.
z0: NDArray
z0: NDArray | None
(..., N) initial point(s) for the dual variable.
If ``None`` (default), then use ``K(x0)`` as the initial point for the dual variable.
tau: Real | None
Primal step size.
rho: Real | None
Momentum parameter.
Momentum parameter for relaxation.
tuning_strategy: [1, 2, 3]
Strategy to be employed when setting the hyperparameters (default to 1). See base class for more details.
cg_kwargs: dict | None
Initialization parameters of the inner-loop :py:class:`~pycsou.opt.solver.cg.CG` algorithm (see Remark 2).
cg_fit_kwargs: dict | None
Parameters of the ``fit()`` method of the inner-loop :py:class:`~pycsou.opt.solver.cg.CG` algorithm (see Remark
2).
nlcg_kwargs: dict | None
Initialization parameters of the inner-loop :py:class:`~pycsou.opt.solver.nlcg.NLCG` algorithm (see Remark 2).
nlcg_fit_kwargs: dict | None
Parameters of the ``fit()`` method of the inner-loop :py:class:`~pycsou.opt.solver.nlcg.NLCG` algorithm (see
Remark 2).
solver_kwargs: dict | None
Optional keyword arguments to be passed to the ``fit()`` method of the sub-iterative :py:class:`~pycsou.opt.solver.cg.CG` or
:py:class:`~pycsou.opt.solver.cg.NLCG` solvers (see Remark 2). Ignored if custom ``solver`` is provided.
See Also
--------
Expand All @@ -1303,31 +1304,30 @@ def __init__(
h: typ.Optional[pyca.ProxFunc] = None,
K: typ.Optional[pyca.DiffMap] = None,
solver: typ.Callable[[pyct.NDArray, float], pyct.NDArray] = None,
solver_kwargs: typ.Optional[dict] = None,
**kwargs,
):
kwargs.update(log_var=kwargs.get("log_var", ("x", "u", "z")))

x_update_method = "solver" # Method for the x-minimization step
x_update_solver = "custom" # Method for the x-minimization step
g = None
beta = 1 # The value of beta is irrelevant in the cg and nlcg scenarios
if solver is None:
if f.has(pyco.Property.PROXIMABLE) and K is None:
x_update_method = "prox"
x_update_solver = "prox"
g = f # In this case, f corresponds to g in the _PDS terminology
f = None
beta = 0 # Beta does not apply to the prox x_update_method since f is None
elif isinstance(f, pycf.QuadraticFunc):
x_update_method = "cg"
x_update_solver = "cg"
warnings.warn(
"An inner-loop conjugate gradient algorithm will be applied for the x-minimization step "
"of ADMM. This might lead to slow convergence.",
"A sub-iterative conjugate gradient algorithm is used for the x-minimization step "
"of ADMM. This might be computationally expensive.",
UserWarning,
)
elif f.has(pyco.Property.DIFFERENTIABLE_FUNCTION):
x_update_method = "nlcg"
x_update_solver = "nlcg"
warnings.warn(
"An inner-loop non-linear conjugate gradient algorithm will be applied for the "
"x-minimization step of ADMM. This might lead to slow convergence.",
"An sub-iterative non-linear conjugate gradient algorithm is used for the "
"x-minimization step of ADMM. This might be computationally expensive.",
UserWarning,
)
else:
Expand All @@ -1336,15 +1336,15 @@ def __init__(
"QuadraticFunc, or a DiffMap. If neither of these scenarios hold, a solver must be provided for the"
"x-minimization step of ADMM."
)
self.solver = solver
self.x_update_method = x_update_method
self._solver = solver
self._x_update_solver = x_update_solver
self._init_kwargs = solver_kwargs if solver_kwargs is not None else dict(show_progress=False)

super().__init__(
f=f,
g=g,
h=h,
K=K,
beta=beta,
**kwargs,
)

Expand All @@ -1356,21 +1356,19 @@ def m_init(
tau: typ.Optional[pyct.Real] = None,
rho: typ.Optional[pyct.Real] = None,
tuning_strategy: typ.Literal[1, 2, 3] = 1,
solver_kwargs: typ.Optional[dict] = None,
**kwargs,
):
super().m_init(x0=x0, z0=z0, tau=tau, sigma=None, rho=rho, tuning_strategy=tuning_strategy)
mst = self._mstate # shorthand
mst["u"] = self._K(x0) if x0.ndim > 1 else self._K(x0).reshape(1, -1)
# Conjugate gradient parameters
mst["cg_kwargs"] = kwargs.get("cg_kwargs", dict(show_progress=False))
mst["cg_fit_kwargs"] = kwargs.get("cg_fit_kwargs", dict())
# Nonlinear conjugate gradient parameters
mst["nlcg_kwargs"] = kwargs.get("nlcg_kwargs", dict(show_progress=False))
mst["nlcg_fit_kwargs"] = kwargs.get("nlcg_fit_kwargs", dict())

# Fit kwargs of sub-iterative solver
self._fit_kwargs = solver_kwargs

def m_step(
self,
): # Algorithm (130) in [PSA]. Paper -> code correspondence: L -> K, K -> -Id, c -> 0, y -> u, v -> z, g -> h
): # Algorithm (145) in [PSA]. Paper -> code correspondence: L -> K, K -> -Id, c -> 0, y -> u, v -> z, g -> h
mst = self._mstate
mst["x"] = self._x_update(mst["u"] - mst["z"], tau=mst["tau"])
z_temp = mst["z"] + self._K(mst["x"]) - mst["u"]
Expand All @@ -1379,24 +1377,24 @@ def m_step(
mst["z"] = z_temp + (mst["rho"] - 1) * (self._K(mst["x"]) - mst["u"])

def _x_update(self, arr: pyct.NDArray, tau: float) -> pyct.NDArray:
if self.x_update_method == "solver":
return self.solver(arr, tau)
elif self.x_update_method == "prox":
if self._x_update_solver == "custom":
return self._solver(arr, tau)
elif self._x_update_solver == "prox":
return self._g.prox(arr, tau=tau)
elif self.x_update_method == "cg":
elif self._x_update_solver == "cg":
from pycsou.opt.solver import CG

b = (1 / tau) * self._K.adjoint(arr) - self._f._c.grad(arr)
A = self._f._Q + (1 / tau) * self._K.gram()
slvr = CG(A=A, **self._mstate["cg_kwargs"])
slvr.fit(b=b, **self._mstate["cg_fit_kwargs"])
slvr = CG(A=A, **self._init_kwargs)
slvr.fit(b=b, x0=self._mstate["x"], **self._fit_kwargs) # Initialize CG with previous iterate
return slvr.solution()
elif self.x_update_method == "nlcg":
elif self._x_update_solver == "nlcg":
from pycsou.opt.solver import NLCG

quad_func = pycf.QuadraticFunc(self._K.gram(), pyca.LinFunc.from_array(-self._K.adjoint(arr)))
slvr = NLCG(f=self._f + (1 / tau) * quad_func, **self._mstate["nlcg_kwargs"])
slvr.fit(x0=self._mstate["x"], **self._mstate["nlcg_fit_kwargs"]) # Initialize NLCG with previous iterate
quad_func = pycf.QuadraticFunc(2 * self._K.gram(), pyca.LinFunc.from_array(-self._K.adjoint(arr)))
slvr = NLCG(f=self._f + (1 / tau) * quad_func, **self._init_kwargs)
slvr.fit(x0=self._mstate["x"], **self._fit_kwargs) # Initialize NLCG with previous iterate
return slvr.solution()

def solution(self, which: typ.Literal["primal", "primal_h", "dual"] = "primal") -> pyct.NDArray:
Expand Down

0 comments on commit 206fca2

Please sign in to comment.