Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update fit with val hook test #8060

Merged
merged 27 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7500fd0
Add callback to hook tests and add predict test
carmocca May 27, 2021
27e2dcf
Fix lambda callback test
carmocca May 27, 2021
174be4c
Simplify lambda call test
carmocca May 28, 2021
e99e711
Use LambdaCallback
carmocca May 28, 2021
c52ab79
Dynamically append to called for the model
carmocca May 28, 2021
fcfe381
Remove print
carmocca May 28, 2021
2ca39c4
Consistency
carmocca May 28, 2021
aa8eea0
Consistency
carmocca May 28, 2021
1de5fbd
Prepare args/kwargs testing
carmocca May 28, 2021
736f1c2
yapf doesn't like dict literals
carmocca May 28, 2021
020d98d
Add arguments for fit no val test
carmocca May 28, 2021
c069e2d
Add arguments for fit no val test
carmocca May 28, 2021
6245149
Merge branch 'master' into tests/improve-hook-tests
carmocca Jun 10, 2021
deb67fb
Test arguments
carmocca Jun 11, 2021
4554003
Datamodule refactor
carmocca Jun 11, 2021
6c92649
Merge branch 'master' into tests/improve-hook-tests
carmocca Jun 17, 2021
8c8e059
Fix eval test
carmocca Jun 17, 2021
af39b28
Merge branch 'master' into tests/improve-hook-tests
carmocca Jun 21, 2021
6e9bcf8
Update full fit + val test
carmocca Jun 21, 2021
037100e
Update test
carmocca Jun 21, 2021
a5511f1
Remove FIXME
carmocca Jun 21, 2021
78b4062
Remove FIXME
carmocca Jun 21, 2021
fd65bb8
Undo change
carmocca Jun 21, 2021
c32e5e0
Fix
carmocca Jun 21, 2021
8f664a6
Fix save_checkpoint signature inspection
carmocca Jun 21, 2021
640360b
Update tests/models/test_hooks.py
carmocca Jun 21, 2021
ca3ec93
Merge branch 'master' into tests/improve-hook-test-full-fit
carmocca Jun 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def on_keyboard_interrupt(self):
@staticmethod
def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool:
parameters = list(signature(fn).parameters)
return len(parameters) == 2 and parameters[1] != "args"
return len(parameters) == 2 and parameters[0] != "args"

@staticmethod
def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def extra(self) -> Dict[str, Any]:
return self.get('_extra', {})

@extra.setter
def extra(self, extra: Mapping[str, Any]) -> None:
def extra(self, extra: Dict[str, Any]) -> None:
justusschock marked this conversation as resolved.
Show resolved Hide resolved

def check_fn(v):
if v.grad_fn is not None:
Expand All @@ -378,7 +378,8 @@ def check_fn(v):
return v.detach()
return v

extra = apply_to_collection(extra, torch.Tensor, check_fn)
# update instead of replace to keep the extra dict reference. TODO: remove with v1.6 deprecation removal
extra.update(apply_to_collection(extra, torch.Tensor, check_fn))
self['_extra'] = extra

def log(
Expand Down
85 changes: 1 addition & 84 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,95 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import ANY, call, MagicMock, Mock
from unittest.mock import call, Mock

from pytorch_lightning import Trainer
from tests.helpers import BoringModel


@mock.patch("torch.save") # need to mock torch.save or we get pickle error
def test_trainer_callback_hook_system_fit(_, tmpdir):
"""Test the callback hook system for fit."""

model = BoringModel()
callback_mock = MagicMock()
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[callback_mock],
max_epochs=1,
limit_val_batches=1,
limit_train_batches=3,
progress_bar_refresh_rate=0,
)

# check that only the to calls exists
assert trainer.callbacks[0] == callback_mock
assert callback_mock.method_calls == [
call.on_init_start(trainer),
call.on_init_end(trainer),
]

# fit model
trainer.fit(model)

