Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Mar 28, 2023
1 parent 4438c41 commit 640b315
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 22 deletions.
9 changes: 5 additions & 4 deletions src/codegen/reverse.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Codegen shared by both stage1 and stage2

function make_opaque_closure(interp, typ, name, meth_nargs, isva, lno, cis, revs...)
function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, cis, revs...)
if interp !== nothing
cis.inferred = true
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
Expand Down Expand Up @@ -112,7 +112,8 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
opaque_ci
end

nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs
nfixedargs = Int(meth.nargs)
meth.isva && (nfixedargs -= 1)

extra_slotnames = Symbol[]
extra_slotflags = UInt8[]
Expand Down Expand Up @@ -158,7 +159,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
# TODO: Can we use the same method for each 2nd order of the transform
# (except the last and the first one)
for nc = 1:2:n_closures
arg_accums = Union{Nothing, Vector{Any}}[nothing for i = 1:(meth.nargs)]
arg_accums = Union{Nothing, Vector{Any}}[nothing for i = 1:Int(meth.nargs)]
accums = Union{Nothing, Vector{Any}}[nothing for i = 1:length(ir.stmts)]

opaque_ci = opaque_cis[nc]
Expand Down Expand Up @@ -376,7 +377,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
lno = LineNumberNode(1, :none)
next_oc = insert_node_rev!(make_opaque_closure(interp, Tuple{(Any for i = 1:nargs+1)...},
cname(nc+1, N, meth.name),
meth.nargs,
Int(meth.nargs),
meth.isva,
lno,
opaque_cis[nc+1],
Expand Down
10 changes: 2 additions & 8 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,9 @@ function perform_optic_transform(@nospecialize(ff::Type{∂⃖recurse{N}}), @nos
ci′ = copy(ci)
ci′.edges = MethodInstance[mi]

r = transform!(ci′, mi.def, length(args) - 1, match.sparams, N)
if isa(r, Expr)
return r
end
ci′ = diffract_transform!(ci′, mi.def, length(args) - 1, match.sparams, N)

ci′.ssavaluetypes = length(ci′.code)
ci′.ssaflags = UInt8[0 for i=1:length(ci′.code)]
ci′.method_for_inference_limit_heuristics = match.method
ci′
return ci′
end

# This relies on PartialStruct to infer well
Expand Down
20 changes: 10 additions & 10 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,12 @@ function sptypes(sparams)
end
end

function transform!(ci, meth, nargs, sparams, N)
function diffract_transform!(ci, meth, nargs, sparams, N)
code = ci.code
cfg = compute_basic_blocks(code)
slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
slotflags = UInt8[(0x00 for i = 1:2)..., ci.slotflags...]
slottypes = ci.slottypes === nothing ? nothing : UInt8[(Any for i = 1:2)..., ci.slottypes...]
ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
ci.slottypes = ci.slottypes === nothing ? Any[Any for _ in 1:length(ci.slotflags)] : Any[Any, Any, ci.slottypes...]

meta = Expr[]
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],
Expand All @@ -273,19 +273,19 @@ function transform!(ci, meth, nargs, sparams, N)
domtree = construct_domtree(ir.cfg.blocks)
defuse_insts = scan_slot_def_use(Int(meth.nargs), ci, ir.stmts.inst)
ci.ssavaluetypes = Any[Any for i = 1:ci.ssavaluetypes]
ir = construct_ssa!(ci, ir, domtree, defuse_insts, Any[Any for i = 1:length(slotnames)], Core.Compiler.OptimizerLattice())
ir = construct_ssa!(ci, ir, domtree, defuse_insts, ci.slottypes, Core.Compiler.OptimizerLattice())
ir = compact!(ir)

nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs
meth.isva || @assert nfixedargs == nargs+1

ir = diffract_ir!(ir, ci, meth, sparams, nargs, N)

Core.Compiler.replace_code_newstyle!(ci, ir, nargs+1)
Core.Compiler.replace_code_newstyle!(ci, ir)

ci.ssavaluetypes = length(ci.code)
ci.slotnames = slotnames
ci.slotflags = slotflags
ci.slottypes = slottypes
ci.ssaflags = UInt8[0x00 for i=1:length(ci.code)]
ci.method_for_inference_limit_heuristics = meth

ci
return ci
end

0 comments on commit 640b315

Please sign in to comment.