-
Notifications
You must be signed in to change notification settings - Fork 207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better tests for utils #465
Conversation
Hi @acforvs, many thanks for your PR and apologies for the belated review! |
Hi @hbq1 , thank you for the review! Sure thing, I'll add the tests for the remaining functions, but I have some questions about them:
As a result, examples like a = jnp.ones(shape=(3, 2, 2))
diags = jnp.array([[[3, 3, 3, 3, 3, 3]]])
print(a.shape, diags.shape) # prints (3, 2, 2) (1, 1, 6)
optax._src.utils.set_diags(a, diags) or a = jnp.ones(shape=(3, 2, 2))
diags = jnp.array([[3, 3, 3], [3, 3, 3]])
print(a.shape, diags.shape) # prints (3, 2, 2) (2, 3)
optax._src.utils.set_diags(a, diags) work just fine Is this by design or should we raise an error in case of a shape mismatch?
tree = dict(a=2.5, b=dict(c=-2.5))
tree = jax.tree_util.tree_map(lambda x : x, tree)
optax._src.utils.cast_tree(tree, int) Would it be okay to only test this function for dicts of jnp.arrays of differerent dtypes? Something like tree = dict(a=jnp.array(2.5), b=dict(c=jnp.array(-2.5)))
tree = jax.tree_util.tree_map(lambda x : x, tree)
optax._src.utils.cast_tree(tree, int) |
Great, thank you!
|
I've fixed the errors and added tests for the rest of the utils.py functions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes and for the nice PR 👍
Hi, hopefully this one closes #404