Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lite: Support self.log from a LightningModule #16311

Merged
merged 23 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
fa7e864
support logging from LM
lightningforever Dec 20, 2022
97df628
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2023
9cf55af
test
lightningforever Jan 10, 2023
adc0ce5
changelog
lightningforever Jan 10, 2023
fba38cf
test
lightningforever Jan 10, 2023
5871244
refactor
lightningforever Jan 10, 2023
b30b51e
types
lightningforever Jan 10, 2023
a6b576a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2023
7268aac
validation
lightningforever Jan 10, 2023
7543472
Merge remote-tracking branch 'origin/lite/logger-lm' into lite/logger-lm
lightningforever Jan 10, 2023
af3ca7a
Merge branch 'master' into lite/logger-lm
lightningforever Jan 10, 2023
d3109ea
reuse the log_dict_through_fabric
lightningforever Jan 10, 2023
2ae8ceb
Update src/lightning_fabric/fabric.py
awaelchli Jan 10, 2023
789d413
reuse metrics_to_scalars without deprecation
lightningforever Jan 10, 2023
117acdb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2023
c72085b
reuse utility code in fabric
lightningforever Jan 10, 2023
b691e2f
Merge remote-tracking branch 'origin/lite/logger-lm' into lite/logger-lm
lightningforever Jan 10, 2023
53fa4f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2023
6f10f9d
add test for utility function
lightningforever Jan 10, 2023
78573a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2023
52f9625
fix type
lightningforever Jan 10, 2023
29a5489
Merge remote-tracking branch 'origin/lite/logger-lm' into lite/logger-lm
lightningforever Jan 10, 2023
c5b9fa7
fix docs parsing error
lightningforever Jan 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `Fabric.log` for logging scalars using multiple loggers
* Added `Fabric.log_dict` for logging a dictionary of multiple metrics at once
* Added `Fabric.loggers` and `Fabric.logger` attributes to access the individual logger instances

* Added support for calling `self.log` and `self.log_dict` in a LightningModule when using Fabric
* Added access to `self.logger` and `self.loggers` in a LightningModule when using Fabric

