Skip to content

Commit

Permalink
Fix argmin/argmax for non-vectors
Browse files Browse the repository at this point in the history
Make sure `argmin` and `argmax` yield the correct return type when
falling back to `indmin` and `indmax`, at least for `AbstractArray` and
`Associative` arguments.
  • Loading branch information
martinholters committed Jan 25, 2018
1 parent edc9a63 commit 4a3d938
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1311,8 +1311,17 @@ if !isdefined(Base, :findall)
end

@static if !isdefined(Base, :argmin)
const argmin = indmin
const argmax = indmax
if VERSION >= v"0.7.0-DEV.1660" # indmin/indmax return key
const argmin = indmin
const argmax = indmax
else
argmin(x::AbstractArray) = CartesianIndex(ind2sub(x, indmin(x)))
argmin(x::AbstractVector) = indmin(x)
argmin(x::Associative) = first(Iterators.drop(keys(x), indmin(values(x))-1))
argmax(x::AbstractArray) = CartesianIndex(ind2sub(x, indmax(x)))
argmax(x::AbstractVector) = indmax(x)
argmax(x::Associative) = first(Iterators.drop(keys(x), indmax(values(x))-1))
end
export argmin, argmax
end

Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,11 @@ end

# 0.7.0-DEV.3516
@test argmax([10,12,9,11]) == 2
@test argmax([10 12; 9 11]) == CartesianIndex(1, 2)
@test argmax(Dict(:z=>10, :y=>12, :x=>9, :w=>11)) == :y
@test argmin([10,12,9,11]) == 3
@test argmin([10 12; 9 11]) == CartesianIndex(2, 1)
@test argmin(Dict(:z=>10, :y=>12, :x=>9, :w=>11)) == :x

# 0.7.0-DEV.3415
@test findall(x -> x==1, [1, 2, 3, 2, 1]) == [1, 5]
Expand Down

0 comments on commit 4a3d938

Please sign in to comment.