Skip to content

Commit

Permalink
restructure(secagg): rename methods and update docstrings
Browse files Browse the repository at this point in the history
Signed-off-by: Pant, Akshay <[email protected]>
  • Loading branch information
theakshaypant committed Feb 19, 2025
1 parent e2fccff commit 86f09a4
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 105 deletions.
51 changes: 19 additions & 32 deletions openfl/callbacks/secure_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""
This file contains callbacks that help setup for secure aggregation for
both, the aggregator and collaborator.
This file contains callback that help setup for secure aggregation for the
collaborator.
"""

import json
Expand Down Expand Up @@ -30,11 +30,11 @@
class CollaboratorSecAgg(Callback):
"""
This callback is used by the collaborator to perform the setup steps
for scure aggregation on the collaborators.
for secure aggregation on the collaborators.
Required params include:
- origin: Name of the collaborator using the callback.
- client: Client for aggregator secure aggregation setup.
- client: AggregatorGRPCClient to communicate with the aggregator server.
It also requires the tensor-db client to be set.
"""
Expand All @@ -57,7 +57,7 @@ def on_experiment_begin(self, logs=None):
# aggregator.
self._decrypt_ciphertexts(collaborator_keys)
# Save the tensors which are required for masking of gradients.
self._save_tensors()
self._save_mask_tensors()

def _generate_keys(self):
"""
Expand Down Expand Up @@ -93,21 +93,16 @@ def _generate_keys(self):

def _fetch_public_keys(self):
"""
Fetches public keys from participants and identifies the index of the
current participant's public key.
This method retrieves the public keys from the aggregator's secure
aggregation mechanism. It then iterates through the fetched public
keys to find the index of the current participant's public key based
on the provided parameters.
Fetches collaborators' public keys from the aggregator and identifies
the index of the current collaborator using it's public key.
Returns:
dict: A dictionary containing the public keys of all participants,
where the keys are the participant indices and the values are
dict: A dictionary containing the public keys of all collaborators,
where the keys are the collaborator indices and the values are
the public keys.
"""
public_keys = {}
public_keys_tensor = self._fetch_from_collaborator("public_keys")
public_keys_tensor = self._fetch_from_aggregator("public_keys")
for tensor in public_keys_tensor:
# Creating a dictionary of the received public keys.
public_keys[int(tensor[0])] = [tensor[1], tensor[2]]
Expand Down Expand Up @@ -177,21 +172,21 @@ def _generate_ciphertexts(self, public_keys):

def _decrypt_ciphertexts(self, public_keys):
"""
Decrypts the ciphertexts received from participants using the provided
Decrypts the ciphertexts received from collaborators using the provided
public keys.
This method fetches the ciphertexts from the aggregator, decrypts them
using the participant's private key and the provided public keys, and
using the collaborator's private key and the provided public keys, and
then sends the decrypted seed shares and key shares back to the
participants.
aggregator.
Args:
public_keys (dict): A dictionary containing the public keys of the
participants.
collaborators.
"""
logger.debug("SecAgg: fetching addressed ciphertexts from the aggregator")

ciphertexts = self._fetch_from_collaborator("ciphertexts")
ciphertexts = self._fetch_from_aggregator("ciphertexts")
private_keys = self.params["private_key"]
ciphertext_verification = self.params["ciphertext_verification"]

Expand Down Expand Up @@ -223,21 +218,13 @@ def _generate_masks(self):

return private_mask, shared_mask

