Skip to content

Commit

Permalink
GlobalPowerLimitOptimizer FSDP example (#147)
Browse files Browse the repository at this point in the history
Co-authored-by: Jae-Won Chung <[email protected]>
  • Loading branch information
parthraut and jaywonchung authored Jan 26, 2025
1 parent 3875315 commit a501e56
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 30 deletions.
26 changes: 23 additions & 3 deletions examples/power_limit_optimizer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

This example will demonstrate how to integrate Zeus with `torchvision` and the ImageNet dataset.

[`train_single.py`](train_single.py) and [`train_dp.py`](train_dp.py) were adapted and simplified from [PyTorch's example training code for ImageNet](https://github.com/pytorch/examples/blob/main/imagenet/main.py).
The former script is for simple single GPU training, whereas the latter is for data parallel training with PyTorch DDP and [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html).
[`train_single.py`](train_single.py) and [`train_dp.py`](train_dp.py) were adapted and simplified from [PyTorch's example training code for ImageNet](https://github.com/pytorch/examples/blob/main/imagenet/main.py). [`train_fsdp.py`](train_fsdp.py) was adapted from [Getting Started with Fully Sharded Data Parallel(FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html).

[`train_single.py`](train_single.py) is for simple single GPU training, [`train_dp.py`](train_dp.py) is for data parallel training with PyTorch DDP, and [`train_fsdp.py`](train_fsdp.py) is for Fully Sharded Data Parallel training.

## Dependencies

Expand All @@ -23,6 +24,17 @@ You just need to download and extract the ImageNet data and mount it to the Dock
- [`ZeusMonitor`](http://ml.energy/zeus/reference/monitor/#zeus.monitor.ZeusMonitor): Measures the GPU time and energy consumption of arbitrary code blocks.
- [`GlobalPowerLimitOptimizer`](https://ml.energy/zeus/reference/optimizer/power_limit/#zeus.optimizer.power_limit.GlobalPowerLimitOptimizer): Online-profiles each power limit with `ZeusMonitor` and finds the cost-optimal power limit.

## Multi-GPU Distributed Training (Pytorch DDP and FSDP)

When using `ZeusMonitor` and/or `GlobalPowerLimitOptimizer` in a multi-GPU Distributed context, construct one instance of `ZeusMonitor` and/or `GlobalPowerLimitOptimizer` per local rank (per GPU on each node), and pass in the local rank to `ZeusMonitor` as shown below:

```python
monitor = ZeusMonitor(gpu_indices=[local_rank]) # pass in local rank to gpu_indices.
plo = GlobalPowerLimitOptimizer(monitor)
```

Ensure that only one GPU is monitored per `ZeusMonitor`. Internally, `GlobalPowerLimitOptimizer` performs an [AllReduce](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html) to aggregate time and energy measurements across all GPUs before making a power limit decision.

## Example command

You can specify the maximum training time slowdown factor (1.0 means no slowdown) by setting `ZEUS_MAX_SLOWDOWN`. The default is set to 1.1 in this example script, meaning the lowest power limit that keeps training time inflation within 10% will be automatically found.
Expand All @@ -34,11 +46,19 @@ python train_single.py \
[DATA_DIR] \
--gpu 0 `# Specify the GPU id to use`

# Multi-GPU Data Parallel
# Multi-GPU Distributed Data Parallel
torchrun \
--nnodes 1 \
--nproc_per_node gpu `# Number of processes per node, should be equal to the number of GPUs.` \
`# When set to 'gpu', it means use all the GPUs available.` \
train_dp.py \
[DATA_DIR]

# Multi-GPU Fully Sharded Data Parallel
torchrun \
--nnodes 1 \
--nproc_per_node=gpu `# Number of processes per node, should be equal to the number of GPUs.` \
train_fsdp.py \
[DATA_DIR]
```
41 changes: 17 additions & 24 deletions examples/power_limit_optimizer/train_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import Subset

# ZEUS
from zeus.monitor import ZeusMonitor
from zeus.optimizer.power_limit import MaxSlowdownConstraint, GlobalPowerLimitOptimizer
from zeus.utils.env import get_env
from zeus.callback import Callback, CallbackSet


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -197,37 +195,32 @@ def main():
sampler=val_sampler,
)

# The rank 0 process will monitor and optimize the power limit of all GPUs.
if args.gpu == 0:
callback_set: list[Callback] = [
GlobalPowerLimitOptimizer(
monitor=ZeusMonitor(gpu_indices=None), # All visible GPUs.
optimum_selector=MaxSlowdownConstraint(
factor=get_env("ZEUS_MAX_SLOWDOWN", float, 1.1),
),
warmup_steps=10,
profile_steps=40,
pl_step=25,
)
]
else:
callback_set = []
callbacks = CallbackSet(callback_set)
# ZEUS
plo = GlobalPowerLimitOptimizer(
# Each process manages and monitors exactly one GPU in DDP training.
monitor=ZeusMonitor(gpu_indices=[args.gpu]),
optimum_selector=MaxSlowdownConstraint(
factor=get_env("ZEUS_MAX_SLOWDOWN", float, 1.1),
),
warmup_steps=10,
profile_steps=40,
pl_step=25,
)

for epoch in range(args.epochs):
train_sampler.set_epoch(epoch)

callbacks.on_epoch_begin()
train(train_loader, model, criterion, optimizer, epoch, args, callbacks)
callbacks.on_epoch_end()
plo.on_epoch_begin()
train(train_loader, model, criterion, optimizer, epoch, args, plo)
plo.on_epoch_end()

acc1 = validate(val_loader, model, criterion, args)
print(f"Top-1 accuracy: {acc1}")

scheduler.step()


def train(train_loader, model, criterion, optimizer, epoch, args, callbacks):
def train(train_loader, model, criterion, optimizer, epoch, args, plo):
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":.4e")
Expand All @@ -245,7 +238,7 @@ def train(train_loader, model, criterion, optimizer, epoch, args, callbacks):

end = time.time()
for i, (images, target) in enumerate(train_loader):
callbacks.on_step_begin() # Mark the beginning of the training step.
plo.on_step_begin() # Mark the beginning of the training step.

# Load data to GPU
images = images.cuda(args.gpu)
Expand Down Expand Up @@ -273,7 +266,7 @@ def train(train_loader, model, criterion, optimizer, epoch, args, callbacks):
batch_time.update(time.time() - end)
end = time.time()

callbacks.on_step_end()
plo.on_step_end()

if i % args.print_freq == 0:
progress.display(i + 1)
Expand Down
199 changes: 199 additions & 0 deletions examples/power_limit_optimizer/train_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import os
import argparse
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from torch.optim.lr_scheduler import StepLR

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)

from zeus.monitor import ZeusMonitor
from zeus.optimizer.power_limit import GlobalPowerLimitOptimizer

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output

def train(args, model, rank, world_size, train_loader, optimizer, epoch, plo, sampler=None):
model.train()
ddp_loss = torch.zeros(2).to(rank)
if sampler:
sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(train_loader):
plo.on_step_begin()

data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target, reduction='sum')
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(data)

plo.on_step_end()

dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))

def test(model, rank, world_size, test_loader):
model.eval()
ddp_loss = torch.zeros(3).to(rank)
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(rank), target.to(rank)
output = model(data)
ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
ddp_loss[2] += len(data)

dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
test_loss = ddp_loss[0] / ddp_loss[2]
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, int(ddp_loss[1]), int(ddp_loss[2]),
100. * ddp_loss[1] / ddp_loss[2]))

def fsdp_main(args):
# If the user wants to explicitly set MASTER_ADDR and MASTER_PORT:
if args.master_addr is not None:
os.environ['MASTER_ADDR'] = args.master_addr
if args.master_port is not None:
os.environ['MASTER_PORT'] = args.master_port

# The following environment variables are provided by torchrun:
# MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE
# We can now initialize the process group using these env variables.
dist.init_process_group(backend="nccl", init_method="env://")

rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = args.local_rank # Get local rank from the arguments

# Set the device using local rank
torch.cuda.set_device(local_rank)

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('./data', train=False, transform=transform)

sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)

train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
cuda_kwargs = {'num_workers': 2, 'pin_memory': True, 'shuffle': False}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)

train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(local_rank)
model = FSDP(model)

optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

init_start_event = torch.cuda.Event(enable_timing=True)
init_end_event = torch.cuda.Event(enable_timing=True)
init_start_event.record()

# Init ZeusMonitor and GPLO
monitor = ZeusMonitor(gpu_indices=[local_rank])
plo = GlobalPowerLimitOptimizer(monitor, profile_steps=200)

for epoch in range(1, args.epochs + 1):
plo.on_epoch_begin()
train(args, model, local_rank, world_size, train_loader, optimizer, epoch, plo, sampler=sampler1)
plo.on_epoch_end()

test(model, local_rank, world_size, test_loader)
scheduler.step()

init_end_event.record()

if rank == 0:
print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec")
print(f"{model}")

if args.save_model:
dist.barrier()
states = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")

dist.destroy_process_group()

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch MNIST FSDP with torchrun')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
parser.add_argument('--master_addr', type=str, default=None,
help='Master address for distributed training (optional, otherwise taken from env)')
parser.add_argument('--master_port', type=str, default=None,
help='Master port for distributed training (optional, otherwise taken from env)')
parser.add_argument('--local-rank', type=int, default=0,
help='Local rank for the process (required for torchrun)')

args = parser.parse_args()
torch.manual_seed(args.seed)

fsdp_main(args)
Loading

0 comments on commit a501e56

Please sign in to comment.