-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add flower script #2754
Add flower script #2754
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Flower App (PyTorch) in NVIDIA FLARE | ||
|
||
In this example, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator with the JobAPI. | ||
|
||
## Preconditions | ||
|
||
To run Flower code in NVFlare, simply install the requirements and run the below script. | ||
|
||
Note, this code is directly copied from Flower's [app-pytorch](https://github.com/adap/flower/tree/main/examples/app-pytorch) example. | ||
|
||
## Install dependencies | ||
To run this job with NVFlare, we first need to install the dependencies. | ||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Run a simulation | ||
|
||
Next, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator. | ||
```bash | ||
python fedavg_flwr_cifar10.py | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from flwr.client import ClientApp, NumPyClient | ||
from flwr.common import Context | ||
from task import DEVICE, Net, get_weights, load_data, set_weights, test, train | ||
|
||
# Load model and data (simple CNN, CIFAR-10) | ||
net = Net().to(DEVICE) | ||
trainloader, testloader = load_data() | ||
|
||
|
||
# Define FlowerClient and client_fn | ||
class FlowerClient(NumPyClient): | ||
def fit(self, parameters, config): | ||
set_weights(net, parameters) | ||
results = train(net, trainloader, testloader, epochs=1, device=DEVICE) | ||
return get_weights(net), len(trainloader.dataset), results | ||
|
||
def evaluate(self, parameters, config): | ||
set_weights(net, parameters) | ||
loss, accuracy = test(net, testloader) | ||
return loss, len(testloader.dataset), {"accuracy": accuracy} | ||
|
||
|
||
def client_fn(context: Context): | ||
"""Create and return an instance of Flower `Client`.""" | ||
return FlowerClient().to_client() | ||
|
||
|
||
# Flower ClientApp | ||
app = ClientApp( | ||
client_fn=client_fn, | ||
) | ||
|
||
|
||
# Legacy mode | ||
if __name__ == "__main__": | ||
from flwr.client import start_client | ||
|
||
start_client( | ||
server_address="127.0.0.1:8080", | ||
client=FlowerClient().to_client(), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from nvflare import FedJob | ||
from nvflare.app_opt.flower.controller import FlowerController | ||
from nvflare.app_opt.flower.executor import FlowerExecutor | ||
|
||
if __name__ == "__main__": | ||
n_clients = 2 | ||
num_rounds = 2 | ||
|
||
job = FedJob(name="cifar10_flwr") | ||
|
||
# Define the controller workflow and send to server | ||
controller = FlowerController(server_app="server:app") | ||
job.to_server(controller) | ||
|
||
# Add flwr server code | ||
job.to_server("server.py") | ||
|
||
# Add clients | ||
executor = FlowerExecutor(client_app="client:app") | ||
job.to_clients(executor) | ||
|
||
# Add flwr client code | ||
job.to_clients("client.py") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This syntax is strange, I don't think we should do this. should the client.py goes to the FlowerExecutor, executor = FlowerExecutor(client_app ="client:app", task_scripts = "client.py") the syntax Job.to(, ) , seems very strange to me, and in most cases, doesn't make sense. I didnt know until QA highlight this. Not all executor has task executor file. Add special syntax, take any string path seem to be a big assumption.
I strongly suggest disasble this syntax. @yanchengnv @YuanTingHsieh @yhwen @holgerroth what are you opinions on this ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding @yhwen. Supporting placing of scripts to server/clients with So, to me it makes sense and is an extensible solution for different use cases. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the Git Rendering wipe out the original string, what wrote was "job.to( left bracket "<" string path right bracket ">", site_name) job.to(project.toml, site-1) all goes to an external directory feel strange to me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could do it so python modules are supported directly
but then you still cannot add config files that might be needed by a component. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now, I am looking at .to_client() syntax, I also confused. which site_name this one add to ? it is suppose to be an syntax sugar wrapper. If the Client is Client model intend to have script executor, why the component doesnt include script or configure file as argument ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's to_clients(). Only use when sending the component to all clients. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, to_clients() make sense, but this is not address the original question though There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The job.to(thing, target) syntax is confusing since the 1st arg could be many things. Perhaps we could improve by defining multiple input args, one arg for each type of things, like this: job.to(scripts=[list of scripts], dirs=[list of dirs], objects=[list of objs], targets=[list of targets]) It is possible or maybe even desirable to send multiple types of things in one call - you shouldn't have to call job.to() many times. |
||
|
||
job.export_job("/tmp/nvflare/jobs/job_config") | ||
job.simulator_run("/tmp/nvflare/jobs/workdir", n_clients=n_clients) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
nvflare~=2.5.0rc | ||
flwr[simulation]>=1.8.0 | ||
torch==2.2.1 | ||
torchvision==0.17.1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import List, Tuple | ||
|
||
from flwr.common import Metrics, ndarrays_to_parameters | ||
from flwr.server import ServerApp, ServerConfig | ||
from flwr.server.strategy import FedAvg | ||
from task import Net, get_weights | ||
|
||
|
||
# Define metric aggregation function | ||
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: | ||
examples = [num_examples for num_examples, _ in metrics] | ||
|
||
# Multiply accuracy of each client by number of examples used | ||
train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] | ||
train_accuracies = [num_examples * m["train_accuracy"] for num_examples, m in metrics] | ||
val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] | ||
val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] | ||
|
||
# Aggregate and return custom metric (weighted average) | ||
return { | ||
"train_loss": sum(train_losses) / sum(examples), | ||
"train_accuracy": sum(train_accuracies) / sum(examples), | ||
"val_loss": sum(val_losses) / sum(examples), | ||
"val_accuracy": sum(val_accuracies) / sum(examples), | ||
} | ||
|
||
|
||
# Initialize model parameters | ||
ndarrays = get_weights(Net()) | ||
parameters = ndarrays_to_parameters(ndarrays) | ||
|
||
|
||
# Define strategy | ||
strategy = FedAvg( | ||
fraction_fit=1.0, # Select all available clients | ||
fraction_evaluate=0.0, # Disable evaluation | ||
min_available_clients=2, | ||
fit_metrics_aggregation_fn=weighted_average, | ||
initial_parameters=parameters, | ||
) | ||
|
||
|
||
# Define config | ||
config = ServerConfig(num_rounds=3) | ||
|
||
|
||
# Flower ServerApp | ||
app = ServerApp( | ||
config=config, | ||
strategy=strategy, | ||
) | ||
|
||
|
||
# Legacy mode | ||
if __name__ == "__main__": | ||
from flwr.server import start_server | ||
|
||
start_server( | ||
server_address="0.0.0.0:8080", | ||
config=config, | ||
strategy=strategy, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections import OrderedDict | ||
from logging import INFO | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from flwr.common.logger import log | ||
from torch.utils.data import DataLoader | ||
from torchvision.datasets import CIFAR10 | ||
from torchvision.transforms import Compose, Normalize, ToTensor | ||
|
||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
class Net(nn.Module): | ||
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" | ||
|
||
def __init__(self) -> None: | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 5 * 5) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
return self.fc3(x) | ||
|
||
|
||
def load_data(): | ||
"""Load CIFAR-10 (training and test set).""" | ||
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | ||
trainset = CIFAR10("./data", train=True, download=True, transform=trf) | ||
testset = CIFAR10("./data", train=False, download=True, transform=trf) | ||
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) | ||
|
||
|
||
def train(net, trainloader, valloader, epochs, device): | ||
"""Train the model on the training set.""" | ||
log(INFO, "Starting training...") | ||
net.to(device) # move model to GPU if available | ||
criterion = torch.nn.CrossEntropyLoss().to(device) | ||
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) | ||
net.train() | ||
for _ in range(epochs): | ||
for images, labels in trainloader: | ||
images, labels = images.to(device), labels.to(device) | ||
optimizer.zero_grad() | ||
loss = criterion(net(images), labels) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
train_loss, train_acc = test(net, trainloader) | ||
val_loss, val_acc = test(net, valloader) | ||
|
||
results = { | ||
"train_loss": train_loss, | ||
"train_accuracy": train_acc, | ||
"val_loss": val_loss, | ||
"val_accuracy": val_acc, | ||
} | ||
return results | ||
|
||
|
||
def test(net, testloader): | ||
"""Validate the model on the test set.""" | ||
net.to(DEVICE) | ||
criterion = torch.nn.CrossEntropyLoss() | ||
correct, loss = 0, 0.0 | ||
with torch.no_grad(): | ||
for images, labels in testloader: | ||
outputs = net(images.to(DEVICE)) | ||
labels = labels.to(DEVICE) | ||
loss += criterion(outputs, labels).item() | ||
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() | ||
accuracy = correct / len(testloader.dataset) | ||
return loss, accuracy | ||
|
||
|
||
def get_weights(net): | ||
return [val.cpu().numpy() for _, val in net.state_dict().items()] | ||
|
||
|
||
def set_weights(net, parameters): | ||
params_dict = zip(net.state_dict().keys(), parameters) | ||
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) | ||
net.load_state_dict(state_dict, strict=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might need to ping specific version here as Flower is still changing their api implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we do need to pin a specific version of flower.
I asked them for a fix version number but have not got response from them. Their main developer Pan is attending a conference. Not sure when he will be able to get back to us.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I talked to pan, he said already respond to you ( he message from LinkedIn, He is done with 6 hr training workshop). This week will be the final change on CLI