From 52b08dfc1559b74896457d2f2e0ecaea055cf66f Mon Sep 17 00:00:00 2001 From: Holger Roth Date: Thu, 1 Aug 2024 19:28:11 -0400 Subject: [PATCH] add flower script --- examples/getting_started/flower/README.md | 22 ++++ examples/getting_started/flower/client.py | 54 +++++++++ .../flower/fedavg_flwr_cifar10.py | 40 +++++++ .../getting_started/flower/requirements.txt | 4 + examples/getting_started/flower/server.py | 75 +++++++++++++ examples/getting_started/flower/task.py | 106 ++++++++++++++++++ .../jobs/hello-flwr-pt/app/custom/client.py | 3 +- 7 files changed, 303 insertions(+), 1 deletion(-) create mode 100644 examples/getting_started/flower/README.md create mode 100644 examples/getting_started/flower/client.py create mode 100644 examples/getting_started/flower/fedavg_flwr_cifar10.py create mode 100644 examples/getting_started/flower/requirements.txt create mode 100644 examples/getting_started/flower/server.py create mode 100644 examples/getting_started/flower/task.py diff --git a/examples/getting_started/flower/README.md b/examples/getting_started/flower/README.md new file mode 100644 index 0000000000..e307cc98b5 --- /dev/null +++ b/examples/getting_started/flower/README.md @@ -0,0 +1,22 @@ +# Flower App (PyTorch) in NVIDIA FLARE + +In this example, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator with the JobAPI. + +## Preconditions + +To run Flower code in NVFlare, simply install the requirements and run the below script. + +Note, this code is directly copied from Flower's [app-pytorch](https://github.com/adap/flower/tree/main/examples/app-pytorch) example. + +## Install dependencies +To run this job with NVFlare, we first need to install the dependencies. +```bash +pip install -r requirements.txt +``` + +## Run a simulation + +Next, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator. +```bash +python fedavg_flwr_cifar10.py +``` diff --git a/examples/getting_started/flower/client.py b/examples/getting_started/flower/client.py new file mode 100644 index 0000000000..27fcac08aa --- /dev/null +++ b/examples/getting_started/flower/client.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flwr.client import ClientApp, NumPyClient +from flwr.common import Context +from task import DEVICE, Net, get_weights, load_data, set_weights, test, train + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) +trainloader, testloader = load_data() + + +# Define FlowerClient and client_fn +class FlowerClient(NumPyClient): + def fit(self, parameters, config): + set_weights(net, parameters) + results = train(net, trainloader, testloader, epochs=1, device=DEVICE) + return get_weights(net), len(trainloader.dataset), results + + def evaluate(self, parameters, config): + set_weights(net, parameters) + loss, accuracy = test(net, testloader) + return loss, len(testloader.dataset), {"accuracy": accuracy} + + +def client_fn(context: Context): + """Create and return an instance of Flower `Client`.""" + return FlowerClient().to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, +) + + +# Legacy mode +if __name__ == "__main__": + from flwr.client import start_client + + start_client( + server_address="127.0.0.1:8080", + client=FlowerClient().to_client(), + ) diff --git a/examples/getting_started/flower/fedavg_flwr_cifar10.py b/examples/getting_started/flower/fedavg_flwr_cifar10.py new file mode 100644 index 0000000000..4da1e97d44 --- /dev/null +++ b/examples/getting_started/flower/fedavg_flwr_cifar10.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare import FedJob +from nvflare.app_opt.flower.controller import FlowerController +from nvflare.app_opt.flower.executor import FlowerExecutor + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 2 + + job = FedJob(name="cifar10_flwr") + + # Define the controller workflow and send to server + controller = FlowerController(server_app="server:app") + job.to_server(controller) + + # Add flwr server code + job.to_server("server.py") + + # Add clients + executor = FlowerExecutor(client_app="client:app") + job.to_clients(executor) + + # Add flwr client code + job.to_clients("client.py") + + job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir", n_clients=n_clients) diff --git a/examples/getting_started/flower/requirements.txt b/examples/getting_started/flower/requirements.txt new file mode 100644 index 0000000000..1d8990f84a --- /dev/null +++ b/examples/getting_started/flower/requirements.txt @@ -0,0 +1,4 @@ +nvflare~=2.5.0rc +flwr[simulation]>=1.8.0 +torch==2.2.1 +torchvision==0.17.1 diff --git a/examples/getting_started/flower/server.py b/examples/getting_started/flower/server.py new file mode 100644 index 0000000000..8083a6b802 --- /dev/null +++ b/examples/getting_started/flower/server.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple + +from flwr.common import Metrics, ndarrays_to_parameters +from flwr.server import ServerApp, ServerConfig +from flwr.server.strategy import FedAvg +from task import Net, get_weights + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [num_examples * m["train_accuracy"] for num_examples, m in metrics] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Initialize model parameters +ndarrays = get_weights(Net()) +parameters = ndarrays_to_parameters(ndarrays) + + +# Define strategy +strategy = FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=2, + fit_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, +) + + +# Define config +config = ServerConfig(num_rounds=3) + + +# Flower ServerApp +app = ServerApp( + config=config, + strategy=strategy, +) + + +# Legacy mode +if __name__ == "__main__": + from flwr.server import start_server + + start_server( + server_address="0.0.0.0:8080", + config=config, + strategy=strategy, + ) diff --git a/examples/getting_started/flower/task.py b/examples/getting_started/flower/task.py new file mode 100644 index 0000000000..7a5c1a0514 --- /dev/null +++ b/examples/getting_started/flower/task.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict +from logging import INFO + +import torch +import torch.nn as nn +import torch.nn.functional as F +from flwr.common.logger import log +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def load_data(): + """Load CIFAR-10 (training and test set).""" + trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = CIFAR10("./data", train=True, download=True, transform=trf) + testset = CIFAR10("./data", train=False, download=True, transform=trf) + return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) + + +def train(net, trainloader, valloader, epochs, device): + """Train the model on the training set.""" + log(INFO, "Starting training...") + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + net.train() + for _ in range(epochs): + for images, labels in trainloader: + images, labels = images.to(device), labels.to(device) + optimizer.zero_grad() + loss = criterion(net(images), labels) + loss.backward() + optimizer.step() + + train_loss, train_acc = test(net, trainloader) + val_loss, val_acc = test(net, valloader) + + results = { + "train_loss": train_loss, + "train_accuracy": train_acc, + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results + + +def test(net, testloader): + """Validate the model on the test set.""" + net.to(DEVICE) + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for images, labels in testloader: + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def get_weights(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_weights(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) diff --git a/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/client.py b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/client.py index 9e674b27c3..27fcac08aa 100644 --- a/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/client.py +++ b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/client.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from flwr.client import ClientApp, NumPyClient +from flwr.common import Context from task import DEVICE, Net, get_weights, load_data, set_weights, test, train # Load model and data (simple CNN, CIFAR-10) @@ -32,7 +33,7 @@ def evaluate(self, parameters, config): return loss, len(testloader.dataset), {"accuracy": accuracy} -def client_fn(cid: str): +def client_fn(context: Context): """Create and return an instance of Flower `Client`.""" return FlowerClient().to_client()