-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GlobalPowerLimitOptimizer
FSDP example (#147)
Co-authored-by: Jae-Won Chung <[email protected]>
- Loading branch information
1 parent
3875315
commit a501e56
Showing
5 changed files
with
330 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import os | ||
import argparse | ||
import functools | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torchvision import datasets, transforms | ||
|
||
from torch.optim.lr_scheduler import StepLR | ||
|
||
import torch.distributed as dist | ||
from torch.utils.data.distributed import DistributedSampler | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
|
||
from torch.distributed.fsdp.fully_sharded_data_parallel import ( | ||
CPUOffload, | ||
BackwardPrefetch, | ||
) | ||
from torch.distributed.fsdp.wrap import ( | ||
size_based_auto_wrap_policy, | ||
enable_wrap, | ||
wrap, | ||
) | ||
|
||
from zeus.monitor import ZeusMonitor | ||
from zeus.optimizer.power_limit import GlobalPowerLimitOptimizer | ||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 32, 3, 1) | ||
self.conv2 = nn.Conv2d(32, 64, 3, 1) | ||
self.dropout1 = nn.Dropout(0.25) | ||
self.dropout2 = nn.Dropout(0.5) | ||
self.fc1 = nn.Linear(9216, 128) | ||
self.fc2 = nn.Linear(128, 10) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = F.relu(x) | ||
x = self.conv2(x) | ||
x = F.relu(x) | ||
x = F.max_pool2d(x, 2) | ||
x = self.dropout1(x) | ||
x = torch.flatten(x, 1) | ||
x = self.fc1(x) | ||
x = F.relu(x) | ||
x = self.dropout2(x) | ||
x = self.fc2(x) | ||
output = F.log_softmax(x, dim=1) | ||
return output | ||
|
||
def train(args, model, rank, world_size, train_loader, optimizer, epoch, plo, sampler=None): | ||
model.train() | ||
ddp_loss = torch.zeros(2).to(rank) | ||
if sampler: | ||
sampler.set_epoch(epoch) | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
plo.on_step_begin() | ||
|
||
data, target = data.to(rank), target.to(rank) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target, reduction='sum') | ||
loss.backward() | ||
optimizer.step() | ||
ddp_loss[0] += loss.item() | ||
ddp_loss[1] += len(data) | ||
|
||
plo.on_step_end() | ||
|
||
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) | ||
if rank == 0: | ||
print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1])) | ||
|
||
def test(model, rank, world_size, test_loader): | ||
model.eval() | ||
ddp_loss = torch.zeros(3).to(rank) | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(rank), target.to(rank) | ||
output = model(data) | ||
ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() | ||
pred = output.argmax(dim=1, keepdim=True) | ||
ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item() | ||
ddp_loss[2] += len(data) | ||
|
||
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) | ||
if rank == 0: | ||
test_loss = ddp_loss[0] / ddp_loss[2] | ||
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( | ||
test_loss, int(ddp_loss[1]), int(ddp_loss[2]), | ||
100. * ddp_loss[1] / ddp_loss[2])) | ||
|
||
def fsdp_main(args): | ||
# If the user wants to explicitly set MASTER_ADDR and MASTER_PORT: | ||
if args.master_addr is not None: | ||
os.environ['MASTER_ADDR'] = args.master_addr | ||
if args.master_port is not None: | ||
os.environ['MASTER_PORT'] = args.master_port | ||
|
||
# The following environment variables are provided by torchrun: | ||
# MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE | ||
# We can now initialize the process group using these env variables. | ||
dist.init_process_group(backend="nccl", init_method="env://") | ||
|
||
rank = dist.get_rank() | ||
world_size = dist.get_world_size() | ||
local_rank = args.local_rank # Get local rank from the arguments | ||
|
||
# Set the device using local rank | ||
torch.cuda.set_device(local_rank) | ||
|
||
transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
]) | ||
|
||
dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform) | ||
dataset2 = datasets.MNIST('./data', train=False, transform=transform) | ||
|
||
sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True) | ||
sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size) | ||
|
||
train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1} | ||
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2} | ||
cuda_kwargs = {'num_workers': 2, 'pin_memory': True, 'shuffle': False} | ||
train_kwargs.update(cuda_kwargs) | ||
test_kwargs.update(cuda_kwargs) | ||
|
||
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) | ||
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) | ||
|
||
model = Net().to(local_rank) | ||
model = FSDP(model) | ||
|
||
optimizer = optim.Adadelta(model.parameters(), lr=args.lr) | ||
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) | ||
|
||
init_start_event = torch.cuda.Event(enable_timing=True) | ||
init_end_event = torch.cuda.Event(enable_timing=True) | ||
init_start_event.record() | ||
|
||
# Init ZeusMonitor and GPLO | ||
monitor = ZeusMonitor(gpu_indices=[local_rank]) | ||
plo = GlobalPowerLimitOptimizer(monitor, profile_steps=200) | ||
|
||
for epoch in range(1, args.epochs + 1): | ||
plo.on_epoch_begin() | ||
train(args, model, local_rank, world_size, train_loader, optimizer, epoch, plo, sampler=sampler1) | ||
plo.on_epoch_end() | ||
|
||
test(model, local_rank, world_size, test_loader) | ||
scheduler.step() | ||
|
||
init_end_event.record() | ||
|
||
if rank == 0: | ||
print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec") | ||
print(f"{model}") | ||
|
||
if args.save_model: | ||
dist.barrier() | ||
states = model.state_dict() | ||
if rank == 0: | ||
torch.save(states, "mnist_cnn.pt") | ||
|
||
dist.destroy_process_group() | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='PyTorch MNIST FSDP with torchrun') | ||
parser.add_argument('--batch-size', type=int, default=64, metavar='N', | ||
help='input batch size for training (default: 64)') | ||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', | ||
help='input batch size for testing (default: 1000)') | ||
parser.add_argument('--epochs', type=int, default=10, metavar='N', | ||
help='number of epochs to train (default: 10)') | ||
parser.add_argument('--lr', type=float, default=1.0, metavar='LR', | ||
help='learning rate (default: 1.0)') | ||
parser.add_argument('--gamma', type=float, default=0.7, metavar='M', | ||
help='Learning rate step gamma (default: 0.7)') | ||
parser.add_argument('--no-cuda', action='store_true', default=False, | ||
help='disables CUDA training') | ||
parser.add_argument('--seed', type=int, default=1, metavar='S', | ||
help='random seed (default: 1)') | ||
parser.add_argument('--save-model', action='store_true', default=False, | ||
help='For Saving the current Model') | ||
parser.add_argument('--master_addr', type=str, default=None, | ||
help='Master address for distributed training (optional, otherwise taken from env)') | ||
parser.add_argument('--master_port', type=str, default=None, | ||
help='Master port for distributed training (optional, otherwise taken from env)') | ||
parser.add_argument('--local-rank', type=int, default=0, | ||
help='Local rank for the process (required for torchrun)') | ||
|
||
args = parser.parse_args() | ||
torch.manual_seed(args.seed) | ||
|
||
fsdp_main(args) |
Oops, something went wrong.