Skip to content

Commit

Permalink
Revert "Fix overflow in collectors" (openvinotoolkit#2384)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximProshin authored Jan 10, 2024
1 parent ae721fd commit c760b3d
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 173 deletions.
1 change: 0 additions & 1 deletion nncf/common/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


class Dtype(Enum):
BOOL = "bool"
FLOAT = "float"
INTEGER = "int"

Expand Down
2 changes: 1 addition & 1 deletion nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor:

@staticmethod
@abstractmethod
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
"""
Computes the mean of elements across given dimensions of NNCFTensor.
Expand Down
61 changes: 15 additions & 46 deletions nncf/onnx/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,12 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor:
return ONNXNNCFTensor(np.maximum(x1.tensor, x2.tensor))

@staticmethod
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
return ONNXNNCFTensor(
np.mean(x.tensor, axis=axis, keepdims=keepdims, dtype=comp_dtype).astype(dtype=out_dtype, copy=False)
)
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
return ONNXNNCFTensor(np.mean(x.tensor, axis=axis, keepdims=keepdims))

@staticmethod
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
t = x.tensor.astype(dtype=comp_dtype, copy=False)
return ONNXNNCFTensor(np.median(t, axis=axis, keepdims=keepdims).astype(dtype=out_dtype, copy=False))
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
return ONNXNNCFTensor(np.median(x.tensor, axis=axis, keepdims=keepdims))

@classmethod
def masked_mean(
Expand All @@ -74,12 +69,8 @@ def masked_mean(
) -> NNCFTensor:
if mask is None:
return cls.mean(x, axis=axis, keepdims=keepdims)
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
masked_x = np.ma.array(x.tensor, mask=mask.tensor)
result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims, dtype=comp_dtype)
if isinstance(result, np.ma.MaskedArray):
result = result.data
return ONNXNNCFTensor(result.astype(dtype=out_dtype, copy=False))
return ONNXNNCFTensor(np.ma.mean(masked_x, axis=axis, keepdims=False).data)

@classmethod
def masked_median(
Expand All @@ -91,21 +82,8 @@ def masked_median(
) -> NNCFTensor:
if mask is None:
return cls.median(x, axis=axis, keepdims=keepdims)
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
t = x.tensor.astype(dtype=comp_dtype, copy=False)
masked_x = np.ma.array(t, mask=mask.tensor)
result = np.ma.median(masked_x, axis=axis, keepdims=keepdims)
if isinstance(result, np.ma.MaskedArray):
result = result.data
return ONNXNNCFTensor(result.astype(dtype=out_dtype, copy=False))

@staticmethod
def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
if len(x.shape) < 3:
return ONNXNNCFTensor.mean(x, axis=0)
x = np.moveaxis(x.tensor, axis, 1)
t = x.reshape(x.shape[0], x.shape[1], -1)
return ONNXNNCFCollectorTensorProcessor.mean(ONNXNNCFTensor(t), axis=(0, 2))
masked_x = np.ma.array(x.tensor, mask=mask.tensor)
return ONNXNNCFTensor(np.ma.median(masked_x, axis=axis, keepdims=keepdims).data)

@staticmethod
def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor:
Expand Down Expand Up @@ -152,6 +130,14 @@ def percentile(
) -> List[TensorElementsType]:
raise NotImplementedError()

@staticmethod
def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
if len(x.shape) < 3:
return ONNXNNCFTensor(np.mean(x.tensor, axis=0))
x = np.moveaxis(x.tensor, axis, 1)
t = x.reshape(x.shape[0], x.shape[1], -1)
return ONNXNNCFTensor(np.mean(t, axis=(0, 2)))

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
return ONNXNNCFTensor(np.transpose(x.tensor, axes))
Expand Down Expand Up @@ -233,20 +219,3 @@ def _register_input(self, x: ONNXNNCFTensor):

def _get_statistics(self) -> ONNXRawTensorStatistic:
return ONNXRawTensorStatistic(self._all_values)


def _get_computing_dtype(dtype: np.dtype) -> Tuple[Optional[np.dtype], Optional[np.dtype]]:
"""
Determines the appropriate dtypes for intermediate computations and the final output,
aiming to prevent overflow while maintaining precision.
:param dtype: The dtype of the processed tensor.
:return:
- comp_dtype: The recommended dtype for intermediate computations to avoid overflow.
If None, no dtype change is necessary for intermediate computations.
- out_dtype: The recommended dtype for the final output, balancing precision and memory usage.
If None, the input dtype is preserved for the output.
"""
if dtype in [np.float32, np.float16]:
return (np.float64, dtype)
return (None, None)
45 changes: 10 additions & 35 deletions nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,11 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor:

@staticmethod
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
return OVNNCFTensor(
np.mean(x.tensor, axis=axis, keepdims=keepdims, dtype=comp_dtype).astype(dtype=out_dtype, copy=False)
)
return OVNNCFTensor(np.mean(x.tensor, axis=axis, keepdims=keepdims))

@staticmethod
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
t = x.tensor.astype(dtype=comp_dtype, copy=False)
return OVNNCFTensor(np.median(t, axis=axis, keepdims=keepdims).astype(dtype=out_dtype, copy=False))
return OVNNCFTensor(np.median(x.tensor, axis=axis, keepdims=keepdims))

@classmethod
def masked_mean(
Expand All @@ -91,12 +86,11 @@ def masked_mean(
) -> NNCFTensor:
if mask is None:
return cls.mean(x, axis=axis, keepdims=keepdims)
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
masked_x = np.ma.array(x.tensor, mask=mask.tensor)
result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims, dtype=comp_dtype)
result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims)
if isinstance(result, np.ma.MaskedArray):
result = result.data
return OVNNCFTensor(result.astype(dtype=out_dtype, copy=False))
return OVNNCFTensor(result.data)
return OVNNCFTensor(result)

@classmethod
def masked_median(
Expand All @@ -108,21 +102,19 @@ def masked_median(
) -> NNCFTensor:
if mask is None:
return cls.median(x, axis=axis, keepdims=keepdims)
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
t = x.tensor.astype(dtype=comp_dtype, copy=False)
masked_x = np.ma.array(t, mask=mask.tensor)
masked_x = np.ma.array(x.tensor, mask=mask.tensor)
result = np.ma.median(masked_x, axis=axis, keepdims=keepdims)
if isinstance(result, np.ma.MaskedArray):
result = result.data
return OVNNCFTensor(result.astype(dtype=out_dtype, copy=False))
return OVNNCFTensor(result.data)
return OVNNCFTensor(result)

@staticmethod
def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
if len(x.shape) < 3:
return OVNNCFCollectorTensorProcessor.mean(x, axis=0)
return OVNNCFTensor(np.mean(x.tensor, axis=0))
x = np.moveaxis(x.tensor, axis, 1)
t = x.reshape(x.shape[0], x.shape[1], -1)
return OVNNCFCollectorTensorProcessor.mean(OVNNCFTensor(t), axis=(0, 2))
return OVNNCFTensor(np.mean(t, axis=(0, 2)))

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
Expand Down Expand Up @@ -344,20 +336,3 @@ def get_raw_stat_collector(num_samples, inplace=False):
StatisticsType.QUANTILE: OVQuantileReducer,
StatisticsType.ABS_QUANTILE: OVAbsQuantileReducer,
}


def _get_computing_dtype(dtype: np.dtype) -> Tuple[Optional[np.dtype], Optional[np.dtype]]:
"""
Determines the appropriate dtypes for intermediate computations and the final output,
aiming to prevent overflow while maintaining precision.
:param dtype: The dtype of the processed tensor.
:return:
- comp_dtype: The recommended dtype for intermediate computations to avoid overflow.
If None, no dtype change is necessary for intermediate computations.
- out_dtype: The recommended dtype for the final output, balancing precision and memory usage.
If None, the input dtype is preserved for the output.
"""
if dtype in [np.float32, np.float16]:
return (np.float64, dtype)
return (None, None)
8 changes: 4 additions & 4 deletions nncf/tensorflow/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,22 @@ def max(x1: tf.Tensor, x2: tf.Tensor) -> NNCFTensor:
return TFNNCFTensor(tf.math.maximum(x1.tensor, x2.tensor))

@staticmethod
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
return TFNNCFTensor(tf.math.reduce_mean(x.tensor, axis=axis, keepdims=keepdims))

@staticmethod
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
raise NotImplementedError()

@classmethod
def masked_mean(
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims: bool = False
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False
) -> NNCFTensor:
raise NotImplementedError()

@classmethod
def masked_median(
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims: bool = False
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False
) -> NNCFTensor:
raise NotImplementedError()

Expand Down
59 changes: 16 additions & 43 deletions nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,66 +71,56 @@ def max(cls, *args) -> NNCFTensor:
return cls.reduce_max(stacked, axis=0, keepdims=False)

@staticmethod
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
return PTNNCFTensor(
torch.mean(x.tensor, axis=axis, keepdims=keepdims, dtype=comp_dtype).to(dtype=out_dtype, copy=False)
)
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
return PTNNCFTensor(x.tensor.mean(dim=axis, keepdim=keepdims))

@staticmethod
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
# See https://github.com/pytorch/pytorch/issues/61582
if not isinstance(axis, int):
device = x.tensor.device
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
np_tensor = x.tensor.detach().cpu().to(dtype=comp_dtype, copy=False).numpy()
result = torch.tensor(np.median(np_tensor, axis=axis, keepdims=keepdims))
return PTNNCFTensor(result.type(out_dtype).to(device))
result = torch.tensor(np.median(x.tensor.detach().cpu().numpy(), axis=axis, keepdims=keepdims))
return PTNNCFTensor(result.type(x.tensor.dtype).to(device))
return PTNNCFCollectorTensorProcessor.quantile(x, quantile=[0.5], axis=axis, keepdims=keepdims)[0]

@classmethod
def masked_mean(
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims: bool = False
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False
) -> NNCFTensor:
if mask is None:
return cls.mean(x, axis=axis, keepdims=keepdims)
device = x.tensor.device
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
np_tensor = x.tensor.detach().cpu().to(dtype=comp_dtype, copy=False).numpy()
masked_x = np.ma.array(np_tensor, mask=mask.tensor.detach().cpu().numpy())
result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims)
masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor.detach().cpu().numpy())
result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims).astype(masked_x.dtype)
if isinstance(result, np.ma.MaskedArray):
result = result.data
return PTNNCFTensor(torch.tensor(result).to(device=device, dtype=out_dtype))
return PTNNCFTensor(torch.tensor(result).to(device=device))

@classmethod
def masked_median(
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims: bool = False
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False
) -> NNCFTensor:
# Implemented in numpy as torch.masked.median is not implemented yet
# Implemented in numy as torch.masked.median is not implemented yet
if mask is None:
return cls.median(x, axis=axis, keepdims=keepdims)

device = x.tensor.device
comp_dtype, out_dtype = _get_computing_dtype(x.tensor.dtype)
np_tensor = x.tensor.detach().cpu().to(dtype=comp_dtype, copy=False).numpy()
masked_x = np.ma.array(np_tensor, mask=mask.tensor.detach().cpu().numpy())
masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor.detach().cpu().numpy())
result = np.ma.median(masked_x, axis=axis, keepdims=keepdims).astype(masked_x.dtype)
if isinstance(result, np.ma.MaskedArray):
result = result.data
return PTNNCFTensor(torch.tensor(result).to(device=device, dtype=out_dtype))
return PTNNCFTensor(torch.tensor(result).to(device=device))

@staticmethod
def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
if len(x.shape) < 3:
return PTNNCFCollectorTensorProcessor.mean(x, axis=0)
return PTNNCFTensor(torch.mean(x.tensor, axis=0))
x = torch.moveaxis(x.tensor, axis, 1)
t = x.reshape(x.shape[0], x.shape[1], -1)
return PTNNCFCollectorTensorProcessor.mean(PTNNCFTensor(t), axis=(0, 2))
return PTNNCFTensor(torch.mean(t, axis=(0, 2)))

@staticmethod
def batch_mean(x: NNCFTensor) -> NNCFTensor:
return PTNNCFCollectorTensorProcessor.mean(x, axis=0, keepdims=True)
return PTNNCFTensor(torch.mean(x.tensor, axis=0, keepdims=True))

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
Expand Down Expand Up @@ -564,20 +554,3 @@ def get_mean_statistic_collector(
StatisticsType.QUANTILE: PTQuantileReducer,
StatisticsType.ABS_QUANTILE: PTAbsQuantileReducer,
}


def _get_computing_dtype(dtype: torch.dtype) -> Tuple[Optional[torch.dtype], Optional[torch.dtype]]:
"""
Determines the appropriate dtypes for intermediate computations and the final output,
aiming to prevent overflow while maintaining precision.
:param dtype: The dtype of the processed tensor.
:return:
- comp_dtype: The recommended dtype for intermediate computations to avoid overflow.
If None, no dtype change is necessary for intermediate computations.
- out_dtype: The recommended dtype for the final output, balancing precision and memory usage.
If None, the input dtype is preserved for the output.
"""
if dtype in [torch.float32, torch.float16]:
return (torch.float64, dtype)
return (None, None)
Loading

0 comments on commit c760b3d

Please sign in to comment.