Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚰️ Remove deprecated #2485

Merged
merged 1 commit into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 0 additions & 63 deletions trl/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
import os
import subprocess
import sys
import warnings
from dataclasses import dataclass
from typing import Iterable, Optional, Union

import yaml
from transformers import HfArgumentParser
from transformers.hf_argparser import DataClass, DataClassType
from transformers.utils.deprecation import deprecate_kwarg


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,45 +58,6 @@ class ScriptArguments:
ignore_bias_buffers: bool = False


class YamlConfigParser:
""" """

def __init__(self) -> None:
warnings.warn(
"The `YamlConfigParser` class is deprecated and will be removed in version 0.14. "
"If you need to use this class, please copy the code to your own project.",
DeprecationWarning,
)

def parse_and_set_env(self, config_path: str) -> dict:
with open(config_path) as yaml_file:
config = yaml.safe_load(yaml_file)

if "env" in config:
env_vars = config.pop("env")
if isinstance(env_vars, dict):
for key, value in env_vars.items():
os.environ[key] = str(value)
else:
raise ValueError("`env` field should be a dict in the YAML file.")

return config

def to_string(self, config):
final_string = ""
for key, value in config.items():
if isinstance(value, (dict, list)):
if len(value) != 0:
value = str(value)
value = value.replace("'", '"')
value = f"'{value}'"
else:
continue

final_string += f"--{key} {value} "
return final_string


def init_zero_verbose():
"""
Perform zero verbose init - use this method on top of the CLI modules to make
Expand Down Expand Up @@ -165,16 +124,9 @@ class MyArguments:
```
"""

@deprecate_kwarg(
"ignore_extra_args",
"0.14.0",
warn_if_greater_or_equal_version=True,
additional_message="Use the `return_remaining_strings` in the `parse_args_and_config` method instead.",
)
def __init__(
self,
dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None,
ignore_extra_args: Optional[bool] = None,
**kwargs,
):
# Make sure dataclass_types is an iterable
Expand All @@ -192,18 +144,6 @@ def __init__(
)

super().__init__(dataclass_types=dataclass_types, **kwargs)
self._ignore_extra_args = ignore_extra_args

def post_process_dataclasses(self, dataclasses):
"""
Post process dataclasses to merge the TrainingArguments with the SFTScriptArguments or DPOScriptArguments.
"""
warnings.warn(
"The `post_process_dataclasses` method is deprecated and will be removed in version 0.14. "
"It is no longer functional and can be safely removed from your code.",
DeprecationWarning,
)
return dataclasses

def parse_args_and_config(
self, args: Optional[Iterable[str]] = None, return_remaining_strings: bool = False
Expand All @@ -216,9 +156,6 @@ def parse_args_and_config(
default values in the dataclasses. Command line arguments can override values set by the config file. The
method also sets any environment variables specified in the `env` field of the config file.
"""
if self._ignore_extra_args is not None:
return_remaining_strings = not self._ignore_extra_args

args = list(args) if args is not None else sys.argv[1:]
if "--config" in args:
# Get the config file path from
Expand Down
4 changes: 0 additions & 4 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from .cpo_config import CPOConfig
Expand Down Expand Up @@ -106,9 +105,6 @@ class CPOTrainer(Trainer):

_tag_names = ["trl", "cpo"]

@deprecate_kwarg(
"tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
Expand Down
14 changes: 0 additions & 14 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, Optional
Expand Down Expand Up @@ -172,7 +171,6 @@ class DPOConfig(TrainingArguments):
truncation_mode: str = "keep_end"
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_target_length: Optional[int] = None # deprecated in favor of max_completion_length
max_completion_length: Optional[int] = None
is_encoder_decoder: Optional[bool] = None
disable_dropout: bool = True
Expand All @@ -194,15 +192,3 @@ class DPOConfig(TrainingArguments):
rpo_alpha: Optional[float] = None
discopop_tau: float = 0.05
use_num_logits_to_keep: bool = False

def __post_init__(self):
if self.max_target_length is not None:
warnings.warn(
"The `max_target_length` argument is deprecated in favor of `max_completion_length` and will be "
"removed in v0.14.",
FutureWarning,
)
if self.max_completion_length is None:
self.max_completion_length = self.max_target_length

return super().__post_init__()
4 changes: 0 additions & 4 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
)
from transformers.trainer_utils import EvalLoopOutput, has_length
from transformers.utils import is_peft_available
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset
from ..models import PreTrainedModelWrapper, create_reference_model
Expand Down Expand Up @@ -318,9 +317,6 @@ class KTOTrainer(Trainer):

_tag_names = ["trl", "kto"]

@deprecate_kwarg(
"tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
Expand Down
4 changes: 0 additions & 4 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker
from transformers.training_args import OptimizerNames
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..models import create_reference_model
Expand Down Expand Up @@ -128,9 +127,6 @@ class OnlineDPOTrainer(Trainer):

_tag_names = ["trl", "online-dpo"]

@deprecate_kwarg(
"tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Union[PreTrainedModel, nn.Module],
Expand Down
4 changes: 0 additions & 4 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils.deprecation import deprecate_kwarg

from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
Expand Down Expand Up @@ -72,9 +71,6 @@
class RLOOTrainer(Trainer):
_tag_names = ["trl", "rloo"]

@deprecate_kwarg(
"tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
config: RLOOConfig,
Expand Down
3 changes: 0 additions & 3 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
is_torch_npu_available,
is_torch_xpu_available,
)
from transformers.utils.deprecation import deprecate_kwarg

from ..import_utils import is_unsloth_available
from ..trainer.model_config import ModelConfig
Expand Down Expand Up @@ -897,7 +896,6 @@ def trl_sanitze_kwargs_for_tagging(model, tag_names, kwargs=None):
return kwargs


@deprecate_kwarg("model_config", "0.14.0", "model_args", warn_if_greater_or_equal_version=True)
def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]:
if model_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
Expand Down Expand Up @@ -926,7 +924,6 @@ def get_kbit_device_map() -> Optional[dict[str, int]]:
return None


@deprecate_kwarg("model_config", "0.14.0", "model_args", warn_if_greater_or_equal_version=True)
def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]":
if model_args.use_peft is False:
return None
Expand Down
Loading