Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Update broadcast + reduce decision ModelCheckpoint] #6410

Merged
merged 70 commits into from
Mar 14, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
597ae27
resolve bug
tchaton Mar 4, 2021
ef11927
update
tchaton Mar 4, 2021
85b327d
update changelog
tchaton Mar 4, 2021
47f0b2c
update PR
tchaton Mar 4, 2021
bbe4255
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 4, 2021
1c33b48
Update pytorch_lightning/trainer/connectors/logger_connector/epoch_re…
tchaton Mar 4, 2021
6cd4713
add todo
tchaton Mar 4, 2021
45d7239
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 4, 2021
b58d7fb
resolve issues
tchaton Mar 4, 2021
e3a084a
resolve flake8
tchaton Mar 4, 2021
77edbed
update
tchaton Mar 4, 2021
6bcc88d
add coverage for reduce
tchaton Mar 4, 2021
c63bca5
wip
tchaton Mar 4, 2021
e26d301
restore back to brodbact
tchaton Mar 4, 2021
ce239fd
remove test.py
tchaton Mar 4, 2021
d8f1dc9
resolve flake8
tchaton Mar 4, 2021
237bbd2
update
tchaton Mar 4, 2021
f546ae4
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 4, 2021
6fbe70d
check world size
tchaton Mar 4, 2021
5f25fc5
resolve test
tchaton Mar 4, 2021
46cf2c6
update
tchaton Mar 4, 2021
7029b31
Merge branch 'master' into bugfix/5604_ddp_model_checkpoint
tchaton Mar 5, 2021
8523167
use pytorch version when defined
tchaton Mar 5, 2021
f28f950
update on comments
tchaton Mar 5, 2021
6eae79d
update on comments
tchaton Mar 5, 2021
1cd9431
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 5, 2021
9448964
flake8
tchaton Mar 5, 2021
1b5c90a
resolve bugs
tchaton Mar 5, 2021
a1264d9
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 5, 2021
9f3eb41
Merge branch 'master' into bugfix/5604_ddp_model_checkpoint
tchaton Mar 5, 2021
e88ef07
Merge branch 'master' into bugfix/5604_ddp_model_checkpoint
tchaton Mar 5, 2021
c21f148
Update CHANGELOG.md
tchaton Mar 5, 2021
94e9aa9
update
tchaton Mar 5, 2021
4626310
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 5, 2021
b260bf6
update
tchaton Mar 5, 2021
dd60ed1
update
tchaton Mar 5, 2021
45b65f1
update
tchaton Mar 6, 2021
dcd6884
remove test
tchaton Mar 6, 2021
2e046e8
update
tchaton Mar 6, 2021
68ffb5b
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 6, 2021
23b2c10
resolve flake8
tchaton Mar 6, 2021
b4c663b
update
tchaton Mar 6, 2021
aa89d5d
Merge branch 'master' into bugfix/5604_ddp_model_checkpoint
tchaton Mar 6, 2021
73e83f7
update
tchaton Mar 6, 2021
c060444
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 6, 2021
2eb6db4
update
tchaton Mar 6, 2021
060992b
proxy
tchaton Mar 6, 2021
5bad135
update
tchaton Mar 6, 2021
4579842
update
tchaton Mar 6, 2021
5276cd0
Merge branch 'master' into bugfix/broadcast_2
tchaton Mar 8, 2021
8027838
resolve typo
tchaton Mar 8, 2021
aa9a6ca
prune
tchaton Mar 8, 2021
4b6a6c5
update parallel
tchaton Mar 8, 2021
4b55c52
update
tchaton Mar 8, 2021
cbacf48
update changelog
tchaton Mar 8, 2021
057fbf3
update
tchaton Mar 9, 2021
7f515ea
Merge branch 'master' into bugfix/broadcast_2
tchaton Mar 9, 2021
7cbf38b
try running pipe
tchaton Mar 9, 2021
928cf2c
Merge branch 'master' into bugfix/broadcast_2
carmocca Mar 10, 2021
690b61f
Update pytorch_lightning/callbacks/model_checkpoint.py
tchaton Mar 10, 2021
300a632
update on comments
tchaton Mar 10, 2021
5e30377
Update pytorch_lightning/callbacks/model_checkpoint.py
tchaton Mar 11, 2021
015fbac
update on comennts
tchaton Mar 12, 2021
c213716
Merge branch 'bugfix/broadcast_2' of https://github.com/PyTorchLightn…
tchaton Mar 12, 2021
f668c3a
update
tchaton Mar 12, 2021
30feb40
update
tchaton Mar 12, 2021
a4bf623
update
tchaton Mar 12, 2021
b482589
fix
tchaton Mar 12, 2021
1ad9c62
remove comments
tchaton Mar 12, 2021
b64e105
resolve bugs
tchaton Mar 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
tchaton committed Mar 9, 2021
commit 057fbf3ddc89fb38a511822d08f285e3487e0ea4
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _do_save(self, trainer, filepath: str):
else:
raise ValueError(".save_function() not set")

