Skip to content

Commit

Permalink
Feature: Add Jax support in CUDA sync and ZeusMonitor (#97)
Browse files Browse the repository at this point in the history
Co-authored-by: Jae-Won Chung <[email protected]>
  • Loading branch information
HGangloff and jaywonchung authored Jul 11, 2024
1 parent defed56 commit 37cfa1d
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 9 deletions.
35 changes: 35 additions & 0 deletions examples/jax/simple_monitoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import time
import jax
import jax.numpy as jnp
from zeus.monitor import ZeusMonitor

@jax.jit
def mat_prod(B):
""" Dummy example to make GPU warm with a jitting"""
A = jnp.ones((1000, 1000))
return A @ B

if __name__ == "__main__":
# Time/Energy measurements for four GPUs will begin and end at the same time.
gpu_indices = [0]

monitor = ZeusMonitor(gpu_indices=gpu_indices, backend="jax")

# Mark the beginning of a measurement window. You can use any string
# as the window name, but make sure it's unique.
monitor.begin_window("all_computations")

# Actual work
key = jax.random.PRNGKey(0)
B = jax.random.uniform(key, (1000, 1000))
for i in range(100):
B = mat_prod(B)

# Mark the end of a measurement window and retrieve the measurment result.
result = monitor.end_window("all_computations")

# Print the measurement result.
print(f"Training took {result.time} seconds.")
print(f"Training consumed {result.total_energy} Joules.")
for gpu_idx, gpu_energy in result.gpu_energy.items():
print(f"GPU {gpu_idx} consumed {gpu_energy} Joules.")
9 changes: 7 additions & 2 deletions zeus/monitor/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import os
import warnings
from typing import Literal
from time import time
from pathlib import Path
from dataclasses import dataclass
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(
cpu_indices: list[int] | None = None,
approx_instant_energy: bool = False,
log_file: str | Path | None = None,
backend: Literal["torch", "jax"] = "torch",
) -> None:
"""Instantiate the monitor.
Expand All @@ -179,9 +181,12 @@ def __init__(
instantaneous power consumption with the window's execution time. This should
be a better estimate than zero, but it's still an approximation.
log_file: Path to the log CSV file. If `None`, logging will be disabled.
backend: Deep learning framework to use to synchronize GPU computations.
Defaults to `"torch"`, in which case `torch.cuda.synchronize` will be used.
"""
# Save arguments.
self.approx_instant_energy = approx_instant_energy
self.backend: Literal["torch", "jax"] = backend

# Get gpus
try:
Expand Down Expand Up @@ -274,7 +279,7 @@ def begin_window(self, key: str, sync_cuda: bool = True) -> None:
# Call cudaSynchronize to make sure we freeze at the right time.
if sync_cuda:
for gpu_index in self.gpu_indices:
cuda_sync(gpu_index)
cuda_sync(gpu_index, self.backend)

# Freeze the start time of the profiling window.
timestamp: float = time()
Expand Down Expand Up @@ -338,7 +343,7 @@ def end_window(
# Call cudaSynchronize to make sure we freeze at the right time.
if sync_cuda:
for gpu_index in self.gpu_indices:
cuda_sync(gpu_index)
cuda_sync(gpu_index, self.backend)

# If the measurement window is cancelled, return an empty Measurement object.
if cancel:
Expand Down
49 changes: 42 additions & 7 deletions zeus/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import types
from typing import Literal
from functools import lru_cache

from zeus.utils.logging import get_logger
Expand All @@ -26,7 +27,7 @@


@lru_cache(maxsize=1)
def torch_is_available():
def torch_is_available(ensure_available: bool = False):
"""Check if PyTorch is available."""
try:
import torch
Expand All @@ -37,23 +38,57 @@ def torch_is_available():
MODULE_CACHE["torch"] = torch
logger.info("PyTorch with CUDA support is available.")
return True
except ImportError:
except ImportError as e:
logger.info("PyTorch is not available.")
if ensure_available:
raise RuntimeError("Failed to import Pytorch") from e
return False


def cuda_sync(device: int | None = None) -> None:
"""Synchronize CPU and CUDA.
@lru_cache(maxsize=1)
def jax_is_available(ensure_available: bool = False):
"""Check if JAX is available."""
try:
import jax # type: ignore

assert jax.devices("gpu"), "JAX is available but does not have CUDA support."
MODULE_CACHE["jax"] = jax
logger.info("JAX with CUDA support is available.")
return True
except ImportError as e:
logger.info("JAX is not available")
if ensure_available:
raise RuntimeError("Failed to import JAX") from e
return False


def cuda_sync(
device: int | None = None, backend: Literal["torch", "jax"] = "torch"
) -> None:
"""Synchronize CPU with CUDA.
Note: `cupy.cuda.Device.synchronize` may be a good choice to make
CUDA device synchronization more general. Haven't tested it yet.
Args:
device: The device to synchronize.
backend: Deep learning framework to use to synchronize GPU computations.
Defaults to `"torch"`, in which case `torch.cuda.synchronize` will be used.
"""
if torch_is_available():
if backend == "torch" and torch_is_available(ensure_available=True):
torch = MODULE_CACHE["torch"]

torch.cuda.synchronize(device)
return

raise RuntimeError("No frameworks are available.")
elif backend == "jax" and jax_is_available(ensure_available=True):
jax = MODULE_CACHE["jax"]

(
jax.device_put(
0.0, device=None if device is None else jax.devices("gpu")[device]
)
+ 0
).block_until_ready()

else:
raise RuntimeError("No framework is available.")

0 comments on commit 37cfa1d

Please sign in to comment.