Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combined nightly fixes #127

Merged
merged 4 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/Diffractor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ export ∂⃖, gradient

const CC = Core.Compiler

const GENERATORS = Expr[]

include("runtime.jl")
include("interface.jl")
include("utils.jl")
Expand Down Expand Up @@ -37,4 +39,11 @@ include("debugutils.jl")

include("stage1/termination.jl")

function reload()
@info "reloading Diffractor generators"
for generator in GENERATORS
Core.eval(@__MODULE__, generator)
end
end

end
12 changes: 7 additions & 5 deletions src/codegen/reverse.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Codegen shared by both stage1 and stage2

function make_opaque_closure(interp, typ, name, meth_nargs, isva, lno, cis, revs...)
function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, cis, revs...)
if interp !== nothing
cis.inferred = true
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
typ, Union{}, cis.rettype, @__MODULE__, cis, lno.line, lno.file, meth_nargs, isva, ()).source
return Expr(:new_opaque_closure, typ, Union{}, Any,
ocm, revs...)
else
oc_nargs = Int64(meth_nargs)
Expr(:new_opaque_closure, typ, Union{}, Any,
Expr(:opaque_closure_method, name, meth_nargs, isva, lno, cis), revs...)
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, cis), revs...)
end
end

Expand Down Expand Up @@ -111,7 +112,8 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
opaque_ci
end

nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs
nfixedargs = Int(meth.nargs)
meth.isva && (nfixedargs -= 1)

extra_slotnames = Symbol[]
extra_slotflags = UInt8[]
Expand Down Expand Up @@ -157,7 +159,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
# TODO: Can we use the same method for each 2nd order of the transform
# (except the last and the first one)
for nc = 1:2:n_closures
arg_accums = Union{Nothing, Vector{Any}}[nothing for i = 1:(meth.nargs)]
arg_accums = Union{Nothing, Vector{Any}}[nothing for i = 1:Int(meth.nargs)]
accums = Union{Nothing, Vector{Any}}[nothing for i = 1:length(ir.stmts)]

