Skip to content

Commit

Permalink
Fix overflow in collectors (#2337)
Browse files Browse the repository at this point in the history
### Changes

Cast float16 and float32 to float64 for mean and median.

### Reason for changes

Error on calculation parameters of FQ for large scale.


### Tests

test_reducers_and_aggregators.py::test_overflow
  • Loading branch information
AlexanderDokuchaev authored Jan 8, 2024
1 parent 3f83e1a commit 22b6e58
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 69 deletions.
1 change: 1 addition & 0 deletions nncf/common/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


class Dtype(Enum):
BOOL = "bool"
FLOAT = "float"
INTEGER = "int"

Expand Down
2 changes: 1 addition & 1 deletion nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor:

@staticmethod
@abstractmethod
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
"""
Computes the mean of elements across given dimensions of NNCFTensor.
Expand Down
61 changes: 46 additions & 15 deletions nncf/onnx/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,17 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor:
return ONNXNNCFTensor(np.maximum(x1.tensor, x2.tensor))

@staticmethod
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
return ONNXNNCFTensor(np.mean(x.tensor, axis=axis, keepdims=keepdims))
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
return ONNXNNCFTensor(
np.mean(x.tensor, axis=axis, keepdims=keepdims, dtype=comp_dtype).astype(dtype=out_dtype, copy=False)
)

@staticmethod
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
return ONNXNNCFTensor(np.median(x.tensor, axis=axis, keepdims=keepdims))
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
t = x.tensor.astype(dtype=comp_dtype, copy=False)
return ONNXNNCFTensor(np.median(t, axis=axis, keepdims=keepdims).astype(dtype=out_dtype, copy=False))

@classmethod
def masked_mean(
Expand All @@ -69,8 +74,12 @@ def masked_mean(
) -> NNCFTensor:
if mask is None:
return cls.mean(x, axis=axis, keepdims=keepdims)
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
masked_x = np.ma.array(x.tensor, mask=mask.tensor)
return ONNXNNCFTensor(np.ma.mean(masked_x, axis=axis, keepdims=False).data)
result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims, dtype=comp_dtype)
if isinstance(result, np.ma.MaskedArray):
result = result.data
return ONNXNNCFTensor(result.astype(dtype=out_dtype, copy=False))

@classmethod
def masked_median(
Expand All @@ -82,8 +91,21 @@ def masked_median(
) -> NNCFTensor:
if mask is None:
return cls.median(x, axis=axis, keepdims=keepdims)
masked_x = np.ma.array(x.tensor, mask=mask.tensor)
return ONNXNNCFTensor(np.ma.median(masked_x, axis=axis, keepdims=keepdims).data)
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
t = x.tensor.astype(dtype=comp_dtype, copy=False)
masked_x = np.ma.array(t, mask=mask.tensor)
result = np.ma.median(masked_x, axis=axis, keepdims=keepdims)
if isinstance(result, np.ma.MaskedArray):
result = result.data
return ONNXNNCFTensor(result.astype(dtype=out_dtype, copy=False))

@staticmethod
def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
if len(x.shape) < 3:
return ONNXNNCFTensor.mean(x, axis=0)
x = np.moveaxis(x.tensor, axis, 1)
t = x.reshape(x.shape[0], x.shape[1], -1)
return ONNXNNCFCollectorTensorProcessor.mean(ONNXNNCFTensor(t), axis=(0, 2))

@staticmethod
def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor:
Expand Down Expand Up @@ -130,14 +152,6 @@ def percentile(
) -> List[TensorElementsType]:
raise NotImplementedError()

@staticmethod
def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
if len(x.shape) < 3:
return ONNXNNCFTensor(np.mean(x.tensor, axis=0))
x = np.moveaxis(x.tensor, axis, 1)
t = x.reshape(x.shape[0], x.shape[1], -1)
return ONNXNNCFTensor(np.mean(t, axis=(0, 2)))

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
return ONNXNNCFTensor(np.transpose(x.tensor, axes))
Expand Down Expand Up @@ -219,3 +233,20 @@ def _register_input(self, x: ONNXNNCFTensor):

def _get_statistics(self) -> ONNXRawTensorStatistic:
return ONNXRawTensorStatistic(self._all_values)


def _get_computing_dtype(dtype: np.dtype) -> Tuple[Optional[np.dtype], Optional[np.dtype]]:
"""
Determines the appropriate dtypes for intermediate computations and the final output,
aiming to prevent overflow while maintaining precision.
:param dtype: The dtype of the processed tensor.
:return:
- comp_dtype: The recommended dtype for intermediate computations to avoid overflow.
If None, no dtype change is necessary for intermediate computations.
- out_dtype: The recommended dtype for the final output, balancing precision and memory usage.
If None, the input dtype is preserved for the output.
"""
if dtype in [np.float32, np.float16]:
return (np.float64, dtype)
return (None, None)
45 changes: 35 additions & 10 deletions nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,16 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor:

@staticmethod
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
return OVNNCFTensor(np.mean(x.tensor, axis=axis, keepdims=keepdims))
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
return OVNNCFTensor(
np.mean(x.tensor, axis=axis, keepdims=keepdims, dtype=comp_dtype).astype(dtype=out_dtype, copy=False)
)

@staticmethod
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
return OVNNCFTensor(np.median(x.tensor, axis=axis, keepdims=keepdims))
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
t = x.tensor.astype(dtype=comp_dtype, copy=False)
return OVNNCFTensor(np.median(t, axis=axis, keepdims=keepdims).astype(dtype=out_dtype, copy=False))

@classmethod
def masked_mean(
Expand All @@ -86,11 +91,12 @@ def masked_mean(
) -> NNCFTensor:
if mask is None:
return cls.mean(x, axis=axis, keepdims=keepdims)
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
masked_x = np.ma.array(x.tensor, mask=mask.tensor)
result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims)
result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims, dtype=comp_dtype)
if isinstance(result, np.ma.MaskedArray):
return OVNNCFTensor(result.data)
return OVNNCFTensor(result)
result = result.data
return OVNNCFTensor(result.astype(dtype=out_dtype, copy=False))

@classmethod
def masked_median(
Expand All @@ -102,19 +108,21 @@ def masked_median(
) -> NNCFTensor:
if mask is None:
return cls.median(x, axis=axis, keepdims=keepdims)
masked_x = np.ma.array(x.tensor, mask=mask.tensor)
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
t = x.tensor.astype(dtype=comp_dtype, copy=False)
masked_x = np.ma.array(t, mask=mask.tensor)
result = np.ma.median(masked_x, axis=axis, keepdims=keepdims)
if isinstance(result, np.ma.MaskedArray):
return OVNNCFTensor(result.data)
return OVNNCFTensor(result)
result = result.data
return OVNNCFTensor(result.astype(dtype=out_dtype, copy=False))

@staticmethod
def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
if len(x.shape) < 3:
return OVNNCFTensor(np.mean(x.tensor, axis=0))
return OVNNCFCollectorTensorProcessor.mean(x, axis=0)
x = np.moveaxis(x.tensor, axis, 1)
t = x.reshape(x.shape[0], x.shape[1], -1)
return OVNNCFTensor(np.mean(t, axis=(0, 2)))
return OVNNCFCollectorTensorProcessor.mean(OVNNCFTensor(t), axis=(0, 2))

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
Expand Down Expand Up @@ -336,3 +344,20 @@ def get_raw_stat_collector(num_samples, inplace=False):
StatisticsType.QUANTILE: OVQuantileReducer,
StatisticsType.ABS_QUANTILE: OVAbsQuantileReducer,
}


def _get_computing_dtype(dtype: np.dtype) -> Tuple[Optional[np.dtype], Optional[np.dtype]]:
"""
Determines the appropriate dtypes for intermediate computations and the final output,
aiming to prevent overflow while maintaining precision.
:param dtype: The dtype of the processed tensor.
:return:
- comp_dtype: The recommended dtype for intermediate computations to avoid overflow.
If None, no dtype change is necessary for intermediate computations.
- out_dtype: The recommended dtype for the final output, balancing precision and memory usage.
If None, the input dtype is preserved for the output.
"""
if dtype in [np.float32, np.float16]:
return (np.float64, dtype)
return (None, None)
8 changes: 4 additions & 4 deletions nncf/tensorflow/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,22 @@ def max(x1: tf.Tensor, x2: tf.Tensor) -> NNCFTensor:
return TFNNCFTensor(tf.math.maximum(x1.tensor, x2.tensor))

@staticmethod
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
return TFNNCFTensor(tf.math.reduce_mean(x.tensor, axis=axis, keepdims=keepdims))

@staticmethod
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
raise NotImplementedError()

@classmethod
def masked_mean(
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims: bool = False
) -> NNCFTensor:
raise NotImplementedError()

@classmethod
def masked_median(
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims: bool = False
) -> NNCFTensor:
raise NotImplementedError()

Expand Down
59 changes: 43 additions & 16 deletions nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,56 +71,66 @@ def max(cls, *args) -> NNCFTensor:
return cls.reduce_max(stacked, axis=0, keepdims=False)

@staticmethod
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
return PTNNCFTensor(x.tensor.mean(dim=axis, keepdim=keepdims))
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
return PTNNCFTensor(
torch.mean(x.tensor, axis=axis, keepdims=keepdims, dtype=comp_dtype).to(dtype=out_dtype, copy=False)
)

@staticmethod
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
# See https://github.com/pytorch/pytorch/issues/61582
if not isinstance(axis, int):
device = x.tensor.device
result = torch.tensor(np.median(x.tensor.detach().cpu().numpy(), axis=axis, keepdims=keepdims))
return PTNNCFTensor(result.type(x.tensor.dtype).to(device))
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
np_tensor = x.tensor.detach().cpu().to(dtype=comp_dtype, copy=False).numpy()
result = torch.tensor(np.median(np_tensor, axis=axis, keepdims=keepdims))
return PTNNCFTensor(result.type(out_dtype).to(device))
return PTNNCFCollectorTensorProcessor.quantile(x, quantile=[0.5], axis=axis, keepdims=keepdims)[0]

@classmethod
def masked_mean(
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims: bool = False
) -> NNCFTensor:
if mask is None:
return cls.mean(x, axis=axis, keepdims=keepdims)
device = x.tensor.device
masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor.detach().cpu().numpy())
result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims).astype(masked_x.dtype)
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
np_tensor = x.tensor.detach().cpu().to(dtype=comp_dtype, copy=False).numpy()
masked_x = np.ma.array(np_tensor, mask=mask.tensor.detach().cpu().numpy())
result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims)
if isinstance(result, np.ma.MaskedArray):
result = result.data
return PTNNCFTensor(torch.tensor(result).to(device=device))
return PTNNCFTensor(torch.tensor(result).to(device=device, dtype=out_dtype))

@classmethod
def masked_median(
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims: bool = False
) -> NNCFTensor:
# Implemented in numy as torch.masked.median is not implemented yet
# Implemented in numpy as torch.masked.median is not implemented yet
if mask is None:
return cls.median(x, axis=axis, keepdims=keepdims)

device = x.tensor.device
masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor.detach().cpu().numpy())
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
np_tensor = x.tensor.detach().cpu().to(dtype=comp_dtype, copy=False).numpy()
masked_x = np.ma.array(np_tensor, mask=mask.tensor.detach().cpu().numpy())
result = np.ma.median(masked_x, axis=axis, keepdims=keepdims).astype(masked_x.dtype)
if isinstance(result, np.ma.MaskedArray):
result = result.data
return PTNNCFTensor(torch.tensor(result).to(device=device))
return PTNNCFTensor(torch.tensor(result).to(device=device, dtype=out_dtype))

@staticmethod
def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
if len(x.shape) < 3:
return PTNNCFTensor(torch.mean(x.tensor, axis=0))
return PTNNCFCollectorTensorProcessor.mean(x, axis=0)
x = torch.moveaxis(x.tensor, axis, 1)
t = x.reshape(x.shape[0], x.shape[1], -1)
return PTNNCFTensor(torch.mean(t, axis=(0, 2)))
return PTNNCFCollectorTensorProcessor.mean(PTNNCFTensor(t), axis=(0, 2))

@staticmethod
def batch_mean(x: NNCFTensor) -> NNCFTensor:
return PTNNCFTensor(torch.mean(x.tensor, axis=0, keepdims=True))
return PTNNCFCollectorTensorProcessor.mean(x, axis=0, keepdims=True)

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
Expand Down Expand Up @@ -554,3 +564,20 @@ def get_mean_statistic_collector(
StatisticsType.QUANTILE: PTQuantileReducer,
StatisticsType.ABS_QUANTILE: PTAbsQuantileReducer,
}


def _get_computing_dtype(dtype: torch.dtype) -> Tuple[Optional[torch.dtype], Optional[torch.dtype]]:
"""
Determines the appropriate dtypes for intermediate computations and the final output,
aiming to prevent overflow while maintaining precision.
:param dtype: The dtype of the processed tensor.
:return:
- comp_dtype: The recommended dtype for intermediate computations to avoid overflow.
If None, no dtype change is necessary for intermediate computations.
- out_dtype: The recommended dtype for the final output, balancing precision and memory usage.
If None, the input dtype is preserved for the output.
"""
if dtype in [torch.float32, torch.float16]:
return (torch.float64, dtype)
return (None, None)
Loading

0 comments on commit 22b6e58

Please sign in to comment.