Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add report_to training arguments to control the integrations used #9735

Merged
merged 1 commit into from
Jan 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,21 @@ def _objective(trial, checkpoint_dir=None):
return best_run


def get_available_reporting_integrations():
integrations = []
if is_azureml_available():
integrations.append("azure_ml")
if is_comet_available():
integrations.append("comet_ml")
if is_mlflow_available():
integrations.append("mlflow")
if is_tensorboard_available():
integrations.append("tensorboard")
if is_wandb_available():
integrations.append("wandb")
return integrations


def rewrite_logs(d):
new_d = {}
eval_prefix = "eval_"
Expand Down Expand Up @@ -757,3 +772,21 @@ def __del__(self):
# not let you start a new run before the previous one is killed
if self._ml_flow.active_run is not None:
self._ml_flow.end_run(status="KILLED")


INTEGRATION_TO_CALLBACK = {
"azure_ml": AzureMLCallback,
"comet_ml": CometCallback,
"mlflow": MLflowCallback,
"tensorboard": TensorBoardCallback,
"wandb": WandbCallback,
}


def get_reporting_integration_callbacks(report_to):
for integration in report_to:
if integration not in INTEGRATION_TO_CALLBACK:
raise ValueError(
f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
)
return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]
35 changes: 3 additions & 32 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,11 @@
# Integrations must be imported before ML frameworks:
from .integrations import ( # isort: split
default_hp_search_backend,
get_reporting_integration_callbacks,
hp_params,
is_azureml_available,
is_comet_available,
is_fairscale_available,
is_mlflow_available,
is_optuna_available,
is_ray_tune_available,
is_tensorboard_available,
is_wandb_available,
run_hp_search_optuna,
run_hp_search_ray,
init_deepspeed,
Expand Down Expand Up @@ -124,32 +120,6 @@
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl

if is_tensorboard_available():
from .integrations import TensorBoardCallback

DEFAULT_CALLBACKS.append(TensorBoardCallback)


if is_wandb_available():
from .integrations import WandbCallback

DEFAULT_CALLBACKS.append(WandbCallback)

if is_comet_available():
from .integrations import CometCallback

DEFAULT_CALLBACKS.append(CometCallback)

if is_mlflow_available():
from .integrations import MLflowCallback

DEFAULT_CALLBACKS.append(MLflowCallback)

if is_azureml_available():
from .integrations import AzureMLCallback

DEFAULT_CALLBACKS.append(AzureMLCallback)

if is_fairscale_available():
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
Expand Down Expand Up @@ -300,7 +270,8 @@ def __init__(
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
)
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ class TrainingArguments:
group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to group together samples of roughly the same legnth in the training dataset (to minimize
padding applied and be more efficient). Only useful if applying dynamic padding.
report_to (:obj:`List[str]`, `optional`, defaults to the list of integrations platforms installed):
The list of integrations to report the results and logs to. Supported platforms are :obj:`"azure_ml"`,
:obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`.
"""

output_dir: str = field(
Expand Down Expand Up @@ -413,6 +416,9 @@ class TrainingArguments:
default=False,
metadata={"help": "Whether or not to group samples of roughly the same length together when batching."},
)
report_to: Optional[List[str]] = field(
default=None, metadata={"help": "The list of integrations to report the results and logs to."}
)
_n_gpu: int = field(init=False, repr=False, default=-1)

def __post_init__(self):
Expand All @@ -434,6 +440,11 @@ def __post_init__(self):

if is_torch_available() and self.device.type != "cuda" and self.fp16:
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
if self.report_to is None:
# Import at runtime to avoid a circular import.
from .integrations import get_available_reporting_integrations

self.report_to = get_available_reporting_integrations()

def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
Expand Down