diff --git a/setup.py b/setup.py index dff8fdb61494..cc653bd82d17 100644 --- a/setup.py +++ b/setup.py @@ -149,7 +149,7 @@ "pytest-timeout", "pytest-xdist", "python>=3.8.0", - "ray[tune]", + "ray[tune]>=2.7.0", "regex!=2019.12.17", "requests", "rhoknp>=1.1.0,<1.3.1", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index a26a93df9627..fcace1826ac4 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -55,7 +55,7 @@ "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", "python": "python>=3.8.0", - "ray[tune]": "ray[tune]", + "ray[tune]": "ray[tune]>=2.7.0", "regex": "regex!=2019.12.17", "requests": "requests", "rhoknp": "rhoknp>=1.1.0,<1.3.1", diff --git a/src/transformers/hyperparameter_search.py b/src/transformers/hyperparameter_search.py index 8dfd60cc39cd..c14165165ca1 100644 --- a/src/transformers/hyperparameter_search.py +++ b/src/transformers/hyperparameter_search.py @@ -15,7 +15,7 @@ from .integrations import ( is_optuna_available, - is_ray_available, + is_ray_tune_available, is_sigopt_available, is_wandb_available, run_hp_search_optuna, @@ -81,7 +81,7 @@ class RayTuneBackend(HyperParamSearchBackendBase): @staticmethod def is_available(): - return is_ray_available() + return is_ray_tune_available() def run(self, trainer, n_trials: int, direction: str, **kwargs): return run_hp_search_ray(trainer, n_trials, direction, **kwargs) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 5eef480ac93f..dbcbe0bc551e 100644 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -236,8 +236,9 @@ def _objective(trial, checkpoint_dir=None): def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: import ray + import ray.train - def _objective(trial, local_trainer, checkpoint_dir=None): + def _objective(trial: dict, local_trainer): try: from transformers.utils.notebook import NotebookProgressCallback @@ -246,19 +247,34 @@ def _objective(trial, local_trainer, checkpoint_dir=None): except ModuleNotFoundError: pass - checkpoint = None - if checkpoint_dir: - for subdir in os.listdir(checkpoint_dir): - if subdir.startswith(PREFIX_CHECKPOINT_DIR): - checkpoint = os.path.join(checkpoint_dir, subdir) local_trainer.objective = None - local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial) + + checkpoint = ray.train.get_checkpoint() + if checkpoint: + # Upon trial resume, the local_trainer's objective gets reset to None. + # If `local_trainer.train` is a noop (training has already reached + # the target number of epochs/steps), then this would + # trigger an unnecessary extra checkpoint at the end of training. + # -> Set the objective to a dummy value upon resume as a workaround. + local_trainer.objective = "objective" + + with checkpoint.as_directory() as checkpoint_dir: + checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix() + local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial) + else: + local_trainer.train(trial=trial) + # If there hasn't been any evaluation during the training loop. if getattr(local_trainer, "objective", None) is None: metrics = local_trainer.evaluate() local_trainer.objective = local_trainer.compute_objective(metrics) - local_trainer._tune_save_checkpoint() - ray.tune.report(objective=local_trainer.objective, **metrics, done=True) + + metrics.update({"objective": local_trainer.objective, "done": True}) + + with tempfile.TemporaryDirectory() as temp_checkpoint_dir: + local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir) + checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir) + ray.train.report(metrics, checkpoint=checkpoint) if not trainer._memory_tracker.skip_memory_metrics: from ..trainer_utils import TrainerMemoryTracker @@ -296,28 +312,10 @@ def _objective(trial, local_trainer, checkpoint_dir=None): from ray.tune import CLIReporter kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"]) - if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0: - # `keep_checkpoints_num=0` would disabled checkpointing - trainer.use_tune_checkpoints = True - if kwargs["keep_checkpoints_num"] > 1: - logger.warning( - f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. " - "Checkpoints are usually huge, " - "consider setting `keep_checkpoints_num=1`." - ) + if "scheduler" in kwargs: from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining - # Check if checkpointing is enabled for PopulationBasedTraining - if isinstance(kwargs["scheduler"], PopulationBasedTraining): - if not trainer.use_tune_checkpoints: - logger.warning( - "You are using PopulationBasedTraining but you haven't enabled checkpointing. " - "This means your trials will train from scratch everytime they are exploiting " - "new configurations. Consider enabling checkpointing by passing " - "`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`." - ) - # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting. if isinstance( kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 584459907127..742bb3392986 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -28,6 +28,7 @@ import re import shutil import sys +import tempfile import time import warnings from collections.abc import Mapping @@ -595,7 +596,6 @@ def __init__( # returned to 0 every time flos need to be logged self.current_flos = 0 self.hp_search_backend = None - self.use_tune_checkpoints = False default_label_names = find_labels(self.model.__class__) self.label_names = default_label_names if self.args.label_names is None else self.args.label_names self.can_return_loss = can_return_loss(self.model.__class__) @@ -1201,7 +1201,8 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): if self.hp_search_backend is None or trial is None: return - self.objective = self.compute_objective(metrics.copy()) + metrics = metrics.copy() + self.objective = self.compute_objective(metrics) if self.hp_search_backend == HPSearchBackend.OPTUNA: import optuna @@ -1211,24 +1212,23 @@ def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], ste self.callback_handler.on_train_end(self.args, self.state, self.control) raise optuna.TrialPruned() elif self.hp_search_backend == HPSearchBackend.RAY: - from ray import tune - - if self.control.should_save: - self._tune_save_checkpoint() - tune.report(objective=self.objective, **metrics) - - def _tune_save_checkpoint(self): - from ray import tune - - if not self.use_tune_checkpoints: - return - with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: - output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") - self.save_model(output_dir, _internal_call=True) - if self.args.should_save: - self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) - torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) - torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + import ray.train + + with tempfile.TemporaryDirectory() as temp_checkpoint_dir: + checkpoint = None + if self.control.should_save: + self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir) + checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir) + metrics["objective"] = self.objective + ray.train.report(metrics, checkpoint=checkpoint) + + def _tune_save_checkpoint(self, checkpoint_dir: str): + output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") + self.save_model(output_dir, _internal_call=True) + if self.args.should_save: + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) def call_model_init(self, trial=None): model_init_argcount = number_of_arguments(self.model_init) @@ -2004,9 +2004,9 @@ def _get_output_dir(self, trial): if self.hp_search_backend == HPSearchBackend.OPTUNA: run_id = trial.number elif self.hp_search_backend == HPSearchBackend.RAY: - from ray import tune + import ray.train - run_id = tune.get_trial_id() + run_id = ray.train.get_context().get_trial_id() elif self.hp_search_backend == HPSearchBackend.SIGOPT: run_id = trial.id elif self.hp_search_backend == HPSearchBackend.WANDB: