Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wma/precompile and docs #176

Merged
merged 29 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1c77140
first pass, seems slower for some reason.
willow-ahrens Apr 26, 2023
490b923
Changing get to get! so as not to recompile haha
willow-ahrens Apr 26, 2023
863a505
make tests pass
willow-ahrens Apr 26, 2023
665de4c
cleanup algebra docs
willow-ahrens Apr 26, 2023
4047e24
remove duplicate def
willow-ahrens Apr 26, 2023
528a185
working on a general infrastructure
willow-ahrens Apr 27, 2023
0282e97
generalize calling infrastructure for runtime generated functions
willow-ahrens Apr 27, 2023
e3aa32a
push more code into the abstraction
willow-ahrens Apr 27, 2023
113352c
continue rewriting generated functions
willow-ahrens Apr 27, 2023
0c7435e
Merge branch 'main' into wma/precompile-and-docs
willow-ahrens May 1, 2023
3fb6d3a
Merge branch 'main' into wma/precompile-and-docs
willow-ahrens May 1, 2023
5d54627
Merge branch 'simplify-benchmarks' into wma/precompile-and-docs
willow-ahrens May 1, 2023
19658bf
fixes
willow-ahrens May 1, 2023
f810d11
update compiled functions to use typed arguments
willow-ahrens May 1, 2023
92783ad
rename
willow-ahrens May 1, 2023
f716a8b
feels better.
willow-ahrens May 1, 2023
00ac620
guard idempotency optimization with check for nonempty extent
willow-ahrens May 1, 2023
1290cfe
closer
willow-ahrens May 1, 2023
0fe0885
Merge branch 'simplify-benchmarks' into wma/precompile-and-docs
willow-ahrens May 1, 2023
17aa354
finish staging transform
willow-ahrens May 1, 2023
566e978
no longer need rgfs
willow-ahrens May 1, 2023
2163a49
Merge branch 'wma/cleanup-benchmarks' into wma/precompile-and-docs
willow-ahrens May 2, 2023
e567d8e
reorganize
willow-ahrens May 2, 2023
1a80142
Merge remote-tracking branch 'origin/main' into wma/precompile-and-docs
willow-ahrens May 2, 2023
ca666d3
rm ExprTools
willow-ahrens May 2, 2023
cb61f2d
Merge branch 'main' into wma/precompile-and-docs
willow-ahrens May 3, 2023
b96a36c
update docs a bit
willow-ahrens May 3, 2023
05742ea
cleanup tests
willow-ahrens May 3, 2023
ce208ed
cleanup docs a bit
willow-ahrens May 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 84 additions & 53 deletions docs/src/algebra.md
Original file line number Diff line number Diff line change
@@ -1,76 +1,54 @@
```@meta
CurrentModule = Finch
```

# Custom Functions

Finch supports arbitrary Julia Base functions over [`isbits`](@ref) types. For your convenience,
Finch defines a few useful functions that help express common array operations inside Finch:

```@docs
choose
minby
maxby
```
## User Functions

Finch supports arbitrary Julia Base functions over [`isbits`](@ref) types. You
can also use your own functions and use them in Finch! Just remember to define
any special algebraic properties of your functions so that Finch can optimize
them better. You must declare the properties of your functions before you call
any Finch functions on them.

Finch only supports incrementing assignments to arrays such as `+=` or `*=`. If
you would like to increment `A[i...]` by the value of `ex` with a custom
reduction operator `op`, you may use the following syntax: `A[i...] <<op>>= ex`.

# User Functions

Users can also define their own functions, and declare their properties to the
Finch compiler as follows:

## Register User Functions

Finch uses generated functions to compile kernels. If any functions have been
defined after Finch was loaded, Finch needs to be notified about them. The most
correct approach is to create a trait datatype that subtypes
`Finch.AbstractAlgebra` and call `Finch.register` on that type. After you call
`register`, that subtype reflects the methods you know to be currently defined
at that world age. You can pass your algebra to Finch to run Finch in that world
age.

## Declare Algebraic Properties

Users can help Finch optimize expressions over new functions by declaring key
function properties in the algebra. Finch kernels can then be executed using the
algebra.

As an example, suppose we wanted to declare some properties for the greatest
common divisor function `gcd`. This function is associative and commutative, and
the greatest common divisor of 1 and anything else is 1, so 1 is an annihilator.

