Skip to content

Commit

Permalink
Merge pull request #5358 from JuliaLang/cjh/tridiag-invdet
Browse files Browse the repository at this point in the history
Provide inv and det of Tridiagonal and SymTridiagonal matrices
  • Loading branch information
jiahao committed Jan 12, 2014
2 parents bd41a76 + 5de5fb0 commit 43ec0e3
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 3 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Library improvements
matrices of generic types ([#5263])
- new algorithms for linear solvers and eigensystems of `Bidiagonal`
matrices of generic types ([#5277])
- specialized `inv` and `det` for `Tridiagonal` and `SymTridiagonal` ([#5358])
- specialized methods `transpose`, `ctranspose`, `istril`, `istriu` for
`Triangular` ([#5255]) and `Bidiagonal` ([#5277])
- new LAPACK wrappers
Expand Down Expand Up @@ -140,6 +141,7 @@ Deprecated or removed
[#2345]: https://github.com/JuliaLang/julia/issues/2345
[#5330]: https://github.com/JuliaLang/julia/issues/5330
[#4882]: https://github.com/JuliaLang/julia/issues/4882
[#5358]: https://github.com/JuliaLang/julia/pull/5358

Julia v0.2.0 Release Notes
==========================
Expand Down
73 changes: 71 additions & 2 deletions base/linalg/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,68 @@ eigmin(m::SymTridiagonal) = eigvals(m, 1, 1)[1]
eigvecs(m::SymTridiagonal) = eig(m)[2]
eigvecs{Eigenvalue<:Real}(m::SymTridiagonal, eigvals::Vector{Eigenvalue}) = LAPACK.stein!(m.dv, m.ev, eigvals)

###################
# Generic methods #
###################

#Needed for inv_usmani()
type ZeroOffsetVector
data::Vector
end
getindex (a::ZeroOffsetVector, i) = a.data[i+1]
setindex!(a::ZeroOffsetVector, x, i) = a.data[i+1]=x

#Implements the inverse using the recurrence relation between principal minors
# a, b, c are assumed to be the subdiagonal, diagonal, and superdiagonal of
# a tridiagonal matrix.
#Ref:
# R. Usmani, "Inversion of a tridiagonal Jacobi matrix",
# Linear Algebra and its Applications 212-213 (1994), pp.413-414
# doi:10.1016/0024-3795(94)90414-6
function inv_usmani{T}(a::Vector{T}, b::Vector{T}, c::Vector{T})
n = length(b)
θ = ZeroOffsetVector(zeros(T, n+1)) #principal minors of A
θ[0] = 1
n>=1 && (θ[1] = b[1])
for i=2:n
θ[i] = b[i]*θ[i-1]-a[i-1]*c[i-1]*θ[i-2]
end
φ = zeros(T, n+1)
φ[n+1] = 1
n>=1 && (φ[n] = b[n])
for i=n-1:-1:1
φ[i] = b[i]*φ[i+1]-a[i]*c[i]*φ[i+2]
end
α = Array(T, n, n)
for i=1:n, j=1:n
sign = (i+j)%2==0 ? (+) : (-)
if i<j
α[i,j]=(sign)(prod(c[i:j-1]))*θ[i-1]*φ[j+1]/θ[n]
elseif i==j
α[i,i]= θ[i-1]*φ[i+1]/θ[n]
else #i>j
α[i,j]=(sign)(prod(a[j:i-1]))*θ[j-1]*φ[i+1]/θ[n]
end
end
α
end

#Implements the determinant using principal minors
#Inputs and reference are as above for inv_usmani()
function det_usmani{T}(a::Vector{T}, b::Vector{T}, c::Vector{T})
n = length(b)
θa = one(T)
n==0 && return θa
θb = b[1]
for i=2:n
θb, θa = b[i]*θb-a[i-1]*c[i-1]*θa, θb
end
return θb
end

inv(A::SymTridiagonal) = inv_usmani(A.ev, A.dv, A.ev)
det(A::SymTridiagonal) = det_usmani(A.ev, A.dv, A.ev)

## Tridiagonal matrices ##
type Tridiagonal{T} <: AbstractMatrix{T}
dl::Vector{T} # sub-diagonal
Expand Down Expand Up @@ -175,6 +237,10 @@ function diag{T}(M::Tridiagonal{T}, n::Integer=0)
end
end

###################
# Generic methods #
###################

+(A::Tridiagonal, B::Tridiagonal) = Tridiagonal(A.dl+B.dl, A.d+B.d, A.du+B.du)
-(A::Tridiagonal, B::Tridiagonal) = Tridiagonal(A.dl-B.dl, A.d-B.d, A.du+B.du)
*(A::Tridiagonal, B::Number) = Tridiagonal(A.dl*B, A.d*B, A.du*B)
Expand All @@ -185,6 +251,9 @@ end
==(A::Tridiagonal, B::SymTridiagonal) = (A.dl==A.du==B.ev) && (A.d==B.dv)
==(A::SymTridiagonal, B::SymTridiagonal) = B==A

inv(A::Tridiagonal) = inv_usmani(A.dl, A.d, A.du)
det(A::Tridiagonal) = det_usmani(A.dl, A.d, A.du)

# Elementary operations that mix Tridiagonal and SymTridiagonal matrices
convert(::Type{Tridiagonal}, A::SymTridiagonal) = Tridiagonal(A.ev, A.dv, A.ev)
+(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+B.ev, A.d+B.dv, A.du+B.ev)
Expand Down Expand Up @@ -351,7 +420,7 @@ type LUTridiagonal{T} <: Factorization{T}
# end
end
lufact!{T<:BlasFloat}(A::Tridiagonal{T}) = LUTridiagonal{T}(LAPACK.gttrf!(A.dl,A.d,A.du)...)
lufact!(A::Tridiagonal) = lufact!(float(A))
lufact!{T<:Union(Rational,Integer)}(A::Tridiagonal{T}) = lufact!(float(A))
lufact(A::Tridiagonal) = lufact!(copy(A))
factorize!(A::Tridiagonal) = lufact!(A)
#show(io, lu::LUTridiagonal) = print(io, "LU decomposition of ", summary(lu.lu))
Expand All @@ -361,7 +430,7 @@ function det{T}(lu::LUTridiagonal{T})
prod(lu.d) * (bool(sum(lu.ipiv .!= 1:n) % 2) ? -one(T) : one(T))
end

det(A::Tridiagonal) = det(lufact(A))
det{T<:BlasFloat}(A::Tridiagonal{T}) = det(lufact(A))

A_ldiv_B!{T<:BlasFloat}(lu::LUTridiagonal{T}, B::StridedVecOrMat{T}) =
LAPACK.gttrs!('N', lu.dl, lu.d, lu.du, lu.du2, lu.ipiv, B)
Expand Down
36 changes: 35 additions & 1 deletion test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ end
function test_approx_eq_vecs{S<:Real,T<:Real}(a::StridedVecOrMat{S}, b::StridedVecOrMat{T}, error=nothing)
n = size(a, 1)
@test n==size(b,1) && size(a,2)==size(b,2)
if error==nothing error=n^2*(eps(S)+eps(T)) end
error==nothing && (error=n^3*(eps(S)+eps(T)))
for i=1:n
ev1, ev2 = a[:,i], b[:,i]
deviation = min(abs(norm(ev1-ev2)),abs(norm(ev1+ev2)))
Expand Down Expand Up @@ -603,7 +603,41 @@ for relty in (Float16, Float32, Float64, BigFloat), elty in (relty, Complex{relt
end
end

#Tridiagonal matrices
for relty in (Float16, Float32, Float64), elty in (relty, Complex{relty})
a = convert(Vector{elty}, randn(n-1))
b = convert(Vector{elty}, randn(n))
c = convert(Vector{elty}, randn(n-1))
if elty <: Complex
a += im*convert(Vector{elty}, randn(n-1))
b += im*convert(Vector{elty}, randn(n))
c += im*convert(Vector{elty}, randn(n-1))
end

A=Tridiagonal(a, b, c)
fA=(elty<:Complex?complex128:float64)(full(A))
for func in (det, inv)
@test_approx_eq_eps func(A) func(fA) n^2*sqrt(eps(relty))
end
end

#SymTridiagonal (symmetric tridiagonal) matrices
for relty in (Float16, Float32, Float64), elty in (relty, )#XXX Complex{relty}) doesn't work
a = convert(Vector{elty}, randn(n))
b = convert(Vector{elty}, randn(n-1))
if elty <: Complex
relty==Float16 && continue
a += im*convert(Vector{elty}, randn(n))
b += im*convert(Vector{elty}, randn(n-1))
end

A=SymTridiagonal(a, b)
fA=(elty<:Complex?complex128:float64)(full(A))
for func in (det, inv)
@test_approx_eq_eps func(A) func(fA) n^2*sqrt(eps(relty))
end
end

Ainit = randn(n)
Binit = randn(n-1)
for elty in (Float32, Float64)
Expand Down

0 comments on commit 43ec0e3

Please sign in to comment.