Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create staticint.jl #44538

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft

Create staticint.jl #44538

wants to merge 8 commits into from

Conversation

Tokazama
Copy link
Contributor

@Tokazama Tokazama commented Mar 9, 2022

See this conversation for the motivation for this.

@ViralBShah ViralBShah marked this pull request as draft March 9, 2022 14:17
Avoid generating new code when `StaticInt` will ultimately be lowered to
`Int` anyway.
base/staticint.jl Outdated Show resolved Hide resolved
base/staticint.jl Outdated Show resolved Hide resolved
@KristofferC
Copy link
Member

If there is nothing in Base that uses this, why would it not be a package?

@Tokazama
Copy link
Contributor Author

Tokazama commented Mar 9, 2022

If there is nothing in Base that uses this, why would it not be a package?

There currently isn't, but the idea is there should be. For example, if we had support for this in ranges eachindex(::IndexLinear, x::AbstractArray...) could return a statically sized set of indices if any of the arrays had them. I figured this would be a good start before trying to add everything at once.

EDIT:
Or rather, there should be if we can't expect the compiler to soon provide support for problems discussed here #44519. If solves this stuff in the compiler then I'm open to that being used instead.

base/staticint.jl Outdated Show resolved Hide resolved
base/staticint.jl Outdated Show resolved Hide resolved
@LilithHafner
Copy link
Member

Could we define eachindex in Base specialize eachindex for StaticArrays in StaticArrays.jl, and have StaticArrays.jl depend on StaticInts.jl?

@LilithHafner
Copy link
Member

I know of four uses for Val in Base: tuples, array dimensions, literal exponentiation, and base/special which I think does fancy mth optimizations. Of these, the first three use cases could all be entirely replaced by StaticInt while the usage in base/special includes things of the form logbL(::Type{Float64},::Val{:ℯ}) = 0.0 which cannot be.

It would be so nice if we could use Int instead of StaticInt. I'm rooting for the compiler people here, but in the interim, I think this solution fits a real-world use case in base and elsewhere.

@JeffBezanson JeffBezanson added the triage This should be discussed on a triage call label Mar 9, 2022
@mcabbott
Copy link
Contributor

Xref #41853, about adding static Zero & One.

@Tokazama
Copy link
Contributor Author

Tokazama commented Mar 10, 2022

Could we define eachindex in Base specialize eachindex for StaticArrays in StaticArrays.jl, and have StaticArrays.jl depend on StaticInts.jl?

This is a good concrete example of where it's nice have this in Base.
If you take the definition you're interested in here it's a bit hard to ensure that we propagate static info without jumping through some hoops (define methods for nested static arrays and methods where statically sized arrays are first, second or third).

Now take the approach of ArrayInterface.indices. It's using _pick_range internally to propagate static lengths.

Pretty much anywhere there's a new array created or subsequent calls depend on eachindex or axes is a place we could improve on if we were aware of this info.

It would be so nice if we could use Int instead of StaticInt. I'm rooting for the compiler people here, but in the interim, I think this solution fits a real-world use case in base and elsewhere.

This is pretty close to how I feel about it. Hopefully this comment makes it more clear what we would need from the compiler for this sort of thing to be unnecessary.

I know of four uses for Val in Base: tuples, array dimensions, literal exponentiation, and base/special which I think does fancy mth optimizations. Of these, the first three use cases could all be entirely replaced by StaticInt while the usage in base/special includes things of the form logbL(::Type{Float64},::Val{:ℯ}) = 0.0 which cannot be.

Val is a bit of a footgun approach to propagating information when the only types that need to be in the type domain are Int, Symbol, and Bool. Val doesn't subtype any of these so people just keep putting things in the type domain next to other stuff it interacts with. That's the entirety of StaticArrays, putting the size information next to a buffer. A lot of traits I've seen out in the wild could just be a static true and false instead of a whole new set of types. The type system is fun to play with but it would probably be better if most people didn't worry about manually putting these things in the type domain. Corresponding types in base could be a more structured way of talking to the compiler, but there's a lot going on there and perhaps there's a more robust workflow in development.

@LilithHafner
Copy link
Member

Pretty much anywhere there's a new array created or subsequent calls depend on eachindex or axes is a place we could improve on if we were aware of this info.

I'm not sure how the folks at StaticArrays are doing it, but I think it is hard to improve on the constant propigation going on in this specific case where subsequent calls depend on eachindex and axis:

