Skip to content

Commit

Permalink
make fastdigest an optional
Browse files Browse the repository at this point in the history
  • Loading branch information
chesterxgchen committed Feb 20, 2025
1 parent 95d4e27 commit 6e3e666
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@


class DFStatistics(DFStatisticsCore):

def __init__(self, filename, data_root_dir="/tmp/nvflare/df_stats/data"):
super().__init__()
self.data_root_dir = data_root_dir
Expand Down
1 change: 0 additions & 1 deletion examples/advanced/streaming/src/simple_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


class SimpleController(Controller):

def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
logger.info(f"Entering control loop of {self.__class__.__name__}")
engine = fl_ctx.get_engine()
Expand Down
2 changes: 0 additions & 2 deletions examples/advanced/streaming/src/standalone_file_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@


class FileSender(FLComponent):

def __init__(self):
super().__init__()
self.seq = 0
Expand Down Expand Up @@ -73,7 +72,6 @@ def _sending_file(self, fl_ctx):


class FileReceiver(FLComponent):

def __init__(self):
super().__init__()
self.done = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@


class FedAvgV1(BaseFedAvg):

def __init__(
self,
*args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@


class HelloController(Controller):

def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
# Create the task with name "hello"
task = Task(name="hello", data=Shareable())
Expand All @@ -43,7 +42,6 @@ def stop_controller(self, fl_ctx: FLContext):


class HelloExecutor(Executor):

def execute(
self,
task_name: str,
Expand All @@ -57,7 +55,6 @@ def execute(


class HelloDataController(Controller):

def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
# Prepare any extra parameters to send to the clients
data = DXO(
Expand All @@ -84,7 +81,6 @@ def stop_controller(self, fl_ctx: FLContext):


class HelloDataExecutor(Executor):

def execute(
self,
task_name: str,
Expand All @@ -100,7 +96,6 @@ def execute(


class HelloResponseController(Controller):

def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
# Prepare any extra parameters to send to the clients
dxo = DXO(
Expand Down Expand Up @@ -136,7 +131,6 @@ def _process_client_response(self, client_task, fl_ctx: FLContext) -> None:


class HelloResponseExecutor(Executor):

def execute(
self,
task_name: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@


class BasicController(Controller):

def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
self.broadcast_and_wait(
task=Task(name="talk", data=Shareable()),
Expand All @@ -40,7 +39,6 @@ def stop_controller(self, fl_ctx: FLContext):


class P2PExecutor(Executor):

def execute(
self,
task_name: str,
Expand Down
39 changes: 35 additions & 4 deletions nvflare/app_common/statistics/numeric_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import sys
from math import sqrt
from typing import Dict, List, TypeVar

from fastdigest import TDigest

from nvflare.app_common.abstract.statistics_spec import Bin, BinRange, DataType, Feature, Histogram, HistogramType
from nvflare.app_common.app_constant import StatisticsConstants as StC
from nvflare.app_common.statistics.statistics_config_utils import get_target_quantiles
from nvflare.fuel.utils.import_utils import optional_import

T = TypeVar("T")

logger = logging.getLogger(__name__)


def get_global_feature_data_types(
client_feature_dts: Dict[str, Dict[str, List[Feature]]]
client_feature_dts: Dict[str, Dict[str, List[Feature]]],
) -> Dict[str, Dict[str, DataType]]:
global_feature_data_types = {}
for client_name in client_feature_dts:
Expand Down Expand Up @@ -86,6 +89,9 @@ def get_global_stats(

global_metrics[StC.STATS_STDDEV] = ds_stddev
elif metric == StC.STATS_QUANTILE:
if not check_fastdigest_installed():
continue

global_digest = {}
for client_name in stats:
global_digest = merge_quantiles(stats[client_name], global_digest)
Expand Down Expand Up @@ -233,6 +239,12 @@ def filter_numeric_features(ds_features: Dict[str, List[Feature]]) -> Dict[str,


def merge_quantiles(metrics: Dict[str, Dict[str, Dict]], g_digest: dict) -> dict:
TDigest = check_and_import_tdigest()
if TDigest is None:
return g_digest

from fastdigest import TDigest

for ds_name in metrics:
if ds_name not in g_digest:
g_digest[ds_name] = {}
Expand All @@ -250,7 +262,26 @@ def merge_quantiles(metrics: Dict[str, Dict[str, Dict]], g_digest: dict) -> dict
return g_digest


def compute_quantiles(g_digest: Dict[str, Dict[str, TDigest]], quantile_config: Dict, precision: int = 4) -> dict:
def check_fastdigest_installed():
fastdigest, flag = optional_import("fastdigest")
if not flag:
logger.error("fastdigest is not installed. Please install it using 'pip install fastdigest'.")
return False
return True


def check_and_import_tdigest():
TDigest, flag = optional_import("fastdigest", name="TDigest")
if not flag:
logger.error("TDigest is not installed. Please install it using 'pip install fastdigest'.")
return None
return TDigest


def compute_quantiles(g_digest, quantile_config: Dict, precision: int = 4) -> dict:
if not check_fastdigest_installed():
return {}

g_ds_metrics = {}
for ds_name in g_digest:
if ds_name not in g_ds_metrics:
Expand Down
6 changes: 5 additions & 1 deletion nvflare/app_common/workflows/statistics_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nvflare.app_common.abstract.statistics_spec import Bin, Histogram, StatisticConfig
from nvflare.app_common.abstract.statistics_writer import StatisticsWriter
from nvflare.app_common.app_constant import StatisticsConstants as StC
from nvflare.app_common.statistics.numeric_stats import get_global_stats
from nvflare.app_common.statistics.numeric_stats import check_fastdigest_installed, get_global_stats
from nvflare.app_common.statistics.statisitcs_objects_decomposer import fobs_registration
from nvflare.fuel.utils import fobs

Expand Down Expand Up @@ -410,6 +410,10 @@ def _combine_all_statistics(self):
buckets = StatisticsController._apply_histogram_precision(hist.bins, self.precision)
result[feature_name][statistic][client][ds] = buckets
elif statistic == StC.STATS_QUANTILE:

if not check_fastdigest_installed():
continue

quantiles = self.client_statistics[statistic][client][ds][feature_name][StC.STATS_QUANTILE]
formatted_quantiles = {}
for p in quantiles:
Expand Down
6 changes: 5 additions & 1 deletion nvflare/app_opt/statistics/df/df_core_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

import numpy as np
import pandas as pd
from fastdigest import TDigest
from pandas.core.series import Series

from nvflare.app_common.abstract.statistics_spec import BinRange, Feature, Histogram, HistogramType, Statistics
from nvflare.app_common.app_constant import StatisticsConstants
from nvflare.app_common.statistics.numeric_stats import check_and_import_tdigest
from nvflare.app_common.statistics.numpy_utils import dtype_to_data_type, get_std_histogram_buckets


Expand Down Expand Up @@ -95,6 +95,10 @@ def min_value(self, dataset_name: str, feature_name: str) -> float:
return df[feature_name].min()

def quantiles(self, dataset_name: str, feature_name: str, percents: List) -> Dict:
TDigest = check_and_import_tdigest()
if TDigest is None:
return {}

df = self.data[dataset_name]
data = df[feature_name]
max_bin = self.max_bin if self.max_bin else round(sqrt(len(data)))
Expand Down
23 changes: 21 additions & 2 deletions tests/unit_test/app_common/statistics/quantile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,20 @@

import numpy as np
import pandas as pd
from fastdigest import TDigest
import pytest

from nvflare.apis.fl_context import FLContext
from nvflare.app_common.app_constant import StatisticsConstants
from nvflare.app_common.statistics.numeric_stats import compute_quantiles, merge_quantiles
from nvflare.app_opt.statistics.df.df_core_statistics import DFStatisticsCore
from nvflare.fuel.utils.import_utils import optional_import

try:
from fastdigest import TDigest

TDIGEST_AVAILABLE = True
except ImportError:
TDIGEST_AVAILABLE = False


class MockDFStats(DFStatisticsCore):
Expand Down Expand Up @@ -62,7 +70,7 @@ def load_data(self):


class TestQuantile:

@pytest.mark.skipif(not TDIGEST_AVAILABLE, reason="fastdigest is not installed")
def test_tdigest1(self):
# Small dataset
data = [1, 2, 3, 4, 5]
Expand All @@ -75,6 +83,7 @@ def test_tdigest1(self):
assert fd.quantile(0.5) == np_data.quantile(0.5)
assert fd.quantile(0.75) == np_data.quantile(0.75)

@pytest.mark.skipif(not TDIGEST_AVAILABLE, reason="fastdigest is not installed")
def test_tdigest2(self):
# Small dataset
data = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
Expand All @@ -86,6 +95,7 @@ def test_tdigest2(self):
assert fd.quantile(0.5) == np_data.quantile(0.5)
assert fd.quantile(0.75) == np_data.quantile(0.75)

@pytest.mark.skipif(not TDIGEST_AVAILABLE, reason="fastdigest is not installed")
def test_tdigest3(self):
# Small dataset
data = [-50.0, -40.4, -30.3, -20.3, -10.1, 0, 1.1, 2.2, 3.3, 4.4, 5.5]
Expand All @@ -97,6 +107,7 @@ def test_tdigest3(self):
assert round(fd.quantile(0.5), 2) == np_data.quantile(0.5)
assert round(fd.quantile(0.75), 2) == np_data.quantile(0.75)

@pytest.mark.skipif(not TDIGEST_AVAILABLE, reason="fastdigest is not installed")
def test_tdigest4(self):
# Small dataset
data = [-5, -4, -3, -2, -1, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
Expand All @@ -108,6 +119,7 @@ def test_tdigest4(self):
assert round(fd.quantile(0.5), 2) == np_data.quantile(0.5)
assert round(fd.quantile(0.75), 2) == np_data.quantile(0.75)

@pytest.mark.skipif(not TDIGEST_AVAILABLE, reason="fastdigest is not installed")
def test_tdigest5(self):
# Small dataset
data1 = [x for x in range(-5, 0)]
Expand All @@ -119,6 +131,7 @@ def test_tdigest5(self):
assert fd.quantile(0.1) == -4
assert fd.quantile(0.9) == 4

@pytest.mark.skipif(not TDIGEST_AVAILABLE, reason="fastdigest is not installed")
def test_tdigest6(self):
# Small dataset
data1 = [x for x in range(-10000, 0)]
Expand All @@ -133,6 +146,7 @@ def test_tdigest6(self):
assert fdx.quantile(0.5) == np_data.quantile(0.5)
assert merged_fd.quantile(0.5) == np_data.quantile(0.5)

@pytest.mark.skipif(not TDIGEST_AVAILABLE, reason="fastdigest is not installed")
def test_tdigest7(self):
median = 10
data = np.concatenate((np.arange(0, median), [median], np.arange(median + 1, median * 2 + 1)))
Expand All @@ -147,6 +161,7 @@ def test_tdigest7(self):

assert v == median

@pytest.mark.skipif(not TDIGEST_AVAILABLE, reason="fastdigest is not installed")
def test_tdigest8(self):

median = 10
Expand All @@ -170,6 +185,7 @@ def test_tdigest8(self):

assert v == median

@pytest.mark.skipif(not TDIGEST_AVAILABLE, reason="fastdigest is not installed")
def test_tdigest_merge_serde(self):

median = 10
Expand All @@ -193,6 +209,7 @@ def test_tdigest_merge_serde(self):

assert v == median

@pytest.mark.skipif(not TDIGEST_AVAILABLE, reason="fastdigest is not installed")
def test_tdigest_compress(self):

digest = TDigest(range(101))
Expand All @@ -217,6 +234,7 @@ def test_tdigest_compress(self):
assert before_75 == after_75
assert len(digest) == 10

@pytest.mark.skipif(not TDIGEST_AVAILABLE, reason="fastdigest is not installed")
def test_percentile_metrics(self):
stats_generator = MockDFStats(given_median=100)
stats_generator.load_data()
Expand All @@ -229,6 +247,7 @@ def test_percentile_metrics(self):

assert result.get(0.5) == stats_generator.median

@pytest.mark.skipif(not TDIGEST_AVAILABLE, reason="TDigest package not installed")
def test_percentile_metrics_aggregation(self):
stats_generators = [
MockDFStats2(data_array=[0, 1, 2, 3, 4, 5, 6]),
Expand Down

0 comments on commit 6e3e666

Please sign in to comment.