Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new RLHF benchmark #273

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions benchmarks/rlhf/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Use global base if possible
ifndef MILABENCH_BASE
MILABENCH_BASE="base"
endif

export MILABENCH_BASE

BENCH_NAME=rlhf
MILABENCH_CONFIG=dev.yaml
MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE)

all:
install prepare single gpus nodes

install:
milabench install $(MILABENCH_ARGS) --force

prepare:
milabench prepare $(MILABENCH_ARGS)

tests: install prepare
milabench run $(MILABENCH_ARGS)

single:
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-single

gpus:
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus

nodes:
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes
4 changes: 4 additions & 0 deletions benchmarks/rlhf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

# Rlhf

Rewrite this README to explain what the benchmark is!
41 changes: 41 additions & 0 deletions benchmarks/rlhf/benchfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from milabench.pack import Package


class Rlhf(Package):
# Requirements file installed by install(). It can be empty or absent.
base_requirements = "requirements.in"

# The preparation script called by prepare(). It must be executable,
# but it can be any type of script. It can be empty or absent.
prepare_script = "prepare.py"

# The main script called by run(). It must be a Python file. It has to
# be present.
main_script = "main.py"

# You can remove the functions below if you don't need to modify them.

def make_env(self):
# Return a dict of environment variables for prepare_script and
# main_script.
return super().make_env()

async def install(self):
await super().install() # super() call installs the requirements

async def prepare(self):
await super().prepare() # super() call executes prepare_script

def build_run_plan(self):
from milabench.commands import PackCommand, AccelerateAllNodes

main = self.dirs.code / self.main_script
plan = PackCommand(self, *self.argv, lazy=True)

if False:
plan = VoirCommand(plan, cwd=main.parent)

return AccelerateAllNodes(plan).use_stdout()


__pack__ = Rlhf
29 changes: 29 additions & 0 deletions benchmarks/rlhf/dev.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

rlhf_:
inherits: _defaults
definition: .
install-variant: unpinned
install_group: torch
plan:
method: per_gpu

argv:
--output_dir: "{milabench_extra}/output"
--model_name_or_path: EleutherAI/pythia-1b-deduped
--per_device_train_batch_size: 64
--logging_strategy: "no"
--log_level: "critical"
--bf16: true


rlhf-single:
inherits: rlhf_
plan:
method: per_gpu


rlhf-gpus:
inherits: rlhf_
plan:
method: njobs
n: 1
136 changes: 136 additions & 0 deletions benchmarks/rlhf/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#!/usr/bin/env python

import shutil

from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
)

from trl import ModelConfig
from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE


class PPOv2TrainerIntrumented(PPOv2Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def batch_size_fn(batch):
x, y = batch['input_ids'].shape
return x * y

from benchmate.observer import BenchObserver
observer = BenchObserver(
batch_size_fn=batch_size_fn,
earlystop=70,
raise_stop_program=True,
stdout=True,
)

self.dataloader = observer.iterate(self.dataloader)

def generate_completions(self, sampling: bool = False):
pass

def _save_checkpoint(self, *args, **kwargs):
pass

def save_model(self, *args, **kwargs):
pass


def main():
parser = HfArgumentParser((PPOv2Config, ModelConfig))
config, model_config = parser.parse_args_into_dataclasses()
# remove output_dir if exists
shutil.rmtree(config.output_dir, ignore_errors=True)

################
# Model & Tokenizer
################
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
value_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
################
# Dataset
################
raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
eval_samples = 20
train_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples))
eval_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples, len(raw_datasets)))
dataset_text_field = "prompt"

def prepare_dataset(dataset, tokenizer):
"""pre-tokenize the dataset before training; only collate during training"""

def tokenize(element):
outputs = tokenizer(
element[dataset_text_field],
padding=False,
)
return {"input_ids": outputs["input_ids"]}

return dataset.map(
tokenize,
batched=True,
remove_columns=dataset.column_names,
num_proc=config.dataset_num_proc,
)

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)

################
# Training
################
trainer = PPOv2TrainerIntrumented(
config=config,
tokenizer=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(config.output_dir)
if config.push_to_hub:
trainer.push_to_hub()
trainer.generate_completions()


if __name__ == "__main__":
from voir.phase import StopProgram
from benchmate.monitor import bench_monitor

try:
with bench_monitor():
main()
except StopProgram:
pass
54 changes: 54 additions & 0 deletions benchmarks/rlhf/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python

import shutil

from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
)
from datasets import load_dataset
from trl import ModelConfig
from trl.trainer.ppov2_trainer import PPOv2Config
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE


if __name__ == "__main__":
parser = HfArgumentParser((PPOv2Config, ModelConfig))
config, model_config = parser.parse_args_into_dataclasses()

# remove output_dir if exists
shutil.rmtree(config.output_dir, ignore_errors=True)

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
)

tokenizer.add_special_tokens({"pad_token": "[PAD]"})

if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE

value_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path,
trust_remote_code=model_config.trust_remote_code,
num_labels=1
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path,
trust_remote_code=model_config.trust_remote_code,
num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path,
trust_remote_code=model_config.trust_remote_code
)
policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path,
trust_remote_code=model_config.trust_remote_code
)

raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
Loading
Loading