From aa8bdf9ac025d92c696e24992b888e68b90e366c Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Thu, 30 Jun 2022 05:46:19 -0400 Subject: [PATCH 01/16] axes_keys method and Key type indexing --- .../src/ArrayInterfaceCore.jl | 14 +++- src/ArrayInterface.jl | 2 +- src/axes.jl | 64 +++++++++++++++++++ src/indexing.jl | 7 ++ 4 files changed, 84 insertions(+), 3 deletions(-) diff --git a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl index aa68c1b20..9e19e5334 100644 --- a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl +++ b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl @@ -94,7 +94,7 @@ _flatten(::Tuple{}) = () Returns the parent array that type `T` wraps. """ parent_type(x) = parent_type(typeof(x)) -parent_type(::Type{Symmetric{T,S}}) where {T,S} = S +parent_type(@nospecialize T::Type{<:Union{Symmetric,Hermitian}}) = fieldtype(T, :data) parent_type(::Type{<:AbstractTriangular{T,S}}) where {T,S} = S parent_type(@nospecialize T::Type{<:PermutedDimsArray}) = fieldtype(T, :parent) parent_type(@nospecialize T::Type{<:Adjoint}) = fieldtype(T, :parent) @@ -560,6 +560,15 @@ Base.@propagate_inbounds function Base.getindex(ind::TridiagonalIndex, i::Int) end end +""" + Key(key) + +A type that clearly communicates that `key` refers to a key-index mapping. +""" +struct Key{K} <: ArrayIndex{0} + key::K +end + _cartesian_index(i::Tuple{Vararg{Int}}) = CartesianIndex(i) _cartesian_index(::Any) = nothing @@ -643,6 +652,7 @@ is returned. ndims_index(::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = N # preserve CartesianIndices{0} as they consume a dimension. ndims_index(::Type{CartesianIndices{0,Tuple{}}}) = 1 +ndims_index(@nospecialize T::Type{<:Union{Number,Key}}) = 1 ndims_index(@nospecialize T::Type{<:AbstractArray{Bool}}) = ndims(T) ndims_index(@nospecialize T::Type{<:AbstractArray}) = ndims_index(eltype(T)) ndims_index(@nospecialize T::Type{<:Base.LogicalIndex}) = ndims(fieldtype(T, :mask)) @@ -658,7 +668,7 @@ indexing with an instance of `I`. ndims_shape(T::DataType) = ndims_index(T) ndims_shape(::Type{Colon}) = 1 ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T) -ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex}}) = 0 +ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex,Key}}) = 0 ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1 ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T) ndims_shape(x) = ndims_shape(typeof(x)) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index eed5789a6..bf83506a7 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -6,7 +6,7 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff issingular, isstructured, matrix_colors, restructure, lu_instance, safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type, ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo, - map_tuple_type, flatten_tuples, GetIndex + map_tuple_type, flatten_tuples, GetIndex, Key # ArrayIndex subtypes and methods import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex diff --git a/src/axes.jl b/src/axes.jl index 89150b12d..2071a35d0 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -245,3 +245,67 @@ lazy_axes(x::Union{LinearIndices,CartesianIndices,AbstractRange}) = axes(x) @inline function lazy_axes(x::Union{PermutedDimsArray,MatAdjTrans}) map(GetIndex{false}(lazy_axes(parent(x))), to_parent_dims(x)) end + + +# TODO wait for response on https://github.com/JuliaLang/julia/issues/45872 +# struct IndexKeys <: IndexStyle end + +""" + axes_keys(x) + axes_keys(x, dim) + +Returns a tuple of keys assigned to each axis or the axis at dimension `dim` for `x`. +Default is to simply return `map(keys, axes(x))`. +""" +axes_keys(x, dim) = axes_keys(x, to_dims(x, dim)) +@inline axes_keys(x, d::CanonicalInt) = d > ndims(x) ? keys(axes(x, d)) : axes_keys(x)[d] +@inline axes_keys(x) = is_forwarding_wrapper(x) ? axes_keys(parent(x)) : map(keys, axes(x)) +function axes_keys(x::Union{MatAdjTrans,PermutedDimsArray}) + map(GetIndex{false}(axes_keys(parent(x))), to_parent_dims(x)) +end +axes_keys(A::VecAdjTrans) = (SOneTo{1}(), getfield(axes_keys(parent(A)), 1)) + +# TODO ReshapedArray - is there any approach for appropriately propagating keys? +function axes_keys(x::SubArray) + flatten_tuples(map( + Base.Fix1(_axis_key_view, (x.indices, axes_keys(parent(x)))), + map_indices_info(map_indices_info(IndicesInfo(x))) + )) +end +# TODO should we be taking views of keys instead of directly indexing them? views may be +# problematic if the keys aren't array types (e.g., tuple) +function _axis_key_view((inds, ks), ::Tuple{StaticInt{index},StaticInt{pdim},StaticInt{cdim}}) where {index,pdim,cdim} + if pdim === 0 # trailing dimension + return keys(SOneTo{1}()) + elseif cdim === 0 # dropped dimension + return () + else + i = getfield(inds, index) + if idx isa Base.Slice + return getfield(ks, pdim) + else + return @inbounds getfield(ks, pdim)[i] # TODO can we assume this is safe? + end + end +end +# if the index creates multiple dimension in the SubArray or maps to multiple dimension of +# the parent array, then we just get the keys from the index (similar to how we manage axes). +function _axis_key_view((inds, ks), x::Tuple{StaticInt{index},Any,Any}) where {index} + axes_keys(getfield(inds, index)) +end +axes_keys(x::Union{Symmetric,Hermitian}) = axes_keys(parent(x)) +axes_keys(x::LazyAxis{N,P}) where {N,P} = axes_keys(getfield(x, :parent), static(N)) +@inline function axes_keys(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} + if sizeof(S) > sizeof(T) # TODO should we check if we can cleanly convert each field name of `S` to a key? + return flatten_tuples((keys(SOneTo{div(sizeof(S), sizeof(T))}()), axes_keys(parent(x)))) + elseif sizeof(S) < sizeof(T) + return Base.tail(axes_keys(parent(x))) + else + return axes_keys(parent(x)) + end +end +@inline @inline function axes_keys(x::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S} + ak = axes_keys(parent(x)) + ak1 = keys(StaticInt(1):div(static_length(first(ak)) * static(sizeof(S)), static(sizeof(T)))) + flatten_tuples((ak1, Base.tail(ak))) +end diff --git a/src/indexing.jl b/src/indexing.jl index 76f3effe3..2ad4ea778 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -174,6 +174,13 @@ end @inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}}) max(_add1(canonicalize(i.x)), static_first(x)):static_last(x) end +function to_index(x, k::Key) + index = findfirst(==(k.key), first(axes_keys(x))) + # delay throwing bounds-error if we didn't find key + index === nothing ? offset1(x) - 1: index +end +# TODO there's probably a more efficient way of doing this +to_index(x, ks::AbstractArray{<:Key}) = [to_index(x, k) for k in ks] # integer indexing to_index(x, i::AbstractArray{<:Integer}) = i to_index(x, @nospecialize(i::StaticInt)) = i From 88be1b1c8ed1819c8d3bd5a4d480965420a20c79 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Thu, 30 Jun 2022 05:49:57 -0400 Subject: [PATCH 02/16] fix typo --- src/indexing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/indexing.jl b/src/indexing.jl index 2ad4ea778..4ba6d2a98 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -177,7 +177,7 @@ end function to_index(x, k::Key) index = findfirst(==(k.key), first(axes_keys(x))) # delay throwing bounds-error if we didn't find key - index === nothing ? offset1(x) - 1: index + index === nothing ? offset1(x) - 1 : index end # TODO there's probably a more efficient way of doing this to_index(x, ks::AbstractArray{<:Key}) = [to_index(x, k) for k in ks] From 8f1e165bcf18b427533bb028743f76810e906101 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Thu, 30 Jun 2022 09:21:05 -0400 Subject: [PATCH 03/16] Add tests --- .../src/ArrayInterfaceCore.jl | 4 ++-- src/axes.jl | 21 +++++++++++-------- src/dimensions.jl | 2 ++ src/indexing.jl | 7 +++++++ test/axes.jl | 18 ++++++++++++++++ test/setup.jl | 11 +++++++++- 6 files changed, 51 insertions(+), 12 deletions(-) diff --git a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl index 9e19e5334..c5b5f5fe9 100644 --- a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl +++ b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl @@ -652,7 +652,7 @@ is returned. ndims_index(::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = N # preserve CartesianIndices{0} as they consume a dimension. ndims_index(::Type{CartesianIndices{0,Tuple{}}}) = 1 -ndims_index(@nospecialize T::Type{<:Union{Number,Key}}) = 1 +ndims_index(@nospecialize T::Type{<:Union{Number,Key,Symbol,AbstractString}}) = 1 ndims_index(@nospecialize T::Type{<:AbstractArray{Bool}}) = ndims(T) ndims_index(@nospecialize T::Type{<:AbstractArray}) = ndims_index(eltype(T)) ndims_index(@nospecialize T::Type{<:Base.LogicalIndex}) = ndims(fieldtype(T, :mask)) @@ -668,7 +668,7 @@ indexing with an instance of `I`. ndims_shape(T::DataType) = ndims_index(T) ndims_shape(::Type{Colon}) = 1 ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T) -ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex,Key}}) = 0 +ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex,Key,Symbol,AbstractString}}) = 0 ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1 ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T) ndims_shape(x) = ndims_shape(typeof(x)) diff --git a/src/axes.jl b/src/axes.jl index 2071a35d0..f65cdcf04 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -90,7 +90,6 @@ function axes_types(::Type{A}) where {T,N,S,A<:Base.ReshapedReinterpretArray{T,N end end - # FUTURE NOTE: we avoid `SOneTo(1)` when `axis(A, dim::Int)``. This is inended to decreases # breaking changes for this adopting this method to situations where they clearly benefit # from the propagation of static axes. This creates the somewhat awkward situation of @@ -246,7 +245,6 @@ lazy_axes(x::Union{LinearIndices,CartesianIndices,AbstractRange}) = axes(x) map(GetIndex{false}(lazy_axes(parent(x))), to_parent_dims(x)) end - # TODO wait for response on https://github.com/JuliaLang/julia/issues/45872 # struct IndexKeys <: IndexStyle end @@ -269,7 +267,7 @@ axes_keys(A::VecAdjTrans) = (SOneTo{1}(), getfield(axes_keys(parent(A)), 1)) function axes_keys(x::SubArray) flatten_tuples(map( Base.Fix1(_axis_key_view, (x.indices, axes_keys(parent(x)))), - map_indices_info(map_indices_info(IndicesInfo(x))) + map_indices_info(IndicesInfo(x)) )) end # TODO should we be taking views of keys instead of directly indexing them? views may be @@ -281,23 +279,28 @@ function _axis_key_view((inds, ks), ::Tuple{StaticInt{index},StaticInt{pdim},Sta return () else i = getfield(inds, index) - if idx isa Base.Slice - return getfield(ks, pdim) + if i isa Base.Slice + return (getfield(ks, pdim),) else - return @inbounds getfield(ks, pdim)[i] # TODO can we assume this is safe? + return (@inbounds(getfield(ks, pdim)[i]),) # TODO can we assume this is safe? end end end +axes_keys(x::Union{LinearIndices,CartesianIndices}) = map(first ∘ axes_keys, axes(x)) # if the index creates multiple dimension in the SubArray or maps to multiple dimension of # the parent array, then we just get the keys from the index (similar to how we manage axes). function _axis_key_view((inds, ks), x::Tuple{StaticInt{index},Any,Any}) where {index} axes_keys(getfield(inds, index)) end axes_keys(x::Union{Symmetric,Hermitian}) = axes_keys(parent(x)) -axes_keys(x::LazyAxis{N,P}) where {N,P} = axes_keys(getfield(x, :parent), static(N)) +axes_keys(x::LazyAxis{N,P}) where {N,P} = (axes_keys(getfield(x, :parent), static(N)),) @inline function axes_keys(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} - if sizeof(S) > sizeof(T) # TODO should we check if we can cleanly convert each field name of `S` to a key? - return flatten_tuples((keys(SOneTo{div(sizeof(S), sizeof(T))}()), axes_keys(parent(x)))) + if sizeof(S) > sizeof(T) + if isstructtype(S) && div(sizeof(S), sizeof(T)) === fieldcount(S) + return flatten_tuples(((fieldnames(S),), axes_keys(parent(x)))) + else + return flatten_tuples((keys(SOneTo{}()), axes_keys(parent(x)))) + end elseif sizeof(S) < sizeof(T) return Base.tail(axes_keys(parent(x))) else diff --git a/src/dimensions.jl b/src/dimensions.jl index f77ef2347..c05e1b4c0 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -201,6 +201,8 @@ end return ntuple(Compat.Returns(static(:_)), StaticInt(ndims(x))) end end +dimnames(x::LazyAxis{:,P}) where {P} = first(dimnames(getfield(x, :parent))) +dimnames(x::LazyAxis{N,P}) where {N,P} = getfield(dimnames(getfield(x, :parent)), N) @inline function dimnames(x::X) where {X} if is_forwarding_wrapper(X) return dimnames(parent(x)) diff --git a/src/indexing.jl b/src/indexing.jl index 4ba6d2a98..dc8b2d23e 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -171,6 +171,9 @@ end @inline function to_index(x, i::Base.Fix2{typeof(>=),<:Union{Base.BitInteger,StaticInt}}) max(canonicalize(i.x), static_first(x)):static_last(x) end +@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:Key}) + findall(i.f(i.x.key), first(axes_keys(x))) +end @inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}}) max(_add1(canonicalize(i.x)), static_first(x)):static_last(x) end @@ -179,6 +182,10 @@ function to_index(x, k::Key) # delay throwing bounds-error if we didn't find key index === nothing ? offset1(x) - 1 : index end +function to_index(x, k::Union{Symbol,AbstractString}) + index = findfirst(==(k), first(axes_keys(x))) + index === nothing ? offset1(x) - 1 : index +end # TODO there's probably a more efficient way of doing this to_index(x, ks::AbstractArray{<:Key}) = [to_index(x, k) for k in ks] # integer indexing diff --git a/test/axes.jl b/test/axes.jl index 325b20c69..b8d9a2718 100644 --- a/test/axes.jl +++ b/test/axes.jl @@ -104,3 +104,21 @@ if isdefined(Base, :ReshapedReinterpretArray) @inferred(ArrayInterface.axes(fa)) isa ArrayInterface.axes_types(fa) end end + +@testset "axes_keys" begin + colors = KeyedArray([(R = rand(), G = rand(), B = rand()) for i ∈ 1:100], (range(-10, 10, length=100),)) + colormat = reinterpret(reshape, Float64, colors); + cmat_view1 = view(colormat, :, 4); + cmat_view2 = view(colormat, :, 4:7); + cmat_view3 = view(colormat, 2:3,:); + + @test @inferred(ArrayInterface.axes_keys(colors)) == (range(-10, 10, length=100),) + @test @inferred(ArrayInterface.axes_keys(colormat)) == ((:R, :G, :B), range(-10, 10, length=100)) + @test @inferred(ArrayInterface.axes_keys(cmat_view1)) == ((:R, :G, :B),) + @test @inferred((ArrayInterface.axes_keys(cmat_view2))) == ((:R, :G, :B), -9.393939393939394:0.20202020202020202:-8.787878787878787) + # can't infer this b/c tuple is being indexed by range + @test ArrayInterface.axes_keys(cmat_view3) == ((:G, :B), -10.0:0.20202020202020202:10.0) + + @test @inferred(ArrayInterface.getindex(colormat, :R, :)) == colormat[1, :] + @test @inferred(ArrayInterface.getindex(cmat_view1, :R)) == cmat_view1[1] +end diff --git a/test/setup.jl b/test/setup.jl index 4dc7869ba..aff877d02 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -34,7 +34,16 @@ function ArrayInterface.known_dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L ArrayInterface.Static.known(L) end -Base.parent(x::NamedDimsWrapper) = x.parent +struct KeyedArray{T,N,P<:AbstractArray{T,N},K} <: ArrayInterface.AbstractArray2{T,N} + parent::P + keys::K + + KeyedArray(p::P, k::K) where {P,K} = new{eltype(P),ndims(p),P,K}(p, k) +end +ArrayInterface.is_forwarding_wrapper(::Type{<:KeyedArray}) = true +Base.parent(x::KeyedArray) = getfield(x, :parent) +ArrayInterface.parent_type(::Type{T}) where {P,T<:KeyedArray{<:Any,<:Any,P}} = P +ArrayInterface.axes_keys(x::KeyedArray) = getfield(x, :keys) # Dummy array type with undetermined contiguity properties struct DummyZeros{T,N} <: AbstractArray{T,N} From 2e4e4a2bb6b0acaf50d3712f8baa607c3cb8bd61 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Thu, 30 Jun 2022 12:02:02 -0400 Subject: [PATCH 04/16] More tests and fix transpose arrays keys --- src/axes.jl | 2 +- src/dimensions.jl | 7 +++++-- src/indexing.jl | 9 +++++---- test/axes.jl | 11 ++++++++++- test/dimensions.jl | 7 +++++++ 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/axes.jl b/src/axes.jl index f65cdcf04..a1cbe8f0f 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -261,7 +261,7 @@ axes_keys(x, dim) = axes_keys(x, to_dims(x, dim)) function axes_keys(x::Union{MatAdjTrans,PermutedDimsArray}) map(GetIndex{false}(axes_keys(parent(x))), to_parent_dims(x)) end -axes_keys(A::VecAdjTrans) = (SOneTo{1}(), getfield(axes_keys(parent(A)), 1)) +axes_keys(x::VecAdjTrans) = (keys(SOneTo{1}()), getfield(axes_keys(parent(x)), 1)) # TODO ReshapedArray - is there any approach for appropriately propagating keys? function axes_keys(x::SubArray) diff --git a/src/dimensions.jl b/src/dimensions.jl index c05e1b4c0..87d44ea86 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -142,6 +142,9 @@ end return ntuple(Compat.Returns(:_), StaticInt(ndims(T))) end end +known_dimnames(::Type{<:LazyAxis{:,P}}) where {P} = (first(known_dimnames(P)),) +known_dimnames(::Type{<:LazyAxis{N,P}}) where {N,P} = (getfield(known_dimnames(P), N),) + @inline function known_dimnames(::Type{T}) where {T} if is_forwarding_wrapper(T) return known_dimnames(parent_type(T)) @@ -201,8 +204,8 @@ end return ntuple(Compat.Returns(static(:_)), StaticInt(ndims(x))) end end -dimnames(x::LazyAxis{:,P}) where {P} = first(dimnames(getfield(x, :parent))) -dimnames(x::LazyAxis{N,P}) where {N,P} = getfield(dimnames(getfield(x, :parent)), N) +dimnames(x::LazyAxis{:,P}) where {P} = (first(dimnames(getfield(x, :parent))),) +dimnames(x::LazyAxis{N,P}) where {N,P} = (getfield(dimnames(getfield(x, :parent)), N),) @inline function dimnames(x::X) where {X} if is_forwarding_wrapper(X) return dimnames(parent(x)) diff --git a/src/indexing.jl b/src/indexing.jl index dc8b2d23e..c986c8d60 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -177,6 +177,11 @@ end @inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}}) max(_add1(canonicalize(i.x)), static_first(x)):static_last(x) end +to_index(x, i::AbstractArray{<:Union{Base.BitInteger,StaticInt}}) = i +to_index(x, @nospecialize(i::StaticInt)) = i +to_index(x, i::Integer) = Int(i) +@inline to_index(x, i) = to_index(IndexStyle(x), x, i) +# key indexing function to_index(x, k::Key) index = findfirst(==(k.key), first(axes_keys(x))) # delay throwing bounds-error if we didn't find key @@ -189,10 +194,6 @@ end # TODO there's probably a more efficient way of doing this to_index(x, ks::AbstractArray{<:Key}) = [to_index(x, k) for k in ks] # integer indexing -to_index(x, i::AbstractArray{<:Integer}) = i -to_index(x, @nospecialize(i::StaticInt)) = i -to_index(x, i::Integer) = Int(i) -@inline to_index(x, i) = to_index(IndexStyle(x), x, i) function to_index(S::IndexStyle, x, i) throw(ArgumentError( "invalid index: $S does not support indices of type $(typeof(i)) for instances of type $(typeof(x))." diff --git a/test/axes.jl b/test/axes.jl index b8d9a2718..b942df42e 100644 --- a/test/axes.jl +++ b/test/axes.jl @@ -106,19 +106,28 @@ if isdefined(Base, :ReshapedReinterpretArray) end @testset "axes_keys" begin - colors = KeyedArray([(R = rand(), G = rand(), B = rand()) for i ∈ 1:100], (range(-10, 10, length=100),)) + colors = KeyedArray([(R = rand(), G = rand(), B = rand()) for i ∈ 1:100], (range(-10, 10, length=100),)); + caxis = ArrayInterface.LazyAxis{1}(colors); colormat = reinterpret(reshape, Float64, colors); cmat_view1 = view(colormat, :, 4); cmat_view2 = view(colormat, :, 4:7); cmat_view3 = view(colormat, 2:3,:); @test @inferred(ArrayInterface.axes_keys(colors)) == (range(-10, 10, length=100),) + @test @inferred(ArrayInterface.axes_keys(caxis)) == (range(-10, 10, length=100),) + @test ArrayInterface.axes_keys(view(colors, :, :), 2) == keys(static(1):static(1)) + @test @inferred(ArrayInterface.axes_keys(LinearIndices((caxis,)))) == (range(-10, 10, length=100),) @test @inferred(ArrayInterface.axes_keys(colormat)) == ((:R, :G, :B), range(-10, 10, length=100)) + @test @inferred(ArrayInterface.axes_keys(colormat')) == (range(-10, 10, length=100), (:R, :G, :B)) @test @inferred(ArrayInterface.axes_keys(cmat_view1)) == ((:R, :G, :B),) @test @inferred((ArrayInterface.axes_keys(cmat_view2))) == ((:R, :G, :B), -9.393939393939394:0.20202020202020202:-8.787878787878787) + @test @inferred((ArrayInterface.axes_keys(view(colormat, 1, :)'))) == (keys(static(1):static(1)), range(-10, 10, length=100)) # can't infer this b/c tuple is being indexed by range @test ArrayInterface.axes_keys(cmat_view3) == ((:G, :B), -10.0:0.20202020202020202:10.0) + @test @inferred(ArrayInterface.axes_keys(Symmetric(view(colormat, :, 1:3)))) == ((:R, :G, :B), -10.0:0.20202020202020202:-9.595959595959595) @test @inferred(ArrayInterface.getindex(colormat, :R, :)) == colormat[1, :] @test @inferred(ArrayInterface.getindex(cmat_view1, :R)) == cmat_view1[1] + @test @inferred(ArrayInterface.getindex(colormat, :,ArrayInterface.Key(-9.595959595959595))) == colormat[:, 3] + @test @inferred(ArrayInterface.getindex(colormat, :,<=(ArrayInterface.Key(-9.595959595959595)))) == colormat[:, 1:3] end diff --git a/test/dimensions.jl b/test/dimensions.jl index 00876a383..48dd6b1c1 100644 --- a/test/dimensions.jl +++ b/test/dimensions.jl @@ -29,6 +29,9 @@ end r4 = reinterpret(reshape, Float64, x) w = Wrapper(x) dnums = ntuple(+, length(d)) + lz2 = ArrayInterface.lazy_axes(x)[2] + lzslice = ArrayInterface.LazyAxis{:}(x) + @test @inferred(ArrayInterface.has_dimnames(x)) == true @test @inferred(ArrayInterface.has_dimnames(z)) == true @test @inferred(ArrayInterface.has_dimnames(ones(2, 2))) == false @@ -36,6 +39,8 @@ end @test @inferred(ArrayInterface.has_dimnames(typeof(x))) == true @test @inferred(ArrayInterface.has_dimnames(typeof(view(x, :, 1, :)))) == true @test @inferred(ArrayInterface.dimnames(x)) === d + @test @inferred(ArrayInterface.dimnames(lz2)) === (static(:y),) + @test @inferred(ArrayInterface.dimnames(lzslice)) === (static(:x),) @test @inferred(ArrayInterface.dimnames(w)) === d @test @inferred(ArrayInterface.dimnames(r1)) === d @test @inferred(ArrayInterface.dimnames(r2)) === (static(:_), d...) @@ -64,6 +69,8 @@ end # multidmensional indices @test @inferred(ArrayInterface.known_dimnames(view(x, ones(Int, 2, 2), 1))) === (:_, :_) @test @inferred(ArrayInterface.known_dimnames(view(x, [CartesianIndex(1,1), CartesianIndex(1,1)]))) === (:_,) + @test @inferred(ArrayInterface.known_dimnames(lz2)) === (:y,) + @test @inferred(ArrayInterface.known_dimnames(lzslice)) === (:x,) @test @inferred(ArrayInterface.known_dimnames(z)) === (nothing, :y) @test @inferred(ArrayInterface.known_dimnames(reshape(x, (1, 4)))) === (:x, :y) From d05d7bc7e238d829d1b88f854240e67d7e7004a1 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Thu, 30 Jun 2022 12:34:07 -0400 Subject: [PATCH 05/16] try to preserve axes on reinterpret when size doesn't change --- src/axes.jl | 15 ++++++++------- test/axes.jl | 5 +++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/axes.jl b/src/axes.jl index a1cbe8f0f..34cf4edd2 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -245,9 +245,6 @@ lazy_axes(x::Union{LinearIndices,CartesianIndices,AbstractRange}) = axes(x) map(GetIndex{false}(lazy_axes(parent(x))), to_parent_dims(x)) end -# TODO wait for response on https://github.com/JuliaLang/julia/issues/45872 -# struct IndexKeys <: IndexStyle end - """ axes_keys(x) axes_keys(x, dim) @@ -299,7 +296,7 @@ axes_keys(x::LazyAxis{N,P}) where {N,P} = (axes_keys(getfield(x, :parent), stati if isstructtype(S) && div(sizeof(S), sizeof(T)) === fieldcount(S) return flatten_tuples(((fieldnames(S),), axes_keys(parent(x)))) else - return flatten_tuples((keys(SOneTo{}()), axes_keys(parent(x)))) + return flatten_tuples((keys(SOneTo{div(sizeof(S), sizeof(T))}()), axes_keys(parent(x)))) end elseif sizeof(S) < sizeof(T) return Base.tail(axes_keys(parent(x))) @@ -308,7 +305,11 @@ axes_keys(x::LazyAxis{N,P}) where {N,P} = (axes_keys(getfield(x, :parent), stati end end @inline @inline function axes_keys(x::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S} - ak = axes_keys(parent(x)) - ak1 = keys(StaticInt(1):div(static_length(first(ak)) * static(sizeof(S)), static(sizeof(T)))) - flatten_tuples((ak1, Base.tail(ak))) + Ss = sizeof(S) + Ts = sizeof(T) + if Ss === Ts + return axes_keys(parent(x)) + else + return flatten_tuples((keys(StaticInt(1):size(x, 1)), Base.tail(axes_keys(parent(x))))) + end end diff --git a/test/axes.jl b/test/axes.jl index b942df42e..cc5ac25e4 100644 --- a/test/axes.jl +++ b/test/axes.jl @@ -130,4 +130,9 @@ end @test @inferred(ArrayInterface.getindex(cmat_view1, :R)) == cmat_view1[1] @test @inferred(ArrayInterface.getindex(colormat, :,ArrayInterface.Key(-9.595959595959595))) == colormat[:, 3] @test @inferred(ArrayInterface.getindex(colormat, :,<=(ArrayInterface.Key(-9.595959595959595)))) == colormat[:, 1:3] + @test @inferred(ArrayInterface.axes_keys(reinterpret(Int8, KeyedArray(randn(2,2), ([:a, :b], ["a", "b"]))))) == (keys(Base.OneTo(16)), ["a", "b"]) + @test @inferred(ArrayInterface.axes_keys(reinterpret(reshape, Int8, KeyedArray(randn(2,2), ([:a, :b], ["a", "b"]))))) == (keys(static(1):static(8)), [:a, :b], ["a", "b"]) + @test @inferred(ArrayInterface.axes_keys(reinterpret(reshape, Int64, KeyedArray(rand(Int32, 2,2), ([:a, :b], ["a", "b"],))))) == (["a", "b"],) + @test @inferred(ArrayInterface.axes_keys(reinterpret(reshape, Float64, KeyedArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],) + @test @inferred(ArrayInterface.axes_keys(reinterpret(Float64, KeyedArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],) end From 5f7434fe8a401a9c224384556532be5883cf0f65 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Fri, 1 Jul 2022 13:48:48 -0400 Subject: [PATCH 06/16] Split up reinterpreting keys to help inference on 1.6 --- .../src/ArrayInterfaceCore.jl | 4 +-- src/axes.jl | 30 ++++++++++++------- src/indexing.jl | 6 +++- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl index c5b5f5fe9..24fe6fbc7 100644 --- a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl +++ b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl @@ -652,7 +652,7 @@ is returned. ndims_index(::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = N # preserve CartesianIndices{0} as they consume a dimension. ndims_index(::Type{CartesianIndices{0,Tuple{}}}) = 1 -ndims_index(@nospecialize T::Type{<:Union{Number,Key,Symbol,AbstractString}}) = 1 +ndims_index(@nospecialize T::Type{<:Union{Number,Key,Symbol,AbstractString,AbstractChar}}) = 1 ndims_index(@nospecialize T::Type{<:AbstractArray{Bool}}) = ndims(T) ndims_index(@nospecialize T::Type{<:AbstractArray}) = ndims_index(eltype(T)) ndims_index(@nospecialize T::Type{<:Base.LogicalIndex}) = ndims(fieldtype(T, :mask)) @@ -668,7 +668,7 @@ indexing with an instance of `I`. ndims_shape(T::DataType) = ndims_index(T) ndims_shape(::Type{Colon}) = 1 ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T) -ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex,Key,Symbol,AbstractString}}) = 0 +ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex,Key,Symbol,AbstractString,AbstractChar}}) = 0 ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1 ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T) ndims_shape(x) = ndims_shape(typeof(x)) diff --git a/src/axes.jl b/src/axes.jl index 34cf4edd2..3d423a1db 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -291,19 +291,29 @@ function _axis_key_view((inds, ks), x::Tuple{StaticInt{index},Any,Any}) where {i end axes_keys(x::Union{Symmetric,Hermitian}) = axes_keys(parent(x)) axes_keys(x::LazyAxis{N,P}) where {N,P} = (axes_keys(getfield(x, :parent), static(N)),) -@inline function axes_keys(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} - if sizeof(S) > sizeof(T) - if isstructtype(S) && div(sizeof(S), sizeof(T)) === fieldcount(S) - return flatten_tuples(((fieldnames(S),), axes_keys(parent(x)))) - else - return flatten_tuples((keys(SOneTo{div(sizeof(S), sizeof(T))}()), axes_keys(parent(x)))) - end - elseif sizeof(S) < sizeof(T) - return Base.tail(axes_keys(parent(x))) +function axes_keys(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} + _reinterpret_axes_keys(div(StaticInt(sizeof(S)), StaticInt(sizeof(T))), x) +end +@inline function _reinterpreted_fieldnames(@nospecialize T::Type{<:Base.ReshapedReinterpretArray}) + S = eltype(parent_type(T)) + if isstructtype(S) + return fieldnames(S) else - return axes_keys(parent(x)) + return () + end +end +function _reinterpret_axes_keys(s::StaticInt{N}, x::Base.ReshapedReinterpretArray) where {N} + __reinterpret_axes_keys(s, _reinterpreted_fieldnames(typeof(x)), axes_keys(parent(x))) +end +@inline function __reinterpret_axes_keys(::StaticInt{N}, fields::NTuple{M,Symbol}, ks::Tuple) where {N,M} + if N === M + return flatten_tuples(((fields,), ks)) + else + return flatten_tuples((LinearIndices((SOneTo{N}(),)), ks)) end end +_reinterpret_axes_keys(::StaticInt{1}, x::Base.ReshapedReinterpretArray) = axes_keys(parent(x)) +_reinterpret_axes_keys(::StaticInt{0}, x::Base.ReshapedReinterpretArray) = tail(axes_keys(parent(x))) @inline @inline function axes_keys(x::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S} Ss = sizeof(S) Ts = sizeof(T) diff --git a/src/indexing.jl b/src/indexing.jl index c986c8d60..b6e64f4c9 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -187,12 +187,16 @@ function to_index(x, k::Key) # delay throwing bounds-error if we didn't find key index === nothing ? offset1(x) - 1 : index end -function to_index(x, k::Union{Symbol,AbstractString}) +function to_index(x, k::Union{Symbol,AbstractString,AbstractChar}) index = findfirst(==(k), first(axes_keys(x))) index === nothing ? offset1(x) - 1 : index end # TODO there's probably a more efficient way of doing this to_index(x, ks::AbstractArray{<:Key}) = [to_index(x, k) for k in ks] +function to_index(x, ks::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar}}) + [findfirst(==(k), axes_keys(x)) for k in ks] +end + # integer indexing function to_index(S::IndexStyle, x, i) throw(ArgumentError( From c27eb5a8ebfdcefc21ba340db28eef51b77ddc4e Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 3 Jul 2022 11:35:02 -0400 Subject: [PATCH 07/16] version bump and fix multi-key access --- Project.toml | 2 +- lib/ArrayInterfaceCore/Project.toml | 2 +- src/indexing.jl | 6 +++--- test/axes.jl | 13 ++++++++----- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 2872bc31a..31bb4fabe 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "6.0.19" +version = "6.0.20" [deps] ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" diff --git a/lib/ArrayInterfaceCore/Project.toml b/lib/ArrayInterfaceCore/Project.toml index 708575a42..fb886ad15 100644 --- a/lib/ArrayInterfaceCore/Project.toml +++ b/lib/ArrayInterfaceCore/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterfaceCore" uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" -version = "0.1.14" +version = "0.1.15" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/indexing.jl b/src/indexing.jl index b6e64f4c9..757a2650d 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -187,14 +187,14 @@ function to_index(x, k::Key) # delay throwing bounds-error if we didn't find key index === nothing ? offset1(x) - 1 : index end -function to_index(x, k::Union{Symbol,AbstractString,AbstractChar}) +function to_index(x, k::Union{Symbol,AbstractString,AbstractChar,Number}) index = findfirst(==(k), first(axes_keys(x))) index === nothing ? offset1(x) - 1 : index end # TODO there's probably a more efficient way of doing this to_index(x, ks::AbstractArray{<:Key}) = [to_index(x, k) for k in ks] -function to_index(x, ks::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar}}) - [findfirst(==(k), axes_keys(x)) for k in ks] +function to_index(x, ks::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar,Number}}) + [to_index(x, k) for k in ks] end # integer indexing diff --git a/test/axes.jl b/test/axes.jl index cc5ac25e4..ba4ecdde9 100644 --- a/test/axes.jl +++ b/test/axes.jl @@ -112,6 +112,7 @@ end cmat_view1 = view(colormat, :, 4); cmat_view2 = view(colormat, :, 4:7); cmat_view3 = view(colormat, 2:3,:); + absym_abstr = KeyedArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],)) @test @inferred(ArrayInterface.axes_keys(colors)) == (range(-10, 10, length=100),) @test @inferred(ArrayInterface.axes_keys(caxis)) == (range(-10, 10, length=100),) @@ -126,13 +127,15 @@ end @test ArrayInterface.axes_keys(cmat_view3) == ((:G, :B), -10.0:0.20202020202020202:10.0) @test @inferred(ArrayInterface.axes_keys(Symmetric(view(colormat, :, 1:3)))) == ((:R, :G, :B), -10.0:0.20202020202020202:-9.595959595959595) + @test @inferred(ArrayInterface.axes_keys(reinterpret(Int8, absym_abstr))) == (keys(Base.OneTo(16)), ["a", "b"]) + @test @inferred(ArrayInterface.axes_keys(reinterpret(reshape, Int8, absym_abstr))) == (keys(static(1):static(8)), [:a, :b], ["a", "b"]) + @test @inferred(ArrayInterface.axes_keys(reinterpret(reshape, Int64, KeyedArray(rand(Int32, 2,2), ([:a, :b], ["a", "b"],))))) == (["a", "b"],) + @test @inferred(ArrayInterface.axes_keys(reinterpret(reshape, Float64, KeyedArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],) + @test @inferred(ArrayInterface.axes_keys(reinterpret(Float64, absym_abstr))) == ([:a, :b], ["a", "b"],) + @test @inferred(ArrayInterface.getindex(colormat, :R, :)) == colormat[1, :] @test @inferred(ArrayInterface.getindex(cmat_view1, :R)) == cmat_view1[1] @test @inferred(ArrayInterface.getindex(colormat, :,ArrayInterface.Key(-9.595959595959595))) == colormat[:, 3] @test @inferred(ArrayInterface.getindex(colormat, :,<=(ArrayInterface.Key(-9.595959595959595)))) == colormat[:, 1:3] - @test @inferred(ArrayInterface.axes_keys(reinterpret(Int8, KeyedArray(randn(2,2), ([:a, :b], ["a", "b"]))))) == (keys(Base.OneTo(16)), ["a", "b"]) - @test @inferred(ArrayInterface.axes_keys(reinterpret(reshape, Int8, KeyedArray(randn(2,2), ([:a, :b], ["a", "b"]))))) == (keys(static(1):static(8)), [:a, :b], ["a", "b"]) - @test @inferred(ArrayInterface.axes_keys(reinterpret(reshape, Int64, KeyedArray(rand(Int32, 2,2), ([:a, :b], ["a", "b"],))))) == (["a", "b"],) - @test @inferred(ArrayInterface.axes_keys(reinterpret(reshape, Float64, KeyedArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],) - @test @inferred(ArrayInterface.axes_keys(reinterpret(Float64, KeyedArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],) + @test @inferred(ArrayInterface.getindex(absym_abstr, :, ["a"])) == absym_abstr[:,[1]] end From b8d6e93efbd9f262b905c76a058ff1513a239568 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Fri, 5 Aug 2022 17:35:23 -0400 Subject: [PATCH 08/16] axes_keys -> axislabels --- .../src/ArrayInterfaceCore.jl | 12 +- src/ArrayInterface.jl | 2 +- src/axes.jl | 114 ++++++++---------- src/indexing.jl | 16 +-- test/axes.jl | 44 +++---- test/setup.jl | 14 +-- 6 files changed, 92 insertions(+), 110 deletions(-) diff --git a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl index 24fe6fbc7..a467ccb59 100644 --- a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl +++ b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl @@ -561,12 +561,12 @@ Base.@propagate_inbounds function Base.getindex(ind::TridiagonalIndex, i::Int) end """ - Key(key) + Label(label) -A type that clearly communicates that `key` refers to a key-index mapping. +A type that clearly communicates that `label` refers to a key-index mapping. """ -struct Key{K} <: ArrayIndex{0} - key::K +struct Label{L} <: ArrayIndex{0} + label::L end _cartesian_index(i::Tuple{Vararg{Int}}) = CartesianIndex(i) @@ -652,7 +652,7 @@ is returned. ndims_index(::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = N # preserve CartesianIndices{0} as they consume a dimension. ndims_index(::Type{CartesianIndices{0,Tuple{}}}) = 1 -ndims_index(@nospecialize T::Type{<:Union{Number,Key,Symbol,AbstractString,AbstractChar}}) = 1 +ndims_index(@nospecialize T::Type{<:Union{Number,Label,Symbol,AbstractString,AbstractChar}}) = 1 ndims_index(@nospecialize T::Type{<:AbstractArray{Bool}}) = ndims(T) ndims_index(@nospecialize T::Type{<:AbstractArray}) = ndims_index(eltype(T)) ndims_index(@nospecialize T::Type{<:Base.LogicalIndex}) = ndims(fieldtype(T, :mask)) @@ -668,7 +668,7 @@ indexing with an instance of `I`. ndims_shape(T::DataType) = ndims_index(T) ndims_shape(::Type{Colon}) = 1 ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T) -ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex,Key,Symbol,AbstractString,AbstractChar}}) = 0 +ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex,Label,Symbol,AbstractString,AbstractChar}}) = 0 ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1 ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T) ndims_shape(x) = ndims_shape(typeof(x)) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index bf83506a7..b7f491b6f 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -6,7 +6,7 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff issingular, isstructured, matrix_colors, restructure, lu_instance, safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type, ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo, - map_tuple_type, flatten_tuples, GetIndex, Key + map_tuple_type, flatten_tuples, GetIndex, Label # ArrayIndex subtypes and methods import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex diff --git a/src/axes.jl b/src/axes.jl index 3d423a1db..0ce7f0b73 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -112,7 +112,7 @@ axes(A::ReshapedArray) = Base.axes(A) @inline function axes(x::Union{MatAdjTrans,PermutedDimsArray}) map(GetIndex{false}(axes(parent(x))), to_parent_dims(x)) end -axes(A::VecAdjTrans) = (SOneTo{1}(), axes(parent(A), 1)) +axes(A::VecAdjTrans) = (SOneTo{1}(), getfield(axes(parent(A)), 1)) @inline axes(x::SubArray) = flatten_tuples(map(Base.Fix1(_sub_axes, x), sub_axes_map(typeof(x)))) @inline _sub_axes(x::SubArray, axis::SOneTo) = axis @@ -246,80 +246,62 @@ lazy_axes(x::Union{LinearIndices,CartesianIndices,AbstractRange}) = axes(x) end """ - axes_keys(x) - axes_keys(x, dim) + axislabels(x) + axislabels(x, dim) -Returns a tuple of keys assigned to each axis or the axis at dimension `dim` for `x`. -Default is to simply return `map(keys, axes(x))`. +Returns a tuple of labels assigned to each axis or a collection of labels corresponding to +axis `dim` of `x`. Default is to simply return `map(keys, axes(x))`. """ -axes_keys(x, dim) = axes_keys(x, to_dims(x, dim)) -@inline axes_keys(x, d::CanonicalInt) = d > ndims(x) ? keys(axes(x, d)) : axes_keys(x)[d] -@inline axes_keys(x) = is_forwarding_wrapper(x) ? axes_keys(parent(x)) : map(keys, axes(x)) -function axes_keys(x::Union{MatAdjTrans,PermutedDimsArray}) - map(GetIndex{false}(axes_keys(parent(x))), to_parent_dims(x)) -end -axes_keys(x::VecAdjTrans) = (keys(SOneTo{1}()), getfield(axes_keys(parent(x)), 1)) - -# TODO ReshapedArray - is there any approach for appropriately propagating keys? -function axes_keys(x::SubArray) - flatten_tuples(map( - Base.Fix1(_axis_key_view, (x.indices, axes_keys(parent(x)))), - map_indices_info(IndicesInfo(x)) - )) -end -# TODO should we be taking views of keys instead of directly indexing them? views may be -# problematic if the keys aren't array types (e.g., tuple) -function _axis_key_view((inds, ks), ::Tuple{StaticInt{index},StaticInt{pdim},StaticInt{cdim}}) where {index,pdim,cdim} - if pdim === 0 # trailing dimension - return keys(SOneTo{1}()) - elseif cdim === 0 # dropped dimension - return () - else - i = getfield(inds, index) - if i isa Base.Slice - return (getfield(ks, pdim),) +axislabels(x, dim) = axislabels(x, to_dims(x, dim)) +axislabels(@nospecialize x::Number) = () +@inline axislabels(x, d::CanonicalInt) = d > ndims(x) ? keys(axes(x, d)) : axislabels(x)[d] +@inline axislabels(x) = is_forwarding_wrapper(x) ? axislabels(parent(x)) : map(keys, axes(x)) +function axislabels(x::Union{MatAdjTrans,PermutedDimsArray}) + map(GetIndex{false}(axislabels(parent(x))), to_parent_dims(x)) +end +axislabels(x::VecAdjTrans) = (keys(SOneTo{1}()), getfield(axislabels(parent(x)), 1)) +axislabels(x::SubArray) = _sub_axislabels(parent(x), x.indices, IndicesInfo(x)) +function _sub_axislabels(x::AbstractArray, inds::Tuple, ::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims} + labels = axislabels(x) + flatten_tuples(ntuple(Val{nfields(pdims)}()) do i + pdim_i = getfield(pdims, i) + cdim_i = getfield(cdims, i) + index = getfield(inds, i) + if pdim_i isa Tuple || cdim_i isa Tuple # no direct mapping to parent axes + axislabels(index) + elseif cdim_i === 0 # integer indexing drops axes + () + elseif pdim_i === 0 # trailing dimension + LinearIndices((SOneTo{1}(),)) + elseif index isa Base.Slice # index into labels where there is direct mapping to parent axis + (getfield(labels, pdim_i),) else - return (@inbounds(getfield(ks, pdim)[i]),) # TODO can we assume this is safe? + (@inbounds(getfield(labels, pdim_i)[index]),) end + end) +end +axislabels(x::Union{LinearIndices,CartesianIndices}) = map(first ∘ axislabels, axes(x)) +axislabels(x::Union{Symmetric,Hermitian}) = axislabels(parent(x)) +axislabels(x::LazyAxis{N,P}) where {N,P} = (axislabels(getfield(x, :parent), StaticInt(N)),) +@inline @inline function axislabels(x::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S} + if sizeof(T) === sizeof(S) + return axislabels(parent(x)) + else + return flatten_tuples((keys(StaticInt(1):size(x, 1)), Base.tail(axislabels(parent(x))))) end end -axes_keys(x::Union{LinearIndices,CartesianIndices}) = map(first ∘ axes_keys, axes(x)) -# if the index creates multiple dimension in the SubArray or maps to multiple dimension of -# the parent array, then we just get the keys from the index (similar to how we manage axes). -function _axis_key_view((inds, ks), x::Tuple{StaticInt{index},Any,Any}) where {index} - axes_keys(getfield(inds, index)) -end -axes_keys(x::Union{Symmetric,Hermitian}) = axes_keys(parent(x)) -axes_keys(x::LazyAxis{N,P}) where {N,P} = (axes_keys(getfield(x, :parent), static(N)),) -function axes_keys(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} - _reinterpret_axes_keys(div(StaticInt(sizeof(S)), StaticInt(sizeof(T))), x) +function axislabels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} + _reinterpret_axislabels(div(StaticInt(sizeof(S)), StaticInt(sizeof(T))), x) end @inline function _reinterpreted_fieldnames(@nospecialize T::Type{<:Base.ReshapedReinterpretArray}) S = eltype(parent_type(T)) - if isstructtype(S) - return fieldnames(S) - else - return () - end -end -function _reinterpret_axes_keys(s::StaticInt{N}, x::Base.ReshapedReinterpretArray) where {N} - __reinterpret_axes_keys(s, _reinterpreted_fieldnames(typeof(x)), axes_keys(parent(x))) + isstructtype(S) ? fieldnames(S) : () end -@inline function __reinterpret_axes_keys(::StaticInt{N}, fields::NTuple{M,Symbol}, ks::Tuple) where {N,M} - if N === M - return flatten_tuples(((fields,), ks)) - else - return flatten_tuples((LinearIndices((SOneTo{N}(),)), ks)) - end +function _reinterpret_axislabels(s::StaticInt{N}, x::Base.ReshapedReinterpretArray) where {N} + __reinterpret_axislabels(s, _reinterpreted_fieldnames(typeof(x)), axislabels(parent(x))) end -_reinterpret_axes_keys(::StaticInt{1}, x::Base.ReshapedReinterpretArray) = axes_keys(parent(x)) -_reinterpret_axes_keys(::StaticInt{0}, x::Base.ReshapedReinterpretArray) = tail(axes_keys(parent(x))) -@inline @inline function axes_keys(x::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S} - Ss = sizeof(S) - Ts = sizeof(T) - if Ss === Ts - return axes_keys(parent(x)) - else - return flatten_tuples((keys(StaticInt(1):size(x, 1)), Base.tail(axes_keys(parent(x))))) - end +@inline function __reinterpret_axislabels(::StaticInt{N}, fields::NTuple{M,Symbol}, ks::Tuple) where {N,M} + N === M ? (fields, ks...,) : (LinearIndices((SOneTo{N}(),)), ks...,) end +_reinterpret_axislabels(::StaticInt{1}, x::Base.ReshapedReinterpretArray) = axislabels(parent(x)) +_reinterpret_axislabels(::StaticInt{0}, x::Base.ReshapedReinterpretArray) = tail(axislabels(parent(x))) diff --git a/src/indexing.jl b/src/indexing.jl index 757a2650d..97b58b84f 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -171,8 +171,8 @@ end @inline function to_index(x, i::Base.Fix2{typeof(>=),<:Union{Base.BitInteger,StaticInt}}) max(canonicalize(i.x), static_first(x)):static_last(x) end -@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:Key}) - findall(i.f(i.x.key), first(axes_keys(x))) +@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:Label}) + findall(i.f(i.x.label), first(axislabels(x))) end @inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}}) max(_add1(canonicalize(i.x)), static_first(x)):static_last(x) @@ -182,17 +182,17 @@ to_index(x, @nospecialize(i::StaticInt)) = i to_index(x, i::Integer) = Int(i) @inline to_index(x, i) = to_index(IndexStyle(x), x, i) # key indexing -function to_index(x, k::Key) - index = findfirst(==(k.key), first(axes_keys(x))) - # delay throwing bounds-error if we didn't find key +function to_index(x, i::Label) + index = findfirst(==(getfield(i, :label)), first(axislabels(x))) + # delay throwing bounds-error if we didn't find label index === nothing ? offset1(x) - 1 : index end -function to_index(x, k::Union{Symbol,AbstractString,AbstractChar,Number}) - index = findfirst(==(k), first(axes_keys(x))) +function to_index(x, i::Union{Symbol,AbstractString,AbstractChar,Number}) + index = findfirst(==(i), getfield(axislabels(x), 1)) index === nothing ? offset1(x) - 1 : index end # TODO there's probably a more efficient way of doing this -to_index(x, ks::AbstractArray{<:Key}) = [to_index(x, k) for k in ks] +to_index(x, ks::AbstractArray{<:Label}) = [to_index(x, k) for k in ks] function to_index(x, ks::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar,Number}}) [to_index(x, k) for k in ks] end diff --git a/test/axes.jl b/test/axes.jl index ba4ecdde9..5726c8880 100644 --- a/test/axes.jl +++ b/test/axes.jl @@ -105,37 +105,37 @@ if isdefined(Base, :ReshapedReinterpretArray) end end -@testset "axes_keys" begin - colors = KeyedArray([(R = rand(), G = rand(), B = rand()) for i ∈ 1:100], (range(-10, 10, length=100),)); +@testset "axislabels" begin + colors = LabelledArray([(R = rand(), G = rand(), B = rand()) for i ∈ 1:100], (range(-10, 10, length=100),)); caxis = ArrayInterface.LazyAxis{1}(colors); colormat = reinterpret(reshape, Float64, colors); cmat_view1 = view(colormat, :, 4); cmat_view2 = view(colormat, :, 4:7); cmat_view3 = view(colormat, 2:3,:); - absym_abstr = KeyedArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],)) - - @test @inferred(ArrayInterface.axes_keys(colors)) == (range(-10, 10, length=100),) - @test @inferred(ArrayInterface.axes_keys(caxis)) == (range(-10, 10, length=100),) - @test ArrayInterface.axes_keys(view(colors, :, :), 2) == keys(static(1):static(1)) - @test @inferred(ArrayInterface.axes_keys(LinearIndices((caxis,)))) == (range(-10, 10, length=100),) - @test @inferred(ArrayInterface.axes_keys(colormat)) == ((:R, :G, :B), range(-10, 10, length=100)) - @test @inferred(ArrayInterface.axes_keys(colormat')) == (range(-10, 10, length=100), (:R, :G, :B)) - @test @inferred(ArrayInterface.axes_keys(cmat_view1)) == ((:R, :G, :B),) - @test @inferred((ArrayInterface.axes_keys(cmat_view2))) == ((:R, :G, :B), -9.393939393939394:0.20202020202020202:-8.787878787878787) - @test @inferred((ArrayInterface.axes_keys(view(colormat, 1, :)'))) == (keys(static(1):static(1)), range(-10, 10, length=100)) + absym_abstr = LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],)); + + @test @inferred(ArrayInterface.axislabels(colors)) == (range(-10, 10, length=100),) + @test @inferred(ArrayInterface.axislabels(caxis)) == (range(-10, 10, length=100),) + @test ArrayInterface.axislabels(view(colors, :, :), 2) == keys(static(1):static(1)) + @test @inferred(ArrayInterface.axislabels(LinearIndices((caxis,)))) == (range(-10, 10, length=100),) + @test @inferred(ArrayInterface.axislabels(colormat)) == ((:R, :G, :B), range(-10, 10, length=100)) + @test @inferred(ArrayInterface.axislabels(colormat')) == (range(-10, 10, length=100), (:R, :G, :B)) + @test @inferred(ArrayInterface.axislabels(cmat_view1)) == ((:R, :G, :B),) + @test @inferred((ArrayInterface.axislabels(cmat_view2))) == ((:R, :G, :B), -9.393939393939394:0.20202020202020202:-8.787878787878787) + @test @inferred((ArrayInterface.axislabels(view(colormat, 1, :)'))) == (keys(static(1):static(1)), range(-10, 10, length=100)) # can't infer this b/c tuple is being indexed by range - @test ArrayInterface.axes_keys(cmat_view3) == ((:G, :B), -10.0:0.20202020202020202:10.0) - @test @inferred(ArrayInterface.axes_keys(Symmetric(view(colormat, :, 1:3)))) == ((:R, :G, :B), -10.0:0.20202020202020202:-9.595959595959595) + @test ArrayInterface.axislabels(cmat_view3) == ((:G, :B), -10.0:0.20202020202020202:10.0) + @test @inferred(ArrayInterface.axislabels(Symmetric(view(colormat, :, 1:3)))) == ((:R, :G, :B), -10.0:0.20202020202020202:-9.595959595959595) - @test @inferred(ArrayInterface.axes_keys(reinterpret(Int8, absym_abstr))) == (keys(Base.OneTo(16)), ["a", "b"]) - @test @inferred(ArrayInterface.axes_keys(reinterpret(reshape, Int8, absym_abstr))) == (keys(static(1):static(8)), [:a, :b], ["a", "b"]) - @test @inferred(ArrayInterface.axes_keys(reinterpret(reshape, Int64, KeyedArray(rand(Int32, 2,2), ([:a, :b], ["a", "b"],))))) == (["a", "b"],) - @test @inferred(ArrayInterface.axes_keys(reinterpret(reshape, Float64, KeyedArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],) - @test @inferred(ArrayInterface.axes_keys(reinterpret(Float64, absym_abstr))) == ([:a, :b], ["a", "b"],) + @test @inferred(ArrayInterface.axislabels(reinterpret(Int8, absym_abstr))) == (keys(Base.OneTo(16)), ["a", "b"]) + @test @inferred(ArrayInterface.axislabels(reinterpret(reshape, Int8, absym_abstr))) == (keys(static(1):static(8)), [:a, :b], ["a", "b"]) + @test @inferred(ArrayInterface.axislabels(reinterpret(reshape, Int64, LabelledArray(rand(Int32, 2,2), ([:a, :b], ["a", "b"],))))) == (["a", "b"],) + @test @inferred(ArrayInterface.axislabels(reinterpret(reshape, Float64, LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],) + @test @inferred(ArrayInterface.axislabels(reinterpret(Float64, absym_abstr))) == ([:a, :b], ["a", "b"],) @test @inferred(ArrayInterface.getindex(colormat, :R, :)) == colormat[1, :] @test @inferred(ArrayInterface.getindex(cmat_view1, :R)) == cmat_view1[1] - @test @inferred(ArrayInterface.getindex(colormat, :,ArrayInterface.Key(-9.595959595959595))) == colormat[:, 3] - @test @inferred(ArrayInterface.getindex(colormat, :,<=(ArrayInterface.Key(-9.595959595959595)))) == colormat[:, 1:3] + @test @inferred(ArrayInterface.getindex(colormat, :,ArrayInterface.Label(-9.595959595959595))) == colormat[:, 3] + @test @inferred(ArrayInterface.getindex(colormat, :,<=(ArrayInterface.Label(-9.595959595959595)))) == colormat[:, 1:3] @test @inferred(ArrayInterface.getindex(absym_abstr, :, ["a"])) == absym_abstr[:,[1]] end diff --git a/test/setup.jl b/test/setup.jl index aff877d02..d5d2e0736 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -34,16 +34,16 @@ function ArrayInterface.known_dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L ArrayInterface.Static.known(L) end -struct KeyedArray{T,N,P<:AbstractArray{T,N},K} <: ArrayInterface.AbstractArray2{T,N} +struct LabelledArray{T,N,P<:AbstractArray{T,N},L} <: ArrayInterface.AbstractArray2{T,N} parent::P - keys::K + labels::L - KeyedArray(p::P, k::K) where {P,K} = new{eltype(P),ndims(p),P,K}(p, k) + LabelledArray(p::P, labels::L) where {P,L} = new{eltype(P),ndims(p),P,L}(p, labels) end -ArrayInterface.is_forwarding_wrapper(::Type{<:KeyedArray}) = true -Base.parent(x::KeyedArray) = getfield(x, :parent) -ArrayInterface.parent_type(::Type{T}) where {P,T<:KeyedArray{<:Any,<:Any,P}} = P -ArrayInterface.axes_keys(x::KeyedArray) = getfield(x, :keys) +ArrayInterface.is_forwarding_wrapper(::Type{<:LabelledArray}) = true +Base.parent(x::LabelledArray) = getfield(x, :parent) +ArrayInterface.parent_type(::Type{T}) where {P,T<:LabelledArray{<:Any,<:Any,P}} = P +ArrayInterface.axislabels(x::LabelledArray) = getfield(x, :labels) # Dummy array type with undetermined contiguity properties struct DummyZeros{T,N} <: AbstractArray{T,N} From 8fde10bae271d8283c34405d281f5ee42854a2ab Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Fri, 26 Aug 2022 21:17:54 -0400 Subject: [PATCH 09/16] axislabels -> index_labels Discourse conversation pointed out there may be some confusion with plotting libraries that often use axis tick labels. Switch to index should be more specific what is being labeled. --- src/axes.jl | 52 ++++++++++++++++++++++++------------------------- src/indexing.jl | 6 +++--- test/axes.jl | 36 +++++++++++++++++----------------- test/setup.jl | 2 +- 4 files changed, 48 insertions(+), 48 deletions(-) diff --git a/src/axes.jl b/src/axes.jl index 9cc4099b1..7e624dfcd 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -249,29 +249,29 @@ lazy_axes(x, ::StaticInt{dim}) where {dim} = ndims(x) < dim ? SOneTo{1}() : Lazy @inline lazy_axes(x, dims::Tuple) = map(Base.Fix1(lazy_axes, x), dims) """ - axislabels(x) - axislabels(x, dim) + index_labels(x) + index_labels(x, dim) Returns a tuple of labels assigned to each axis or a collection of labels corresponding to axis `dim` of `x`. Default is to simply return `map(keys, axes(x))`. """ -axislabels(x, dim) = axislabels(x, to_dims(x, dim)) -axislabels(@nospecialize x::Number) = () -@inline axislabels(x, d::CanonicalInt) = d > ndims(x) ? keys(axes(x, d)) : axislabels(x)[d] -@inline axislabels(x) = is_forwarding_wrapper(x) ? axislabels(parent(x)) : map(keys, axes(x)) -function axislabels(x::Union{MatAdjTrans,PermutedDimsArray}) - map(GetIndex{false}(axislabels(parent(x))), to_parent_dims(x)) +index_labels(x, dim) = index_labels(x, to_dims(x, dim)) +index_labels(@nospecialize x::Number) = () +@inline index_labels(x, d::CanonicalInt) = d > ndims(x) ? keys(axes(x, d)) : index_labels(x)[d] +@inline index_labels(x) = is_forwarding_wrapper(x) ? index_labels(parent(x)) : map(keys, axes(x)) +function index_labels(x::Union{MatAdjTrans,PermutedDimsArray}) + map(GetIndex{false}(index_labels(parent(x))), to_parent_dims(x)) end -axislabels(x::VecAdjTrans) = (keys(SOneTo{1}()), getfield(axislabels(parent(x)), 1)) -axislabels(x::SubArray) = _sub_axislabels(parent(x), x.indices, IndicesInfo(x)) -function _sub_axislabels(x::AbstractArray, inds::Tuple, ::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims} - labels = axislabels(x) +index_labels(x::VecAdjTrans) = (keys(SOneTo{1}()), getfield(index_labels(parent(x)), 1)) +index_labels(x::SubArray) = _sub_index_labels(parent(x), x.indices, IndicesInfo(x)) +function _sub_index_labels(x::AbstractArray, inds::Tuple, ::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims} + labels = index_labels(x) flatten_tuples(ntuple(Val{nfields(pdims)}()) do i pdim_i = getfield(pdims, i) cdim_i = getfield(cdims, i) index = getfield(inds, i) if pdim_i isa Tuple || cdim_i isa Tuple # no direct mapping to parent axes - axislabels(index) + index_labels(index) elseif cdim_i === 0 # integer indexing drops axes () elseif pdim_i === 0 # trailing dimension @@ -283,28 +283,28 @@ function _sub_axislabels(x::AbstractArray, inds::Tuple, ::IndicesInfo{N,pdims,cd end end) end -axislabels(x::Union{LinearIndices,CartesianIndices}) = map(first ∘ axislabels, axes(x)) -axislabels(x::Union{Symmetric,Hermitian}) = axislabels(parent(x)) -axislabels(x::LazyAxis{N,P}) where {N,P} = (axislabels(getfield(x, :parent), StaticInt(N)),) -@inline @inline function axislabels(x::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S} +index_labels(x::Union{LinearIndices,CartesianIndices}) = map(first ∘ index_labels, axes(x)) +index_labels(x::Union{Symmetric,Hermitian}) = index_labels(parent(x)) +index_labels(x::LazyAxis{N,P}) where {N,P} = (index_labels(getfield(x, :parent), StaticInt(N)),) +@inline @inline function index_labels(x::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S} if sizeof(T) === sizeof(S) - return axislabels(parent(x)) + return index_labels(parent(x)) else - return flatten_tuples((keys(StaticInt(1):size(x, 1)), Base.tail(axislabels(parent(x))))) + return flatten_tuples((keys(StaticInt(1):size(x, 1)), Base.tail(index_labels(parent(x))))) end end -function axislabels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} - _reinterpret_axislabels(div(StaticInt(sizeof(S)), StaticInt(sizeof(T))), x) +function index_labels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} + _reinterpret_index_labels(div(StaticInt(sizeof(S)), StaticInt(sizeof(T))), x) end @inline function _reinterpreted_fieldnames(@nospecialize T::Type{<:Base.ReshapedReinterpretArray}) S = eltype(parent_type(T)) isstructtype(S) ? fieldnames(S) : () end -function _reinterpret_axislabels(s::StaticInt{N}, x::Base.ReshapedReinterpretArray) where {N} - __reinterpret_axislabels(s, _reinterpreted_fieldnames(typeof(x)), axislabels(parent(x))) +function _reinterpret_index_labels(s::StaticInt{N}, x::Base.ReshapedReinterpretArray) where {N} + __reinterpret_index_labels(s, _reinterpreted_fieldnames(typeof(x)), index_labels(parent(x))) end -@inline function __reinterpret_axislabels(::StaticInt{N}, fields::NTuple{M,Symbol}, ks::Tuple) where {N,M} +@inline function __reinterpret_index_labels(::StaticInt{N}, fields::NTuple{M,Symbol}, ks::Tuple) where {N,M} N === M ? (fields, ks...,) : (LinearIndices((SOneTo{N}(),)), ks...,) end -_reinterpret_axislabels(::StaticInt{1}, x::Base.ReshapedReinterpretArray) = axislabels(parent(x)) -_reinterpret_axislabels(::StaticInt{0}, x::Base.ReshapedReinterpretArray) = tail(axislabels(parent(x))) +_reinterpret_index_labels(::StaticInt{1}, x::Base.ReshapedReinterpretArray) = index_labels(parent(x)) +_reinterpret_index_labels(::StaticInt{0}, x::Base.ReshapedReinterpretArray) = tail(index_labels(parent(x))) diff --git a/src/indexing.jl b/src/indexing.jl index 60d884ea7..97e9795d3 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -171,7 +171,7 @@ end max(canonicalize(i.x), static_first(x)):static_last(x) end @inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:Label}) - findall(i.f(i.x.label), first(axislabels(x))) + findall(i.f(i.x.label), first(index_labels(x))) end @inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}}) max(_add1(canonicalize(i.x)), static_first(x)):static_last(x) @@ -182,12 +182,12 @@ to_index(x, i::Integer) = Int(i) @inline to_index(x, i) = to_index(IndexStyle(x), x, i) # key indexing function to_index(x, i::Label) - index = findfirst(==(getfield(i, :label)), first(axislabels(x))) + index = findfirst(==(getfield(i, :label)), first(index_labels(x))) # delay throwing bounds-error if we didn't find label index === nothing ? offset1(x) - 1 : index end function to_index(x, i::Union{Symbol,AbstractString,AbstractChar,Number}) - index = findfirst(==(i), getfield(axislabels(x), 1)) + index = findfirst(==(i), getfield(index_labels(x), 1)) index === nothing ? offset1(x) - 1 : index end # TODO there's probably a more efficient way of doing this diff --git a/test/axes.jl b/test/axes.jl index 5726c8880..cb7d0451b 100644 --- a/test/axes.jl +++ b/test/axes.jl @@ -105,7 +105,7 @@ if isdefined(Base, :ReshapedReinterpretArray) end end -@testset "axislabels" begin +@testset "index_labels" begin colors = LabelledArray([(R = rand(), G = rand(), B = rand()) for i ∈ 1:100], (range(-10, 10, length=100),)); caxis = ArrayInterface.LazyAxis{1}(colors); colormat = reinterpret(reshape, Float64, colors); @@ -114,24 +114,24 @@ end cmat_view3 = view(colormat, 2:3,:); absym_abstr = LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],)); - @test @inferred(ArrayInterface.axislabels(colors)) == (range(-10, 10, length=100),) - @test @inferred(ArrayInterface.axislabels(caxis)) == (range(-10, 10, length=100),) - @test ArrayInterface.axislabels(view(colors, :, :), 2) == keys(static(1):static(1)) - @test @inferred(ArrayInterface.axislabels(LinearIndices((caxis,)))) == (range(-10, 10, length=100),) - @test @inferred(ArrayInterface.axislabels(colormat)) == ((:R, :G, :B), range(-10, 10, length=100)) - @test @inferred(ArrayInterface.axislabels(colormat')) == (range(-10, 10, length=100), (:R, :G, :B)) - @test @inferred(ArrayInterface.axislabels(cmat_view1)) == ((:R, :G, :B),) - @test @inferred((ArrayInterface.axislabels(cmat_view2))) == ((:R, :G, :B), -9.393939393939394:0.20202020202020202:-8.787878787878787) - @test @inferred((ArrayInterface.axislabels(view(colormat, 1, :)'))) == (keys(static(1):static(1)), range(-10, 10, length=100)) + @test @inferred(ArrayInterface.index_labels(colors)) == (range(-10, 10, length=100),) + @test @inferred(ArrayInterface.index_labels(caxis)) == (range(-10, 10, length=100),) + @test ArrayInterface.index_labels(view(colors, :, :), 2) == keys(static(1):static(1)) + @test @inferred(ArrayInterface.index_labels(LinearIndices((caxis,)))) == (range(-10, 10, length=100),) + @test @inferred(ArrayInterface.index_labels(colormat)) == ((:R, :G, :B), range(-10, 10, length=100)) + @test @inferred(ArrayInterface.index_labels(colormat')) == (range(-10, 10, length=100), (:R, :G, :B)) + @test @inferred(ArrayInterface.index_labels(cmat_view1)) == ((:R, :G, :B),) + @test @inferred((ArrayInterface.index_labels(cmat_view2))) == ((:R, :G, :B), -9.393939393939394:0.20202020202020202:-8.787878787878787) + @test @inferred((ArrayInterface.index_labels(view(colormat, 1, :)'))) == (keys(static(1):static(1)), range(-10, 10, length=100)) # can't infer this b/c tuple is being indexed by range - @test ArrayInterface.axislabels(cmat_view3) == ((:G, :B), -10.0:0.20202020202020202:10.0) - @test @inferred(ArrayInterface.axislabels(Symmetric(view(colormat, :, 1:3)))) == ((:R, :G, :B), -10.0:0.20202020202020202:-9.595959595959595) - - @test @inferred(ArrayInterface.axislabels(reinterpret(Int8, absym_abstr))) == (keys(Base.OneTo(16)), ["a", "b"]) - @test @inferred(ArrayInterface.axislabels(reinterpret(reshape, Int8, absym_abstr))) == (keys(static(1):static(8)), [:a, :b], ["a", "b"]) - @test @inferred(ArrayInterface.axislabels(reinterpret(reshape, Int64, LabelledArray(rand(Int32, 2,2), ([:a, :b], ["a", "b"],))))) == (["a", "b"],) - @test @inferred(ArrayInterface.axislabels(reinterpret(reshape, Float64, LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],) - @test @inferred(ArrayInterface.axislabels(reinterpret(Float64, absym_abstr))) == ([:a, :b], ["a", "b"],) + @test ArrayInterface.index_labels(cmat_view3) == ((:G, :B), -10.0:0.20202020202020202:10.0) + @test @inferred(ArrayInterface.index_labels(Symmetric(view(colormat, :, 1:3)))) == ((:R, :G, :B), -10.0:0.20202020202020202:-9.595959595959595) + + @test @inferred(ArrayInterface.index_labels(reinterpret(Int8, absym_abstr))) == (keys(Base.OneTo(16)), ["a", "b"]) + @test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Int8, absym_abstr))) == (keys(static(1):static(8)), [:a, :b], ["a", "b"]) + @test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Int64, LabelledArray(rand(Int32, 2,2), ([:a, :b], ["a", "b"],))))) == (["a", "b"],) + @test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Float64, LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],) + @test @inferred(ArrayInterface.index_labels(reinterpret(Float64, absym_abstr))) == ([:a, :b], ["a", "b"],) @test @inferred(ArrayInterface.getindex(colormat, :R, :)) == colormat[1, :] @test @inferred(ArrayInterface.getindex(cmat_view1, :R)) == cmat_view1[1] diff --git a/test/setup.jl b/test/setup.jl index d5d2e0736..6c2b33128 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -43,7 +43,7 @@ end ArrayInterface.is_forwarding_wrapper(::Type{<:LabelledArray}) = true Base.parent(x::LabelledArray) = getfield(x, :parent) ArrayInterface.parent_type(::Type{T}) where {P,T<:LabelledArray{<:Any,<:Any,P}} = P -ArrayInterface.axislabels(x::LabelledArray) = getfield(x, :labels) +ArrayInterface.index_labels(x::LabelledArray) = getfield(x, :labels) # Dummy array type with undetermined contiguity properties struct DummyZeros{T,N} <: AbstractArray{T,N} From d6cb2181fdbe2ba7c96ff7bc4dcebb162ec1952a Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 28 Aug 2022 12:48:45 -0400 Subject: [PATCH 10/16] Label -> IndexLabel --- lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl | 8 ++++---- src/ArrayInterface.jl | 2 +- src/indexing.jl | 6 +++--- test/axes.jl | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl index 759c07a0d..a46dddb48 100644 --- a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl +++ b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl @@ -617,11 +617,11 @@ Base.@propagate_inbounds function Base.getindex(ind::TridiagonalIndex, i::Int) end """ - Label(label) + IndexLabel(label) A type that clearly communicates that `label` refers to a key-index mapping. """ -struct Label{L} <: ArrayIndex{0} +struct IndexLabel{L} <: ArrayIndex{0} label::L end @@ -724,7 +724,7 @@ julia> ArrayInterfaceCore.ndims_index([CartesianIndex(1, 2), CartesianIndex(1, 3 ndims_index(::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = N # preserve CartesianIndices{0} as they consume a dimension. ndims_index(::Type{CartesianIndices{0,Tuple{}}}) = 1 -ndims_index(@nospecialize T::Type{<:Union{Number,Label,Symbol,AbstractString,AbstractChar}}) = 1 +ndims_index(@nospecialize T::Type{<:Union{Number,IndexLabel,Symbol,AbstractString,AbstractChar}}) = 1 ndims_index(@nospecialize T::Type{<:AbstractArray{Bool}}) = ndims(T) ndims_index(@nospecialize T::Type{<:AbstractArray}) = ndims_index(eltype(T)) ndims_index(@nospecialize T::Type{<:Base.LogicalIndex}) = ndims(fieldtype(T, :mask)) @@ -752,7 +752,7 @@ julia> ndims(CartesianIndices((2,2))[[CartesianIndex(1, 1), CartesianIndex(1, 2) ndims_shape(T::DataType) = ndims_index(T) ndims_shape(::Type{Colon}) = 1 ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T) -ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex,Label,Symbol,AbstractString,AbstractChar}}) = 0 +ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex,IndexLabel,Symbol,AbstractString,AbstractChar}}) = 0 ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1 ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T) ndims_shape(x) = ndims_shape(typeof(x)) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index abd87449d..5beb3bd7c 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -6,7 +6,7 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff issingular, isstructured, matrix_colors, restructure, lu_instance, safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type, ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo, - map_tuple_type, flatten_tuples, GetIndex, SetIndex!, Label, defines_strides, + map_tuple_type, flatten_tuples, GetIndex, SetIndex!, IndexLabel, defines_strides, stride_preserving_index # ArrayIndex subtypes and methods diff --git a/src/indexing.jl b/src/indexing.jl index 97e9795d3..0d5270816 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -170,7 +170,7 @@ end @inline function to_index(x, i::Base.Fix2{typeof(>=),<:Union{Base.BitInteger,StaticInt}}) max(canonicalize(i.x), static_first(x)):static_last(x) end -@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:Label}) +@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel}) findall(i.f(i.x.label), first(index_labels(x))) end @inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}}) @@ -181,7 +181,7 @@ to_index(x, @nospecialize(i::StaticInt)) = i to_index(x, i::Integer) = Int(i) @inline to_index(x, i) = to_index(IndexStyle(x), x, i) # key indexing -function to_index(x, i::Label) +function to_index(x, i::IndexLabel) index = findfirst(==(getfield(i, :label)), first(index_labels(x))) # delay throwing bounds-error if we didn't find label index === nothing ? offset1(x) - 1 : index @@ -191,7 +191,7 @@ function to_index(x, i::Union{Symbol,AbstractString,AbstractChar,Number}) index === nothing ? offset1(x) - 1 : index end # TODO there's probably a more efficient way of doing this -to_index(x, ks::AbstractArray{<:Label}) = [to_index(x, k) for k in ks] +to_index(x, ks::AbstractArray{<:IndexLabel}) = [to_index(x, k) for k in ks] function to_index(x, ks::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar,Number}}) [to_index(x, k) for k in ks] end diff --git a/test/axes.jl b/test/axes.jl index cb7d0451b..6ef6b9c24 100644 --- a/test/axes.jl +++ b/test/axes.jl @@ -135,7 +135,7 @@ end @test @inferred(ArrayInterface.getindex(colormat, :R, :)) == colormat[1, :] @test @inferred(ArrayInterface.getindex(cmat_view1, :R)) == cmat_view1[1] - @test @inferred(ArrayInterface.getindex(colormat, :,ArrayInterface.Label(-9.595959595959595))) == colormat[:, 3] - @test @inferred(ArrayInterface.getindex(colormat, :,<=(ArrayInterface.Label(-9.595959595959595)))) == colormat[:, 1:3] + @test @inferred(ArrayInterface.getindex(colormat, :,ArrayInterface.IndexLabel(-9.595959595959595))) == colormat[:, 3] + @test @inferred(ArrayInterface.getindex(colormat, :,<=(ArrayInterface.IndexLabel(-9.595959595959595)))) == colormat[:, 1:3] @test @inferred(ArrayInterface.getindex(absym_abstr, :, ["a"])) == absym_abstr[:,[1]] end From bc7a635ce379fb1cde3e13054b7f0e8c78ae1b95 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 28 Aug 2022 20:26:06 -0400 Subject: [PATCH 11/16] Return `nothing` instead of the keys of an axis --- .../src/ArrayInterfaceCore.jl | 5 +- src/ArrayInterface.jl | 1 + src/axes.jl | 67 +++++++++++++++---- src/dimensions.jl | 6 +- test/axes.jl | 8 +-- 5 files changed, 66 insertions(+), 21 deletions(-) diff --git a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl index a46dddb48..2c2004383 100644 --- a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl +++ b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl @@ -619,9 +619,10 @@ end """ IndexLabel(label) -A type that clearly communicates that `label` refers to a key-index mapping. +A type that clearly communicates to internal methods to lookup the index corresponding to +for `label`. """ -struct IndexLabel{L} <: ArrayIndex{0} +struct IndexLabel{L} <: ArrayIndex{1} label::L end diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 5beb3bd7c..7bc4aba3d 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -35,6 +35,7 @@ using Base.Iterators: Pairs using LinearAlgebra import Compat +using Compat: Returns _add1(@nospecialize x) = x + oneunit(x) _sub1(@nospecialize x) = x - oneunit(x) diff --git a/src/axes.jl b/src/axes.jl index 7e624dfcd..0adda535f 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -248,24 +248,61 @@ lazy_axes(x, ::Colon) = LazyAxis{:}(x) lazy_axes(x, ::StaticInt{dim}) where {dim} = ndims(x) < dim ? SOneTo{1}() : LazyAxis{dim}(x) @inline lazy_axes(x, dims::Tuple) = map(Base.Fix1(lazy_axes, x), dims) +""" + has_index_labels(x) -> Bool + +Returns `true` if `x` has has any index labels. If [`index_labels`](@ref) returns a tuple of +`nothing`, this will be `false`. + +See also: [`index_labels`](@ref) +""" +has_index_labels(x) = is_forwarding_wrapper(x) ? has_index_labels(parent(x)) : false +function has_index_labels(x::Union{Base.NonReshapedReinterpretArray,Transpose,Adjoint,PermutedDimsArray,Symmetric,Hermitian}) + has_index_labels(parent(x)) +end +function has_index_labels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} + if has_index_labels(parent(x)) + true + else + size1 = div(sizeof(S), sizeof(T)) + size1 > 1 && size1 === fieldcount(S) + end +end +function has_index_labels(x::SubArray) + has_index_labels(parent(x)) || any(has_index_labels, x.indices) +end + """ index_labels(x) index_labels(x, dim) Returns a tuple of labels assigned to each axis or a collection of labels corresponding to -axis `dim` of `x`. Default is to simply return `map(keys, axes(x))`. +each index along `dim` of `x`. Default is to simply return `nothing`. + +See also: [`has_index_labels`](@ref) """ index_labels(x, dim) = index_labels(x, to_dims(x, dim)) index_labels(@nospecialize x::Number) = () -@inline index_labels(x, d::CanonicalInt) = d > ndims(x) ? keys(axes(x, d)) : index_labels(x)[d] -@inline index_labels(x) = is_forwarding_wrapper(x) ? index_labels(parent(x)) : map(keys, axes(x)) +@inline function index_labels(x, dim::CanonicalInt) + dim > ndims(x) ? nothing : getfield(index_labels(x), Int(dim)) +end +@inline function index_labels(x) + if is_forwarding_wrapper(x) + index_labels(buffer(x)) + else + ntuple(Returns(nothing), Val{ndims(x)}()) + end +end function index_labels(x::Union{MatAdjTrans,PermutedDimsArray}) map(GetIndex{false}(index_labels(parent(x))), to_parent_dims(x)) end -index_labels(x::VecAdjTrans) = (keys(SOneTo{1}()), getfield(index_labels(parent(x)), 1)) -index_labels(x::SubArray) = _sub_index_labels(parent(x), x.indices, IndicesInfo(x)) -function _sub_index_labels(x::AbstractArray, inds::Tuple, ::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims} - labels = index_labels(x) +index_labels(x::VecAdjTrans) = (nothing, getfield(index_labels(parent(x)), 1)) +function index_labels(x::SubArray) + labels = index_labels(parent(x)) + inds = x.indices + info = IndicesInfo(x) + pdims = parentdims(info) + cdims = childdims(info) flatten_tuples(ntuple(Val{nfields(pdims)}()) do i pdim_i = getfield(pdims, i) cdim_i = getfield(cdims, i) @@ -275,22 +312,24 @@ function _sub_index_labels(x::AbstractArray, inds::Tuple, ::IndicesInfo{N,pdims, elseif cdim_i === 0 # integer indexing drops axes () elseif pdim_i === 0 # trailing dimension - LinearIndices((SOneTo{1}(),)) + nothing elseif index isa Base.Slice # index into labels where there is direct mapping to parent axis (getfield(labels, pdim_i),) else - (@inbounds(getfield(labels, pdim_i)[index]),) + labels_i = getfield(labels, pdim_i) + labels_i === nothing ? index_labels(index) : (@inbounds(labels_i[index]),) end end) end -index_labels(x::Union{LinearIndices,CartesianIndices}) = map(first ∘ index_labels, axes(x)) +index_labels(x::Union{LinearIndices,CartesianIndices}) = map(first ∘ index_labels, x.indices) index_labels(x::Union{Symmetric,Hermitian}) = index_labels(parent(x)) -index_labels(x::LazyAxis{N,P}) where {N,P} = (index_labels(getfield(x, :parent), StaticInt(N)),) +index_labels(@nospecialize(x::LazyAxis{:})) = (nothing,) +index_labels(x::LazyAxis{N}) where {N} = (getfield(index_labels(getfield(x, :parent)), N),) @inline @inline function index_labels(x::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S} if sizeof(T) === sizeof(S) return index_labels(parent(x)) else - return flatten_tuples((keys(StaticInt(1):size(x, 1)), Base.tail(index_labels(parent(x))))) + return (nothing, Base.tail(index_labels(parent(x)))...) end end function index_labels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} @@ -304,7 +343,7 @@ function _reinterpret_index_labels(s::StaticInt{N}, x::Base.ReshapedReinterpretA __reinterpret_index_labels(s, _reinterpreted_fieldnames(typeof(x)), index_labels(parent(x))) end @inline function __reinterpret_index_labels(::StaticInt{N}, fields::NTuple{M,Symbol}, ks::Tuple) where {N,M} - N === M ? (fields, ks...,) : (LinearIndices((SOneTo{N}(),)), ks...,) + N === M ? (fields, ks...,) : (nothing, ks...,) end _reinterpret_index_labels(::StaticInt{1}, x::Base.ReshapedReinterpretArray) = index_labels(parent(x)) -_reinterpret_index_labels(::StaticInt{0}, x::Base.ReshapedReinterpretArray) = tail(index_labels(parent(x))) +_reinterpret_index_labels(::StaticInt{0}, x::Base.ReshapedReinterpretArray) = Base.tail(index_labels(parent(x))) diff --git a/src/dimensions.jl b/src/dimensions.jl index 87d44ea86..ab662b447 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -1,11 +1,15 @@ _init_dimsmap(x) = _init_dimsmap(IndicesInfo(x)) -function _init_dimsmap(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims} +function _init_dimsmap(info::IndicesInfo{<:Any,pdims,cdims}) where {pdims,cdims} ntuple(i -> static(getfield(pdims, i)), length(pdims)), ntuple(i -> static(getfield(cdims, i)), length(pdims)) end +parentdims(::IndicesInfo{<:Any,pdims}) where {pdims} = pdims + +childdims(::IndicesInfo{<:Any,<:Any,cdims}) where {cdims} = cdims + """ to_parent_dims(::Type{T}) -> Tuple{Vararg{Union{StaticInt,Tuple{Vararg{StaticInt}}}}} diff --git a/test/axes.jl b/test/axes.jl index 6ef6b9c24..0918f9eab 100644 --- a/test/axes.jl +++ b/test/axes.jl @@ -116,19 +116,19 @@ end @test @inferred(ArrayInterface.index_labels(colors)) == (range(-10, 10, length=100),) @test @inferred(ArrayInterface.index_labels(caxis)) == (range(-10, 10, length=100),) - @test ArrayInterface.index_labels(view(colors, :, :), 2) == keys(static(1):static(1)) + @test ArrayInterface.index_labels(view(colors, :, :), 2) === nothing @test @inferred(ArrayInterface.index_labels(LinearIndices((caxis,)))) == (range(-10, 10, length=100),) @test @inferred(ArrayInterface.index_labels(colormat)) == ((:R, :G, :B), range(-10, 10, length=100)) @test @inferred(ArrayInterface.index_labels(colormat')) == (range(-10, 10, length=100), (:R, :G, :B)) @test @inferred(ArrayInterface.index_labels(cmat_view1)) == ((:R, :G, :B),) @test @inferred((ArrayInterface.index_labels(cmat_view2))) == ((:R, :G, :B), -9.393939393939394:0.20202020202020202:-8.787878787878787) - @test @inferred((ArrayInterface.index_labels(view(colormat, 1, :)'))) == (keys(static(1):static(1)), range(-10, 10, length=100)) + @test @inferred((ArrayInterface.index_labels(view(colormat, 1, :)'))) == (nothing, range(-10, 10, length=100)) # can't infer this b/c tuple is being indexed by range @test ArrayInterface.index_labels(cmat_view3) == ((:G, :B), -10.0:0.20202020202020202:10.0) @test @inferred(ArrayInterface.index_labels(Symmetric(view(colormat, :, 1:3)))) == ((:R, :G, :B), -10.0:0.20202020202020202:-9.595959595959595) - @test @inferred(ArrayInterface.index_labels(reinterpret(Int8, absym_abstr))) == (keys(Base.OneTo(16)), ["a", "b"]) - @test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Int8, absym_abstr))) == (keys(static(1):static(8)), [:a, :b], ["a", "b"]) + @test @inferred(ArrayInterface.index_labels(reinterpret(Int8, absym_abstr))) == (nothing, ["a", "b"]) + @test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Int8, absym_abstr))) == (nothing, [:a, :b], ["a", "b"]) @test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Int64, LabelledArray(rand(Int32, 2,2), ([:a, :b], ["a", "b"],))))) == (["a", "b"],) @test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Float64, LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],) @test @inferred(ArrayInterface.index_labels(reinterpret(Float64, absym_abstr))) == ([:a, :b], ["a", "b"],) From 7231dbe9b1b9773872cbc102dafd00e468fcaa77 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 29 Aug 2022 12:41:30 -0400 Subject: [PATCH 12/16] test has_index_labels --- src/axes.jl | 13 +++++++++++-- test/axes.jl | 6 ++++++ test/setup.jl | 1 + 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/axes.jl b/src/axes.jl index 0adda535f..c71743d2a 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -256,7 +256,7 @@ Returns `true` if `x` has has any index labels. If [`index_labels`](@ref) return See also: [`index_labels`](@ref) """ -has_index_labels(x) = is_forwarding_wrapper(x) ? has_index_labels(parent(x)) : false +has_index_labels(x) = is_forwarding_wrapper(x) ? has_index_labels(buffer(x)) : false function has_index_labels(x::Union{Base.NonReshapedReinterpretArray,Transpose,Adjoint,PermutedDimsArray,Symmetric,Hermitian}) has_index_labels(parent(x)) end @@ -269,8 +269,17 @@ function has_index_labels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} end end function has_index_labels(x::SubArray) - has_index_labels(parent(x)) || any(has_index_labels, x.indices) + if has_index_labels(parent(x)) + return true + else + inds = x.indices + for i in 1:nfields(inds) + has_index_labels(getfield(inds, i)) && return true + end + return false + end end +has_index_labels(x::LazyAxis) = index_labels(x) !== (nothing,) """ index_labels(x) diff --git a/test/axes.jl b/test/axes.jl index 0918f9eab..dac3a7d71 100644 --- a/test/axes.jl +++ b/test/axes.jl @@ -133,6 +133,12 @@ end @test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Float64, LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],) @test @inferred(ArrayInterface.index_labels(reinterpret(Float64, absym_abstr))) == ([:a, :b], ["a", "b"],) + @test ArrayInterface.has_index_labels(colors) + @test ArrayInterface.has_index_labels(caxis) + @test ArrayInterface.has_index_labels(colormat) + @test ArrayInterface.has_index_labels(cmat_view1) + @test !ArrayInterface.has_index_labels(view(colors, :, :)) + @test @inferred(ArrayInterface.getindex(colormat, :R, :)) == colormat[1, :] @test @inferred(ArrayInterface.getindex(cmat_view1, :R)) == cmat_view1[1] @test @inferred(ArrayInterface.getindex(colormat, :,ArrayInterface.IndexLabel(-9.595959595959595))) == colormat[:, 3] diff --git a/test/setup.jl b/test/setup.jl index 6c2b33128..6678cbb49 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -44,6 +44,7 @@ ArrayInterface.is_forwarding_wrapper(::Type{<:LabelledArray}) = true Base.parent(x::LabelledArray) = getfield(x, :parent) ArrayInterface.parent_type(::Type{T}) where {P,T<:LabelledArray{<:Any,<:Any,P}} = P ArrayInterface.index_labels(x::LabelledArray) = getfield(x, :labels) +ArrayInterface.has_index_labels(::LabelledArray) = true # Dummy array type with undetermined contiguity properties struct DummyZeros{T,N} <: AbstractArray{T,N} From 268576fd035522383e8670c06d188b772cb0d018 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 29 Aug 2022 12:47:25 -0400 Subject: [PATCH 13/16] Use `index_labels` as fallback for `has_index_labels` --- src/axes.jl | 5 +++-- test/setup.jl | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axes.jl b/src/axes.jl index c71743d2a..129f17581 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -256,7 +256,7 @@ Returns `true` if `x` has has any index labels. If [`index_labels`](@ref) return See also: [`index_labels`](@ref) """ -has_index_labels(x) = is_forwarding_wrapper(x) ? has_index_labels(buffer(x)) : false +has_index_labels(x) = _any_labels(index_labels(x)) function has_index_labels(x::Union{Base.NonReshapedReinterpretArray,Transpose,Adjoint,PermutedDimsArray,Symmetric,Hermitian}) has_index_labels(parent(x)) end @@ -279,7 +279,8 @@ function has_index_labels(x::SubArray) return false end end -has_index_labels(x::LazyAxis) = index_labels(x) !== (nothing,) +_any_labels(@nospecialize labels::Tuple{Vararg{Nothing}}) = false +_any_labels(@nospecialize labels::Tuple{Vararg{Any}}) = true """ index_labels(x) diff --git a/test/setup.jl b/test/setup.jl index 6678cbb49..6c2b33128 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -44,7 +44,6 @@ ArrayInterface.is_forwarding_wrapper(::Type{<:LabelledArray}) = true Base.parent(x::LabelledArray) = getfield(x, :parent) ArrayInterface.parent_type(::Type{T}) where {P,T<:LabelledArray{<:Any,<:Any,P}} = P ArrayInterface.index_labels(x::LabelledArray) = getfield(x, :labels) -ArrayInterface.has_index_labels(::LabelledArray) = true # Dummy array type with undetermined contiguity properties struct DummyZeros{T,N} <: AbstractArray{T,N} From d7c0a27f00baa6485a92aa12a83ce4a83df850c7 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Thu, 29 Sep 2022 04:54:56 -0400 Subject: [PATCH 14/16] Simplify initial index labels PR * removed support for nested traversal of `index_labels`. Support for this can be added later but it required making a lot of decisions about managing labels that up front that would be better addressed through iterative PRs * removed `has_index_labels`. There are some odd corner cases for this one. Particularly for `SubArrays` where the presence of labels in the parent don't always have a clear way of propagating forward. Again, we can address this one but it will take some decisions about how labels are propagated. * `UnlabelledIndices` and `LabelledIndices` are types that provide a more clear structure to what a label is and how they are accessed. --- .../src/ArrayInterfaceCore.jl | 61 +++++++++++ src/ArrayInterface.jl | 4 +- src/axes.jl | 102 +----------------- src/dimensions.jl | 4 - src/indexing.jl | 29 ++--- test/axes.jl | 53 +++------ test/setup.jl | 2 +- 7 files changed, 100 insertions(+), 155 deletions(-) diff --git a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl index 2074e1f26..bcdd82f9d 100644 --- a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl +++ b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl @@ -677,6 +677,67 @@ struct IndexLabel{L} <: ArrayIndex{1} label::L end +""" + UnlabelledIndices(indices) + +A set of indices that explicitly do not have any labels and cannot be accessed with an +[`IndexLabel`](@ref). +""" +struct UnlabelledIndices{I<:AbstractUnitRange{Int}} <: AbstractUnitRange{Int} + indices::I +end + +""" + LabelledIndices(labels) + +A subtype of `AbstractUnitRange{Int}` whose associeated with labels (`labels`). +`eachindex(labels)` are the indices for `LabelledIndices`. +""" +struct LabelledIndices{L<:AbstractVector} <: AbstractUnitRange{Int} + labels::L +end +# don't nest instances of `LabelledIndices` +LabelledIndices(labels::LabelledIndices) = labels + +Base.parent(x::UnlabelledIndices) = getfield(x, :indices) +Base.parent(x::LabelledIndices) = getfield(x, :labels) + +parent_type(@nospecialize T::Type{<:UnlabelledIndices}) = fieldtype(T, :indices) +parent_type(@nospecialize T::Type{<:LabelledIndices}) = fieldtype(T, :labels) + +is_forwarding_wrapper(@nospecialize T::Type{<:Union{UnlabelledIndices,LabelledIndices}}) = true + +Base.size(x::Union{UnlabelledIndices,LabelledIndices}) = size(parent(x)) +Base.axes(x::Union{UnlabelledIndices,LabelledIndices}) = axes(parent(x)) + +Base.first(x::UnlabelledIndices) = first(parent(x)) +Base.first(x::LabelledIndices) = firstindex(parent(x)) + +Base.last(x::UnlabelledIndices) = last(parent(x)) +Base.last(x::LabelledIndices) = lastindex(parent(x)) + +""" + getlabels(x, idx) + +Given a collection of labelled indices (`x`), the subset of lablled indices corresponding +to the index `idx` are returned. +""" +Base.@propagate_inbounds function getlabels(x::LabelledIndices, idx::I) where {I} + ndims_shape(I) === 0 ? parent(x)[idx] : LabelledIndices(parent(x)[idx]) +end +function getlabels(x::UnlabelledIndices, idx::I) where {I} + @boundscheck checkbounds(parent(x), idx) + ndims_shape(I) === 0 ? nothing : UnlabelledIndices(eachindex(idx)) +end + +""" + setlabels!(x, idx, vals) + +Given a collection of labelled indices (`x`), the subset of lablled indices corresponding +to the index `idx` are returned. +""" +Base.@propagate_inbounds setlabels!(x::LabelledIndices, i, v) = setindex!(parent(x), i, v) + _cartesian_index(i::Tuple{Vararg{Int}}) = CartesianIndex(i) _cartesian_index(::Any) = nothing diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 932e58341..a9639b63c 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -6,8 +6,8 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff issingular, isstructured, matrix_colors, restructure, lu_instance, safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type, ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo, childdims, - parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, IndexLabel, defines_strides, - stride_preserving_index + parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, IndexLabel, + LabelledIndices, getlabels, UnlabelledIndices, defines_strides, stride_preserving_index # ArrayIndex subtypes and methods import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex diff --git a/src/axes.jl b/src/axes.jl index 129f17581..1672ee392 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -248,112 +248,18 @@ lazy_axes(x, ::Colon) = LazyAxis{:}(x) lazy_axes(x, ::StaticInt{dim}) where {dim} = ndims(x) < dim ? SOneTo{1}() : LazyAxis{dim}(x) @inline lazy_axes(x, dims::Tuple) = map(Base.Fix1(lazy_axes, x), dims) -""" - has_index_labels(x) -> Bool - -Returns `true` if `x` has has any index labels. If [`index_labels`](@ref) returns a tuple of -`nothing`, this will be `false`. - -See also: [`index_labels`](@ref) -""" -has_index_labels(x) = _any_labels(index_labels(x)) -function has_index_labels(x::Union{Base.NonReshapedReinterpretArray,Transpose,Adjoint,PermutedDimsArray,Symmetric,Hermitian}) - has_index_labels(parent(x)) -end -function has_index_labels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} - if has_index_labels(parent(x)) - true - else - size1 = div(sizeof(S), sizeof(T)) - size1 > 1 && size1 === fieldcount(S) - end -end -function has_index_labels(x::SubArray) - if has_index_labels(parent(x)) - return true - else - inds = x.indices - for i in 1:nfields(inds) - has_index_labels(getfield(inds, i)) && return true - end - return false - end -end -_any_labels(@nospecialize labels::Tuple{Vararg{Nothing}}) = false -_any_labels(@nospecialize labels::Tuple{Vararg{Any}}) = true - """ index_labels(x) index_labels(x, dim) Returns a tuple of labels assigned to each axis or a collection of labels corresponding to -each index along `dim` of `x`. Default is to simply return `nothing`. - -See also: [`has_index_labels`](@ref) +each index along `dim` of `x`. Default is to return `UnlabelledIndices(axes(x, dim))`. """ index_labels(x, dim) = index_labels(x, to_dims(x, dim)) -index_labels(@nospecialize x::Number) = () @inline function index_labels(x, dim::CanonicalInt) - dim > ndims(x) ? nothing : getfield(index_labels(x), Int(dim)) + dim > ndims(x) ? UnlabelledIndices(SOneTo(1)) : getfield(index_labels(x), Int(dim)) end @inline function index_labels(x) - if is_forwarding_wrapper(x) - index_labels(buffer(x)) - else - ntuple(Returns(nothing), Val{ndims(x)}()) - end -end -function index_labels(x::Union{MatAdjTrans,PermutedDimsArray}) - map(GetIndex{false}(index_labels(parent(x))), to_parent_dims(x)) -end -index_labels(x::VecAdjTrans) = (nothing, getfield(index_labels(parent(x)), 1)) -function index_labels(x::SubArray) - labels = index_labels(parent(x)) - inds = x.indices - info = IndicesInfo(x) - pdims = parentdims(info) - cdims = childdims(info) - flatten_tuples(ntuple(Val{nfields(pdims)}()) do i - pdim_i = getfield(pdims, i) - cdim_i = getfield(cdims, i) - index = getfield(inds, i) - if pdim_i isa Tuple || cdim_i isa Tuple # no direct mapping to parent axes - index_labels(index) - elseif cdim_i === 0 # integer indexing drops axes - () - elseif pdim_i === 0 # trailing dimension - nothing - elseif index isa Base.Slice # index into labels where there is direct mapping to parent axis - (getfield(labels, pdim_i),) - else - labels_i = getfield(labels, pdim_i) - labels_i === nothing ? index_labels(index) : (@inbounds(labels_i[index]),) - end - end) -end -index_labels(x::Union{LinearIndices,CartesianIndices}) = map(first ∘ index_labels, x.indices) -index_labels(x::Union{Symmetric,Hermitian}) = index_labels(parent(x)) -index_labels(@nospecialize(x::LazyAxis{:})) = (nothing,) -index_labels(x::LazyAxis{N}) where {N} = (getfield(index_labels(getfield(x, :parent)), N),) -@inline @inline function index_labels(x::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S} - if sizeof(T) === sizeof(S) - return index_labels(parent(x)) - else - return (nothing, Base.tail(index_labels(parent(x)))...) - end -end -function index_labels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} - _reinterpret_index_labels(div(StaticInt(sizeof(S)), StaticInt(sizeof(T))), x) -end -@inline function _reinterpreted_fieldnames(@nospecialize T::Type{<:Base.ReshapedReinterpretArray}) - S = eltype(parent_type(T)) - isstructtype(S) ? fieldnames(S) : () -end -function _reinterpret_index_labels(s::StaticInt{N}, x::Base.ReshapedReinterpretArray) where {N} - __reinterpret_index_labels(s, _reinterpreted_fieldnames(typeof(x)), index_labels(parent(x))) -end -@inline function __reinterpret_index_labels(::StaticInt{N}, fields::NTuple{M,Symbol}, ks::Tuple) where {N,M} - N === M ? (fields, ks...,) : (nothing, ks...,) + is_forwarding_wrapper(x) ? index_labels(buffer(x)) : map(UnlabelledIndices, axes(x)) end -_reinterpret_index_labels(::StaticInt{1}, x::Base.ReshapedReinterpretArray) = index_labels(parent(x)) -_reinterpret_index_labels(::StaticInt{0}, x::Base.ReshapedReinterpretArray) = Base.tail(index_labels(parent(x))) +index_labels(axis::LazyAxis{N}) where {N} = (index_labels(getfield(axis, :parent), N),) diff --git a/src/dimensions.jl b/src/dimensions.jl index f9d1b9363..d562be58f 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -8,10 +8,6 @@ function _init_dimsmap(@nospecialize info::IndicesInfo) ntuple(i -> static(getfield(cdims, i)), length(pdims)) end -parentdims(::IndicesInfo{<:Any,pdims}) where {pdims} = pdims - -childdims(::IndicesInfo{<:Any,<:Any,cdims}) where {cdims} = cdims - """ to_parent_dims(::Type{T}) -> Tuple{Vararg{Union{StaticInt,Tuple{Vararg{StaticInt}}}}} diff --git a/src/indexing.jl b/src/indexing.jl index 0d5270816..88029b49b 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -170,9 +170,6 @@ end @inline function to_index(x, i::Base.Fix2{typeof(>=),<:Union{Base.BitInteger,StaticInt}}) max(canonicalize(i.x), static_first(x)):static_last(x) end -@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel}) - findall(i.f(i.x.label), first(index_labels(x))) -end @inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}}) max(_add1(canonicalize(i.x)), static_first(x)):static_last(x) end @@ -180,20 +177,26 @@ to_index(x, i::AbstractArray{<:Union{Base.BitInteger,StaticInt}}) = i to_index(x, @nospecialize(i::StaticInt)) = i to_index(x, i::Integer) = Int(i) @inline to_index(x, i) = to_index(IndexStyle(x), x, i) -# key indexing -function to_index(x, i::IndexLabel) - index = findfirst(==(getfield(i, :label)), first(index_labels(x))) +# label indexing +to_index(x, i::IndexLabel) = to_index(getfield(index_labels(x), 1), i) +function to_index(x::LabelledIndices, i::IndexLabel) + index = findfirst(==(getfield(i, :label)), parent(x)) # delay throwing bounds-error if we didn't find label - index === nothing ? offset1(x) - 1 : index -end -function to_index(x, i::Union{Symbol,AbstractString,AbstractChar,Number}) - index = findfirst(==(i), getfield(index_labels(x), 1)) - index === nothing ? offset1(x) - 1 : index + index === nothing ? typemin(Int) : index end +to_index(x, i::Union{Symbol,AbstractString,AbstractChar,Number}) = to_index(x, IndexLabel(i)) # TODO there's probably a more efficient way of doing this +to_index(x, i::LabelledIndices) = to_index(getfield(index_labels(x), 1), i) +to_index(x::LabelledIndices, i::LabelledIndices) = findall(in(parent(i)), parent(x)) to_index(x, ks::AbstractArray{<:IndexLabel}) = [to_index(x, k) for k in ks] -function to_index(x, ks::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar,Number}}) - [to_index(x, k) for k in ks] +function to_index(x, i::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar,Number}}) + to_index(x, LabelledIndices(i)) +end +@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel}) + to_index(getfield(index_labels(x), 1), i) +end +@inline function to_index(x::LabelledIndices, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel}) + findall(i.f(i.x.label), parent(x)) end # integer indexing diff --git a/test/axes.jl b/test/axes.jl index dac3a7d71..8a2b333e5 100644 --- a/test/axes.jl +++ b/test/axes.jl @@ -107,41 +107,20 @@ end @testset "index_labels" begin colors = LabelledArray([(R = rand(), G = rand(), B = rand()) for i ∈ 1:100], (range(-10, 10, length=100),)); - caxis = ArrayInterface.LazyAxis{1}(colors); - colormat = reinterpret(reshape, Float64, colors); - cmat_view1 = view(colormat, :, 4); - cmat_view2 = view(colormat, :, 4:7); - cmat_view3 = view(colormat, 2:3,:); - absym_abstr = LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],)); - - @test @inferred(ArrayInterface.index_labels(colors)) == (range(-10, 10, length=100),) - @test @inferred(ArrayInterface.index_labels(caxis)) == (range(-10, 10, length=100),) - @test ArrayInterface.index_labels(view(colors, :, :), 2) === nothing - @test @inferred(ArrayInterface.index_labels(LinearIndices((caxis,)))) == (range(-10, 10, length=100),) - @test @inferred(ArrayInterface.index_labels(colormat)) == ((:R, :G, :B), range(-10, 10, length=100)) - @test @inferred(ArrayInterface.index_labels(colormat')) == (range(-10, 10, length=100), (:R, :G, :B)) - @test @inferred(ArrayInterface.index_labels(cmat_view1)) == ((:R, :G, :B),) - @test @inferred((ArrayInterface.index_labels(cmat_view2))) == ((:R, :G, :B), -9.393939393939394:0.20202020202020202:-8.787878787878787) - @test @inferred((ArrayInterface.index_labels(view(colormat, 1, :)'))) == (nothing, range(-10, 10, length=100)) - # can't infer this b/c tuple is being indexed by range - @test ArrayInterface.index_labels(cmat_view3) == ((:G, :B), -10.0:0.20202020202020202:10.0) - @test @inferred(ArrayInterface.index_labels(Symmetric(view(colormat, :, 1:3)))) == ((:R, :G, :B), -10.0:0.20202020202020202:-9.595959595959595) - - @test @inferred(ArrayInterface.index_labels(reinterpret(Int8, absym_abstr))) == (nothing, ["a", "b"]) - @test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Int8, absym_abstr))) == (nothing, [:a, :b], ["a", "b"]) - @test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Int64, LabelledArray(rand(Int32, 2,2), ([:a, :b], ["a", "b"],))))) == (["a", "b"],) - @test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Float64, LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],) - @test @inferred(ArrayInterface.index_labels(reinterpret(Float64, absym_abstr))) == ([:a, :b], ["a", "b"],) - - @test ArrayInterface.has_index_labels(colors) - @test ArrayInterface.has_index_labels(caxis) - @test ArrayInterface.has_index_labels(colormat) - @test ArrayInterface.has_index_labels(cmat_view1) - @test !ArrayInterface.has_index_labels(view(colors, :, :)) - - @test @inferred(ArrayInterface.getindex(colormat, :R, :)) == colormat[1, :] - @test @inferred(ArrayInterface.getindex(cmat_view1, :R)) == cmat_view1[1] - @test @inferred(ArrayInterface.getindex(colormat, :,ArrayInterface.IndexLabel(-9.595959595959595))) == colormat[:, 3] - @test @inferred(ArrayInterface.getindex(colormat, :,<=(ArrayInterface.IndexLabel(-9.595959595959595)))) == colormat[:, 1:3] - @test @inferred(ArrayInterface.getindex(absym_abstr, :, ["a"])) == absym_abstr[:,[1]] + + @test parent(@inferred(ArrayInterface.index_labels(colors))[1]) == range(-10, 10, length=100) + @test parent(ArrayInterface.index_labels(colors, 1)) == range(-10, 10, length=100) + @test ArrayInterface.index_labels(colors, 2) == ArrayInterface.UnlabelledIndices(axes(colors, 2)) + @test ArrayInterface.index_labels(parent(colors)) == map(ArrayInterface.UnlabelledIndices, axes(colors)) + + label = ArrayInterface.IndexLabel(ArrayInterface.getlabels(ArrayInterface.index_labels(colors)[1], 3)) + @test @inferred(ArrayInterface.getindex(colors, label)) == colors[3] + @test @inferred(ArrayInterface.getindex(colors, <=(label))) == colors[1:3] end + +#= +ArrayInterface.to_indices(colors, (label,)) +axis = ArrayInterface.lazy_axes(colors)[1] +labels = ArrayInterface.index_labels(axis)[1] +ArrayInterface.to_index(labels, label) +=# diff --git a/test/setup.jl b/test/setup.jl index 6c2b33128..6e9c99f58 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -43,7 +43,7 @@ end ArrayInterface.is_forwarding_wrapper(::Type{<:LabelledArray}) = true Base.parent(x::LabelledArray) = getfield(x, :parent) ArrayInterface.parent_type(::Type{T}) where {P,T<:LabelledArray{<:Any,<:Any,P}} = P -ArrayInterface.index_labels(x::LabelledArray) = getfield(x, :labels) +ArrayInterface.index_labels(x::LabelledArray) = map(ArrayInterface.LabelledIndices, getfield(x, :labels)) # Dummy array type with undetermined contiguity properties struct DummyZeros{T,N} <: AbstractArray{T,N} From 809839b6f033cd994ec03effc1de7e18bceeb550 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Thu, 29 Sep 2022 05:03:40 -0400 Subject: [PATCH 15/16] Simplify initial index labels PR * removed support for nested traversal of `index_labels`. Support for this can be added later but it required making a lot of decisions about managing labels that up front that would be better addressed through iterative PRs * removed `has_index_labels`. There are some odd corner cases for this one. Particularly for `SubArrays` where the presence of labels in the parent don't always have a clear way of propagating forward. Again, we can address this one but it will take some decisions about how labels are propagated. * `UnlabelledIndices` and `LabelledIndices` are types that provide a more clear structure to what a label is and how they are accessed. For now access only is supported through `getlabels` and `setlabels` --- lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl index bcdd82f9d..e286446fb 100644 --- a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl +++ b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl @@ -719,8 +719,8 @@ Base.last(x::LabelledIndices) = lastindex(parent(x)) """ getlabels(x, idx) -Given a collection of labelled indices (`x`), the subset of lablled indices corresponding -to the index `idx` are returned. +Given a collection of labelled indices (`x`), returns the subset of lablled indices +corresponding to the index `idx` are returned. """ Base.@propagate_inbounds function getlabels(x::LabelledIndices, idx::I) where {I} ndims_shape(I) === 0 ? parent(x)[idx] : LabelledIndices(parent(x)[idx]) @@ -733,8 +733,7 @@ end """ setlabels!(x, idx, vals) -Given a collection of labelled indices (`x`), the subset of lablled indices corresponding -to the index `idx` are returned. +Sets new labels `vals` at the indices `idx` for the collection of labelled indices `x`. """ Base.@propagate_inbounds setlabels!(x::LabelledIndices, i, v) = setindex!(parent(x), i, v) From 0ad5cd603dc2fa32b3435f2bfa204df9336c11f8 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Thu, 29 Sep 2022 05:04:59 -0400 Subject: [PATCH 16/16] Simplify initial index labels PR * removed support for nested traversal of `index_labels`. Support for this can be added later but it required making a lot of decisions about managing labels that up front that would be better addressed through iterative PRs * removed `has_index_labels`. There are some odd corner cases for this one. Particularly for `SubArrays` where the presence of labels in the parent don't always have a clear way of propagating forward. Again, we can address this one but it will take some decisions about how labels are propagated. * `UnlabelledIndices` and `LabelledIndices` are types that provide a more clear structure to what a label is and how they are accessed. For now access only is supported through `getlabels` and `setlabels` --- src/ArrayInterface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index a9639b63c..2bae3d9ca 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -7,7 +7,7 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type, ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo, childdims, parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, IndexLabel, - LabelledIndices, getlabels, UnlabelledIndices, defines_strides, stride_preserving_index + LabelledIndices, getlabels, setlabels!, UnlabelledIndices, defines_strides, stride_preserving_index # ArrayIndex subtypes and methods import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex