Skip to content

Commit

Permalink
Simplify stabilized Sinkhorn algorithm (#99)
Browse files Browse the repository at this point in the history
* Move file

* Move test file

* Improve stabilized Sinkhorn algorithm

* Add test and remove redundant computations

* Bump version
  • Loading branch information
devmotion authored Jun 6, 2021
1 parent 24ecf5d commit 3674ceb
Show file tree
Hide file tree
Showing 8 changed files with 451 additions and 202 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimalTransport"
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
authors = ["zsteve <[email protected]>"]
version = "0.3.10"
version = "0.3.11"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand Down
3 changes: 2 additions & 1 deletion src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ include("distances/bures.jl")
include("utils.jl")
include("exact.jl")
include("wasserstein.jl")
include("entropic.jl")
include("entropic/sinkhorn.jl")
include("entropic/sinkhorn_stabilized.jl")
include("quadratic.jl")

end
130 changes: 0 additions & 130 deletions src/entropic.jl → src/entropic/sinkhorn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -461,136 +461,6 @@ function sinkhorn_unbalanced2(
return dot(γ, C)
end

"""
sinkhorn_stabilized_epsscaling(μ, ν, C, ε; lambda = 0.5, k = 5, kwargs...)
Compute the optimal transport plan for the entropically regularized optimal transport problem
with source and target marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, and entropic regularisation parameter `ε`. Employs the log-domain stabilized algorithm of Schmitzer et al. [^S19] with ε-scaling.
`k` ε-scaling steps are used with scaling factor `lambda`, i.e. sequentially solve Sinkhorn using `sinkhorn_stabilized` with regularisation parameters
``ε_i \\in [λ^{1-k}, \\ldots, λ^{-1}, 1] \\times ε``.
See also: [`sinkhorn_stabilized`](@ref), [`sinkhorn`](@ref)
"""
function sinkhorn_stabilized_epsscaling(μ, ν, C, ε; lambda=0.5, k=5, kwargs...)
α = zero(μ)
β = zero(ν)
for ε_i in* lambda^(1 - j) for j in k:-1:1)
@debug "Epsilon-scaling Sinkhorn algorithm: ε = $ε_i"
α, β = sinkhorn_stabilized(
μ, ν, C, ε_i; alpha=α, beta=β, return_duals=true, kwargs...
)
end
gamma = similar(C)
getK!(gamma, C, α, β, ε, μ, ν)
return gamma
end

function getK!(K, C, α, β, ε, μ, ν)
@. K = exp(-(C - α - β') / ε) * μ * ν'
return K
end

"""
sinkhorn_stabilized(μ, ν, C, ε; absorb_tol = 1e3, alpha_0 = zero(μ), beta = zero(ν), maxiter = 1_000, atol = tol, rtol=nothing, return_duals = false)
Compute the optimal transport plan for the entropically regularized optimal transport problem
with source and target marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, and entropic regularisation parameter `ε`. Employs the log-domain stabilized algorithm of Schmitzer et al. [^S19]
`alpha` and `beta` are initial scalings for the stabilized Gibbs kernel. If not specified, `alpha` and `beta` are initialised to zero.
If `return_duals = true`, then the optimal dual variables `(u, v)` corresponding to `(μ, ν)` are returned. Otherwise, the coupling `γ` is returned.
[^S19]: Schmitzer, B., 2019. Stabilized sparse scaling algorithms for entropy regularized transport problems. SIAM Journal on Scientific Computing, 41(3), pp.A1443-A1481.
See also: [`sinkhorn`](@ref)
"""
function sinkhorn_stabilized(
μ,
ν,
C,
ε;
absorb_tol=1e3,
maxiter=1_000,
tol=nothing,
atol=tol,
rtol=nothing,
check_convergence=10,
alpha=zero(μ),
beta=zero(ν),
return_duals=false,
)
if tol !== nothing
Base.depwarn(
"keyword argument `tol` is deprecated, please use `atol` and `rtol`",
:sinkhorn_stabilized,
)
end
sum(μ) sum(ν) ||
throw(ArgumentError("source and target marginals must have the same mass"))

T = float(Base.promote_eltype(μ, ν, C))
_atol = atol === nothing ? 0 : atol
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol

norm_μ = sum(abs, μ)
isconverged = false

K = similar(C)
gamma = similar(C)

getK!(K, C, alpha, beta, ε, μ, ν)
u = μ ./ sum(K; dims=2)
v = ν ./ (K' * u)
tmp_u = similar(u)
for iter in 0:maxiter
if (max(norm(u, Inf), norm(v, Inf)) > absorb_tol)
@debug "Absorbing (u, v) into (alpha, beta)"
# absorb into α, β
alpha += ε * log.(u)
beta += ε * log.(v)
u .= 1
v .= 1
getK!(K, C, alpha, beta, ε, μ, ν)
end
if iter % check_convergence == 0
# check marginal
getK!(gamma, C, alpha, beta, ε, μ, ν)
@. gamma *= u * v'
norm_diff = sum(abs, gamma * ones(size(ν)) - μ)
norm_uKv = sum(abs, gamma)
@debug "Stabilized Sinkhorn algorithm (" *
string(iter) *
"/" *
string(maxiter) *
": error of source marginal = " *
string(norm_diff)

if norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv))
@debug "Stabilized Sinkhorn algorithm ($iter/$maxiter): converged"
isconverged = true
break
end
end
mul!(tmp_u, K, v)
u = μ ./ tmp_u
mul!(v, K', u)
v = ν ./ v
end

if !isconverged
@warn "Stabilized Sinkhorn algorithm ($maxiter/$maxiter): not converged"
end

alpha = alpha + ε * log.(u)
beta = beta + ε * log.(v)
if return_duals
return alpha, beta
end
getK!(gamma, C, alpha, beta, ε, μ, ν)
return gamma
end

"""
sinkhorn_barycenter(μ, C, ε, w; tol=1e-9, check_marginal_step=10, max_iter=1000)
Expand Down
Loading

2 comments on commit 3674ceb

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/38277

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.11 -m "<description of version>" 3674ceb465938c7553d7e77188ff1d5597d30c2b
git push origin v0.3.11

Please sign in to comment.