Skip to content

Commit

Permalink
Remove meta and enable uint8 quantization (#3222)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

Bug in dequantization meta removal, also enabling uint8 quantization

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Quick tests passed locally by running `./runtest.sh`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated.
  • Loading branch information
ZiyueXu77 authored Feb 22, 2025
1 parent 1a55270 commit 3aefb2d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
10 changes: 4 additions & 6 deletions nvflare/app_opt/pt/quantization/dequantizor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,18 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio
quantization_type = dxo.get_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, default=None)
if quantization_type.upper() not in QUANTIZATION_TYPE:
raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}")

source_datatype = dxo.get_meta_prop(key="source_datatype", default=None)
dequantized_params = self.dequantization(
params=dxo.data,
quant_state=dxo.meta["quant_state"],
quantization_type=quantization_type,
source_datatype=dxo.meta["source_datatype"],
source_datatype=source_datatype,
fl_ctx=fl_ctx,
)
# Compose new DXO with dequantized data
dxo.data = dequantized_params
dxo.remove_meta_props(MetaKey.PROCESSED_ALGORITHM)
dxo.remove_meta_props("quant_state")
dxo.remove_meta_props("source_datatype")
dxo.remove_meta_props([MetaKey.PROCESSED_ALGORITHM, "quant_state", "source_datatype", "quantized_flag"])
dxo.update_shareable(shareable)
self.log_info(fl_ctx, "Dequantized back")
self.log_info(fl_ctx, f"Dequantized back to {source_datatype}")

return dxo
36 changes: 28 additions & 8 deletions nvflare/app_opt/pt/quantization/quantizor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,33 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio
"""

self.log_info(fl_ctx, "Running quantization...")
quantized_params, quant_state, source_datatype = self.quantization(params=dxo.data, fl_ctx=fl_ctx)
# Compose new DXO with quantized data
# Add quant_state to the new DXO meta
new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta)
new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type)
new_dxo.set_meta_prop(key="quant_state", value=quant_state)
new_dxo.set_meta_prop(key="source_datatype", value=source_datatype)
self.log_info(fl_ctx, f"Quantized to {self.quantization_type}")

# for already quantized message, skip quantization
# The reason in this current example:
# server job in this case is 1-N communication with identical quantization operation
# the first communication to client will apply quantization and change the data on the server
# thus the subsequent communications to the rest of clients will no longer need to apply quantization
# This will not apply to client job, since the client job will be 1-1 and quantization applies to each client
# Potentially:
# If clients talks to each other, it will also be 1-N and same rule applies
# If 1-N server-client filters can be different (Filter_1 applies to server-client_subset_1, etc.), then
# a deep copy of the server data should be made by filter before applying a different filter

# quantized_flag None if does not exist in meta
quantized_flag = dxo.get_meta_prop("quantized_flag")
if quantized_flag:
self.log_info(fl_ctx, "Already quantized, skip quantization")
new_dxo = dxo
else:
# apply quantization
quantized_params, quant_state, source_datatype = self.quantization(params=dxo.data, fl_ctx=fl_ctx)
# Compose new DXO with quantized data
# Add quant_state to the new DXO meta
new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta)
new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type)
new_dxo.set_meta_prop(key="quant_state", value=quant_state)
new_dxo.set_meta_prop(key="source_datatype", value=source_datatype)
new_dxo.set_meta_prop(key="quantized_flag", value=True)
self.log_info(fl_ctx, f"Quantized from {source_datatype} to {self.quantization_type}")

return new_dxo

0 comments on commit 3aefb2d

Please sign in to comment.