opaque_ci = opaque_cis[nc]
Expand Down Expand Up @@ -375,7 +377,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
lno = LineNumberNode(1, :none)
next_oc = insert_node_rev!(make_opaque_closure(interp, Tuple{(Any for i = 1:nargs+1)...},
cname(nc+1, N, meth.name),
meth.nargs,
Int(meth.nargs),
meth.isva,
lno,
opaque_cis[nc+1],
Expand Down
5 changes: 3 additions & 2 deletions src/debugutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ using Core.Compiler: AbstractInterpreter, CodeInstance, MethodInstance, WorldVie
using InteractiveUtils

function infer_function(interp, tt)
world = Core.Compiler.get_world_counter()

# Find all methods that are applicable to these types
mthds = _methods_by_ftype(tt, -1, typemax(UInt))
mthds = _methods_by_ftype(tt, -1, world)
if mthds === false || length(mthds) != 1
error("Unable to find single applicable method for $tt")
end
Expand All @@ -17,7 +19,6 @@ function infer_function(interp, tt)
result = Core.Compiler.InferenceResult(mi)

# Create an InferenceState to begin inference, give it a world that is always newest
world = Core.Compiler.get_world_counter()
frame = Core.Compiler.InferenceState(result, #=cached=# true, interp)

# Run type inference on this frame. Because the interpreter is embedded
Expand Down
46 changes: 18 additions & 28 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,29 @@ struct ∂⃖recurse{N}; end

include("recurse.jl")

function perform_optic_transform(@nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N}
function perform_optic_transform(world::UInt, source::LineNumberNode,
@nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N}
@assert N >= 1

# Check if we have an rrule for this function
mthds = Base._methods_by_ftype(Tuple{args...}, -1, typemax(UInt))
if length(mthds) != 1
return :(throw(MethodError(ff, args)))
sig = Tuple{args...}
mthds = Base._methods_by_ftype(sig, -1, world)
if mthds === nothing || length(mthds) != 1
# Core.println("[perform_optic_transform] ", sig, " => ", mthds)
stub = Core.GeneratedFunctionStub(identity, Core.svec(:ff, :args), Core.svec())
return stub(world, source, :(throw(MethodError(ff, args))))
end
match = mthds[1]
match = only(mthds)::Core.MethodMatch

mi = Core.Compiler.specialize_method(match)
ci = Core.Compiler.retrieve_code_info(mi)
ci = Core.Compiler.retrieve_code_info(mi, world)

ci′ = copy(ci)
ci′.edges = MethodInstance[mi]

r = transform!(ci′, mi.def, length(args) - 1, match.sparams, N)
if isa(r, Expr)
return r
end
ci′ = diffract_transform!(ci′, mi.def, length(args) - 1, match.sparams, N)

ci′.ssavaluetypes = length(ci′.code)
ci′.ssaflags = UInt8[0 for i=1:length(ci′.code)]
ci′.method_for_inference_limit_heuristics = match.method
ci′
return ci′
end

# This relies on PartialStruct to infer well
Expand Down Expand Up @@ -402,18 +400,10 @@ end
ChainRulesCore.backing(::ZeroTangent) = ZeroTangent()
ChainRulesCore.backing(::NoTangent) = NoTangent()

function reload()
Core.eval(Diffractor, quote
function (ff::∂⃖recurse)(args...)
$(Expr(:meta, :generated_only))
$(Expr(:meta,
:generated,
Expr(:new,
Core.GeneratedFunctionStub,
:perform_optic_transform,
Core.svec(:ff, :args),
Core.svec())))
end
end)
let ex = :(function (ff::∂⃖recurse)(args...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, perform_optic_transform))
end)
push!(GENERATORS, ex)
Core.eval(@__MODULE__, ex)
end
reload()
20 changes: 10 additions & 10 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,12 @@ function sptypes(sparams)
end
end

function transform!(ci, meth, nargs, sparams, N)
function diffract_transform!(ci, meth, nargs, sparams, N)
code = ci.code
cfg = compute_basic_blocks(code)
slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
slotflags = UInt8[(0x00 for i = 1:2)..., ci.slotflags...]
slottypes = ci.slottypes === nothing ? nothing : UInt8[(Any for i = 1:2)..., ci.slottypes...]
ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
ci.slottypes = ci.slottypes === nothing ? Any[Any for _ in 1:length(ci.slotflags)] : Any[Any, Any, ci.slottypes...]

meta = Expr[]
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],
Expand All @@ -273,19 +273,19 @@ function transform!(ci, meth, nargs, sparams, N)
domtree = construct_domtree(ir.cfg.blocks)
defuse_insts = scan_slot_def_use(Int(meth.nargs), ci, ir.stmts.inst)
ci.ssavaluetypes = Any[Any for i = 1:ci.ssavaluetypes]
ir = construct_ssa!(ci, ir, domtree, defuse_insts, Any[Any for i = 1:length(slotnames)], Core.Compiler.OptimizerLattice())
ir = construct_ssa!(ci, ir, domtree, defuse_insts, ci.slottypes, Core.Compiler.OptimizerLattice())
ir = compact!(ir)

nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs
meth.isva || @assert nfixedargs == nargs+1

ir = diffract_ir!(ir, ci, meth, sparams, nargs, N)

Core.Compiler.replace_code_newstyle!(ci, ir, nargs+1)
Core.Compiler.replace_code_newstyle!(ci, ir)

ci.ssavaluetypes = length(ci.code)
ci.slotnames = slotnames
ci.slotflags = slotflags
ci.slottypes = slottypes
ci.ssaflags = UInt8[0x00 for i=1:length(ci.code)]
ci.method_for_inference_limit_heuristics = meth

ci
return ci
end
32 changes: 16 additions & 16 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,24 @@ function ∂☆nomethd(@nospecialize(args))
throw(MethodError(primal(args[1]), map(primal, Base.tail(args))))
end

function perform_fwd_transform(@nospecialize(ff::Type{∂☆recurse{N}}), @nospecialize(args)) where {N}
function perform_fwd_transform(world::UInt, source::LineNumberNode,
@nospecialize(ff::Type{∂☆recurse{N}}), @nospecialize(args)) where {N}
if all(x->x <: ZeroBundle, args)
return :(∂☆passthrough(args))
end

# Check if we have an rrule for this function
sig = Tuple{map(π, args)...}
mthds = Base._methods_by_ftype(sig, -1, typemax(UInt))
if length(mthds) != 1
return :(∂☆nomethd(args))
mthds = Base._methods_by_ftype(sig, -1, world)
if mthds === nothing || length(mthds) != 1
# Core.println("[perform_fwd_transform] ", sig, " => ", mthds)
stub = Core.GeneratedFunctionStub(identity, Core.svec(:ff, :args), Core.svec())
return stub(world, source, :(∂☆nomethd(args)))
end
match = mthds[1]
match = only(mthds)::Core.MethodMatch

mi = Core.Compiler.specialize_method(match)
ci = Core.Compiler.retrieve_code_info(mi)
ci = Core.Compiler.retrieve_code_info(mi, world)

ci′ = copy(ci)
ci′.edges = MethodInstance[mi]
Expand All @@ -59,16 +62,13 @@ function perform_fwd_transform(@nospecialize(ff::Type{∂☆recurse{N}}), @nospe
ci′.slotflags = slotflags
ci′.slottypes = slottypes

ci′
return ci′
end

@eval function (ff::∂☆recurse)(args...)
$(Expr(:meta, :generated_only))
$(Expr(:meta,
:generated,
Expr(:new,
Core.GeneratedFunctionStub,
:perform_fwd_transform,
Core.svec(:ff, :args),
Core.svec())))
let ex = :(function (ff::∂☆recurse)(args...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, perform_fwd_transform))
end)
push!(GENERATORS, ex)
Core.eval(@__MODULE__, ex)
end
95 changes: 47 additions & 48 deletions src/stage1/termination.jl
Original file line number Diff line number Diff line change
@@ -1,73 +1,72 @@
if :recursion_relation in fieldnames(Method)

first(methods(Diffractor.∂⃖recurse{1}())).recursion_relation = function(method1, method2, parent_sig, new_sig)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
child_order = new_sig.parameters[1].parameters[1]
#@Core.Main.Base.show (parent_order, child_order)
if parent_order > child_order
return true
end
wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...}
wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...}
if method2 !== nothing && isdefined(method2, :recursion_relation)
# TODO: What if method2 is itself a generated function.
return method2.recursion_relation(method2, nothing, wrapped_parent_sig, wrapped_new_sig)
end
return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
child_order = new_sig.parameters[1].parameters[1]
#@Core.Main.Base.show (parent_order, child_order)
if parent_order > child_order
return true
end
wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...}
wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...}
if method2 !== nothing && isdefined(method2, :recursion_relation)
# TODO: What if method2 is itself a generated function.
return method2.recursion_relation(method2, nothing, wrapped_parent_sig, wrapped_new_sig)
end
return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1)
end

