From b66b77645ebf1ea0c8faf34610edd4816ee2a6ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 15 Dec 2024 20:34:56 +0100 Subject: [PATCH] remove deprecated --- trl/scripts/utils.py | 63 ------------------------------- trl/trainer/cpo_trainer.py | 4 -- trl/trainer/dpo_config.py | 14 ------- trl/trainer/kto_trainer.py | 4 -- trl/trainer/online_dpo_trainer.py | 4 -- trl/trainer/rloo_trainer.py | 4 -- trl/trainer/utils.py | 3 -- 7 files changed, 96 deletions(-) diff --git a/trl/scripts/utils.py b/trl/scripts/utils.py index a6637c02a3..e386a19d37 100644 --- a/trl/scripts/utils.py +++ b/trl/scripts/utils.py @@ -18,14 +18,12 @@ import os import subprocess import sys -import warnings from dataclasses import dataclass from typing import Iterable, Optional, Union import yaml from transformers import HfArgumentParser from transformers.hf_argparser import DataClass, DataClassType -from transformers.utils.deprecation import deprecate_kwarg logger = logging.getLogger(__name__) @@ -60,45 +58,6 @@ class ScriptArguments: ignore_bias_buffers: bool = False -class YamlConfigParser: - """ """ - - def __init__(self) -> None: - warnings.warn( - "The `YamlConfigParser` class is deprecated and will be removed in version 0.14. " - "If you need to use this class, please copy the code to your own project.", - DeprecationWarning, - ) - - def parse_and_set_env(self, config_path: str) -> dict: - with open(config_path) as yaml_file: - config = yaml.safe_load(yaml_file) - - if "env" in config: - env_vars = config.pop("env") - if isinstance(env_vars, dict): - for key, value in env_vars.items(): - os.environ[key] = str(value) - else: - raise ValueError("`env` field should be a dict in the YAML file.") - - return config - - def to_string(self, config): - final_string = "" - for key, value in config.items(): - if isinstance(value, (dict, list)): - if len(value) != 0: - value = str(value) - value = value.replace("'", '"') - value = f"'{value}'" - else: - continue - - final_string += f"--{key} {value} " - return final_string - - def init_zero_verbose(): """ Perform zero verbose init - use this method on top of the CLI modules to make @@ -165,16 +124,9 @@ class MyArguments: ``` """ - @deprecate_kwarg( - "ignore_extra_args", - "0.14.0", - warn_if_greater_or_equal_version=True, - additional_message="Use the `return_remaining_strings` in the `parse_args_and_config` method instead.", - ) def __init__( self, dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None, - ignore_extra_args: Optional[bool] = None, **kwargs, ): # Make sure dataclass_types is an iterable @@ -192,18 +144,6 @@ def __init__( ) super().__init__(dataclass_types=dataclass_types, **kwargs) - self._ignore_extra_args = ignore_extra_args - - def post_process_dataclasses(self, dataclasses): - """ - Post process dataclasses to merge the TrainingArguments with the SFTScriptArguments or DPOScriptArguments. - """ - warnings.warn( - "The `post_process_dataclasses` method is deprecated and will be removed in version 0.14. " - "It is no longer functional and can be safely removed from your code.", - DeprecationWarning, - ) - return dataclasses def parse_args_and_config( self, args: Optional[Iterable[str]] = None, return_remaining_strings: bool = False @@ -216,9 +156,6 @@ def parse_args_and_config( default values in the dataclasses. Command line arguments can override values set by the config file. The method also sets any environment variables specified in the `env` field of the config file. """ - if self._ignore_extra_args is not None: - return_remaining_strings = not self._ignore_extra_args - args = list(args) if args is not None else sys.argv[1:] if "--config" in args: # Get the config file path from diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 2998d534cf..6d236cfb37 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -45,7 +45,6 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available, is_torch_fx_proxy -from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from .cpo_config import CPOConfig @@ -106,9 +105,6 @@ class CPOTrainer(Trainer): _tag_names = ["trl", "cpo"] - @deprecate_kwarg( - "tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True - ) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 88abdd4a5c..e7abe4bb6e 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from dataclasses import dataclass from enum import Enum from typing import Any, Literal, Optional @@ -172,7 +171,6 @@ class DPOConfig(TrainingArguments): truncation_mode: str = "keep_end" max_length: Optional[int] = None max_prompt_length: Optional[int] = None - max_target_length: Optional[int] = None # deprecated in favor of max_completion_length max_completion_length: Optional[int] = None is_encoder_decoder: Optional[bool] = None disable_dropout: bool = True @@ -194,15 +192,3 @@ class DPOConfig(TrainingArguments): rpo_alpha: Optional[float] = None discopop_tau: float = 0.05 use_num_logits_to_keep: bool = False - - def __post_init__(self): - if self.max_target_length is not None: - warnings.warn( - "The `max_target_length` argument is deprecated in favor of `max_completion_length` and will be " - "removed in v0.14.", - FutureWarning, - ) - if self.max_completion_length is None: - self.max_completion_length = self.max_target_length - - return super().__post_init__() diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index fb0b39bbe9..d054d97e7d 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -49,7 +49,6 @@ ) from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available -from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset from ..models import PreTrainedModelWrapper, create_reference_model @@ -318,9 +317,6 @@ class KTOTrainer(Trainer): _tag_names = ["trl", "kto"] - @deprecate_kwarg( - "tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True - ) def __init__( self, model: Union[PreTrainedModel, nn.Module, str] = None, diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index dd34d19af0..68008881f5 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -45,7 +45,6 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker from transformers.training_args import OptimizerNames from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging -from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from ..models import create_reference_model @@ -128,9 +127,6 @@ class OnlineDPOTrainer(Trainer): _tag_names = ["trl", "online-dpo"] - @deprecate_kwarg( - "tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True - ) def __init__( self, model: Union[PreTrainedModel, nn.Module], diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index fa9634696a..23ea1ca21f 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -44,7 +44,6 @@ from transformers.integrations import get_reporting_integration_callbacks from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback -from transformers.utils.deprecation import deprecate_kwarg from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( @@ -72,9 +71,6 @@ class RLOOTrainer(Trainer): _tag_names = ["trl", "rloo"] - @deprecate_kwarg( - "tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True - ) def __init__( self, config: RLOOConfig, diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 30e453532f..bf760570bc 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -50,7 +50,6 @@ is_torch_npu_available, is_torch_xpu_available, ) -from transformers.utils.deprecation import deprecate_kwarg from ..import_utils import is_unsloth_available from ..trainer.model_config import ModelConfig @@ -897,7 +896,6 @@ def trl_sanitze_kwargs_for_tagging(model, tag_names, kwargs=None): return kwargs -@deprecate_kwarg("model_config", "0.14.0", "model_args", warn_if_greater_or_equal_version=True) def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]: if model_args.load_in_4bit: quantization_config = BitsAndBytesConfig( @@ -926,7 +924,6 @@ def get_kbit_device_map() -> Optional[dict[str, int]]: return None -@deprecate_kwarg("model_config", "0.14.0", "model_args", warn_if_greater_or_equal_version=True) def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]": if model_args.use_peft is False: return None