We can express this by subtyping `AbstractAlgebra` and defining properties as
follows:
Consider the greatest common divisor function `gcd`. This function is
associative and commutative, and the greatest common divisor of 1 and anything
else is 1, so 1 is an annihilator. We declare these properties by overloading
trait functions on Finch's default algebra as follows:
```
struct MyAlgebra <: AbstractAlgebra end

Finch.isassociative(::MyAlgebra, ::typeof(gcd)) = true
Finch.iscommutative(::MyAlgebra, ::typeof(gcd)) = true
Finch.isannihilator(::MyAlgebra, ::typeof(gcd), x) = x == 1
Finch.isassociative(::Finch.DefaultAlgebra, ::typeof(gcd)) = true
Finch.iscommutative(::Finch.DefaultAlgebra, ::typeof(gcd)) = true
Finch.isannihilator(::Finch.DefaultAlgebra, ::typeof(gcd), x) = x == 1
```

When you're all done defining functions that dispatch on your algebra, call
`Finch.register` to register your new algebra in Finch.
Then, the following code will only call gcd when neither `u[i]` nor `v[i]` are 1
(just once!).
```
Finch.register(MyAlgebra)
u = @fiber(sl(e(1)), [3, 1, 6, 1, 9, 1, 4, 1, 8, 1])
v = @fiber(sl(e(1)), [1, 2, 3, 1, 1, 1, 1, 4, 1, 1])
w = @fiber sl(e(1))

@finch MyAlgebra() (w .= 1; @loop i w[i] = gcd(u[i], v[i]))
```

Then, we can call a kernel that uses our algebra!
## A Few Convenient Functions

```
u = @fiber sl(e(1)) #TODO add some data
v = @fiber sl(e(1)) #TODO add some data
w = @fiber sl(e(1))
For your convenience, Finch defines a few useful functions that help express common array operations inside Finch:

@finch MyAlgebra() (w .= 1; @loop i w[i] = gcd(u[i], v[i]))
```@docs
choose
minby
maxby
```

## Properties

The full list of properties recognized by Finch is as follows:
The full list of properties recognized by Finch is as follows (use these to declare the properties of your own functions):

```@docs
isassociative
Expand All @@ -83,13 +61,66 @@ isinverse
isinvolution
```

## Rewriting
## Finch Kernel Caching

Finch code is cached when you first run it. Thus, if you run a Finch
function once, then make changes to the Finch compiler (like defining new
properties), the cached code will be used and the changes will not be reflected.

It's best to design your code so that modifications to the Finch compiler occur
before any Finch functions are called. However, if you really need to modify a
precompiled Finch kernel, you can call `Finch.refresh()` to invalidate the
code cache.

```@docs
refresh
```

### (Advanced) On World-Age and Generated Functions
Julia uses a "world age" to describe the set of defined functions at a point in time. Generated functions run in the same world age in which they were defined, so they can't call functions defined after the generated function. This means that if Finch used normal generated functions, users can't define their own functions without first redefining all of Finch's generated functions.

Finch uses special generators that run in the current world age, but do not
update with subsequent compiler function invalidations. If two packages modify
the behavior of Finch in different ways, and call those Finch functions during
precompilation, the resulting behavior is undefined.

There are several packages that take similar, but different, approaches to
allow user participation in staged Julia programming (not to mention Base `eval` or `@generated`): [StagedFunctions.jl](https://github.com/NHDaly/StagedFunctions.jl),
[GeneralizedGenerated.jl](https://github.com/JuliaStaging/GeneralizedGenerated.jl),
[RuntimeGeneratedFunctions.jl](https://github.com/SciML/RuntimeGeneratedFunctions.jl),
or [Zygote.jl](https://github.com/FluxML/Zygote.jl).

Our approach is most similar to that of StagedFunctions.jl or Zygote.jl. We
chose our approach to be the simple and flexible while keeping the kernel call
overhead low.

## (Advanced) Separate Algebras
If you want to define non-standard properties or custom rewrite rules for some
functions in a separate context, you can represent these changes with your own
algebra type. We express this by subtyping `AbstractAlgebra` and defining
properties as follows:
```
struct MyAlgebra <: AbstractAlgebra end

Finch.isassociative(::MyAlgebra, ::typeof(gcd)) = true
Finch.iscommutative(::MyAlgebra, ::typeof(gcd)) = true
Finch.isannihilator(::MyAlgebra, ::typeof(gcd), x) = x == 1
```

We pass the algebra to Finch as an optional first argument:

```
@finch MyAlgebra() (w .= 1; @loop i w[i] = gcd(u[i], v[i]))
```


### Rewriting

One can also define custom rewrite rules by overloading the `getrules` function
Define custom rewrite rules by overloading the `getrules` function
on your algebra. Unless you want to write the full rule set from scratch, be
sure to append your new rules to the old rules, which can be obtained by calling
`base_rules`. Rules can be specified directly on Finch IR using
[RewriteTools.jl](https://github.com/willow-ahrens/RewriteTools.jl)
[RewriteTools.jl](https://github.com/willow-ahrens/RewriteTools.jl).

```@docs
base_rules
Expand Down
4 changes: 0 additions & 4 deletions ext/SparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,4 @@ Finch.virtual_eltype(tns::VirtualSparseVector) = tns.Tv

