Skip to content

Commit

Permalink
undo changes to host_callback (not needed anymore)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 22, 2021
1 parent fe4d12c commit 9cc69ca
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
2 changes: 0 additions & 2 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,6 @@ def _outside_call_jvp_rule(primals, tangents, **params):
if not params["identity"]:
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
tangent_instantiated = tuple(map(_instantiate_zeros, primals, tangents))
tangent_instantiated = tuple(map(ad.replace_float0s, primals, tangent_instantiated))

arg_treedef = params["arg_treedef"]
# The argument to the jvp tap is a pair of the tapped primals and tangents
Expand All @@ -947,7 +946,6 @@ def _outside_call_jvp_rule(primals, tangents, **params):
arg_treedef=jvp_arg_treedef,
))
out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)])
out_tangents_tapped = map(ad.recast_to_float0, out_primals_tapped, out_tangents_tapped)
return tuple(out_primals_tapped), tuple(out_tangents_tapped)


Expand Down
6 changes: 3 additions & 3 deletions tests/host_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ def func(x, yint):
2 )
transforms: ['jvp', 'transpose'] what: pair
( 2.00
0 )""", testing_stream.output)
False )""", testing_stream.output)
testing_stream.reset()

def test_tap_vmap(self):
Expand Down Expand Up @@ -1590,8 +1590,8 @@ def padded_sum(x):
( 3 ) ) )
( ( [0. 0.1 0.2 0.3 0.4]
[0. 0.2 0.4 0.6 0.8] )
( ( 0 )
( 0 ) ) ) )""", testing_stream.output)
( ( False )
( False ) ) ) )""", testing_stream.output)
testing_stream.reset()

# Now with JIT
Expand Down

0 comments on commit 9cc69ca

Please sign in to comment.