-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrace.jl
87 lines (73 loc) · 2.94 KB
/
trace.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
_typeof(f) = typeof(f)
_typeof(T::Type) = T
is_trace_primitive(x...) = false
## Default primitives
is_trace_primitive(::Type{typeof(LinearAlgebra.norm)},
::Type{<:AbstractVector},
::Type{<:Any}) = true
function trace(f, args...; isprimitive = is_trace_primitive, submodules = [], ctx = Dict{Any, Any}())
primitive_sigs =
Ghost.FunctionResolver{Bool}([Tuple{_typeof(f), Vararg} => true for f in submodules])
is_primitive_or_submodule(sig) =
Ghost.is_primitive(sig) ||
isprimitive(Ghost.get_type_parameters(sig)...) ||
sig ∈ primitive_sigs
_, tape = Ghost.trace(f, args...; is_primitive = is_primitive_or_submodule, ctx = ctx)
return tape
end
function transform!(f, fctx, tape::Ghost.Tape)
local entry, rebind_to
itr = iterate(tape)
@debug tape
while !isnothing(itr)
entry, idx = itr
@debug entry
if entry isa Ghost.Call
new_entries, rebind_to = f(tape.c, entry)
if isempty(new_entries)
deleteat!(tape, idx - 1; rebind_to = rebind_to)
else
replace!(tape, idx - 1 => new_entries; rebind_to = rebind_to)
vars = [tape[Ghost.Variable(idx - 1 + i)] for i in 0:(length(new_entries) - 1)]
fctx(tape.c, vars)
end
idx += length(new_entries) - 1
elseif entry isa Ghost.Input
new_entries, rebind_to = f(tape.c, entry)
isempty(new_entries) && error("Cannot delete Ghost.Input")
new_vars = insert!(tape, idx, new_entries[2:end]...)
vars = Ghost.AbstractOp[entry]
append!(vars, [tape[v] for v in new_vars])
fctx(tape.c, vars)
idx += length(new_entries) - 1
end
@debug tape
itr = iterate(tape, idx)
end
return tape
end
transform!(f, tape::Ghost.Tape) = transform!(f, (x...) -> nothing, tape)
squashable(x) = false
for op in (:+, :-, :*, :/, :÷)
@eval squashable(::typeof($op)) = true
end
_squash_binary_vararg(ctx, entry) = [entry], 1
function _squash_binary_vararg(ctx, call::Ghost.Call)
squashable(call.fn) || return [call], 1
new_calls = accumulate(call.args[2:end]; init = call.args[1]) do x, y
xvar = (x isa Ghost.Call) ? Ghost.Variable(x) : x
yvar = (y isa Ghost.Call) ? Ghost.Variable(y) : y
return Ghost.mkcall(call.fn, xvar, yvar)
end
return new_calls, length(new_calls)
end
# TODO: do we need macros for defining primitive operations?
# macro operator(ex)
# @capture(ex, fdef_ => optype_(opargs__)) ||
# error("Cannot parse expression $ex in @simulatable. Expected: f(arg1, arg2, ...) => Operator(oparg1, oparg2, ...)")
# @capture(fdef, f_(args__)) || error("Cannot parse expression $f in @simulatable. Expected: f(arg1, arg2, ...)")
# argsyms = map(x -> splitarg(x)[1], args)
# argtypes = map(x -> splitarg(x)[2], args)
# return quote
# end
# end