diff --git a/src/special-operators.jl b/src/special-operators.jl index c58e857a..d23e9393 100644 --- a/src/special-operators.jl +++ b/src/special-operators.jl @@ -184,12 +184,16 @@ The operation `Z * v` is equivalent to `v[I]`. `I` can be `:`. Alias for `opRestriction([k], ncol)`. """ -function opRestriction(Idx::LinearOperatorIndexType{I}, ncol::I) where {I <: Integer} +function opRestriction(Idx::LinearOperatorIndexType{I}, ncol::I; S = nothing) where {I <: Integer} all(1 .≤ Idx .≤ ncol) || throw(LinearOperatorException("indices should be between 1 and $ncol")) nrow = length(Idx) prod! = @closure (res, v, α, β) -> mulRestrict!(res, Idx, v, α, β) tprod! = @closure (res, u, α, β) -> multRestrict!(res, Idx, u, α, β) - return LinearOperator{I}(nrow, ncol, false, false, prod!, tprod!, tprod!) + if isnothing(S) + return LinearOperator{I}(nrow, ncol, false, false, prod!, tprod!, tprod!) + else + return LinearOperator{I}(nrow, ncol, false, false, prod!, tprod!, tprod!; S = S) + end end opRestriction(::Colon, ncol::I) where {I <: Integer} = opEye(I, ncol) @@ -209,8 +213,8 @@ The operation `w = Z * v` is equivalent to `w = zeros(ncol); w[I] = v`. Alias for `opExtension([k], ncol)`. """ -opExtension(Idx::LinearOperatorIndexType{I}, ncol::I) where {I <: Integer} = - opRestriction(Idx, ncol)' +opExtension(Idx::LinearOperatorIndexType{I}, ncol::I; S = nothing) where {I <: Integer} = + opRestriction(Idx, ncol; S = S)' opExtension(::Colon, ncol::I) where {I <: Integer} = opEye(I, ncol)