Skip to content

Commit

Permalink
code review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
eladc-git committed Mar 10, 2024
1 parent fe61768 commit 5e33f22
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,20 @@ def _kl_error_function_wrapper(x: np.ndarray,
n_bins: int = 2048,
n_bits: int = 8) -> np.ndarray:
"""
Compute the error function between a tensor to its quantized version per channel.
The error is computed based on the KL-divergence the distributions have.
Number of bins to use when computing the histogram of the float tensor is passed.
The threshold and number of bits that were used to quantize the tensor are needed to compute the
histograms boundaries and the number of quantized bins.
Computes the error function between a tensor and its quantized version for each channel.
The error is based on the KL-divergence between the distributions.
The function uses a specified number of bins to compute the histogram of the float tensor.
It requires the threshold and number of bits used for quantization to determine the histogram's boundaries and the number of quantized bins.
Args:
x: Float tensor.
range_min: array of min bound on the quantization range.
range_max: array of max bound on the quantization range.
range_min: Array specifying the minimum bound of the quantization range for each channel.
range_max: Array specifying the maximum bound of the quantization range for each channel.
n_bins: Number of bins for the float histogram.
n_bits: Number of bits the quantized tensor was quantized by.
n_bits: Number of bits used for quantization.
Returns:
The KL-divergence of the float histogram and the quantized histogram of the tensors, per channel
An array containing the KL-divergence between the float and quantized histograms of the tensor for each channel.
"""

Expand Down Expand Up @@ -378,13 +377,13 @@ def get_threshold_selection_tensor_error_function(quantization_method: Quantizat
Returns the error function compatible to the provided threshold method,
to be used in the threshold optimization search for tensor quantization.
Args:
quantization_method: Quantization method for threshold selection
quant_error_method: the requested error function type.
p: p-norm to use for the Lp-norm distance.
axis: Axis along which the operator has been computed.
norm: whether to normalize the error function result.
n_bits: Number of bits to quantize the tensor.
signed: signed input
quantization_method: Method used for selecting the quantization threshold.
quant_error_method: Type of error function requested.
p: P-norm to use for calculating the Lp-norm distance.
axis: Axis along which the operation has been performed.
norm: Indicates whether to normalize the result of the error function.
n_bits: Number of bits used to quantize the tensor.
signed: Indicates whether the input is signed.
Returns: a Callable method that calculates the error between a tensor and a quantized tensor.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def _error_function_wrapper(error_function: Callable,
q_tensor: Numpy array with quantized tensor's content.
in_params: Quantization params the tensor is quantized by (used in specific error functions only).
Returns: A array of error values per-channel for the quantized tensor, according to the error function.
Returns: An array of error values for each channel of the quantized tensor, as determined by the specified error function.
"""
return error_function(float_tensor, q_tensor, in_params)

0 comments on commit 5e33f22

Please sign in to comment.