diff --git a/base/multidimensional.jl b/base/multidimensional.jl index 883476c965d17..bfa0ce5851e09 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -410,3 +410,78 @@ for (V, PT, BT) in [((:N,), BitArray, BitArray), ((:T,:N), Array, StridedArray)] return P end end + +## unique across dim + +immutable Prehashed + hash::Uint +end +hash(x::Prehashed) = x.hash + +@ngenerate N typeof(A) function unique{T,N}(A::AbstractArray{T,N}, dim::Int) + 1 <= dim <= N || return copy(A) + hashes = zeros(Uint, size(A, dim)) + + # Compute hash for each row + k = 0 + @nloops N i A d->(if d == dim; k = i_d; end) begin + @inbounds hashes[k] = bitmix(hashes[k], hash((@nref N A i))) + end + + # Collect index of first row for each hash + uniquerow = Array(Int, size(A, dim)) + firstrow = Dict{Prehashed,Int}() + for k = 1:size(A, dim) + uniquerow[k] = get!(firstrow, Prehashed(hashes[k]), k) + end + uniquerows = collect(values(firstrow)) + + # Check for collisions + collided = falses(size(A, dim)) + @inbounds begin + @nloops N i A d->(if d == dim + k = i_d + j_d = uniquerow[k] + else + j_d = i_d + end) begin + if (@nref N A j) != (@nref N A i) + collided[k] = true + end + end + end + + if any(collided) + nowcollided = BitArray(size(A, dim)) + while any(collided) + # Collect index of first row for each collided hash + empty!(firstrow) + for j = 1:size(A, dim) + collided[j] || continue + uniquerow[j] = get!(firstrow, Prehashed(hashes[j]), j) + end + for v in values(firstrow) + push!(uniquerows, v) + end + + # Check for collisions + fill!(nowcollided, false) + @nloops N i A d->begin + if d == dim + k = i_d + j_d = uniquerow[k] + (!collided[k] || j_d == k) && continue + else + j_d = i_d + end + end begin + if (@nref N A j) != (@nref N A i) + nowcollided[k] = true + end + end + (collided, nowcollided) = (nowcollided, collided) + end + end + + @nref N A d->d == dim ? sort!(uniquerows) : (1:size(A, d)) +end diff --git a/test/arrayops.jl b/test/arrayops.jl index 0917e69bb73ee..dc9ae60d6d3a3 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -331,6 +331,32 @@ for i = tensors @test isequal(i,permutedims(ipermutedims(i,perm),perm)) end +## unique across dim ## + +# All rows and columns unique +A = ones(10, 10) +A[diagind(A)] = shuffle!([1:10]) +@test unique(A, 1) == A +@test unique(A, 2) == A + +# 10 repeats of each row +B = A[shuffle!(repmat(1:10, 10)), :] +C = unique(B, 1) +@test sortrows(C) == sortrows(A) +@test unique(B, 2) == B +@test unique(B.', 2).' == C + +# Along third dimension +D = cat(3, B, B) +@test unique(D, 1) == cat(3, C, C) +@test unique(D, 3) == cat(3, B) + +# With hash collisions +immutable HashCollision + x::Float64 +end +Base.hash(::HashCollision) = uint(0) +@test map(x->x.x, unique(map(HashCollision, B), 1)) == C ## reduce ##