diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 2eb8b1a8efc1..af7909fca230 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -225,6 +225,21 @@ def _objective(trial, checkpoint_dir=None): return best_run +def get_available_reporting_integrations(): + integrations = [] + if is_azureml_available(): + integrations.append("azure_ml") + if is_comet_available(): + integrations.append("comet_ml") + if is_mlflow_available(): + integrations.append("mlflow") + if is_tensorboard_available(): + integrations.append("tensorboard") + if is_wandb_available(): + integrations.append("wandb") + return integrations + + def rewrite_logs(d): new_d = {} eval_prefix = "eval_" @@ -757,3 +772,21 @@ def __del__(self): # not let you start a new run before the previous one is killed if self._ml_flow.active_run is not None: self._ml_flow.end_run(status="KILLED") + + +INTEGRATION_TO_CALLBACK = { + "azure_ml": AzureMLCallback, + "comet_ml": CometCallback, + "mlflow": MLflowCallback, + "tensorboard": TensorBoardCallback, + "wandb": WandbCallback, +} + + +def get_reporting_integration_callbacks(report_to): + for integration in report_to: + if integration not in INTEGRATION_TO_CALLBACK: + raise ValueError( + f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported." + ) + return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to] diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index edc7a09cec58..f0a24129a226 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -31,15 +31,11 @@ # Integrations must be imported before ML frameworks: from .integrations import ( # isort: split default_hp_search_backend, + get_reporting_integration_callbacks, hp_params, - is_azureml_available, - is_comet_available, is_fairscale_available, - is_mlflow_available, is_optuna_available, is_ray_tune_available, - is_tensorboard_available, - is_wandb_available, run_hp_search_optuna, run_hp_search_ray, init_deepspeed, @@ -124,32 +120,6 @@ import torch_xla.debug.metrics as met import torch_xla.distributed.parallel_loader as pl -if is_tensorboard_available(): - from .integrations import TensorBoardCallback - - DEFAULT_CALLBACKS.append(TensorBoardCallback) - - -if is_wandb_available(): - from .integrations import WandbCallback - - DEFAULT_CALLBACKS.append(WandbCallback) - -if is_comet_available(): - from .integrations import CometCallback - - DEFAULT_CALLBACKS.append(CometCallback) - -if is_mlflow_available(): - from .integrations import MLflowCallback - - DEFAULT_CALLBACKS.append(MLflowCallback) - -if is_azureml_available(): - from .integrations import AzureMLCallback - - DEFAULT_CALLBACKS.append(AzureMLCallback) - if is_fairscale_available(): from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP from fairscale.optim import OSS @@ -300,7 +270,8 @@ def __init__( "Passing a `model_init` is incompatible with providing the `optimizers` argument." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) - callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks self.callback_handler = CallbackHandler( callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler ) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index d725ef1cf92c..5a7aa99bcbd4 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -231,6 +231,9 @@ class TrainingArguments: group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to group together samples of roughly the same legnth in the training dataset (to minimize padding applied and be more efficient). Only useful if applying dynamic padding. + report_to (:obj:`List[str]`, `optional`, defaults to the list of integrations platforms installed): + The list of integrations to report the results and logs to. Supported platforms are :obj:`"azure_ml"`, + :obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`. """ output_dir: str = field( @@ -413,6 +416,9 @@ class TrainingArguments: default=False, metadata={"help": "Whether or not to group samples of roughly the same length together when batching."}, ) + report_to: Optional[List[str]] = field( + default=None, metadata={"help": "The list of integrations to report the results and logs to."} + ) _n_gpu: int = field(init=False, repr=False, default=-1) def __post_init__(self): @@ -434,6 +440,11 @@ def __post_init__(self): if is_torch_available() and self.device.type != "cuda" and self.fp16: raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.") + if self.report_to is None: + # Import at runtime to avoid a circular import. + from .integrations import get_available_reporting_integrations + + self.report_to = get_available_reporting_integrations() def __repr__(self): # We override the default repr to remove deprecated arguments from the repr. This method should be removed once