You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
All over Mava we have statements like tree.map(lambda x: x[jnp.newaxis], pytree) I think it would make it a lot clearer if we had an add_batch_dim method in utils.jax_utils
The text was updated successfully, but these errors were encountered:
All over Mava we have statements like
tree.map(lambda x: x[jnp.newaxis], pytree)
I think it would make it a lot clearer if we had anadd_batch_dim
method inutils.jax_utils
The text was updated successfully, but these errors were encountered: