Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Zaida Zhou <[email protected]>
  • Loading branch information
hanhaowen-mt and zhouzaida committed Jan 10, 2024
1 parent 658a2f2 commit 0a15a75
Show file tree
Hide file tree
Showing 10 changed files with 21 additions and 29 deletions.
4 changes: 2 additions & 2 deletions mmengine/dist/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,11 +790,11 @@ def _gather_object(obj: Any,
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
is_mccl_backend = group_backend == 'mccl'
if is_nccl_backend:
current_device = torch.device('', torch.cuda.current_device())
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
elif is_mccl_backend:
current_device = torch.device('', torch.musa.current_device())
current_device = torch.device('musa', torch.musa.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and
Expand Down
4 changes: 2 additions & 2 deletions mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from torch import Tensor
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
from mmengine.device import is_mlu_available, is_npu_available
from mmengine.device import is_musa_available
from mmengine.device import (is_mlu_available, is_npu_available,
is_musa_available)

from collections.abc import Iterable, Mapping

Expand Down
7 changes: 1 addition & 6 deletions mmengine/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,7 @@ def _get_device_id():
musa_visible_devices = list(range(num_device))
else:
musa_visible_devices = musa_visible_devices.split(',')
try:
return int(musa_visible_devices[local_rank])
except ValueError:
# handle case for Multi-Instance GPUs
# see #1148 for details
return musa_visible_devices[local_rank]
return int(musa_visible_devices[local_rank])
else:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
# TODO: return device id of npu and mlu.
Expand Down
1 change: 0 additions & 1 deletion mmengine/runner/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def autocast(device_type: Optional[str] = None,
with torch.musa.amp.autocast(
enabled=enabled, dtype=dtype, cache_enabled=cache_enabled):
yield
return
else:
# Device like MPS does not support fp16 training or testing.
# If an inappropriate device is set and fp16 is enabled, an error
Expand Down
11 changes: 4 additions & 7 deletions mmengine/runner/log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,13 @@ def get_log_after_iter(self, runner, batch_idx: int,
log_tag.pop('time')
log_tag.pop('data_time')

# If cuda is available, the max memory occupied should be calculated.
if is_cuda_available():
max_memory = self._get_max_memory(runner)
log_str += f'memory: {max_memory} '
tag['memory'] = max_memory
# If musa is available, the max memory occupied should be calculated.
if is_musa_available():
# If cuda/musa is available,
# the max memory occupied should be calculated.
if is_cuda_available() or is_musa_available():
max_memory = self._get_max_memory(runner)
log_str += f'memory: {max_memory} '
tag['memory'] = max_memory

# Loop left keys to fill `log_str`.
if mode in ('train', 'val'):
log_items = []
Expand Down
5 changes: 3 additions & 2 deletions mmengine/runner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch.utils.data import DataLoader

from mmengine.device import is_musa_available
from mmengine.device import is_cuda_available, is_musa_available
from mmengine.dist import get_rank, sync_random_seed
from mmengine.logging import print_log
from mmengine.utils import digit_version, is_list_of
Expand Down Expand Up @@ -70,7 +70,8 @@ def set_random_seed(seed: Optional[int] = None,
np.random.seed(seed)
torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if is_cuda_available():
torch.cuda.manual_seed_all(seed)
if is_musa_available():
torch.musa.manual_seed_all(seed)
# os.environ['PYTHONHASHSEED'] = str(seed)
Expand Down
2 changes: 1 addition & 1 deletion mmengine/structures/base_data_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def cuda(self) -> 'BaseDataElement':

# Tensor-like methods
def musa(self) -> 'BaseDataElement':
"""Convert all tensors to GPU in data."""
"""Convert all tensors to musa in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
Expand Down
4 changes: 2 additions & 2 deletions mmengine/utils/dl_utils/collect_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch

import mmengine
from mmengine.device import is_musa_available
from mmengine.device import is_cuda_available, is_musa_available
from .parrots_wrapper import TORCH_VERSION, get_build_config, is_rocm_pytorch


Expand Down Expand Up @@ -57,7 +57,7 @@ def collect_env():
env_info['sys.platform'] = sys.platform
env_info['Python'] = sys.version.replace('\n', '')

cuda_available = torch.cuda.is_available()
cuda_available = is_cuda_available()
musa_available = is_musa_available()
env_info['CUDA available'] = cuda_available
env_info['MUSA available'] = musa_available
Expand Down
6 changes: 3 additions & 3 deletions mmengine/utils/dl_utils/time_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from mmengine.device import is_musa_available
from mmengine.device import is_cuda_available, is_musa_available
from mmengine.dist.utils import master_only
from mmengine.logging import MMLogger, print_log

Expand Down Expand Up @@ -86,7 +86,7 @@ def wrapper(*args, **kwargs):
self.__count += 1

if self.with_sync:
if torch.cuda.is_available():
if is_cuda_available():
torch.cuda.synchronize()
elif is_musa_available():
torch.musa.synchronize()
Expand All @@ -95,7 +95,7 @@ def wrapper(*args, **kwargs):
result = fn(*args, **kwargs)

if self.with_sync:
if torch.cuda.is_available():
if is_cuda_available():
torch.cuda.synchronize()
elif is_musa_available():
torch.musa.synchronize()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_runner/test_log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from parameterized import parameterized

from mmengine.device import is_musa_available
from mmengine.device import is_cuda_available, is_musa_available
from mmengine.logging import HistoryBuffer, MessageHub, MMLogger
from mmengine.runner import LogProcessor
from mmengine.testing import RunnerTestCase
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy):
f"time: {train_logs['time']:.4f} "
f"data_time: {train_logs['data_time']:.4f} ")

if torch.cuda.is_available() or is_musa_available():
if is_cuda_available() or is_musa_available():
log_str += 'memory: 100 '
if mode == 'train':
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy):
f"time: {train_logs['time']:.4f} "
f"data_time: {train_logs['data_time']:.4f} ")

if torch.cuda.is_available() or is_musa_available():
if is_cuda_available() or is_musa_available():
log_str += 'memory: 100 '

if mode == 'train':
Expand Down

0 comments on commit 0a15a75

Please sign in to comment.