a = @SVector rand(100);
b = @SVector rand(100);
f(a, b) = sum(eachindex(a, b)) + sum(sum.(axes(a)))
@code_llvm f(a, b)
;  @ none:1 within `f`
define i64 @julia_f_11329([1 x [100 x double]]* nocapture nonnull readonly align 8 dereferenceable(800) %0, [1 x [100 x double]]* nocapture nonnull readonly align 8 dereferenceable(800) %1) #0 {
top:
  ret i64 10100
}

@Tokazama
Copy link
Contributor Author

Pretty much anywhere there's a new array created or subsequent calls depend on eachindex or axes is a place we could improve on if we were aware of this info.

I'm not sure how the folks at StaticArrays are doing it, but I think it is hard to improve on the constant propigation going on in this specific case where subsequent calls depend on eachindex and axis:

a = @SVector rand(100);
b = @SVector rand(100);
f(a, b) = sum(eachindex(a, b)) + sum(sum.(axes(a)))
@code_llvm f(a, b)
;  @ none:1 within `f`
define i64 @julia_f_11329([1 x [100 x double]]* nocapture nonnull readonly align 8 dereferenceable(800) %0, [1 x [100 x double]]* nocapture nonnull readonly align 8 dereferenceable(800) %1) #0 {
top:
  ret i64 10100
}

But eachindex(rand(100), a) returns OneTo and other methods will produce a dynamically sized array when it could be inferred as statically sized. Similarly, combining axes when broadcasting can be tricky. If we can get the same code generated without @SVector then we don't need statically sized arrays and don't need to worry about static types at all.

@LilithHafner
Copy link
Member

LilithHafner commented Mar 10, 2022

eachindex(rand(100), a) returns OneTo

That's because eachindex performs bounds checking and then returns its first argument (i.e. eachindex(a, rand(100)) === SOneTo(100) and eachindex(rand(100), a) === Base.OneTo(100)) How would switching to StaticInt fix this?

@LilithHafner
Copy link
Member

If we can get the same code generated without @SVector

How would this work? Would we parse 100 as a StaticInt?

@Tokazama
Copy link
Contributor Author

Tokazama commented Mar 10, 2022

eachindex(rand(100), a) returns OneTo

That's because eachindex performs bounds checking and then returns its first argument (i.e. eachindex(a, rand(100)) === SOneTo(100) and eachindex(rand(100), a) === Base.OneTo(100)) How would switching to StaticInt fix this?
The link above to ArrayInterface.indices is essentially eachindex being aware of staticness. In order for eachindex to combine these appropriately it has to be aware that static sizing is even a thing.

If we can get the same code generated without @SVector

How would this work? Would we parse 100 as a StaticInt?

I'm not saying that example is solved with StaticInt. If StaticInt were in base perhaps you could pass something like rand(StaticInt(100)), but I was trying to keep this PR succinct so it was easy to discuss what's here. I could add in some of the traits and range related behavior if it helps to see some more concrete applications.

My point was just that if the same code is generated without the size in the type domain then we wouldn't need SVector and might not need StaticInt.

@JeffBezanson
Copy link
Member

In theory the compiler could add a lattice element for an array with known size, and propagate constants to know the size of rand(100) at compile time. The only problem is that we'd then do it for every constant array size, which would likely cause excessive compile times. It is possible there aren't that many constant-size arrays, and that the size being a constant is a hint we should specialize on it, so I'm not sure. Very much an empirical question. But then, you still need a way to request size specialization even when the size is non-constant, and rand(StaticInt(n)) might be a good way to do that.

@LilithHafner
Copy link
Member

LilithHafner commented Mar 10, 2022

What if we had static to take anything into the compile-time domain. With TypedVal{T,x} <: T, static(x) === TypedVal{typeof(x), x}(), and TypedVal{T,x}() == x.

EDIT: I'm not sure about this idea. One problem is that for concrete types T we'd be trying to subtype a concrete type.

@Tokazama
Copy link
Contributor Author

In theory the compiler could add a lattice element for an array with known size, and propagate constants to know the size of rand(100) at compile time. The only problem is that we'd then do it for every constant array size, which would likely cause excessive compile times. It is possible there aren't that many constant-size arrays, and that the size being a constant is a hint we should specialize on it, so I'm not sure. Very much an empirical question. But then, you still need a way to request size specialization even when the size is non-constant, and rand(StaticInt(n)) might be a good way to do that.

