Skip to content

Commit

Permalink
Merge pull request #127 from JuliaDiff/ox/comb_nightlyfix
Browse files Browse the repository at this point in the history
Combined nightly fixes
  • Loading branch information
oscardssmith authored Mar 28, 2023
2 parents 3a13dab + 496d032 commit 84ea690
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 115 deletions.
9 changes: 9 additions & 0 deletions src/Diffractor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ export ∂⃖, gradient

const CC = Core.Compiler

const GENERATORS = Expr[]

include("runtime.jl")
include("interface.jl")
include("utils.jl")
Expand Down Expand Up @@ -37,4 +39,11 @@ include("debugutils.jl")

include("stage1/termination.jl")

function reload()
@info "reloading Diffractor generators"
for generator in GENERATORS
Core.eval(@__MODULE__, generator)
end
end

end
12 changes: 7 additions & 5 deletions src/codegen/reverse.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Codegen shared by both stage1 and stage2

function make_opaque_closure(interp, typ, name, meth_nargs, isva, lno, cis, revs...)
function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, cis, revs...)
if interp !== nothing
cis.inferred = true
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
typ, Union{}, cis.rettype, @__MODULE__, cis, lno.line, lno.file, meth_nargs, isva, ()).source
return Expr(:new_opaque_closure, typ, Union{}, Any,
ocm, revs...)
else
oc_nargs = Int64(meth_nargs)
Expr(:new_opaque_closure, typ, Union{}, Any,
Expr(:opaque_closure_method, name, meth_nargs, isva, lno, cis), revs...)
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, cis), revs...)
end
end

Expand Down Expand Up @@ -111,7 +112,8 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
opaque_ci
end

nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs
nfixedargs = Int(meth.nargs)
meth.isva && (nfixedargs -= 1)

extra_slotnames = Symbol[]
extra_slotflags = UInt8[]
Expand Down Expand Up @@ -157,7 +159,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
# TODO: Can we use the same method for each 2nd order of the transform
# (except the last and the first one)
for nc = 1:2:n_closures
arg_accums = Union{Nothing, Vector{Any}}[nothing for i = 1:(meth.nargs)]
arg_accums = Union{Nothing, Vector{Any}}[nothing for i = 1:Int(meth.nargs)]
accums = Union{Nothing, Vector{Any}}[nothing for i = 1:length(ir.stmts)]

