-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsimulatable.jl
176 lines (153 loc) · 6.35 KB
/
simulatable.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
struct SimulatableContext
# map from original variable to popped variable
popmap::Dict{Ghost.Variable, Ghost.Variable}
# map from signatures to variable bindings
opmap::Dict{Tuple, Ghost.Variable}
end
SimulatableContext() = SimulatableContext(Dict{Ghost.Variable, Vector{Ghost.Variable}}(),
Dict{Tuple, Ghost.Variable}())
function Ghost.rebind_context!(tape::Ghost.Tape{SimulatableContext}, subs::Dict)
replace!(tape.c.popmap) do kv
old_var, new_var = kv
old_var = get(subs, old_var, old_var)
new_var = get(subs, new_var, new_var)
return old_var => new_var
end
replace!(tape.c.opmap) do kv
sig, var = kv
newout = get(subs, var, var)
newargs = [get(subs, arg, arg) for arg in Base.tail(sig)]
return (sig[1], newargs...) => newout
end
end
function _update_ctx!(ctx::SimulatableContext, vars)
if _issimulatable(vars[1])
if vars[1] isa Ghost.Input && length(vars) > 1
ctx.popmap[Ghost.Variable(vars[1])] = Ghost.Variable(vars[2])
elseif vars[1] isa Ghost.Call && _isbcast(vars[1].fn) && length(vars) > 1
ctx.popmap[Ghost.Variable(vars[1])] = Ghost.Variable(vars[end - 1])
ctx.popmap[Ghost.Variable(vars[2])] = Ghost.Variable(vars[end - 1])
elseif vars[1] isa Ghost.Call && length(vars) > 1
ctx.popmap[Ghost.Variable(vars[1])] = Ghost.Variable(vars[end - 1])
# elseif vars[1] isa Ghost.Call # identity replacement
# ctx.popmap[Ghost.Variable(vars[1])] = vars[1].args[1]
end
end
foreach(vars) do var
(var isa Ghost.Call) && push!(ctx.opmap, (var.fn, var.args...) => Ghost.Variable(var))
end
end
function _popcalls(ctx, args)
calls = []
new_args = []
foreach(args) do arg
if _gettapeval(arg) isa SBitstreamLike
if haskey(ctx.popmap, arg)
push!(new_args, ctx.popmap[arg])
else
popcall = Ghost.mkcall(getbit, arg)
push!(calls, popcall)
push!(new_args, Ghost.Variable(popcall))
end
else
push!(new_args, arg)
end
end
return calls, new_args
end
function _wrap_bcast_bits(popbits)
wrap_calls = []
wrap_bits = []
for bit in popbits
if _gettapeval(bit) isa SBit
call = Ghost.mkcall(Ref, bit)
push!(wrap_calls, call)
push!(wrap_bits, Ghost.Variable(call))
else
push!(wrap_bits, bit)
end
end
return wrap_calls, wrap_bits
end
function _unbroadcasted_transform(ctx, call, sim)
# insert calls to pop bits from the args
popcalls, popbits = _popcalls(ctx, call.args)
# evaluate simulator on popped bits
bit = Ghost.mkcall(sim, popbits...)
# push resulting bits onto bitstream
psh = Ghost.mkcall(setbit!, Ghost.Variable(call), Ghost.Variable(bit))
return [call, popcalls..., bit, psh], 1
end
function _broadcasted_transform(ctx, call, sim)
# ignore first (function) arg of broadcasted
args = call.args[2:end]
# insert calls to pop bits from the args
popcalls, popbits = _popcalls(ctx, args)
# materialize broadcasted
mat = Ghost.mkcall(Base.materialize, Ghost.Variable(call))
# evaluate simulator on popped bits
bit = Ghost.mkcall(sim, popbits...)
# push resulting bits onto bitstream
psh = Ghost.mkcall(setbit!, Ghost.Variable(mat), Ghost.Variable(bit))
return [call, mat, popcalls..., bit, psh], 1
end
function _broadcasted_transform(ctx, call, sims::AbstractArray)
# ignore first (function) arg of broadcasted
args = call.args[2:end]
# insert calls to pop bits from the args
popcalls, popbits = _popcalls(ctx, args)
# materialize broadcasted
mat = Ghost.mkcall(Base.materialize, Ghost.Variable(call))
# evaluate simulators element-wise on popped bits
wrapcalls, wrapbits = _wrap_bcast_bits(popbits)
bits = Ghost.mkcall(Base.broadcasted, (f, a...) -> f(a...), sims, wrapbits...)
matbits = Ghost.mkcall(Base.materialize, Ghost.Variable(bits))
# push resulting bits onto bitstreams
psh = Ghost.mkcall(setbit!, Ghost.Variable(mat), Ghost.Variable(matbits))
return [call, mat, popcalls..., wrapcalls..., bits, matbits, psh], 1
end
_handle_bcast_and_transform(ctx, call, sim) =
_isbcast(call.fn) ? _broadcasted_transform(ctx, call, sim) :
_unbroadcasted_transform(ctx, call, sim)
_simtransform(ctx, input::Ghost.Input) =
_gettapeval(Ghost.Variable(input)) isa SBitstreamLike ?
([input, Ghost.mkcall(getbit, Ghost.Variable(input))], 1) :
([input], 1)
function _simtransform(ctx, call::Ghost.Call)
# if call has already been transformed,
# then delete this call and rebind to the transformed call
haskey(ctx.opmap, (call.fn, call.args...)) && return [], ctx.opmap[(call.fn, call.args...)].id
# if the args don't contain SBitstreamLike, then skip
sig = Ghost.get_type_parameters(Ghost.call_signature(call.fn, _gettapeval.(call.args)...))
is_simulatable_primitive(sig...) || return [call], 1
# (call.fn == getindex) && return [call], 1
# otherwise, transform this call while handling broadcasting
# get the simulator for this call signature
sim = getsimulator(call.fn, map(arg -> _gettapeval(arg), call.args)...)
return _handle_bcast_and_transform(ctx, call, sim)
end
getbit(x) = x
getbit(x::SBitstreamLike) = pop!.(x)
setbit!(x::SBitstream, bit) = push!(x, bit)
setbit!(x::AbstractArray{<:SBitstream}, bits) = push!.(x, bits)
is_simulatable_primitive(sig...) = is_trace_primitive(sig...)
function simulator(f, args...)
# if f itself is a primitive, do a manual tape
if is_simulatable_primitive(Ghost.get_type_parameters(Ghost.call_signature(f, args...))...)
tape = Ghost.Tape(SimulatableContext())
inputs = Ghost.inputs!(tape, f, args...)
if _isstruct(f)
tape.result = push!(tape, Ghost.mkcall(inputs...))
else
tape.result = push!(tape, Ghost.mkcall(f, inputs[2:end]...))
end
else
tape = trace(f, args...;
isprimitive = is_simulatable_primitive,
ctx = SimulatableContext())
end
transform!(_squash_binary_vararg, tape)
transform!(_simtransform, _update_ctx!, tape)
return tape
end
simulatable(f, args...) = Ghost.compile(simulator(f, args...))