Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove no return warning from val/test step #6139

Merged
merged 8 commits into from
Mar 6, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
rebase
  • Loading branch information
rohitgr7 committed Mar 6, 2021
commit d5584e1d20aa2bd1c1f83358564f8fbb880b9fee
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))


<<<<<<< HEAD
- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def forward(self, *inputs, **kwargs):
# it is done manually in ``LightningModule.manual_backward``
# `require_backward_grad_sync` will be reset in the
# ddp_plugin ``post_training_step`` hook
if self.module.automatic_optimization:
if not self.module.automatic_optimization:
trainer.model.require_backward_grad_sync = False
elif trainer and trainer.testing:
output = self.module.test_step(*inputs, **kwargs)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state == TrainerState.FITTING
and best_model_path is not None
self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state == TrainerState.FITTING
and best_model_path is not None
self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,7 @@ def get_evaluate_epoch_results(self):

# log results of evaluation
if (
self.trainer.state != TrainerState.FITTING
and self.trainer.evaluating
and self.trainer.is_global_zero
self.trainer.state != TrainerState.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
and self.trainer.verbose_evaluate
):
print('-' * 80)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def get_evaluation_dataloaders(self):
self.trainer.reset_val_dataloader(model)
if self.trainer.sanity_checking:
self.trainer.num_sanity_val_batches = [
min(self.trainer.num_sanity_val_steps, val_batches)
for val_batches in self.trainer.num_val_batches
min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
]
max_batches = self.trainer.num_sanity_val_batches
else:
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,8 +875,7 @@ def test(
# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
results = (
self.__evaluate_given_model(model, dataloaders=test_dataloaders)
if model_provided else
self.__evaluate_given_model(model, dataloaders=test_dataloaders) if model_provided else
self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders)
)

Expand Down
14 changes: 8 additions & 6 deletions tests/overrides/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
LightningParallelModule,
LightningDistributedModule,
])
@pytest.mark.parametrize("stage", [
("training", "training_step"),
("testing", "test_step"),
("validating", "validation_step"),
("predicting", "predict"),
])
@pytest.mark.parametrize(
"stage", [
("training", "training_step"),
("testing", "test_step"),
("validating", "validation_step"),
("predicting", "predict"),
]
)
def test_lightning_wrapper_module_methods(wrapper_class, stage):
""" Test that the LightningWrapper redirects .forward() to the LightningModule methods. """
pl_module = MagicMock()
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_trainer_state_while_running(tmpdir, extra_params):
trainer = Trainer(default_root_dir=tmpdir, **extra_params, auto_lr_find=True)

class TestModel(BoringModel):

def __init__(self, expected_state):
super().__init__()
self.expected_state = expected_state
Expand Down Expand Up @@ -78,6 +79,7 @@ def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params):
model = BoringModel()

class InterruptCallback(Callback):

def on_batch_start(self, trainer, pl_module):
raise KeyboardInterrupt

Expand Down
12 changes: 8 additions & 4 deletions tests/utilities/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,19 @@ class UnpicklableClass:


def test_parse_class_init_keys(tmpdir):

class Class:

def __init__(self, hparams, *my_args, anykw=42, **my_kwargs):
pass

assert parse_class_init_keys(Class) == ("self", "my_args", "my_kwargs")


def test_get_init_args(tmpdir):

class AutomaticArgsModel:

def __init__(self, anyarg, anykw=42, **kwargs):
super().__init__()

Expand All @@ -259,7 +263,9 @@ def get_init_args_wrapper(self):


def test_collect_init_args():

class AutomaticArgsParent:

def __init__(self, anyarg, anykw=42, **kwargs):
super().__init__()
self.get_init_args_wrapper()
Expand All @@ -269,6 +275,7 @@ def get_init_args_wrapper(self):
self.result = collect_init_args(frame, [])

class AutomaticArgsChild(AutomaticArgsParent):

def __init__(self, anyarg, childarg, anykw=42, childkw=42, **kwargs):
super().__init__(anyarg, anykw=anykw, **kwargs)

Expand Down Expand Up @@ -299,10 +306,7 @@ def test_attribute_dict(tmpdir):


def test_flatten_dict(tmpdir):
d = {
'1': 1,
'_': {'2': 2, '_': {'3': 3, '4': 4}}
}
d = {'1': 1, '_': {'2': 2, '_': {'3': 3, '4': 4}}}

expected = {
'1': 1,
Expand Down