From 72dc8e7ab0919e0d89ea091bce5fb42801c14b38 Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Tue, 27 Jun 2023 15:24:19 -0700 Subject: [PATCH] Improve type inference of non-`const` GlobalRef's GlobalRefs often do not get their types refined, so we should be careful to re-use types that we have when possible. This PR moves non-`const` GlobalRef handling out of `maparg()` so that we can keep the original type. --- src/codegen/forward_demand.jl | 16 ++++++++++------ test/forward_diff_no_inf.jl | 5 +++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index acd00bd8..5f675508 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -256,11 +256,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; # TODO: Should we remember whether the callbacks wanted the arg? return transform!(ir, arg, order) elseif isa(arg, GlobalRef) - if !isconst(arg) - # Non-const GlabalRefs need to need to be accessed as seperate statements - arg = insert_node!(ir, ssa, NewInstruction(arg, Any)) - end - + @assert isconst(arg) return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg), Any)) elseif isa(arg, QuoteNode) return ZeroBundle{order}(arg.value) @@ -302,7 +298,15 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; # TODO: New PiNode that discriminates based on primal? inst[:inst] = maparg(stmt.val, SSAValue(ssa), order) inst[:type] = Any - elseif isa(stmt, GlobalRef) || isa(stmt, SSAValue) || isa(stmt, QuoteNode) + elseif isa(stmt, GlobalRef) + if !isconst(stmt) + # Non-const GlobalRefs need to need to be accessed as seperate statements + stmt = insert_node!(ir, ssa, NewInstruction(inst)) + end + + inst[:inst] = Expr(:call, ZeroBundle{order}, stmt) + inst[:type] = Any + elseif isa(stmt, SSAValue) || isa(stmt, QuoteNode) inst[:inst] = maparg(stmt, SSAValue(ssa), order) inst[:type] = Any elseif isa(stmt, Expr) || isa(stmt, PhiNode) || isa(stmt, PhiCNode) || diff --git a/test/forward_diff_no_inf.jl b/test/forward_diff_no_inf.jl index 572bf6a0..4e8bce57 100644 --- a/test/forward_diff_no_inf.jl +++ b/test/forward_diff_no_inf.jl @@ -28,6 +28,11 @@ module forward_diff_no_inf Diffractor.forward_diff_no_inf!(ir, [Core.SSAValue(1) => 1]; transform! = identity_transform!) ir2 = Core.Compiler.compact!(ir) Core.Compiler.verify_ir(ir2) # This would error if we were not handling nonconst globals correctly + # Assert that the reference to `Main._coeff` is properly typed + stmt_idx = findfirst(stmt -> isa(stmt[:inst], GlobalRef), collect(ir2.stmts)) + stmt = ir2.stmts[stmt_idx] + @test stmt[:inst].name == :_coeff + @test stmt[:type] == Float64 f = Core.OpaqueClosure(ir2; do_compile=false) @test f(3.5) == 28.0 end