Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.cond can bind to unexpected function signature #16413

Closed
packquickly opened this issue Jun 14, 2023 · 3 comments
Closed

lax.cond can bind to unexpected function signature #16413

packquickly opened this issue Jun 14, 2023 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@packquickly
Copy link

Description

If you call a lax.cond with the format lax.cond(predicate, function, function, callable_pytree, callable_pytree) then lax.cond will bind to the old function signature <Signature (pred, true_operand, true_fun: Callable, false_operand, false_fun: Callable)> and swap the arguments in an unexpected way.

Here is a reproducing example:

import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu

def true_branch(add_one, add_two):
    return add_one(add_two(jnp.array(1.)))

def false_branch(add_one, add_two):
    return add_two(add_one(jnp.array(1.)))

add_one = jtu.Partial(jnp.add, jnp.array(1.)) # A callable pytree
add_two = jtu.Partial(jnp.add, jnp.array(2.))
four = lax.cond(True, true_branch, false_branch, add_one, add_two) # TypeError

What jax/jaxlib version are you using?

0.4.11

Which accelerator(s) are you using?

CPU

Additional system info

Linux

NVIDIA GPU info

No response

@packquickly packquickly added the bug Something isn't working label Jun 14, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 15, 2023

Thanks - I think in this case (five arguments, last four callable) there's no way to disambiguate between the old signature and the new signature. This shim is old enough that it's probably time we deprecate the old cond signature – @mattjj, what do you think?

@mattjj
Copy link
Collaborator

mattjj commented Jun 19, 2023

That sounds reasonable to me, @jakevdp, but it may require some Google monorepo fixes.

Another possibility is we could error if something is callable and a pytree node, and require the user to be more explicit in that case to avoid the error (not sure how).

@froystig could you take a look?

@froystig froystig assigned froystig and unassigned mattjj Jun 19, 2023
@froystig
Copy link
Member

14f3265 fixes this by resolving to new-style in the presence of this ambiguity.

At some point, we should deprecate the old-style behavior of lax.cond altogether. We can instead offer an explicit function for per-branch-operand conditionals, under a different name. Maybe lax.ifelse?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants