Skip to content

Commit

Permalink
minor BCO fixes (#1923)
Browse files Browse the repository at this point in the history
* checkpointing BCO UDM classifier

* kto_config remove unused parameters

* BCO fix loading

* kto_config remove unused parameters

* kto_config remove unused parameters

---------

Co-authored-by: Clara Luise Pohland <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
  • Loading branch information
3 people authored Aug 14, 2024
1 parent f05f63c commit c1b272f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 37 deletions.
8 changes: 0 additions & 8 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,6 @@ def test_kto(self):
model_init_kwargs={"trust_remote_code": True},
ref_model_init_kwargs={"trust_remote_code": True},
dataset_num_proc=4,
loss_type="bco",
prompt_sample_size=512,
min_density_ratio=0.2,
max_density_ratio=20.0,
)
trainer = KTOTrainer(model="gpt2", ref_model="gpt2", args=args, train_dataset=dataset, tokenizer=tokenizer)
self.assertEqual(trainer.args.max_length, 256)
Expand All @@ -218,10 +214,6 @@ def test_kto(self):
self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True})
self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True})
self.assertEqual(trainer.args.dataset_num_proc, 4)
self.assertEqual(trainer.args.loss_type, "bco")
self.assertEqual(trainer.args.prompt_sample_size, 512)
self.assertEqual(trainer.args.min_density_ratio, 0.2)
self.assertEqual(trainer.args.max_density_ratio, 20.0)

def test_online_dpo(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
Expand Down
16 changes: 14 additions & 2 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer

RUNNING_NAME = "running.pt"
RUNNING_NAME = "running.json"
CLF_NAME = "clf.pt"


def _tokenize(
Expand Down Expand Up @@ -822,6 +823,9 @@ def _save_optimizer_and_scheduler(self, output_dir):

self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))

if self.match_underlying_distribution:
torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))

def _load_optimizer_and_scheduler(self, checkpoint):
super()._load_optimizer_and_scheduler(checkpoint)

Expand All @@ -831,7 +835,15 @@ def _load_optimizer_and_scheduler(self, checkpoint):
running_file = os.path.join(checkpoint, RUNNING_NAME)
if not os.path.isfile(running_file):
warnings.warn(f"Missing file {running_file}. Will use a new running delta value for BCO loss calculation")
self.running = RunningMoments.load_from_json(self.accelerator, running_file)
else:
self.running = RunningMoments.load_from_json(self.accelerator, running_file)

if self.match_underlying_distribution:
clf_file = os.path.join(checkpoint, CLF_NAME)
if not os.path.isfile(running_file):
warnings.warn(f"Missing file {clf_file}. Will use a new UDM classifier for BCO loss calculation")
else:
self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))

@contextmanager
def null_ref_context(self):
Expand Down
28 changes: 1 addition & 27 deletions trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Literal, Optional
from typing import Dict, Optional

from transformers import TrainingArguments

from ..import_utils import is_sklearn_available


@dataclass
class KTOConfig(TrainingArguments):
Expand Down Expand Up @@ -60,14 +58,6 @@ class KTOConfig(TrainingArguments):
Dict of Optional kwargs to pass when instantiating the ref model from a string.
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the datasets.
loss_type: (`Literal["kto", "bco"]`, *optional*):
The type of loss to use. Either `"kto"` the default KTO loss, `"bco"` loss from [BCO](https://huggingface.co/papers/2404.04656) paper.
prompt_sample_size: (`int`, defaults to 1024):
Number of prompts that are fed to density ratio classifier.
min_density_ratio: (`float`, defaults to 0.5):
The minimum value of the density ratio. The estimated density ratio is clamped to this value.
max_density_ratio: (`float`, defaults to 10.0):
The maximum value of the density ratio. The estimated density ratio is clamped to this value.
"""

max_length: Optional[int] = None
Expand All @@ -92,19 +82,3 @@ class KTOConfig(TrainingArguments):
model_init_kwargs: Optional[Dict] = None
ref_model_init_kwargs: Optional[Dict] = None
dataset_num_proc: Optional[int] = None

loss_type: Literal["kto", "bco"] = "kto"

# BCO config
prompt_sample_size: int = 1024
min_density_ratio: float = 0.5
max_density_ratio: float = 10.0

def __post_init__(self):
super().__post_init__()

if self.loss_type == "bco" and not is_sklearn_available():
raise ImportError(
"You need to install scikit-learn to use loss_type='bco' "
"You can install it with `pip install scikit-learn`."
)

0 comments on commit c1b272f

Please sign in to comment.