-
Notifications
You must be signed in to change notification settings - Fork 28.1k
/
Copy pathintegration_utils.py
executable file
·2166 lines (1824 loc) · 94.7 KB
/
integration_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Integrations with other Python libraries.
"""
import functools
import importlib.metadata
import importlib.util
import json
import numbers
import os
import pickle
import shutil
import sys
import tempfile
from dataclasses import asdict, fields
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
import numpy as np
import packaging.version
from .. import PreTrainedModel, TFPreTrainedModel
from .. import __version__ as version
from ..utils import (
PushToHubMixin,
flatten_dict,
is_datasets_available,
is_pandas_available,
is_tf_available,
is_torch_available,
logging,
)
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
# comet_ml requires to be imported before any ML frameworks
_MIN_COMET_VERSION = "3.43.2"
try:
_comet_version = importlib.metadata.version("comet_ml")
_is_comet_installed = True
_is_comet_recent_enough = packaging.version.parse(_comet_version) >= packaging.version.parse(_MIN_COMET_VERSION)
# Check if the Comet API Key is set
import comet_ml
if comet_ml.config.get_config("comet.api_key") is not None:
_is_comet_configured = True
else:
_is_comet_configured = False
except (importlib.metadata.PackageNotFoundError, ImportError, ValueError, TypeError, AttributeError, KeyError):
_comet_version = None
_is_comet_installed = False
_is_comet_recent_enough = False
_is_comet_configured = False
_has_neptune = (
importlib.util.find_spec("neptune") is not None or importlib.util.find_spec("neptune-client") is not None
)
if TYPE_CHECKING and _has_neptune:
try:
_neptune_version = importlib.metadata.version("neptune")
logger.info(f"Neptune version {_neptune_version} available.")
except importlib.metadata.PackageNotFoundError:
try:
_neptune_version = importlib.metadata.version("neptune-client")
logger.info(f"Neptune-client version {_neptune_version} available.")
except importlib.metadata.PackageNotFoundError:
_has_neptune = False
from .. import modelcard # noqa: E402
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from ..training_args import ParallelMode # noqa: E402
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available # noqa: E402
# Integration functions:
def is_wandb_available():
# any value of WANDB_DISABLED disables wandb
if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
logger.warning(
"Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the "
"--report_to flag to control the integrations used for logging result (for instance --report_to none)."
)
return False
return importlib.util.find_spec("wandb") is not None
def is_clearml_available():
return importlib.util.find_spec("clearml") is not None
def is_comet_available():
if os.getenv("COMET_MODE", "").upper() == "DISABLED":
logger.warning(
"Using the `COMET_MODE=DISABLED` environment variable is deprecated and will be removed in v5. Use the "
"--report_to flag to control the integrations used for logging result (for instance --report_to none)."
)
return False
if _is_comet_installed is False:
return False
if _is_comet_recent_enough is False:
logger.warning(
"comet_ml version %s is installed, but version %s or higher is required. "
"Please update comet_ml to the latest version to enable Comet logging with pip install 'comet-ml>=%s'.",
_comet_version,
_MIN_COMET_VERSION,
_MIN_COMET_VERSION,
)
return False
if _is_comet_configured is False:
logger.warning(
"comet_ml is installed but the Comet API Key is not configured. "
"Please set the `COMET_API_KEY` environment variable to enable Comet logging. "
"Check out the documentation for other ways of configuring it: "
"https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#set-the-api-key"
)
return False
return True
def is_tensorboard_available():
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
def is_optuna_available():
return importlib.util.find_spec("optuna") is not None
def is_ray_available():
return importlib.util.find_spec("ray") is not None
def is_ray_tune_available():
if not is_ray_available():
return False
return importlib.util.find_spec("ray.tune") is not None
def is_sigopt_available():
return importlib.util.find_spec("sigopt") is not None
def is_azureml_available():
if importlib.util.find_spec("azureml") is None:
return False
if importlib.util.find_spec("azureml.core") is None:
return False
return importlib.util.find_spec("azureml.core.run") is not None
def is_mlflow_available():
if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
return False
return importlib.util.find_spec("mlflow") is not None
def is_dagshub_available():
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
def is_neptune_available():
return _has_neptune
def is_codecarbon_available():
return importlib.util.find_spec("codecarbon") is not None
def is_flytekit_available():
return importlib.util.find_spec("flytekit") is not None
def is_flyte_deck_standard_available():
if not is_flytekit_available():
return False
return importlib.util.find_spec("flytekitplugins.deck") is not None
def is_dvclive_available():
return importlib.util.find_spec("dvclive") is not None
def hp_params(trial):
if is_optuna_available():
import optuna
if isinstance(trial, optuna.trial.BaseTrial):
return trial.params
if is_ray_tune_available():
if isinstance(trial, dict):
return trial
if is_sigopt_available():
if isinstance(trial, dict):
return trial
if is_wandb_available():
if isinstance(trial, dict):
return trial
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import optuna
from accelerate.utils.memory import release_memory
if trainer.args.process_index == 0:
def _objective(trial: optuna.Trial, checkpoint_dir=None):
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
checkpoint = os.path.join(checkpoint_dir, subdir)
trainer.objective = None
if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
trainer.hp_space(trial)
fixed_trial = optuna.trial.FixedTrial(trial.params, trial.number)
trial_main_rank_list = [fixed_trial]
torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
else:
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
# Free GPU memory
trainer.model_wrapped, trainer.model = release_memory(trainer.model_wrapped, trainer.model)
trainer.accelerator.clear()
return trainer.objective
timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
gc_after_trial = kwargs.pop("gc_after_trial", False)
directions = direction if isinstance(direction, list) else None
direction = None if directions is not None else direction
study = optuna.create_study(direction=direction, directions=directions, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs, gc_after_trial=gc_after_trial)
if not study._is_multi_objective():
best_trial = study.best_trial
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
else:
best_trials = study.best_trials
return [BestRun(str(best.number), best.values, best.params) for best in best_trials]
else:
for i in range(n_trials):
trainer.objective = None
trial_main_rank_list = [None]
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
trainer.train(resume_from_checkpoint=None, trial=trial_main_rank_list[0])
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
return None
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import ray
import ray.train
def _objective(trial: dict, local_trainer):
try:
from transformers.utils.notebook import NotebookProgressCallback
if local_trainer.pop_callback(NotebookProgressCallback):
local_trainer.add_callback(ProgressCallback)
except ModuleNotFoundError:
pass
local_trainer.objective = None
checkpoint = ray.train.get_checkpoint()
if checkpoint:
# Upon trial resume, the local_trainer's objective gets reset to None.
# If `local_trainer.train` is a noop (training has already reached
# the target number of epochs/steps), then this would
# trigger an unnecessary extra checkpoint at the end of training.
# -> Set the objective to a dummy value upon resume as a workaround.
local_trainer.objective = "objective"
with checkpoint.as_directory() as checkpoint_dir:
checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
else:
local_trainer.train(trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(local_trainer, "objective", None) is None:
metrics = local_trainer.evaluate()
local_trainer.objective = local_trainer.compute_objective(metrics)
metrics.update({"objective": local_trainer.objective, "done": True})
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
ray.train.report(metrics, checkpoint=checkpoint)
if not trainer._memory_tracker.skip_memory_metrics:
from ..trainer_utils import TrainerMemoryTracker
logger.warning(
"Memory tracking for your Trainer is currently "
"enabled. Automatically disabling the memory tracker "
"since the memory tracker is not serializable."
)
trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search.
_tb_writer = trainer.pop_callback(TensorBoardCallback)
trainer.model = None
# Setup default `resources_per_trial`.
if "resources_per_trial" not in kwargs:
# Default to 1 CPU and 1 GPU (if applicable) per trial.
kwargs["resources_per_trial"] = {"cpu": 1}
if trainer.args.n_gpu > 0:
kwargs["resources_per_trial"]["gpu"] = 1
resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
logger.info(
"No `resources_per_trial` arg was passed into "
"`hyperparameter_search`. Setting it to a default value "
f"of {resource_msg} for each trial."
)
# Make sure each trainer only uses GPUs that were allocated per trial.
gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
trainer.args._n_gpu = gpus_per_trial
# Setup default `progress_reporter`.
if "progress_reporter" not in kwargs:
from ray.tune import CLIReporter
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
if "scheduler" in kwargs:
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
) and (not trainer.args.do_eval or trainer.args.eval_strategy == IntervalStrategy.NO):
raise RuntimeError(
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
"This means your trials will not report intermediate results to Ray Tune, and "
"can thus not be stopped early or used to exploit other trials parameters. "
"If this is what you want, do not use {cls}. If you would like to use {cls}, "
"make sure you pass `do_eval=True` and `eval_strategy='steps'` in the "
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
)
trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)
@functools.wraps(trainable)
def dynamic_modules_import_trainable(*args, **kwargs):
"""
Wrapper around `tune.with_parameters` to ensure datasets_modules are loaded on each Actor.
Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.
Assumes that `_objective`, defined above, is a function.
"""
if is_datasets_available():
import datasets.load
dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
# load dynamic_modules from path
spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
datasets_modules = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = datasets_modules
spec.loader.exec_module(datasets_modules)
return trainable(*args, **kwargs)
# special attr set by tune.with_parameters
if hasattr(trainable, "__mixins__"):
dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__
analysis = ray.tune.run(
dynamic_modules_import_trainable,
config=trainer.hp_space(None),
num_samples=n_trials,
**kwargs,
)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3], scope=trainer.args.ray_scope)
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config, analysis)
if _tb_writer is not None:
trainer.add_callback(_tb_writer)
return best_run
def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import sigopt
if trainer.args.process_index == 0:
if importlib.metadata.version("sigopt") >= "8.0.0":
sigopt.set_project("huggingface")
experiment = sigopt.create_experiment(
name="huggingface-tune",
type="offline",
parameters=trainer.hp_space(None),
metrics=[{"name": "objective", "objective": direction, "strategy": "optimize"}],
parallel_bandwidth=1,
budget=n_trials,
)
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
for run in experiment.loop():
with run:
trainer.objective = None
if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
trainer._hp_search_setup(run.run)
torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
trainer.train(resume_from_checkpoint=None)
else:
trainer.train(resume_from_checkpoint=None, trial=run.run)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
run.log_metric("objective", trainer.objective)
best = list(experiment.get_best_runs())[0]
best_run = BestRun(best.id, best.values["objective"].value, best.assignments)
else:
from sigopt import Connection
conn = Connection()
proxies = kwargs.pop("proxies", None)
if proxies is not None:
conn.set_proxies(proxies)
experiment = conn.experiments().create(
name="huggingface-tune",
parameters=trainer.hp_space(None),
metrics=[{"name": "objective", "objective": direction, "strategy": "optimize"}],
parallel_bandwidth=1,
observation_budget=n_trials,
project="huggingface",
)
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
while experiment.progress.observation_count < experiment.observation_budget:
suggestion = conn.experiments(experiment.id).suggestions().create()
trainer.objective = None
if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
trainer._hp_search_setup(suggestion)
torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
trainer.train(resume_from_checkpoint=None)
else:
trainer.train(resume_from_checkpoint=None, trial=suggestion)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
values = [{"name": "objective", "value": trainer.objective}]
obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
experiment = conn.experiments(experiment.id).fetch()
best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
best_run = BestRun(best.id, best.value, best.assignments)
return best_run
else:
for i in range(n_trials):
trainer.objective = None
args_main_rank = list(pickle.dumps(trainer.args))
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(args_main_rank, src=0)
args = pickle.loads(bytes(args_main_rank))
for key, value in asdict(args).items():
if key != "local_rank":
setattr(trainer.args, key, value)
trainer.train(resume_from_checkpoint=None)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
return None
def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
from ..integrations import is_wandb_available
if not is_wandb_available():
raise ImportError("This function needs wandb installed: `pip install wandb`")
import wandb
# add WandbCallback if not already added in trainer callbacks
reporting_to_wandb = False
for callback in trainer.callback_handler.callbacks:
if isinstance(callback, WandbCallback):
reporting_to_wandb = True
break
if not reporting_to_wandb:
trainer.add_callback(WandbCallback())
trainer.args.report_to = ["wandb"]
best_trial = {"run_id": None, "objective": None, "hyperparameters": None}
sweep_id = kwargs.pop("sweep_id", None)
project = kwargs.pop("project", None)
name = kwargs.pop("name", None)
entity = kwargs.pop("entity", None)
metric = kwargs.pop("metric", "eval/loss")
sweep_config = trainer.hp_space(None)
sweep_config["metric"]["goal"] = direction
sweep_config["metric"]["name"] = metric
if name:
sweep_config["name"] = name
def _objective():
run = wandb.run if wandb.run else wandb.init()
trainer.state.trial_name = run.name
run.config.update({"assignments": {}, "metric": metric})
config = wandb.config
trainer.objective = None
trainer.train(resume_from_checkpoint=None, trial=vars(config)["_items"])
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
format_metrics = rewrite_logs(metrics)
if metric not in format_metrics:
logger.warning(
f"Provided metric {metric} not found. This might result in unexpected sweeps charts. The available"
f" metrics are {format_metrics.keys()}"
)
best_score = False
if best_trial["run_id"] is not None:
if direction == "minimize":
best_score = trainer.objective < best_trial["objective"]
elif direction == "maximize":
best_score = trainer.objective > best_trial["objective"]
if best_score or best_trial["run_id"] is None:
best_trial["run_id"] = run.id
best_trial["objective"] = trainer.objective
best_trial["hyperparameters"] = dict(config)
return trainer.objective
sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
logger.info(f"wandb sweep id - {sweep_id}")
wandb.agent(sweep_id, function=_objective, count=n_trials)
return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])
def get_available_reporting_integrations():
integrations = []
if is_azureml_available() and not is_mlflow_available():
integrations.append("azure_ml")
if is_comet_available():
integrations.append("comet_ml")
if is_dagshub_available():
integrations.append("dagshub")
if is_dvclive_available():
integrations.append("dvclive")
if is_mlflow_available():
integrations.append("mlflow")
if is_neptune_available():
integrations.append("neptune")
if is_tensorboard_available():
integrations.append("tensorboard")
if is_wandb_available():
integrations.append("wandb")
if is_codecarbon_available():
integrations.append("codecarbon")
if is_clearml_available():
integrations.append("clearml")
return integrations
def rewrite_logs(d):
new_d = {}
eval_prefix = "eval_"
eval_prefix_len = len(eval_prefix)
test_prefix = "test_"
test_prefix_len = len(test_prefix)
for k, v in d.items():
if k.startswith(eval_prefix):
new_d["eval/" + k[eval_prefix_len:]] = v
elif k.startswith(test_prefix):
new_d["test/" + k[test_prefix_len:]] = v
else:
new_d["train/" + k] = v
return new_d
class TensorBoardCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
Args:
tb_writer (`SummaryWriter`, *optional*):
The writer to use. Will instantiate one if not set.
"""
def __init__(self, tb_writer=None):
has_tensorboard = is_tensorboard_available()
if not has_tensorboard:
raise RuntimeError(
"TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
" install tensorboardX."
)
if has_tensorboard:
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
self._SummaryWriter = SummaryWriter
except ImportError:
try:
from tensorboardX import SummaryWriter
self._SummaryWriter = SummaryWriter
except ImportError:
self._SummaryWriter = None
else:
self._SummaryWriter = None
self.tb_writer = tb_writer
def _init_summary_writer(self, args, log_dir=None):
log_dir = log_dir or args.logging_dir
if self._SummaryWriter is not None:
self.tb_writer = self._SummaryWriter(log_dir=log_dir)
def on_train_begin(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
log_dir = None
if state.is_hyper_param_search:
trial_name = state.trial_name
if trial_name is not None:
log_dir = os.path.join(args.logging_dir, trial_name)
if self.tb_writer is None:
self._init_summary_writer(args, log_dir)
if self.tb_writer is not None:
self.tb_writer.add_text("args", args.to_json_string())
if "model" in kwargs:
model = kwargs["model"]
if hasattr(model, "config") and model.config is not None:
model_config_json = model.config.to_json_string()
self.tb_writer.add_text("model_config", model_config_json)
def on_log(self, args, state, control, logs=None, **kwargs):
if not state.is_world_process_zero:
return
if self.tb_writer is None:
self._init_summary_writer(args)
if self.tb_writer is not None:
logs = rewrite_logs(logs)
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
elif isinstance(v, str):
self.tb_writer.add_text(k, v, state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
"This invocation of Tensorboard's writer.add_scalar() "
"is incorrect so we dropped this attribute."
)
self.tb_writer.flush()
def on_train_end(self, args, state, control, **kwargs):
if self.tb_writer:
self.tb_writer.close()
self.tb_writer = None
def save_model_architecture_to_file(model: Any, output_dir: str):
with open(f"{output_dir}/model_architecture.txt", "w+") as f:
if isinstance(model, PreTrainedModel):
print(model, file=f)
elif is_tf_available() and isinstance(model, TFPreTrainedModel):
def print_to_file(s):
print(s, file=f)
model.summary(print_fn=print_to_file)
elif is_torch_available() and (
isinstance(model, (torch.nn.Module, PushToHubMixin)) and hasattr(model, "base_model")
):
print(model, file=f)
class WandbLogModel(str, Enum):
"""Enum of possible log model values in W&B."""
CHECKPOINT = "checkpoint"
END = "end"
FALSE = "false"
@property
def is_enabled(self) -> bool:
"""Check if the value corresponds to a state where the `WANDB_LOG_MODEL` setting is enabled."""
return self in (WandbLogModel.CHECKPOINT, WandbLogModel.END)
@classmethod
def _missing_(cls, value: Any) -> "WandbLogModel":
if not isinstance(value, str):
raise ValueError(f"Expecting to have a string `WANDB_LOG_MODEL` setting, but got {type(value)}")
if value.upper() in ENV_VARS_TRUE_VALUES:
raise DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
return WandbLogModel.END
logger.warning(
f"Received unrecognized `WANDB_LOG_MODEL` setting value={value}; so disabling `WANDB_LOG_MODEL`"
)
return WandbLogModel.FALSE
class WandbCallback(TrainerCallback):
"""
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
"""
def __init__(self):
has_wandb = is_wandb_available()
if not has_wandb:
raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
if has_wandb:
import wandb
self._wandb = wandb
self._initialized = False
self._log_model = WandbLogModel(os.getenv("WANDB_LOG_MODEL", "false"))
def setup(self, args, state, model, **kwargs):
"""
Setup the optional Weights & Biases (*wandb*) integration.
One can subclass and override this method to customize the setup if needed. Find more information
[here](https://docs.wandb.ai/guides/integrations/huggingface). You can also override the following environment
variables:
Environment:
- **WANDB_LOG_MODEL** (`str`, *optional*, defaults to `"false"`):
Whether to log model and checkpoints during training. Can be `"end"`, `"checkpoint"` or `"false"`. If set
to `"end"`, the model will be uploaded at the end of training. If set to `"checkpoint"`, the checkpoint
will be uploaded every `args.save_steps` . If set to `"false"`, the model will not be uploaded. Use along
with [`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model.
<Deprecated version="5.0">
Setting `WANDB_LOG_MODEL` as `bool` will be deprecated in version 5 of 🤗 Transformers.
</Deprecated>
- **WANDB_WATCH** (`str`, *optional* defaults to `"false"`):
Can be `"gradients"`, `"all"`, `"parameters"`, or `"false"`. Set to `"all"` to log gradients and
parameters.
- **WANDB_PROJECT** (`str`, *optional*, defaults to `"huggingface"`):
Set this to a custom string to store results in a different project.
- **WANDB_DISABLED** (`bool`, *optional*, defaults to `False`):
Whether to disable wandb entirely. Set `WANDB_DISABLED=true` to disable.
"""
if self._wandb is None:
return
self._initialized = True
# prepare to handle potential configuration issues during setup
from wandb.sdk.lib.config_util import ConfigError as WandbConfigError
if state.is_world_process_zero:
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**args.to_dict()}
if hasattr(model, "config") and model.config is not None:
model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
if hasattr(model, "peft_config") and model.peft_config is not None:
peft_config = model.peft_config
combined_dict = {**{"peft_config": peft_config}, **combined_dict}
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
init_args["name"] = trial_name
init_args["group"] = args.run_name
elif args.run_name is not None:
init_args["name"] = args.run_name
if args.run_name == args.output_dir:
self._wandb.termwarn(
"The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was "
"not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.",
repeat=False,
)
if self._wandb.run is None:
self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
**init_args,
)
# add config parameters (run may have been created manually)
self._wandb.config.update(combined_dict, allow_val_change=True)
# define default x-axis (for latest wandb versions)
if getattr(self._wandb, "define_metric", None):
self._wandb.define_metric("train/global_step")
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
# keep track of model topology and gradients, unsupported on TPU
_watch_model = os.getenv("WANDB_WATCH", "false")
if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
self._wandb.run._label(code="transformers_trainer")
# add number of model parameters to wandb config
try:
self._wandb.config["model/num_parameters"] = model.num_parameters()
except AttributeError:
logger.info(
"Could not log the number of model parameters in Weights & Biases due to an AttributeError."
)
except WandbConfigError:
logger.warning(
"A ConfigError was raised whilst setting the number of model parameters in Weights & Biases config."
)
# log the initial model architecture to an artifact
if self._log_model.is_enabled:
with tempfile.TemporaryDirectory() as temp_dir:
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
model_artifact = self._wandb.Artifact(
name=model_name,
type="model",
metadata={
"model_config": model.config.to_dict() if hasattr(model, "config") else None,
"num_parameters": self._wandb.config.get("model/num_parameters"),
"initial_model": True,
},
)
# add the architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with model_artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
badge_markdown = (
f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
f'-28.svg" alt="Visualize in Weights & Biases" width="20'
f'0" height="32"/>]({self._wandb.run.get_url()})'
)
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
return
hp_search = state.is_hyper_param_search
if hp_search:
self._wandb.finish()
self._initialized = False
args.run_name = None
if not self._initialized:
self.setup(args, state, model, **kwargs)
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
from ..trainer import Trainer
fake_trainer = Trainer(args=args, model=model, processing_class=tokenizer, eval_dataset=["fake"])
with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir)
metadata = (
{
k: v
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
if not args.load_best_model_at_end
else {
f"eval/{args.metric_for_best_model}": state.best_metric,
"train/total_floss": state.total_flos,
"model/num_parameters": self._wandb.config.get("model/num_parameters"),
}
)
metadata["final_model"] = True
logger.info("Logging model artifacts. ...")
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
# add the model architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(artifact, aliases=["final_model"])
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
single_value_scalars = [
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"train_loss",
"total_flos",
]
if self._wandb is None:
return
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
for k, v in logs.items():
if k in single_value_scalars:
self._wandb.run.summary[k] = v
non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
non_scalar_logs = rewrite_logs(non_scalar_logs)
self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
def on_save(self, args, state, control, **kwargs):
if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero:
checkpoint_metadata = {
k: v
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters")
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
checkpoint_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
artifact.add_dir(artifact_path)
self._wandb.log_artifact(
artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"]
)