I figured that this was the case. In addition to making it so more methods actually can propagate static info, there might be some more intelligent codegen the compiler can be aware of if there's a dedicated type for this. For example, perhaps foo(x::Integer) could avoid generating an additional method unless it's explicitly foo(::StaticInt{N}) where {N}.

@Tokazama
Copy link
Contributor Author

Newest changes have the basic types and traits that we've found useful over in ArrayInterface so far. These don't necessarily need to be a part of this PR or have have the current names. I just thought it would be helpful to provide a little more context. Some places where this is helpful to have in base are Broadcast.combine_axes, promote_shape, and eachindex.

@JeffBezanson
Copy link
Member

Triage discussed this but we don't have any hard conclusions yet.

@Tokazama
Copy link
Contributor Author

Triage discussed this but we don't have any hard conclusions yet.

Really appreciate the update. Just let me know if there's anything else that's needed from me to help solve this issue (whether it be StaticInt or some entirely different solution). I know there's a lot of things being worked on right now that may be more urgent, but a lot of what I'm working on in Julia is related and I just want to make sure this continues to have some momentum.

@oschulz
Copy link
Contributor

oschulz commented Jun 23, 2022

I think it is hard to improve on the constant propigation [...]

I'd love to have static ints (or even static numbers in general) in Base as well. We have several uses cases in which static ints are not confined to a function but become part of types that get passed around (so that "staticness" can be inferred in a different place of the code).

@Tokazama
Copy link
Contributor Author

As I understand it, constant propagation is meant to carry small amounts of compile time known information small distances.
For example, we often know what ndims will return at compile time but we wouldn't expect (or even want) that information to fully propagate for some complicated function passed to ntuple(foo, ndims(A)).
Simply preserving that information at check points isn't the biggest problem though (after all, we have Val for that).
The issue is that when that compile time known information is part of the structure of some object, there's no natural way to communicate that information between methods without manually dispatching on something like StaticArray or passing Val to a bunch of intermediate functions.

Take the simple example of indexing an AbstractUnitRange.
first and last have no concept of statically known indices, so there can't be a generic reconstruction of statically known set of indices.
At a certain point constant propagation just won't help once the value is stored as an Int.
The lack of StaticInt being in base ensures this sort of thing will continue to be an issue because it's not just an issue for ranges.
We keep creating iterables that explicitly store the number of iterations/indices/axes as an Int.

There are lots of examples where things get screwy internally because Base doesn't know how to handle statically sized indices.
Combining heterogeneous axes/indices when broadcasting or using eachindex is also less than optimal.
We test that these indices are the same start and stop position but we miss out on the opportunity to promote these to statically known values if they aren't all static.
Sometimes getting this to work requires rigging together a minimally workable solution that tinkers with methods that aren't part of the public API (e.g., broadcast_shape).
Some methods are clearly created with the intention of only working for Int and overwriting them to support StaticInt is brittle and subject to fluctuations in invalidations between versions of Julia (which is a problem of using a new integer type that is explicitly intended for indexing).

I'm sure there are other benefits on the compiler/interpreter side of things, but I'm not too savvy to that kind of stuff.
I'm assuming static information could also somehow be useful in providing a canonical form for passing information to compiler passes for like loop unrolling or something like Chris Elrod's LoopModels.

@thchr
Copy link
Contributor

thchr commented Jun 23, 2022

  • [...] Commenters here have alluded to StaticArrays needing this type in some cases, but I don't quite get it yet. I would like to see that spelled out in a detailed example.