def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor]) -> bool:
def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool:
if current is None:
return False

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/distributed/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
from typing import Any

from pytorch_lightning.utilities.distributed import broadcast_object_list
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.torch_distributed import broadcast_object_list


class LightningDistributed:
Expand All @@ -28,7 +28,7 @@ def broadcast(self, obj: Any, group=_group.WORLD):
obj = [obj]

if self.rank != 0:
obj = [None for _ in range(len(obj))]
obj = [None] * len(obj)

broadcast_object_list(obj, 0, group=group or _group.WORLD)

Expand Down
86 changes: 1 addition & 85 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,16 @@

import logging
import os
import pickle
import warnings
from functools import wraps
from typing import Any, Optional, Union

import torch

from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7

log = logging.getLogger(__name__)

if torch.distributed.is_available():
from torch.distributed import Backend, broadcast, get_backend, get_rank, group, GroupMember, ReduceOp
from torch.distributed import group, ReduceOp

else:

Expand All @@ -37,87 +34,6 @@ class group:
WORLD = None


# This part is used to provide broadcast support for PyTorch 1.5 and lower.
# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160
def _rank_not_in_group(group):
"""
Helper that checks if the current process's rank is not in a given group.
"""
if group is None:
return False
return group == GroupMember.NON_GROUP_MEMBER


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164
def _object_to_tensor(obj):
buffer = pickle.dumps(obj)
byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined]
byte_tensor = torch.ByteTensor(byte_storage)
local_size = torch.LongTensor([byte_tensor.numel()])
return byte_tensor, local_size


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py
def _tensor_to_object(tensor, tensor_size):
buf = tensor.numpy().tobytes()[:tensor_size]
out = pickle.loads(buf)
return out


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327
def _broadcast_object_list(object_list, src=0, group=None):
if _rank_not_in_group(group):
return

my_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.LongTensor(len(object_list))

group_backend = get_backend(group)
is_nccl_backend = group_backend == Backend.NCCL
current_device = torch.device("cpu")
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device('cuda', torch.cuda.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
object_sizes_tensor = object_sizes_tensor.to(current_device)

# Broadcast object sizes
broadcast(object_sizes_tensor, src=src, group=group)

# Concatenate and broadcast serialized object tensors
if my_rank == src:
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item())

if is_nccl_backend:
object_tensor = object_tensor.to(current_device)

broadcast(object_tensor, src=src, group=group)

# Deserialize objects using their stored sizes.
offset = 0
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset:offset + obj_size]
obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload]
offset += obj_size
object_list[i] = _tensor_to_object(obj_view, obj_size)


if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available():
from torch.distributed.distributed_c10d import broadcast_object_list
else:
broadcast_object_list = _broadcast_object_list


def rank_zero_only(fn):

@wraps(fn)
Expand Down
94 changes: 94 additions & 0 deletions pytorch_lightning/utilities/torch_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import logging
import pickle

import torch

from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7

log = logging.getLogger(__name__)

if torch.distributed.is_available():
from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember

# The code underneath is taken from PyTorch ``torch/distributed/distributed_c10d.py``
# and enable broadcasting for PyTorch 1.6 and lower.


# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160
def _rank_not_in_group(group):
"""
Helper that checks if the current process's rank is not in a given group.
"""
if group is None:
return False
return group == GroupMember.NON_GROUP_MEMBER


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164
def _object_to_tensor(obj):
buffer = pickle.dumps(obj)
byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined]
byte_tensor = torch.ByteTensor(byte_storage)
local_size = torch.LongTensor([byte_tensor.numel()])
return byte_tensor, local_size


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py
def _tensor_to_object(tensor, tensor_size):
buf = tensor.numpy().tobytes()[:tensor_size]
out = pickle.loads(buf)
return out


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327
def _broadcast_object_list(object_list, src=0, group=None):
if _rank_not_in_group(group):
return

my_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.LongTensor(len(object_list))

group_backend = get_backend(group)
is_nccl_backend = group_backend == Backend.NCCL
current_device = torch.device("cpu")
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device('cuda', torch.cuda.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
object_sizes_tensor = object_sizes_tensor.to(current_device)

# Broadcast object sizes
broadcast(object_sizes_tensor, src=src, group=group)

# Concatenate and broadcast serialized object tensors
if my_rank == src:
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item())

if is_nccl_backend:
object_tensor = object_tensor.to(current_device)

broadcast(object_tensor, src=src, group=group)

# Deserialize objects using their stored sizes.
offset = 0
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset:offset + obj_size]
obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload]
offset += obj_size
object_list[i] = _tensor_to_object(obj_view, obj_size)


if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available():
from torch.distributed.distributed_c10d import broadcast_object_list
else:
broadcast_object_list = _broadcast_object_list