Skip to content

Commit

Permalink
remove usage of eachcol
Browse files Browse the repository at this point in the history
  • Loading branch information
zsteve committed Sep 22, 2021
1 parent 73ff802 commit cee80a6
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 14 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ authors = ["zsteve <[email protected]>"]
version = "0.3.17"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -13,7 +12,6 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[compat]
Compat = "2.2, 3"
ExactOptimalTransport = "0.1"
IterativeSolvers = "0.8.4, 0.9"
LogExpFunctions = "0.2, 0.3"
Expand Down
1 change: 0 additions & 1 deletion src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ using LinearAlgebra
using IterativeSolvers
using LogExpFunctions: LogExpFunctions
using NNlib: NNlib
using Compat

export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling
export SinkhornBarycenterGibbs
Expand Down
5 changes: 4 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ end
dot_matwise(x::AbstractMatrix, y::AbstractArray) = dot_matwise(y, x)

function dot_vecwise(x::AbstractMatrix, y::AbstractMatrix)
Compat.@compat return [dot(u, v) for (u, v) in zip(eachcol(x), eachcol(y))]
return [
dot(u, v) for (u, v) in
zip((view(x, :, i) for i in axes(x, 2)), (view(y, :, i) for i in axes(y, 2)))
]
end

dot_vecwise(x::AbstractMatrix, y::AbstractVector) = x' * y
Expand Down
19 changes: 9 additions & 10 deletions test/entropic/sinkhorn_divergence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using PythonOT: PythonOT
using LinearAlgebra
using Random
using Test
using Compat

const POT = PythonOT

Expand Down Expand Up @@ -53,19 +52,19 @@ Random.seed!(100)
ν = hcat([normalize!(f.(x; μ=randn(), σ=0.5), 1) for _ in 1:M]...)
for reg in (true, false)
loss_batch = sinkhorn_divergence(μ, ν, C, ε; regularization=reg)
@compat @test loss_batch [
sinkhorn_divergence(x, y, C, ε; regularization=reg) for
(x, y) in zip(eachcol(μ), eachcol(ν))
@test loss_batch [
sinkhorn_divergence(μ[:, i], ν[:, i], C, ε; regularization=reg) for
i in 1:M
]
loss_batch_μ = sinkhorn_divergence(μ, ν[:, 1], C, ε; regularization=reg)
@compat @test loss_batch_μ [
sinkhorn_divergence(x, ν[:, 1], C, ε; regularization=reg) for
x in eachcol(μ)
@test loss_batch_μ [
sinkhorn_divergence(μ[:, i], ν[:, 1], C, ε; regularization=reg) for
i in 1:M
]
loss_batch_ν = sinkhorn_divergence(μ[:, 1], ν, C, ε; regularization=reg)
@compat @test loss_batch_ν [
sinkhorn_divergence(μ[:, 1], y, C, ε; regularization=reg) for
y in eachcol(ν)
@test loss_batch_ν [
sinkhorn_divergence(μ[:, 1], ν[:, i], C, ε; regularization=reg) for
i in 1:M
]
end
end
Expand Down

0 comments on commit cee80a6

Please sign in to comment.