From 662993ffb67162957842f04598542007b79a62bd Mon Sep 17 00:00:00 2001 From: Akira Kyle Date: Thu, 5 Dec 2024 11:07:41 -0700 Subject: [PATCH 1/7] Add test_bases from QuantumOpticsBase and reorganize a bit --- src/QuantumInterface.jl | 36 +++++- src/abstract_types.jl | 32 +++-- src/bases.jl | 274 +--------------------------------------- src/embed_permute.jl | 26 ++-- src/identityoperator.jl | 4 +- src/julia_base.jl | 39 +++--- src/julia_linalg.jl | 6 +- src/linalg.jl | 198 ++++++++++++++++++++++++++++- src/show.jl | 69 ++++++++++ src/sparse.jl | 2 +- test/runtests.jl | 1 + test/test_bases.jl | 55 ++++++++ 12 files changed, 415 insertions(+), 327 deletions(-) create mode 100644 src/show.jl create mode 100644 test/test_bases.jl diff --git a/src/QuantumInterface.jl b/src/QuantumInterface.jl index 34efa04..89d3c9a 100644 --- a/src/QuantumInterface.jl +++ b/src/QuantumInterface.jl @@ -1,14 +1,39 @@ module QuantumInterface -import Base: ==, +, -, *, /, ^, length, one, exp, conj, conj!, transpose, copy -import LinearAlgebra: tr, ishermitian, norm, normalize, normalize! -import Base: show, summary -import SparseArrays: sparse, spzeros, AbstractSparseMatrix # TODO move to an extension +## +# Basis specific +## + +""" + basis(a) + +Return the basis of an object. + +If it's ambiguous, e.g. if an operator has a different left and right basis, +an [`IncompatibleBases`](@ref) error is thrown. +""" +function basis end + +""" +Exception that should be raised for an illegal algebraic operation. +""" +mutable struct IncompatibleBases <: Exception end + + +## +# Standard methods +## function apply! end function dagger end +""" + directsum(x, y, z...) + +Direct sum of the given objects. Alternatively, the unicode +symbol ⊕ (\\oplus) can be used. +""" function directsum end const ⊕ = directsum directsum() = GenericBasis(0) @@ -86,8 +111,9 @@ function squeeze end function wigner end -include("bases.jl") include("abstract_types.jl") +include("bases.jl") +include("show.jl") include("linalg.jl") include("tensor.jl") diff --git a/src/abstract_types.jl b/src/abstract_types.jl index f8667c9..0650290 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -1,3 +1,18 @@ +""" +Abstract base class for all specialized bases. + +The Basis class is meant to specify a basis of the Hilbert space of the +studied system. Besides basis specific information all subclasses must +implement a shape variable which indicates the dimension of the used +Hilbert space. For a spin-1/2 Hilbert space this would be the +vector `[2]`. A system composed of two spins would then have a +shape vector `[2 2]`. + +Composite systems can be defined with help of the [`CompositeBasis`](@ref) +class. +""" +abstract type Basis end + """ Abstract base class for `Bra` and `Ket` states. @@ -38,20 +53,3 @@ A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)} ``` """ abstract type AbstractSuperOperator{B1,B2} end - -function summary(stream::IO, x::AbstractOperator) - print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n") - if samebases(x) - print(stream, " basis: ") - show(stream, basis(x)) - else - print(stream, " basis left: ") - show(stream, x.basis_l) - print(stream, "\n basis right: ") - show(stream, x.basis_r) - end -end - -show(stream::IO, x::AbstractOperator) = summary(stream, x) - -traceout!(s::StateVector, i) = ptrace(s,i) diff --git a/src/bases.jl b/src/bases.jl index 6e4b077..2c059e0 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -1,17 +1,6 @@ -""" -Abstract base class for all specialized bases. - -The Basis class is meant to specify a basis of the Hilbert space of the -studied system. Besides basis specific information all subclasses must -implement a shape variable which indicates the dimension of the used -Hilbert space. For a spin-1/2 Hilbert space this would be the -vector `[2]`. A system composed of two spins would then have a -shape vector `[2 2]`. - -Composite systems can be defined with help of the [`CompositeBasis`](@ref) -class. -""" -abstract type Basis end +## +# GenericBasis, CompositeBasis +## """ length(b::Basis) @@ -20,17 +9,6 @@ Total dimension of the Hilbert space. """ Base.length(b::Basis) = prod(b.shape) -""" - basis(a) - -Return the basis of an object. - -If it's ambiguous, e.g. if an operator has a different left and right basis, -an [`IncompatibleBases`](@ref) error is thrown. -""" -function basis end - - """ GenericBasis(N) @@ -67,39 +45,6 @@ CompositeBasis(bases::Vector) = CompositeBasis((bases...,)) Base.:(==)(b1::T, b2::T) where T<:CompositeBasis = equal_shape(b1.shape, b2.shape) -tensor(b::Basis) = b - -""" - tensor(x::Basis, y::Basis, z::Basis...) - -Create a [`CompositeBasis`](@ref) from the given bases. - -Any given CompositeBasis is expanded so that the resulting CompositeBasis never -contains another CompositeBasis. -""" -tensor(b1::Basis, b2::Basis) = CompositeBasis([length(b1); length(b2)], (b1, b2)) -tensor(b1::CompositeBasis, b2::CompositeBasis) = CompositeBasis([b1.shape; b2.shape], (b1.bases..., b2.bases...)) -function tensor(b1::CompositeBasis, b2::Basis) - N = length(b1.bases) - shape = vcat(b1.shape, length(b2)) - bases = (b1.bases..., b2) - CompositeBasis(shape, bases) -end -function tensor(b1::Basis, b2::CompositeBasis) - N = length(b2.bases) - shape = vcat(length(b1), b2.shape) - bases = (b1, b2.bases...) - CompositeBasis(shape, bases) -end -tensor(bases::Basis...) = reduce(tensor, bases) - -function Base.:^(b::Basis, N::Integer) - if N < 1 - throw(ArgumentError("Power of a basis is only defined for positive integers.")) - end - tensor([b for i=1:N]...) -end - """ equal_shape(a, b) @@ -137,130 +82,6 @@ function equal_bases(a, b) return true end -""" -Exception that should be raised for an illegal algebraic operation. -""" -mutable struct IncompatibleBases <: Exception end - -const BASES_CHECK = Ref(true) - -""" - @samebases - -Macro to skip checks for same bases. Useful for `*`, `expect` and similar -functions. -""" -macro samebases(ex) - return quote - BASES_CHECK.x = false - local val = $(esc(ex)) - BASES_CHECK.x = true - val - end -end - -""" - samebases(a, b) - -Test if two objects have the same bases. -""" -samebases(b1::Basis, b2::Basis) = b1==b2 -samebases(b1::Tuple{Basis, Basis}, b2::Tuple{Basis, Basis}) = b1==b2 # for checking superoperators - -""" - check_samebases(a, b) - -Throw an [`IncompatibleBases`](@ref) error if the objects don't have -the same bases. -""" -function check_samebases(b1, b2) - if BASES_CHECK[] && !samebases(b1, b2) - throw(IncompatibleBases()) - end -end - - -""" - multiplicable(a, b) - -Check if two objects are multiplicable. -""" -multiplicable(b1::Basis, b2::Basis) = b1==b2 - -function multiplicable(b1::CompositeBasis, b2::CompositeBasis) - if !equal_shape(b1.shape,b2.shape) - return false - end - for i=1:length(b1.shape) - if !multiplicable(b1.bases[i], b2.bases[i]) - return false - end - end - return true -end - -""" - check_multiplicable(a, b) - -Throw an [`IncompatibleBases`](@ref) error if the objects are -not multiplicable. -""" -function check_multiplicable(b1, b2) - if BASES_CHECK[] && !multiplicable(b1, b2) - throw(IncompatibleBases()) - end -end - -""" - reduced(a, indices) - -Reduced basis, state or operator on the specified subsystems. - -The `indices` argument, which can be a single integer or a vector of integers, -specifies which subsystems are kept. At least one index must be specified. -""" -function reduced(b::CompositeBasis, indices) - if length(indices)==0 - throw(ArgumentError("At least one subsystem must be specified in reduced.")) - elseif length(indices)==1 - return b.bases[indices[1]] - else - return CompositeBasis(b.shape[indices], b.bases[indices]) - end -end - -""" - ptrace(a, indices) - -Partial trace of the given basis, state or operator. - -The `indices` argument, which can be a single integer or a vector of integers, -specifies which subsystems are traced out. The number of indices has to be -smaller than the number of subsystems, i.e. it is not allowed to perform a -full trace. -""" -function ptrace(b::CompositeBasis, indices) - J = [i for i in 1:length(b.bases) if i ∉ indices] - length(J) > 0 || throw(ArgumentError("Tracing over all indices is not allowed in ptrace.")) - reduced(b, J) -end - - -""" - permutesystems(a, perm) - -Change the ordering of the subsystems of the given object. - -For a permutation vector `[2,1,3]` and a given object with basis `[b1, b2, b3]` -this function results in `[b2, b1, b3]`. -""" -function permutesystems(b::CompositeBasis, perm) - @assert length(b.bases) == length(perm) - @assert isperm(perm) - CompositeBasis(b.shape[perm], b.bases[perm]) -end - - ## # Common bases ## @@ -366,89 +187,6 @@ SumBasis(shape, bases::Vector) = (tmp = (bases...,); SumBasis(shape, tmp)) SumBasis(bases::Vector) = SumBasis((bases...,)) SumBasis(bases::Basis...) = SumBasis((bases...,)) -==(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape) -==(b1::SumBasis, b2::SumBasis) = false -length(b::SumBasis) = sum(b.shape) - -""" - directsum(b1::Basis, b2::Basis) - -Construct the [`SumBasis`](@ref) out of two sub-bases. -""" -directsum(b1::Basis, b2::Basis) = SumBasis(Int[length(b1); length(b2)], Basis[b1, b2]) -directsum(b::Basis) = b -directsum(b::Basis...) = reduce(directsum, b) -function directsum(b1::SumBasis, b2::Basis) - shape = [b1.shape;length(b2)] - bases = [b1.bases...;b2] - return SumBasis(shape, (bases...,)) -end -function directsum(b1::Basis, b2::SumBasis) - shape = [length(b1);b2.shape] - bases = [b1;b2.bases...] - return SumBasis(shape, (bases...,)) -end -function directsum(b1::SumBasis, b2::SumBasis) - shape = [b1.shape;b2.shape] - bases = [b1.bases...;b2.bases...] - return SumBasis(shape, (bases...,)) -end - -embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops) - -## -# show methods -## - -function show(stream::IO, x::GenericBasis) - if length(x.shape) == 1 - write(stream, "Basis(dim=$(x.shape[1]))") - else - s = replace(string(x.shape), " " => "") - write(stream, "Basis(shape=$s)") - end -end - -function show(stream::IO, x::CompositeBasis) - write(stream, "[") - for i in 1:length(x.bases) - show(stream, x.bases[i]) - if i != length(x.bases) - write(stream, " ⊗ ") - end - end - write(stream, "]") -end - -function show(stream::IO, x::SpinBasis) - d = denominator(x.spinnumber) - n = numerator(x.spinnumber) - if d == 1 - write(stream, "Spin($n)") - else - write(stream, "Spin($n/$d)") - end -end - -function show(stream::IO, x::FockBasis) - if iszero(x.offset) - write(stream, "Fock(cutoff=$(x.N))") - else - write(stream, "Fock(cutoff=$(x.N), offset=$(x.offset))") - end -end - -function show(stream::IO, x::NLevelBasis) - write(stream, "NLevel(N=$(x.N))") -end - -function show(stream::IO, x::SumBasis) - write(stream, "[") - for i in 1:length(x.bases) - show(stream, x.bases[i]) - if i != length(x.bases) - write(stream, " ⊕ ") - end - end - write(stream, "]") -end +Base.:(==)(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape) +Base.:(==)(b1::SumBasis, b2::SumBasis) = false +Base.length(b::SumBasis) = sum(b.shape) diff --git a/src/embed_permute.jl b/src/embed_permute.jl index c2cc4ca..297ab58 100644 --- a/src/embed_permute.jl +++ b/src/embed_permute.jl @@ -67,8 +67,8 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis, ops_sb = [x[2] for x in idxop_sb] for (idxsb, opsb) in zip(indices_sb, ops_sb) - (opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases()) - (opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases()) + (opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases()) # FIXME issue #12 + (opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases()) # FIXME issue #12 end S = length(operators) > 0 ? mapreduce(eltype, promote_type, operators) : Any @@ -83,10 +83,20 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis, return embed_op end -permutesystems(a::AbstractOperator, perm) = arithmetic_unary_error("Permutations of subsystems", a) +embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops) + +""" + permutesystems(a, perm) + +Change the ordering of the subsystems of the given object. -nsubsystems(s::AbstractKet) = nsubsystems(basis(s)) -nsubsystems(s::AbstractOperator) = nsubsystems(basis(s)) -nsubsystems(b::CompositeBasis) = length(b.bases) -nsubsystems(b::Basis) = 1 -nsubsystems(::Nothing) = 1 # TODO Exists because of QuantumSavory; Consider removing this and reworking the functions that depend on it. E.g., a reason to have it when performing a project_traceout measurement on a state that contains only one subsystem +For a permutation vector `[2,1,3]` and a given object with basis `[b1, b2, b3]` +this function results in `[b2, b1, b3]`. +""" +function permutesystems(b::CompositeBasis, perm) + @assert length(b.bases) == length(perm) + @assert isperm(perm) + CompositeBasis(b.shape[perm], b.bases[perm]) +end + +permutesystems(a::AbstractOperator, perm) = arithmetic_unary_error("Permutations of subsystems", a) diff --git a/src/identityoperator.jl b/src/identityoperator.jl index 5959882..aa031ff 100644 --- a/src/identityoperator.jl +++ b/src/identityoperator.jl @@ -1,4 +1,4 @@ -one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x) +Base.one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x) """ identityoperator(a::Basis[, b::Basis]) @@ -22,4 +22,4 @@ identityoperator(::Type{T}, ::Type{Any}, b1::Basis, b2::Basis) where T<:Abstract identityoperator(b1::Basis, b2::Basis) = identityoperator(ComplexF64, b1, b2) """Prepare the identity superoperator over a given space.""" -function identitysuperoperator end \ No newline at end of file +function identitysuperoperator end diff --git a/src/julia_base.jl b/src/julia_base.jl index 9a0532d..2d8e085 100644 --- a/src/julia_base.jl +++ b/src/julia_base.jl @@ -1,3 +1,5 @@ +import Base: +, -, *, /, ^, length, exp, conj, conj!, adjoint, transpose, copy + # Common error messages arithmetic_unary_error(funcname, x::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this type of operator: $(typeof(x)).\nTry to convert to another operator type first with e.g. dense() or sparse().")) arithmetic_binary_error(funcname, a::AbstractOperator, b::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this combination of types of operators: $(typeof(a)), $(typeof(b)).\nTry to convert to a common operator type first with e.g. dense() or sparse().")) @@ -8,33 +10,31 @@ addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an op # States ## --(a::T) where {T<:StateVector} = T(a.basis, -a.data) +-(a::T) where {T<:StateVector} = T(a.basis, -a.data) # FIXME issue #12 *(a::StateVector, b::Number) = b*a -copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) -length(a::StateVector) = length(a.basis)::Int -basis(a::StateVector) = a.basis -directsum(x::StateVector...) = reduce(directsum, x) +copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) # FIXME issue #12 +length(a::StateVector) = length(a.basis)::Int # FIXME issue #12 +basis(a::StateVector) = a.basis # FIXME issue #12 +adjoint(a::StateVector) = dagger(a) + # Array-like functions -Base.size(x::StateVector) = size(x.data) -@inline Base.axes(x::StateVector) = axes(x.data) +Base.size(x::StateVector) = size(x.data) # FIXME issue #12 +@inline Base.axes(x::StateVector) = axes(x.data) # FIXME issue #12 Base.ndims(x::StateVector) = 1 Base.ndims(::Type{<:StateVector}) = 1 -Base.eltype(x::StateVector) = eltype(x.data) +Base.eltype(x::StateVector) = eltype(x.data) # FIXME issue #12 # Broadcasting Base.broadcastable(x::StateVector) = x -Base.adjoint(a::StateVector) = dagger(a) - - ## # Operators ## -length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int -basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) -basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) +length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int # FIXME issue #12 +basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) # FIXME issue #12 +basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) # FIXME issue #12 # Ensure scalar broadcasting Base.broadcastable(x::AbstractOperator) = Ref(x) @@ -60,14 +60,19 @@ Operator exponential. """ exp(op::AbstractOperator) = throw(ArgumentError("exp() is not defined for this type of operator: $(typeof(op)).\nTry to convert to dense operator first with dense().")) -Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r)) +Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r)) # FIXME issue #12 function Base.size(op::AbstractOperator, i::Int) i < 1 && throw(ErrorException("dimension index is < 1")) i > 2 && return 1 - i==1 ? length(op.basis_l) : length(op.basis_r) + i==1 ? length(op.basis_l) : length(op.basis_r) # FIXME issue #12 end -Base.adjoint(a::AbstractOperator) = dagger(a) +dagger(a::AbstractOperator) = arithmetic_unary_error("Hermitian conjugate", a) + +adjoint(a::AbstractOperator) = dagger(a) + +transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a) + conj(a::AbstractOperator) = arithmetic_unary_error("Complex conjugate", a) conj!(a::AbstractOperator) = conj(a::AbstractOperator) diff --git a/src/julia_linalg.jl b/src/julia_linalg.jl index d2f4d3d..3087d0a 100644 --- a/src/julia_linalg.jl +++ b/src/julia_linalg.jl @@ -1,3 +1,5 @@ +import LinearAlgebra: tr, ishermitian, norm, normalize, normalize! + """ ishermitian(op::AbstractOperator) @@ -17,7 +19,7 @@ tr(x::AbstractOperator) = arithmetic_unary_error("Trace", x) Norm of the given bra or ket state. """ -norm(x::StateVector) = norm(x.data) +norm(x::StateVector) = norm(x.data) # FIXME issue #12 """ normalize(x::StateVector) @@ -31,7 +33,7 @@ normalize(x::StateVector) = x/norm(x) In-place normalization of the given bra or ket so that `norm(x)` is one. """ -normalize!(x::StateVector) = (normalize!(x.data); x) +normalize!(x::StateVector) = (normalize!(x.data); x) # FIXME issue #12 """ normalize(op) diff --git a/src/linalg.jl b/src/linalg.jl index 8bb47cd..076ec32 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -1,10 +1,194 @@ -samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool -samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool -check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r) -multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) -dagger(a::AbstractOperator) = arithmetic_unary_error("Hermitian conjugate", a) -transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a) -directsum(a::AbstractOperator...) = reduce(directsum, a) +## +# Basis checks +## + +const BASES_CHECK = Ref(true) + +""" + @samebases + +Macro to skip checks for same bases. Useful for `*`, `expect` and similar +functions. +""" +macro samebases(ex) + return quote + BASES_CHECK.x = false + local val = $(esc(ex)) + BASES_CHECK.x = true + val + end +end + +""" + samebases(a, b) + +Test if two objects have the same bases. +""" +samebases(b1::Basis, b2::Basis) = b1==b2 +samebases(b1::Tuple{Basis, Basis}, b2::Tuple{Basis, Basis}) = b1==b2 # for checking superoperators + +""" + check_samebases(a, b) + +Throw an [`IncompatibleBases`](@ref) error if the objects don't have +the same bases. +""" +function check_samebases(b1, b2) + if BASES_CHECK[] && !samebases(b1, b2) + throw(IncompatibleBases()) + end +end + + +""" + multiplicable(a, b) + +Check if two objects are multiplicable. +""" +multiplicable(b1::Basis, b2::Basis) = b1==b2 + +function multiplicable(b1::CompositeBasis, b2::CompositeBasis) + if !equal_shape(b1.shape,b2.shape) + return false + end + for i=1:length(b1.shape) + if !multiplicable(b1.bases[i], b2.bases[i]) + return false + end + end + return true +end + +""" + check_multiplicable(a, b) + +Throw an [`IncompatibleBases`](@ref) error if the objects are +not multiplicable. +""" +function check_multiplicable(b1, b2) + if BASES_CHECK[] && !multiplicable(b1, b2) + throw(IncompatibleBases()) + end +end + +samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool # FIXME issue #12 +samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool # FIXME issue #12 +check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r) # FIXME issue #12 +multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) # FIXME issue #12 + +## +# tensor, reduce, ptrace +## + +tensor(b::Basis) = b + +""" + tensor(x::Basis, y::Basis, z::Basis...) + +Create a [`CompositeBasis`](@ref) from the given bases. + +Any given CompositeBasis is expanded so that the resulting CompositeBasis never +contains another CompositeBasis. +""" +tensor(b1::Basis, b2::Basis) = CompositeBasis([length(b1); length(b2)], (b1, b2)) +tensor(b1::CompositeBasis, b2::CompositeBasis) = CompositeBasis([b1.shape; b2.shape], (b1.bases..., b2.bases...)) +function tensor(b1::CompositeBasis, b2::Basis) + N = length(b1.bases) + shape = vcat(b1.shape, length(b2)) + bases = (b1.bases..., b2) + CompositeBasis(shape, bases) +end +function tensor(b1::Basis, b2::CompositeBasis) + N = length(b2.bases) + shape = vcat(length(b1), b2.shape) + bases = (b1, b2.bases...) + CompositeBasis(shape, bases) +end +tensor(bases::Basis...) = reduce(tensor, bases) + +function Base.:^(b::Basis, N::Integer) + if N < 1 + throw(ArgumentError("Power of a basis is only defined for positive integers.")) + end + tensor([b for i=1:N]...) +end + +""" + reduced(a, indices) + +Reduced basis, state or operator on the specified subsystems. + +The `indices` argument, which can be a single integer or a vector of integers, +specifies which subsystems are kept. At least one index must be specified. +""" +function reduced(b::CompositeBasis, indices) + if length(indices)==0 + throw(ArgumentError("At least one subsystem must be specified in reduced.")) + elseif length(indices)==1 + return b.bases[indices[1]] + else + return CompositeBasis(b.shape[indices], b.bases[indices]) + end +end + +""" + ptrace(a, indices) + +Partial trace of the given basis, state or operator. + +The `indices` argument, which can be a single integer or a vector of integers, +specifies which subsystems are traced out. The number of indices has to be +smaller than the number of subsystems, i.e. it is not allowed to perform a +full trace. +""" +function ptrace(b::CompositeBasis, indices) + J = [i for i in 1:length(b.bases) if i ∉ indices] + length(J) > 0 || throw(ArgumentError("Tracing over all indices is not allowed in ptrace.")) + reduced(b, J) +end + ptrace(a::AbstractOperator, index) = arithmetic_unary_error("Partial trace", a) _index_complement(b::CompositeBasis, indices) = complement(length(b.bases), indices) reduced(a, indices) = ptrace(a, _index_complement(basis(a), indices)) +traceout!(s::StateVector, i) = ptrace(s,i) + +## +# nsubsystems +## + +nsubsystems(s::AbstractKet) = nsubsystems(basis(s)) +nsubsystems(s::AbstractOperator) = nsubsystems(basis(s)) +nsubsystems(b::CompositeBasis) = length(b.bases) +nsubsystems(b::Basis) = 1 +nsubsystems(::Nothing) = 1 # TODO Exists because of QuantumSavory; Consider removing this and reworking the functions that depend on it. E.g., a reason to have it when performing a project_traceout measurement on a state that contains only one subsystem + +## +# directsum +## + +""" + directsum(b1::Basis, b2::Basis) + +Construct the [`SumBasis`](@ref) out of two sub-bases. +""" +directsum(b1::Basis, b2::Basis) = SumBasis(Int[length(b1); length(b2)], Basis[b1, b2]) +directsum(b::Basis) = b +directsum(b::Basis...) = reduce(directsum, b) +function directsum(b1::SumBasis, b2::Basis) + shape = [b1.shape;length(b2)] + bases = [b1.bases...;b2] + return SumBasis(shape, (bases...,)) +end +function directsum(b1::Basis, b2::SumBasis) + shape = [length(b1);b2.shape] + bases = [b1;b2.bases...] + return SumBasis(shape, (bases...,)) +end +function directsum(b1::SumBasis, b2::SumBasis) + shape = [b1.shape;b2.shape] + bases = [b1.bases...;b2.bases...] + return SumBasis(shape, (bases...,)) +end + +directsum(x::StateVector...) = reduce(directsum, x) +directsum(a::AbstractOperator...) = reduce(directsum, a) diff --git a/src/show.jl b/src/show.jl new file mode 100644 index 0000000..38607b0 --- /dev/null +++ b/src/show.jl @@ -0,0 +1,69 @@ +import Base: show, summary + +function summary(stream::IO, x::AbstractOperator) + print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n") + if samebases(x) + print(stream, " basis: ") + show(stream, basis(x)) + else + print(stream, " basis left: ") + show(stream, x.basis_l) + print(stream, "\n basis right: ") + show(stream, x.basis_r) + end +end + +show(stream::IO, x::AbstractOperator) = summary(stream, x) + +function show(stream::IO, x::GenericBasis) + if length(x.shape) == 1 + write(stream, "Basis(dim=$(x.shape[1]))") + else + s = replace(string(x.shape), " " => "") + write(stream, "Basis(shape=$s)") + end +end + +function show(stream::IO, x::CompositeBasis) + write(stream, "[") + for i in 1:length(x.bases) + show(stream, x.bases[i]) + if i != length(x.bases) + write(stream, " ⊗ ") + end + end + write(stream, "]") +end + +function show(stream::IO, x::SpinBasis) + d = denominator(x.spinnumber) + n = numerator(x.spinnumber) + if d == 1 + write(stream, "Spin($n)") + else + write(stream, "Spin($n/$d)") + end +end + +function show(stream::IO, x::FockBasis) + if iszero(x.offset) + write(stream, "Fock(cutoff=$(x.N))") + else + write(stream, "Fock(cutoff=$(x.N), offset=$(x.offset))") + end +end + +function show(stream::IO, x::NLevelBasis) + write(stream, "NLevel(N=$(x.N))") +end + +function show(stream::IO, x::SumBasis) + write(stream, "[") + for i in 1:length(x.bases) + show(stream, x.bases[i]) + if i != length(x.bases) + write(stream, " ⊕ ") + end + end + write(stream, "]") +end diff --git a/src/sparse.jl b/src/sparse.jl index 2ba8f5f..d6b301c 100644 --- a/src/sparse.jl +++ b/src/sparse.jl @@ -1,4 +1,4 @@ -# TODO make an extension? +import SparseArrays: sparse, spzeros, AbstractSparseMatrix # TODO move to an extension # dense(a::AbstractOperator) = arithmetic_unary_error("Conversion to dense", a) diff --git a/test/runtests.jl b/test/runtests.jl index 0bccf25..826fe33 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,7 @@ end println("Starting tests with $(Threads.nthreads()) threads out of `Sys.CPU_THREADS = $(Sys.CPU_THREADS)`...") @doset "sortedindices" +@doset "bases" #VERSION >= v"1.9" && @doset "doctests" get(ENV,"JET_TEST","")=="true" && @doset "jet" VERSION >= v"1.9" && @doset "aqua" diff --git a/test/test_bases.jl b/test/test_bases.jl new file mode 100644 index 0000000..1d91673 --- /dev/null +++ b/test/test_bases.jl @@ -0,0 +1,55 @@ +using Test +using QuantumInterface: tensor, ⊗, ptrace, reduced, permutesystems, equal_bases, multiplicable +using QuantumInterface: GenericBasis, CompositeBasis, NLevelBasis, FockBasis + +@testset "basis" begin + +shape1 = [5] +shape2 = [2, 3] +shape3 = [6] + +b1 = GenericBasis(shape1) +b2 = GenericBasis(shape2) +b3 = GenericBasis(shape3) + +@test b1.shape == shape1 +@test b2.shape == shape2 +@test b1 != b2 +@test b1 != FockBasis(2) +@test b1 == b1 + +@test tensor(b1) == b1 +comp_b1 = tensor(b1, b2) +comp_uni = b1 ⊗ b2 +comp_b2 = tensor(b1, b1, b2) +@test comp_b1.shape == [prod(shape1), prod(shape2)] +@test comp_uni.shape == [prod(shape1), prod(shape2)] +@test comp_b2.shape == [prod(shape1), prod(shape1), prod(shape2)] + +@test b1^3 == CompositeBasis(b1, b1, b1) +@test (b1⊗b2)^2 == CompositeBasis(b1, b2, b1, b2) +@test_throws ArgumentError b1^(0) + +comp_b1_b2 = tensor(comp_b1, comp_b2) +@test comp_b1_b2.shape == [prod(shape1), prod(shape2), prod(shape1), prod(shape1), prod(shape2)] +@test comp_b1_b2 == CompositeBasis(b1, b2, b1, b1, b2) + +@test_throws ArgumentError tensor() +@test comp_b2.shape == tensor(b1, comp_b1).shape +@test comp_b2 == tensor(b1, comp_b1) + +@test_throws ArgumentError ptrace(comp_b1, [1, 2]) +@test ptrace(comp_b2, [1]) == ptrace(comp_b2, [2]) == comp_b1 == ptrace(comp_b2, 1) +@test ptrace(comp_b2, [1, 2]) == ptrace(comp_b1, [1]) +@test ptrace(comp_b2, [2, 3]) == ptrace(comp_b1, [2]) +@test ptrace(comp_b2, [2, 3]) == reduced(comp_b2, [1]) +@test_throws ArgumentError reduced(comp_b1, []) + +comp1 = tensor(b1, b2, b3) +comp2 = tensor(b2, b1, b3) +@test permutesystems(comp1, [2,1,3]) == comp2 + +@test !equal_bases([b1, b2], [b1, b3]) +@test !multiplicable(comp1, b1 ⊗ b2 ⊗ NLevelBasis(prod(b3.shape))) + +end # testset From 1b31070ecdc7a3c56886166a99f1871c355d180d Mon Sep 17 00:00:00 2001 From: Akira Kyle Date: Thu, 5 Dec 2024 18:14:56 -0700 Subject: [PATCH 2/7] Deprecate `PauliBasis` and `equal_bases` --- CHANGELOG.md | 4 ++++ Project.toml | 2 +- src/QuantumInterface.jl | 1 + src/bases.jl | 29 +++++------------------------ src/deprecated.jl | 14 ++++++++++++++ test/test_bases.jl | 4 ++-- 6 files changed, 27 insertions(+), 27 deletions(-) create mode 100644 src/deprecated.jl diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c30de4..f086a09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # News +## v0.3.7 - 2024-12-05 + +- Rename `PauliBasis` to `NQubitBasis` with warning, and add deprecation to `equal_bases`. + ## v0.3.6 - 2024-09-08 - Add `coherentstate`, `thermalstate`, `displace`, `squeeze`, `wigner`, previously from QuantumOptics. diff --git a/Project.toml b/Project.toml index d8b70c0..a02a99b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "QuantumInterface" uuid = "5717a53b-5d69-4fa3-b976-0bf2f97ca1e5" authors = ["QuantumInterface.jl contributors"] -version = "0.3.6" +version = "0.3.7" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/QuantumInterface.jl b/src/QuantumInterface.jl index 89d3c9a..c1bc651 100644 --- a/src/QuantumInterface.jl +++ b/src/QuantumInterface.jl @@ -126,5 +126,6 @@ include("julia_linalg.jl") include("sparse.jl") include("sortedindices.jl") +include("deprecated.jl") end # module diff --git a/src/bases.jl b/src/bases.jl index 2c059e0..532ad65 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -65,23 +65,6 @@ function equal_shape(a, b) return true end -""" - equal_bases(a, b) - -Check if two subbases vectors are identical. -""" -function equal_bases(a, b) - if a===b - return true - end - for i=1:length(a) - if a[i]!=b[i] - return false - end - end - return true -end - ## # Common bases ## @@ -126,25 +109,23 @@ end Base.:(==)(b1::NLevelBasis, b2::NLevelBasis) = b1.N == b2.N - """ - PauliBasis(num_qubits::Int) + NQubitBasis(num_qubits::Int) Basis for an N-qubit space where `num_qubits` specifies the number of qubits. -The dimension of the basis is 2²ᴺ. +The dimension of the basis is 2ᴺ. """ -struct PauliBasis{S,B} <: Basis +struct NQubitBasis{S,B} <: Basis shape::S bases::B - function PauliBasis(num_qubits::T) where {T<:Integer} + function NQubitBasis(num_qubits::T) where {T<:Integer} shape = [2 for _ in 1:num_qubits] bases = Tuple(SpinBasis(1//2) for _ in 1:num_qubits) return new{typeof(shape),typeof(bases)}(shape, bases) end end -Base.:(==)(pb1::PauliBasis, pb2::PauliBasis) = length(pb1.bases) == length(pb2.bases) - +Base.:(==)(pb1::NQubitBasis, pb2::NQubitBasis) = length(pb1.bases) == length(pb2.bases) """ SpinBasis(n) diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 0000000..d4aadd5 --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1,14 @@ +function equal_bases(a, b) + Base.depwarn("`==` should be preferred over `equal_bases`!", :equal_bases) + if a===b + return true + end + for i=1:length(a) + if a[i]!=b[i] + return false + end + end + return true +end + +Base.@deprecate PauliBasis(num_qubits) NQubitBasis(num_qubits) false diff --git a/test/test_bases.jl b/test/test_bases.jl index 1d91673..cce4cd2 100644 --- a/test/test_bases.jl +++ b/test/test_bases.jl @@ -1,5 +1,5 @@ using Test -using QuantumInterface: tensor, ⊗, ptrace, reduced, permutesystems, equal_bases, multiplicable +using QuantumInterface: tensor, ⊗, ptrace, reduced, permutesystems, multiplicable using QuantumInterface: GenericBasis, CompositeBasis, NLevelBasis, FockBasis @testset "basis" begin @@ -49,7 +49,7 @@ comp1 = tensor(b1, b2, b3) comp2 = tensor(b2, b1, b3) @test permutesystems(comp1, [2,1,3]) == comp2 -@test !equal_bases([b1, b2], [b1, b3]) +@test [b1, b2] != [b1, b3] @test !multiplicable(comp1, b1 ⊗ b2 ⊗ NLevelBasis(prod(b3.shape))) end # testset From 1672e2f1263242893ca463e912d48459e79832fb Mon Sep 17 00:00:00 2001 From: Akira Kyle Date: Thu, 5 Dec 2024 18:14:56 -0700 Subject: [PATCH 3/7] Move current basis checking to new file deprecated.jl This is in anticipation of new basis checking interface and eventual future deprecation of current basis checking --- src/bases.jl | 20 ------------- src/deprecated.jl | 58 ++++++++++++++++++++++++++++++++++++++ src/linalg.jl | 72 ----------------------------------------------- 3 files changed, 58 insertions(+), 92 deletions(-) diff --git a/src/bases.jl b/src/bases.jl index 532ad65..e178a66 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -45,26 +45,6 @@ CompositeBasis(bases::Vector) = CompositeBasis((bases...,)) Base.:(==)(b1::T, b2::T) where T<:CompositeBasis = equal_shape(b1.shape, b2.shape) -""" - equal_shape(a, b) - -Check if two shape vectors are the same. -""" -function equal_shape(a, b) - if a === b - return true - end - if length(a) != length(b) - return false - end - for i=1:length(a) - if a[i]!=b[i] - return false - end - end - return true -end - ## # Common bases ## diff --git a/src/deprecated.jl b/src/deprecated.jl index d4aadd5..680b8f3 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -12,3 +12,61 @@ function equal_bases(a, b) end Base.@deprecate PauliBasis(num_qubits) NQubitBasis(num_qubits) false + +function equal_shape(a, b) + if a === b + return true + end + if length(a) != length(b) + return false + end + for i=1:length(a) + if a[i]!=b[i] + return false + end + end + return true +end + +macro samebases(ex) + return quote + BASES_CHECK.x = false + local val = $(esc(ex)) + BASES_CHECK.x = true + val + end +end + +function check_samebases(b1, b2) + if BASES_CHECK[] && !samebases(b1, b2) + throw(IncompatibleBases()) + end +end + +function check_multiplicable(b1, b2) + if BASES_CHECK[] && !multiplicable(b1, b2) + throw(IncompatibleBases()) + end +end + +samebases(b1::Basis, b2::Basis) = b1==b2 +samebases(b1::Tuple{Basis, Basis}, b2::Tuple{Basis, Basis}) = b1==b2 # for checking superoperators +samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool # FIXME issue #12 +samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool # FIXME issue #12 +check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r) # FIXME issue #12 + +multiplicable(b1::Basis, b2::Basis) = b1==b2 + +function multiplicable(b1::CompositeBasis, b2::CompositeBasis) + if !equal_shape(b1.shape,b2.shape) + return false + end + for i=1:length(b1.shape) + if !multiplicable(b1.bases[i], b2.bases[i]) + return false + end + end + return true +end + +multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) # FIXME issue #12 diff --git a/src/linalg.jl b/src/linalg.jl index 076ec32..b48f458 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -4,78 +4,6 @@ const BASES_CHECK = Ref(true) -""" - @samebases - -Macro to skip checks for same bases. Useful for `*`, `expect` and similar -functions. -""" -macro samebases(ex) - return quote - BASES_CHECK.x = false - local val = $(esc(ex)) - BASES_CHECK.x = true - val - end -end - -""" - samebases(a, b) - -Test if two objects have the same bases. -""" -samebases(b1::Basis, b2::Basis) = b1==b2 -samebases(b1::Tuple{Basis, Basis}, b2::Tuple{Basis, Basis}) = b1==b2 # for checking superoperators - -""" - check_samebases(a, b) - -Throw an [`IncompatibleBases`](@ref) error if the objects don't have -the same bases. -""" -function check_samebases(b1, b2) - if BASES_CHECK[] && !samebases(b1, b2) - throw(IncompatibleBases()) - end -end - - -""" - multiplicable(a, b) - -Check if two objects are multiplicable. -""" -multiplicable(b1::Basis, b2::Basis) = b1==b2 - -function multiplicable(b1::CompositeBasis, b2::CompositeBasis) - if !equal_shape(b1.shape,b2.shape) - return false - end - for i=1:length(b1.shape) - if !multiplicable(b1.bases[i], b2.bases[i]) - return false - end - end - return true -end - -""" - check_multiplicable(a, b) - -Throw an [`IncompatibleBases`](@ref) error if the objects are -not multiplicable. -""" -function check_multiplicable(b1, b2) - if BASES_CHECK[] && !multiplicable(b1, b2) - throw(IncompatibleBases()) - end -end - -samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool # FIXME issue #12 -samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool # FIXME issue #12 -check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r) # FIXME issue #12 -multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) # FIXME issue #12 - ## # tensor, reduce, ptrace ## From 31982ff6ce5d725bde355f40c25e6b0a1fa19790 Mon Sep 17 00:00:00 2001 From: Akira Kyle Date: Thu, 5 Dec 2024 14:43:29 -0700 Subject: [PATCH 4/7] Implement basis interface proposed in #40 --- src/QuantumInterface.jl | 71 +++++++++++++++++++++++++++++++++++++++-- src/abstract_types.jl | 41 ++++++++++++++---------- src/bases.jl | 16 ++++++++-- src/deprecated.jl | 11 ++----- src/expect_variance.jl | 32 ++++++++++--------- src/julia_base.jl | 3 -- src/linalg.jl | 44 +++++++++++++++++++++++++ 7 files changed, 169 insertions(+), 49 deletions(-) diff --git a/src/QuantumInterface.jl b/src/QuantumInterface.jl index c1bc651..80c7b1a 100644 --- a/src/QuantumInterface.jl +++ b/src/QuantumInterface.jl @@ -7,23 +7,88 @@ module QuantumInterface """ basis(a) -Return the basis of an object. +Return the basis of a quantum object. -If it's ambiguous, e.g. if an operator has a different left and right basis, -an [`IncompatibleBases`](@ref) error is thrown. +If it's ambiguous, e.g. if an operator has a different +left and right basis, an [`IncompatibleBases`](@ref) error is thrown. + +See [`StateVector`](@ref) and [`AbstractOperator`](@ref) """ function basis end +""" + basis_l(a) + +Return the left basis of an operator. +""" +function basis_l end + +""" + basis_r(a) + +Return the right basis of an operator. +""" +function basis_r end + """ Exception that should be raised for an illegal algebraic operation. """ mutable struct IncompatibleBases <: Exception end +#function bases end + +function spinnumber end + +function cutoff end + +function offset end ## # Standard methods ## +""" + multiplicable(a, b) + +Check if any two subtypes of `StateVector` or `AbstractOperator`, +can be multiplied in the given order. +""" +function multiplicable end + +""" + check_multiplicable(a, b) + +Throw an [`IncompatibleBases`](@ref) error if the objects are +not multiplicable as determined by `multiplicable(a, b)`. + +If the macro `@compatiblebases` is used anywhere up the call stack, +this check is disabled. +""" +function check_multiplicable end + +""" + addible(a, b) + +Check if any two subtypes of `StateVector` or `AbstractOperator` + can be added together. + +Spcefically this checks whether the left basis of a is equal +to the left basis of b and whether the right basis of a is equal +to the right basis of b. +""" +function addible end + +""" + check_addible(a, b) + +Throw an [`IncompatibleBases`](@ref) error if the objects are +not addible as determined by `addible(a, b)`. + +If the macro `@compatiblebases` is used anywhere up the call stack, +this check is disabled. +""" +function check_addible end + function apply! end function dagger end diff --git a/src/abstract_types.jl b/src/abstract_types.jl index 0650290..7e63a53 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -1,35 +1,40 @@ """ -Abstract base class for all specialized bases. +Abstract type for all specialized bases of a Hilbert space. -The Basis class is meant to specify a basis of the Hilbert space of the -studied system. Besides basis specific information all subclasses must -implement a shape variable which indicates the dimension of the used -Hilbert space. For a spin-1/2 Hilbert space this would be the -vector `[2]`. A system composed of two spins would then have a -shape vector `[2 2]`. +The `Basis` type specifies an orthonormal basis for the Hilbert +space of the studied system. All subtypes must implement `Base.:(==)`, +and `Base.size`. `size` should return a tuple representing the total dimension +of the Hilbert space with any tensor product structure the basis has such that +`length(b::Basis) = prod(size(b))` gives the total Hilbert dimension -Composite systems can be defined with help of the [`CompositeBasis`](@ref) -class. +Composite systems can be defined with help of [`CompositeBasis`](@ref). + +All relevant properties of subtypes of `Basis` defined in `QuantumInterface` +should be accessed using their documented functions and should not +assume anything about the internal representation of instances of these +types (i.e. don't access the struct's fields directly). """ abstract type Basis end """ -Abstract base class for `Bra` and `Ket` states. +Abstract type for `Bra` and `Ket` states. -The state vector class stores the coefficients of an abstract state -in respect to a certain basis. These coefficients are stored in the -`data` field and the basis is defined in the `basis` -field. +The state vector class stores an abstract state with respect +to a certain basis. All subtypes must implement the `basis` +method which should this basis as a subtype of `Basis`. """ abstract type StateVector{B,T} end abstract type AbstractKet{B,T} <: StateVector{B,T} end abstract type AbstractBra{B,T} <: StateVector{B,T} end """ -Abstract base class for all operators. +Abstract type for all operators and super operators. -All deriving operator classes have to define the fields -`basis_l` and `basis_r` defining the left and right side bases. +All subtypes must implement the methods `basis_l` and +`basis_r` which return subtypes of `Basis` and +represent the left and right bases that the operator +maps between and thus is compatible with a `Bra` defined +in the left basis and a `Ket` defined in the right basis. For fast time evolution also at least the function `mul!(result::Ket,op::AbstractOperator,x::Ket,alpha,beta)` should be @@ -53,3 +58,5 @@ A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)} ``` """ abstract type AbstractSuperOperator{B1,B2} end + +const AbstractQObjType = Union{<:StateVector,<:AbstractOperator} diff --git a/src/bases.jl b/src/bases.jl index e178a66..9f583e6 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -7,7 +7,7 @@ Total dimension of the Hilbert space. """ -Base.length(b::Basis) = prod(b.shape) +Base.length(b::Basis) = prod(b.shape) # change to prod(size(b)) when downstream Bases are updated """ GenericBasis(N) @@ -24,7 +24,7 @@ end GenericBasis(N::Integer) = GenericBasis([N]) Base.:(==)(b1::GenericBasis, b2::GenericBasis) = equal_shape(b1.shape, b2.shape) - +Base.size(b::GenericBasis) = b.shape """ CompositeBasis(b1, b2...) @@ -42,8 +42,11 @@ end CompositeBasis(bases) = CompositeBasis([length(b) for b ∈ bases], bases) CompositeBasis(bases::Basis...) = CompositeBasis((bases...,)) CompositeBasis(bases::Vector) = CompositeBasis((bases...,)) +#bases(b::CompositeBasis) = b.bases Base.:(==)(b1::T, b2::T) where T<:CompositeBasis = equal_shape(b1.shape, b2.shape) +Base.size(b::CompositeBasis) = length.(b.bases) +Base.getindex(b::CompositeBasis, i) = getindex(b.bases, i) ## # Common bases @@ -69,6 +72,9 @@ struct FockBasis{T} <: Basis end Base.:(==)(b1::FockBasis, b2::FockBasis) = (b1.N==b2.N && b1.offset==b2.offset) +Base.size(b::FockBasis) = (b.N - b.offset + 1,) +cutoff(b::FockBasis) = b.N +offset(b::FockBasis) = b.offset """ @@ -88,6 +94,7 @@ struct NLevelBasis{T} <: Basis end Base.:(==)(b1::NLevelBasis, b2::NLevelBasis) = b1.N == b2.N +Base.size(b::NLevelBasis) = (b.N,) """ NQubitBasis(num_qubits::Int) @@ -106,6 +113,7 @@ struct NQubitBasis{S,B} <: Basis end Base.:(==)(pb1::NQubitBasis, pb2::NQubitBasis) = length(pb1.bases) == length(pb2.bases) +Base.size(b::NQubitBasis) = b.shape """ SpinBasis(n) @@ -132,7 +140,8 @@ SpinBasis(spinnumber::Rational) = SpinBasis{spinnumber}(spinnumber) SpinBasis(spinnumber) = SpinBasis(convert(Rational{Int}, spinnumber)) Base.:(==)(b1::SpinBasis, b2::SpinBasis) = b1.spinnumber==b2.spinnumber - +Base.size(b::SpinBasis) = (numerator(b.spinnumber*2 + 1),) +spinnumber(b::SpinBasis) = b.spinnumber """ SumBasis(b1, b2...) @@ -151,3 +160,4 @@ SumBasis(bases::Basis...) = SumBasis((bases...,)) Base.:(==)(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape) Base.:(==)(b1::SumBasis, b2::SumBasis) = false Base.length(b::SumBasis) = sum(b.shape) +# TODO how should `.bases` be accessed? `getindex` or a `sumbases` method? diff --git a/src/deprecated.jl b/src/deprecated.jl index 680b8f3..242844c 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1,3 +1,6 @@ + +basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) # FIXME issue #12 + function equal_bases(a, b) Base.depwarn("`==` should be preferred over `equal_bases`!", :equal_bases) if a===b @@ -43,12 +46,6 @@ function check_samebases(b1, b2) end end -function check_multiplicable(b1, b2) - if BASES_CHECK[] && !multiplicable(b1, b2) - throw(IncompatibleBases()) - end -end - samebases(b1::Basis, b2::Basis) = b1==b2 samebases(b1::Tuple{Basis, Basis}, b2::Tuple{Basis, Basis}) = b1==b2 # for checking superoperators samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool # FIXME issue #12 @@ -68,5 +65,3 @@ function multiplicable(b1::CompositeBasis, b2::CompositeBasis) end return true end - -multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) # FIXME issue #12 diff --git a/src/expect_variance.jl b/src/expect_variance.jl index 86f5e64..678fb5f 100644 --- a/src/expect_variance.jl +++ b/src/expect_variance.jl @@ -3,33 +3,35 @@ If an `index` is given, it assumes that `op` is defined in the subsystem specified by this number. """ -function expect(indices, op::AbstractOperator{B1,B2}, state::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis} - N = length(state.basis_l.shape) - indices_ = complement(N, indices) - expect(op, ptrace(state, indices_)) -end +expect(indices, op::AbstractOperator, state::AbstractOperator) = + expect(op, ptrace(state, complement(nsubsystems(state), indices))) + +expect(index::Integer, op::AbstractOperator, state::AbstractOperator) = expect([index], op, state) -expect(index::Integer, op::AbstractOperator{B1,B2}, state::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis} = expect([index], op, state) expect(op::AbstractOperator, states::Vector) = [expect(op, state) for state=states] + expect(indices, op::AbstractOperator, states::Vector) = [expect(indices, op, state) for state=states] -expect(op::AbstractOperator{B1,B2}, state::AbstractOperator{B2,B2}) where {B1,B2} = tr(op*state) +expect(op::AbstractOperator, state::AbstractOperator) = + (check_multiplicable(state, state); check_multiplicable(op,state); tr(op*state)) """ variance(index, op, state) If an `index` is given, it assumes that `op` is defined in the subsystem specified by this number """ -function variance(indices, op::AbstractOperator{B,B}, state::AbstractOperator{BC,BC}) where {B,BC<:CompositeBasis} - N = length(state.basis_l.shape) - indices_ = complement(N, indices) - variance(op, ptrace(state, indices_)) -end +variance(indices, op::AbstractOperator, state::AbstractOperator) = + variance(op, ptrace(state, complement(nsubsystems(state), indices))) + +variance(index::Integer, op::AbstractOperator, state::AbstractOperator) = variance([index], op, state) -variance(index::Integer, op::AbstractOperator{B,B}, state::AbstractOperator{BC,BC}) where {B,BC<:CompositeBasis} = variance([index], op, state) variance(op::AbstractOperator, states::Vector) = [variance(op, state) for state=states] + variance(indices, op::AbstractOperator, states::Vector) = [variance(indices, op, state) for state=states] -function variance(op::AbstractOperator{B,B}, state::AbstractOperator{B,B}) where B - expect(op*op, state) - expect(op, state)^2 +function variance(op::AbstractOperator, state::AbstractOperator) + check_multiplicable(op,op) + check_multiplicable(state,state) + check_multiplicable(op,state) + @compatiblebases expect(op*op, state) - expect(op, state)^2 end diff --git a/src/julia_base.jl b/src/julia_base.jl index 2d8e085..2eb11fa 100644 --- a/src/julia_base.jl +++ b/src/julia_base.jl @@ -14,7 +14,6 @@ addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an op *(a::StateVector, b::Number) = b*a copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) # FIXME issue #12 length(a::StateVector) = length(a.basis)::Int # FIXME issue #12 -basis(a::StateVector) = a.basis # FIXME issue #12 adjoint(a::StateVector) = dagger(a) @@ -33,8 +32,6 @@ Base.broadcastable(x::StateVector) = x ## length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int # FIXME issue #12 -basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) # FIXME issue #12 -basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) # FIXME issue #12 # Ensure scalar broadcasting Base.broadcastable(x::AbstractOperator) = Ref(x) diff --git a/src/linalg.jl b/src/linalg.jl index b48f458..0fd8a26 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -4,6 +4,50 @@ const BASES_CHECK = Ref(true) +""" + @compatiblebases + +Macro to skip checks for compatible bases. Useful for `*`, `expect` and similar +functions. +""" +macro compatiblebases(ex) + return quote + BASES_CHECK.x = false + local val = $(esc(ex)) + BASES_CHECK.x = true + val + end +end + +function check_addible(b1, b2) + if BASES_CHECK[] && !addible(b1, b2) + throw(IncompatibleBases()) + end +end + +function check_multiplicable(b1, b2) + if BASES_CHECK[] && !multiplicable(b1, b2) + throw(IncompatibleBases()) + end +end + +addible(a::AbstractQObjType, b::AbstractQObjType) = false +addible(a::AbstractBra, b::AbstractBra) = (basis(a) == basis(b)) +addible(a::AbstractKet, b::AbstractKet) = (basis(a) == basis(b)) +addible(a::AbstractOperator, b::AbstractOperator) = + (basis_l(a) == basis_l(b)) && (basis_r(a) == basis_r(b)) + +multiplicable(a::AbstractQObjType, b::AbstractQObjType) = false +multiplicable(a::AbstractBra, b::AbstractKet) = (basis(a) == basis(b)) +multiplicable(a::AbstractOperator, b::AbstractKet) = (basis_r(a) == basis(b)) +multiplicable(a::AbstractBra, b::AbstractOperator) = (basis(a) == basis_l(b)) +multiplicable(a::AbstractOperator, b::AbstractOperator) = (basis_r(a) == basis_l(b)) + +basis(a::StateVector) = throw(ArgumentError("basis() is not defined for this type of state vector: $(typeof(a)).")) +basis_l(a::AbstractOperator) = throw(ArgumentError("basis_l() is not defined for this type of operator: $(typeof(a)).")) +basis_r(a::AbstractOperator) = throw(ArgumentError("basis_r() is not defined for this type of operator: $(typeof(a)).")) +basis(a::AbstractOperator) = (basis_l(a) == basis_r(a); basis_l(a)) + ## # tensor, reduce, ptrace ## From 1514e3128c135bf55f259af93d04cde991452eb1 Mon Sep 17 00:00:00 2001 From: Akira Kyle Date: Mon, 25 Nov 2024 14:46:19 -0700 Subject: [PATCH 5/7] Eliminate type parameters from abstract types --- src/abstract_types.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/abstract_types.jl b/src/abstract_types.jl index 7e63a53..9feac57 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -23,9 +23,9 @@ The state vector class stores an abstract state with respect to a certain basis. All subtypes must implement the `basis` method which should this basis as a subtype of `Basis`. """ -abstract type StateVector{B,T} end -abstract type AbstractKet{B,T} <: StateVector{B,T} end -abstract type AbstractBra{B,T} <: StateVector{B,T} end +abstract type StateVector end +abstract type AbstractKet <: StateVector end +abstract type AbstractBra <: StateVector end """ Abstract type for all operators and super operators. @@ -41,7 +41,7 @@ For fast time evolution also at least the function implemented. Many other generic multiplication functions can be defined in terms of this function and are provided automatically. """ -abstract type AbstractOperator{BL,BR} end +abstract type AbstractOperator end """ Base class for all super operator classes. @@ -57,6 +57,6 @@ A_{bl_1,bl_2} = S_{(bl_1,bl_2) ↔ (br_1,br_2)} B_{br_1,br_2} A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)} ``` """ -abstract type AbstractSuperOperator{B1,B2} end +abstract type AbstractSuperOperator end const AbstractQObjType = Union{<:StateVector,<:AbstractOperator} From f91978bbebb375aaeb98ba7bdbf4141f4089ab35 Mon Sep 17 00:00:00 2001 From: Akira Kyle Date: Mon, 9 Dec 2024 11:18:15 -0700 Subject: [PATCH 6/7] Add bases for superoperators and move AbstractSuperOperatorType to deprecated.jl --- src/abstract_types.jl | 16 --------------- src/bases.jl | 48 +++++++++++++++++++++++++++++++++++++++++++ src/deprecated.jl | 2 ++ 3 files changed, 50 insertions(+), 16 deletions(-) diff --git a/src/abstract_types.jl b/src/abstract_types.jl index 9feac57..b318491 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -43,20 +43,4 @@ terms of this function and are provided automatically. """ abstract type AbstractOperator end -""" -Base class for all super operator classes. - -Super operators are bijective mappings from operators given in one specific -basis to operators, possibly given in respect to another, different basis. -To embed super operators in an algebraic framework they are defined with a -left hand basis `basis_l` and a right hand basis `basis_r` where each of -them again consists of a left and right hand basis. -```math -A_{bl_1,bl_2} = S_{(bl_1,bl_2) ↔ (br_1,br_2)} B_{br_1,br_2} -\\\\ -A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)} -``` -""" -abstract type AbstractSuperOperator end - const AbstractQObjType = Union{<:StateVector,<:AbstractOperator} diff --git a/src/bases.jl b/src/bases.jl index 9f583e6..8249a58 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -161,3 +161,51 @@ Base.:(==)(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape) Base.:(==)(b1::SumBasis, b2::SumBasis) = false Base.length(b::SumBasis) = sum(b.shape) # TODO how should `.bases` be accessed? `getindex` or a `sumbases` method? + +## +# Operator Bases +## + +""" + KetBraBasis(BL,BR) + +Typical "Ket-Bra" outter-product Basis. +TODO: write more... +""" +struct KetBraBasis <: Basis + left::Basis + right::Basis +end +KetBraBasis(b::Basis) = KetBraBasis(b,b) +basis_l(b::KetBraBasis) = b.left +basis_r(b::KetBraBasis) = b.right +Base.:(==)(b1::KetBraBasis, b2::KetBraBasis) = (b1.left == b2.left && b1.right == b2.right) +Base.length(b::KetBraBasis) = length(b.left)*length(b.right) +Base.size(b::KetBraBasis) = (length(b.left), length(b.right)) + +struct ChoiRefSysBasis <: Basis + basis::Basis +end +Base.:(==)(b1::ChoiRefSysBasis, b2::ChoiRefSysBasis) = (b1.basis == b2.basis) +Base.length(b::ChoiRefSysBasis) = length(b.basis) +Base.size(b::ChoiRefSysBasis) = (length(b.basis),) + +struct ChoiOutSysBasis <: Basis + basis::Basis +end +Base.:(==)(b1::ChoiOutSysBasis, b2::ChoiOutSysBasis) = (b1.basis == b2.basis) +Base.length(b::ChoiOutSysBasis) = length(b.basis) +Base.size(b::ChoiOutSysBasis) = (length(b.basis),) + + +""" + _PauliBasis() + +Pauli operator basis consisting of I, Z, X, Y, in that order. +""" +struct _PauliBasis <: Basis +end + +Base.:(==)(pb1::_PauliBasis, pb2::_PauliBasis) = true +Base.length(b::_PauliBasis) = 4 +Base.size(b::_PauliBasis) = (4,) diff --git a/src/deprecated.jl b/src/deprecated.jl index 242844c..5da87c8 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1,3 +1,5 @@ +# TODO: figure out how to deprecate abstract type +abstract type AbstractSuperOperator end basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) # FIXME issue #12 From 278d10a1b861a7ce704025dfcaefb78d4070aa8d Mon Sep 17 00:00:00 2001 From: Akira Kyle Date: Thu, 5 Dec 2024 14:51:15 -0700 Subject: [PATCH 7/7] Add to changelog and bump version to 0.4.0 --- CHANGELOG.md | 11 +++++++++++ Project.toml | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f086a09..0c5965e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # News +## v0.4.0 - 2024-12-10 + +This version implements the RFC in #40 without deprecating anything. Future versions will first add deprecation warnings before removing any existing interfaces. + +- Eliminate all type parameters from `StateVector`, `AbstractKet`, `AbstractBra`, `AbstractOperator`, and `AbstractSuperOperator`. +- Implement new basis checking interface consisting of `multiplicable`, `addible`, `check_multiplicable`, `check_addible`, and `@compatiblebases`. +- Add `basis_l` and `basis_r` to get left and right bases of operators. +- Implement `size` for all bases and implement `getindex` for `CompositeBasis`. +- Add method interface to existing subtypes of `Bases`: `spinnumber`, `cutoff`, `offset`. +- Add `KetBraBasis`, `ChoiRefSysBasis`, `ChoiOutSysBasis`, and `_PauliBasis`. Note that eventually `_PauliBasis` will be renamed to `PauliBasis`. + ## v0.3.7 - 2024-12-05 - Rename `PauliBasis` to `NQubitBasis` with warning, and add deprecation to `equal_bases`. diff --git a/Project.toml b/Project.toml index a02a99b..9f743ed 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "QuantumInterface" uuid = "5717a53b-5d69-4fa3-b976-0bf2f97ca1e5" authors = ["QuantumInterface.jl contributors"] -version = "0.3.7" +version = "0.4.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"