Skip to content

Commit

Permalink
[2.5] Remove test/predict in lightning fl example (#3185)
Browse files Browse the repository at this point in the history
### Issue

When using the external process (SubprocessLauncher)

(1) In the last round the CJ will consider the job finish right after
"fit" is called, thus the predict/test stage of external program will be
killed in the middle

(2) the standalone lightning training script
(`examples/hello-world/ml-to-fl/pt/src/cifar10_lightning_original.py`)
is not working because of import (`from src.net import Net`)

### Description

- Remove the test/predict stage as it is not required
- Move the Net definition inside the lit_net.py file so import will not
be an issue when running standalone lightning training

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Quick tests passed locally by running `./runtest.sh`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated.
  • Loading branch information
YuanTingHsieh authored Jan 28, 2025
1 parent e375c1e commit 978a42d
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 86 deletions.
21 changes: 0 additions & 21 deletions examples/hello-world/ml-to-fl/pt/src/cifar10_lightning_ddp_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(self, data_dir: str = DATASET_PATH, batch_size: int = BATCH_SIZE):

def prepare_data(self):
torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=transform)
torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=transform)

def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
Expand All @@ -49,24 +48,12 @@ def setup(self, stage: str):
)
self.cifar_train, self.cifar_val = random_split(cifar_full, [0.8, 0.2])

# Assign test dataset for use in dataloader(s)
if stage == "test" or stage == "predict":
self.cifar_test = torchvision.datasets.CIFAR10(
root=self.data_dir, train=False, download=False, transform=transform
)

def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)

def predict_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)


def main():
model = LitNet()
Expand Down Expand Up @@ -95,14 +82,6 @@ def main():
print("--- train new model ---")
trainer.fit(model, datamodule=cifar10_dm)

# test local model
print("--- test new model ---")
trainer.test(ckpt_path="best", datamodule=cifar10_dm)

# get predictions
print("--- prediction with new best model ---")
trainer.predict(ckpt_path="best", datamodule=cifar10_dm)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,12 @@ def setup(self, stage: str):
)
self.cifar_train, self.cifar_val = random_split(cifar_full, [0.8, 0.2])

# Assign test dataset for use in dataloader(s)
if stage == "test" or stage == "predict":
self.cifar_test = torchvision.datasets.CIFAR10(
root=self.data_dir, train=False, download=False, transform=transform
)

def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)

def predict_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)


def main():
model = LitNet()
Expand All @@ -75,14 +63,6 @@ def main():
print("--- train new model ---")
trainer.fit(model, datamodule=cifar10_dm)

# test local model
print("--- test new model ---")
trainer.test(ckpt_path="best", datamodule=cifar10_dm)

# get predictions
print("--- prediction with new best model ---")
trainer.predict(ckpt_path="best", datamodule=cifar10_dm)


if __name__ == "__main__":
main()
21 changes: 0 additions & 21 deletions examples/hello-world/ml-to-fl/pt/src/cifar10_lightning_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(self, data_dir: str = DATASET_PATH, batch_size: int = BATCH_SIZE):

def prepare_data(self):
torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=transform)
torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=transform)

def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
Expand All @@ -49,24 +48,12 @@ def setup(self, stage: str):
)
self.cifar_train, self.cifar_val = random_split(cifar_full, [0.8, 0.2])

# Assign test dataset for use in dataloader(s)
if stage == "test" or stage == "predict":
self.cifar_test = torchvision.datasets.CIFAR10(
root=self.data_dir, train=False, download=False, transform=transform
)

def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)

def predict_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)


def main():
model = LitNet()
Expand Down Expand Up @@ -95,14 +82,6 @@ def main():
print("--- train new model ---")
trainer.fit(model, datamodule=cifar10_dm)

# test local model
print("--- test new model ---")
trainer.test(ckpt_path="best", datamodule=cifar10_dm)

# get predictions
print("--- prediction with new best model ---")
trainer.predict(ckpt_path="best", datamodule=cifar10_dm)


if __name__ == "__main__":
main()
27 changes: 4 additions & 23 deletions examples/hello-world/ml-to-fl/pt/src/cifar10_lightning_original.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def __init__(self, data_dir: str = DATASET_PATH, batch_size: int = BATCH_SIZE):

def prepare_data(self):
torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=transform)
torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=transform)

def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
Expand All @@ -46,43 +45,25 @@ def setup(self, stage: str):
)
self.cifar_train, self.cifar_val = random_split(cifar_full, [0.8, 0.2])

# Assign test dataset for use in dataloader(s)
if stage == "test" or stage == "predict":
self.cifar_test = torchvision.datasets.CIFAR10(
root=self.data_dir, train=False, download=False, transform=transform
)

def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)

def predict_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)


def main():
model = LitNet()
cifar10_dm = CIFAR10DataModule()

trainer = Trainer(max_epochs=1, devices=1 if torch.cuda.is_available() else None)
if torch.cuda.is_available():
trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1 if torch.cuda.is_available() else None)
else:
trainer = Trainer(max_epochs=1, devices=None)

# perform local training
print("--- train new model ---")
trainer.fit(model, datamodule=cifar10_dm)

# test local model
print("--- test new model ---")
trainer.test(ckpt_path="best", datamodule=cifar10_dm)

# get predictions
print("--- prediction with new best model ---")
trainer.predict(ckpt_path="best", datamodule=cifar10_dm)


if __name__ == "__main__":
main()
23 changes: 22 additions & 1 deletion examples/hello-world/ml-to-fl/pt/src/lit_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,37 @@

from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule
from src.net import Net
from torchmetrics import Accuracy

NUM_CLASSES = 10
criterion = nn.CrossEntropyLoss()


class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


class LitNet(LightningModule):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 978a42d

Please sign in to comment.