opaque_ci = opaque_cis[nc]
Expand Down Expand Up @@ -375,7 +377,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
lno = LineNumberNode(1, :none)
next_oc = insert_node_rev!(make_opaque_closure(interp, Tuple{(Any for i = 1:nargs+1)...},
cname(nc+1, N, meth.name),
meth.nargs,
Int(meth.nargs),
meth.isva,
lno,
opaque_cis[nc+1],
Expand Down
5 changes: 3 additions & 2 deletions src/debugutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ using Core.Compiler: AbstractInterpreter, CodeInstance, MethodInstance, WorldVie
using InteractiveUtils

function infer_function(interp, tt)
world = Core.Compiler.get_world_counter()

# Find all methods that are applicable to these types
mthds = _methods_by_ftype(tt, -1, typemax(UInt))
mthds = _methods_by_ftype(tt, -1, world)
if mthds === false || length(mthds) != 1
error("Unable to find single applicable method for $tt")
end
Expand All @@ -17,7 +19,6 @@ function infer_function(interp, tt)
result = Core.Compiler.InferenceResult(mi)

# Create an InferenceState to begin inference, give it a world that is always newest
world = Core.Compiler.get_world_counter()
frame = Core.Compiler.InferenceState(result, #=cached=# true, interp)

# Run type inference on this frame. Because the interpreter is embedded
Expand Down
46 changes: 18 additions & 28 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,29 @@ struct ∂⃖recurse{N}; end

include("recurse.jl")

function perform_optic_transform(@nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N}
function perform_optic_transform(world::UInt, source::LineNumberNode,
@nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N}
@assert N >= 1

# Check if we have an rrule for this function
mthds = Base._methods_by_ftype(Tuple{args...}, -1, typemax(UInt))
if length(mthds) != 1
return :(throw(MethodError(ff, args)))
sig = Tuple{args...}
mthds = Base._methods_by_ftype(sig, -1, world)
if mthds === nothing || length(mthds) != 1
# Core.println("[perform_optic_transform] ", sig, " => ", mthds)
stub = Core.GeneratedFunctionStub(identity, Core.svec(:ff, :args), Core.svec())
return stub(world, source, :(throw(MethodError(ff, args))))
end
match = mthds[1]
match = only(mthds)::Core.MethodMatch

mi = Core.Compiler.specialize_method(match)
ci = Core.Compiler.retrieve_code_info(mi)
ci = Core.Compiler.retrieve_code_info(mi, world)

ci′ = copy(ci)
ci′.edges = MethodInstance[mi]

r = transform!(ci′, mi.def, length(args) - 1, match.sparams, N)
if isa(r, Expr)
return r
end
ci′ = diffract_transform!(ci′, mi.def, length(args) - 1, match.sparams, N)

ci′.ssavaluetypes = length(ci′.code)
ci′.ssaflags = UInt8[0 for i=1:length(ci′.code)]
ci′.method_for_inference_limit_heuristics = match.method
ci′
return ci′
end

# This relies on PartialStruct to infer well
Expand Down Expand Up @@ -402,18 +400,10 @@ end
ChainRulesCore.backing(::ZeroTangent) = ZeroTangent()
ChainRulesCore.backing(::NoTangent) = NoTangent()

function reload()
Core.eval(Diffractor, quote
function (ff::∂⃖recurse)(args...)
$(Expr(:meta, :generated_only))
$(Expr(:meta,
:generated,
Expr(:new,
Core.GeneratedFunctionStub,
:perform_optic_transform,
Core.svec(:ff, :args),
Core.svec())))
end
end)
let ex = :(function (ff::∂⃖recurse)(args...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, perform_optic_transform))
end)
push!(GENERATORS, ex)
Core.eval(@__MODULE__, ex)
end
reload()
20 changes: 10 additions & 10 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,12 @@ function sptypes(sparams)
end
end

function transform!(ci, meth, nargs, sparams, N)
function diffract_transform!(ci, meth, nargs, sparams, N)
code = ci.code
cfg = compute_basic_blocks(code)
slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
slotflags = UInt8[(0x00 for i = 1:2)..., ci.slotflags...]
slottypes = ci.slottypes === nothing ? nothing : UInt8[(Any for i = 1:2)..., ci.slottypes...]
ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
ci.slottypes = ci.slottypes === nothing ? Any[Any for _ in 1:length(ci.slotflags)] : Any[Any, Any, ci.slottypes...]

meta = Expr[]
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],
Expand All @@ -273,19 +273,19 @@ function transform!(ci, meth, nargs, sparams, N)
domtree = construct_domtree(ir.cfg.blocks)
defuse_insts = scan_slot_def_use(Int(meth.nargs), ci, ir.stmts.inst)
ci.ssavaluetypes = Any[Any for i = 1:ci.ssavaluetypes]
ir = construct_ssa!(ci, ir, domtree, defuse_insts, Any[Any for i = 1:length(slotnames)], Core.Compiler.OptimizerLattice())
ir = construct_ssa!(ci, ir, domtree, defuse_insts, ci.slottypes, Core.Compiler.OptimizerLattice())
ir = compact!(ir)

nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs
meth.isva || @assert nfixedargs == nargs+1

ir = diffract_ir!(ir, ci, meth, sparams, nargs, N)

Core.Compiler.replace_code_newstyle!(ci, ir, nargs+1)
Core.Compiler.replace_code_newstyle!(ci, ir)

ci.ssavaluetypes = length(ci.code)
ci.slotnames = slotnames
ci.slotflags = slotflags
ci.slottypes = slottypes
ci.ssaflags = UInt8[0x00 for i=1:length(ci.code)]
ci.method_for_inference_limit_heuristics = meth

ci
return ci
end
32 changes: 16 additions & 16 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,24 @@ function ∂☆nomethd(@nospecialize(args))
throw(MethodError(primal(args[1]), map(primal, Base.tail(args))))
end

function perform_fwd_transform(@nospecialize(ff::Type{∂☆recurse{N}}), @nospecialize(args)) where {N}
function perform_fwd_transform(world::UInt, source::LineNumberNode,
@nospecialize(ff::Type{∂☆recurse{N}}), @nospecialize(args)) where {N}
if all(x->x <: ZeroBundle, args)
return :(∂☆passthrough(args))
end

# Check if we have an rrule for this function
sig = Tuple{map(π, args)...}
mthds = Base._methods_by_ftype(sig, -1, typemax(UInt))
if length(mthds) != 1
return :(∂☆nomethd(args))
mthds = Base._methods_by_ftype(sig, -1, world)
if mthds === nothing || length(mthds) != 1
# Core.println("[perform_fwd_transform] ", sig, " => ", mthds)
stub = Core.GeneratedFunctionStub(identity, Core.svec(:ff, :args), Core.svec())
return stub(world, source, :(∂☆nomethd(args)))
end
match = mthds[1]
match = only(mthds)::Core.MethodMatch

mi = Core.Compiler.specialize_method(match)
ci = Core.Compiler.retrieve_code_info(mi)
ci = Core.Compiler.retrieve_code_info(mi, world)

ci′ = copy(ci)
ci′.edges = MethodInstance[mi]
Expand All @@ -59,16 +62,13 @@ function perform_fwd_transform(@nospecialize(ff::Type{∂☆recurse{N}}), @nospe
ci′.slotflags = slotflags
ci′.slottypes = slottypes

ci′
return ci′
end

@eval function (ff::∂☆recurse)(args...)
$(Expr(:meta, :generated_only))
$(Expr(:meta,
:generated,
Expr(:new,
Core.GeneratedFunctionStub,
:perform_fwd_transform,
Core.svec(:ff, :args),
Core.svec())))
let ex = :(function (ff::∂☆recurse)(args...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, perform_fwd_transform))
end)
push!(GENERATORS, ex)
Core.eval(@__MODULE__, ex)
end
95 changes: 47 additions & 48 deletions src/stage1/termination.jl
Original file line number Diff line number Diff line change
@@ -1,73 +1,72 @@
if :recursion_relation in fieldnames(Method)

first(methods(Diffractor.∂⃖recurse{1}())).recursion_relation = function(method1, method2, parent_sig, new_sig)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
child_order = new_sig.parameters[1].parameters[1]
#@Core.Main.Base.show (parent_order, child_order)
if parent_order > child_order
return true
end
wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...}
wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...}
if method2 !== nothing && isdefined(method2, :recursion_relation)
# TODO: What if method2 is itself a generated function.
return method2.recursion_relation(method2, nothing, wrapped_parent_sig, wrapped_new_sig)
end
return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
child_order = new_sig.parameters[1].parameters[1]
#@Core.Main.Base.show (parent_order, child_order)
if parent_order > child_order
return true
end
wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...}
wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...}
if method2 !== nothing && isdefined(method2, :recursion_relation)
# TODO: What if method2 is itself a generated function.
return method2.recursion_relation(method2, nothing, wrapped_parent_sig, wrapped_new_sig)
end
return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1)
end

