Skip to content

Commit

Permalink
LinearAlgebra: adjoint for bidiag/tridiag may preserve structure (#54027
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jishnub authored and pull[bot] committed Oct 1, 2024
1 parent 0d47f7c commit 95872b9
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
4 changes: 3 additions & 1 deletion stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ for func in (:conj, :copy, :real, :imag)
@eval ($func)(M::Bidiagonal) = Bidiagonal(($func)(M.dv), ($func)(M.ev), M.uplo)
end

adjoint(B::Bidiagonal{<:Number}) = Bidiagonal(conj(B.dv), conj(B.ev), B.uplo == 'U' ? :L : :U)
adjoint(B::Bidiagonal{<:Number}) = Bidiagonal(vec(adjoint(B.dv)), vec(adjoint(B.ev)), B.uplo == 'U' ? :L : :U)
adjoint(B::Bidiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) =
Bidiagonal(adjoint(parent(B.dv)), adjoint(parent(B.ev)), B.uplo == 'U' ? :L : :U)
transpose(B::Bidiagonal{<:Number}) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? :L : :U)
permutedims(B::Bidiagonal) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? 'L' : 'U')
function permutedims(B::Bidiagonal, perm)
Expand Down
9 changes: 7 additions & 2 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ for func in (:conj, :copy, :real, :imag)
end

transpose(S::SymTridiagonal) = S
adjoint(S::SymTridiagonal{<:Real}) = S
adjoint(S::SymTridiagonal{<:Number}) = SymTridiagonal(vec(adjoint(S.dv)), vec(adjoint(S.ev)))
adjoint(S::SymTridiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) =
SymTridiagonal(adjoint(parent(S.dv)), adjoint(parent(S.ev)))

permutedims(S::SymTridiagonal) = S
function permutedims(S::SymTridiagonal, perm)
Base.checkdims_perm(S, S, perm)
Expand Down Expand Up @@ -622,7 +625,9 @@ for func in (:conj, :copy, :real, :imag)
end
end

adjoint(S::Tridiagonal{<:Real}) = Tridiagonal(S.du, S.d, S.dl)
adjoint(S::Tridiagonal{<:Number}) = Tridiagonal(vec(adjoint(S.du)), vec(adjoint(S.d)), vec(adjoint(S.dl)))
adjoint(S::Tridiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) =
Tridiagonal(adjoint(parent(S.du)), adjoint(parent(S.d)), adjoint(parent(S.dl)))
transpose(S::Tridiagonal{<:Number}) = Tridiagonal(S.du, S.d, S.dl)
permutedims(T::Tridiagonal) = Tridiagonal(T.du, T.d, T.dl)
function permutedims(T::Tridiagonal, perm)
Expand Down
3 changes: 3 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ Random.seed!(1)

@testset for func in (conj, transpose, adjoint)
@test func(func(T)) == T
if func (transpose, adjoint)
@test func(func(T)) === T
end
end

@testset "permutedims(::Bidiagonal)" begin
Expand Down
3 changes: 3 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ end
@testset "Idempotent tests" begin
for func in (conj, transpose, adjoint)
@test func(func(A)) == A
if func (transpose, adjoint)
@test func(func(A)) === A
end
end
end
@testset "permutedims(::[Sym]Tridiagonal)" begin
Expand Down

0 comments on commit 95872b9

Please sign in to comment.