assert callback_mock.method_calls == [
call.on_init_start(trainer),
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, stage='fit'),
call.on_configure_sharded_model(trainer, model),
call.on_fit_start(trainer, model),
call.on_pretrain_routine_start(trainer, model),
call.on_pretrain_routine_end(trainer, model),
call.on_sanity_check_start(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_sanity_check_end(trainer, model),
call.on_train_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_train_epoch_start(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 0, 0),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_batch_end(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 1, 0),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_batch_end(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 2, 0),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
call.on_batch_end(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
call.on_train_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, stage='fit'),
]


def test_callbacks_configured_in_model(tmpdir):
""" Test the callback system with callbacks added through the model hook. """

Expand Down
197 changes: 115 additions & 82 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,35 +283,38 @@ def test_epoch_end(self, *args, **kwargs):
pass

@staticmethod
def _train_batch():
return [
'on_train_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'forward',
'training_step',
'training_step_end',
'on_before_zero_grad',
'optimizer_zero_grad',
'backward',
'on_after_backward',
'optimizer_step',
'on_train_batch_end',
]

@staticmethod
def _val_batch():
return [
'on_validation_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'forward',
'validation_step',
'validation_step_end',
'on_validation_batch_end',
]
def _train_batch(trainer, model, batches):
out = []
for i in range(batches):
out.extend([
# TODO: `on_batch_{start,end}`
dict(name='Callback.on_batch_start', args=(trainer, model)),
carmocca marked this conversation as resolved.
Show resolved Hide resolved
dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name='on_train_batch_start', args=(ANY, i, 0)),
dict(name='on_before_batch_transfer', args=(ANY, None)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
dict(name='on_after_batch_transfer', args=(ANY, None)),
dict(name='forward', args=(ANY, )),
dict(name='training_step', args=(ANY, i)),
dict(name='training_step_end', args=(dict(loss=ANY), )),
dict(name='Callback.on_before_zero_grad', args=(trainer, model, ANY)),
dict(name='on_before_zero_grad', args=(ANY, )),
dict(name='optimizer_zero_grad', args=(0, i, ANY, 0)),
# TODO: `on_before_backward`
carmocca marked this conversation as resolved.
Show resolved Hide resolved
dict(name='backward', args=(ANY, ANY, 0)),
dict(name='Callback.on_after_backward', args=(trainer, model)),
dict(name='on_after_backward'),
# TODO: `on_before_optimizer_step`
carmocca marked this conversation as resolved.
Show resolved Hide resolved
dict(
name='optimizer_step',
args=(0, i, ANY, 0, ANY),
kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=False)
),
dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)),
dict(name='on_train_batch_end', args=(dict(loss=ANY), ANY, i, 0)),
dict(name='Callback.on_batch_end', args=(trainer, model)),
])
return out

@staticmethod
def _eval_epoch(fn, trainer, model, batches, key):
Expand Down Expand Up @@ -372,6 +375,7 @@ def _predict_batch(trainer, model, batches):
def test_trainer_model_hook_system_fit(tmpdir):
called = []
model = HookedModel(called)
callback = HookedCallback(called)
train_batches = 2
val_batches = 2
trainer = Trainer(
Expand All @@ -381,63 +385,89 @@ def test_trainer_model_hook_system_fit(tmpdir):
limit_val_batches=val_batches,
progress_bar_refresh_rate=0,
weights_summary=None,
callbacks=[callback]
)
assert called == []
assert called == [
dict(name='Callback.on_init_start', args=(trainer, )),
dict(name='Callback.on_init_end', args=(trainer, )),
]
trainer.fit(model)
saved_ckpt = {
'callbacks': ANY,
'epoch': 1,
'global_step': train_batches,
'lr_schedulers': ANY,
'optimizer_states': ANY,
'pytorch-lightning_version': __version__,
'state_dict': ANY,
}
expected = [
'prepare_data',
'configure_callbacks',
'setup',
'configure_sharded_model',
'configure_optimizers',
'on_fit_start',
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'on_val_dataloader',
'val_dataloader',
'train', # eval() == train(False)
'on_validation_model_eval',
'zero_grad',
'on_validation_start',
'on_epoch_start',
'on_validation_epoch_start',
*(HookedModel._val_batch() * val_batches),
'validation_epoch_end',
'on_validation_epoch_end',
'on_epoch_end',
'on_validation_end',
'train',
'on_validation_model_train',
dict(name='Callback.on_init_start', args=(trainer, )),
dict(name='Callback.on_init_end', args=(trainer, )),
dict(name='prepare_data'),
dict(name='configure_callbacks'),
dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)),
dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='fit')),
dict(name='setup', kwargs=dict(stage='fit')),
dict(name='configure_sharded_model'),
dict(name='Callback.on_configure_sharded_model', args=(trainer, model)),
dict(name='configure_optimizers'),
dict(name='Callback.on_fit_start', args=(trainer, model)),
dict(name='on_fit_start'),
dict(name='Callback.on_pretrain_routine_start', args=(trainer, model)),
dict(name='on_pretrain_routine_start'),
dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)),
dict(name='on_pretrain_routine_end'),
dict(name='Callback.on_sanity_check_start', args=(trainer, model)),
dict(name='on_val_dataloader'),
dict(name='val_dataloader'),
dict(name='train', args=(False, )),
dict(name='on_validation_model_eval'),
dict(name='zero_grad'),
dict(name='Callback.on_validation_start', args=(trainer, model)),
dict(name='on_validation_start'),
*model._eval_epoch('validation', trainer, model, val_batches, 'x'),
dict(name='Callback.on_validation_end', args=(trainer, model)),
dict(name='on_validation_end'),
dict(name='train'),
dict(name='on_validation_model_train'),
dict(name='Callback.on_sanity_check_end', args=(trainer, model)),
# duplicate `train` because `_run_train` calls it again in case validation wasn't run
'train',
'on_train_dataloader',
'train_dataloader',
'on_train_start',
'on_epoch_start',
'on_train_epoch_start',
*(HookedModel._train_batch() * train_batches),
'train', # eval() == train(False)
'on_validation_model_eval',
'zero_grad',
'on_validation_start',
'on_epoch_start',
'on_validation_epoch_start',
*(HookedModel._val_batch() * val_batches),
'validation_epoch_end',
'on_validation_epoch_end',
'on_epoch_end',
'on_save_checkpoint',
'on_validation_end',
'train',
'on_validation_model_train',
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
'on_train_end',
'on_fit_end',
'teardown',
dict(name='train'),
dict(name='on_train_dataloader'),
dict(name='train_dataloader'),
dict(name='Callback.on_train_start', args=(trainer, model)),
dict(name='on_train_start'),
dict(name='Callback.on_epoch_start', args=(trainer, model)),
dict(name='on_epoch_start'),
dict(name='Callback.on_train_epoch_start', args=(trainer, model)),
dict(name='on_train_epoch_start'),
*model._train_batch(trainer, model, train_batches),
dict(name='train', args=(False, )),
dict(name='on_validation_model_eval'),
dict(name='zero_grad'),
dict(name='Callback.on_validation_start', args=(trainer, model)),
dict(name='on_validation_start'),
*model._eval_epoch('validation', trainer, model, val_batches, 'x'),
dict(name='Callback.on_validation_end', args=(trainer, model)),
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_validation_end`
dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)),
dict(name='on_save_checkpoint', args=(saved_ckpt, )),
dict(name='on_validation_end'),
dict(name='train'),
dict(name='on_validation_model_train'),
dict(name='training_epoch_end', args=([dict(loss=ANY)] * train_batches, )),
dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY)] * train_batches)),
dict(name='on_train_epoch_end', args=([dict(loss=ANY)] * train_batches, )),
dict(name='Callback.on_epoch_end', args=(trainer, model)),
dict(name='on_epoch_end'),
dict(name='Callback.on_train_end', args=(trainer, model)),
dict(name='on_train_end'),
dict(name='Callback.on_fit_end', args=(trainer, model)),
dict(name='on_fit_end'),
dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage='fit')),
dict(name='teardown', kwargs=dict(stage='fit')),
]
called = [c['name'] for c in called]
assert called == expected


Expand Down Expand Up @@ -488,7 +518,10 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
'on_train_start',
'on_epoch_start',
'on_train_epoch_start',
*(HookedModel._train_batch() * train_batches),
*[
h['name']
for h in HookedModel._train_batch(trainer, model, train_batches) if not h['name'].startswith('Callback')
],
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
Expand Down