Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX backend + possibly others] generate() decode error #1149

Closed
abheesht17 opened this issue Jul 13, 2023 · 1 comment
Closed

[JAX backend + possibly others] generate() decode error #1149

abheesht17 opened this issue Jul 13, 2023 · 1 comment

Comments

@abheesht17
Copy link
Collaborator

abheesht17 commented Jul 13, 2023

Do not see the same error in legacy Keras. Secondly, there is a huge difference in loss (JAX at 5.5), whereas Keras is at 3.something. So, I guess the model is not getting trained as well as the legacy Keras one, and is producing gibberish. I do not see this error when I train on more data. But in any case, we really need to sort out this decode issue

Repro: https://colab.research.google.com/drive/1TyEUS9fCJBq9duZMgOXJlchYsvH2r5Vx?usp=sharing

@abheesht17 abheesht17 changed the title [JAX backend + possible others] generate() decode error [JAX backend + possibly others] generate() decode error Jul 13, 2023
@mattdangerw
Copy link
Member

I think we fixed this with a combo of #1150 and landing a loss scale optimizer in Keras Core/Keras 3. Let's reopen if I am mistaken!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants