diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index fd9209435468b..0f4460a3d5144 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -41,7 +41,10 @@ def clean_namespace(hparams: MutableMapping) -> None: del_attrs = [k for k, v in hparams.items() if not is_picklable(v)] for k in del_attrs: - rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled") + rank_zero_warn( + f"Attribute '{k}' removed from hparams because it cannot be pickled. You can suppress this warning by" + f" setting `self.save_hyperparameters(ignore=['{k}'])`.", + ) del hparams[k] diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index c22ab1228575a..25c09b8e0b46d 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -528,19 +528,26 @@ def test_hparams_pickle(tmpdir): class UnpickleableArgsBoringModel(BoringModel): """A model that has an attribute that cannot be pickled.""" - def __init__(self, foo="bar", pickle_me=(lambda x: x + 1), **kwargs): + def __init__(self, foo="bar", pickle_me=(lambda x: x + 1), ignore=False, **kwargs): super().__init__(**kwargs) assert not is_picklable(pickle_me) - self.save_hyperparameters() + if ignore: + self.save_hyperparameters(ignore=["pickle_me"]) + else: + self.save_hyperparameters() def test_hparams_pickle_warning(tmpdir): model = UnpickleableArgsBoringModel() trainer = Trainer(default_root_dir=tmpdir, max_steps=1) - with pytest.warns(UserWarning, match="attribute 'pickle_me' removed from hparams because it cannot be pickled"): + with pytest.warns(UserWarning, match="Attribute 'pickle_me' removed from hparams because it cannot be pickled"): trainer.fit(model) assert "pickle_me" not in model.hparams + model = UnpickleableArgsBoringModel(ignore=True) + with no_warning_call(UserWarning, match="Attribute 'pickle_me' removed from hparams because it cannot be pickled"): + trainer.fit(model) + def test_hparams_save_yaml(tmpdir): class Options(str, Enum):