Skip to content

Commit

Permalink
Improve type inference of non-const GlobalRef's
Browse files Browse the repository at this point in the history
GlobalRefs often do not get their types refined, so we should be careful
to re-use types that we have when possible.  This PR moves non-`const`
GlobalRef handling out of `maparg()` so that we can keep the original
type.
  • Loading branch information
staticfloat committed Jun 28, 2023
1 parent b23337a commit 72dc8e7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
# TODO: Should we remember whether the callbacks wanted the arg?
return transform!(ir, arg, order)
elseif isa(arg, GlobalRef)
if !isconst(arg)
# Non-const GlabalRefs need to need to be accessed as seperate statements
arg = insert_node!(ir, ssa, NewInstruction(arg, Any))
end

@assert isconst(arg)
return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg), Any))
elseif isa(arg, QuoteNode)
return ZeroBundle{order}(arg.value)
Expand Down Expand Up @@ -302,7 +298,15 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
# TODO: New PiNode that discriminates based on primal?
inst[:inst] = maparg(stmt.val, SSAValue(ssa), order)
inst[:type] = Any
elseif isa(stmt, GlobalRef) || isa(stmt, SSAValue) || isa(stmt, QuoteNode)
elseif isa(stmt, GlobalRef)
if !isconst(stmt)
# Non-const GlobalRefs need to need to be accessed as seperate statements
stmt = insert_node!(ir, ssa, NewInstruction(inst))
end

inst[:inst] = Expr(:call, ZeroBundle{order}, stmt)
inst[:type] = Any
elseif isa(stmt, SSAValue) || isa(stmt, QuoteNode)
inst[:inst] = maparg(stmt, SSAValue(ssa), order)
inst[:type] = Any
elseif isa(stmt, Expr) || isa(stmt, PhiNode) || isa(stmt, PhiCNode) ||
Expand Down
5 changes: 5 additions & 0 deletions test/forward_diff_no_inf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ module forward_diff_no_inf
Diffractor.forward_diff_no_inf!(ir, [Core.SSAValue(1) => 1]; transform! = identity_transform!)
ir2 = Core.Compiler.compact!(ir)
Core.Compiler.verify_ir(ir2) # This would error if we were not handling nonconst globals correctly
# Assert that the reference to `Main._coeff` is properly typed
stmt_idx = findfirst(stmt -> isa(stmt[:inst], GlobalRef), collect(ir2.stmts))
stmt = ir2.stmts[stmt_idx]
@test stmt[:inst].name == :_coeff
@test stmt[:type] == Float64
f = Core.OpaqueClosure(ir2; do_compile=false)
@test f(3.5) == 28.0
end
Expand Down

0 comments on commit 72dc8e7

Please sign in to comment.