Skip to content

Commit

Permalink
[CPU EP] Add blocked quantization to QuantizeLinear op kernel (#20977)
Browse files Browse the repository at this point in the history
### Description
Add blocked quantization to QuantizeLinear op kernel.

If the quantize axis is not the last axis, block the tensor using 1x128
blocks. Blocks are dispatched to multiple threads for concurrently
processing. Currently only support scalar instructions.

If the quantize axis is the last axis, block the tensor using 1 x
quant_block_size blocks. Blocks are dispatched to multiple threads for
concurrent processing. If output type is int types, call mlas kernel to
use the SIMD instructions in each block.

#### Benchmark data
20 core 2GHz CPU, RelWithDebInfo config, 196 x 4096 tensor, quantize
float to int4x2

Quantize before last axis:
 * single thread, scalar instruction: 31380900 ns
 * 8 thread, scalar instruction: 5098620 ns

Quantize last axis:
 * single thread, scalar instruction: 27927900 ns
 * 8 thread, SIMD instruction: 102261 ns

more thread, SIMD instruction, larger block size helps

### Motivation and Context
ONNX added blocked quantization to QuantizeLinear in optset 21
  • Loading branch information
fajin-corp authored Jun 12, 2024
1 parent 17d5dc5 commit 9be3034
Show file tree
Hide file tree
Showing 4 changed files with 1,957 additions and 34 deletions.
131 changes: 97 additions & 34 deletions onnxruntime/core/providers/cpu/quantization/quantize_linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ class QuantizeLinear final : public OpKernel {
block_size_ = 0;
}

// TODO(adrianlizarraga): Support the block_size attribute added in opset 21.
if (block_size_ != 0) {
ORT_THROW("QuantizeLinear does not yet support the 'block_size' attribute.");
}
ORT_ENFORCE(block_size_ >= 0, "'block_size' must be non-negative.");
}

Status Compute(OpKernelContext* context) const override;
Expand Down Expand Up @@ -700,6 +697,9 @@ void ParQuantizeLinear(const InputType* Input,
#endif
}

/**
* @brief Compute per-tensor or per-axis quantization.
*/
template <typename T, typename InT>
void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const T* zero_point, T* output,
int64_t process_block_count, int64_t broadcast_dim, int64_t process_block_size, bool saturate) {
Expand All @@ -714,24 +714,24 @@ void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const
}

// Quantizes float32 to INT4 (in-place) using MLAS kernel.
#define DEFINE_COMPUTE_LOOP_FP32_TO_INT4(INT4_TYPE, QUANT_FUNC) \
template <> \
void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const INT4_TYPE* zero_point, \
INT4_TYPE* output, int64_t N, int64_t axis_dim_val, int64_t quant_block_size, bool saturate) { \
ORT_UNUSED_PARAMETER(saturate); \
size_t output_index = 0; \
for (size_t n = 0; n < static_cast<size_t>(N); n++) { \
for (size_t bd = 0; bd < static_cast<size_t>(axis_dim_val); bd++) { \
size_t bd_i = bd >> 1; /*bd / 2*/ \
size_t bd_j = bd & 0x1; /*bd % 2*/ \
INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \
QUANT_FUNC(input, output, output_index, output_index + static_cast<size_t>(quant_block_size), \
scale[bd], INT4_TYPE(zp, 0), ctx->GetOperatorThreadPool()); \
input += quant_block_size; \
output_index += static_cast<size_t>(quant_block_size); \
} \
} \
assert(output_index == static_cast<size_t>(N * axis_dim_val * quant_block_size)); \
#define DEFINE_COMPUTE_LOOP_FP32_TO_INT4(INT4_TYPE, QUANT_FUNC) \
template <> \
void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const INT4_TYPE* zero_point, \
INT4_TYPE* output, int64_t M, int64_t K, int64_t N, bool saturate) { \
ORT_UNUSED_PARAMETER(saturate); \
size_t output_index = 0; \
for (size_t m = 0; m < static_cast<size_t>(M); m++) { \
for (size_t bd = 0; bd < static_cast<size_t>(K); bd++) { \
size_t bd_i = bd >> 1; /*bd / 2*/ \
size_t bd_j = bd & 0x1; /*bd % 2*/ \
INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \
QUANT_FUNC(input, output, output_index, output_index + static_cast<size_t>(N), \
scale[bd], INT4_TYPE(zp, 0), ctx->GetOperatorThreadPool()); \
input += N; \
output_index += static_cast<size_t>(N); \
} \
} \
assert(output_index == static_cast<size_t>(M * K * N)); \
}

