Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify stabilized Sinkhorn algorithm #99

Merged
merged 5 commits into from
Jun 6, 2021
Merged

Simplify stabilized Sinkhorn algorithm #99

merged 5 commits into from
Jun 6, 2021

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented Jun 6, 2021

This PR simplifies the stabilized Sinkhorn algorithm and extend its tests. It minimizes the number of allocations and fixes a type instability by using a cache that is mutated instead of the alpha, beta, and return_duals keyword arguments.

There are two possible follow-up PRs here:

  • extend to batch computations in the same way as sinkhorn_gibbs
  • also use a more explicit cache in sinkhorn_gibbs

With the caches, it is also quite natural to implement the init/solve! interface that would allow to define

init(mu, nu, C, eps, alg; kwargs...)

for both alg = SinkhornStabilized() and alg = SinkhornGibbs(). One could maybe also wrap the other input arguments in a problem type as suggested in #63 and include the problem and the options in the solver such that one can just call solve!(solver) without additional arguments.

@devmotion devmotion marked this pull request as ready for review June 6, 2021 00:08
@coveralls
Copy link

coveralls commented Jun 6, 2021

Pull Request Test Coverage Report for Build 911412545

  • 100 of 106 (94.34%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.3%) to 94.572%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/entropic/sinkhorn_stabilized.jl 100 106 94.34%
Totals Coverage Status
Change from base Build 910296728: -0.3%
Covered Lines: 453
Relevant Lines: 479

💛 - Coveralls

@zsteve
Copy link
Member

zsteve commented Jun 6, 2021

A comment would be that the sinkhorn_stabilized algorithm in fact includes the case of sinkhorn_gibbs when we have alpha = beta = 0 and we set the threshold for absorbing u, v to Inf. In this case, getK would yield the standard Gibbs kernel and only be computed once, and at least off the top of my head there shouldn't be any additional computational cost. So I wonder if we should consider merging stabilization into sinkhorn_gibbs?

@devmotion
Copy link
Member Author

Sure, this could be done but it will decrease the performance of sinkhorn_gibbs since one has to evaluate maximum(abs, u) and maximum(abs, v) in every step and it requires an additional log and exp to compute the plan:

julia> using OptimalTransport

julia> using Distances

julia> using LinearAlgebra

julia> using BenchmarkTools

julia> M = 250
250

julia> N = 200
200

julia> μ = normalize!(rand(M), 1);

julia> ν = normalize!(rand(N), 1);

julia> C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2);

julia> ε = 0.01
0.01

julia> @benchmark sinkhorn($μ, $ν, $C, $ε)
BenchmarkTools.Trial: 
  memory estimate:  789.73 KiB
  allocs estimate:  12
  --------------
  minimum time:     7.257 ms (0.00% GC)
  median time:      9.361 ms (0.00% GC)
  mean time:        9.343 ms (0.25% GC)
  maximum time:     46.595 ms (0.00% GC)
  --------------
  samples:          535
  evals/sample:     1

julia> @benchmark sinkhorn_stabilized($μ, $ν, $C, $ε; absorb_tol=$Inf)
BenchmarkTools.Trial: 
  memory estimate:  402.73 KiB
  allocs estimate:  8
  --------------
  minimum time:     10.223 ms (0.00% GC)
  median time:      12.510 ms (0.00% GC)
  mean time:        12.614 ms (0.11% GC)
  maximum time:     16.507 ms (0.00% GC)
  --------------
  samples:          396
  evals/sample:     1

The allocations are lower for sinkhorn_stabilized since no new array is allocated for the plan but Ktilde is updated in-place.

An additional advantage of a separate sinkhorn_gibbs (or SinkhornGibbs() when one uses different algorithms) is that it is easy to read and debug and hence can be used as a reliable baseline for more complicated methods.

ν = normalize!(rand(N), 1)

# create random cost matrix
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious. Any reason why doing C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) instead of C = pairwise(SqEuclidean(), rand(M), rand(N); dims=1)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latter does not work since the pairwise implementation in Distances requires a matrix, not a vector. Of course, one could use rand(M, 1) etc., I'm just more used to dims=2 since it was the default (and still is even though default values are deprecated).

@devmotion devmotion merged commit 3674ceb into master Jun 6, 2021
@devmotion devmotion deleted the dw/stabilized branch June 6, 2021 12:42
@coveralls
Copy link

coveralls commented Jul 19, 2024

Pull Request Test Coverage Report for Build 910472106

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 100 of 106 (94.34%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.3%) to 94.572%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/entropic/sinkhorn_stabilized.jl 100 106 94.34%
Totals Coverage Status
Change from base Build 910296728: -0.3%
Covered Lines: 453
Relevant Lines: 479

💛 - Coveralls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants