Skip to content

Commit

Permalink
Small improvements / fixes to toxicity example (#266)
Browse files Browse the repository at this point in the history
* fixes during debugging

* Update examples/toxicity/scripts/gpt-j-6b-toxicity.py

Co-authored-by: Younes Belkada <[email protected]>

---------

Co-authored-by: Younes Belkada <[email protected]>
  • Loading branch information
Nathan Lambert and younesbelkada authored Apr 10, 2023
1 parent 131e5cd commit fc468e0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
13 changes: 9 additions & 4 deletions examples/toxicity/scripts/gpt-j-6b-toxicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,15 @@ class ScriptArguments:
model_name: Optional[str] = field(default="ybelkada/gpt-j-6b-sharded-bf16", metadata={"help": "the model name"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=(1.47e-5) * 2, metadata={"help": "the learning rate"})
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
batch_size: Optional[int] = field(default=256, metadata={"help": "the batch size"})
mini_batch_size: Optional[int] = field(default=4, metadata={"help": "the PPO minibatch size"})
batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"})
gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
model_save_path: Optional[str] = field(
default="./gpt-j-6B-detoxified-long-context-26-shl-1e4-final",
metadata={"help": "the path to save the model"},
)


parser = HfArgumentParser(ScriptArguments)
Expand All @@ -81,6 +85,7 @@ class ScriptArguments:
model_name=script_args.model_name,
learning_rate=script_args.learning_rate,
log_with=script_args.log_with,
ppo_epochs=100,
mini_batch_size=script_args.mini_batch_size,
batch_size=script_args.batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
Expand Down Expand Up @@ -199,12 +204,12 @@ def collator(data):
output_max_length = 30
output_length_sampler = LengthSampler(output_min_length, output_max_length)

model_save_path = "/mnt/disks/younes-disk/models/gpt-j-6B-detoxified-long-context-26-shl-1e4-final"
model_save_path = script_args.model_save_path

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
query_tensors = batch["input_ids"]

# Get response from gpt2
# Get response from the policy model
response_tensors = []
for query in query_tensors:
gen_len = output_length_sampler()
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ class PPOConfig(object):
seed: Optional[int] = field(default=0, metadata={"help": "Seed value for random generations"})
optimize_cuda_cache: Optional[bool] = field(
default=False,
metadata={"help": "Optimize CUDA cache for slightly more memory-effcient training"},
metadata={"help": "Optimize CUDA cache for slightly more memory-efficient training"},
)
early_stopping: Optional[bool] = field(
default=False, metadata={"help": "Whether to stop the PPO opimization loop early is the KL too high"}
default=False, metadata={"help": "Whether to stop the PPO optimization loop early is the KL too high"}
)
target_kl: Optional[float] = field(
default=0.1, metadata={"help": "Stop early if we exceed this value by over 50%"}
Expand Down

0 comments on commit fc468e0

Please sign in to comment.