first(methods(PrimeDerivativeBack(sin))).recursion_relation = function(method1, method2, parent_sig, new_sig)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
child_order = new_sig.parameters[1].parameters[1]
#@Core.Main.Base.show (parent_order, child_order)
if parent_order > child_order
return true
end
wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...}
wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...}
return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
child_order = new_sig.parameters[1].parameters[1]
#@Core.Main.Base.show (parent_order, child_order)
if parent_order > child_order
return true
end
wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...}
wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...}
return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1)
end

which(Tuple{∂⃖{N}, T, Vararg{Any}} where {T,N}).recursion_relation = function(_, _, parent_sig, new_sig)
# Any actual recursion will always be caught be one of the functions we're
# recursing into.
return isa(Base.unwrap_unionall(parent_sig.parameters[1].parameters[1]), Int) &&
# Any actual recursion will always be caught be one of the functions we're
# recursing into.
return isa(Base.unwrap_unionall(parent_sig.parameters[1].parameters[1]), Int) &&
isa(Base.unwrap_unionall(new_sig.parameters[1].parameters[1]), Int)
end

which(Tuple{∂⃖{N}, ∂⃖{1}, Vararg{Any}} where {N}).recursion_relation = function(_, _, parent_sig, new_sig)
# Allowed as long as both parent and new sig have concrete integers. In that
# case, actual recursion will be caught elsewhere.
return isa(Base.unwrap_unionall(parent_sig.parameters[1].parameters[1]), Int) &&
isa(Base.unwrap_unionall(new_sig.parameters[1].parameters[1]), Int)
# Allowed as long as both parent and new sig have concrete integers. In that
# case, actual recursion will be caught elsewhere.
return isa(Base.unwrap_unionall(parent_sig.parameters[1].parameters[1]), Int) &&
isa(Base.unwrap_unionall(new_sig.parameters[1].parameters[1]), Int)
end

for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆recurse{N}, Vararg{Any}} where {N}, nothing, -1, Base.get_world_counter())
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
child_order = new_sig.parameters[1].parameters[1]
#@Core.Main.Base.show (parent_order, child_order)
if parent_order > child_order
return true
end
@show (parent_sig, new_sig)
return false
end
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
child_order = new_sig.parameters[1].parameters[1]
if parent_order > child_order
return true
end
Core.Compiler.@show (parent_sig, new_sig)
return false
end
end

for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆internal{N}, Vararg{Any}} where {N}, nothing, -1, Base.get_world_counter())
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
return true
end
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
return true
end
end

for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆{N}, Vararg{Any}} where {N}, nothing, -1, Base.get_world_counter())
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
return true
end
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
return true
end
end

end
Loading