Skip to content

Commit

Permalink
Optimizations in collaborator flow
Browse files Browse the repository at this point in the history
Signed-off-by: Agrawal, Kush <[email protected]>
  • Loading branch information
kagrawa2 committed Feb 11, 2025
1 parent a306683 commit 6101894
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 25 deletions.
9 changes: 7 additions & 2 deletions openfl-workspace/torch/histology/src/taskrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,12 @@ def train_(
loss.backward()
self.optimizer.step()
losses.append(loss.detach().cpu().numpy())
loss = np.mean(losses)
return Metric(name=self.loss_fn.__name__, value=np.array(loss))

del data, target, output, loss
loss_arr = np.array(np.mean(losses))
losses = None
del losses
return Metric(name=self.loss_fn.__name__, value=loss_arr)

def validate_(
self, validation_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]
Expand Down Expand Up @@ -135,6 +139,7 @@ def validate_(
# get the index of the max log-probability
pred = output.argmax(dim=1)
val_score += pred.eq(target).sum().cpu().numpy()
del data, target, output, pred

accuracy = val_score / total_samples
return Metric(name="accuracy", value=np.array(accuracy))
9 changes: 8 additions & 1 deletion openfl/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def run(self):
for task in tasks:
metrics = self.do_task(task, round_num)
logs.update(metrics)
metrics = None
del metrics

# Round end
self.tensor_db.clean_up(self.db_store_rounds)
Expand Down Expand Up @@ -335,8 +337,12 @@ def do_task(self, task, round_number) -> dict:
# send the results for this tasks; delta and compression will occur in
# this function
metrics = self.send_task_results(global_output_tensor_dict, round_number, task_name)
return metrics

del global_output_tensor_dict
del local_output_tensor_dict
del input_tensor_dict
return metrics

def get_numpy_dict_for_tensorkeys(self, tensor_keys):
"""Get tensor dictionary for specified tensorkey set.
Expand Down Expand Up @@ -523,6 +529,7 @@ def send_task_results(self, tensor_dict, round_number, task_name) -> dict:
named_tensors,
)

del named_tensors
return metrics

def nparray_to_named_tensor(self, tensor_key, nparray):
Expand Down
40 changes: 25 additions & 15 deletions openfl/databases/tensor_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,17 @@ def clean_up(self, remove_older_than: int = 1) -> None:
current_round = self.tensor_db["round"].astype(int).max()
if current_round == ROUND_PLACEHOLDER:
current_round = np.sort(self.tensor_db["round"].astype(int).unique())[-2]
# Keep only recent records
old_tensor_db = self.tensor_db
self.tensor_db = self.tensor_db[
(self.tensor_db["round"].astype(int) > current_round - remove_older_than)
| self.tensor_db["report"]
].reset_index(drop=True)
].copy() # Avoid unnecessary memory retention

self.tensor_db.reset_index(drop=True, inplace=True)

# Delete old DataFrame
del old_tensor_db

def cache_tensor(self, tensor_key_dict: Dict[TensorKey, np.ndarray]) -> None:
"""Insert a tensor into TensorDB (dataframe).
Expand All @@ -106,25 +113,28 @@ def cache_tensor(self, tensor_key_dict: Dict[TensorKey, np.ndarray]) -> None:
"""
entries_to_add = []
with self.mutex:
old_tensor_db = self.tensor_db
for tensor_key, nparray in tensor_key_dict.items():
tensor_name, origin, fl_round, report, tags = tensor_key
entries_to_add.append(
pd.DataFrame(
new_entry = pd.DataFrame(
[
[
[
tensor_name,
origin,
fl_round,
report,
tags,
nparray,
]
],
columns=list(self.tensor_db.columns),
)
tensor_name,
origin,
fl_round,
report,
tags,
nparray,
]
],
columns=list(self.tensor_db.columns),
)
entries_to_add.append(new_entry)

self.tensor_db = pd.concat([self.tensor_db, *entries_to_add], ignore_index=True, copy=True)

self.tensor_db = pd.concat([self.tensor_db, *entries_to_add], ignore_index=True)
del old_tensor_db
entries_to_add.clear()

def get_tensor_from_cache(self, tensor_key: TensorKey) -> Optional[np.ndarray]:
"""Perform a lookup of the tensor_key in the TensorDB.
Expand Down
10 changes: 6 additions & 4 deletions openfl/federated/task/runner_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from copy import deepcopy
from typing import Iterator, Tuple

import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -476,14 +475,16 @@ def train_(self, train_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]) -> M
losses = []
for data, target in train_dataloader:
data, target = torch.tensor(data).to(self.device), torch.tensor(target).to(self.device)
self.optimizer.zero_grad()
self.optimizer.zero_grad(set_to_none=True)
output = self(data)
loss = self.loss_fn(output=output, target=target)
loss.backward()
self.optimizer.step()
losses.append(loss.detach().cpu().numpy())
loss = np.mean(losses)
return Metric(name=self.loss_fn.__name__, value=np.array(loss))
loss_arr = np.array(np.mean(losses))
losses = None
del losses, output
return Metric(name=self.loss_fn.__name__, value=loss_arr)

def validate_(self, validation_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]) -> Metric:
"""
Expand Down Expand Up @@ -514,6 +515,7 @@ def validate_(self, validation_dataloader: Iterator[Tuple[np.ndarray, np.ndarray
val_score += pred.eq(target).sum().cpu().numpy()

accuracy = val_score / total_samples
del output, pred
return Metric(name="accuracy", value=np.array(accuracy))


Expand Down
10 changes: 7 additions & 3 deletions openfl/protocols/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0

"""Proto utils."""

from openfl.protocols import base_pb2
from openfl.utilities import TensorKey

Expand Down Expand Up @@ -205,6 +204,7 @@ def construct_model_proto(tensor_dict, round_number, tensor_pipe):
)
)

del bytes_data, transformer_metadata
return base_pb2.ModelProto(tensors=named_tensors)


Expand Down Expand Up @@ -330,6 +330,7 @@ def proto_to_datastream(proto, logger, max_buffer_size=(2 * 1024 * 1024)):
reply: Chunks of the data stream for the remote connection.
"""
npbytes = proto.SerializeToString()
npbytes_view = memoryview(npbytes)
data_size = len(npbytes)
buffer_size = data_size if max_buffer_size > data_size else max_buffer_size
logger.debug(
Expand All @@ -339,10 +340,13 @@ def proto_to_datastream(proto, logger, max_buffer_size=(2 * 1024 * 1024)):
)

for i in range(0, data_size, buffer_size):
chunk = npbytes[i : i + buffer_size]
reply = base_pb2.DataStream(npbytes=chunk, size=len(chunk))
chunk = bytes(npbytes_view[i : i + buffer_size])
reply = base_pb2.DataStream(npbytes=chunk, size=buffer_size)
yield reply
reply = None
chunk = None

del npbytes, npbytes_view

def get_headers(context) -> dict:
"""Get headers from context.
Expand Down
3 changes: 3 additions & 0 deletions openfl/transport/grpc/aggregator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ def send_local_task_results(

# also do other validation, like on the round_number
self.validate_response(response, collaborator_name)
del request
del stream
del response

def _get_trained_model(self, experiment_name, model_type):
"""Get trained model RPC.
Expand Down

0 comments on commit 6101894

Please sign in to comment.