-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify stabilized Sinkhorn algorithm #99
Conversation
Pull Request Test Coverage Report for Build 911412545
💛 - Coveralls |
A comment would be that the |
Sure, this could be done but it will decrease the performance of julia> using OptimalTransport
julia> using Distances
julia> using LinearAlgebra
julia> using BenchmarkTools
julia> M = 250
250
julia> N = 200
200
julia> μ = normalize!(rand(M), 1);
julia> ν = normalize!(rand(N), 1);
julia> C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2);
julia> ε = 0.01
0.01
julia> @benchmark sinkhorn($μ, $ν, $C, $ε)
BenchmarkTools.Trial:
memory estimate: 789.73 KiB
allocs estimate: 12
--------------
minimum time: 7.257 ms (0.00% GC)
median time: 9.361 ms (0.00% GC)
mean time: 9.343 ms (0.25% GC)
maximum time: 46.595 ms (0.00% GC)
--------------
samples: 535
evals/sample: 1
julia> @benchmark sinkhorn_stabilized($μ, $ν, $C, $ε; absorb_tol=$Inf)
BenchmarkTools.Trial:
memory estimate: 402.73 KiB
allocs estimate: 8
--------------
minimum time: 10.223 ms (0.00% GC)
median time: 12.510 ms (0.00% GC)
mean time: 12.614 ms (0.11% GC)
maximum time: 16.507 ms (0.00% GC)
--------------
samples: 396
evals/sample: 1 The allocations are lower for An additional advantage of a separate |
ν = normalize!(rand(N), 1) | ||
|
||
# create random cost matrix | ||
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious. Any reason why doing C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)
instead of C = pairwise(SqEuclidean(), rand(M), rand(N); dims=1)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latter does not work since the pairwise
implementation in Distances requires a matrix, not a vector. Of course, one could use rand(M, 1)
etc., I'm just more used to dims=2
since it was the default (and still is even though default values are deprecated).
Pull Request Test Coverage Report for Build 910472106Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
This PR simplifies the stabilized Sinkhorn algorithm and extend its tests. It minimizes the number of allocations and fixes a type instability by using a cache that is mutated instead of the
alpha
,beta
, andreturn_duals
keyword arguments.There are two possible follow-up PRs here:
sinkhorn_gibbs
sinkhorn_gibbs
With the caches, it is also quite natural to implement the
init
/solve!
interface that would allow to definefor both
alg = SinkhornStabilized()
andalg = SinkhornGibbs()
. One could maybe also wrap the other input arguments in a problem type as suggested in #63 and include the problem and the options in the solver such that one can just callsolve!(solver)
without additional arguments.