Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor BCO fixes #1923

Merged
merged 7 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,6 @@ def test_kto(self):
model_init_kwargs={"trust_remote_code": True},
ref_model_init_kwargs={"trust_remote_code": True},
dataset_num_proc=4,
loss_type="bco",
prompt_sample_size=512,
min_density_ratio=0.2,
max_density_ratio=20.0,
)
trainer = KTOTrainer(model="gpt2", ref_model="gpt2", args=args, train_dataset=dataset, tokenizer=tokenizer)
self.assertEqual(trainer.args.max_length, 256)
Expand All @@ -218,10 +214,6 @@ def test_kto(self):
self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True})
self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True})
self.assertEqual(trainer.args.dataset_num_proc, 4)
self.assertEqual(trainer.args.loss_type, "bco")
self.assertEqual(trainer.args.prompt_sample_size, 512)
self.assertEqual(trainer.args.min_density_ratio, 0.2)
self.assertEqual(trainer.args.max_density_ratio, 20.0)

def test_online_dpo(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
Expand Down
16 changes: 14 additions & 2 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer

RUNNING_NAME = "running.pt"
RUNNING_NAME = "running.json"
CLF_NAME = "clf.pt"


def _tokenize(
Expand Down Expand Up @@ -822,6 +823,9 @@ def _save_optimizer_and_scheduler(self, output_dir):

self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))

if self.match_underlying_distribution:
torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))

def _load_optimizer_and_scheduler(self, checkpoint):
super()._load_optimizer_and_scheduler(checkpoint)

Expand All @@ -831,7 +835,15 @@ def _load_optimizer_and_scheduler(self, checkpoint):
running_file = os.path.join(checkpoint, RUNNING_NAME)
if not os.path.isfile(running_file):
warnings.warn(f"Missing file {running_file}. Will use a new running delta value for BCO loss calculation")
self.running = RunningMoments.load_from_json(self.accelerator, running_file)
else:
self.running = RunningMoments.load_from_json(self.accelerator, running_file)

if self.match_underlying_distribution:
clf_file = os.path.join(checkpoint, CLF_NAME)
if not os.path.isfile(running_file):
warnings.warn(f"Missing file {clf_file}. Will use a new UDM classifier for BCO loss calculation")
else:
self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))

@contextmanager
def null_ref_context(self):
Expand Down
28 changes: 1 addition & 27 deletions trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Literal, Optional
from typing import Dict, Optional

from transformers import TrainingArguments

from ..import_utils import is_sklearn_available


@dataclass
class KTOConfig(TrainingArguments):
Expand Down Expand Up @@ -60,14 +58,6 @@ class KTOConfig(TrainingArguments):
Dict of Optional kwargs to pass when instantiating the ref model from a string.
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the datasets.
loss_type: (`Literal["kto", "bco"]`, *optional*):
The type of loss to use. Either `"kto"` the default KTO loss, `"bco"` loss from [BCO](https://huggingface.co/papers/2404.04656) paper.
prompt_sample_size: (`int`, defaults to 1024):
Number of prompts that are fed to density ratio classifier.
min_density_ratio: (`float`, defaults to 0.5):
The minimum value of the density ratio. The estimated density ratio is clamped to this value.
max_density_ratio: (`float`, defaults to 10.0):
The maximum value of the density ratio. The estimated density ratio is clamped to this value.
"""

max_length: Optional[int] = None
Expand All @@ -92,19 +82,3 @@ class KTOConfig(TrainingArguments):
model_init_kwargs: Optional[Dict] = None
ref_model_init_kwargs: Optional[Dict] = None
dataset_num_proc: Optional[int] = None

loss_type: Literal["kto", "bco"] = "kto"

# BCO config
prompt_sample_size: int = 1024
min_density_ratio: float = 0.5
max_density_ratio: float = 10.0

def __post_init__(self):
super().__post_init__()

if self.loss_type == "bco" and not is_sklearn_available():
raise ImportError(
"You need to install scikit-learn to use loss_type='bco' "
"You can install it with `pip install scikit-learn`."
)
Loading