def _save_tensors(self):
def _save_mask_tensors(self):
"""
Generates private and shared masks, stores them in a local tensor
dictionary, and caches the dictionary in the tensor database.
These tensors are then added to the gradient before sharing them
with the aggregator during trainign task.
This method performs the following steps:
1. Generates private and shared masks by calling the `_generate_masks`
method.
2. Creates a local tensor dictionary with the generated masks.
3. Caches the local tensor dictionary in the tensor database.
4. Logs an informational message indicating the completion of the
setup and the saving of required tensors to the database.
with the aggregator during training task.
"""
private_mask, shared_mask = self._generate_masks()
local_tensor_dict = {
Expand Down Expand Up @@ -277,9 +264,9 @@ def _send_to_aggregator(self, tensor_dict: dict, stage: str):

self.client.send_local_task_results(self.name, -1, f"secagg_{stage}", -1, named_tensors)

def _fetch_from_collaborator(self, key_name):
def _fetch_from_aggregator(self, key_name):
"""
Fetches the aggregated tensor data from a collaborator.
Fetches the aggregated tensor data from a aggregator.
Args:
key_name (str): The name of the key to fetch the tensor for.
Expand Down
89 changes: 16 additions & 73 deletions openfl/utilities/secagg/setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""
Thsi file contains the Setup class used on the server side for secure aggregation setup.
This file contains the Setup class used on the server side for secure
aggregation setup.
"""

import logging
Expand Down Expand Up @@ -88,10 +89,8 @@ def aggregate_tensor(self, tensor_name):
self._aggregate_public_keys()
elif tensor_name == "ciphertext":
self._aggregate_ciphertexts()
elif tensor_name == "seed_share":
self._aggregate_seed_shares()
elif tensor_name == "key_share":
self._aggregate_key_shares()
elif tensor_name in ["seed_share", "key_share"]:
self._aggregate_secret_shares(tensor_name)

if "seed_shares" in self._results and "key_shares" in self._results:
self._reconstruct_secrets()
Expand Down Expand Up @@ -175,80 +174,24 @@ def _aggregate_ciphertexts(self):
}
)

def _aggregate_seed_shares(self):
def _aggregate_secret_shares(self, key_name):
"""
Aggregates seed shares for each collaborator from the tensor database
and stores them in the results dictionary.
Aggregates secret shares for a given key name from the tensor database.
This method fetches seed shares for each collaborator from the tensor
database using the `get_tensor_from_cache` method.
It then creates a map of seed shares for local use and stores it in
the `self._results["seed_shares"]` dictionary.
The structure of `self._results["seed_shares"]` is as follows:
{
collaborator_id: {
share_id: share_value,
...
},
...
}
"""
self._results["seed_shares"] = {}

for collaborator in self._collaborator_list:
# Seed shares
# Fetching seed shares for each collaborator from tensor db.
nparray = self._tensor_db.get_tensor_from_cache(
TensorKey(
"seed_share",
self._aggregator_uuid,
-1,
False,
(
collaborator,
"secagg",
),
)
)
for share in nparray:
# Creating a map for local use.
if int(share[1]) not in self._results["seed_shares"]:
self._results["seed_shares"][int(share[1])] = {}
self._results["seed_shares"][int(share[1])][int(share[0])] = share[2][2:-1]
database and organizes them into a dictionary for local use.
def _aggregate_key_shares(self):
"""
Aggregates key shares from the tensor database for each collaborator
and stores them in the results dictionary.
This method fetches key shares for each collaborator from the tensor
database and creates a local map of these key shares. The key shares
are stored in the `self._results["key_shares"]` dictionary, where the
keys are the first elements of the shares and the values are
dictionaries mapping the second elements of the shares to the third
elements.
The structure of `self._results["key_shares"]` is as follows:
{
share[0]: {
share[1]: share[2]
}
}
The method assumes that `self._collaborator_list` is a list of
collaborators and `self._tensor_db` is an instance of a tensor
database that has a method `get_tensor_from_cache` which takes a
`TensorKey` object as an argument.
Args:
key_name (str): The name of the key for which secret shares are to
be aggregated.
"""
self._results["key_shares"] = {}
self._results[f"{key_name}s"] = {}

for collaborator in self._collaborator_list:
# Key shares
# Fetching key shares for each collaborator from tensor db.
# Fetching seed shares for each collaborator from tensor db.
nparray = self._tensor_db.get_tensor_from_cache(
TensorKey(
"key_share",
key_name,
self._aggregator_uuid,
-1,
False,
Expand All @@ -260,9 +203,9 @@ def _aggregate_key_shares(self):
)
for share in nparray:
# Creating a map for local use.
if int(share[1]) not in self._results["key_shares"]:
self._results["key_shares"][int(share[1])] = {}
self._results["key_shares"][int(share[1])][int(share[0])] = share[2][2:-1]
if int(share[1]) not in self._results[f"{key_name}s"]:
self._results[f"{key_name}s"][int(share[1])] = {}
self._results[f"{key_name}s"][int(share[1])][int(share[0])] = share[2][2:-1]

def _reconstruct_secrets(self):
"""
Expand Down

0 comments on commit 86f09a4

Please sign in to comment.