first(methods(PrimeDerivativeBack(sin))).recursion_relation = function(method1, method2, parent_sig, new_sig)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
child_order = new_sig.parameters[1].parameters[1]
#@Core.Main.Base.show (parent_order, child_order)
if parent_order > child_order
return true
end
wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...}
wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...}
return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
child_order = new_sig.parameters[1].parameters[1]
#@Core.Main.Base.show (parent_order, child_order)
if parent_order > child_order
return true
end
wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...}
wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...}
return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1)
end

which(Tuple{∂⃖{N}, T, Vararg{Any}} where {T,N}).recursion_relation = function(_, _, parent_sig, new_sig)
# Any actual recursion will always be caught be one of the functions we're
# recursing into.
return isa(Base.unwrap_unionall(parent_sig.parameters[1].parameters[1]), Int) &&
# Any actual recursion will always be caught be one of the functions we're
# recursing into.
return isa(Base.unwrap_unionall(parent_sig.parameters[1].parameters[1]), Int) &&
isa(Base.unwrap_unionall(new_sig.parameters[1].parameters[1]), Int)
end

which(Tuple{∂⃖{N}, ∂⃖{1}, Vararg{Any}} where {N}).recursion_relation = function(_, _, parent_sig, new_sig)
# Allowed as long as both parent and new sig have concrete integers. In that
# case, actual recursion will be caught elsewhere.
return isa(Base.unwrap_unionall(parent_sig.parameters[1].parameters[1]), Int) &&
isa(Base.unwrap_unionall(new_sig.parameters[1].parameters[1]), Int)
# Allowed as long as both parent and new sig have concrete integers. In that
# case, actual recursion will be caught elsewhere.
return isa(Base.unwrap_unionall(parent_sig.parameters[1].parameters[1]), Int) &&
isa(Base.unwrap_unionall(new_sig.parameters[1].parameters[1]), Int)
end

for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆recurse{N}, Vararg{Any}} where {N}, nothing, -1, Base.get_world_counter())
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
child_order = new_sig.parameters[1].parameters[1]
#@Core.Main.Base.show (parent_order, child_order)
if parent_order > child_order
return true
end
@show (parent_sig, new_sig)
return false
end
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
# Recursion from a higher to a lower order is always allowed
parent_order = parent_sig.parameters[1].parameters[1]
child_order = new_sig.parameters[1].parameters[1]
if parent_order > child_order
return true
end
Core.Compiler.@show (parent_sig, new_sig)
return false
end
end

for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆internal{N}, Vararg{Any}} where {N}, nothing, -1, Base.get_world_counter())
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
return true
end
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
return true
end
end

for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆{N}, Vararg{Any}} where {N}, nothing, -1, Base.get_world_counter())
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
return true
end
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
return true
end
end

end
Loading

0 comments on commit 84ea690

Please sign in to comment.