SparseArrays.nnz(fbr::Fiber) = countstored(fbr)

function __init__()
Finch.register(Finch.DefaultAlgebra)
end

end
14 changes: 2 additions & 12 deletions src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ export default, AsArray

include("util.jl")

registry = []

include("semantics.jl")
include("FinchNotation/FinchNotation.jl")
using .FinchNotation
Expand Down Expand Up @@ -97,14 +95,6 @@ module h
generate_embed_docs()
end

function register(algebra)
for r in registry
@eval Finch $(r(algebra))
end
end

register(DefaultAlgebra)

include("base/abstractarrays.jl")
include("base/abstractunitranges.jl")
include("base/broadcast.jl")
Expand All @@ -114,8 +104,8 @@ include("base/compare.jl")
include("base/copy.jl")
include("base/fsparse.jl")

@static if !isdefined(Base, :get_extension)
function __init__()
function __init__()
@static if !isdefined(Base, :get_extension)
@require SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" include("../ext/SparseArraysExt.jl")
end
end
Expand Down
20 changes: 10 additions & 10 deletions src/base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ function Base.similar(bc::Broadcast.Broadcasted{FinchStyle{N}}, ::Type{T}, dims)
similar_broadcast_helper(lift_broadcast(bc))
end

@generated function similar_broadcast_helper(bc::Broadcast.Broadcasted{FinchStyle{N}}) where {N}
idxs = [index(Symbol(:i, n)) for n = 1:N]
@staged_function similar_broadcast_helper bc begin
idxs = [index(Symbol(:i, n)) for n = 1:ndims(bc)]
ctx = LowerJulia()
rep = pointwise_finch_traits(:bc, bc, idxs)
rep = PointwiseRep(ctx)(rep, reverse(idxs))
Expand Down Expand Up @@ -161,20 +161,20 @@ function pointwise_finch_expr(ex, T, ctx, idxs)
:($src[$(idxs[1:ndims(T)]...)])
end

@generated function Base.copyto!(out, bc::Broadcasted{FinchStyle{N}}) where {N}
copyto_helper!(:out, out, :bc, bc)
function Base.copyto!(out, bc::Broadcasted{<:FinchStyle})
copyto_broadcast_helper!(out, bc)
end

function copyto_helper!(out_ex, out, bc_ex, bc::Type{<:Broadcasted{FinchStyle{N}}}) where {N}
@staged_function copyto_broadcast_helper! out bc begin
contain(LowerJulia()) do ctx
idxs = [ctx.freshen(:idx, n) for n = 1:N]
pw_ex = pointwise_finch_expr(bc_ex, bc, ctx, idxs)
idxs = [ctx.freshen(:idx, n) for n = 1:ndims(bc)]
pw_ex = pointwise_finch_expr(:bc, bc, ctx, idxs)
quote
@finch begin
$out_ex .= $(default(out))
@loop($(reverse(idxs)...), $out_ex[$(idxs...)] = $pw_ex)
out .= $(default(out))
@loop($(reverse(idxs)...), out[$(idxs...)] = $pw_ex)
end
$out_ex
out
end
end
end
4 changes: 2 additions & 2 deletions src/base/compare.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@generated function helper_equal(A, B)
@staged_function helper_equal A B begin
idxs = [Symbol(:i_, n) for n = 1:ndims(A)]
return quote
size(A) == size(B) || return false
Expand All @@ -20,7 +20,7 @@ function Base.:(==)(A::AbstractArray, B::Fiber)
return helper_equal(A, B)
end

@generated function helper_isequal(A, B)
@staged_function helper_isequal A B begin
idxs = [Symbol(:i_, n) for n = 1:ndims(A)]
return quote
size(A) == size(B) || return false
Expand Down
6 changes: 4 additions & 2 deletions src/base/copy.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@generated function copyto_helper!(dst, src)
@staged_function copyto_helper! dst src begin
ndims(dst) > ndims(src) && throw(DimensionMismatch("more dimensions in destination than source"))
ndims(dst) < ndims(src) && throw(DimensionMismatch("less dimensions in destination than source"))
idxs = [Symbol(:i_, n) for n = 1:ndims(dst)]
Expand All @@ -21,7 +21,9 @@ end

