-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
Copy pathcallback_connector.py
279 lines (242 loc) · 13.2 KB
/
callback_connector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import timedelta
from typing import Dict, List, Optional, Union
from pytorch_lightning.callbacks import (
Callback,
GradientAccumulationScheduler,
ModelCheckpoint,
ModelSummary,
ProgressBar,
ProgressBarBase,
RichProgressBar,
)
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
class CallbackConnector:
def __init__(self, trainer):
self.trainer = trainer
def on_trainer_init(
self,
callbacks: Optional[Union[List[Callback], Callback]],
checkpoint_callback: Optional[bool],
enable_checkpointing: bool,
enable_progress_bar: bool,
progress_bar_refresh_rate: Optional[int],
process_position: int,
default_root_dir: Optional[str],
weights_save_path: Optional[str],
weights_summary: Optional[str],
stochastic_weight_avg: bool,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
):
# init folder paths for checkpoint + weights save callbacks
self.trainer._default_root_dir = default_root_dir or os.getcwd()
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
if stochastic_weight_avg:
rank_zero_deprecation(
"Setting `Trainer(stochastic_weight_avg=True)` is deprecated in v1.5 and will be removed in v1.7."
" Please pass `pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging`"
" directly to the Trainer's `callbacks` argument instead."
)
self.trainer._stochastic_weight_avg = stochastic_weight_avg
# init callbacks
if isinstance(callbacks, Callback):
callbacks = [callbacks]
self.trainer.callbacks = callbacks or []
# configure checkpoint callback
# pass through the required args to figure out defaults
self._configure_checkpoint_callbacks(checkpoint_callback, enable_checkpointing)
# configure swa callback
self._configure_swa_callbacks()
# configure the timer callback.
# responsible to stop the training when max_time is reached.
self._configure_timer_callback(max_time)
# init progress bar
if process_position != 0:
rank_zero_deprecation(
f"Setting `Trainer(process_position={process_position})` is deprecated in v1.5 and will be removed"
" in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with"
" `process_position` directly to the Trainer's `callbacks` argument instead."
)
if progress_bar_refresh_rate is not None:
rank_zero_deprecation(
f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
" will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with"
" `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress"
" bar pass `enable_progress_bar = False` to the Trainer."
)
if enable_progress_bar:
self.trainer._progress_bar_callback = self.configure_progress_bar(
progress_bar_refresh_rate, process_position
)
else:
self.trainer._progress_bar_callback = None
# configure the ModelSummary callback
self._configure_model_summary_callback(weights_summary)
# accumulated grads
self._configure_accumulated_gradients(accumulate_grad_batches)
# push all checkpoint callbacks to the end
# it is important that these are the last callbacks to run
self.trainer.callbacks = self._reorder_callbacks(self.trainer.callbacks)
def _configure_accumulated_gradients(
self, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None
) -> None:
grad_accum_callback = [cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler)]
if grad_accum_callback:
if accumulate_grad_batches is not None:
raise MisconfigurationException(
"You have set both `accumulate_grad_batches` and passed an instance of "
"`GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches` "
"from trainer or remove `GradientAccumulationScheduler` from callbacks list."
)
grad_accum_callback = grad_accum_callback[0]
else:
if accumulate_grad_batches is None:
accumulate_grad_batches = 1
if isinstance(accumulate_grad_batches, dict):
grad_accum_callback = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
grad_accum_callback = GradientAccumulationScheduler({0: accumulate_grad_batches})
else:
raise MisconfigurationException(
f"`accumulate_grad_batches` should be an int or a dict. Got {accumulate_grad_batches}."
)
self.trainer.callbacks.append(grad_accum_callback)
self.trainer.accumulate_grad_batches = grad_accum_callback.get_accumulate_grad_batches(0)
self.trainer.accumulation_scheduler = grad_accum_callback
def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], enable_checkpointing: bool) -> None:
if checkpoint_callback is not None:
rank_zero_deprecation(
f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
f"be removed in v1.7. Please consider using `Trainer(enable_checkpointing={checkpoint_callback})`."
)
# if both are set then checkpoint only if both are True
enable_checkpointing = checkpoint_callback and enable_checkpointing
# TODO: Remove this error in v1.5 so we rely purely on the type signature
if not isinstance(enable_checkpointing, bool):
error_msg = (
"Invalid type provided for `enable_checkpointing`: "
f"Expected bool but received {type(enable_checkpointing)}."
)
if isinstance(enable_checkpointing, Callback):
error_msg += " Pass callback instances to the `callbacks` argument in the Trainer constructor instead."
raise MisconfigurationException(error_msg)
if self._trainer_has_checkpoint_callbacks() and enable_checkpointing is False:
raise MisconfigurationException(
"Trainer was configured with `enable_checkpointing=False` but found `ModelCheckpoint` in callbacks list."
)
if not self._trainer_has_checkpoint_callbacks() and enable_checkpointing is True:
self.trainer.callbacks.append(ModelCheckpoint())
def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None:
if any(isinstance(cb, ModelSummary) for cb in self.trainer.callbacks):
return
if weights_summary is not None:
if weights_summary not in ModelSummaryMode.supported_types():
raise MisconfigurationException(
f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
f" but got {weights_summary}",
)
max_depth = ModelSummaryMode.get_max_depth(weights_summary)
if self.trainer._progress_bar_callback is not None and isinstance(
self.trainer._progress_bar_callback, RichProgressBar
):
model_summary = RichModelSummary(max_depth=max_depth)
else:
model_summary = ModelSummary(max_depth=max_depth)
self.trainer.callbacks.append(model_summary)
self.trainer.weights_summary = weights_summary
def _configure_swa_callbacks(self):
if not self.trainer._stochastic_weight_avg:
return
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]
if not existing_swa:
self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks
def configure_progress_bar(self, refresh_rate=None, process_position=0):
if os.getenv("COLAB_GPU") and refresh_rate is None:
# smaller refresh rate on colab causes crashes, choose a higher value
refresh_rate = 20
refresh_rate = 1 if refresh_rate is None else refresh_rate
progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)]
if len(progress_bars) > 1:
raise MisconfigurationException(
"You added multiple progress bar callbacks to the Trainer, but currently only one"
" progress bar is supported."
)
if len(progress_bars) == 1:
progress_bar_callback = progress_bars[0]
elif refresh_rate > 0:
progress_bar_callback = ProgressBar(refresh_rate=refresh_rate, process_position=process_position)
self.trainer.callbacks.append(progress_bar_callback)
else:
progress_bar_callback = None
return progress_bar_callback
def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None:
if max_time is None:
return
if any(isinstance(cb, Timer) for cb in self.trainer.callbacks):
rank_zero_info("Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer.")
return
timer = Timer(duration=max_time, interval="step")
self.trainer.callbacks.append(timer)
def _trainer_has_checkpoint_callbacks(self):
return len(self.trainer.checkpoint_callbacks) > 0
def attach_model_logging_functions(self, model):
for callback in self.trainer.callbacks:
callback.log = model.log
callback.log_dict = model.log_dict
def _attach_model_callbacks(self) -> None:
"""Attaches the callbacks defined in the model.
If a callback returned by the model's configure_callback method has the same type as one or several
callbacks already present in the trainer callbacks list, it will replace them.
In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
will be pushed to the end of the list, ensuring they run last.
"""
model_callbacks = self.trainer.call_hook("configure_callbacks")
if not model_callbacks:
return
model_callback_types = {type(c) for c in model_callbacks}
trainer_callback_types = {type(c) for c in self.trainer.callbacks}
override_types = model_callback_types.intersection(trainer_callback_types)
if override_types:
rank_zero_info(
"The following callbacks returned in `LightningModule.configure_callbacks` will override"
" existing callbacks passed to Trainer:"
f" {', '.join(sorted(t.__name__ for t in override_types))}"
)
# remove all callbacks with a type that occurs in model callbacks
all_callbacks = [c for c in self.trainer.callbacks if type(c) not in override_types]
all_callbacks.extend(model_callbacks)
all_callbacks = CallbackConnector._reorder_callbacks(all_callbacks)
# TODO: connectors refactor: move callbacks list to connector and do not write Trainer state
self.trainer.callbacks = all_callbacks
@staticmethod
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
"""Moves all ModelCheckpoint callbacks to the end of the list. The sequential order within the group of
checkpoint callbacks is preserved, as well as the order of all other callbacks.
Args:
callbacks: A list of callbacks.
Return:
A new list in which the last elements are ModelCheckpoints if there were any present in the
input.
"""
checkpoints = [c for c in callbacks if isinstance(c, ModelCheckpoint)]
not_checkpoints = [c for c in callbacks if not isinstance(c, ModelCheckpoint)]
return not_checkpoints + checkpoints