From 887526f4507feae6ef044ec793a1602a811c6454 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 13 Feb 2024 17:23:44 -0500 Subject: [PATCH] fixup! add MultilineFusion example --- test/compiler/plugins.jl | 24 +++++++++++++++++++ .../MultilineFusion/src/MultilineFusion.jl | 4 ++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/test/compiler/plugins.jl b/test/compiler/plugins.jl index c15d24d953638..e375f27fa3eaa 100644 --- a/test/compiler/plugins.jl +++ b/test/compiler/plugins.jl @@ -28,5 +28,29 @@ let tr = trace(fib, 2) @test length(tr.children) == 6 end + +using MultilineFusion + +# XXX: should these be in `MultilineFusion/test/runtests.jl`? + +function multiline(A, B) + C = A .* B + D = C .+ A + end + +let A = ones(3,3) + B = ones(3) + @test (@inferred multiline_fusion(multiline, A, B))::Matrix{Float64} == multiline(A, B) +end + +let ir, _ = only(Base.code_ircode(multiline, (Matrix{Float64}, Vector{Float64}), optimize_until="compact 1")) + @test length(ir.stmts) == 5 + @test ir.stmts[2][:stmt].args[2] == GlobalRef(Base, materialize) +end + +let ir, _ = only(Base.code_ircode(multiline, (Matrix{Float64}, Vector{Float64}), optimize_until="compact 1", interp=MultilineFusion.MLFInterp())) + @test length(ir.stmts) == 4 +end + empty!(Base.LOAD_PATH) append!(Base.LOAD_PATH, original_load_path) \ No newline at end of file diff --git a/test/compiler/plugins/MultilineFusion/src/MultilineFusion.jl b/test/compiler/plugins/MultilineFusion/src/MultilineFusion.jl index dc758b289d9c9..6f88b63878e58 100644 --- a/test/compiler/plugins/MultilineFusion/src/MultilineFusion.jl +++ b/test/compiler/plugins/MultilineFusion/src/MultilineFusion.jl @@ -3,7 +3,7 @@ module MultilineFusion export multiline_fusion function multiline_fusion(f, args...) - Base.invoke_within(MultilineFusion(), f, args...) + Base.invoke_within(MLFCompiler(), f, args...) end const CC = Core.Compiler @@ -25,7 +25,7 @@ struct MLFInterp <: CC.AbstractInterpreter inf_params::CC.InferenceParams opt_params::CC.OptimizationParams inf_cache::Vector{CC.InferenceResult} - function MLFInterp(compiler::MLFCompiler; + function MLFInterp(compiler::MLFCompiler = MLFCompiler(); world::UInt = Base.get_world_counter(), inf_params::CC.InferenceParams = CC.InferenceParams(), opt_params::CC.OptimizationParams = CC.OptimizationParams(),