dropdefaults(src) = dropdefaults!(similar(src), src)

@generated function dropdefaults!(dst::Fiber, src)
dropdefaults!(dst::Fiber, src) = dropdefaults_helper!(dst, src)

@staged_function dropdefaults_helper! dst src begin
ndims(dst) > ndims(src) && throw(DimensionMismatch("more dimensions in destination than source"))
ndims(dst) < ndims(src) && throw(DimensionMismatch("less dimensions in destination than source"))
idxs = [Symbol(:i_, n) for n = 1:ndims(dst)]
Expand Down
4 changes: 2 additions & 2 deletions src/base/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ getindex_rep_def(lvl::RepeatData, idx) = SolidData(ElementData(lvl.default, lvl.
getindex_rep_def(lvl::RepeatData, idx::Type{<:AbstractUnitRange}) = SolidData(ElementData(lvl.default, lvl.eltype))

Base.getindex(arr::Fiber, inds...) = getindex_helper(arr, to_indices(arr, inds)...)
@generated function getindex_helper(arr::Fiber, inds...)
@staged_function getindex_helper arr inds... begin
@assert ndims(arr) == length(inds)
N = ndims(arr)

Expand Down Expand Up @@ -67,7 +67,7 @@ Base.getindex(arr::Fiber, inds...) = getindex_helper(arr, to_indices(arr, inds).
end

Base.setindex!(arr::Fiber, src, inds...) = setindex_helper(arr, src, to_indices(arr, inds)...)
@generated function setindex_helper(arr::Fiber, src, inds...)
@staged_function setindex_helper arr src inds... begin
@assert ndims(arr) == length(inds)
@assert sum(ndims.(inds)) == 0 || (ndims(src) == sum(ndims.(inds)))
N = ndims(arr)
Expand Down
6 changes: 5 additions & 1 deletion src/base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ function Base.reduce(op::Function, bc::Broadcasted{FinchStyle{N}}; dims=:, init
reduce_helper(Callable{op}(), lift_broadcast(bc), Val(dims), Val(init))
end

@generated function reduce_helper(::Callable{op}, bc::Broadcasted{FinchStyle{N}}, ::Val{dims}, ::Val{init}) where {op, dims, init, N}
@staged_function reduce_helper op bc dims init begin
reduce_helper_code(op, bc, dims, init)
end

function reduce_helper_code(::Type{Callable{op}}, bc::Type{<:Broadcasted{FinchStyle{N}}}, ::Type{Val{dims}}, ::Type{Val{init}}) where {op, dims, init, N}
contain(LowerJulia()) do ctx
idxs = [ctx.freshen(:idx, n) for n = 1:N]
rep = pointwise_finch_traits(:bc, bc, index.(idxs))
Expand Down
12 changes: 4 additions & 8 deletions src/execute.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
execute(ex) = execute(ex, DefaultAlgebra())

push!(registry, (algebra)-> quote
@generated function execute(ex, a::$algebra)
quote
@inbounds begin
$(execute_code(:ex, ex, a()))
end
end
@staged_function execute ex a quote
@inbounds begin
$(execute_code(:ex, ex, a()) |> unblock)
end
end)
end

function execute_code(ex, T, algebra = DefaultAlgebra())
prgm = nothing
Expand Down
20 changes: 9 additions & 11 deletions src/fibers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,18 +353,16 @@ macro fiber(ex, arg)
return :($dropdefaults!($Fiber!($(f_decode(ex))), $(esc(arg))))
end

push!(registry, (algebra) -> quote
@generated function Fiber!(lvl)
contain(LowerJulia()) do ctx
lvl = virtualize(:lvl, lvl, ctx)
lvl = resolve(lvl, ctx)
lvl = declare_level!(lvl, ctx, literal(0), literal(virtual_level_default(lvl)))
push!(ctx.preamble, assemble_level!(lvl, ctx, literal(1), literal(1)))
lvl = freeze_level!(lvl, ctx, literal(1))
:(Fiber($(ctx(lvl))))
end
@staged_function Fiber! lvl begin
contain(LowerJulia()) do ctx
lvl = virtualize(:lvl, lvl, ctx)
lvl = resolve(lvl, ctx)
lvl = declare_level!(lvl, ctx, literal(0), literal(virtual_level_default(lvl)))
push!(ctx.preamble, assemble_level!(lvl, ctx, literal(1), literal(1)))
lvl = freeze_level!(lvl, ctx, literal(1))
:(Fiber($(ctx(lvl))))
end
end)
end

@inline f_code(@nospecialize ::Any) = nothing

Expand Down
Loading