From fa7e864ff2159764dc2c4049a100fee08c61f8f5 Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 20 Dec 2022 14:46:14 +0100 Subject: [PATCH 01/19] support logging from LM --- src/pytorch_lightning/core/module.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index aa71bf2fab120..afb36a98f52fc 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -403,6 +403,14 @@ def log( rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call. """ + if self._fabric is not None and logger is not False: + apply_to_collection( + value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor) + ) + # TODO(fabric): Warn if on_epoch, on_step, etc. are set + self._fabric.log(name=name, value=value) + return + # check for invalid values apply_to_collection(value, dict, self.__check_not_nested, name) apply_to_collection( @@ -554,6 +562,12 @@ def log_dict( rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call. """ + if self._fabric is not None and logger is not False: + # TODO(fabric): Warn if on_epoch, on_step, etc. are set + # TODO(fabric): Restrict the type, check that it's not nested + self._fabric.log_dict(dictionary) + return + for k, v in dictionary.items(): self.log( name=k, From 97df628d88966a0ce4240dc161614281ed42f15e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Jan 2023 01:00:03 +0000 Subject: [PATCH 02/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/core/module.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index afb36a98f52fc..7c53349fb15c8 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -404,9 +404,7 @@ def log( would produce a deadlock as not all processes would perform this log call. """ if self._fabric is not None and logger is not False: - apply_to_collection( - value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor) - ) + apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor)) # TODO(fabric): Warn if on_epoch, on_step, etc. are set self._fabric.log(name=name, value=value) return From 9cf55aff3d2d5d0eed0baf408564cc8f6befa2e2 Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 10 Jan 2023 02:11:29 +0100 Subject: [PATCH 03/19] test --- .../core/test_lightning_module.py | 40 ++++++++++++++++--- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index e907640c8a9e3..c0fade800882c 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -512,13 +512,41 @@ def test_fabric_attributes(): fabric = Fabric() wrapped_module, wrapped_optimizer = fabric.setup(module, optimizer) - assert module.fabric is fabric - assert module._fabric_optimizers == [wrapped_optimizer] + assert wrapped_module.fabric is fabric + assert wrapped_module._fabric_optimizers == [wrapped_optimizer] # Attribute access on LightningModule.trainer gets redirected to Fabric - assert isinstance(module.trainer, _TrainerFabricShim) - assert module.trainer.global_rank == 0 + assert isinstance(wrapped_module.trainer, _TrainerFabricShim) + assert wrapped_module.trainer.global_rank == 0 with pytest.raises(AttributeError, match="Your LightningModule code tried to access `self.trainer.current_epoch`"): - _ = module.trainer.current_epoch + _ = wrapped_module.trainer.current_epoch - assert module.optimizers() == wrapped_optimizer + assert wrapped_module.optimizers() == wrapped_optimizer + + +def test_fabric_logger_access(): + """Test that the logger attribute can be accessed when the LightningModule is used together with Fabric.""" + # No logger + module = BoringModel() + fabric = Fabric() + wrapped_module = fabric.setup(module) + assert wrapped_module.loggers == [] + with pytest.raises(IndexError): + _ = wrapped_module.logger + + # Single Logger + logger = Mock() + module = BoringModel() + fabric = Fabric(loggers=logger) + wrapped_module = fabric.setup(module) + assert wrapped_module.logger == logger + assert wrapped_module.loggers == [logger] + + # Multiple loggers + logger1 = Mock() + logger2 = Mock() + module = BoringModel() + fabric = Fabric(loggers=[logger1, logger2]) + wrapped_module = fabric.setup(module) + assert wrapped_module.logger == logger1 + assert wrapped_module.loggers == [logger1, logger2] From adc0ce54f8d0178ee80a65eb6754349dcb24c3b0 Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 10 Jan 2023 02:13:20 +0100 Subject: [PATCH 04/19] changelog --- src/lightning_fabric/CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning_fabric/CHANGELOG.md b/src/lightning_fabric/CHANGELOG.md index 0a51fc3858cc4..271e1bff87304 100644 --- a/src/lightning_fabric/CHANGELOG.md +++ b/src/lightning_fabric/CHANGELOG.md @@ -35,7 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `Fabric.log` for logging scalars using multiple loggers * Added `Fabric.log_dict` for logging a dictionary of multiple metrics at once * Added `Fabric.loggers` and `Fabric.logger` attributes to access the individual logger instances - + * Added support for calling `self.log` and `self.log_dict` in a LightningModule when using Fabric + * Added access to `self.logger` and `self.loggers` in a LightningModule when using Fabric - Added support for a consistent `.zero_grad(set_to_none=...)` on the wrapped optimizer regardless of which strategy is used ([#16275](https://github.com/Lightning-AI/lightning/issues/16275)) From fba38cf45a306b5c09fe0423cfdf558d51392e25 Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 10 Jan 2023 02:39:11 +0100 Subject: [PATCH 05/19] test --- .../core/test_lightning_module.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index c0fade800882c..919ebbae2df67 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -550,3 +550,26 @@ def test_fabric_logger_access(): wrapped_module = fabric.setup(module) assert wrapped_module.logger == logger1 assert wrapped_module.loggers == [logger1, logger2] + + +def test_fabric_log(): + logger = Mock() + module = BoringModel() + fabric = Fabric(loggers=[logger]) + wrapped_module = fabric.setup(module) + + # unsupported data type + with pytest.raises(ValueError, match="`dict` values cannot be logged"): + wrapped_module.log("invalid", dict()) + + # unsupported data type + # with pytest.raises(ValueError, match="`list` values cannot be logged"): + # wrapped_module.log_dict("invalid", [1, 2, 3]) + + # self.log() + wrapped_module.log("loss", 0.1) + logger.log_metrics.assert_called_with(metrics={'loss': 0.1}, step=None) + + # self.log_dict() + wrapped_module.log_dict({"x": 1, "y": 2}) + logger.log_metrics.assert_called_with(metrics={"x": 1, "y": 2}, step=None) From 5871244b7d6ea607b9e0f38af21e22f8c0b168eb Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 10 Jan 2023 11:51:00 +0100 Subject: [PATCH 06/19] refactor --- src/pytorch_lightning/core/module.py | 62 ++++++++++++------- .../core/test_lightning_module.py | 41 +++++++++--- 2 files changed, 71 insertions(+), 32 deletions(-) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 7c53349fb15c8..b92971843e912 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -403,10 +403,8 @@ def log( rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call. """ - if self._fabric is not None and logger is not False: - apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor)) - # TODO(fabric): Warn if on_epoch, on_step, etc. are set - self._fabric.log(name=name, value=value) + if self._fabric is not None: + self._log_through_fabric(name=name, value=value, logger=logger) return # check for invalid values @@ -560,28 +558,44 @@ def log_dict( rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call. """ - if self._fabric is not None and logger is not False: - # TODO(fabric): Warn if on_epoch, on_step, etc. are set - # TODO(fabric): Restrict the type, check that it's not nested - self._fabric.log_dict(dictionary) + if self._fabric is not None: + self._log_dict_through_fabric(dictionary=dictionary, logger=logger) + else: + for k, v in dictionary.items(): + self.log( + name=k, + value=v, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + sync_dist=sync_dist, + sync_dist_group=sync_dist_group, + add_dataloader_idx=add_dataloader_idx, + batch_size=batch_size, + rank_zero_only=rank_zero_only, + ) + + def _log_through_fabric(self, name: str, value: _METRIC_COLLECTION, logger: Optional[bool] = None) -> None: + if logger is False: + # Passing `logger=False` with Fabric does not make much sense because there is no other destination to + # log to, but we support it in case the original code was written for Trainer use return + apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor)) + self._fabric.log(name=name, value=value) - for k, v in dictionary.items(): - self.log( - name=k, - value=v, - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - enable_graph=enable_graph, - sync_dist=sync_dist, - sync_dist_group=sync_dist_group, - add_dataloader_idx=add_dataloader_idx, - batch_size=batch_size, - rank_zero_only=rank_zero_only, - ) + def _log_dict_through_fabric(self, dictionary, logger: Optional[bool] = None)-> None: + if logger is False: + return + + if any(isinstance(v, dict) for v in dictionary.values()): + raise ValueError(f"`self.log_dict({dictionary})` was called, but nested dictionaries cannot be logged") + for name, value in dictionary.items(): + apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor)) + + self._fabric.log_dict(metrics=dictionary) @staticmethod def __check_not_nested(value: dict, name: str) -> None: diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 919ebbae2df67..3e0d2d75b734c 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -562,14 +562,39 @@ def test_fabric_log(): with pytest.raises(ValueError, match="`dict` values cannot be logged"): wrapped_module.log("invalid", dict()) + # supported data types + wrapped_module.log("int", 1) + logger.log_metrics.assert_called_with(metrics={'int': 1}, step=None) + wrapped_module.log("float", 0.1) + logger.log_metrics.assert_called_with(metrics={'float': 0.1}, step=None) + wrapped_module.log("tensor", torch.tensor(0.1)) + logger.log_metrics.assert_called_with(metrics={'tensor': torch.tensor(0.1)}, step=None) + + # logger=False + logger.reset_mock() + wrapped_module.log("nothing", 1, logger=False) + logger.log_metrics.assert_not_called() + + +def test_fabric_log_dict(): + logger = Mock() + module = BoringModel() + fabric = Fabric(loggers=[logger]) + wrapped_module = fabric.setup(module) + # unsupported data type - # with pytest.raises(ValueError, match="`list` values cannot be logged"): - # wrapped_module.log_dict("invalid", [1, 2, 3]) + with pytest.raises(ValueError, match="`list` values cannot be logged"): + wrapped_module.log_dict({"invalid": [1, 2, 3]}) + + # nested dicts + with pytest.raises(ValueError, match="nested dictionaries cannot be logged"): + wrapped_module.log_dict({"nested": {"nested": 1}}) - # self.log() - wrapped_module.log("loss", 0.1) - logger.log_metrics.assert_called_with(metrics={'loss': 0.1}, step=None) + # supported data types + wrapped_module.log_dict({"int": 1, "float": 0.1, "tensor": torch.tensor(0.1)}) + logger.log_metrics.assert_called_with(metrics={"int": 1, "float": 0.1, "tensor": torch.tensor(0.1)}, step=None) - # self.log_dict() - wrapped_module.log_dict({"x": 1, "y": 2}) - logger.log_metrics.assert_called_with(metrics={"x": 1, "y": 2}, step=None) + # logger=False + logger.reset_mock() + wrapped_module.log_dict({"nothing": 1}, logger=False) + logger.log_metrics.assert_not_called() From b30b51e19811ff9b107afc576dcc0145fd7bfe4a Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 10 Jan 2023 11:56:45 +0100 Subject: [PATCH 07/19] types --- src/pytorch_lightning/core/module.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index b92971843e912..de1bab5c30c89 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -559,7 +559,10 @@ def log_dict( would produce a deadlock as not all processes would perform this log call. """ if self._fabric is not None: - self._log_dict_through_fabric(dictionary=dictionary, logger=logger) + self._log_dict_through_fabric( + dictionary=dictionary, # type: ignore[arg-type] + logger=logger + ) else: for k, v in dictionary.items(): self.log( @@ -584,9 +587,11 @@ def _log_through_fabric(self, name: str, value: _METRIC_COLLECTION, logger: Opti # log to, but we support it in case the original code was written for Trainer use return apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor)) + + assert self._fabric is not None self._fabric.log(name=name, value=value) - def _log_dict_through_fabric(self, dictionary, logger: Optional[bool] = None)-> None: + def _log_dict_through_fabric(self, dictionary: Dict[str, Any], logger: Optional[bool] = None) -> None: if logger is False: return @@ -595,6 +600,7 @@ def _log_dict_through_fabric(self, dictionary, logger: Optional[bool] = None)-> for name, value in dictionary.items(): apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor)) + assert self._fabric is not None self._fabric.log_dict(metrics=dictionary) @staticmethod From a6b576ad7d41573fd9568541ad6145a0f6036781 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Jan 2023 10:57:55 +0000 Subject: [PATCH 08/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/core/module.py | 5 +---- tests/tests_pytorch/core/test_lightning_module.py | 6 +++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index de1bab5c30c89..a54e406a9f9a5 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -559,10 +559,7 @@ def log_dict( would produce a deadlock as not all processes would perform this log call. """ if self._fabric is not None: - self._log_dict_through_fabric( - dictionary=dictionary, # type: ignore[arg-type] - logger=logger - ) + self._log_dict_through_fabric(dictionary=dictionary, logger=logger) # type: ignore[arg-type] else: for k, v in dictionary.items(): self.log( diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 3e0d2d75b734c..25f87e055d415 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -564,11 +564,11 @@ def test_fabric_log(): # supported data types wrapped_module.log("int", 1) - logger.log_metrics.assert_called_with(metrics={'int': 1}, step=None) + logger.log_metrics.assert_called_with(metrics={"int": 1}, step=None) wrapped_module.log("float", 0.1) - logger.log_metrics.assert_called_with(metrics={'float': 0.1}, step=None) + logger.log_metrics.assert_called_with(metrics={"float": 0.1}, step=None) wrapped_module.log("tensor", torch.tensor(0.1)) - logger.log_metrics.assert_called_with(metrics={'tensor': torch.tensor(0.1)}, step=None) + logger.log_metrics.assert_called_with(metrics={"tensor": torch.tensor(0.1)}, step=None) # logger=False logger.reset_mock() From 7268aac932958e91345f15ed30a5a7e0c739df34 Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 10 Jan 2023 12:42:00 +0100 Subject: [PATCH 09/19] validation --- src/lightning_fabric/fabric.py | 11 ++++++++++- tests/tests_fabric/test_fabric.py | 25 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/lightning_fabric/fabric.py b/src/lightning_fabric/fabric.py index fbfe572ebbdb5..c45159c7efd76 100644 --- a/src/lightning_fabric/fabric.py +++ b/src/lightning_fabric/fabric.py @@ -597,7 +597,8 @@ def log(self, name: str, value: Any, step: Optional[int] = None) -> None: Args: name: The name of the metric to log. - value: The metric value to collect. + value: The metric value to collect. If the value is a :class:`torch.Tensor`, it gets detached from the + graph automatically. step: Optional step number. Most Logger implementations auto-increment the step value by one with every log call. You can specify your own value here. """ @@ -608,9 +609,17 @@ def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: Args: metrics: A dictionary where the key is the name of the metric and the value the scalar to be logged. + Any :class:`torch.Tensor`s in the dictionary get detached from the graph automatically. step: Optional step number. Most Logger implementations auto-increment this value by one with every log call. You can specify your own value here. """ + + def to_item(value: Tensor) -> int | float | bool: + if value.numel() != 1: + raise ValueError("Logging tensors with more than one element is not supported.") + return value.item() + + metrics = apply_to_collection(metrics, dtype=Tensor, function=to_item) for logger in self._loggers: logger.log_metrics(metrics=metrics, step=step) diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 88ac3d9fb7d82..671257a42a440 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -800,3 +800,28 @@ def test_log_dict(): fabric.log_dict({"foo": 3, "bar": 4}, step=15) logger0.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15) logger1.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15) + + +def test_log_dict_input_parsing(): + """Test validation of input data types and preprocessing.""" + logger = Mock() + fabric = Fabric(loggers=[logger]) + + # Tensor scalar, 0 dims + fabric.log("log", torch.tensor(1)) + logger.log_metrics.assert_called_with(metrics={"log": 1}, step=None) + fabric.log_dict({"log_dict": torch.tensor(1)}) + logger.log_metrics.assert_called_with(metrics={"log_dict": 1}, step=None) + + # Tensor scalar, 1 dims + fabric.log("log", torch.tensor([2])) + logger.log_metrics.assert_called_with(metrics={"log": 2}, step=None) + fabric.log_dict({"log_dict": torch.tensor([2])}) + logger.log_metrics.assert_called_with(metrics={"log_dict": 2}, step=None) + + # Tensor, multiple dims + with pytest.raises(ValueError, match="Logging tensors with more than one element is not supported"): + fabric.log("log", torch.tensor([3, 4])) + + with pytest.raises(ValueError, match="Logging tensors with more than one element is not supported"): + fabric.log_dict({"log_dict": torch.tensor([3, 4])}) From d3109ea89eaec9f529fcbedbe48bc4bc6b39a691 Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 10 Jan 2023 14:40:03 +0100 Subject: [PATCH 10/19] reuse the log_dict_through_fabric --- src/pytorch_lightning/core/module.py | 12 ++---------- tests/tests_pytorch/core/test_lightning_module.py | 4 ++-- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index a54e406a9f9a5..878313b675bdc 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -404,7 +404,7 @@ def log( would produce a deadlock as not all processes would perform this log call. """ if self._fabric is not None: - self._log_through_fabric(name=name, value=value, logger=logger) + self._log_dict_through_fabric(dictionary={name: value}, logger=logger) return # check for invalid values @@ -578,19 +578,11 @@ def log_dict( rank_zero_only=rank_zero_only, ) - def _log_through_fabric(self, name: str, value: _METRIC_COLLECTION, logger: Optional[bool] = None) -> None: + def _log_dict_through_fabric(self, dictionary: Dict[str, Any], logger: Optional[bool] = None) -> None: if logger is False: # Passing `logger=False` with Fabric does not make much sense because there is no other destination to # log to, but we support it in case the original code was written for Trainer use return - apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor)) - - assert self._fabric is not None - self._fabric.log(name=name, value=value) - - def _log_dict_through_fabric(self, dictionary: Dict[str, Any], logger: Optional[bool] = None) -> None: - if logger is False: - return if any(isinstance(v, dict) for v in dictionary.values()): raise ValueError(f"`self.log_dict({dictionary})` was called, but nested dictionaries cannot be logged") diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 25f87e055d415..95be3ccb0eb4c 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -559,8 +559,8 @@ def test_fabric_log(): wrapped_module = fabric.setup(module) # unsupported data type - with pytest.raises(ValueError, match="`dict` values cannot be logged"): - wrapped_module.log("invalid", dict()) + with pytest.raises(ValueError, match="`list` values cannot be logged"): + wrapped_module.log("invalid", list()) # supported data types wrapped_module.log("int", 1) From 2ae8ceb5974c084cd7119fec8fe9a22a02ca4c0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 10 Jan 2023 08:48:17 -0500 Subject: [PATCH 11/19] Update src/lightning_fabric/fabric.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- src/lightning_fabric/fabric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_fabric/fabric.py b/src/lightning_fabric/fabric.py index c45159c7efd76..661d62c019daf 100644 --- a/src/lightning_fabric/fabric.py +++ b/src/lightning_fabric/fabric.py @@ -614,7 +614,7 @@ def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: log call. You can specify your own value here. """ - def to_item(value: Tensor) -> int | float | bool: + def to_item(value: Tensor) -> Union[int, float, bool]: if value.numel() != 1: raise ValueError("Logging tensors with more than one element is not supported.") return value.item() From 789d41306b004d33b559f47f85ee0c84c5635926 Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 10 Jan 2023 15:01:28 +0100 Subject: [PATCH 12/19] reuse metrics_to_scalars without deprecation --- src/lightning_fabric/utilities/apply_func.py | 18 ++++++++++++++++++ .../logger_connector/logger_connector.py | 4 ++-- .../connectors/logger_connector/result.py | 4 ++-- src/pytorch_lightning/utilities/metrics.py | 17 +++-------------- 4 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/lightning_fabric/utilities/apply_func.py b/src/lightning_fabric/utilities/apply_func.py index 29732731a54ff..0fc535215057d 100644 --- a/src/lightning_fabric/utilities/apply_func.py +++ b/src/lightning_fabric/utilities/apply_func.py @@ -111,3 +111,21 @@ def _move_to_device_and_make_contiguous(t: Tensor, device: _DEVICE) -> Tensor: # make sure existing tensors are in the correct device, also contiguous return apply_to_collection(data, Tensor, _move_to_device_and_make_contiguous, device=device) + + +def convert_tensors_to_scalars(data: Any) -> Any: + """Recursively walk through a collection and convert single-item tensors to scalar values. + + Raises: + ValueError: + If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar. + """ + + def to_item(value: Tensor) -> Union[int | float | bool]: + if value.numel() != 1: + raise ValueError( + f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar." + ) + return value.item() + + return apply_to_collection(data, Tensor, to_item) diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a8b28866ea4c2..70c04f454bedb 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -19,9 +19,9 @@ import pytorch_lightning as pl from lightning_fabric.plugins.environments import SLURMEnvironment from lightning_fabric.utilities import move_data_to_device +from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars from pytorch_lightning.loggers import Logger, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT -from pytorch_lightning.utilities.metrics import metrics_to_scalars class LoggerConnector: @@ -80,7 +80,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: self._logged_metrics.update(metrics) # turn all tensors to scalars - scalar_metrics = metrics_to_scalars(metrics) + scalar_metrics = convert_tensors_to_scalars(metrics) if step is None: step = scalar_metrics.pop("step", None) diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py index 4cd72cc0e0e29..853350ab53e31 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -24,11 +24,11 @@ from lightning_fabric.utilities import move_data_to_device from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning_fabric.utilities.distributed import _distributed_available +from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars from pytorch_lightning.utilities.data import extract_batch_size from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.memory import recursive_detach -from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache from pytorch_lightning.utilities.warnings import PossibleUserWarning @@ -610,7 +610,7 @@ def any_tensor(_: Any) -> None: # populate progress_bar metrics. convert tensors to numbers if result_metric.meta.prog_bar: - metrics["pbar"][forked_name] = metrics_to_scalars(value) + metrics["pbar"][forked_name] = convert_tensors_to_scalars(value) return metrics diff --git a/src/pytorch_lightning/utilities/metrics.py b/src/pytorch_lightning/utilities/metrics.py index bbc27e4e672a9..62734e1a47342 100644 --- a/src/pytorch_lightning/utilities/metrics.py +++ b/src/pytorch_lightning/utilities/metrics.py @@ -12,29 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. """Helper functions to operate on metric values.""" -from __future__ import annotations from typing import Any -from lightning_utilities.core.apply_func import apply_to_collection -from torch import Tensor - -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars def metrics_to_scalars(metrics: Any) -> Any: """Recursively walk through a collection and convert single-item tensors to scalar values. Raises: - MisconfigurationException: + ValueError: If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar. """ - def to_item(value: Tensor) -> int | float | bool: - if value.numel() != 1: - raise MisconfigurationException( - f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar." - ) - return value.item() - - return apply_to_collection(metrics, Tensor, to_item) + return convert_tensors_to_scalars(metrics) From 117acdb7f2112ed33bec578b6006a807585b5272 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Jan 2023 14:03:33 +0000 Subject: [PATCH 13/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py index 853350ab53e31..1de991900b886 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -22,9 +22,9 @@ from typing_extensions import TypedDict from lightning_fabric.utilities import move_data_to_device +from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning_fabric.utilities.distributed import _distributed_available -from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars from pytorch_lightning.utilities.data import extract_batch_size from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training From c72085b466bf1dc116e42f25e9c5a19943185147 Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 10 Jan 2023 15:07:39 +0100 Subject: [PATCH 14/19] reuse utility code in fabric --- src/lightning_fabric/fabric.py | 10 ++-------- tests/tests_fabric/test_fabric.py | 4 ++-- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/lightning_fabric/fabric.py b/src/lightning_fabric/fabric.py index 661d62c019daf..f4f750f6b777e 100644 --- a/src/lightning_fabric/fabric.py +++ b/src/lightning_fabric/fabric.py @@ -42,7 +42,7 @@ ) from lightning_fabric.strategies.strategy import _Sharded, TBroadcast from lightning_fabric.utilities import move_data_to_device -from lightning_fabric.utilities.apply_func import convert_to_tensors +from lightning_fabric.utilities.apply_func import convert_to_tensors, convert_tensors_to_scalars from lightning_fabric.utilities.data import ( _auto_add_worker_init_fn, _replace_dunder_methods, @@ -613,13 +613,7 @@ def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: step: Optional step number. Most Logger implementations auto-increment this value by one with every log call. You can specify your own value here. """ - - def to_item(value: Tensor) -> Union[int, float, bool]: - if value.numel() != 1: - raise ValueError("Logging tensors with more than one element is not supported.") - return value.item() - - metrics = apply_to_collection(metrics, dtype=Tensor, function=to_item) + metrics = convert_tensors_to_scalars(metrics) for logger in self._loggers: logger.log_metrics(metrics=metrics, step=step) diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 671257a42a440..b5981265e2bcf 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -820,8 +820,8 @@ def test_log_dict_input_parsing(): logger.log_metrics.assert_called_with(metrics={"log_dict": 2}, step=None) # Tensor, multiple dims - with pytest.raises(ValueError, match="Logging tensors with more than one element is not supported"): + with pytest.raises(ValueError, match="it cannot be converted to a scalar."): fabric.log("log", torch.tensor([3, 4])) - with pytest.raises(ValueError, match="Logging tensors with more than one element is not supported"): + with pytest.raises(ValueError, match="it cannot be converted to a scalar."): fabric.log_dict({"log_dict": torch.tensor([3, 4])}) From 53fa4f48e13c1820b9d1b5b866cccbb2bd5493e7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Jan 2023 14:09:36 +0000 Subject: [PATCH 15/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_fabric/fabric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_fabric/fabric.py b/src/lightning_fabric/fabric.py index f4f750f6b777e..70886ab63d7fb 100644 --- a/src/lightning_fabric/fabric.py +++ b/src/lightning_fabric/fabric.py @@ -42,7 +42,7 @@ ) from lightning_fabric.strategies.strategy import _Sharded, TBroadcast from lightning_fabric.utilities import move_data_to_device -from lightning_fabric.utilities.apply_func import convert_to_tensors, convert_tensors_to_scalars +from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars, convert_to_tensors from lightning_fabric.utilities.data import ( _auto_add_worker_init_fn, _replace_dunder_methods, From 6f10f9d33bb0f904282bc8b03c6b82a9bd90e5ca Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 10 Jan 2023 15:16:20 +0100 Subject: [PATCH 16/19] add test for utility function --- .../tests_fabric/utilities/test_apply_func.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/tests_fabric/utilities/test_apply_func.py b/tests/tests_fabric/utilities/test_apply_func.py index 3a265942769b2..683ff90f5789b 100644 --- a/tests/tests_fabric/utilities/test_apply_func.py +++ b/tests/tests_fabric/utilities/test_apply_func.py @@ -15,7 +15,7 @@ import torch from torch import Tensor -from lightning_fabric.utilities.apply_func import move_data_to_device +from lightning_fabric.utilities.apply_func import move_data_to_device, convert_tensors_to_scalars @pytest.mark.parametrize("should_return", [False, True]) @@ -34,3 +34,20 @@ def to(self, device): tensor = torch.tensor(0.1) obj = TensorObject(tensor, should_return) assert obj == move_data_to_device(obj, torch.device("cpu")) + + +def test_convert_tensors_to_scalars(): + assert convert_tensors_to_scalars("string") == "string" + assert convert_tensors_to_scalars(1) == 1 + assert convert_tensors_to_scalars(True) is True + assert convert_tensors_to_scalars({"scalar": 1.0}) == {"scalar": 1.0} + + result = convert_tensors_to_scalars({"tensor": torch.tensor(2.0)}) + # note: `==` comparison as above is not sufficient, since `torch.tensor(x) == x` evaluates to truth + assert not isinstance(result["tensor"], Tensor) and result["tensor"] == 2.0 + + result = convert_tensors_to_scalars({"tensor": torch.tensor([2.0])}) + assert not isinstance(result["tensor"], Tensor) and result["tensor"] == 2.0 + + with pytest.raises(ValueError, match="does not contain a single element"): + convert_tensors_to_scalars({"tensor": torch.tensor([1, 2, 3])}) From 78573a02233409d43e65246c0450e23991c5a16f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Jan 2023 14:17:41 +0000 Subject: [PATCH 17/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/utilities/test_apply_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/utilities/test_apply_func.py b/tests/tests_fabric/utilities/test_apply_func.py index 683ff90f5789b..a2763270b9fa9 100644 --- a/tests/tests_fabric/utilities/test_apply_func.py +++ b/tests/tests_fabric/utilities/test_apply_func.py @@ -15,7 +15,7 @@ import torch from torch import Tensor -from lightning_fabric.utilities.apply_func import move_data_to_device, convert_tensors_to_scalars +from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars, move_data_to_device @pytest.mark.parametrize("should_return", [False, True]) From 52f962523a370083536aea6f8c144a5bfeb39b1e Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 10 Jan 2023 15:47:31 +0100 Subject: [PATCH 18/19] fix type --- src/lightning_fabric/utilities/apply_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_fabric/utilities/apply_func.py b/src/lightning_fabric/utilities/apply_func.py index 0fc535215057d..5a55c1f8cac9b 100644 --- a/src/lightning_fabric/utilities/apply_func.py +++ b/src/lightning_fabric/utilities/apply_func.py @@ -121,7 +121,7 @@ def convert_tensors_to_scalars(data: Any) -> Any: If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar. """ - def to_item(value: Tensor) -> Union[int | float | bool]: + def to_item(value: Tensor) -> Union[int, float, bool]: if value.numel() != 1: raise ValueError( f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar." From c5b9fa7eb98bfcb79fcabd2c2f0cd379fc295f28 Mon Sep 17 00:00:00 2001 From: lightningforever Date: Tue, 10 Jan 2023 16:37:23 +0100 Subject: [PATCH 19/19] fix docs parsing error --- src/lightning_fabric/fabric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_fabric/fabric.py b/src/lightning_fabric/fabric.py index 70886ab63d7fb..42ae8f0e931f6 100644 --- a/src/lightning_fabric/fabric.py +++ b/src/lightning_fabric/fabric.py @@ -609,7 +609,7 @@ def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: Args: metrics: A dictionary where the key is the name of the metric and the value the scalar to be logged. - Any :class:`torch.Tensor`s in the dictionary get detached from the graph automatically. + Any :class:`torch.Tensor` in the dictionary get detached from the graph automatically. step: Optional step number. Most Logger implementations auto-increment this value by one with every log call. You can specify your own value here. """