Skip to content

Commit

Permalink
[1/4] Add get_device_stats to accelerator interface (#9586)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniellepintz authored Sep 27, 2021
1 parent 83d83ab commit ab06987
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))


- Added `get_device_stats` to the Accelerator Interface and added its implementation for GPU and TPU ([#9586](https://github.com/PyTorchLightning/pytorch-lightning/pull/9586))


- Added `multifile` option to `LightningCLI` to enable/disable config save to preserve multiple files structure ([#9073](https://github.com/PyTorchLightning/pytorch-lightning/pull/9073))


Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Accelerator:
- CPU
- GPU
- TPU
- IPU
Each Accelerator gets two plugins upon initialization:
One to handle differences from the training routine and one to handle different precisions.
Expand Down Expand Up @@ -422,6 +423,17 @@ def restore_checkpoint_after_pre_dispatch(self) -> bool:
"""
return self.training_type_plugin.restore_checkpoint_after_pre_dispatch

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Gets stats for a given device.
Args:
device: device for which to get stats
Returns:
Dictionary of device stats
"""
raise NotImplementedError

def on_train_start(self) -> None:
"""Called when train begins."""
return self.training_type_plugin.on_train_start()
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Union

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -29,3 +33,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.")

return super().setup(trainer)

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Returns dummy implementation for now."""
return {}
81 changes: 81 additions & 0 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
# limitations under the License.
import logging
import os
import shutil
import subprocess
from typing import Any, Dict, List, Union

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,6 +57,83 @@ def set_nvidia_flags(local_rank: int) -> None:
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Gets stats for the given GPU device.
Args:
device: GPU device for which to get stats
Returns:
A dictionary mapping the metrics to their values.
Raises:
FileNotFoundError:
If nvidia-smi installation not found
"""
if _TORCH_GREATER_EQUAL_1_8:
return torch.cuda.memory_stats(device)
return _get_nvidia_gpu_stats(device)

def teardown(self) -> None:
super().teardown()
self._move_optimizer_state(torch.device("cpu"))


def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
Args:
device: GPU device for which to get stats
Returns:
A dictionary mapping the metrics to their values.
Raises:
FileNotFoundError:
If nvidia-smi installation not found
"""
gpu_stat_metrics = [
("utilization.gpu", "%"),
("memory.used", "MB"),
("memory.free", "MB"),
("utilization.memory", "%"),
("fan.speed", "%"),
("temperature.gpu", "°C"),
("temperature.memory", "°C"),
]
gpu_stat_keys = [k for k, _ in gpu_stat_metrics]
gpu_query = ",".join(gpu_stat_keys)

gpu_id = _get_gpu_id(device.index)
nvidia_smi_path = shutil.which("nvidia-smi")
if nvidia_smi_path is None:
raise FileNotFoundError("nvidia-smi: command not found")
result = subprocess.run(
[nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"],
encoding="utf-8",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE, # for backward compatibility with python version 3.6
check=True,
)

def _to_float(x: str) -> float:
try:
return float(x)
except ValueError:
return 0.0

s = result.stdout.strip()
stats = [_to_float(x) for x in s.split(", ")]

gpu_stats = {}
for i, (x, unit) in enumerate(gpu_stat_metrics):
gpu_stats[f"{x} ({unit})"] = stats[i]
return gpu_stats


def _get_gpu_id(device_id: int) -> str:
"""Get the unmasked real GPU IDs."""
# All devices if `CUDA_VISIBLE_DEVICES` unset
default = ",".join(str(i) for i in range(torch.cuda.device_count()))
cuda_visible_devices: List[str] = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
return cuda_visible_devices[device_id].strip()
20 changes: 19 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, Optional, Union

import torch
from torch.optim import Optimizer
Expand Down Expand Up @@ -61,3 +61,21 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
for opt in self.optimizers:
for p, v in opt.state.items():
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Gets stats for the given TPU device.
Args:
device: TPU device for which to get stats
Returns:
A dictionary mapping the metrics (free memory and peak memory) to their values.
"""
memory_info = xm.get_memory_info(device)
free_memory = memory_info["kb_free"]
peak_memory = memory_info["kb_total"] - free_memory
device_stats = {
"avg. free memory (MB)": free_memory,
"avg. peak memory (MB)": peak_memory,
}
return device_stats
36 changes: 36 additions & 0 deletions tests/accelerators/test_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch

from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin
from tests.helpers.runif import RunIf


@RunIf(min_torch="1.8")
@RunIf(min_gpus=1)
def test_get_torch_gpu_stats(tmpdir):
"""Test GPU get_device_stats with Pytorch >= 1.8.0."""
current_device = torch.device(f"cuda:{torch.cuda.current_device()}")
GPUAccel = GPUAccelerator(
training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin()
)
gpu_stats = GPUAccel.get_device_stats(current_device)
fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"]

for f in fields:
assert any(f in h for h in gpu_stats.keys())


@RunIf(max_torch="1.7")
@RunIf(min_gpus=1)
def test_get_nvidia_gpu_stats(tmpdir):
"""Test GPU get_device_stats with Pytorch < 1.8.0."""
current_device = torch.device(f"cuda:{torch.cuda.current_device()}")
GPUAccel = GPUAccelerator(
training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin()
)
gpu_stats = GPUAccel.get_device_stats(current_device)
fields = ["utilization.gpu", "memory.used", "memory.free", "utilization.memory"]

for f in fields:
assert any(f in h for h in gpu_stats.keys())
16 changes: 16 additions & 0 deletions tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.plugins import SingleTPUPlugin
from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
from tests.helpers.runif import RunIf


@RunIf(tpu=True)
def test_device_stats_tpu(tmpdir):
"""Test TPU get_device_stats."""
plugin = SingleTPUPlugin(1)
TPUAccel = TPUAccelerator(training_type_plugin=TPUSpawnPlugin(), precision_plugin=plugin)
tpu_stats = TPUAccel.get_device_stats("1")
fields = ["avg. free memory (MB)", "avg. peak memory (MB)"]

for f in fields:
assert any(f in h for h in tpu_stats.keys())

0 comments on commit ab06987

Please sign in to comment.