Skip to content

Commit

Permalink
Lite: Support self.log from a LightningModule (#16311)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2023
1 parent a56c12c commit 91aaa53
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 44 deletions.
3 changes: 2 additions & 1 deletion src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `Fabric.log` for logging scalars using multiple loggers
* Added `Fabric.log_dict` for logging a dictionary of multiple metrics at once
* Added `Fabric.loggers` and `Fabric.logger` attributes to access the individual logger instances

* Added support for calling `self.log` and `self.log_dict` in a LightningModule when using Fabric
* Added access to `self.logger` and `self.loggers` in a LightningModule when using Fabric

- Added support for a consistent `.zero_grad(set_to_none=...)` on the wrapped optimizer regardless of which strategy is used ([#16275](https://github.com/Lightning-AI/lightning/issues/16275))

Expand Down
7 changes: 5 additions & 2 deletions src/lightning_fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)
from lightning_fabric.strategies.strategy import _Sharded, TBroadcast
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.apply_func import convert_to_tensors
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars, convert_to_tensors
from lightning_fabric.utilities.data import (
_auto_add_worker_init_fn,
_replace_dunder_methods,
Expand Down Expand Up @@ -597,7 +597,8 @@ def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
Args:
name: The name of the metric to log.
value: The metric value to collect.
value: The metric value to collect. If the value is a :class:`torch.Tensor`, it gets detached from the
graph automatically.
step: Optional step number. Most Logger implementations auto-increment the step value by one with every
log call. You can specify your own value here.
"""
Expand All @@ -608,9 +609,11 @@ def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
Args:
metrics: A dictionary where the key is the name of the metric and the value the scalar to be logged.
Any :class:`torch.Tensor` in the dictionary get detached from the graph automatically.
step: Optional step number. Most Logger implementations auto-increment this value by one with every
log call. You can specify your own value here.
"""
metrics = convert_tensors_to_scalars(metrics)
for logger in self._loggers:
logger.log_metrics(metrics=metrics, step=step)

Expand Down
18 changes: 18 additions & 0 deletions src/lightning_fabric/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,21 @@ def _move_to_device_and_make_contiguous(t: Tensor, device: _DEVICE) -> Tensor:

# make sure existing tensors are in the correct device, also contiguous
return apply_to_collection(data, Tensor, _move_to_device_and_make_contiguous, device=device)


def convert_tensors_to_scalars(data: Any) -> Any:
"""Recursively walk through a collection and convert single-item tensors to scalar values.
Raises:
ValueError:
If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar.
"""

def to_item(value: Tensor) -> Union[int, float, bool]:
if value.numel() != 1:
raise ValueError(
f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar."
)
return value.item()

return apply_to_collection(data, Tensor, to_item)
53 changes: 37 additions & 16 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ def log(
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
"""
if self._fabric is not None:
self._log_dict_through_fabric(dictionary={name: value}, logger=logger)
return

# check for invalid values
apply_to_collection(value, dict, self.__check_not_nested, name)
apply_to_collection(
Expand Down Expand Up @@ -554,22 +558,39 @@ def log_dict(
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
"""
for k, v in dictionary.items():
self.log(
name=k,
value=v,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_dist=sync_dist,
sync_dist_group=sync_dist_group,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
rank_zero_only=rank_zero_only,
)
if self._fabric is not None:
self._log_dict_through_fabric(dictionary=dictionary, logger=logger) # type: ignore[arg-type]
else:
for k, v in dictionary.items():
self.log(
name=k,
value=v,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_dist=sync_dist,
sync_dist_group=sync_dist_group,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
rank_zero_only=rank_zero_only,
)

def _log_dict_through_fabric(self, dictionary: Dict[str, Any], logger: Optional[bool] = None) -> None:
if logger is False:
# Passing `logger=False` with Fabric does not make much sense because there is no other destination to
# log to, but we support it in case the original code was written for Trainer use
return

if any(isinstance(v, dict) for v in dictionary.values()):
raise ValueError(f"`self.log_dict({dictionary})` was called, but nested dictionaries cannot be logged")
for name, value in dictionary.items():
apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor))

assert self._fabric is not None
self._fabric.log_dict(metrics=dictionary)

@staticmethod
def __check_not_nested(value: dict, name: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import pytorch_lightning as pl
from lightning_fabric.plugins.environments import SLURMEnvironment
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
from pytorch_lightning.loggers import Logger, TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.utilities.metrics import metrics_to_scalars


class LoggerConnector:
Expand Down Expand Up @@ -80,7 +80,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
self._logged_metrics.update(metrics)

# turn all tensors to scalars
scalar_metrics = metrics_to_scalars(metrics)
scalar_metrics = convert_tensors_to_scalars(metrics)

if step is None:
step = scalar_metrics.pop("step", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from typing_extensions import TypedDict

from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning_fabric.utilities.distributed import _distributed_available
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
from pytorch_lightning.utilities.warnings import PossibleUserWarning

Expand Down Expand Up @@ -610,7 +610,7 @@ def any_tensor(_: Any) -> None:

# populate progress_bar metrics. convert tensors to numbers
if result_metric.meta.prog_bar:
metrics["pbar"][forked_name] = metrics_to_scalars(value)
metrics["pbar"][forked_name] = convert_tensors_to_scalars(value)

return metrics

Expand Down
17 changes: 3 additions & 14 deletions src/pytorch_lightning/utilities/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions to operate on metric values."""
from __future__ import annotations

from typing import Any

from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars


def metrics_to_scalars(metrics: Any) -> Any:
"""Recursively walk through a collection and convert single-item tensors to scalar values.
Raises:
MisconfigurationException:
ValueError:
If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar.
"""

def to_item(value: Tensor) -> int | float | bool:
if value.numel() != 1:
raise MisconfigurationException(
f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar."
)
return value.item()

return apply_to_collection(metrics, Tensor, to_item)
return convert_tensors_to_scalars(metrics)
25 changes: 25 additions & 0 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,28 @@ def test_log_dict():
fabric.log_dict({"foo": 3, "bar": 4}, step=15)
logger0.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15)
logger1.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15)


def test_log_dict_input_parsing():
"""Test validation of input data types and preprocessing."""
logger = Mock()
fabric = Fabric(loggers=[logger])

# Tensor scalar, 0 dims
fabric.log("log", torch.tensor(1))
logger.log_metrics.assert_called_with(metrics={"log": 1}, step=None)
fabric.log_dict({"log_dict": torch.tensor(1)})
logger.log_metrics.assert_called_with(metrics={"log_dict": 1}, step=None)

# Tensor scalar, 1 dims
fabric.log("log", torch.tensor([2]))
logger.log_metrics.assert_called_with(metrics={"log": 2}, step=None)
fabric.log_dict({"log_dict": torch.tensor([2])})
logger.log_metrics.assert_called_with(metrics={"log_dict": 2}, step=None)

# Tensor, multiple dims
with pytest.raises(ValueError, match="it cannot be converted to a scalar."):
fabric.log("log", torch.tensor([3, 4]))

with pytest.raises(ValueError, match="it cannot be converted to a scalar."):
fabric.log_dict({"log_dict": torch.tensor([3, 4])})
19 changes: 18 additions & 1 deletion tests/tests_fabric/utilities/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from torch import Tensor

from lightning_fabric.utilities.apply_func import move_data_to_device
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars, move_data_to_device


@pytest.mark.parametrize("should_return", [False, True])
Expand All @@ -34,3 +34,20 @@ def to(self, device):
tensor = torch.tensor(0.1)
obj = TensorObject(tensor, should_return)
assert obj == move_data_to_device(obj, torch.device("cpu"))


def test_convert_tensors_to_scalars():
assert convert_tensors_to_scalars("string") == "string"
assert convert_tensors_to_scalars(1) == 1
assert convert_tensors_to_scalars(True) is True
assert convert_tensors_to_scalars({"scalar": 1.0}) == {"scalar": 1.0}

result = convert_tensors_to_scalars({"tensor": torch.tensor(2.0)})
# note: `==` comparison as above is not sufficient, since `torch.tensor(x) == x` evaluates to truth
assert not isinstance(result["tensor"], Tensor) and result["tensor"] == 2.0

result = convert_tensors_to_scalars({"tensor": torch.tensor([2.0])})
assert not isinstance(result["tensor"], Tensor) and result["tensor"] == 2.0

with pytest.raises(ValueError, match="does not contain a single element"):
convert_tensors_to_scalars({"tensor": torch.tensor([1, 2, 3])})
88 changes: 82 additions & 6 deletions tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,13 +512,89 @@ def test_fabric_attributes():

fabric = Fabric()
wrapped_module, wrapped_optimizer = fabric.setup(module, optimizer)
assert module.fabric is fabric
assert module._fabric_optimizers == [wrapped_optimizer]
assert wrapped_module.fabric is fabric
assert wrapped_module._fabric_optimizers == [wrapped_optimizer]

# Attribute access on LightningModule.trainer gets redirected to Fabric
assert isinstance(module.trainer, _TrainerFabricShim)
assert module.trainer.global_rank == 0
assert isinstance(wrapped_module.trainer, _TrainerFabricShim)
assert wrapped_module.trainer.global_rank == 0
with pytest.raises(AttributeError, match="Your LightningModule code tried to access `self.trainer.current_epoch`"):
_ = module.trainer.current_epoch
_ = wrapped_module.trainer.current_epoch

assert module.optimizers() == wrapped_optimizer
assert wrapped_module.optimizers() == wrapped_optimizer


def test_fabric_logger_access():
"""Test that the logger attribute can be accessed when the LightningModule is used together with Fabric."""
# No logger
module = BoringModel()
fabric = Fabric()
wrapped_module = fabric.setup(module)
assert wrapped_module.loggers == []
with pytest.raises(IndexError):
_ = wrapped_module.logger

# Single Logger
logger = Mock()
module = BoringModel()
fabric = Fabric(loggers=logger)
wrapped_module = fabric.setup(module)
assert wrapped_module.logger == logger
assert wrapped_module.loggers == [logger]

# Multiple loggers
logger1 = Mock()
logger2 = Mock()
module = BoringModel()
fabric = Fabric(loggers=[logger1, logger2])
wrapped_module = fabric.setup(module)
assert wrapped_module.logger == logger1
assert wrapped_module.loggers == [logger1, logger2]


def test_fabric_log():
logger = Mock()
module = BoringModel()
fabric = Fabric(loggers=[logger])
wrapped_module = fabric.setup(module)

# unsupported data type
with pytest.raises(ValueError, match="`list` values cannot be logged"):
wrapped_module.log("invalid", list())

# supported data types
wrapped_module.log("int", 1)
logger.log_metrics.assert_called_with(metrics={"int": 1}, step=None)
wrapped_module.log("float", 0.1)
logger.log_metrics.assert_called_with(metrics={"float": 0.1}, step=None)
wrapped_module.log("tensor", torch.tensor(0.1))
logger.log_metrics.assert_called_with(metrics={"tensor": torch.tensor(0.1)}, step=None)

# logger=False
logger.reset_mock()
wrapped_module.log("nothing", 1, logger=False)
logger.log_metrics.assert_not_called()


def test_fabric_log_dict():
logger = Mock()
module = BoringModel()
fabric = Fabric(loggers=[logger])
wrapped_module = fabric.setup(module)

# unsupported data type
with pytest.raises(ValueError, match="`list` values cannot be logged"):
wrapped_module.log_dict({"invalid": [1, 2, 3]})

# nested dicts
with pytest.raises(ValueError, match="nested dictionaries cannot be logged"):
wrapped_module.log_dict({"nested": {"nested": 1}})

# supported data types
wrapped_module.log_dict({"int": 1, "float": 0.1, "tensor": torch.tensor(0.1)})
logger.log_metrics.assert_called_with(metrics={"int": 1, "float": 0.1, "tensor": torch.tensor(0.1)}, step=None)

# logger=False
logger.reset_mock()
wrapped_module.log_dict({"nothing": 1}, logger=False)
logger.log_metrics.assert_not_called()

0 comments on commit 91aaa53

Please sign in to comment.