Skip to content

Commit

Permalink
Gradient actually works
Browse files Browse the repository at this point in the history
  • Loading branch information
malmaud committed Oct 23, 2017
1 parent 9fa6338 commit bf84a38
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
15 changes: 8 additions & 7 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ function get_collection end
return g.collections[name]
end

const DEBUG_EXTEND_GRAPH = false
const DEBUG_EXTEND_GRAPH = true

function Base.convert(::Type{tensorflow.NodeDef}, proto::Vector{UInt8})
b = IOBuffer()
Expand All @@ -351,14 +351,15 @@ end
ph_names = Set{String}()
for node_bytes in node_defs
node_def = convert(tensorflow.NodeDef, node_bytes)
@show node_def
if isnull(get_node_by_name(graph, node_def.name))
# First try to directly add this node to the graph
try
new_op = Operation(node_def)
continue
catch err
DEBUG_EXTEND_GRAPH && warn(err)
end
# try
# new_op = Operation(node_def)
# continue
# catch err
# DEBUG_EXTEND_GRAPH && warn(err)
# end

# If that doesn't work (for example, the node has a
# back edge), then import the node instead.
Expand Down
14 changes: 8 additions & 6 deletions src/ops/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ Example using shape_invariants:
shape_invariants=[i0.get_shape(), tensor_shape.TensorShape([None, 2])])
```
"""
@op function while_loop(condition, body, variables; name=nothing, shape_invariants=nothing,
@op function jl_while_loop(condition, body, variables; name=nothing, shape_invariants=nothing,
parallel_iterations=10, back_prop=true, swap_memory=false)
g = Graph()
def_graph = get_def_graph()
Expand Down Expand Up @@ -445,7 +445,8 @@ function WhileLoopOptions(;parallel_iterations=10, back_prop=true, swap_memory=f
WhileLoopOptions(parallel_iterations, back_prop, swap_memory)
end

function c_while(condition, body, variables; name=nothing, options=WhileLoopOptions())
function while_loop(condition, body, variables; name=nothing, options=WhileLoopOptions())
variables = Tensor.(variables)
name === nothing && (name = "while")
name = String(name)
graph = get_def_graph()
Expand Down Expand Up @@ -485,7 +486,7 @@ function create_while_context(graph, name, n_inputs; options=WhileLoopOptions())
loop_exit_names=String[])
context_matcher = Regex("^$(name)/")
for op in get_operations(graph)
@show op
# @show op
if ismatch(context_matcher, get_def(op).name)
def = get_def(op)
n_outputs = length(get_op_def(def.op).output_arg)
Expand All @@ -494,18 +495,19 @@ function create_while_context(graph, name, n_inputs; options=WhileLoopOptions())
end
end
end
push!(ctx.values_def.values, "$(name)/merge0:1")
push!(ctx.values_def.values, "$(name)/switch0:1")
# push!(ctx.values_def.values, "$(name)/merge0:1")
# push!(ctx.values_def.values, "$(name)/switch0:1")
set_field!(ctx, :pivot_for_pred_name, "$(name)/merge0:0")
switch_name = "$(name)/switch0"
switch_op = get_node_by_name(switch_name) |> get |> get_def
# We assume the pivot tensor is the second input to the switch statement.
# The first input is the result of the merge.
# The first input is the result of the merge.
cond_op = switch_op.input[2]
set_field!(ctx, :pivot_for_body_name, "$(switch_name):0")
set_field!(ctx, :pivot_name, "$(cond_op):0")
for i in 1:n_inputs
push!(ctx.loop_exit_names, "$(name)/exit$(i-1):0")
end
# dump(ctx)
return ctx
end

0 comments on commit bf84a38

Please sign in to comment.