Skip to content

Commit

Permalink
Merge branch 'main' into lightning_load_state_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth authored Feb 13, 2025
2 parents bc6f202 + 4efdf78 commit ff7a27b
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 8 deletions.
25 changes: 24 additions & 1 deletion examples/hello-world/hello-flower/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ If you haven't already, we recommend creating a virtual environment.
python3 -m venv nvflare_flwr
source nvflare_flwr/bin/activate
```

We recommend installing an older version of NumPy as torch/torchvision doesn't support NumPy 2 at this time.
```bash
pip install numpy==1.26.4
```
## 2.1 Run a simulation

To run flwr-pt job with NVFlare, we first need to install its dependencies.
Expand All @@ -49,3 +52,23 @@ the TensorBoard metrics to the server at each iteration using NVFlare's metric s
```bash
python job.py --job_name "flwr-pt-tb" --content_dir "./flwr-pt-tb" --stream_metrics
```

You can visualize the metrics streamed to the server using TensorBoard.
```bash
tensorboard --logdir /tmp/nvflare/hello-flower
```
![tensorboard training curve](./train.png)

## Notes
Make sure your `pyproject.toml` files in the Flower apps contain an "address" field. This needs to be present as the `--federation-config` option of the `flwr run` command tries to override the `“address”` field.
Your `pyproject.toml` should include a section similar to this:
```
[tool.flwr.federations]
default = "xxx"
[tool.flwr.federations.xxx]
options.num-supernodes = 2
address = "127.0.0.1:9093"
insecure = false
```
The number `options.num-supernodes` should match the number of NVFlare clients defined in [job.py](./job.py), e.g., `job.simulator_run(args.workdir, gpu="0", n_clients=2)`.
14 changes: 12 additions & 2 deletions examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
# initializes NVFlare interface
from nvflare.client.tracking import SummaryWriter

flare.init()


# Define FlowerClient and client_fn
class FlowerClient(NumPyClient):
Expand Down Expand Up @@ -81,3 +79,15 @@ def client_fn(context: Context):
app = ClientApp(
client_fn=client_fn,
)


@app.enter()
def enter(ctxt: Context) -> None:
flare.init()
print("ClientApp entering. Flare initialized.")


@app.exit()
def exit(ctxt: Context) -> None:
flare.shutdown()
print("ClientApp exiting. Flare shutdown.")
6 changes: 4 additions & 2 deletions examples/hello-world/hello-flower/flwr-pt-tb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.11.0,<2.0",
"nvflare~=2.5.0rc",
"flwr[simulation]>=1.15.2,<2.0",
"nvflare~=2.6.0rc",
"torch==2.2.1",
"torchvision==0.17.1",
"tensorboard"
Expand All @@ -33,3 +33,5 @@ default = "local-simulation"

[tool.flwr.federations.local-simulation]
options.num-supernodes = 2
address = "127.0.0.1:9093"
insecure = true
7 changes: 5 additions & 2 deletions examples/hello-world/hello-flower/flwr-pt/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.11.0,<2.0",
"nvflare~=2.5.0rc",
"flwr[simulation]>=1.15.2,<2.0",
"nvflare~=2.6.0rc",
"torch==2.2.1",
"torchvision==0.17.1",
"tensorboard"
]

[tool.hatch.build.targets.wheel]
Expand All @@ -32,3 +33,5 @@ default = "local-simulation"

[tool.flwr.federations.local-simulation]
options.num-supernodes = 2
address = "127.0.0.1:9093"
insecure = true
Binary file added examples/hello-world/hello-flower/train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion nvflare/app_opt/flower/applet.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _run_flower_command(self, command: str):

success = result.get("success", False)
if not success:
err = f"failed command '{command}': {success=}"
err = f"failed command '{command}': {success=} {result=}"
self.logger.error(err)
raise RuntimeError(err)

Expand Down

0 comments on commit ff7a27b

Please sign in to comment.