DEFINE_COMPUTE_LOOP_FP32_TO_INT4(Int4x2, ParQuantizeLinearStdS4)
Expand All @@ -743,24 +743,24 @@ DEFINE_COMPUTE_LOOP_FP32_TO_INT4(UInt4x2, ParQuantizeLinearStdU4)
#define DEFINE_COMPUTE_LOOP_FP16_TO_INT4(INT4_TYPE) \
template <> \
void ComputeLoop<INT4_TYPE, MLFloat16>(OpKernelContext * ctx, const MLFloat16* input, const MLFloat16* scale, \
const INT4_TYPE* zero_point, INT4_TYPE* output, int64_t N, \
int64_t axis_dim_val, int64_t quant_block_size, bool saturate) { \
const INT4_TYPE* zero_point, INT4_TYPE* output, int64_t M, \
int64_t K, int64_t N, bool saturate) { \
ORT_UNUSED_PARAMETER(saturate); \
\
size_t total_size = static_cast<size_t>(N * axis_dim_val * quant_block_size); \
size_t total_size = static_cast<size_t>(M * K * N); \
auto tmp_buf = std::make_unique<INT4_TYPE::UnpackedType[]>(total_size); \
size_t tmp_buf_index = 0; \
\
for (size_t n = 0; n < static_cast<size_t>(N); n++) { \
for (size_t bd = 0; bd < static_cast<size_t>(axis_dim_val); bd++) { \
for (size_t m = 0; m < static_cast<size_t>(M); m++) { \
for (size_t bd = 0; bd < static_cast<size_t>(K); bd++) { \
size_t bd_i = bd >> 1; /*bd / 2*/ \
size_t bd_j = bd & 0x1; /*bd % 2*/ \
INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \
ParQuantizeLinearStd<INT4_TYPE::UnpackedType>(input, tmp_buf.get() + tmp_buf_index, \
static_cast<size_t>(quant_block_size), scale[bd], \
static_cast<size_t>(N), scale[bd], \
zp, ctx->GetOperatorThreadPool()); \
input += quant_block_size; \
tmp_buf_index += static_cast<size_t>(quant_block_size); \
input += N; \
tmp_buf_index += static_cast<size_t>(N); \
} \
} \
\
Expand Down Expand Up @@ -797,12 +797,75 @@ Status QuantizeLinear<T>::Compute(OpKernelContext* ctx) const {
const T* zero_point = y_zero_point != nullptr ? y_zero_point->Data<T>() : nullptr;
T* output = y.MutableData<T>();

constexpr int output_type_group_ =
boost::mp11::mp_contains<TypeList<Int4x2, UInt4x2>, T>::value ? 2
#if !defined(DISABLE_FLOAT8_TYPES)
: boost::mp11::mp_contains<element_type_lists::AllFloat8, T>::value ? 1
#endif
: 0;

if (x.IsDataType<float>()) {
ComputeLoop<T, float>(ctx, x.Data<float>(), y_scale.Data<float>(), zero_point, output,
process_block_count, broadcast_dim, process_block_size, saturate_);
if (block_size_) {
if (process_block_size > 1) {
BlockedQuantizeLinear<float, T, output_type_group_>::opNotLastAxis(
ctx->GetOperatorThreadPool(),
x.Data<float>(),
y_scale.Data<float>(),
zero_point,
output,
static_cast<std::ptrdiff_t>(process_block_count),
static_cast<std::ptrdiff_t>(broadcast_dim),
static_cast<std::ptrdiff_t>(process_block_size),
static_cast<std::ptrdiff_t>(block_size_),
128,
saturate_);
} else {
BlockedQuantizeLinear<float, T, output_type_group_>::opLastAxis(
ctx->GetOperatorThreadPool(),
x.Data<float>(),
y_scale.Data<float>(),
zero_point,
output,
static_cast<std::ptrdiff_t>(process_block_count),
static_cast<std::ptrdiff_t>(broadcast_dim),
static_cast<std::ptrdiff_t>(block_size_),
saturate_);
}
} else {
ComputeLoop<T, float>(ctx, x.Data<float>(), y_scale.Data<float>(), zero_point, output,
process_block_count, broadcast_dim, process_block_size, saturate_);
}
} else if (x.IsDataType<MLFloat16>()) {
ComputeLoop<T, MLFloat16>(ctx, x.Data<MLFloat16>(), y_scale.Data<MLFloat16>(), zero_point, output,
process_block_count, broadcast_dim, process_block_size, saturate_);
if (block_size_) {
if (process_block_size > 1) {
BlockedQuantizeLinear<MLFloat16, T, output_type_group_>::opNotLastAxis(
ctx->GetOperatorThreadPool(),
x.Data<MLFloat16>(),
y_scale.Data<MLFloat16>(),
zero_point,
output,
static_cast<std::ptrdiff_t>(process_block_count),
static_cast<std::ptrdiff_t>(broadcast_dim),
static_cast<std::ptrdiff_t>(process_block_size),
static_cast<std::ptrdiff_t>(block_size_),
128,
saturate_);
} else {
BlockedQuantizeLinear<MLFloat16, T, output_type_group_>::opLastAxis(
ctx->GetOperatorThreadPool(),
x.Data<MLFloat16>(),
y_scale.Data<MLFloat16>(),
zero_point,
output,
static_cast<std::ptrdiff_t>(process_block_count),
static_cast<std::ptrdiff_t>(broadcast_dim),
static_cast<std::ptrdiff_t>(block_size_),
saturate_);
}
} else {
ComputeLoop<T, MLFloat16>(ctx, x.Data<MLFloat16>(), y_scale.Data<MLFloat16>(), zero_point, output,
process_block_count, broadcast_dim, process_block_size, saturate_);
}
} else {
ORT_THROW("Unsupported input type.");
}
Expand Down
Loading

0 comments on commit 9be3034

Please sign in to comment.