From 3aefb2d8ac7b9d8102f4da72651daa381c3a0b5e Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Fri, 21 Feb 2025 20:01:51 -0500 Subject: [PATCH] Remove meta and enable uint8 quantization (#3222) Fixes # . ### Description Bug in dequantization meta removal, also enabling uint8 quantization ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --- .../app_opt/pt/quantization/dequantizor.py | 10 +++--- nvflare/app_opt/pt/quantization/quantizor.py | 36 ++++++++++++++----- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/nvflare/app_opt/pt/quantization/dequantizor.py b/nvflare/app_opt/pt/quantization/dequantizor.py index d19a9584ec..bd63da429d 100644 --- a/nvflare/app_opt/pt/quantization/dequantizor.py +++ b/nvflare/app_opt/pt/quantization/dequantizor.py @@ -178,20 +178,18 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio quantization_type = dxo.get_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, default=None) if quantization_type.upper() not in QUANTIZATION_TYPE: raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}") - + source_datatype = dxo.get_meta_prop(key="source_datatype", default=None) dequantized_params = self.dequantization( params=dxo.data, quant_state=dxo.meta["quant_state"], quantization_type=quantization_type, - source_datatype=dxo.meta["source_datatype"], + source_datatype=source_datatype, fl_ctx=fl_ctx, ) # Compose new DXO with dequantized data dxo.data = dequantized_params - dxo.remove_meta_props(MetaKey.PROCESSED_ALGORITHM) - dxo.remove_meta_props("quant_state") - dxo.remove_meta_props("source_datatype") + dxo.remove_meta_props([MetaKey.PROCESSED_ALGORITHM, "quant_state", "source_datatype", "quantized_flag"]) dxo.update_shareable(shareable) - self.log_info(fl_ctx, "Dequantized back") + self.log_info(fl_ctx, f"Dequantized back to {source_datatype}") return dxo diff --git a/nvflare/app_opt/pt/quantization/quantizor.py b/nvflare/app_opt/pt/quantization/quantizor.py index 10c09c001b..43f7f7c117 100644 --- a/nvflare/app_opt/pt/quantization/quantizor.py +++ b/nvflare/app_opt/pt/quantization/quantizor.py @@ -195,13 +195,33 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio """ self.log_info(fl_ctx, "Running quantization...") - quantized_params, quant_state, source_datatype = self.quantization(params=dxo.data, fl_ctx=fl_ctx) - # Compose new DXO with quantized data - # Add quant_state to the new DXO meta - new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta) - new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type) - new_dxo.set_meta_prop(key="quant_state", value=quant_state) - new_dxo.set_meta_prop(key="source_datatype", value=source_datatype) - self.log_info(fl_ctx, f"Quantized to {self.quantization_type}") + + # for already quantized message, skip quantization + # The reason in this current example: + # server job in this case is 1-N communication with identical quantization operation + # the first communication to client will apply quantization and change the data on the server + # thus the subsequent communications to the rest of clients will no longer need to apply quantization + # This will not apply to client job, since the client job will be 1-1 and quantization applies to each client + # Potentially: + # If clients talks to each other, it will also be 1-N and same rule applies + # If 1-N server-client filters can be different (Filter_1 applies to server-client_subset_1, etc.), then + # a deep copy of the server data should be made by filter before applying a different filter + + # quantized_flag None if does not exist in meta + quantized_flag = dxo.get_meta_prop("quantized_flag") + if quantized_flag: + self.log_info(fl_ctx, "Already quantized, skip quantization") + new_dxo = dxo + else: + # apply quantization + quantized_params, quant_state, source_datatype = self.quantization(params=dxo.data, fl_ctx=fl_ctx) + # Compose new DXO with quantized data + # Add quant_state to the new DXO meta + new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta) + new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type) + new_dxo.set_meta_prop(key="quant_state", value=quant_state) + new_dxo.set_meta_prop(key="source_datatype", value=source_datatype) + new_dxo.set_meta_prop(key="quantized_flag", value=True) + self.log_info(fl_ctx, f"Quantized from {source_datatype} to {self.quantization_type}") return new_dxo