Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

t5-sentiment example collapses on master #256

Closed
janpawlowskiof opened this issue Mar 28, 2023 · 10 comments · Fixed by #262
Closed

t5-sentiment example collapses on master #256

janpawlowskiof opened this issue Mar 28, 2023 · 10 comments · Fixed by #262

Comments

@janpawlowskiof
Copy link

janpawlowskiof commented Mar 28, 2023

I just rerun the t5-sentiment example and rn on master it has negative kl divergence in new version and does not learn in general.
This seems not have happened since the v0.4.1 release.

The t5-sentiment.py script itself does not seem to be the culprit as i tested master with it reverted to v0.4.1 version and the behavior is identical.

image

@lvwerra
Copy link
Member

lvwerra commented Mar 29, 2023

@younesbelkada can you reproduce? negative KL is very suspicious!

@younesbelkada
Copy link
Contributor

I was able to reproduce, will investigate!

@younesbelkada
Copy link
Contributor

The culprit seems to be b5cce0d
This PR introduced batched generation on the t5 example, as it can be observed on the wandb log the kl is negative, I can confirm the KL was always positive before that commit

@younesbelkada
Copy link
Contributor

#262 should fix the issue!

@GauravVirmani
Copy link
Contributor

@younesbelkada Understood the bug. I should have checked more diligently.

@younesbelkada
Copy link
Contributor

younesbelkada commented Mar 30, 2023

No problem at all @GauravVirmani ! Don't worry about that as it can happen to anyone! It is also my fault as I did not flagged that the KL was negative when running the experiment

@chizhikchi
Copy link

I was following the t5-sentiment example in order to run RL training on a custom dataset with a custom metric and it also showed negative KL. So I looked into this issue and the created pull-request, which lead me to I rerun my experiments the same way as @younesbelkada has done in the pull-request:

Unfortunately, none of this fixed the issue. My knowledge about PPO is limited, so I cannot contribute much to the discussion about the underlying issue, but I hope that this information might be useful. Also, I'll be grateful if you point-out an error that I could have committed.

Thanks a lot for your effort working on this amazing library!

@younesbelkada
Copy link
Contributor

Hi @chizhikchi
Thanks for the heads up and for your words!
Sadly your wandb reports seems to be private so we can't see it
We will definitely investigate that, can you double check the solution proposed by @janpawlowskiof , i.e. try on the 0.4.1 release? Also I would give it a try without batched generation
Let us know how it goes!

@chizhikchi
Copy link

Hi, @younesbelkada , thank you for the suggestions!
I ran the same experiment on the tix-t5-neg-kl branch without batched generation and I seemed to work better: the KL got negative on some batches, as can be seen on the graph. I aborted this experiment, because it was giving unpromising results.

Then, I run the same experiment on the 0.4.1 version. KL wasn't negative this time, so the problem seems to be related to batched generation.

W B Chart 11_4_2023, 11_36_48

My model didn't improve much, though, but I think that's a more of a problem of the reward definition and the complexity of the task.

W B Chart 11_4_2023, 11_41_43

Hope this information helps! Have a nice day :)

@hecongqing
Copy link

Meaningful experiment. Generating results individually in this code without using batch generation?

response_tensors = ppo_trainer.generate(
query_tensors, return_prompt=False, length_sampler=output_length_sampler, **generation_kwargs
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants