Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add spin_pols_iter to iterate over a process' spin/pol combinations #118

Merged
merged 4 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8"
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
QEDcore = "35dc0263-cb5f-4c33-a114-1d7f54ab753e"
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
4 changes: 3 additions & 1 deletion src/QEDbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ export AbstractDefinitePolarization, AbstractIndefinitePolarization
export PolarizationX, PolX, PolarizationY, PolY, AllPolarization, AllPol
export AbstractDefiniteSpin, AbstractIndefiniteSpin
export SpinUp, SpinDown, AllSpin
export SyncedSpin, SyncedPolarization
export SyncedSpin, SyncedPolarization, SyncedPol
export spin_pols_iter

# probabilities
export differential_probability, unsafe_differential_probability
Expand Down Expand Up @@ -114,6 +115,7 @@ include("interfaces/phase_space_point.jl")
include("implementations/process/momenta.jl")
include("implementations/process/particles.jl")
include("implementations/process/spin_pols.jl")
include("implementations/process/spin_pol_iterator.jl")

include("implementations/cross_section/diff_probability.jl")
include("implementations/cross_section/diff_cross_section.jl")
Expand Down
132 changes: 132 additions & 0 deletions src/implementations/process/spin_pol_iterator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
SpinPolIter

An iterator type to iterate over spin and polarization combinations. Should be used through [`spin_pols_iter`](@ref).
"""
struct SpinPolIter{I,O}
# product iterator doing the actual iterating
product_iter::Base.Iterators.ProductIterator
# lookup table for which indices go where, translating the base iterator to the actual result
indexing_lut::Tuple{NTuple{I,Int},NTuple{O,Int}}
end

"""
all_spin_pols(process::AbstractProcessDefinition)

This function returns an iterator, yielding every fully definite combination of spins and polarizations allowed by the
process' [`spin_pols`](@ref). Each returned element is a Tuple of the incoming and the outgoing spins and polarizations,
in the order of the process' own spins and polarizations.

This works together with the definite spins and polarizations, [`AllSpin`](@ref), [`AllPolarization`](@ref), and the synced versions
[`SyncedPolarization`](@ref) and [`SyncedSpin`](@ref).

