From 0b8e1fcbea0e890d64cf2b8d646c04664cad0fbb Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Thu, 7 Apr 2022 06:16:09 -0400 Subject: [PATCH 1/2] axis_keys --- src/axes.jl | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/axes.jl b/src/axes.jl index 9fdb4f787..569889d6f 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -77,8 +77,8 @@ function _non_reshaped_axis_type(::Type{A}, d::StaticInt{D}) where {A,D} end end -# FUTURE NOTE: we avoid `SOneTo(1)` when `axis(A, dim::Int)``. This is inended to decreases -# breaking changes for this adopting this method to situations where they clearly benefit +# FUTURE NOTE: we avoid `SOneTo(1)` when `axis(A, dim::Int)``. This is intended to decrease +# breaking changes for adapting this method to situations where there's clearly benefit # from the propagation of static axes. This creates the somewhat awkward situation of # conditionally typed (but inferrable) axes. It also means we can't depend on constant # propagation to preserve statically sized axes. This should probably be addressed before @@ -293,3 +293,37 @@ lazy_axes(x::CartesianIndices) = axes(x) @inline lazy_axes(x::VecAdjTrans) = (SOneTo{1}(), first(lazy_axes(parent(x)))) @inline lazy_axes(x::PermutedDimsArray) = permute(lazy_axes(parent(x)), to_parent_dims(x)) +""" + axes_keys(x) + axes_keys(x, dim) + + +Returns a tuple of keys assigned to each axis or the axis at dimension `dim` for `x`. Default +is to simply return `map(keys, axes(x))`. +""" +@inline axes_keys(A) = map(keys, axes(A)) +axes_keys(A::PermutedDimsArray) = permute(axes_keys(parent(A)), to_parent_dims(A)) +axes_keys(A::MatAdjTrans) = permute(axes_keys(parent(A)), to_parent_dims(A)) +axes_keys(A::VecAdjTrans) = (SOneTo{1}(), axes(parent(A), 1)) +function axes_axes(A::SubArray{T,N}) where {T,N} + pdims = to_parent_dims(A) + ntuple(dim -> axes_keys(parent(A), pdims[dim])[A.indices[dim], Val(N)) +end + +axes_keys(A, dim) = axes_keys(A, to_dims(A, dim)) +@inline function axes_keys(A, dim::Int) + if dim > ndims(A) + return OneTo(1) + else + return getfield(axes_keys(A), dim) + end +end +@inline function axes_keys(A, ::StaticInt{dim}) where {dim} + if dim > ndims(A) + return SOneTo{1}() + else + return getfield(axes(A), dim) + end +end +axes_keys(axis::LazyAxis{N,P}) where {N,P} = axes_keys(getfield(x, :parent), static(N)) + From 672554dc6c68be3255367f8ee2e58f26efe266ad Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Thu, 7 Apr 2022 11:13:34 -0400 Subject: [PATCH 2/2] Fix typo :/ --- src/axes.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axes.jl b/src/axes.jl index 569889d6f..8b8499875 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -307,7 +307,7 @@ axes_keys(A::MatAdjTrans) = permute(axes_keys(parent(A)), to_parent_dims(A)) axes_keys(A::VecAdjTrans) = (SOneTo{1}(), axes(parent(A), 1)) function axes_axes(A::SubArray{T,N}) where {T,N} pdims = to_parent_dims(A) - ntuple(dim -> axes_keys(parent(A), pdims[dim])[A.indices[dim], Val(N)) + ntuple(dim -> axes_keys(parent(A), pdims[dim])[A.indices[dim]], Val(N)) end axes_keys(A, dim) = axes_keys(A, to_dims(A, dim))