I think one prominent example is the following (as discussed e.g. at JuliaArrays/StaticArrays.jl#807): in order to encode the size S of a StaticArray{S, ...} type, it is currently the tuple-type S = Tuple{N, M, ...}. One unappealing consequence of this is that S doesn't have an instance. If we had StaticInt, S could be a Tuple{StaticInt{N}, StaticInt{M}, ...} which would have the instance (StaticInt(N), StaticInt(M), ...).

This would allow StaticArrays.jl to drop some of the awkward juggling between Size types and size: with a StaticInt, we could have size(::StaticArray) return this instance in a natural manner (and so keep the static information around; right now, it simply returns (N, M, ...) and so drops the compile-time information; to retain the static information, one currently needs to use Size).

(Of course, this would be quite breaking on the StaticArrays.jl side...)

@Tokazama
Copy link
Contributor Author

There are a handful of things in StaticArrays that make it difficult to share a generic interface with other packages. For example, similar_type is nice when working with StaticArray subtypes in isolation, but it would be preferable to have a generic method that can utilize both static and dynamic information. The accumulation of similar constructs in StaticArrays makes it difficult to design light weight packages that are compatible without full dependency on StaticArrays. Once we get methods with generic support for static and dynamic sizing it should be less treacherous to manage related breaking changes for StaticArrays.

@Tokazama
Copy link
Contributor Author

@JeffBezanson , were you thinking of something like this

@inline function ntuple(f::F, n::Integer) where F
    (n >= 0) || throw(ArgumentError(string("tuple length should be ≥ 0, got ", n)))
    if @generated
        if n <: StaticInt
            quote
                Base.Cartesian.@ntuple $(n.parameters[1]) i->f(i)
            end
        else
            :(_ntuple(f, n))
        end
    else
        _ntuple(f, n)
    end
end
@inline function _ntuple(f::F, n) where F
    # marked inline since this benefits from constant propagation of `n`
    t = n == 0  ? () :
        n == 1  ? (f(1),) :
        n == 2  ? (f(1), f(2)) :
        n == 3  ? (f(1), f(2), f(3)) :
        n == 4  ? (f(1), f(2), f(3), f(4)) :
        n == 5  ? (f(1), f(2), f(3), f(4), f(5)) :
        n == 6  ? (f(1), f(2), f(3), f(4), f(5), f(6)) :
        n == 7  ? (f(1), f(2), f(3), f(4), f(5), f(6), f(7)) :
        n == 8  ? (f(1), f(2), f(3), f(4), f(5), f(6), f(7), f(8)) :
        n == 9  ? (f(1), f(2), f(3), f(4), f(5), f(6), f(7), f(8), f(9)) :
        n == 10 ? (f(1), f(2), f(3), f(4), f(5), f(6), f(7), f(8), f(9), f(10)) :
        __ntuple(f, n)
    return t
end

@noinline __ntuple(f::F, n) where F = ([f(i) for i = 1:n]...,)

@oschulz
Copy link
Contributor

oschulz commented Jun 24, 2022

The new Sliced would make a great use case for static numbers in Base, to make the dimension/slicing order inferable (important for many use cases): #32310 (comment)

@ChrisRackauckas
Copy link
Member

We were able to work around it a bit by creating StaticArraysCore and then ArrayInterfaceStaticArraysCore so that there could be an interface which avoids the static behavior to allow for it to only exist much later, which cut down downstream load times by about 4 seconds. Thus with diligent programming we can work around this issue. That said, I'm not sure how many groups are going to appropriately use ArrayInterfaceStaticArraysCore, StaticArraysCore, and GPUArraysCore, and if not done correctly you get bad load times and lots of invalidations. To me, the existence of ArrayInterfaceStaticArraysCore says something is wrong here, but it appears to not be an opinion shared by some others in this thread so 🤷

@Tokazama
Copy link
Contributor Author

I'm not completely convinced that StaticArraysCore actually solves the problem it is intended to solve, but I haven't had time to really dig deep on that effort lately.

@Tokazama
Copy link
Contributor Author

Now that we're working on 1.9, is this worth discussing further?

@cscherrer
Copy link

I hope so, and also that StaticFloat64 might still be a possibility for 1.9. I discussed the need for that as part of this post, starting at the basemeasure(Normal()) code block.

@Tokazama
Copy link
Contributor Author

Before revisiting this I think it might be helpful to share something I put together in a private discussion with Tim Holy and Chris Elrod last month.

In the code below I wrote custom stride and size methods (get_size and get_strides) that carefully preserve constant propagation. I then went through all of our tests in ArrayInterface.jl and for every dimension that returned a StaticInt I ensured that static_get_size(A, StaticInt(dim)) and static_get_stride(A, StaticInt(dim)) were inferable.

We can rely on constant propagation if every intervening method is extremely careful to preserve that information. However, it often fails for generic methods

julia> function first_of_length(x, y)
           StaticInt(getfield(get_size(first(x, length(y))), 1))
       end

julia> @inferred first_of_length(1:20, static(1):static(10))
ERROR: return type StaticInt{10} does not match inferred return type StaticInt

This can still be resolved if we have very carefully defined methods for specific situations.
I assume the logic for preserving types with static parameters is similar. You would just need to overload a lot of methods.

julia> function first_of_length(x::Base.OneTo, ::ArrayInterface.SUnitRange{1,N}) where {N}
           StaticInt(Base.OneTo{Int}(N).stop)
       end

julia> @inferred first_of_length(Base.OneTo(20), static(1):static(10))
static(10)

I think the take away here is that we can rely on constant propagation for passing around persistent compile time known parameters if we are willing to sacrifice generic programming in cases where it's important. If this is what we want to do then we should probably add more public API hooks for making this possible in certain methods without completely rewriting chunks of Base.

If we want to support more generic methods that preserve this info we need to decide if this should be supported through types (e.g., StaticInt, SArray, etc.) or the compiler. If we want to rely on the compiler, I'm not sure we want traditional constant propagation to be the solution. Not only does it complicate the constant balance between compilation time and inference, but that sort of information we are talking about is explicitly declared static by users and constant propagation isn't something the user actively interacts with.

If we were to rely on non-type driven solutions we'd need something that can be used to variably mark things as statically known across methods and as fields of types.

static sz = (3, 3, 3)  # mark these values as statically known

rr = reshape(1:27, sz)

isstatic(size(rr)) == true  # dims field of ReshapedArray is static here

If this is possible it would be preferable to making bunch of parametric types to support this. However, before committing to something like that I'd like to know that it truly is the next step in the plan on handling this so we aren't just spinning our wheels waiting for a solution that may never come.

# this can be plugged into https://github.com/JuliaArrays/ArrayInterface.jl/blob/bf07fa9cbdf482125ad44e8c1c4095111fe06206/test/indexing.jl#L253 to run the tests in "test/size.jl" and "test/stridelayout.jl"
struct MArray{size,T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N}
    parent::P

    MArray(x::AbstractArray{T,N}) where {T,N} = new{size(x),T,N,typeof(x)}(x)
end
Base.parent(x::MArray) = x.parent
Base.size(::MArray{S}) where {S} = S
Base.getindex(x::MArray, inds...) = x.parent[inds...]
ArrayInterface.defines_strides(::Type{<:MArray}) = true
ArrayInterface.is_forwarding_wrapper(::Type{<:MArray}) = true
ArrayInterface.parent_type(::Type{<:MArray{<:Any,<:Any,<:Any,P}}) where {P} = P
#Base.strides(::MArray{size}) where {size} = ArrayInterface.size_to_strides(

get_size(a) = Base.size(a)
get_size(a, dim) = Base.size(a)[dim]
@inline function get_size(x::SubArray)
    ArrayInterface.flatten_tuples(_get_size(parent(x), x.indices, ArrayInterface.IndicesInfo(x)))
end
function _get_size(p::AbstractArray, inds::Tuple, ::ArrayInterface.IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
    ntuple(Val{nfields(pdims)}()) do i
        pdim_i = getfield(pdims, i)
        cdim_i = getfield(cdims, i)
        if pdim_i isa Int && cdim_i isa Int
            inds_i = getfield(inds, i)
            if inds_i isa Base.Slice{Base.OneTo{Int}}
                getfield(get_size(p), pdim_i)
            else
                get_size(inds_i)
            end
        else
            get_size(getfield(inds, i))
        end
    end
end
@inline get_size(B::ArrayInterface.VecAdjTrans) = (1, getfield(get_size(parent(B)), 1))
function get_size(A::PermutedDimsArray{<:Any,N,perm}) where {N,perm}
    psize = get_size(parent(A))
    ntuple(i->getfield(psize, getfield(perm, i)), Val{N}())
end
function get_size(A::ArrayInterface.MatAdjTrans)
    s1, s2 = get_size(parent(A))
    (s2, s1)
end

@inline function get_size(a::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S}
    psize = get_size(parent(a))
    (div(getfield(psize, 1) * sizeof(S), sizeof(T)), Base.tail(psize)...)
end
@inline function get_size(a::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S}
    psize = get_size(parent(a))
    ss = sizeof(S)
    ts = sizeof(T)
    if ss === ts
        return psize
    elseif ss > ts
        return (div(ss, ts), psize...,)
    else
        return Base.tail(psize)
    end
end
get_size(A::AbstractRange) = (length(A),)

@inline get_strides(A) = ArrayInterface.size_to_strides(get_size(A), 1)
# `strides` for `Base.ReinterpretArray`
function get_strides(A::Base.ReinterpretArray{T,<:Any,S,<:AbstractArray{S},IsReshaped}) where {T,S,IsReshaped}
    ArrayInterface._is_column_dense(parent(A)) && return ArrayInterface.size_to_strides(get_size(A), 1)
    stp = get_strides(parent(A))
    ET, ES = static(sizeof(T)), static(sizeof(S))
    ET === ES && return stp
    IsReshaped && ET < ES && return (One(), _reinterp_strides(stp, ET, ES)...)
    first(stp) == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
    if IsReshaped
        # The wrapper tell us `A`'s parent has static size in dim1.
        # We can make the next stride static if the following dim is still dense.
        sr = ArrayInterface.stride_rank(parent(A))
        dd = ArrayInterface.dense_dims(parent(A))
        stp′ = _new_static(stp, sr, dd, ET ÷ ES)
        return _reinterp_strides(Base.tail(stp′), ET, ES)
    else
        return (One(), _reinterp_strides(Base.tail(stp), ET, ES)...)
    end
end
_new_static(P,_,_,_) = P # This should never be called, just in case.
@generated function _new_static(p::P, ::SR, ::DD, ::StaticInt{S}) where {S,N,P<:NTuple{N,Union{Int,StaticInt}},SR<:NTuple{N,StaticInt},DD<:NTuple{N,StaticBool}}
    sr = fieldtypes(SR)
    j = findfirst(T -> T() == sr[1]()+1, sr)
    if !isnothing(j) && !(fieldtype(P, j) <: StaticInt) && fieldtype(DD, j) === True
        return :(tuple($((i == j ? :(static($S)) : :(p[$i]) for i in 1:N)...)))
    else
        return :(p)
    end
end
@inline function _reinterp_strides(stp::Tuple, els::StaticInt, elp::StaticInt)
    if elp % els == 0
        N = elp ÷ els
        return map(Base.Fix2(*, N), stp)
    else
        return map(stp) do i
            d, r = divrem(elp * i, els)
            iszero(r) || throw(ArgumentError("Parent's strides could not be exactly divided!"))
            d
        end
    end
end
get_strides(@nospecialize x::AbstractRange) = (1,)
function get_strides(x::ArrayInterface.VecAdjTrans)
    st = getfield(get_strides(parent(x)), 1)
    return (st, st)
end
function get_strides(x::ArrayInterface.MatAdjTrans)
    s1, s2 = get_strides(parent(x))
    (s2, s1)
end
function get_strides(A::PermutedDimsArray{<:Any,N,perm}) where {N,perm}
    psize = get_strides(parent(A))
    ntuple(i->getfield(psize, getfield(perm, i)), Val{N}())
end

getmul(x::Tuple, y::Tuple, ::StaticInt{i}) where {i} = getfield(x, i) * getfield(y, i)
function get_strides(A::SubArray)
    Static.eachop(getmul, ArrayInterface.to_parent_dims(typeof(A)), map(maybe_static_step, A.indices), get_strides(parent(A)))
end

maybe_static_step(x::AbstractRange) = ArrayInterface.static_step(x)
maybe_static_step(_) = nothing

static_get_size(x, ::StaticInt{dim}) where {dim} = StaticInt(getfield(get_size(x), dim))
static_get_stride(x, ::StaticInt{dim}) where {dim} = StaticInt(getfield(get_strides(x), dim))

@JeffreySarnoff
Copy link
Contributor

Some time ago, I spent a month or more working with StaticInts, StaticFloats, and StaticBools to see how they might be slightly improved or better used. Unless there is putting one or more of them into Base, where they benefit many who would never work with one of the Static_.jl packages [I think that is most devs who are not reading this], we leave an obligation to spend real time going forward beating back issues discussed above.
Really, I became weary with always considering where code would benefit from my use of them, and stopped.

@Tokazama
Copy link
Contributor Author

@JeffBezanson , you mentioned back in June that triage was back on this. Is there some other solution in the works behind the scenes or should I be updating this and keeping the discussion going?

@Tokazama
Copy link
Contributor Author

@timholy you mentioned getting this in before the feature freeze for 1.9 #44519 (comment). Is this still the plan?

@StefanKarpinski
Copy link
Member

I really think we need to have some of the proponents of adding this to base join the triage call some time and make the case for it—and probably more importantly to answer questions in real time. We've discussed it numerous times and never been convinced. This implementation in response to @JeffBezanson's request is pretty unsatisfying since it just takes what was previously two separate method bodies and combines them by adding a branch on the type check. What would be compelling is if there's an actual reduction in the amount of code and logic required to handle both static and non-static cases with a single code base.

@Seelengrab
Copy link
Contributor

Seelengrab commented Sep 29, 2022

Since StaticFloat64 has already come up and my money is on cases like that just getting more common, adding a special case for Int just feels off - from my POV, just in terms of specializing on sizes, a more general language level feature for expressing staticness (possibly also with dispatch and/or restrictions if it isn't Const?) would be better.

you mentioned back in June that triage was back on this.

I think Jeff was referring to the fact that triage has talked about this lots of times, but always without the people involved in wanting to get this merged 🤷 Next triage meeting (other than today) is on the 13th at 20:15 UTC, I think - someone from this PR joining would be great!

@MasonProtter
Copy link
Contributor

I think the implementation of ntuple that @Tokazama showed is more complicated than it needs to be. With StaticNumbers.jl I can implement ntuple as a one liner, and still constant fold everything:

julia> using StaticNumbers

julia> my_ntuple(f, n) = Tuple(f(i) for i in Base.OneTo(n));

julia> let n = 12, ns = static(n), nv = Val(n)
           f(x) = sin(x + 1)
           @btime my_ntuple($f, $ns)
           @btime    ntuple($f, $nv)
       end;
  1.257 ns (0 allocations: 0 bytes)
  1.258 ns (0 allocations: 0 bytes)

julia> let n = 120, ns = static(n), nv = Val(n)
           f(x) = sin(x + 1)
           @btime my_ntuple($f, $ns)
           @btime    ntuple($f, $nv)
       end;
  15.601 ns (0 allocations: 0 bytes)
  17.599 ns (0 allocations: 0 bytes)

For a f that it can't compute completely at compile time, e.g. a closure or something impure, this is slightly slower than ntuple, but there's probably some relatively simple way to make it just as fast.

julia> let n = 3, ns = static(n), nv = Val(n)
           y = 1
           f = Ref(x -> sin(x + y))
           @btime my_ntuple($f[], $ns)
           @btime    ntuple($f[], $nv)
       end;
  20.779 ns (0 allocations: 0 bytes)
  20.960 ns (0 allocations: 0 bytes)

julia> let n = 13, ns = static(n), nv = Val(n)
           y = 1
           f = Ref(x -> sin(x + y))
           @btime my_ntuple($f[], $ns)
           @btime    ntuple($f[], $nv)
       end;
  101.147 ns (0 allocations: 0 bytes)
  96.864 ns (0 allocations: 0 bytes)

julia> let n = 50, ns = static(n), nv = Val(n)
           y = 1
           f = Ref(x -> sin(x + y))
           @btime my_ntuple($f[], $ns)
           @btime    ntuple($f[], $nv)
       end;
  420.814 ns (0 allocations: 0 bytes)
  408.480 ns (0 allocations: 0 bytes)

julia> let n = 150, ns = static(n), nv = Val(n)
           y = 1
           f = Ref(x -> sin(x + y))
           @btime my_ntuple($f[], $ns)
           @btime    ntuple($f[], $nv)
       end;
  1.317 μs (0 allocations: 0 bytes)
  1.257 μs (0 allocations: 0 bytes)

@Tokazama
Copy link
Contributor Author

I wasn't 100% sure what the end goal was for the ntuple example was, which is why I asked if my hastily put together example was on the right track.

With StaticNumbers.jl I can implement ntuple as a one liner, and still constant fold everything:

StaticNumbers.jl is still doing some extra work there behind the scenes and ultimately ends up using a non-optionally generated ntuple(f, ::StaticInteger) method here. I'm not saying the approach is objectively wrong, but it doesn't exactly consolidate ntuple.

Feeding ntuple into Tuple(itr) seems like the right approach since it could potentially benefit other methods too. The following isn't necessarily pretty, but I'm pretty sure it could carry the builk of the work for unrolling statically size iterators into tuples. It might be a good starting point.

using ArrayInterface
using ArrayInterface: known_length
using BenchmarkTools
using Static

@inline itr2tuple(f::F, n::N) where {F,N} = _to_tuple(Iterators.map(f, static(1):n))
function _to_tuple(itr::I) where {I}
    if @generated
        N = known_length(I)
        quote
            @inline
            v_1, s_1 = iterate(itr)
            Base.Cartesian.@nexprs $(N - 1) i -> (v_{i+1}, s_{i+1}) = iterate(itr, s_i)
            Base.Cartesian.@ntuple $(N) i->v_i
        end
    else
        (itr...,)
    end
end

It also has comparable performance to Base.ntuple.

julia> let n = 12, ns = static(n), nv = Val(n)
           f(x) = sin(x + 1)
           @btime itr2tuple($f, $ns)
           @btime    ntuple($f, $nv)
       end;
  1.516 ns (0 allocations: 0 bytes)
  1.516 ns (0 allocations: 0 bytes)

julia> let n = 120, ns = static(n), nv = Val(n)
           f(x) = sin(x + 1)
           @btime itr2tuple($f, $ns)
           @btime    ntuple($f, $nv)
       end;
  11.205 ns (0 allocations: 0 bytes)
  11.777 ns (0 allocations: 0 bytes)

julia> let n = 3, ns = static(n), nv = Val(n)
           y = 1
           f = Ref(x -> sin(x + y))
           @btime itr2tuple($f[], $ns)
           @btime    ntuple($f[], $nv)
       end;
  22.266 ns (0 allocations: 0 bytes)
  22.281 ns (0 allocations: 0 bytes)

julia> let n = 13, ns = static(n), nv = Val(n)
           y = 1
           f = Ref(x -> sin(x + y))
           @btime itr2tuple($f[], $ns)
           @btime    ntuple($f[], $nv)
       end;
  99.429 ns (0 allocations: 0 bytes)
  101.203 ns (0 allocations: 0 bytes)


julia> let n = 50, ns = static(n), nv = Val(n)
           y = 1
           f = Ref(x -> sin(x + y))
           @btime itr2tuple($f[], $ns)
           @btime    ntuple($f[], $nv)
       end;
  417.814 ns (0 allocations: 0 bytes)
  418.523 ns (0 allocations: 0 bytes)

julia> let n = 150, ns = static(n), nv = Val(n)
           y = 1
           f = Ref(x -> sin(x + y))
           @btime itr2tuple($f[], $ns)
           @btime    ntuple($f[], $nv)
       end;
  1.286 μs (0 allocations: 0 bytes)
  1.290 μs (0 allocations: 0 bytes)

a more general language level feature for expressing staticness (possibly also with dispatch and/or restrictions if it isn't Const?) would be better.

Of course! It would be great if it were as simple as marking some variable. Then we could just do ntuple(f, @persistent_static_parameter(10)). I believe I said as much here, but there hasn't been any indication that will ever be made possible in Julia.

@brenhinkeller brenhinkeller added the compiler:latency Compiler latency label Nov 21, 2022
@amilsted
Copy link
Contributor

amilsted commented Jan 13, 2023

In #47206 we have discussed accepting static bools (or ints) as the alpha and beta arguments in LinearAlgebra.mul!(C, A, B, alpha, beta), allowing compile-time selection of native mul!() variants without relying on constant propagation, which seems pretty flakey here in practice.

@Tokazama
Copy link
Contributor Author

We talked about this in triage in October and I don't think there was a real solid conclusion.

There are places where this would be completely fine as an independent package if we solved issues with latency (invalidations). However, it isn't clear this is the best way to tackle standardization of dynamic versus static components within a single code base. This isn't because there's another solution out there that more clearly solves this.

I think this idea still has some merit but it would probably require making some more guide line based approaches to coding in addition to the actual implementation. For example, this gist defines a handful of reference types that could be treated like Rust's Option type but with static, immutable, mutable, and atomic reference types. The code itself isn't sufficient to solve this problem. We'd need a group of people to go in and write a bunch of very opinionated guides on best practices in writing could that uses these types instead of directly interacting with raw primitive types. That would take a lot of discussion and time for something I think is important but uninteresting.

There's some fancy compiler stuff that could happen, but it seems to be very far off from even reaching the conceptual level in Julia. Maybe once custom compiler passes become more accessible someone will jump in with something innovative.

Maybe something will change but for now I'd recommend using Static.jl.


%(@nospecialize(n::StaticInt), ::Type{Integer}) = Int(n)

eltype(@nospecialize(T::Type{<:StaticInt})) = Int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why make eltype(StaticInt) be Int? I assume it'd be identity by default? What would be wrong with that?

The reason I'm asking is that I have a package for (static) natural numbers here, and I'm considering having it subtype Integer: https://juliahub.com/ui/Packages/General/TypeDomainNaturalNumbers

BTW my package seems like a good alternative for the code here, because it's safer in two ways:

  • It doesn't wrap other types, so there are no issues with needing to assume that the wrapped type is Int, like here
  • negative numbers don't exist, which is safer for representing sizes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the invalidations are prohibitive. It seems any new Integer subtypes have to be compiled together with Base.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler:latency Compiler latency triage This should be discussed on a triage call
Projects
None yet
Development

Successfully merging this pull request may close these issues.