```jldoctest
julia> using QEDbase; using QEDcore; using QEDprocesses;

julia> proc = ScatteringProcess((Photon(), Photon(), Photon(), Electron()), (Photon(), Electron()), (SyncedPolarization(1), SyncedPolarization(2), SyncedPolarization(1), SpinUp()), (SyncedPolarization(2), AllSpin()))
generic QED process
incoming: photon (synced polarization 1), photon (synced polarization 2), photon (synced polarization 1), electron (spin up)
outgoing: photon (synced polarization 2), electron (all spins)


julia> for sp_combo in spin_pols_iter(proc) println(sp_combo) end
((x-polarized, x-polarized, x-polarized, spin up), (x-polarized, spin up))
((y-polarized, x-polarized, y-polarized, spin up), (x-polarized, spin up))
((x-polarized, y-polarized, x-polarized, spin up), (y-polarized, spin up))
((y-polarized, y-polarized, y-polarized, spin up), (y-polarized, spin up))
((x-polarized, x-polarized, x-polarized, spin up), (x-polarized, spin down))
((y-polarized, x-polarized, y-polarized, spin up), (x-polarized, spin down))
((x-polarized, y-polarized, x-polarized, spin up), (y-polarized, spin down))
((y-polarized, y-polarized, y-polarized, spin up), (y-polarized, spin down))

julia> length(spin_pols_iter(proc))
8
```
"""
function spin_pols_iter(process::AbstractProcessDefinition)
DEF_SPINS = (SpinUp(), SpinDown())
DEF_POLS = (PolX(), PolY())

in_sp = incoming_spin_pols(process)
I = length(in_sp)
out_sp = outgoing_spin_pols(process)
O = length(out_sp)

# concatenate for now for easier indices, split again later
sps = (in_sp..., out_sp...)

# keep indices of first seen SyncedSpins or SyncedPols
synced_seen = Dict{AbstractSpinOrPolarization,Int}()
index = 0
for sp in sps
index += 1
if !(sp isa SyncedSpin || sp isa SyncedPolarization)
continue
end
if !haskey(synced_seen, sp)
synced_seen[sp] = index
end
end

# keep indices of the synced spins/pols in the iterator (not necessarily the same as synced_seen)
synced_indices = Dict{AbstractSpinOrPolarization,Int}()

iter_tuples = Vector()
lut = Vector{Int}()
index = 0
for sp in sps
index += 1
if sp isa AbstractDefiniteSpin || sp isa AbstractDefinitePolarization
push!(iter_tuples, (sp,))
push!(lut, length(iter_tuples))
elseif sp isa SyncedSpin
# check if it's the first synced
if index == synced_seen[sp]
push!(iter_tuples, DEF_SPINS)
synced_indices[sp] = length(iter_tuples)
end
push!(lut, synced_indices[sp])
elseif sp isa SyncedPolarization
if index == synced_seen[sp]
push!(iter_tuples, DEF_POLS)
synced_indices[sp] = length(iter_tuples)
end
push!(lut, synced_indices[sp])
elseif sp isa AllSpin
push!(iter_tuples, DEF_SPINS)
push!(lut, length(iter_tuples))
elseif sp isa AllPol
push!(iter_tuples, DEF_POLS)
push!(lut, length(iter_tuples))
end
end

return SpinPolIter(
Iterators.product(iter_tuples...),
(tuple(lut[begin:I]...), tuple(lut[(I + 1):end]...)),
)
end

function Base.iterate(iterator::SpinPolIter, state=nothing)
local prod_iter_res
if isnothing(state)
prod_iter_res = iterate(iterator.product_iter)
else
prod_iter_res = iterate(iterator.product_iter, state)
end

if isnothing(prod_iter_res)
return nothing
end
prod_iter_res, state = prod_iter_res

# translate prod_iter_res into actual result
in_t = ((prod_iter_res[i] for i in iterator.indexing_lut[1])...,)
out_t = ((prod_iter_res[i] for i in iterator.indexing_lut[2])...,)

return (in_t, out_t), state
end

function Base.length(iterator::SpinPolIter)
return length(iterator.product_iter)
end
3 changes: 3 additions & 0 deletions src/interfaces/particles/spin_pol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ struct SyncedPolarization{N} <: AbstractIndefinitePolarization
return new{N}()
end
end
const SyncedPol = SyncedPolarization
Base.show(io::IO, ::SyncedPolarization{N}) where {N} = print(io, "synced polarization $N")

"""
SyncedSpin{N::Int} <: AbstractIndefiniteSpin
Expand All @@ -213,3 +215,4 @@ struct SyncedSpin{N} <: AbstractIndefiniteSpin
return new{N}()
end
end
Base.show(io::IO, ::SyncedSpin{N}) where {N} = print(io, "synced spin $N")
41 changes: 41 additions & 0 deletions test/particle_properties.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
using QEDbase
using QEDcore
using StaticArrays
using Random

include("test_implementation/TestImplementation.jl")
using .TestImplementation: TestParticleBoson, TestParticleFermion

# test function to test scalar broadcasting
test_broadcast(x::AbstractParticle) = x
test_broadcast(x::ParticleDirection) = x
Expand Down Expand Up @@ -29,3 +33,40 @@ test_broadcast(x::AbstractSpinOrPolarization) = x
end
end
end

TESTPROCS = (
TestImplementation.TestProcessSP(
(TestParticleBoson(), TestParticleFermion()),
(TestParticleBoson(), TestParticleFermion()),
(AllPol(), AllSpin()),
(AllPol(), AllSpin()),
),
TestImplementation.TestProcessSP(
(TestParticleBoson(), TestParticleBoson(), TestParticleFermion()),
(TestParticleBoson(), TestParticleFermion()),
(SyncedPol(1), SyncedPol(1), AllSpin()),
(AllPol(), AllSpin()),
),
TestImplementation.TestProcessSP(
(TestParticleBoson(), TestParticleBoson(), TestParticleFermion()),
(TestParticleBoson(), TestParticleFermion()),
(SyncedPol(1), SyncedPol(2), SyncedSpin(2)),
(SyncedPol(2), SyncedSpin(2)),
),
)

@testset "spin_pol iterator ($proc)" for proc in TESTPROCS
@test length(spin_pols_iter(proc)) == multiplicity(proc)

for combinations in spin_pols_iter(proc)
@test length(combinations) == 2
in_comb, out_comb = combinations

@test length(in_comb) == length(incoming_particles(proc))
@test length(out_comb) == length(outgoing_particles(proc))

for sp in Iterators.flatten((in_comb, out_comb))
@test sp isa AbstractDefiniteSpin || sp isa AbstractDefinitePolarization
end
end
end
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using QEDbase
using Test
using SafeTestsets

Expand Down
Loading