-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
move batch to device before sending it to hooks (#7378)
* update train step * test * x * limits * val * typeo * x * x * step * min gpus * run all loops * x * limit test * profiler * clean up accelerator code * move files * rename * move tests * changelog * reorder callbacks and model hooks * add test description * replace unneccessary method * fix chlog * adjust batch_to_device for DP Plugin * update tests for dataloader idx * unused imports * hook change * switch None * clear memory * change to None * None * None * memory savings * remove redundant todo * hack * cheat * Revert "cheat" This reverts commit a8433bd. * Revert "hack" This reverts commit 43a6d1e. * update new epoch loop * remove from old loop code * update chlog * update hook test * changelog * teardown * integrate changes in new eval loop * fix hook calls * add prediction step * bad merge * Revert "bad merge" This reverts commit 4880808. * fix train batch hook test * rm -rf _notebooks * update chlog * release memory * fix type * notebooks mess * debug * Revert "debug" This reverts commit eec4ee2. * teardown * fix teardown bug * debug * x * debug * Revert "debug" This reverts commit a6e6101. Revert "debug" This reverts commit 5ddeaec. debug debug Revert "debug" This reverts commit 605be74. Revert "Revert "debug"" This reverts commit a7612d5. debug x x x s tol x tol * Fix changelog Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
- Loading branch information
1 parent
8193bae
commit ea5cfd2
Showing
11 changed files
with
133 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from pytorch_lightning import Callback, Trainer | ||
from tests.helpers import BoringModel | ||
from tests.helpers.runif import RunIf | ||
|
||
|
||
class BatchHookObserverCallback(Callback): | ||
|
||
def on_train_batch_start(self, trainer, pl_module, batch, *args): | ||
assert batch.device == pl_module.device | ||
|
||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, *args): | ||
assert batch.device == pl_module.device | ||
|
||
def on_validation_batch_start(self, trainer, pl_module, batch, *args): | ||
assert batch.device == pl_module.device | ||
|
||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, *args): | ||
assert batch.device == pl_module.device | ||
|
||
def on_test_batch_start(self, trainer, pl_module, batch, *args): | ||
assert batch.device == pl_module.device | ||
|
||
def on_test_batch_end(self, trainer, pl_module, outputs, batch, *args): | ||
assert batch.device == pl_module.device | ||
|
||
def on_predict_batch_start(self, trainer, pl_module, batch, *args): | ||
assert batch.device == pl_module.device | ||
|
||
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, *args): | ||
assert batch.device == pl_module.device | ||
|
||
|
||
class BatchHookObserverModel(BoringModel): | ||
|
||
def on_train_batch_start(self, batch, *args): | ||
assert batch.device == self.device | ||
|
||
def on_train_batch_end(self, outputs, batch, *args): | ||
assert batch.device == self.device | ||
|
||
def on_validation_batch_start(self, batch, *args): | ||
assert batch.device == self.device | ||
|
||
def on_validation_batch_end(self, outputs, batch, *args): | ||
assert batch.device == self.device | ||
|
||
def on_test_batch_start(self, batch, *args): | ||
assert batch.device == self.device | ||
|
||
def on_test_batch_end(self, outputs, batch, *args): | ||
assert batch.device == self.device | ||
|
||
def on_predict_batch_start(self, batch, *args): | ||
assert batch.device == self.device | ||
|
||
def on_predict_batch_end(self, outputs, batch, *args): | ||
assert batch.device == self.device | ||
|
||
|
||
@RunIf(min_gpus=1) | ||
def test_callback_batch_on_device(tmpdir): | ||
""" Test that the batch object sent to the on_*_batch_start/end hooks is on the right device.""" | ||
|
||
batch_callback = BatchHookObserverCallback() | ||
|
||
model = BatchHookObserverModel() | ||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
max_steps=1, | ||
limit_train_batches=1, | ||
limit_val_batches=1, | ||
limit_test_batches=1, | ||
limit_predict_batches=1, | ||
gpus=1, | ||
callbacks=[batch_callback], | ||
) | ||
trainer.fit(model) | ||
trainer.validate(model) | ||
trainer.test(model) | ||
trainer.predict(model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters