Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flower script #2754

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/getting_started/flower/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Flower App (PyTorch) in NVIDIA FLARE

In this example, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator with the JobAPI.

## Preconditions

To run Flower code in NVFlare, simply install the requirements and run the below script.

Note, this code is directly copied from Flower's [app-pytorch](https://github.com/adap/flower/tree/main/examples/app-pytorch) example.

## Install dependencies
To run this job with NVFlare, we first need to install the dependencies.
```bash
pip install -r requirements.txt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might need to ping specific version here as Flower is still changing their api implementation?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we do need to pin a specific version of flower.
I asked them for a fix version number but have not got response from them. Their main developer Pan is attending a conference. Not sure when he will be able to get back to us.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I talked to pan, he said already respond to you ( he message from LinkedIn, He is done with 6 hr training workshop). This week will be the final change on CLI

```

## Run a simulation

Next, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator.
```bash
python fedavg_flwr_cifar10.py
```
54 changes: 54 additions & 0 deletions examples/getting_started/flower/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from task import DEVICE, Net, get_weights, load_data, set_weights, test, train

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data()


# Define FlowerClient and client_fn
class FlowerClient(NumPyClient):
def fit(self, parameters, config):
set_weights(net, parameters)
results = train(net, trainloader, testloader, epochs=1, device=DEVICE)
return get_weights(net), len(trainloader.dataset), results

def evaluate(self, parameters, config):
set_weights(net, parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {"accuracy": accuracy}


def client_fn(context: Context):
"""Create and return an instance of Flower `Client`."""
return FlowerClient().to_client()


# Flower ClientApp
app = ClientApp(
client_fn=client_fn,
)


# Legacy mode
if __name__ == "__main__":
from flwr.client import start_client

start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
40 changes: 40 additions & 0 deletions examples/getting_started/flower/fedavg_flwr_cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare import FedJob
from nvflare.app_opt.flower.controller import FlowerController
from nvflare.app_opt.flower.executor import FlowerExecutor

if __name__ == "__main__":
n_clients = 2
num_rounds = 2

job = FedJob(name="cifar10_flwr")

# Define the controller workflow and send to server
controller = FlowerController(server_app="server:app")
job.to_server(controller)

# Add flwr server code
job.to_server("server.py")

# Add clients
executor = FlowerExecutor(client_app="client:app")
job.to_clients(executor)

# Add flwr client code
job.to_clients("client.py")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This syntax is strange, I don't think we should do this. should the client.py goes to the FlowerExecutor,

executor = FlowerExecutor(client_app ="client:app", task_scripts = "client.py")
job.to_clients(executor)

the syntax

Job.to(, ) , seems very strange to me, and in most cases, doesn't make sense. I didnt know until QA highlight this.

Not all executor has task executor file. Add special syntax, take any string path seem to be a big assumption.
And also different from the rest of syntax, job.to(x, ), where x is component.

    elif isinstance(obj, str):  # treat the str type object as external script
            if target not in self._deploy_map:
                raise ValueError(
                    f"{target} doesn't have a `Controller` or `Executor`. Deploy one first before adding external script!"
                )

            self._deploy_map[target].add_external_scripts([obj])

I strongly suggest disasble this syntax. @yanchengnv @YuanTingHsieh @yhwen @holgerroth what are you opinions on this ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding @yhwen. Supporting placing of scripts to server/clients with to() is a general purpose solution. We have that built into the JobAPI to support any component that might refer to some code that needs to be deployed as part of the job. For example, now flower might also require a . toml file to be place at clients/server.

So, to me it makes sense and is an extensible solution for different use cases.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the Git Rendering wipe out the original string, what wrote was "job.to( left bracket "<" string path right bracket ">", site_name)
not comment to job.to() but specific take any string and add external path.

job.to(project.toml, site-1)
job.to(train.py, site-2)
job.to("my name", site-3)

all goes to an external directory feel strange to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could do it so python modules are supported directly

import client
job.to_clients(client)

but then you still cannot add config files that might be needed by a component.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now, I am looking at .to_client() syntax, I also confused. which site_name this one add to ? it is suppose to be an syntax sugar wrapper. If the Client is Client model intend to have script executor, why the component doesnt include script or configure file as argument ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's to_clients(). Only use when sending the component to all clients.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, to_clients() make sense, but this is not address the original question though

Copy link
Collaborator

@yanchengnv yanchengnv Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The job.to(thing, target) syntax is confusing since the 1st arg could be many things. Perhaps we could improve by defining multiple input args, one arg for each type of things, like this:

job.to(scripts=[list of scripts], dirs=[list of dirs], objects=[list of objs], targets=[list of targets])

It is possible or maybe even desirable to send multiple types of things in one call - you shouldn't have to call job.to() many times.


job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", n_clients=n_clients)
4 changes: 4 additions & 0 deletions examples/getting_started/flower/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
nvflare~=2.5.0rc
flwr[simulation]>=1.8.0
torch==2.2.1
torchvision==0.17.1
75 changes: 75 additions & 0 deletions examples/getting_started/flower/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple

from flwr.common import Metrics, ndarrays_to_parameters
from flwr.server import ServerApp, ServerConfig
from flwr.server.strategy import FedAvg
from task import Net, get_weights


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
examples = [num_examples for num_examples, _ in metrics]

# Multiply accuracy of each client by number of examples used
train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
train_accuracies = [num_examples * m["train_accuracy"] for num_examples, m in metrics]
val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics]
val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics]

# Aggregate and return custom metric (weighted average)
return {
"train_loss": sum(train_losses) / sum(examples),
"train_accuracy": sum(train_accuracies) / sum(examples),
"val_loss": sum(val_losses) / sum(examples),
"val_accuracy": sum(val_accuracies) / sum(examples),
}


# Initialize model parameters
ndarrays = get_weights(Net())
parameters = ndarrays_to_parameters(ndarrays)


# Define strategy
strategy = FedAvg(
fraction_fit=1.0, # Select all available clients
fraction_evaluate=0.0, # Disable evaluation
min_available_clients=2,
fit_metrics_aggregation_fn=weighted_average,
initial_parameters=parameters,
)


# Define config
config = ServerConfig(num_rounds=3)


# Flower ServerApp
app = ServerApp(
config=config,
strategy=strategy,
)


# Legacy mode
if __name__ == "__main__":
from flwr.server import start_server

start_server(
server_address="0.0.0.0:8080",
config=config,
strategy=strategy,
)
106 changes: 106 additions & 0 deletions examples/getting_started/flower/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from logging import INFO

import torch
import torch.nn as nn
import torch.nn.functional as F
from flwr.common.logger import log
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)


def load_data():
"""Load CIFAR-10 (training and test set)."""
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10("./data", train=True, download=True, transform=trf)
testset = CIFAR10("./data", train=False, download=True, transform=trf)
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)


def train(net, trainloader, valloader, epochs, device):
"""Train the model on the training set."""
log(INFO, "Starting training...")
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
net.train()
for _ in range(epochs):
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()

train_loss, train_acc = test(net, trainloader)
val_loss, val_acc = test(net, valloader)

results = {
"train_loss": train_loss,
"train_accuracy": train_acc,
"val_loss": val_loss,
"val_accuracy": val_acc,
}
return results


def test(net, testloader):
"""Validate the model on the test set."""
net.to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for images, labels in testloader:
outputs = net(images.to(DEVICE))
labels = labels.to(DEVICE)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
return loss, accuracy


def get_weights(net):
return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_weights(net, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from task import DEVICE, Net, get_weights, load_data, set_weights, test, train

# Load model and data (simple CNN, CIFAR-10)
Expand All @@ -32,7 +33,7 @@ def evaluate(self, parameters, config):
return loss, len(testloader.dataset), {"accuracy": accuracy}


def client_fn(cid: str):
def client_fn(context: Context):
"""Create and return an instance of Flower `Client`."""
return FlowerClient().to_client()

Expand Down
Loading