diff --git a/README.md b/README.md index 6425f431ab..1ea8b9d86d 100644 --- a/README.md +++ b/README.md @@ -138,11 +138,10 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') model_ref = create_reference_model(model) tokenizer = AutoTokenizer.from_pretrained('gpt2') +tokenizer.pad_token = tokenizer.eos_token # initialize trainer -ppo_config = PPOConfig( - batch_size=1, -) +ppo_config = PPOConfig(batch_size=1, mini_batch_size=1) # encode a query query_txt = "This morning I went to the "