Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parameter count in ModelSummary when parameters are DTensors #20163

Merged
merged 8 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
awaelchli committed Aug 4, 2024
commit 0a69fcff89cc353a8a80dad92ccb876021b8e825
12 changes: 11 additions & 1 deletion src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from contextlib import nullcontext
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union, TypeGuard

import torch
import torch.nn.functional as F
Expand All @@ -20,6 +20,7 @@
from lightning.fabric.utilities.data import _num_cpus_available
from lightning.fabric.utilities.rank_zero import rank_zero_info
from lightning.fabric.utilities.types import _PATH, ReduceOp
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4

if torch.distributed.is_available():
from torch.distributed import group
Expand All @@ -32,6 +33,7 @@ class group: # type: ignore
if TYPE_CHECKING:
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.strategies import Strategy
from torch.distributed._tensor import DTensor


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -427,3 +429,11 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.barrier()
if self.group is not None:
torch.distributed.destroy_process_group(self.group)


def _is_dtensor(tensor: Tensor) -> TypeGuard["DTensor"]:
if _TORCH_GREATER_EQUAL_2_4:
from torch.distributed._tensor import DTensor

return isinstance(tensor, DTensor)
return False
20 changes: 6 additions & 14 deletions src/lightning/pytorch/utilities/model_summary/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from torch.utils.hooks import RemovableHandle

import lightning.pytorch as pl
from lightning.fabric.utilities.distributed import _is_dtensor
from lightning.pytorch.utilities.model_helpers import _ModuleMode
from lightning.pytorch.utilities.rank_zero import WarningCache
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4

log = logging.getLogger(__name__)
warning_cache = WarningCache()
Expand Down Expand Up @@ -136,7 +136,7 @@ def layer_type(self) -> str:
@property
def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
return sum(math.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._module.parameters())

@property
def training(self) -> bool:
Expand Down Expand Up @@ -265,13 +265,11 @@ def total_training_modes(self) -> Dict[str, int]:

@property
def total_parameters(self) -> int:
return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters())

@property
def trainable_parameters(self) -> int:
return sum(
p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad
)
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters() if p.requires_grad)

@property
def total_layer_params(self) -> int:
Expand Down Expand Up @@ -471,16 +469,10 @@ def get_human_readable_count(number: int) -> str:
return f"{number:,.1f} {labels[index]}"


def _is_lazy_weight_tensor(p: Tensor) -> bool:
def _tensor_has_shape(p: Tensor) -> bool:
from torch.nn.parameter import UninitializedParameter

if _TORCH_GREATER_EQUAL_2_4:
from torch.distributed._tensor import DTensor

if isinstance(p, DTensor):
return False

if isinstance(p, UninitializedParameter):
if isinstance(p, UninitializedParameter) and not _is_dtensor(p):
warning_cache.warn(
"The total number of parameters detected may be inaccurate because the model contains"
" an instance of `UninitializedParameter`. To get an accurate number, set `self.example_input_array`"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
NOT_APPLICABLE,
LayerSummary,
ModelSummary,
_is_lazy_weight_tensor,
_tensor_has_shape,
get_human_readable_count,
)

Expand All @@ -40,7 +40,7 @@ class DeepSpeedLayerSummary(LayerSummary):
@override
def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())

@property
def average_shard_parameters(self) -> int:
Expand All @@ -49,7 +49,7 @@ def average_shard_parameters(self) -> int:
def partitioned_size(p: Parameter) -> int:
return p.partitioned_size() if RequirementCache("deepspeed<0.6.6") else p.partition_numel()

return sum(partitioned_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(partitioned_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())


class DeepSpeedSummary(ModelSummary):
Expand All @@ -71,13 +71,13 @@ def summarize(self) -> Dict[str, DeepSpeedLayerSummary]: # type: ignore[overrid
@property
@override
def total_parameters(self) -> int:
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._model.parameters())

@property
@override
def trainable_parameters(self) -> int:
return sum(
deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0
deepspeed_param_size(p) if not _tensor_has_shape(p) else 0
for p in self._model.parameters()
if p.requires_grad
)
Expand Down