You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If you call a lax.cond with the format lax.cond(predicate, function, function, callable_pytree, callable_pytree) then lax.cond will bind to the old function signature <Signature (pred, true_operand, true_fun: Callable, false_operand, false_fun: Callable)> and swap the arguments in an unexpected way.
Thanks - I think in this case (five arguments, last four callable) there's no way to disambiguate between the old signature and the new signature. This shim is old enough that it's probably time we deprecate the old cond signature – @mattjj, what do you think?
That sounds reasonable to me, @jakevdp, but it may require some Google monorepo fixes.
Another possibility is we could error if something is callable and a pytree node, and require the user to be more explicit in that case to avoid the error (not sure how).
14f3265 fixes this by resolving to new-style in the presence of this ambiguity.
At some point, we should deprecate the old-style behavior of lax.cond altogether. We can instead offer an explicit function for per-branch-operand conditionals, under a different name. Maybe lax.ifelse?
Description
If you call a
lax.cond
with the formatlax.cond(predicate, function, function, callable_pytree, callable_pytree)
thenlax.cond
will bind to the old function signature<Signature (pred, true_operand, true_fun: Callable, false_operand, false_fun: Callable)>
and swap the arguments in an unexpected way.Here is a reproducing example:
What jax/jaxlib version are you using?
0.4.11
Which accelerator(s) are you using?
CPU
Additional system info
Linux
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: