-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create staticint.jl #44538
base: master
Are you sure you want to change the base?
Create staticint.jl #44538
Conversation
Avoid generating new code when `StaticInt` will ultimately be lowered to `Int` anyway.
If there is nothing in Base that uses this, why would it not be a package? |
There currently isn't, but the idea is there should be. For example, if we had support for this in ranges EDIT: |
Could we define |
I know of four uses for It would be so nice if we could use |
Xref #41853, about adding static |
This is a good concrete example of where it's nice have this in Base. Now take the approach of Pretty much anywhere there's a new array created or subsequent calls depend on
This is pretty close to how I feel about it. Hopefully this comment makes it more clear what we would need from the compiler for this sort of thing to be unnecessary.
|
I'm not sure how the folks at StaticArrays are doing it, but I think it is hard to improve on the constant propigation going on in this specific case where subsequent calls depend on a = @SVector rand(100);
b = @SVector rand(100);
f(a, b) = sum(eachindex(a, b)) + sum(sum.(axes(a)))
@code_llvm f(a, b)
; @ none:1 within `f`
define i64 @julia_f_11329([1 x [100 x double]]* nocapture nonnull readonly align 8 dereferenceable(800) %0, [1 x [100 x double]]* nocapture nonnull readonly align 8 dereferenceable(800) %1) #0 {
top:
ret i64 10100
} |
But |
That's because |
How would this work? Would we parse |
I'm not saying that example is solved with My point was just that if the same code is generated without the size in the type domain then we wouldn't need |
In theory the compiler could add a lattice element for an array with known size, and propagate constants to know the size of |
What if we had EDIT: I'm not sure about this idea. One problem is that for concrete types T we'd be trying to subtype a concrete type. |
I figured that this was the case. In addition to making it so more methods actually can propagate static info, there might be some more intelligent codegen the compiler can be aware of if there's a dedicated type for this. For example, perhaps |
Newest changes have the basic types and traits that we've found useful over in ArrayInterface so far. These don't necessarily need to be a part of this PR or have have the current names. I just thought it would be helpful to provide a little more context. Some places where this is helpful to have in base are |
Triage discussed this but we don't have any hard conclusions yet. |
Really appreciate the update. Just let me know if there's anything else that's needed from me to help solve this issue (whether it be |
I'd love to have static ints (or even static numbers in general) in Base as well. We have several uses cases in which static ints are not confined to a function but become part of types that get passed around (so that "staticness" can be inferred in a different place of the code). |
As I understand it, constant propagation is meant to carry small amounts of compile time known information small distances. Take the simple example of indexing an There are lots of examples where things get screwy internally because Base doesn't know how to handle statically sized indices. I'm sure there are other benefits on the compiler/interpreter side of things, but I'm not too savvy to that kind of stuff. |
I think one prominent example is the following (as discussed e.g. at JuliaArrays/StaticArrays.jl#807): in order to encode the size This would allow StaticArrays.jl to drop some of the awkward juggling between (Of course, this would be quite breaking on the StaticArrays.jl side...) |
There are a handful of things in |
@JeffBezanson , were you thinking of something like this @inline function ntuple(f::F, n::Integer) where F
(n >= 0) || throw(ArgumentError(string("tuple length should be ≥ 0, got ", n)))
if @generated
if n <: StaticInt
quote
Base.Cartesian.@ntuple $(n.parameters[1]) i->f(i)
end
else
:(_ntuple(f, n))
end
else
_ntuple(f, n)
end
end
@inline function _ntuple(f::F, n) where F
# marked inline since this benefits from constant propagation of `n`
t = n == 0 ? () :
n == 1 ? (f(1),) :
n == 2 ? (f(1), f(2)) :
n == 3 ? (f(1), f(2), f(3)) :
n == 4 ? (f(1), f(2), f(3), f(4)) :
n == 5 ? (f(1), f(2), f(3), f(4), f(5)) :
n == 6 ? (f(1), f(2), f(3), f(4), f(5), f(6)) :
n == 7 ? (f(1), f(2), f(3), f(4), f(5), f(6), f(7)) :
n == 8 ? (f(1), f(2), f(3), f(4), f(5), f(6), f(7), f(8)) :
n == 9 ? (f(1), f(2), f(3), f(4), f(5), f(6), f(7), f(8), f(9)) :
n == 10 ? (f(1), f(2), f(3), f(4), f(5), f(6), f(7), f(8), f(9), f(10)) :
__ntuple(f, n)
return t
end
@noinline __ntuple(f::F, n) where F = ([f(i) for i = 1:n]...,) |
The new |
We were able to work around it a bit by creating StaticArraysCore and then ArrayInterfaceStaticArraysCore so that there could be an interface which avoids the static behavior to allow for it to only exist much later, which cut down downstream load times by about 4 seconds. Thus with diligent programming we can work around this issue. That said, I'm not sure how many groups are going to appropriately use ArrayInterfaceStaticArraysCore, StaticArraysCore, and GPUArraysCore, and if not done correctly you get bad load times and lots of invalidations. To me, the existence of ArrayInterfaceStaticArraysCore says something is wrong here, but it appears to not be an opinion shared by some others in this thread so 🤷 |
I'm not completely convinced that |
Now that we're working on 1.9, is this worth discussing further? |
I hope so, and also that |
Before revisiting this I think it might be helpful to share something I put together in a private discussion with Tim Holy and Chris Elrod last month. In the code below I wrote custom stride and size methods ( We can rely on constant propagation if every intervening method is extremely careful to preserve that information. However, it often fails for generic methods julia> function first_of_length(x, y)
StaticInt(getfield(get_size(first(x, length(y))), 1))
end
julia> @inferred first_of_length(1:20, static(1):static(10))
ERROR: return type StaticInt{10} does not match inferred return type StaticInt This can still be resolved if we have very carefully defined methods for specific situations. julia> function first_of_length(x::Base.OneTo, ::ArrayInterface.SUnitRange{1,N}) where {N}
StaticInt(Base.OneTo{Int}(N).stop)
end
julia> @inferred first_of_length(Base.OneTo(20), static(1):static(10))
static(10)
I think the take away here is that we can rely on constant propagation for passing around persistent compile time known parameters if we are willing to sacrifice generic programming in cases where it's important. If this is what we want to do then we should probably add more public API hooks for making this possible in certain methods without completely rewriting chunks of Base. If we want to support more generic methods that preserve this info we need to decide if this should be supported through types (e.g., If we were to rely on non-type driven solutions we'd need something that can be used to variably mark things as statically known across methods and as fields of types. static sz = (3, 3, 3) # mark these values as statically known
rr = reshape(1:27, sz)
isstatic(size(rr)) == true # dims field of ReshapedArray is static here If this is possible it would be preferable to making bunch of parametric types to support this. However, before committing to something like that I'd like to know that it truly is the next step in the plan on handling this so we aren't just spinning our wheels waiting for a solution that may never come. # this can be plugged into https://github.com/JuliaArrays/ArrayInterface.jl/blob/bf07fa9cbdf482125ad44e8c1c4095111fe06206/test/indexing.jl#L253 to run the tests in "test/size.jl" and "test/stridelayout.jl"
struct MArray{size,T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N}
parent::P
MArray(x::AbstractArray{T,N}) where {T,N} = new{size(x),T,N,typeof(x)}(x)
end
Base.parent(x::MArray) = x.parent
Base.size(::MArray{S}) where {S} = S
Base.getindex(x::MArray, inds...) = x.parent[inds...]
ArrayInterface.defines_strides(::Type{<:MArray}) = true
ArrayInterface.is_forwarding_wrapper(::Type{<:MArray}) = true
ArrayInterface.parent_type(::Type{<:MArray{<:Any,<:Any,<:Any,P}}) where {P} = P
#Base.strides(::MArray{size}) where {size} = ArrayInterface.size_to_strides(
get_size(a) = Base.size(a)
get_size(a, dim) = Base.size(a)[dim]
@inline function get_size(x::SubArray)
ArrayInterface.flatten_tuples(_get_size(parent(x), x.indices, ArrayInterface.IndicesInfo(x)))
end
function _get_size(p::AbstractArray, inds::Tuple, ::ArrayInterface.IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
ntuple(Val{nfields(pdims)}()) do i
pdim_i = getfield(pdims, i)
cdim_i = getfield(cdims, i)
if pdim_i isa Int && cdim_i isa Int
inds_i = getfield(inds, i)
if inds_i isa Base.Slice{Base.OneTo{Int}}
getfield(get_size(p), pdim_i)
else
get_size(inds_i)
end
else
get_size(getfield(inds, i))
end
end
end
@inline get_size(B::ArrayInterface.VecAdjTrans) = (1, getfield(get_size(parent(B)), 1))
function get_size(A::PermutedDimsArray{<:Any,N,perm}) where {N,perm}
psize = get_size(parent(A))
ntuple(i->getfield(psize, getfield(perm, i)), Val{N}())
end
function get_size(A::ArrayInterface.MatAdjTrans)
s1, s2 = get_size(parent(A))
(s2, s1)
end
@inline function get_size(a::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S}
psize = get_size(parent(a))
(div(getfield(psize, 1) * sizeof(S), sizeof(T)), Base.tail(psize)...)
end
@inline function get_size(a::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S}
psize = get_size(parent(a))
ss = sizeof(S)
ts = sizeof(T)
if ss === ts
return psize
elseif ss > ts
return (div(ss, ts), psize...,)
else
return Base.tail(psize)
end
end
get_size(A::AbstractRange) = (length(A),)
@inline get_strides(A) = ArrayInterface.size_to_strides(get_size(A), 1)
# `strides` for `Base.ReinterpretArray`
function get_strides(A::Base.ReinterpretArray{T,<:Any,S,<:AbstractArray{S},IsReshaped}) where {T,S,IsReshaped}
ArrayInterface._is_column_dense(parent(A)) && return ArrayInterface.size_to_strides(get_size(A), 1)
stp = get_strides(parent(A))
ET, ES = static(sizeof(T)), static(sizeof(S))
ET === ES && return stp
IsReshaped && ET < ES && return (One(), _reinterp_strides(stp, ET, ES)...)
first(stp) == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
if IsReshaped
# The wrapper tell us `A`'s parent has static size in dim1.
# We can make the next stride static if the following dim is still dense.
sr = ArrayInterface.stride_rank(parent(A))
dd = ArrayInterface.dense_dims(parent(A))
stp′ = _new_static(stp, sr, dd, ET ÷ ES)
return _reinterp_strides(Base.tail(stp′), ET, ES)
else
return (One(), _reinterp_strides(Base.tail(stp), ET, ES)...)
end
end
_new_static(P,_,_,_) = P # This should never be called, just in case.
@generated function _new_static(p::P, ::SR, ::DD, ::StaticInt{S}) where {S,N,P<:NTuple{N,Union{Int,StaticInt}},SR<:NTuple{N,StaticInt},DD<:NTuple{N,StaticBool}}
sr = fieldtypes(SR)
j = findfirst(T -> T() == sr[1]()+1, sr)
if !isnothing(j) && !(fieldtype(P, j) <: StaticInt) && fieldtype(DD, j) === True
return :(tuple($((i == j ? :(static($S)) : :(p[$i]) for i in 1:N)...)))
else
return :(p)
end
end
@inline function _reinterp_strides(stp::Tuple, els::StaticInt, elp::StaticInt)
if elp % els == 0
N = elp ÷ els
return map(Base.Fix2(*, N), stp)
else
return map(stp) do i
d, r = divrem(elp * i, els)
iszero(r) || throw(ArgumentError("Parent's strides could not be exactly divided!"))
d
end
end
end
get_strides(@nospecialize x::AbstractRange) = (1,)
function get_strides(x::ArrayInterface.VecAdjTrans)
st = getfield(get_strides(parent(x)), 1)
return (st, st)
end
function get_strides(x::ArrayInterface.MatAdjTrans)
s1, s2 = get_strides(parent(x))
(s2, s1)
end
function get_strides(A::PermutedDimsArray{<:Any,N,perm}) where {N,perm}
psize = get_strides(parent(A))
ntuple(i->getfield(psize, getfield(perm, i)), Val{N}())
end
getmul(x::Tuple, y::Tuple, ::StaticInt{i}) where {i} = getfield(x, i) * getfield(y, i)
function get_strides(A::SubArray)
Static.eachop(getmul, ArrayInterface.to_parent_dims(typeof(A)), map(maybe_static_step, A.indices), get_strides(parent(A)))
end
maybe_static_step(x::AbstractRange) = ArrayInterface.static_step(x)
maybe_static_step(_) = nothing
static_get_size(x, ::StaticInt{dim}) where {dim} = StaticInt(getfield(get_size(x), dim))
static_get_stride(x, ::StaticInt{dim}) where {dim} = StaticInt(getfield(get_strides(x), dim)) |
Some time ago, I spent a month or more working with StaticInts, StaticFloats, and StaticBools to see how they might be slightly improved or better used. Unless there is putting one or more of them into Base, where they benefit many who would never work with one of the Static_.jl packages [I think that is most devs who are not reading this], we leave an obligation to spend real time going forward beating back issues discussed above. |
@JeffBezanson , you mentioned back in June that triage was back on this. Is there some other solution in the works behind the scenes or should I be updating this and keeping the discussion going? |
@timholy you mentioned getting this in before the feature freeze for 1.9 #44519 (comment). Is this still the plan? |
I really think we need to have some of the proponents of adding this to base join the triage call some time and make the case for it—and probably more importantly to answer questions in real time. We've discussed it numerous times and never been convinced. This implementation in response to @JeffBezanson's request is pretty unsatisfying since it just takes what was previously two separate method bodies and combines them by adding a branch on the type check. What would be compelling is if there's an actual reduction in the amount of code and logic required to handle both static and non-static cases with a single code base. |
Since
I think Jeff was referring to the fact that triage has talked about this lots of times, but always without the people involved in wanting to get this merged 🤷 Next triage meeting (other than today) is on the 13th at 20:15 UTC, I think - someone from this PR joining would be great! |
I think the implementation of julia> using StaticNumbers
julia> my_ntuple(f, n) = Tuple(f(i) for i in Base.OneTo(n));
julia> let n = 12, ns = static(n), nv = Val(n)
f(x) = sin(x + 1)
@btime my_ntuple($f, $ns)
@btime ntuple($f, $nv)
end;
1.257 ns (0 allocations: 0 bytes)
1.258 ns (0 allocations: 0 bytes)
julia> let n = 120, ns = static(n), nv = Val(n)
f(x) = sin(x + 1)
@btime my_ntuple($f, $ns)
@btime ntuple($f, $nv)
end;
15.601 ns (0 allocations: 0 bytes)
17.599 ns (0 allocations: 0 bytes) For a julia> let n = 3, ns = static(n), nv = Val(n)
y = 1
f = Ref(x -> sin(x + y))
@btime my_ntuple($f[], $ns)
@btime ntuple($f[], $nv)
end;
20.779 ns (0 allocations: 0 bytes)
20.960 ns (0 allocations: 0 bytes)
julia> let n = 13, ns = static(n), nv = Val(n)
y = 1
f = Ref(x -> sin(x + y))
@btime my_ntuple($f[], $ns)
@btime ntuple($f[], $nv)
end;
101.147 ns (0 allocations: 0 bytes)
96.864 ns (0 allocations: 0 bytes)
julia> let n = 50, ns = static(n), nv = Val(n)
y = 1
f = Ref(x -> sin(x + y))
@btime my_ntuple($f[], $ns)
@btime ntuple($f[], $nv)
end;
420.814 ns (0 allocations: 0 bytes)
408.480 ns (0 allocations: 0 bytes)
julia> let n = 150, ns = static(n), nv = Val(n)
y = 1
f = Ref(x -> sin(x + y))
@btime my_ntuple($f[], $ns)
@btime ntuple($f[], $nv)
end;
1.317 μs (0 allocations: 0 bytes)
1.257 μs (0 allocations: 0 bytes) |
I wasn't 100% sure what the end goal was for the
StaticNumbers.jl is still doing some extra work there behind the scenes and ultimately ends up using a non-optionally generated Feeding using ArrayInterface
using ArrayInterface: known_length
using BenchmarkTools
using Static
@inline itr2tuple(f::F, n::N) where {F,N} = _to_tuple(Iterators.map(f, static(1):n))
function _to_tuple(itr::I) where {I}
if @generated
N = known_length(I)
quote
@inline
v_1, s_1 = iterate(itr)
Base.Cartesian.@nexprs $(N - 1) i -> (v_{i+1}, s_{i+1}) = iterate(itr, s_i)
Base.Cartesian.@ntuple $(N) i->v_i
end
else
(itr...,)
end
end It also has comparable performance to julia> let n = 12, ns = static(n), nv = Val(n)
f(x) = sin(x + 1)
@btime itr2tuple($f, $ns)
@btime ntuple($f, $nv)
end;
1.516 ns (0 allocations: 0 bytes)
1.516 ns (0 allocations: 0 bytes)
julia> let n = 120, ns = static(n), nv = Val(n)
f(x) = sin(x + 1)
@btime itr2tuple($f, $ns)
@btime ntuple($f, $nv)
end;
11.205 ns (0 allocations: 0 bytes)
11.777 ns (0 allocations: 0 bytes)
julia> let n = 3, ns = static(n), nv = Val(n)
y = 1
f = Ref(x -> sin(x + y))
@btime itr2tuple($f[], $ns)
@btime ntuple($f[], $nv)
end;
22.266 ns (0 allocations: 0 bytes)
22.281 ns (0 allocations: 0 bytes)
julia> let n = 13, ns = static(n), nv = Val(n)
y = 1
f = Ref(x -> sin(x + y))
@btime itr2tuple($f[], $ns)
@btime ntuple($f[], $nv)
end;
99.429 ns (0 allocations: 0 bytes)
101.203 ns (0 allocations: 0 bytes)
julia> let n = 50, ns = static(n), nv = Val(n)
y = 1
f = Ref(x -> sin(x + y))
@btime itr2tuple($f[], $ns)
@btime ntuple($f[], $nv)
end;
417.814 ns (0 allocations: 0 bytes)
418.523 ns (0 allocations: 0 bytes)
julia> let n = 150, ns = static(n), nv = Val(n)
y = 1
f = Ref(x -> sin(x + y))
@btime itr2tuple($f[], $ns)
@btime ntuple($f[], $nv)
end;
1.286 μs (0 allocations: 0 bytes)
1.290 μs (0 allocations: 0 bytes)
Of course! It would be great if it were as simple as marking some variable. Then we could just do |
In #47206 we have discussed accepting static bools (or ints) as the |
We talked about this in triage in October and I don't think there was a real solid conclusion. There are places where this would be completely fine as an independent package if we solved issues with latency (invalidations). However, it isn't clear this is the best way to tackle standardization of dynamic versus static components within a single code base. This isn't because there's another solution out there that more clearly solves this. I think this idea still has some merit but it would probably require making some more guide line based approaches to coding in addition to the actual implementation. For example, this gist defines a handful of reference types that could be treated like Rust's There's some fancy compiler stuff that could happen, but it seems to be very far off from even reaching the conceptual level in Julia. Maybe once custom compiler passes become more accessible someone will jump in with something innovative. Maybe something will change but for now I'd recommend using Static.jl. |
|
||
%(@nospecialize(n::StaticInt), ::Type{Integer}) = Int(n) | ||
|
||
eltype(@nospecialize(T::Type{<:StaticInt})) = Int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why make eltype(StaticInt)
be Int
? I assume it'd be identity by default? What would be wrong with that?
The reason I'm asking is that I have a package for (static) natural numbers here, and I'm considering having it subtype Integer
: https://juliahub.com/ui/Packages/General/TypeDomainNaturalNumbers
BTW my package seems like a good alternative for the code here, because it's safer in two ways:
- It doesn't wrap other types, so there are no issues with needing to assume that the wrapped type is
Int
, like here - negative numbers don't exist, which is safer for representing sizes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the invalidations are prohibitive. It seems any new Integer
subtypes have to be compiled together with Base
.
See this conversation for the motivation for this.