Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

axis_keys #260

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions src/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ function _non_reshaped_axis_type(::Type{A}, d::StaticInt{D}) where {A,D}
end
end

# FUTURE NOTE: we avoid `SOneTo(1)` when `axis(A, dim::Int)``. This is inended to decreases
# breaking changes for this adopting this method to situations where they clearly benefit
# FUTURE NOTE: we avoid `SOneTo(1)` when `axis(A, dim::Int)``. This is intended to decrease
# breaking changes for adapting this method to situations where there's clearly benefit
# from the propagation of static axes. This creates the somewhat awkward situation of
# conditionally typed (but inferrable) axes. It also means we can't depend on constant
# propagation to preserve statically sized axes. This should probably be addressed before
Expand Down Expand Up @@ -293,3 +293,37 @@ lazy_axes(x::CartesianIndices) = axes(x)
@inline lazy_axes(x::VecAdjTrans) = (SOneTo{1}(), first(lazy_axes(parent(x))))
@inline lazy_axes(x::PermutedDimsArray) = permute(lazy_axes(parent(x)), to_parent_dims(x))

"""
axes_keys(x)
axes_keys(x, dim)


Returns a tuple of keys assigned to each axis or the axis at dimension `dim` for `x`. Default
is to simply return `map(keys, axes(x))`.
"""
@inline axes_keys(A) = map(keys, axes(A))
axes_keys(A::PermutedDimsArray) = permute(axes_keys(parent(A)), to_parent_dims(A))
axes_keys(A::MatAdjTrans) = permute(axes_keys(parent(A)), to_parent_dims(A))
axes_keys(A::VecAdjTrans) = (SOneTo{1}(), axes(parent(A), 1))
function axes_axes(A::SubArray{T,N}) where {T,N}
pdims = to_parent_dims(A)
ntuple(dim -> axes_keys(parent(A), pdims[dim])[A.indices[dim]], Val(N))
end

axes_keys(A, dim) = axes_keys(A, to_dims(A, dim))
@inline function axes_keys(A, dim::Int)
if dim > ndims(A)
return OneTo(1)
else
return getfield(axes_keys(A), dim)
end
end
@inline function axes_keys(A, ::StaticInt{dim}) where {dim}
if dim > ndims(A)
return SOneTo{1}()
else
return getfield(axes(A), dim)
end
end
axes_keys(axis::LazyAxis{N,P}) where {N,P} = axes_keys(getfield(x, :parent), static(N))