- Added support for a consistent `.zero_grad(set_to_none=...)` on the wrapped optimizer regardless of which strategy is used ([#16275](https://github.com/Lightning-AI/lightning/issues/16275))

Expand Down
11 changes: 10 additions & 1 deletion src/lightning_fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,8 @@ def log(self, name: str, value: Any, step: Optional[int] = None) -> None:

Args:
name: The name of the metric to log.
value: The metric value to collect.
value: The metric value to collect. If the value is a :class:`torch.Tensor`, it gets detached from the
graph automatically.
step: Optional step number. Most Logger implementations auto-increment the step value by one with every
log call. You can specify your own value here.
"""
Expand All @@ -608,9 +609,17 @@ def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:

Args:
metrics: A dictionary where the key is the name of the metric and the value the scalar to be logged.
Any :class:`torch.Tensor`s in the dictionary get detached from the graph automatically.
step: Optional step number. Most Logger implementations auto-increment this value by one with every
log call. You can specify your own value here.
"""

def to_item(value: Tensor) -> Union[int, float, bool]:
if value.numel() != 1:
raise ValueError("Logging tensors with more than one element is not supported.")
return value.item()
carmocca marked this conversation as resolved.
Show resolved Hide resolved

metrics = apply_to_collection(metrics, dtype=Tensor, function=to_item)
for logger in self._loggers:
logger.log_metrics(metrics=metrics, step=step)

Expand Down
18 changes: 18 additions & 0 deletions src/lightning_fabric/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,21 @@ def _move_to_device_and_make_contiguous(t: Tensor, device: _DEVICE) -> Tensor:

# make sure existing tensors are in the correct device, also contiguous
return apply_to_collection(data, Tensor, _move_to_device_and_make_contiguous, device=device)


def convert_tensors_to_scalars(data: Any) -> Any:
"""Recursively walk through a collection and convert single-item tensors to scalar values.

Raises:
ValueError:
If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar.
"""

def to_item(value: Tensor) -> Union[int | float | bool]:
if value.numel() != 1:
raise ValueError(
f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar."
)
return value.item()

return apply_to_collection(data, Tensor, to_item)
53 changes: 37 additions & 16 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ def log(
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
"""
if self._fabric is not None:
self._log_dict_through_fabric(dictionary={name: value}, logger=logger)
return

# check for invalid values
apply_to_collection(value, dict, self.__check_not_nested, name)
apply_to_collection(
Expand Down Expand Up @@ -554,22 +558,39 @@ def log_dict(
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
"""
for k, v in dictionary.items():
self.log(
name=k,
value=v,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_dist=sync_dist,
sync_dist_group=sync_dist_group,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
rank_zero_only=rank_zero_only,
)
if self._fabric is not None:
self._log_dict_through_fabric(dictionary=dictionary, logger=logger) # type: ignore[arg-type]
else:
for k, v in dictionary.items():
self.log(
name=k,
value=v,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_dist=sync_dist,
sync_dist_group=sync_dist_group,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
rank_zero_only=rank_zero_only,
)

def _log_dict_through_fabric(self, dictionary: Dict[str, Any], logger: Optional[bool] = None) -> None:
if logger is False:
# Passing `logger=False` with Fabric does not make much sense because there is no other destination to
# log to, but we support it in case the original code was written for Trainer use
return

if any(isinstance(v, dict) for v in dictionary.values()):
raise ValueError(f"`self.log_dict({dictionary})` was called, but nested dictionaries cannot be logged")
for name, value in dictionary.items():
apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor))

assert self._fabric is not None
self._fabric.log_dict(metrics=dictionary)

@staticmethod
def __check_not_nested(value: dict, name: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import pytorch_lightning as pl
from lightning_fabric.plugins.environments import SLURMEnvironment
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
from pytorch_lightning.loggers import Logger, TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.utilities.metrics import metrics_to_scalars


class LoggerConnector:
Expand Down Expand Up @@ -80,7 +80,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
self._logged_metrics.update(metrics)

# turn all tensors to scalars
scalar_metrics = metrics_to_scalars(metrics)
scalar_metrics = convert_tensors_to_scalars(metrics)

if step is None:
step = scalar_metrics.pop("step", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from typing_extensions import TypedDict

from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning_fabric.utilities.distributed import _distributed_available
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
from pytorch_lightning.utilities.warnings import PossibleUserWarning

Expand Down Expand Up @@ -610,7 +610,7 @@ def any_tensor(_: Any) -> None:

# populate progress_bar metrics. convert tensors to numbers
if result_metric.meta.prog_bar:
metrics["pbar"][forked_name] = metrics_to_scalars(value)
metrics["pbar"][forked_name] = convert_tensors_to_scalars(value)

return metrics

Expand Down
17 changes: 3 additions & 14 deletions src/pytorch_lightning/utilities/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions to operate on metric values."""
from __future__ import annotations

from typing import Any

from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars


def metrics_to_scalars(metrics: Any) -> Any:
"""Recursively walk through a collection and convert single-item tensors to scalar values.

Raises:
MisconfigurationException:
ValueError:
If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar.
"""

def to_item(value: Tensor) -> int | float | bool:
if value.numel() != 1:
raise MisconfigurationException(
f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar."
)
return value.item()

return apply_to_collection(metrics, Tensor, to_item)
return convert_tensors_to_scalars(metrics)
25 changes: 25 additions & 0 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,3 +800,28 @@ def test_log_dict():
fabric.log_dict({"foo": 3, "bar": 4}, step=15)
logger0.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15)
logger1.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15)


def test_log_dict_input_parsing():
"""Test validation of input data types and preprocessing."""
logger = Mock()
fabric = Fabric(loggers=[logger])

# Tensor scalar, 0 dims
fabric.log("log", torch.tensor(1))
logger.log_metrics.assert_called_with(metrics={"log": 1}, step=None)
fabric.log_dict({"log_dict": torch.tensor(1)})
logger.log_metrics.assert_called_with(metrics={"log_dict": 1}, step=None)

# Tensor scalar, 1 dims
fabric.log("log", torch.tensor([2]))
logger.log_metrics.assert_called_with(metrics={"log": 2}, step=None)
fabric.log_dict({"log_dict": torch.tensor([2])})
logger.log_metrics.assert_called_with(metrics={"log_dict": 2}, step=None)

# Tensor, multiple dims
with pytest.raises(ValueError, match="Logging tensors with more than one element is not supported"):
fabric.log("log", torch.tensor([3, 4]))

with pytest.raises(ValueError, match="Logging tensors with more than one element is not supported"):
fabric.log_dict({"log_dict": torch.tensor([3, 4])})
88 changes: 82 additions & 6 deletions tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,13 +512,89 @@ def test_fabric_attributes():

fabric = Fabric()
wrapped_module, wrapped_optimizer = fabric.setup(module, optimizer)
assert module.fabric is fabric
assert module._fabric_optimizers == [wrapped_optimizer]
assert wrapped_module.fabric is fabric
assert wrapped_module._fabric_optimizers == [wrapped_optimizer]

# Attribute access on LightningModule.trainer gets redirected to Fabric
assert isinstance(module.trainer, _TrainerFabricShim)
assert module.trainer.global_rank == 0
assert isinstance(wrapped_module.trainer, _TrainerFabricShim)
assert wrapped_module.trainer.global_rank == 0
with pytest.raises(AttributeError, match="Your LightningModule code tried to access `self.trainer.current_epoch`"):
_ = module.trainer.current_epoch
_ = wrapped_module.trainer.current_epoch

assert module.optimizers() == wrapped_optimizer
assert wrapped_module.optimizers() == wrapped_optimizer


def test_fabric_logger_access():
"""Test that the logger attribute can be accessed when the LightningModule is used together with Fabric."""
# No logger
module = BoringModel()
fabric = Fabric()
wrapped_module = fabric.setup(module)
assert wrapped_module.loggers == []
with pytest.raises(IndexError):
_ = wrapped_module.logger

# Single Logger
logger = Mock()
module = BoringModel()
fabric = Fabric(loggers=logger)
wrapped_module = fabric.setup(module)
assert wrapped_module.logger == logger
assert wrapped_module.loggers == [logger]

# Multiple loggers
logger1 = Mock()
logger2 = Mock()
module = BoringModel()
fabric = Fabric(loggers=[logger1, logger2])
wrapped_module = fabric.setup(module)
assert wrapped_module.logger == logger1
assert wrapped_module.loggers == [logger1, logger2]


def test_fabric_log():
logger = Mock()
module = BoringModel()
fabric = Fabric(loggers=[logger])
wrapped_module = fabric.setup(module)

# unsupported data type
with pytest.raises(ValueError, match="`list` values cannot be logged"):
wrapped_module.log("invalid", list())

# supported data types
wrapped_module.log("int", 1)
logger.log_metrics.assert_called_with(metrics={"int": 1}, step=None)
wrapped_module.log("float", 0.1)
logger.log_metrics.assert_called_with(metrics={"float": 0.1}, step=None)
wrapped_module.log("tensor", torch.tensor(0.1))
logger.log_metrics.assert_called_with(metrics={"tensor": torch.tensor(0.1)}, step=None)

# logger=False
logger.reset_mock()
wrapped_module.log("nothing", 1, logger=False)
logger.log_metrics.assert_not_called()


def test_fabric_log_dict():
logger = Mock()
module = BoringModel()
fabric = Fabric(loggers=[logger])
wrapped_module = fabric.setup(module)

# unsupported data type
with pytest.raises(ValueError, match="`list` values cannot be logged"):
wrapped_module.log_dict({"invalid": [1, 2, 3]})

# nested dicts
with pytest.raises(ValueError, match="nested dictionaries cannot be logged"):
wrapped_module.log_dict({"nested": {"nested": 1}})

# supported data types
wrapped_module.log_dict({"int": 1, "float": 0.1, "tensor": torch.tensor(0.1)})
logger.log_metrics.assert_called_with(metrics={"int": 1, "float": 0.1, "tensor": torch.tensor(0.1)}, step=None)

# logger=False
logger.reset_mock()
wrapped_module.log_dict({"nothing": 1}, logger=False)
logger.log_metrics.assert_not_called()