From b584bef01ad14fe1d3ba567d2b67de941fe5e57b Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Tue, 27 Dec 2022 10:55:04 +0000 Subject: [PATCH] Import stage2 demand-driven forward mode This imports a ripped-out version of the demand-driven AD code from CedarSim and hooks it into the ADInterpreter. Starting with this code has the advantage that it is working-ish, but the disadvantage that it doesn't really interact with the rest of Diffractor yet. Still, I think it's a reasonable point to start. I'm doing this as a separate commit, so we can keep better track of the subsequent refactoring. --- Manifest.toml | 357 ++++++++++++++++++++++++++++++++ src/Diffractor.jl | 6 + src/analysis/forward.jl | 30 +++ src/codegen/forward.jl | 89 ++++++++ src/codegen/forward_demand.jl | 154 ++++++++++++++ src/stage1/compiler_utils.jl | 5 + src/stage1/recurse_fwd.jl | 89 -------- src/stage2/abstractinterpret.jl | 99 +++++---- src/stage2/forward.jl | 38 ++++ src/stage2/interpreter.jl | 17 +- src/stage2/lattice.jl | 17 +- test/runtests.jl | 2 + test/stage2_fwd.jl | 7 + 13 files changed, 769 insertions(+), 141 deletions(-) create mode 100644 Manifest.toml create mode 100644 src/analysis/forward.jl create mode 100644 src/codegen/forward.jl create mode 100644 src/codegen/forward_demand.jl create mode 100644 src/stage2/forward.jl create mode 100644 test/stage2_fwd.jl diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 00000000..174e208d --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,357 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.0-DEV" +manifest_format = "2.0" +project_hash = "f6209327c3bf3625f9bce3952e420a70ebd8af82" + +[[deps.AbstractTrees]] +git-tree-sha1 = "52b3b436f8f73133d7bc3a6c71ee7ed6ab2ab754" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.4.3" + +[[deps.Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "195c5505521008abea5aee4f96930717958eac6f" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.4.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] +git-tree-sha1 = "99a39b0f807499510e2ea14b0eef8422082aa372" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.46.0" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "e7ff6cadf743c098e08fca25c91103ee4303c9bb" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.15.6" + +[[deps.ChangesOfVariables]] +deps = ["ChainRulesCore", "LinearAlgebra", "Test"] +git-tree-sha1 = "38f7a08f19d8810338d4f5085211c7dfa5d5bdd8" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.4" + +[[deps.CodeTracking]] +deps = ["InteractiveUtils", "UUIDs"] +git-tree-sha1 = "3bf60ba2fae10e10f70d53c070424e40a820dac2" +uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" +version = "1.1.2" + +[[deps.Combinatorics]] +git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" +uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +version = "1.0.2" + +[[deps.Compat]] +deps = ["Dates", "LinearAlgebra", "UUIDs"] +git-tree-sha1 = "00a2cccc7f098ff3b66806862d275ca3db9e6e5a" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.5.0" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.0.1+0" + +[[deps.Cthulhu]] +deps = ["CodeTracking", "FoldingTrees", "InteractiveUtils", "Preferences", "REPL", "UUIDs", "Unicode"] +git-tree-sha1 = "e31248559b7861339d09086e7bc5597898ae7a47" +uuid = "f68482b8-f384-11e8-15f7-abe071a5a75f" +version = "2.7.6" + +[[deps.DataAPI]] +git-tree-sha1 = "e8119c1a33d267e16108be441a287a6981ba1630" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.14.0" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.13" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FoldingTrees]] +deps = ["AbstractTrees", "REPL"] +git-tree-sha1 = "d94efd85f2fe192cdf664aa8b7c431592faed59e" +uuid = "1eca21be-9b9b-4ed8-839a-6d8ae26b1781" +version = "1.2.1" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "6872f5ec8fd1a38880f027a26739d42dcda6691f" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.2" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "49510dfcb407e572524ba94aeae2fced1f3feb0f" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.8" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.1" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.84.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "946607f84feb96220f480e0422d3484c49c00239" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.19" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.0+0" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.1.0" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.10.11" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OffsetArrays]] +deps = ["Adapt"] +git-tree-sha1 = "f71d8950b724e9ff6110fc948dff5a329f901d64" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.12.8" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.21+0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.4.1" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.8.0" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.3.0" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA", "Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "a4ada03f999bd01b3a25dcaa30b2d929fe537e00" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.1.0" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] +git-tree-sha1 = "ffc098086f35909741f71ce21d03dadf0d2bfa76" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.5.11" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.0" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.9.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f9af7f195fb13589dd2e2d57fdb401717d2eb1f6" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.5.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.33.21" + +[[deps.StructArrays]] +deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"] +git-tree-sha1 = "b03a3b745aa49b566f128977a7dd1be8711c5e71" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.14" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "5.10.1+0" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] +git-tree-sha1 = "c79322d36826aa2f4fd8ecfa96ddb47b174ac78d" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.10.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+0" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.2.0+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.48.0+0" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+0" diff --git a/src/Diffractor.jl b/src/Diffractor.jl index ec70ddd2..f8f9a32a 100644 --- a/src/Diffractor.jl +++ b/src/Diffractor.jl @@ -4,6 +4,8 @@ using StructArrays export ∂⃖, gradient +const CC = Core.Compiler + include("runtime.jl") include("interface.jl") include("utils.jl") @@ -21,7 +23,11 @@ include("stage2/interpreter.jl") include("stage2/lattice.jl") include("stage2/abstractinterpret.jl") include("stage2/tfuncs.jl") +include("stage2/forward.jl") +include("codegen/forward.jl") +include("analysis/forward.jl") +include("codegen/forward_demand.jl") include("codegen/reverse.jl") include("extra_rules.jl") diff --git a/src/analysis/forward.jl b/src/analysis/forward.jl new file mode 100644 index 00000000..489e9318 --- /dev/null +++ b/src/analysis/forward.jl @@ -0,0 +1,30 @@ +using Core.Compiler: StmtInfo, ArgInfo, CallMeta + +function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, sv::InferenceState, primal_call::CallMeta) + if f === ChainRulesCore.frule + # TODO: Currently, we don't have any termination analysis for the non-stratified + # forward analysis, so bail out here. + return nothing + end + + nargs = length(arginfo.argtypes)-1 + + # Here we simply check for the frule existance - we don't want to do a full + # inference with specialized argtypes and everything since the problem is + # likely sparse and we only need to do a full inference on a few calls. + # Thus, here we pick `Any` for the tangent types rather than trying to + # discover what they are. frules should be written in such a way that + # whether or not they return `nothing`, only depends on the non-tangent arguments + frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}] + frule_argtypes = append!(frule_preargtypes, arginfo.argtypes) + frule_arginfo = ArgInfo(nothing, frule_argtypes) + # turn off frule analysis in the frule to avoid cycling + interp′ = disable_forward(interp) + frule_call = CC.abstract_call_known(interp′, ChainRulesCore.frule, frule_arginfo, StmtInfo(true), sv, #=max_methods=#-1) + if frule_call.rt !== Const(nothing) + return CallMeta(primal_call.rt, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call)) + end + + return nothing +end diff --git a/src/codegen/forward.jl b/src/codegen/forward.jl new file mode 100644 index 00000000..21c4f8ae --- /dev/null +++ b/src/codegen/forward.jl @@ -0,0 +1,89 @@ +function transform_fwd!(ci, meth, nargs, sparams, N) + new_code = Any[] + new_codelocs = Any[] + ssa_mapping = Int[] + loc_mapping = Int[] + + function emit!(stmt) + (isexpr(stmt, :call) || isexpr(stmt, :(=)) || isexpr(stmt, :new)) || return stmt + push!(new_code, stmt) + push!(new_codelocs, isempty(new_codelocs) ? 0 : new_codelocs[end]) + SSAValue(length(new_code)) + end + + function mapstmt!(stmt) + if isexpr(stmt, :(=)) + return Expr(stmt.head, emit!(mapstmt!(stmt.args[1])), emit!(mapstmt!(stmt.args[2]))) + elseif isexpr(stmt, :call) + args = map(stmt.args) do stmt + emit!(mapstmt!(stmt)) + end + return Expr(:call, ∂☆{N}(), args...) + elseif isexpr(stmt, :new) + args = map(stmt.args) do stmt + emit!(mapstmt!(stmt)) + end + return Expr(:call, ∂☆new{N}(), args...) + elseif isexpr(stmt, :splatnew) + args = map(stmt.args) do stmt + emit!(mapstmt!(stmt)) + end + return Expr(:call, Core._apply_iterate, FwdIterate(ZeroBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...) + elseif isa(stmt, SSAValue) + return SSAValue(ssa_mapping[stmt.id]) + elseif isa(stmt, Core.SlotNumber) + return SlotNumber(2 + stmt.id) + elseif isa(stmt, Argument) + return SlotNumber(2 + stmt.n) + elseif isa(stmt, NewvarNode) + return NewvarNode(SlotNumber(2 + stmt.slot.id)) + elseif isa(stmt, ReturnNode) + return ReturnNode(emit!(mapstmt!(stmt.val))) + elseif isa(stmt, GotoNode) + return stmt + elseif isa(stmt, GotoIfNot) + return GotoIfNot(emit!(Expr(:call, primal, emit!(mapstmt!(stmt.cond)))), stmt.dest) + elseif isexpr(stmt, :static_parameter) + return ZeroBundle{N}(sparams[stmt.args[1]]) + elseif isexpr(stmt, :foreigncall) + return Expr(:call, error, "Attempted to AD a foreigncall. Missing rule?") + elseif isexpr(stmt, :meta) || isexpr(stmt, :inbounds) + # Can't trust that meta annotations are still valid in the AD'd + # version. + return nothing + else + return Expr(:call, ZeroBundle{N}, stmt) + end + end + + for i = 1:meth.nargs + if meth.isva && i == meth.nargs + args = map(i:(nargs+1)) do j + emit!(Expr(:call, getfield, SlotNumber(2), j)) + end + emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, ∂vararg{N}(), args...))) + else + emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, getfield, SlotNumber(2), i))) + end + end + + for (stmt, codeloc) in zip(ci.code, ci.codelocs) + push!(loc_mapping, length(new_code)+1) + push!(new_codelocs, codeloc) + push!(new_code, mapstmt!(stmt)) + push!(ssa_mapping, length(new_code)) + end + + # Rewrite control flow + for (i, stmt) in enumerate(new_code) + if isa(stmt, GotoNode) + new_code[i] = GotoNode(loc_mapping[stmt.label]) + elseif isa(stmt, GotoIfNot) + new_code[i] = GotoIfNot(stmt.cond, loc_mapping[stmt.dest]) + end + end + + ci.code = new_code + ci.codelocs = new_codelocs + ci +end diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl new file mode 100644 index 00000000..031115ae --- /dev/null +++ b/src/codegen/forward_demand.jl @@ -0,0 +1,154 @@ +using Core.Compiler: IRInterpretationState, construct_postdomtree, PiNode, + is_known_call, argextype, postdominates + +function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, pantelides::Vector{SSAValue}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}()) + Δs = SSAValue[] + rets = findall(@nospecialize(x)->isa(x, ReturnNode) && isdefined(x, :val), ir.stmts.inst) + postdomtree = construct_postdomtree(ir.cfg.blocks) + for ssa in pantelides + Δssa = forward_diff!(ir, interp, irsv, ssa; custom_diff!, diff_cache) + Δblock = block_for_inst(ir, Δssa.id) + for idx in rets + retblock = block_for_inst(ir, idx) + if !postdominates(postdomtree, retblock, Δblock) + error("Stmt %$ssa does not dominate all return blocks $(rets)") + end + end + push!(Δs, Δssa) + end + return (ir, Δs) +end + +function diff_unassigned_variable!(ir, ssa) + return insert_node!(ir, ssa, NewInstruction( + Expr(:call, GlobalRef(Intrinsics, :state_ddt), ssa), Float64), #=attach_after=#true) +end + +function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue; custom_diff!, diff_cache) + if haskey(diff_cache, ssa) + return diff_cache[ssa] + end + inst = ir[ssa] + stmt = inst[:inst] + if isa(stmt, SSAValue) + return forward_diff!(ir, interp, irsv, stmt; custom_diff!, diff_cache) + end + Δssa = forward_diff_uncached!(ir, interp, irsv, ssa, inst; custom_diff!, diff_cache) + @assert Δssa !== nothing + if isa(Δssa, SSAValue) + diff_cache[ssa] = Δssa + end + return Δssa +end +forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, val::Union{Integer, AbstractFloat}; custom_diff!, diff_cache) = zero(val) +forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, @nospecialize(arg); custom_diff!, diff_cache) = ChainRulesCore.NoTangent() +function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Argument; custom_diff!, diff_cache) + recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache) + val = custom_diff!(ir, SSAValue(0), arg, recurse) + if val !== nothing + return val + end + return ChainRulesCore.NoTangent() +end + +function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue, inst::Core.Compiler.Instruction; custom_diff!, diff_cache) + stmt = inst[:inst] + recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache) + if (val = custom_diff!(ir, ssa, stmt, recurse)) !== nothing + return val + elseif isa(stmt, PiNode) + return recurse(stmt.val) + elseif isa(stmt, PhiNode) + Δphi = PhiNode(copy(stmt.edges), similar(stmt.values)) + T = Union{} + for i in 1:length(stmt.values) + isassigned(stmt.values, i) || continue + Δphi.values[i] = recurse(stmt.values[i]) + T = CC.tmerge(CC.optimizer_lattice(interp), T, argextype(Δphi.values[i], ir)) + end + return insert_node!(ir, ssa, NewInstruction(Δphi, T), true) + elseif is_known_call(stmt, tuple, ir) + Δtpl = Expr(:call, GlobalRef(Core, :tuple)) + for arg in stmt.args[2:end] + arg = recurse(arg) + push!(Δtpl.args, arg) + end + argtypes = Any[argextype(arg, ir) for arg in Δtpl.args[2:end]] + tup_typ = CC.tuple_tfunc(CC.typeinf_lattice(interp), argtypes) + Δssa = insert_node!(ir, ssa, NewInstruction(Δtpl, tup_typ), true) + return Δssa + elseif isexpr(stmt, :new) + Δtpl = Expr(:call, GlobalRef(Core, :tuple)) + for arg in stmt.args[2:end] + push!(Δtpl.args, recurse(arg)) + end + argtypes = Any[argextype(arg, ir) for arg in Δtpl.args[2:end]] + tup_typ = CC.tuple_tfunc(CC.typeinf_lattice(interp), argtypes) + Δbacking = insert_node!(ir, ssa, NewInstruction(Δtpl, tup_typ)) + newT = argextype(stmt.args[1], ir) + @assert isa(newT, Const) + tup_typ_typ = Core.Compiler.typeof_tfunc(tup_typ) + if !(newT.val <: Tuple) + tup_typ_typ = Core.Compiler.apply_type_tfunc(Const(NamedTuple{fieldnames(newT.val)}), tup_typ_typ) + Δbacking = insert_node!(ir, ssa, NewInstruction(Expr(:splatnew, widenconst(tup_typ), Δbacking), tup_typ_typ.val)) + end + tangentT = Core.Compiler.apply_type_tfunc(Const(ChainRulesCore.Tangent), newT, tup_typ_typ).val + Δtangent = insert_node!(ir, ssa, NewInstruction(Expr(:new, tangentT, Δbacking), tangentT)) + return Δtangent + else # general frule handling + info = inst[:info] + if !isa(info, FRuleCallInfo) + @show info + @show inst[:inst] + display(ir) + error() + end + if isexpr(stmt, :invoke) + args = stmt.args[2:end] + else + args = copy(stmt.args) + end + Δtpl = Expr(:call, GlobalRef(Core, :tuple), nothing) + for arg in args[2:end] + push!(Δtpl.args, recurse(arg)) + end + argtypes = Any[argextype(arg, ir) for arg in Δtpl.args[2:end]] + tup_T = CC.tuple_tfunc(CC.typeinf_lattice(interp), argtypes) + + Δ = insert_node!(ir, ssa, NewInstruction( + Δtpl, tup_T)) + + # Now that we know the arguments, do a proper typeinf for this particular callsite + new_spec_types = Tuple{typeof(ChainRulesCore.frule), widenconst(tup_T), (widenconst(argextype(arg, ir)) for arg in args)...} + new_match = Base._which(new_spec_types) + + # Now do proper type inference with the known arguments + interp′ = disable_forward(interp) + new_frame = Core.Compiler.typeinf_frame(interp′, new_match.method, new_match.spec_types, new_match.sparams, #=run_optimizer=#true) + + # Create :invoke expression for the newly inferred frule + frule_mi = CC.EscapeAnalysis.analyze_match(new_match, length(args)+2) + frule_call = Expr(:invoke, frule_mi, GlobalRef(ChainRulesCore, :frule), Δ, args...) + frule_flag = CC.flags_for_effects(new_frame.ipo_effects) + + result = new_frame.result.result + if isa(result, Const) && result.val === nothing + error("DAECompiler thought we had an frule at inference time, but no frule found") + end + + # Incidence analysis through the rt call + # TODO: frule_mi is wrong here, should be the mi of the caller + frule_rt = info.frule_call.rt + improve_frule_rt = CC.concrete_eval_invoke(interp, frule_call, frule_mi, irsv) + if improve_frule_rt !== nothing + frule_rt = improve_frule_rt + end + frule_result = insert_node!(ir, ssa, NewInstruction( + frule_call, frule_rt, info.frule_call.info, inst[:line], + frule_flag)) + ir[ssa][:inst] = Expr(:call, GlobalRef(Core, :getfield), frule_result, 1) + Δssa = insert_node!(ir, ssa, NewInstruction( + Expr(:call, GlobalRef(Core, :getfield), frule_result, 2), CC.getfield_tfunc(CC.typeinf_lattice(interp), frule_rt, Const(2))), #=attach_after=#true) + return Δssa + end +end diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index 0670adfe..b237632b 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -7,6 +7,11 @@ function Base.push!(cfg::CFG, bb::BasicBlock) push!(cfg.index, bb.stmts.start) end +Base.getindex(ir::IRCode, ssa::SSAValue) = + Core.Compiler.getindex(ir, ssa) + +Base.copy(ir::IRCode) = Core.Compiler.copy(ir) + function Core.Compiler.NewInstruction(node) Core.Compiler.NewInstruction(node, Any) end diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index c9542acb..317770eb 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -18,95 +18,6 @@ struct ∂☆new{N}; end # the transform, so this can happen - allow it for now. (this::∂☆new{N})(B::ATB{N, <:Type}, args::ATB{N}...) where {N} = this(primal(B), args...) -function transform_fwd!(ci, meth, nargs, sparams, N) - new_code = Any[] - new_codelocs = Any[] - ssa_mapping = Int[] - loc_mapping = Int[] - - function emit!(stmt) - (isexpr(stmt, :call) || isexpr(stmt, :(=)) || isexpr(stmt, :new)) || return stmt - push!(new_code, stmt) - push!(new_codelocs, isempty(new_codelocs) ? 0 : new_codelocs[end]) - SSAValue(length(new_code)) - end - - function mapstmt!(stmt) - if isexpr(stmt, :(=)) - return Expr(stmt.head, emit!(mapstmt!(stmt.args[1])), emit!(mapstmt!(stmt.args[2]))) - elseif isexpr(stmt, :call) - args = map(stmt.args) do stmt - emit!(mapstmt!(stmt)) - end - return Expr(:call, ∂☆{N}(), args...) - elseif isexpr(stmt, :new) - args = map(stmt.args) do stmt - emit!(mapstmt!(stmt)) - end - return Expr(:call, ∂☆new{N}(), args...) - elseif isexpr(stmt, :splatnew) - args = map(stmt.args) do stmt - emit!(mapstmt!(stmt)) - end - return Expr(:call, Core._apply_iterate, FwdIterate(ZeroBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...) - elseif isa(stmt, SSAValue) - return SSAValue(ssa_mapping[stmt.id]) - elseif isa(stmt, Core.SlotNumber) - return SlotNumber(2 + stmt.id) - elseif isa(stmt, Argument) - return SlotNumber(2 + stmt.n) - elseif isa(stmt, NewvarNode) - return NewvarNode(SlotNumber(2 + stmt.slot.id)) - elseif isa(stmt, ReturnNode) - return ReturnNode(emit!(mapstmt!(stmt.val))) - elseif isa(stmt, GotoNode) - return stmt - elseif isa(stmt, GotoIfNot) - return GotoIfNot(emit!(Expr(:call, primal, emit!(mapstmt!(stmt.cond)))), stmt.dest) - elseif isexpr(stmt, :static_parameter) - return ZeroBundle{N}(sparams[stmt.args[1]]) - elseif isexpr(stmt, :foreigncall) - return Expr(:call, error, "Attempted to AD a foreigncall. Missing rule?") - elseif isexpr(stmt, :meta) || isexpr(stmt, :inbounds) - # Can't trust that meta annotations are still valid in the AD'd - # version. - return nothing - else - return Expr(:call, ZeroBundle{N}, stmt) - end - end - - for i = 1:meth.nargs - if meth.isva && i == meth.nargs - args = map(i:(nargs+1)) do j - emit!(Expr(:call, getfield, SlotNumber(2), j)) - end - emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, ∂vararg{N}(), args...))) - else - emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, getfield, SlotNumber(2), i))) - end - end - - for (stmt, codeloc) in zip(ci.code, ci.codelocs) - push!(loc_mapping, length(new_code)+1) - push!(new_codelocs, codeloc) - push!(new_code, mapstmt!(stmt)) - push!(ssa_mapping, length(new_code)) - end - - # Rewrite control flow - for (i, stmt) in enumerate(new_code) - if isa(stmt, GotoNode) - new_code[i] = GotoNode(loc_mapping[stmt.label]) - elseif isa(stmt, GotoIfNot) - new_code[i] = GotoIfNot(stmt.cond, loc_mapping[stmt.dest]) - end - end - - ci.code = new_code - ci.codelocs = new_codelocs - ci -end π(::Type{<:AbstractTangentBundle{N, B}} where N) where {B} = B diff --git a/src/stage2/abstractinterpret.jl b/src/stage2/abstractinterpret.jl index 3e850ae9..c655a19b 100644 --- a/src/stage2/abstractinterpret.jl +++ b/src/stage2/abstractinterpret.jl @@ -2,66 +2,75 @@ import Core.Compiler: abstract_call_gf_by_type, abstract_call using Core.Compiler: Const, isconstType, argtypes_to_type, tuple_tfunc, Const, getfield_tfunc, _methods_by_ftype, VarTable, cache_lookup, nfields_tfunc, ArgInfo, singleton_type, CallMeta, MethodMatchInfo, specialize_method, - PartialOpaque, UnionSplitApplyCallInfo, typeof_tfunc, apply_type_tfunc, instanceof_tfunc + PartialOpaque, UnionSplitApplyCallInfo, typeof_tfunc, apply_type_tfunc, instanceof_tfunc, + StmtInfo using Core: PartialStruct using Base.Meta function Core.Compiler.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f), - arginfo::ArgInfo, @nospecialize(atype), sv::InferenceState, max_methods::Int) + arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), sv::InferenceState, max_methods::Int) (;argtypes) = arginfo - if f isa ∂⃖recurse - inner_argtypes = argtypes[2:end] - ft = inner_argtypes[1] - f = singleton_type(ft) - rinterp = raise_level(interp) - call = abstract_call_gf_by_type(rinterp, f, ArgInfo(nothing, inner_argtypes), argtypes_to_type(inner_argtypes), sv, max_methods) - if isa(call.info, MethodMatchInfo) - if length(call.info.results.matches) == 0 - @show inner_argtypes - error() + if interp.backward + if f isa ∂⃖recurse + inner_argtypes = argtypes[2:end] + ft = inner_argtypes[1] + f = singleton_type(ft) + rinterp = raise_level(interp) + call = abstract_call_gf_by_type(rinterp, f, ArgInfo(nothing, inner_argtypes), argtypes_to_type(inner_argtypes), sv, max_methods) + if isa(call.info, MethodMatchInfo) + if length(call.info.results.matches) == 0 + @show inner_argtypes + error() + end + mi = specialize_method(call.info.results.matches[1], preexisting=true) + ci = get(rinterp.unopt[rinterp.current_level], mi, nothing) + clos = AbstractCompClosure(rinterp.current_level, 1, call.info, ci.stmt_info) + clos = Core.PartialOpaque(Core.OpaqueClosure{<:Tuple, <:Any}, nothing, sv.linfo, clos) + elseif isa(call.info, RRuleInfo) + if rinterp.current_level == 1 + clos = getfield_tfunc(call.info.rrule_rt, Const(2)) + else + name = call.info.info.results.matches[1].method.sig.parameters[2].name.mt.name + clos = PrimClosure(name, rinterp.current_level - 1, 1, getfield_tfunc(call.info.rrule_rt, Const(2)), call.info, nothing) + end end - mi = specialize_method(call.info.results.matches[1], preexisting=true) - ci = get(rinterp.unopt[rinterp.current_level], mi, nothing) - clos = AbstractCompClosure(rinterp.current_level, 1, call.info, ci.stmt_info) - clos = Core.PartialOpaque(Core.OpaqueClosure{<:Tuple, <:Any}, nothing, sv.linfo, clos) - elseif isa(call.info, RRuleInfo) - if rinterp.current_level == 1 - clos = getfield_tfunc(call.info.rrule_rt, Const(2)) + # TODO: use abstract_new instead, when it exists + obtype = instanceof_tfunc(apply_type_tfunc(Const(OpticBundle), typeof_tfunc(call.rt)))[1] + if obtype isa DataType + rt2 = PartialStruct(obtype, Any[call.rt, clos]) else - name = call.info.info.results.matches[1].method.sig.parameters[2].name.mt.name - clos = PrimClosure(name, rinterp.current_level - 1, 1, getfield_tfunc(call.info.rrule_rt, Const(2)), call.info, nothing) + rt2 = obtype end + return CallMeta(rt2, call.effects, RecurseInfo(call.info)) end - # TODO: use abstract_new instead, when it exists - obtype = instanceof_tfunc(apply_type_tfunc(Const(OpticBundle), typeof_tfunc(call.rt)))[1] - if obtype isa DataType - rt2 = PartialStruct(obtype, Any[call.rt, clos]) - else - rt2 = obtype + + # Check if there is a rrule for this function + if interp.current_level != 0 && f !== ChainRules.rrule + rrule_argtypes = Any[Const(ChainRules.rrule); argtypes] + rrule_atype = argtypes_to_type(rrule_argtypes) + # In general we want the forward type of an rrule'd function to match + # what the function itself would have returned, but let's support this + # not being the case. + if f == accum + error() + end + call = abstract_call_gf_by_type(lower_level(interp), ChainRules.rrule, ArgInfo(nothing, rrule_argtypes), rrule_atype, sv, -1) + if call.rt != Const(nothing) + return CallMeta(getfield_tfunc(call.rt, Const(1)), call.effects, RRuleInfo(call.rt, call.info)) + end end - return CallMeta(rt2, call.effects, RecurseInfo(call.info)) end - # Check if there is a rrule for this function - if interp.current_level != 0 && f !== ChainRules.rrule - rrule_argtypes = Any[Const(ChainRules.rrule); argtypes] - rrule_atype = argtypes_to_type(rrule_argtypes) - # In general we want the forward type of an rrule'd function to match - # what the function itself would have returned, but let's support this - # not being the case. - if f == accum - error() - end - call = abstract_call_gf_by_type(lower_level(interp), ChainRules.rrule, ArgInfo(nothing, rrule_argtypes), rrule_atype, sv, -1) - if call.rt != Const(nothing) - return CallMeta(getfield_tfunc(call.rt, Const(1)), call.effects, RRuleInfo(call.rt, call.info)) + ret = @invoke CC.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any, + arginfo::ArgInfo, si::StmtInfo, atype::Any, sv::InferenceState, max_methods::Int) + + if interp.forward + r = fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret) + if r !== nothing + return r end end - ret = invoke(abstract_call_gf_by_type, - Tuple{AbstractInterpreter, Any, ArgInfo, Any, InferenceState, Int}, - interp, f, arginfo, atype, sv, max_methods) - return ret end diff --git a/src/stage2/forward.jl b/src/stage2/forward.jl new file mode 100644 index 00000000..4ec67f0d --- /dev/null +++ b/src/stage2/forward.jl @@ -0,0 +1,38 @@ +using .CC: compact! + +# Engineering entry point for the 2nd-order forward AD functionality. This is +# unlikely to be the actual interface. For now, it is used for testing. +function dontuse_nth_order_forward_stage2(tt::Type) + interp = ADInterpreter(; forward=true, backward=false) + match = Base._which(tt) + frame = Core.Compiler.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true) + + ir = copy((interp.opt[0][frame.linfo].inferred).ir::IRCode) + + # Find all Return Nodes + vals = SSAValue[] + for i = 1:length(ir.stmts) + if isa(ir[SSAValue(i)][:inst], ReturnNode) + push!(vals, SSAValue(i)) + end + end + + function custom_diff!(ir, ssa, stmt, recurse) + if isa(stmt, ReturnNode) + r = recurse(stmt.val) + ir[ssa][:inst] = ReturnNode(r) + return ssa + elseif isa(stmt, Argument) + return 1.0 + end + return nothing + end + + irsv = CC.IRInterpretationState(interp, ir, frame.linfo, CC.get_world_counter(interp), ir.argtypes[1:frame.linfo.def.nargs]) + forward_diff!(ir, interp, irsv, vals; custom_diff!) + + display(ir) + + ir = compact!(ir) + return OpaqueClosure(ir) +end diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index e444f829..29182107 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -35,6 +35,10 @@ using Core.Compiler: AbstractInterpreter, NativeInterpreter, InferenceState, InferenceResult, CodeInstance, WorldRange struct ADInterpreter <: AbstractInterpreter + # Modes settings + forward::Bool + backward::Bool + # This cache is stratified by AD nesting level. Depending on the # nesting level of the derivative, The AD primitives may behave # differently. @@ -54,6 +58,8 @@ change_level(interp::ADInterpreter, new_level::Int) = ADInterpreter(interp.opt, raise_level(interp::ADInterpreter) = change_level(interp, interp.current_level + 1) lower_level(interp::ADInterpreter) = change_level(interp, interp.current_level - 1) +disable_forward(interp::ADInterpreter) = ADInterpreter(false, interp.backward, interp.opt, interp.unopt, interp.transformed, interp.native_interpreter, interp.current_level, interp.msgs) + Cthulhu.get_optimized_codeinst(interp::ADInterpreter, curs::ADCursor) = (curs.transformed ? interp.transformed : interp.opt)[curs.level][curs.mi] Cthulhu.AbstractCursor(interp::ADInterpreter, mi::MethodInstance) = ADCursor(0, mi, false) @@ -213,7 +219,7 @@ function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info), argtyp interp, info, argtypes, rt, optimize) end -ADInterpreter() = ADInterpreter( +ADInterpreter(;forward = false, backward=true) = ADInterpreter(forward, backward, OffsetVector([Dict{MethodInstance, CodeInstance}(), Dict{MethodInstance, CodeInstance}()], 0:1), OffsetVector([Dict{MethodInstance, Cthulhu.InferredSource}(), Dict{MethodInstance, Cthulhu.InferredSource}()], 0:1), OffsetVector([Dict{MethodInstance, CodeInstance}(), Dict{MethodInstance, CodeInstance}()], 0:1), @@ -231,8 +237,8 @@ Core.Compiler.get_world_counter(ei::ADInterpreter) = get_world_counter(ei.native Core.Compiler.get_inference_cache(ei::ADInterpreter) = get_inference_cache(ei.native_interpreter) # No need to do any locking since we're not putting our results into the runtime cache -lock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing -unlock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing +Core.Compiler.lock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing +Core.Compiler.unlock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing struct CodeInfoView d::Dict{MethodInstance, Any} @@ -280,9 +286,8 @@ function Core.Compiler.finish(state::InferenceState, interp::ADInterpreter) end function Core.Compiler.transform_result_for_cache(interp::ADInterpreter, - linfo::MethodInstance, valid_worlds::WorldRange, @nospecialize(inferred_result), - ipo_effects::Core.Compiler.Effects) - return Cthulhu.create_cthulhu_source(inferred_result, ipo_effects) + linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) + return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects) end #@static if isdefined(Compiler, :is_stmt_inline) diff --git a/src/stage2/lattice.jl b/src/stage2/lattice.jl index 4f2153ab..a6c11c08 100644 --- a/src/stage2/lattice.jl +++ b/src/stage2/lattice.jl @@ -1,4 +1,4 @@ -using Core.Compiler: CodeInfo +using Core.Compiler: CodeInfo, CallInfo, CallMeta import Core.Compiler: widenconst struct CompClosure; opaque; end # TODO: Is this a YAKC? @@ -62,5 +62,20 @@ struct ReifyInfo info end +# Forward mode info +struct FRuleCallInfo <: CallInfo + info::CallInfo + frule_call::CallMeta + FRuleCallInfo(@nospecialize(info::CallInfo), frule_call::CallMeta) = new(info, frule_call) +end +CC.nsplit_impl(info::FRuleCallInfo) = CC.nsplit(info.info) +CC.getsplit_impl(info::FRuleCallInfo, idx::Int) = CC.getsplit(info.info, idx) +CC.getresult_impl(info::FRuleCallInfo, idx::Int) = CC.getresult(info.info, idx) + +function Base.show(io::IO, info::FRuleCallInfo) + print(io, "FRuleCallInfo(", typeof(info.info), ", ", typeof(info.frule_call.info), ")") +end + + # Helpers tuple_type_fields(rt) = isa(rt, PartialStruct) ? rt.fields : widenconst(rt).parameters diff --git a/test/runtests.jl b/test/runtests.jl index eca14ae5..03a30776 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,8 @@ const bwd = Diffractor.PrimeDerivativeBack @testset verbose=true "Diffractor.jl" begin # overall testset, ensures all tests run +include("stage2_fwd.jl") + # Unit tests function tup2(f) a, b = ∂⃖{2}()(f, 1) diff --git a/test/stage2_fwd.jl b/test/stage2_fwd.jl new file mode 100644 index 00000000..6aebc5de --- /dev/null +++ b/test/stage2_fwd.jl @@ -0,0 +1,7 @@ +module stage2_fwd + using Diffractor, Test + mysin(x) = sin(x) + let sin′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(mysin), Float64}) + @test sin′(1.0) == cos(1.0) + end +end