-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor sinkhorn
and sinkhorn2
#100
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…port.jl into dw/sinkhorn_refactor
Pull Request Test Coverage Report for Build 928044999Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
Great PR! I still have to review the code you submitted, but in terms of design, I have a slightly different proposal related to the "problem" and "solver" paradigm. I'd say that, the "problem" would always be the original Optimal Transport problem (i.e., |
Interesting, I'll think about your comment. To me the entropically regularized OT problem is really a separate problem and not just induced by the approximation algorithm. Of course, it is derived from the exact OT problem but it always seemed to me a separate entity. Similar to like in other fields regularized problems are considered separately, e.g., least squares problem with Tikhonov regularization, Lasso or elastic net that have specific mathematical properties. Intuitively, I assume that making the regularization part of the algorithm could lead to problems or inconsistencies when composing different algorithms such as epsilon scaling and different Sinkhorn algorithms - then the regularization would only be a parameter of the subalgorithm but not the main algorithm. |
I think this approach you described is actually the more standard way. But the reason why I tend to think like this in OT is that mostly the regularization is added not to penalize somehow the errors, but just make the original problem solvable. The use of sinkhorn instead of the unregularized problem is usually just for computational reasons (at least in ML, it's the only reason people choose it). If we look at a linear regression, the addition of a regularization is not to "facilitate" the solution of the original regression, but just to emphasize a different aspect. So I see a fundamental distinction on the motivation. But this would be just a way of thinking things. So if it makes coding problematic, I'd say we should just stick with your design. I don't see that much of a gain. |
I also am yet to go through the code, but I think about the above comments I'd also lean towards what is implemented currently. I think in the future it might be useful to have a more general solver that can deal with arbitrary regulariser, using some gradient methods like L-BFGS, similar to the |
Actually none of the problem types or drastic interface changes is implemented in this PR, I just thought it would be worth considering it since If there are problems types for entropically regularized OT, quadratically regularized OT etc. (also a single type that allows to dispatch on the type of regularization would be sufficient), then it would be immediately clear that
The nice thing is that one could define for each algorithm to which problem types it can be applied and to show an error message if the users does use an incorrect algorithm for the provided problem. Or one couldn't allow it at all by using a type structure on the algorithms and problems (seems a bit less flexible though). And similarly for some algorithms one could define that they may be applied to OT problems with arbitrary regulariser. |
What's the status here? Did you have a look at the changes? |
function build_solver( | ||
alg::Sinkhorn, | ||
μ::AbstractVecOrMat, | ||
ν::AbstractVecOrMat, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mu
and nu
here are actually not measures, but just the support, right? Maybe in a future PR we should call this musupp
and nusupp
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not the support, it's the histograms (so also not fully specified measures...). The support is implicitly part of the cost matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right! You see, that's why we should specify this. I always confuse haha
end | ||
|
||
return u, v | ||
return nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return nothing
is equivalent to not having a return
? Any reason why using one instead of the other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, it's equivalent. BlueStyle wants it 🙂
``` | ||
where ``\\lambda`` is the scaling factor and ``k`` the number of scaling steps. | ||
""" | ||
function SinkhornEpsilonScaling(alg::Sinkhorn; factor::Real=1//2, steps::Int=5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any advantage in using fractions here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not promote which is relevant e.g. if you want to keep Float32
(otherwise they are promoted to Float64
which leads to terrible performance with GPUs).
src/entropic/sinkhorn_stabilized.jl
Outdated
```julia | ||
isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1)) | ||
sinkhorn( | ||
SinkhornStabilized(; absorb_tol=absorb_tol), μ, ν, C, ε; kwargs... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the algorithm the first argument? Isn't it the last? I mean sinkhorn(μ, ν, C, ε, SinkhornStabilized(; absorb_tol=absorb_tol)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah, initially I had it as the first argument but then switched it around. Seems I forgot to update the documentation, good catch!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
src/entropic/sinkhorn.jl
Outdated
K = @. exp(-C / ε) | ||
function sinkhorn(μ, ν, C, ε, alg::Sinkhorn; kwargs...) | ||
# build solver | ||
solver = build_solver(alg, μ, ν, C, ε; kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't be better to have the same order of arguments for sinkhorn
and build_solver
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably, again I assume it was since initially I used a different order in sinkhorn
. I'll update it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
for alg in (SinkhornGibbs(), SinkhornStabilized()) | ||
# compute optimal transport plan and cost | ||
γ = sinkhorn(μ, ν, C, ε, SinkhornEpsilonScaling(alg); maxiter=5_000) | ||
c = sinkhorn2(μ, ν, C, ε, SinkhornEpsilonScaling(alg); maxiter=5_000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use cost
and leave c
to be the OT cost function, while C
is the cost matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reading the rest of the code for the tests, I see that this would require a huge amount of refactoring in the tests, so I think it's better to just leave it be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think more consistent names are a good idea but I assume it might be better to do this in a separate PR since it might also affect other unrelated parts.
I don't know much about Sinkhorn other than the regular implementation. So I tried reading the whole code, but I was not very thorough, specially in the unbalanced or the other modifications, such as epsilon scalling. |
GPU tests pass now (new CUDA + NNlibCUDA versions fix some bugs in CUDA that caused timeouts when instantiating a new project environment as done in our tests). |
I apologize in advance for the amount of changes in this PR but I didn't manage to keep it smaller without leaving the package in a half-broken state.
Basically, this PR unifies
sinkhorn
,sinkhorn_stabilized
, andsinkhorn_stabilized_epsscaling
with a singlesinkhorn(mu, nu, C, eps, algorithm; kwargs...)
syntax (the namesinkhorn
seems a bit redundant, see comments below). More details:sinkhorn(mu, nu, C, eps; kwargs...) = sinkhorn(mu, nu, C, eps, SinkhornGibbs(); kwargs...)
, i.e., without algorithm still the standard Sinkhorn algorithm is used as beforesinkhorn2
allows to specify an algorithm as well (default:SinkhornGibbs()
), and hence supports all Sinkhorn variants (and also supports the additional regularization term for all algorithms)SinkhornEpsilonScaling
works both with the regular Sinkhorn algorithm and the stabilized versionSinkhornStabilized
supports batches of histograms as well (requiresNNlib.batched_mul!
for performance and GPU support)The name
sinkhorn(...)
seems a bit redundant since it is already implied by the name of the algorithms (apart from the default choice). However, I did not want to make too many changes in the PR. One approach would be to define a type for entropically regularized OT problems (this could also be used in theSinkhornSolver
struct to group together source and target marginals, cost matrix and regularization parameter) and then to use thesolve
approach (which is already used internally now without depending on CommonSolve). Then one could e.g. still useSinkhornGibbs()
as the default algorithm for such problems but also writesolve(prob, SinkhornStabilized(); kwargs...)
etc. One could even keepsinkhorn(mu, nu, C, eps, alg; kwargs...) = solve(EntropicOTProblem(mu, nu, C, eps), alg; kwargs...)
.