diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index 91e21b3690b27..3d3e831a12d13 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -50,10 +50,7 @@ class QuantizeLinear final : public OpKernel { block_size_ = 0; } - // TODO(adrianlizarraga): Support the block_size attribute added in opset 21. - if (block_size_ != 0) { - ORT_THROW("QuantizeLinear does not yet support the 'block_size' attribute."); - } + ORT_ENFORCE(block_size_ >= 0, "'block_size' must be non-negative."); } Status Compute(OpKernelContext* context) const override; @@ -700,6 +697,9 @@ void ParQuantizeLinear(const InputType* Input, #endif } +/** + * @brief Compute per-tensor or per-axis quantization. + */ template void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const T* zero_point, T* output, int64_t process_block_count, int64_t broadcast_dim, int64_t process_block_size, bool saturate) { @@ -714,24 +714,24 @@ void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const } // Quantizes float32 to INT4 (in-place) using MLAS kernel. -#define DEFINE_COMPUTE_LOOP_FP32_TO_INT4(INT4_TYPE, QUANT_FUNC) \ - template <> \ - void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const INT4_TYPE* zero_point, \ - INT4_TYPE* output, int64_t N, int64_t axis_dim_val, int64_t quant_block_size, bool saturate) { \ - ORT_UNUSED_PARAMETER(saturate); \ - size_t output_index = 0; \ - for (size_t n = 0; n < static_cast(N); n++) { \ - for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { \ - size_t bd_i = bd >> 1; /*bd / 2*/ \ - size_t bd_j = bd & 0x1; /*bd % 2*/ \ - INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ - QUANT_FUNC(input, output, output_index, output_index + static_cast(quant_block_size), \ - scale[bd], INT4_TYPE(zp, 0), ctx->GetOperatorThreadPool()); \ - input += quant_block_size; \ - output_index += static_cast(quant_block_size); \ - } \ - } \ - assert(output_index == static_cast(N * axis_dim_val * quant_block_size)); \ +#define DEFINE_COMPUTE_LOOP_FP32_TO_INT4(INT4_TYPE, QUANT_FUNC) \ + template <> \ + void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const INT4_TYPE* zero_point, \ + INT4_TYPE* output, int64_t M, int64_t K, int64_t N, bool saturate) { \ + ORT_UNUSED_PARAMETER(saturate); \ + size_t output_index = 0; \ + for (size_t m = 0; m < static_cast(M); m++) { \ + for (size_t bd = 0; bd < static_cast(K); bd++) { \ + size_t bd_i = bd >> 1; /*bd / 2*/ \ + size_t bd_j = bd & 0x1; /*bd % 2*/ \ + INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ + QUANT_FUNC(input, output, output_index, output_index + static_cast(N), \ + scale[bd], INT4_TYPE(zp, 0), ctx->GetOperatorThreadPool()); \ + input += N; \ + output_index += static_cast(N); \ + } \ + } \ + assert(output_index == static_cast(M * K * N)); \ } DEFINE_COMPUTE_LOOP_FP32_TO_INT4(Int4x2, ParQuantizeLinearStdS4) @@ -743,24 +743,24 @@ DEFINE_COMPUTE_LOOP_FP32_TO_INT4(UInt4x2, ParQuantizeLinearStdU4) #define DEFINE_COMPUTE_LOOP_FP16_TO_INT4(INT4_TYPE) \ template <> \ void ComputeLoop(OpKernelContext * ctx, const MLFloat16* input, const MLFloat16* scale, \ - const INT4_TYPE* zero_point, INT4_TYPE* output, int64_t N, \ - int64_t axis_dim_val, int64_t quant_block_size, bool saturate) { \ + const INT4_TYPE* zero_point, INT4_TYPE* output, int64_t M, \ + int64_t K, int64_t N, bool saturate) { \ ORT_UNUSED_PARAMETER(saturate); \ \ - size_t total_size = static_cast(N * axis_dim_val * quant_block_size); \ + size_t total_size = static_cast(M * K * N); \ auto tmp_buf = std::make_unique(total_size); \ size_t tmp_buf_index = 0; \ \ - for (size_t n = 0; n < static_cast(N); n++) { \ - for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { \ + for (size_t m = 0; m < static_cast(M); m++) { \ + for (size_t bd = 0; bd < static_cast(K); bd++) { \ size_t bd_i = bd >> 1; /*bd / 2*/ \ size_t bd_j = bd & 0x1; /*bd % 2*/ \ INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, \ - static_cast(quant_block_size), scale[bd], \ + static_cast(N), scale[bd], \ zp, ctx->GetOperatorThreadPool()); \ - input += quant_block_size; \ - tmp_buf_index += static_cast(quant_block_size); \ + input += N; \ + tmp_buf_index += static_cast(N); \ } \ } \ \ @@ -797,12 +797,75 @@ Status QuantizeLinear::Compute(OpKernelContext* ctx) const { const T* zero_point = y_zero_point != nullptr ? y_zero_point->Data() : nullptr; T* output = y.MutableData(); + constexpr int output_type_group_ = + boost::mp11::mp_contains, T>::value ? 2 +#if !defined(DISABLE_FLOAT8_TYPES) + : boost::mp11::mp_contains::value ? 1 +#endif + : 0; + if (x.IsDataType()) { - ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, - process_block_count, broadcast_dim, process_block_size, saturate_); + if (block_size_) { + if (process_block_size > 1) { + BlockedQuantizeLinear::opNotLastAxis( + ctx->GetOperatorThreadPool(), + x.Data(), + y_scale.Data(), + zero_point, + output, + static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(process_block_size), + static_cast(block_size_), + 128, + saturate_); + } else { + BlockedQuantizeLinear::opLastAxis( + ctx->GetOperatorThreadPool(), + x.Data(), + y_scale.Data(), + zero_point, + output, + static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(block_size_), + saturate_); + } + } else { + ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, + process_block_count, broadcast_dim, process_block_size, saturate_); + } } else if (x.IsDataType()) { - ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, - process_block_count, broadcast_dim, process_block_size, saturate_); + if (block_size_) { + if (process_block_size > 1) { + BlockedQuantizeLinear::opNotLastAxis( + ctx->GetOperatorThreadPool(), + x.Data(), + y_scale.Data(), + zero_point, + output, + static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(process_block_size), + static_cast(block_size_), + 128, + saturate_); + } else { + BlockedQuantizeLinear::opLastAxis( + ctx->GetOperatorThreadPool(), + x.Data(), + y_scale.Data(), + zero_point, + output, + static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(block_size_), + saturate_); + } + } else { + ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, + process_block_count, broadcast_dim, process_block_size, saturate_); + } } else { ORT_THROW("Unsupported input type."); } diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index 235ecfde0954a..fcd1db31f95ef 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -10,6 +10,8 @@ #include "core/framework/float8.h" #include "core/framework/int4.h" #include +#include +#include namespace onnxruntime { @@ -305,4 +307,612 @@ ParQuantizeLinearSat(const MLFloat16* Input, #endif +/** + * @brief compute blocked quantization + * + * @tparam TIn + * @tparam TOut + * @tparam output_type_group 0: int other than int4. + * 1: float8 + * 2: int4 + * @method op0 baseline implementation. Single thread. Scalar instructions. + * @method op1 multi-threading implementation. Vector instructions. + */ +template +struct BlockedQuantizeLinear { + /** + * @brief Compute blocked quantization using multi-threading and vector instructions. + * Quantize axis is not the last axis. Block the last axis using thread_block_size. + * N is usually large. Within a block, scale's index increments along with output's index. + * + * @param thread_pool thread pool + * @param input input tensor + * @param scale scale tensor + * @param zero_point zero point tensor + * @param output output tensor + * @param M total size of dimensions before quantize axis + * @param K size of dimension on quantize axis + * @param N total size of dimensions after quantize axis + * @param quant_block_size quantization block size + * @param thread_block_size task block size + * @param saturate used by float8 + */ + static void opNotLastAxis(concurrency::ThreadPool* thread_pool, const float* input, const float* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + std::ptrdiff_t N, const std::ptrdiff_t quant_block_size, + const std::ptrdiff_t thread_block_size, bool saturate); + + /** + * @brief Compute blocked quantization using multi-threading and vector instructions. + * Quantize axis is the last axis. Block along quantize axis using quant_block_size + * as thread_block_size. quant_block_size is usually 2's power between 16 and 256. + * Within a block, scale index does not change. + * + * @param thread_pool thread pool + * @param input input tensor + * @param scale scale tensor + * @param zero_point zero point tensor + * @param output output tensor + * @param M total size of dimensions before quantize axis + * @param K size of dimension on quantize axis + * @param quant_block_size quantization block size + * @param saturate used by float8 + */ + static void opLastAxis(concurrency::ThreadPool* thread_pool, const float* input, const float* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + const std::ptrdiff_t quant_block_size, bool saturate); +}; + +template +struct BlockedQuantizeLinear { + static void opNotLastAxis(concurrency::ThreadPool* thread_pool, const float* input, const float* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + std::ptrdiff_t N, const std::ptrdiff_t quant_block_size, + const std::ptrdiff_t thread_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + constexpr auto low = static_cast(std::numeric_limits::lowest()); + constexpr auto high = static_cast(std::numeric_limits::max()); + const auto num_thread_block_N = (N + thread_block_size - 1) / thread_block_size; + const auto num_thread_block = M * K * num_thread_block_N; + const TensorOpCost unit_cost{static_cast(thread_block_size * sizeof(float) * 2), + static_cast(thread_block_size * sizeof(TOut)), + static_cast(thread_block_size) * 2.0}; + auto KN = K * N; + auto num_quant_block_KN = (K + quant_block_size - 1) / quant_block_size * N; + const auto num_thread_block_KN = K * num_thread_block_N; + + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto m = begin / num_thread_block_KN, k = begin % num_thread_block_KN / num_thread_block_N; + auto n_blk = begin % num_thread_block_N, n = n_blk * thread_block_size; + auto output_idx = m * KN + k * N + n; + auto quant_param_idx = m * num_quant_block_KN + k / quant_block_size * N; + auto quant_param_idx_t = quant_param_idx + n; + + for (; begin < end; ++begin) { + auto n_end = std::min(N, n + thread_block_size); + // TODO(fajin): 1> use SIMD, 2> set block to quant_block_size * thread_block_size + for (; n < n_end; ++n, ++output_idx, ++quant_param_idx_t) { + // TODO(fajin): perf difference + auto zp = zero_point ? static_cast(zero_point[quant_param_idx_t]) : 0; + auto sc = scale[quant_param_idx_t]; + auto v = std::clamp(static_cast(std::nearbyint(input[output_idx] / sc)) + zp, low, high); + output[output_idx] = static_cast(v); + } + + // A simpler approach is to calculate m, k, n in every block. + // This approach tries to reduce division and modulo in the block loop. Not benchmarked. + if (n == N) { + n = 0; + ++k; + if (k == K) { + k = 0; + quant_param_idx += N; + } else if (k % quant_block_size == 0) { + quant_param_idx += N; + } + + quant_param_idx_t = quant_param_idx; + } + } + }); + } + + static void opLastAxis(concurrency::ThreadPool* thread_pool, const float* input, const float* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + const std::ptrdiff_t quant_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + // quant block size is used as thread block size + const auto num_thread_block_K = (K + quant_block_size - 1) / quant_block_size; + const auto num_thread_block = num_thread_block_K * M; + const TensorOpCost unit_cost{static_cast(quant_block_size * sizeof(float)), + static_cast(quant_block_size * sizeof(TOut)), + static_cast(quant_block_size) * 2.0}; + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto m = begin / num_thread_block_K, k_blk = begin % num_thread_block_K, k = k_blk * quant_block_size; + auto output_idx = m * K + k; + + for (; begin < end; ++begin) { + // each thread block is also a quantization block + auto zp = zero_point ? zero_point[begin] : static_cast(0); + auto sc = scale[begin]; + size_t output_size = std::min(K - k, quant_block_size); + MlasQuantizeLinear(input + output_idx, output + output_idx, output_size, sc, zp); + output_idx += output_size; + k = output_idx % K; + } + }); + } +}; + +template +struct BlockedQuantizeLinear { + static void opNotLastAxis(concurrency::ThreadPool* thread_pool, const MLFloat16* input, const MLFloat16* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + std::ptrdiff_t N, const std::ptrdiff_t quant_block_size, + const std::ptrdiff_t thread_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + constexpr auto low = static_cast(std::numeric_limits::lowest()); + constexpr auto high = static_cast(std::numeric_limits::max()); + const auto num_thread_block_N = (N + thread_block_size - 1) / thread_block_size; + const auto num_thread_block = M * K * num_thread_block_N; + const TensorOpCost unit_cost{static_cast(thread_block_size * sizeof(MLFloat16) * 2), + static_cast(thread_block_size * sizeof(TOut)), + static_cast(thread_block_size) * 2.0}; + auto KN = K * N; + auto num_quant_block_KN = (K + quant_block_size - 1) / quant_block_size * N; + const auto num_thread_block_KN = K * num_thread_block_N; + + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto m = begin / num_thread_block_KN, k = begin % num_thread_block_KN / num_thread_block_N; + auto n_blk = begin % num_thread_block_N, n = n_blk * thread_block_size; + auto output_idx = m * KN + k * N + n; + auto quant_param_idx = m * num_quant_block_KN + k / quant_block_size * N; + auto quant_param_idx_t = quant_param_idx + n; + + for (; begin < end; ++begin) { + auto n_end = std::min(N, n + thread_block_size); + // TODO(fajin): 1> use SIMD, 2> set block to quant_block_size * thread_block_size + for (; n < n_end; ++n, ++output_idx, ++quant_param_idx_t) { + // TODO(fajin): perf difference + auto zp = zero_point ? static_cast(zero_point[quant_param_idx_t]) : 0; + auto sc = scale[quant_param_idx_t].ToFloat(); + auto v = std::clamp(static_cast(std::nearbyint(input[output_idx].ToFloat() / sc)) + zp, + low, high); + output[output_idx] = static_cast(v); + } + + if (n == N) { + n = 0; + ++k; + if (k == K) { + k = 0; + quant_param_idx += N; + } else if (k % quant_block_size == 0) { + quant_param_idx += N; + } + + quant_param_idx_t = quant_param_idx; + } + } + }); + } + + static void opLastAxis(concurrency::ThreadPool* thread_pool, const MLFloat16* input, const MLFloat16* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + const std::ptrdiff_t quant_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + constexpr auto low = static_cast(std::numeric_limits::lowest()); + constexpr auto high = static_cast(std::numeric_limits::max()); + // quant block size is used as thread block size + const auto num_thread_block_K = (K + quant_block_size - 1) / quant_block_size; + const auto num_thread_block = num_thread_block_K * M; + const TensorOpCost unit_cost{static_cast(quant_block_size * sizeof(MLFloat16)), + static_cast(quant_block_size * sizeof(TOut)), + static_cast(quant_block_size) * 2.0}; + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto m = begin / num_thread_block_K, k_blk = begin % num_thread_block_K, k = k_blk * quant_block_size; + auto output_idx = m * K + k; + + for (; begin < end; ++begin) { + // each thread block is also a quantization block + auto zp = zero_point ? static_cast(zero_point[begin]) : 0; + auto sc = scale[begin].ToFloat(); + auto output_idx_end = std::min(K - k, quant_block_size) + output_idx; + for (; output_idx < output_idx_end; ++output_idx) { + auto v = std::clamp(static_cast(std::nearbyint(input[output_idx].ToFloat() / sc)) + zp, + low, high); + output[output_idx] = static_cast(v); + } + k = output_idx % K; + } + }); + } +}; + +template +struct BlockedQuantizeLinear { + static void opNotLastAxis(concurrency::ThreadPool* thread_pool, const float* input, const float* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + std::ptrdiff_t N, const std::ptrdiff_t quant_block_size, + const std::ptrdiff_t thread_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + constexpr auto low = static_cast(TOut::min_val); + constexpr auto high = static_cast(TOut::max_val); + const auto num_thread_block_N = (N + thread_block_size - 1) / thread_block_size; + const auto num_thread_block = M * K * num_thread_block_N; + const TensorOpCost unit_cost{static_cast(thread_block_size * sizeof(float) * 2), + static_cast(thread_block_size * sizeof(typename TOut::UnpackedType)), + static_cast(thread_block_size) * 2.0}; + auto KN = K * N; + auto num_quant_block_KN = (K + quant_block_size - 1) / quant_block_size * N; + const auto num_thread_block_KN = K * num_thread_block_N; + + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto m = begin / num_thread_block_KN, k = begin % num_thread_block_KN / num_thread_block_N; + auto n_blk = begin % num_thread_block_N, n = n_blk * thread_block_size; + auto output_idx = m * KN + k * N + n; + auto quant_param_idx = m * num_quant_block_KN + k / quant_block_size * N; + auto quant_param_idx_t = quant_param_idx + n; + + for (; begin < end; ++begin) { + auto n_end = std::min(N, n + thread_block_size); + // TODO(fajin): 1> use SIMD, 2> set block to quant_block_size * thread_block_size + // TODO(fajin): process 2 elements at a time + for (; n < n_end; ++n, ++output_idx, ++quant_param_idx_t) { + // TODO(fajin): perf difference + auto zp = zero_point + ? static_cast(zero_point[quant_param_idx_t >> 1].GetElem(quant_param_idx_t & 1)) + : 0; + auto sc = scale[quant_param_idx_t]; + auto v = std::clamp(static_cast(std::nearbyint(input[output_idx] / sc)) + zp, low, high); + output[output_idx >> 1].SetElem(output_idx & 1, static_cast(v)); + } + + if (n == N) { + n = 0; + ++k; + if (k == K) { + k = 0; + quant_param_idx += N; + } else if (k % quant_block_size == 0) { + quant_param_idx += N; + } + + quant_param_idx_t = quant_param_idx; + } + } + }); + } + + static void opLastAxis(concurrency::ThreadPool* thread_pool, const float* input, const float* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + const std::ptrdiff_t quant_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + constexpr auto low = static_cast(TOut::min_val); + constexpr auto high = static_cast(TOut::max_val); + // quant block size is used as thread block size + const auto num_thread_block_K = (K + quant_block_size - 1) / quant_block_size; + const auto num_thread_block = num_thread_block_K * M; + const TensorOpCost unit_cost{static_cast(quant_block_size * sizeof(float)), + static_cast(quant_block_size * sizeof(typename TOut ::UnpackedType)), + static_cast(quant_block_size) * 2.0}; + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto m = begin / num_thread_block_K, k_blk = begin % num_thread_block_K, k = k_blk * quant_block_size; + auto output_idx = m * K + k; + + for (; begin < end; ++begin) { + auto zp = zero_point ? static_cast(zero_point[begin >> 1].GetElem(begin & 1)) : 0; + auto sc = scale[begin]; + size_t output_idx_end = std::min(K - k, quant_block_size) + output_idx; + size_t out_start = output_idx, out_end = output_idx_end; + + if (out_start & 1) { + auto v = std::clamp(static_cast(std::nearbyint(input[out_start] / sc)) + zp, low, high); + output[out_start >> 1].SetElem(1, static_cast(v)); + ++out_start; + } + + if (out_end & 1) { + --out_end; + auto v = std::clamp(static_cast(std::nearbyint(input[out_end] / sc)) + zp, low, high); + output[out_end >> 1].SetElem(0, static_cast(v)); + } + + if constexpr (std::is_same::value) { + MlasQuantizeLinearS4(input + out_start, reinterpret_cast(&(output[out_start >> 1])), + out_end - out_start, sc, static_cast(zp)); + } else { + MlasQuantizeLinearU4(input + out_start, reinterpret_cast(&(output[out_start >> 1])), + out_end - out_start, sc, static_cast(zp)); + } + + output_idx = output_idx_end; + k = output_idx % K; + } + }); + } +}; + +template +struct BlockedQuantizeLinear { + static void opNotLastAxis(concurrency::ThreadPool* thread_pool, const MLFloat16* input, const MLFloat16* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + std::ptrdiff_t N, const std::ptrdiff_t quant_block_size, + const std::ptrdiff_t thread_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + constexpr auto low = static_cast(TOut::min_val); + constexpr auto high = static_cast(TOut::max_val); + const auto num_thread_block_N = (N + thread_block_size - 1) / thread_block_size; + const auto num_thread_block = M * K * num_thread_block_N; + const TensorOpCost unit_cost{static_cast(thread_block_size * sizeof(MLFloat16) * 2), + static_cast(thread_block_size * sizeof(typename TOut::UnpackedType)), + static_cast(thread_block_size) * 2.0}; + auto KN = K * N; + auto num_quant_block_KN = (K + quant_block_size - 1) / quant_block_size * N; + const auto num_thread_block_KN = K * num_thread_block_N; + + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto m = begin / num_thread_block_KN, k = begin % num_thread_block_KN / num_thread_block_N; + auto n_blk = begin % num_thread_block_N, n = n_blk * thread_block_size; + auto output_idx = m * KN + k * N + n; + auto quant_param_idx = m * num_quant_block_KN + k / quant_block_size * N; + auto quant_param_idx_t = quant_param_idx + n; + + for (; begin < end; ++begin) { + auto n_end = std::min(N, n + thread_block_size); + // TODO(fajin): 1> use SIMD, 2> set block to quant_block_size * thread_block_size + // TODO(fajin): process 2 elements at a time + for (; n < n_end; ++n, ++output_idx, ++quant_param_idx_t) { + // TODO(fajin): perf difference + auto zp = zero_point + ? static_cast(zero_point[quant_param_idx_t >> 1].GetElem(quant_param_idx_t & 1)) + : 0; + auto sc = scale[quant_param_idx_t].ToFloat(); + auto v = std::clamp(static_cast(std::nearbyint(input[output_idx].ToFloat() / sc)) + zp, + low, high); + output[output_idx >> 1].SetElem(output_idx & 1, static_cast(v)); + } + + if (n == N) { + n = 0; + ++k; + if (k == K) { + k = 0; + quant_param_idx += N; + } else if (k % quant_block_size == 0) { + quant_param_idx += N; + } + + quant_param_idx_t = quant_param_idx; + } + } + }); + } + + static void opLastAxis(concurrency::ThreadPool* thread_pool, const MLFloat16* input, const MLFloat16* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + const std::ptrdiff_t quant_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + constexpr auto low = static_cast(TOut::min_val); + constexpr auto high = static_cast(TOut::max_val); + // quant block size is used as thread block size + const auto num_thread_block_K = (K + quant_block_size - 1) / quant_block_size; + const auto num_thread_block = num_thread_block_K * M; + const TensorOpCost unit_cost{static_cast(quant_block_size * sizeof(MLFloat16)), + static_cast(quant_block_size * sizeof(typename TOut::UnpackedType)), + static_cast(quant_block_size) * 2.0}; + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto m = begin / num_thread_block_K, k_blk = begin % num_thread_block_K, k = k_blk * quant_block_size; + auto output_idx = m * K + k; + + for (; begin < end; ++begin) { + // each thread block is also a quantization block + auto zp = zero_point ? static_cast(zero_point[begin >> 1].GetElem(begin & 1)) : 0; + auto sc = scale[begin].ToFloat(); + auto output_idx_end = std::min(K - k, quant_block_size) + output_idx; + for (; output_idx < output_idx_end; ++output_idx) { + auto v = std::clamp(static_cast(std::nearbyint(input[output_idx].ToFloat() / sc)) + zp, + low, high); + output[output_idx >> 1].SetElem(output_idx & 1, static_cast(v)); + } + + k = output_idx % K; + } + }); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template +struct BlockedQuantizeLinear { + static void opNotLastAxis(concurrency::ThreadPool* thread_pool, const float* input, const float* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + std::ptrdiff_t N, const std::ptrdiff_t quant_block_size, + const std::ptrdiff_t thread_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(zero_point); + const auto num_thread_block_N = (N + thread_block_size - 1) / thread_block_size; + const auto num_thread_block = M * K * num_thread_block_N; + const TensorOpCost unit_cost{static_cast(thread_block_size * sizeof(float) * 2), + static_cast(thread_block_size * sizeof(uint8_t)), + static_cast(thread_block_size) * 2.0}; + auto KN = K * N; + auto num_quant_block_KN = (K + quant_block_size - 1) / quant_block_size * N; + const auto num_thread_block_KN = K * num_thread_block_N; + + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto m = begin / num_thread_block_KN, k = begin % num_thread_block_KN / num_thread_block_N; + auto n_blk = begin % num_thread_block_N, n = n_blk * thread_block_size; + auto output_idx = m * KN + k * N + n; + auto quant_param_idx = m * num_quant_block_KN + k / quant_block_size * N; + auto quant_param_idx_t = quant_param_idx + n; + + for (; begin < end; ++begin) { + auto n_end = std::min(N, n + thread_block_size); + for (; n < n_end; ++n, ++output_idx, ++quant_param_idx_t) { + output[output_idx] = TOut(input[output_idx] / scale[quant_param_idx_t], saturate); + } + + if (n == N) { + n = 0; + ++k; + if (k == K) { + k = 0; + quant_param_idx += N; + } else if (k % quant_block_size == 0) { + quant_param_idx += N; + } + + quant_param_idx_t = quant_param_idx; + } + } + }); + } + + static void opLastAxis(concurrency::ThreadPool* thread_pool, const float* input, const float* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + const std::ptrdiff_t quant_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(zero_point); + // quant block size is used as thread block size + const auto num_thread_block_K = (K + quant_block_size - 1) / quant_block_size; + const auto num_thread_block = num_thread_block_K * M; + const TensorOpCost unit_cost{static_cast(quant_block_size * sizeof(float)), + static_cast(quant_block_size * sizeof(uint8_t)), + static_cast(quant_block_size) * 2.0}; + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto m = begin / num_thread_block_K, k_blk = begin % num_thread_block_K, k = k_blk * quant_block_size; + auto output_idx = m * K + k; + + for (; begin < end; ++begin) { + auto sc = scale[begin]; + auto output_idx_end = std::min(K - k, quant_block_size) + output_idx; + for (; output_idx < output_idx_end; ++output_idx) { + output[output_idx] = TOut(input[output_idx] / sc, saturate); + } + k = output_idx % K; + } + }); + } +}; + +template +struct BlockedQuantizeLinear { + static void opNotLastAxis(concurrency::ThreadPool* thread_pool, const MLFloat16* input, const MLFloat16* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + std::ptrdiff_t N, const std::ptrdiff_t quant_block_size, + const std::ptrdiff_t thread_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(zero_point); + const auto num_thread_block_N = (N + thread_block_size - 1) / thread_block_size; + const auto num_thread_block = M * K * num_thread_block_N; + const TensorOpCost unit_cost{static_cast(thread_block_size * sizeof(MLFloat16) * 2), + static_cast(thread_block_size * sizeof(uint8_t)), + static_cast(thread_block_size) * 2.0}; + auto KN = K * N; + auto num_quant_block_KN = (K + quant_block_size - 1) / quant_block_size * N; + const auto num_thread_block_KN = K * num_thread_block_N; + + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto m = begin / num_thread_block_KN, k = begin % num_thread_block_KN / num_thread_block_N; + auto n_blk = begin % num_thread_block_N, n = n_blk * thread_block_size; + auto output_idx = m * KN + k * N + n; + auto quant_param_idx = m * num_quant_block_KN + k / quant_block_size * N; + auto quant_param_idx_t = quant_param_idx + n; + + for (; begin < end; ++begin) { + auto n_end = std::min(N, n + thread_block_size); + for (; n < n_end; ++n, ++output_idx, ++quant_param_idx_t) { + output[output_idx] = TOut(input[output_idx].ToFloat() / scale[quant_param_idx_t].ToFloat(), saturate); + } + + if (n == N) { + n = 0; + ++k; + if (k == K) { + k = 0; + quant_param_idx += N; + } else if (k % quant_block_size == 0) { + quant_param_idx += N; + } + + quant_param_idx_t = quant_param_idx; + } + } + }); + } + + static void opLastAxis(concurrency::ThreadPool* thread_pool, const MLFloat16* input, const MLFloat16* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + const std::ptrdiff_t quant_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(zero_point); + const auto num_thread_block_K = (K + quant_block_size - 1) / quant_block_size; + const auto num_thread_block = num_thread_block_K * M; + const TensorOpCost unit_cost{static_cast(quant_block_size * sizeof(MLFloat16)), + static_cast(quant_block_size * sizeof(uint8_t)), + static_cast(quant_block_size) * 2.0}; + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto m = begin / num_thread_block_K, k_blk = begin % num_thread_block_K, k = k_blk * quant_block_size; + auto output_idx = m * K + k; + + for (; begin < end; ++begin) { + auto sc = scale[begin].ToFloat(); + auto output_idx_end = std::min(K - k, quant_block_size) + output_idx; + for (; output_idx < output_idx_end; ++output_idx) { + output[output_idx] = TOut(input[output_idx].ToFloat() / sc, saturate); + } + k = output_idx % K; + } + }); + } +}; + +#endif + } // namespace onnxruntime diff --git a/onnxruntime/test/onnx/microbenchmark/quantize.cc b/onnxruntime/test/onnx/microbenchmark/quantize.cc index d61c9db68d23c..fda4324c0e83d 100644 --- a/onnxruntime/test/onnx/microbenchmark/quantize.cc +++ b/onnxruntime/test/onnx/microbenchmark/quantize.cc @@ -3,6 +3,7 @@ #include #include "core/util/qmath.h" #include "core/util/thread_utils.h" +#include "core/framework/int4.h" static void BenchSize(benchmark::internal::Benchmark* b) { for (int size : {80000, 160000, 320000, 640000, 1280000}) { @@ -77,3 +78,105 @@ BENCHMARK(BM_Quantize) ->UseRealTime() ->Unit(benchmark::TimeUnit::kNanosecond) ->Apply(BenchSize); + +static void BM_BlockedQuantize_NotLastAxis(benchmark::State& state) { + using Int4 = onnxruntime::Int4x2; + using UnpackedType = Int4::UnpackedType; + const std::ptrdiff_t M[] = {96, 192, 192}; + const std::ptrdiff_t N[] = {2048, 2048, 4096}; + const int64_t size_idx = state.range(0); + const int64_t threads = state.range(1); + const int64_t block_size = state.range(2); + size_t batch_size = M[size_idx] * N[size_idx]; + size_t quant_block_size = 64; + size_t scale_size = batch_size / quant_block_size; + + float* a_data = GenerateArrayWithRandomValue(batch_size, -16, 14); + size_t a_quant_size = sizeof(Int4::UnpackedType) * Int4::CalcNumInt4Pairs(batch_size); + float* scale = GenerateArrayWithRandomValue(scale_size, 1.95f, 2.33f); + UnpackedType* zero_point = GenerateArrayWithRandomValue(Int4::CalcNumInt4Pairs(scale_size), -1, 1); + UnpackedType* a_data_quant = static_cast(aligned_alloc(a_quant_size, 64)); + + OrtThreadPoolParams tpo; + tpo.thread_pool_size = static_cast(threads); + tpo.auto_set_affinity = true; + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + + for (auto _ : state) { + benchmark::DoNotOptimize(a_data_quant); + onnxruntime::BlockedQuantizeLinear::opNotLastAxis( + tp.get(), a_data, scale, reinterpret_cast(zero_point), reinterpret_cast(a_data_quant), + 1, M[size_idx], N[size_idx], static_cast(quant_block_size), + static_cast(block_size), true); + benchmark::ClobberMemory(); + } + aligned_free(a_data_quant); + aligned_free(a_data); + aligned_free(scale); + aligned_free(zero_point); +} + +static void BM_BlockedQuantize_LastAxis(benchmark::State& state) { + using Int4 = onnxruntime::Int4x2; + using UnpackedType = Int4::UnpackedType; + const std::ptrdiff_t M[] = {96, 192, 192}; + const std::ptrdiff_t N[] = {2048, 2048, 4096}; + const int64_t size_idx = state.range(0); + const int64_t threads = state.range(1); + const int64_t quant_block_size = state.range(2); + size_t batch_size = M[size_idx] * N[size_idx]; + size_t scale_size = batch_size / quant_block_size; + + float* a_data = GenerateArrayWithRandomValue(batch_size, -16, 14); + size_t a_quant_size = sizeof(Int4::UnpackedType) * Int4::CalcNumInt4Pairs(batch_size); + float* scale = GenerateArrayWithRandomValue(scale_size, 1.95f, 2.33f); + UnpackedType* zero_point = GenerateArrayWithRandomValue(Int4::CalcNumInt4Pairs(scale_size), -1, 1); + UnpackedType* a_data_quant = static_cast(aligned_alloc(a_quant_size, 64)); + + OrtThreadPoolParams tpo; + tpo.thread_pool_size = static_cast(threads); + tpo.auto_set_affinity = true; + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + + for (auto _ : state) { + benchmark::DoNotOptimize(a_data_quant); + onnxruntime::BlockedQuantizeLinear::opLastAxis( + tp.get(), a_data, scale, reinterpret_cast(zero_point), reinterpret_cast(a_data_quant), + M[size_idx], N[size_idx], static_cast(quant_block_size), true); + benchmark::ClobberMemory(); + } + aligned_free(a_data_quant); + aligned_free(a_data); + aligned_free(scale); + aligned_free(zero_point); +} + +BENCHMARK(BM_BlockedQuantize_NotLastAxis) + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kNanosecond) + ->Apply([](benchmark::internal::Benchmark* b) { + for (int size_idx : {0, 1, 2}) { + for (int thread : {2, 4, 8}) { + for (int block_size : {64, 128}) { + b->Args({size_idx, thread, block_size}); + } + } + } + }); + +BENCHMARK(BM_BlockedQuantize_LastAxis) + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kNanosecond) + ->Apply([](benchmark::internal::Benchmark* b) { + for (int size_idx : {0, 1, 2}) { + for (int thread : {2, 4, 8}) { + for (int quant_block_size : {16, 64, 256}) { + b->Args({size_idx, thread, quant_block_size}); + } + } + } + }); diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 054dcfc75b92e..386bd7d5f7311 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -1549,5 +1549,1152 @@ TEST(DequantizeLinearOp21BlockedTest, Float8_NoZeroPoint_LastAxis) { #endif } // namespace blocked_dequantization +namespace blocked_quantization { + +template +void QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(int64_t block_size, + int64_t scale_block_count, + int64_t zero_point_block_count) { + OpTester test("QuantizeLinear", 21); + std::vector dims{2, 4}; + std::vector y, x_zero_point; + std::vector x, x_scale; + SessionOptions so; + std::vector log_msgs; // redirect error messages + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + so.user_logging_function = [](void* param, OrtLoggingLevel severity, const char* category, + const char* logid, const char* code_location, const char* message) { + ORT_UNUSED_PARAMETER(severity); + ORT_UNUSED_PARAMETER(category); + ORT_UNUSED_PARAMETER(logid); + ORT_UNUSED_PARAMETER(code_location); + std::vector* v_ptr = reinterpret_cast*>(param); + std::vector& msg_vector = *v_ptr; + msg_vector.push_back(std::string(message)); + }; + so.user_logging_param = &log_msgs; + so.session_logid = "QuantizeLinear"; + so.use_per_session_threads = false; + so.session_log_verbosity_level = 1; + so.graph_optimization_level = TransformerLevel::Default; + + for (int64_t i = 0, n = 2 * zero_point_block_count; i < n; ++i) x_zero_point.push_back(0); + for (int64_t i = 0, n = 2 * scale_block_count; i < n; i++) x_scale.push_back(Tin(2.0f)); + for (int i = 0; i < 8; ++i) { + x.push_back(Tin(static_cast(i) * 2.0f)); + y.push_back(i); + } + + test.AddInput("x", dims, x); + test.AddAttribute("axis", 1); + test.AddAttribute("block_size", block_size); + test.AddInput("scale", {2, scale_block_count}, x_scale); + test.AddInput("zero_point", {2, zero_point_block_count}, x_zero_point); + test.AddOutput("y", dims, y); + test.Run(so, OpTester::ExpectResult::kExpectFailure, "", {}, nullptr, &eps); +} + +template +void QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(int64_t block_size, + int64_t scale_block_count, + int64_t zero_point_block_count) { + OpTester test("QuantizeLinear", 21); + std::vector dims{2, 4}; + std::vector x_zero_point, y; + std::vector x, x_scale; + SessionOptions so; + std::vector log_msgs; // redirect error messages + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + so.user_logging_function = [](void* param, OrtLoggingLevel severity, const char* category, + const char* logid, const char* code_location, const char* message) { + ORT_UNUSED_PARAMETER(severity); + ORT_UNUSED_PARAMETER(category); + ORT_UNUSED_PARAMETER(logid); + ORT_UNUSED_PARAMETER(code_location); + std::vector* v_ptr = reinterpret_cast*>(param); + std::vector& msg_vector = *v_ptr; + msg_vector.push_back(std::string(message)); + }; + so.user_logging_param = &log_msgs; + so.session_logid = "QuantizeLinear"; + so.use_per_session_threads = false; + so.session_log_verbosity_level = 1; + so.graph_optimization_level = TransformerLevel::Default; + + for (int64_t i = 0, n = zero_point_block_count; i < n; ++i) x_zero_point.push_back(Tout(0, 0)); + for (int64_t i = 0, n = 2 * scale_block_count; i < n; i++) x_scale.push_back(Tin(2.0f)); + for (int i = 0; i < 8; ++i) { + if (i & 1) y.push_back(Tout(i - 1, i)); + x.push_back(Tin(static_cast(i) * 2.0f)); + } + + test.AddInput("x", dims, x); + test.AddAttribute("axis", 1); + test.AddAttribute("block_size", block_size); + test.AddInput("scale", {2, scale_block_count}, x_scale); + test.AddInput("zero_point", {2, zero_point_block_count}, x_zero_point); + test.AddOutput("y", dims, y); + test.Run(so, OpTester::ExpectResult::kExpectFailure, "", {}, nullptr, &eps); +} + +template +void QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(int64_t block_size, + int64_t scale_block_count, + int64_t zero_point_block_count) { + OpTester test("QuantizeLinear", 21); + std::vector dims{2, 4}; + std::vector x_zero_point, y; + std::vector x, x_scale; + SessionOptions so; + std::vector log_msgs; // redirect error messages + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + so.user_logging_function = [](void* param, OrtLoggingLevel severity, const char* category, + const char* logid, const char* code_location, const char* message) { + ORT_UNUSED_PARAMETER(severity); + ORT_UNUSED_PARAMETER(category); + ORT_UNUSED_PARAMETER(logid); + ORT_UNUSED_PARAMETER(code_location); + std::vector* v_ptr = reinterpret_cast*>(param); + std::vector& msg_vector = *v_ptr; + msg_vector.push_back(std::string(message)); + }; + so.user_logging_param = &log_msgs; + so.session_logid = "QuantizeLinear"; + so.use_per_session_threads = false; + so.session_log_verbosity_level = 1; + so.graph_optimization_level = TransformerLevel::Default; + + for (int64_t i = 0, n = 2 * zero_point_block_count; i < n; i++) x_zero_point.push_back(Tout(0.0f)); + for (int64_t i = 0, n = 2 * scale_block_count; i < n; i++) x_scale.push_back(Tin(2.0f)); + for (int i = 0; i < 8; ++i) y.push_back(Tout(static_cast(i))); + for (int i = 0; i < 8; ++i) x.push_back(Tin(static_cast(i) * 2.0f)); + + test.AddInput("x", dims, x); + test.AddAttribute("axis", 1); + test.AddAttribute("block_size", block_size); + test.AddInput("scale", {2, scale_block_count}, x_scale); + test.AddInput("zero_point", {2, zero_point_block_count}, x_zero_point); + test.AddOutput("y", dims, y); + test.Run(so, OpTester::ExpectResult::kExpectFailure, "", {}, nullptr, &eps); +} + +// test negative block size fail +TEST(QuantizeLinearOp21BlockedTest, NagativeBlockSize_Int) { + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(-1, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(-1, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(-2, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(-2, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-3, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-3, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-4, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-4, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-5, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-5, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-6, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-1, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-1, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(-1, 2, 2); +} + +#if !defined(DISABLE_FLOAT8_TYPES) +TEST(QuantizeLinearOp21BlockedTest, NagativeBlockSize_Float8) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCpuExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + if (enable_cpu || enable_cuda) { + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-1, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-2, 2, 2); + } + if (enable_cpu) { + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-3, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-4, 2, 2); + } + if (enable_cpu || enable_cuda) { + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-5, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-6, 2, 2); + } + if (enable_cpu) { + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-1, 2, 2); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(-1, 2, 2); + } +} +#endif + +// test block size incompatible with x_scale shape fail +TEST(QuantizeLinearOp21BlockedTest, IncompatibleBlockSizeWithX_Int) { + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 1, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 3, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 3, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 1, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 1, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 3, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 3, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 1, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 1, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 3, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 3, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 1, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 3, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 1, 1); +} + +#if !defined(DISABLE_FLOAT8_TYPES) +TEST(QuantizeLinearOp21BlockedTest, IncompatibleBlockSizeWithX_Float8) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCpuExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + if (enable_cpu || enable_cuda) { + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 1, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 3, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 1, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 3, 3); + } + if (enable_cpu) { + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 1, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 3, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 1, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 3, 3); + } +} +#endif + +// test x_scale vs. x_zero_point shape incompatible fail +TEST(QuantizeLinearOp21BlockedTest, ScaleShapeUnmatchZeroPoint_Int) { + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 2, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 2, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 2, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(3, 2, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(3, 2, 1); +} + +#if !defined(DISABLE_FLOAT8_TYPES) +TEST(QuantizeLinearOp21BlockedTest, ScaleShapeUnmatchZeroPoint_Float8) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCpuExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + if (enable_cpu || enable_cuda) { + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 3); + } + if (enable_cpu) { + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 3); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 1); + QuantizeLinearOp21BlockedTest_InvalidBlockSize_Float8(3, 2, 3); + } +} +#endif + +// test DQ with blocked quantization succeed +template +void QuantizeLinearOp21BlockedTest_Int4_Succeed(std::vector&& dims, + int64_t axis, + int64_t block_size, + std::vector& x_, + std::vector& scale_, + std::vector& zero_point_, + std::vector& y_) { + OpTester test("QuantizeLinear", 21); + std::vector scale_shape; + std::vector zero_point, y; + std::vector x, scale; + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + + int64_t non_neg_axis = axis < 0 ? axis + dims.size() : axis; + bool use_zero_point = !zero_point_.empty(); + + for (auto v : x_) x.push_back(Tin(v)); + for (auto v : scale_) scale.push_back(Tin(v)); + for (size_t i = 0, n = dims.size(); i < n; ++i) { + scale_shape.push_back((int64_t)i == non_neg_axis ? (dims[i] + block_size - 1) / block_size : dims[i]); + } + + size_t i = 0, n = y_.size(); + for (; i < n - 1; i += 2) y.push_back(Tout(y_[i], y_[i + 1])); + if (i < n) y.push_back(Tout(y_[i], 0xF)); + + if (use_zero_point) { + i = 0, n = zero_point_.size(); + for (; i < n - 1; i += 2) zero_point.push_back(Tout(zero_point_[i], zero_point_[i + 1])); + if (i < n) zero_point.push_back(Tout(zero_point_[i], 0xF)); + } + + test.AddInput("x", dims, x); + test.AddAttribute("axis", axis); + test.AddAttribute("block_size", block_size); + test.AddInput("y_scale", scale_shape, scale); + if (use_zero_point) test.AddInput("y_zero_point", scale_shape, zero_point); + test.AddOutput("y", dims, y); + test.Run(BaseTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &eps); +} + +template +void QuantizeLinearOp21BlockedTest_Int_Succeed(std::vector&& dims, + int64_t axis, + int64_t block_size, + std::vector& x_, + std::vector& scale_, + std::vector& zero_point_, + std::vector& y_) { + OpTester test("QuantizeLinear", 21); + std::vector scale_shape; + std::vector zero_point, y; + std::vector x, scale; + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + + int64_t non_neg_axis = axis < 0 ? axis + dims.size() : axis; + bool use_zero_point = !zero_point_.empty(); + + for (auto v : x_) x.push_back(Tin(v)); + for (auto v : scale_) scale.push_back(Tin(v)); + for (size_t i = 0, n = dims.size(); i < n; ++i) { + scale_shape.push_back((int64_t)i == non_neg_axis ? (dims[i] + block_size - 1) / block_size : dims[i]); + } + for (auto v : y_) y.push_back(v); + if (use_zero_point) + for (auto v : zero_point_) zero_point.push_back(v); + + test.AddInput("x", dims, x); + test.AddAttribute("axis", axis); + test.AddAttribute("block_size", block_size); + test.AddInput("scale", scale_shape, scale); + if (use_zero_point) test.AddInput("zero_point", scale_shape, zero_point); + test.AddOutput("y", dims, y); + test.Run(BaseTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &eps); +} + +template +void QuantizeLinearOp21BlockedTest_Float8_Succeed(std::vector&& dims, + int64_t axis, + int64_t block_size, + std::vector& x_, + std::vector& scale_, + std::vector& zero_point_, + std::vector& y_) { + OpTester test("QuantizeLinear", 21); + std::vector scale_shape; + std::vector zero_point, y; + std::vector x, scale; + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + + int64_t non_neg_axis = axis < 0 ? axis + dims.size() : axis; + bool use_zero_point = !zero_point_.empty(); + + for (auto v : x_) x.push_back(Tin(v)); + for (auto v : scale_) scale.push_back(Tin(v)); + for (size_t i = 0, n = dims.size(); i < n; ++i) { + scale_shape.push_back((int64_t)i == non_neg_axis ? (dims[i] + block_size - 1) / block_size : dims[i]); + } + + for (auto v : y_) y.push_back(Tout(static_cast(v))); + if (use_zero_point) { + for (auto v : zero_point_) zero_point.push_back(Tout(static_cast(v))); + } + + test.AddInput("x", dims, x); + test.AddAttribute("axis", axis); + test.AddAttribute("block_size", block_size); + test.AddInput("scale", scale_shape, scale); + if (use_zero_point) test.AddInput("zero_point", scale_shape, zero_point); + test.AddOutput("y", dims, y); + test.Run(BaseTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &eps); +} + +TEST(QuantizeLinearOp21BlockedTest, SignedInt4_NoZeroPoint_FirstAxis) { + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, + 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector x{14.0, 24.0, -17.5, -4, 14.0, 24.0, -17.5, -4, 14.0, 24.0, -17.5, -4, 14.0, 24.0, -17.5, -4, + 6.0, 8.0, -3.5, 0.0, 6.0, 8.0, -3.5, 0.0, 6.0, 8.0, -3.5, 0.0, 6.0, 8.0, -3.5, 0.0, + 2.0, 8.0, -10.5, -4.0, 2.0, 8.0, -10.5, -4.0, 2.0, 8.0, -10.5, -4.0, 2.0, 8.0, -10.5, -4.0, + 10.0, 24.0, -24.5, 8.0, 10.0, 24.0, -24.5, 8.0, 10.0, 24.0, -24.5, 8.0, 10.0, 24.0, -24.5, 8.0}; + std::vector y_2{-7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, + -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8}; + std::vector y_3{-7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, + -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, + -1, -2, -3, -4, -1, -2, -3, -4, -1, -2, -3, -4, -1, -2, -3, -4, + 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8}; + + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, SignedInt_NoZeroPoint_FirstAxis) { + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, + 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector x{14.0, 24.0, -17.5, -4, 14.0, 24.0, -17.5, -4, 14.0, 24.0, -17.5, -4, 14.0, 24.0, -17.5, -4, + 6.0, 8.0, -3.5, 0.0, 6.0, 8.0, -3.5, 0.0, 6.0, 8.0, -3.5, 0.0, 6.0, 8.0, -3.5, 0.0, + 2.0, 8.0, -10.5, -4.0, 2.0, 8.0, -10.5, -4.0, 2.0, 8.0, -10.5, -4.0, 2.0, 8.0, -10.5, -4.0, + 10.0, 24.0, -24.5, 8.0, 10.0, 24.0, -24.5, 8.0, 10.0, 24.0, -24.5, 8.0, 10.0, 24.0, -24.5, 8.0}; + std::vector y_2{-7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, + -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8}; + std::vector y_3{-7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, + -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, + -1, -2, -3, -4, -1, -2, -3, -4, -1, -2, -3, -4, -1, -2, -3, -4, + 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8}; + + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, SignedInt4_UseZeroPoint_FirstAxis) { + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, + 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{-6, -4, -3, -1, -6, -4, -3, -1, -6, -4, -3, -1, -6, -4, -3, -1, + 0, 2, 4, 7, 0, 2, 4, 7, 0, 2, 4, 7, 0, 2, 4, 7}; + std::vector x{2.0, 8.0, -7.0, -3, 2.0, 8.0, -7.0, -3, 2.0, 8.0, -7.0, -3, 2.0, 8.0, -7.0, -3, + -6.0, -8.0, 7.0, 1, -6.0, -8.0, 7.0, 1, -6.0, -8.0, 7.0, 1, -6.0, -8.0, 7.0, 1, + 2.0, 0, 3.5, 3.0, 2.0, 0, 3.5, 3.0, 2.0, 0, 3.5, 3.0, 2.0, 0, 3.5, 3.0, + 10.0, 16.0, -10.5, 15, 10.0, 16.0, -10.5, 15, 10.0, 16.0, -10.5, 15, 10.0, 16.0, -10.5, 15}; + std::vector y_2{-7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, + -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8}; + std::vector y_3{-7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, + -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, + -7, -4, -2, 2, -7, -4, -2, 2, -7, -4, -2, 2, -7, -4, -2, 2, + 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8}; + + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, SignedInt_UseZeroPoint_FirstAxis) { + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, + 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{-6, -4, -3, -1, -6, -4, -3, -1, -6, -4, -3, -1, -6, -4, -3, -1, + 0, 2, 4, 7, 0, 2, 4, 7, 0, 2, 4, 7, 0, 2, 4, 7}; + std::vector x{2.0, 8.0, -7.0, -3, 2.0, 8.0, -7.0, -3, 2.0, 8.0, -7.0, -3, 2.0, 8.0, -7.0, -3, + -6.0, -8.0, 7.0, 1, -6.0, -8.0, 7.0, 1, -6.0, -8.0, 7.0, 1, -6.0, -8.0, 7.0, 1, + 2.0, 0, 3.5, 3.0, 2.0, 0, 3.5, 3.0, 2.0, 0, 3.5, 3.0, 2.0, 0, 3.5, 3.0, + 10.0, 16.0, -10.5, 15, 10.0, 16.0, -10.5, 15, 10.0, 16.0, -10.5, 15, 10.0, 16.0, -10.5, 15}; + std::vector y_2{-7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, + -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8}; + std::vector y_3{-7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, + -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, + -7, -4, -2, 2, -7, -4, -2, 2, -7, -4, -2, 2, -7, -4, -2, 2, + 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8}; + + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, SignedInt4_NoZeroPoint_MiddleAxis) { + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector y_scale{-2.0, -4.0, -2.0, -4.0, -2.0, -4.0, -2.0, -4.0, 3.5, 1.0, 3.5, 1.0, 3.5, 1.0, 3.5, 1.0, + 2.0, 4.0, 2.0, 4.0, 2.0, 4.0, 2.0, 4.0, -3.5, -1.0, -3.5, -1.0, -3.5, -1.0, -3.5, -1.0}; + std::vector x{14, 24, 14, 24, 14, 24, 14, 24, 10, 16, 10, 16, 10, 16, 10, 16, + -10.5, -2, -10.5, -2, -10.5, -2, -10.5, -2, -3.5, 0, -3.5, 0, -3.5, 0, -3.5, 0, + 2, 8, 2, 8, 2, 8, 2, 8, 6, 16, 6, 16, 6, 16, 6, 16, + -17.5, -6, -17.5, -6, -17.5, -6, -17.5, -6, -24.5, 8, -24.5, 8, -24.5, 8, -24.5, 8}; + std::vector y_2{-7, -6, -7, -6, -7, -6, -7, -6, -5, -4, -5, -4, -5, -4, -5, -4, + -3, -2, -3, -2, -3, -2, -3, -2, -1, 0, -1, 0, -1, 0, -1, 0, + 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4, 3, 4, + 5, 6, 5, 6, 5, 6, 5, 6, 7, -8, 7, -8, 7, -8, 7, -8}; + std::vector y_3{-7, -6, -7, -6, -7, -6, -7, -6, -5, -4, -5, -4, -5, -4, -5, -4, + 5, 0, 5, 0, 5, 0, 5, 0, -1, 0, -1, 0, -1, 0, -1, 0, + 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4, 3, 4, + -8, -2, -8, -2, -8, -2, -8, -2, 7, -8, 7, -8, 7, -8, 7, -8}; + + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 1, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, SignedInt_NoZeroPoint_MiddleAxis) { + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector y_scale{-2.0, -4.0, -2.0, -4.0, -2.0, -4.0, -2.0, -4.0, 3.5, 1.0, 3.5, 1.0, 3.5, 1.0, 3.5, 1.0, + 2.0, 4.0, 2.0, 4.0, 2.0, 4.0, 2.0, 4.0, -3.5, -1.0, -3.5, -1.0, -3.5, -1.0, -3.5, -1.0}; + std::vector x{14, 24, 14, 24, 14, 24, 14, 24, 10, 16, 10, 16, 10, 16, 10, 16, + -10.5, -2, -10.5, -2, -10.5, -2, -10.5, -2, -3.5, 0, -3.5, 0, -3.5, 0, -3.5, 0, + 2, 8, 2, 8, 2, 8, 2, 8, 6, 16, 6, 16, 6, 16, 6, 16, + -17.5, -6, -17.5, -6, -17.5, -6, -17.5, -6, -24.5, 8, -24.5, 8, -24.5, 8, -24.5, 8}; + std::vector y_2{-7, -6, -7, -6, -7, -6, -7, -6, -5, -4, -5, -4, -5, -4, -5, -4, + -3, -2, -3, -2, -3, -2, -3, -2, -1, 0, -1, 0, -1, 0, -1, 0, + 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4, 3, 4, + 5, 6, 5, 6, 5, 6, 5, 6, 7, -8, 7, -8, 7, -8, 7, -8}; + std::vector y_3{-7, -6, -7, -6, -7, -6, -7, -6, -5, -4, -5, -4, -5, -4, -5, -4, + 5, 0, 5, 0, 5, 0, 5, 0, -1, 0, -1, 0, -1, 0, -1, 0, + 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4, 3, 4, + -9, -2, -9, -2, -9, -2, -9, -2, 7, -8, 7, -8, 7, -8, 7, -8}; + + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, SignedInt4_UseZeroPoint_MiddleAxis) { + std::vector zero_point{-6, -4, -6, -4, -6, -4, -6, -4, -3, -1, -3, -1, -3, -1, -3, -1, + 0, 2, 0, 2, 0, 2, 0, 2, 4, 7, 4, 7, 4, 7, 4, 7}; + std::vector y_scale{-2.0, -4.0, -2.0, -4.0, -2.0, -4.0, -2.0, -4.0, 3.5, 1.0, 3.5, 1.0, 3.5, 1.0, 3.5, 1.0, + 2.0, 4.0, 2.0, 4.0, 2.0, 4.0, 2.0, 4.0, -3.5, -1.0, -3.5, -1.0, -3.5, -1.0, -3.5, -1.0}; + std::vector x{2, 8, 2, 8, 2, 8, 2, 8, -2, 0, -2, 0, -2, 0, -2, 0, + 0, -1, 0, -1, 0, -1, 0, -1, 7, 1, 7, 1, 7, 1, 7, 1, + 2, 0, 2, 0, 2, 0, 2, 0, 6, 8, 6, 8, 6, 8, 6, 8, + -3.5, 1, -3.5, 1, -3.5, 1, -3.5, 1, -10.5, 15, -10.5, 15, -10.5, 15, -10.5, 15}; + std::vector y_2{-7, -6, -7, -6, -7, -6, -7, -6, -5, -4, -5, -4, -5, -4, -5, -4, + -3, -2, -3, -2, -3, -2, -3, -2, -1, 0, -1, 0, -1, 0, -1, 0, + 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4, 3, 4, + 5, 6, 5, 6, 5, 6, 5, 6, 7, -8, 7, -8, 7, -8, 7, -8}; + std::vector y_3{-7, -6, -7, -6, -7, -6, -7, -6, -5, -4, -5, -4, -5, -4, -5, -4, + -6, -4, -6, -4, -6, -4, -6, -4, -1, 0, -1, 0, -1, 0, -1, 0, + 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4, 3, 4, + -2, 2, -2, 2, -2, 2, -2, 2, 7, -8, 7, -8, 7, -8, 7, -8}; + + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 1, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, SignedInt_UseZeroPoint_MiddleAxis) { + std::vector zero_point{-6, -4, -6, -4, -6, -4, -6, -4, -3, -1, -3, -1, -3, -1, -3, -1, + 0, 2, 0, 2, 0, 2, 0, 2, 4, 7, 4, 7, 4, 7, 4, 7}; + std::vector y_scale{-2.0, -4.0, -2.0, -4.0, -2.0, -4.0, -2.0, -4.0, 3.5, 1.0, 3.5, 1.0, 3.5, 1.0, 3.5, 1.0, + 2.0, 4.0, 2.0, 4.0, 2.0, 4.0, 2.0, 4.0, -3.5, -1.0, -3.5, -1.0, -3.5, -1.0, -3.5, -1.0}; + std::vector x{2, 8, 2, 8, 2, 8, 2, 8, -2, 0, -2, 0, -2, 0, -2, 0, + 0, -1, 0, -1, 0, -1, 0, -1, 7, 1, 7, 1, 7, 1, 7, 1, + 2, 0, 2, 0, 2, 0, 2, 0, 6, 8, 6, 8, 6, 8, 6, 8, + -3.5, 1, -3.5, 1, -3.5, 1, -3.5, 1, -10.5, 15, -10.5, 15, -10.5, 15, -10.5, 15}; + std::vector y_2{-7, -6, -7, -6, -7, -6, -7, -6, -5, -4, -5, -4, -5, -4, -5, -4, + -3, -2, -3, -2, -3, -2, -3, -2, -1, 0, -1, 0, -1, 0, -1, 0, + 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4, 3, 4, + 5, 6, 5, 6, 5, 6, 5, 6, 7, -8, 7, -8, 7, -8, 7, -8}; + std::vector y_3{-7, -6, -7, -6, -7, -6, -7, -6, -5, -4, -5, -4, -5, -4, -5, -4, + -6, -4, -6, -4, -6, -4, -6, -4, -1, 0, -1, 0, -1, 0, -1, 0, + 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4, 3, 4, + -2, 2, -2, 2, -2, 2, -2, 2, 7, -8, 7, -8, 7, -8, 7, -8}; + + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 1, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, SignedInt4_NoZeroPoint_LastAxis) { + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector y_scale{-2.0, -4.0, -2.0, -4.0, 3.5, 1.0, 3.5, 1.0, + 2.0, 4.0, 2.0, 4.0, -3.5, -1.0, -3.5, -1.0}; + std::vector x{14, 12, 14, 12, 20, 16, 20, 16, 14, 12, 14, 12, 20, 16, 20, 16, + -10.5, -7, -10.5, -7, -1, 0, -1, 0, -10.5, -7, -10.5, -7, -1, 0, -1, 0, + 2, 4, 2, 4, 12, 16, 12, 16, 2, 4, 2, 4, 12, 16, 12, 16, + -17.5, -21, -17.5, -21, -7, 8, -7, 8, -17.5, -21, -17.5, -21, -7, 8, -7, 8}; + std::vector y_2{-7, -6, -7, -6, -5, -4, -5, -4, -7, -6, -7, -6, -5, -4, -5, -4, + -3, -2, -3, -2, -1, 0, -1, 0, -3, -2, -3, -2, -1, 0, -1, 0, + 1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4, + 5, 6, 5, 6, 7, -8, 7, -8, 5, 6, 5, 6, 7, -8, 7, -8}; + std::vector y_3{-7, -6, -7, -6, -8, -4, -5, -4, -7, -6, -7, -6, -8, -4, -5, -4, + -3, -2, -3, -2, 0, 0, -1, 0, -3, -2, -3, -2, 0, 0, -1, 0, + 1, 2, 1, 2, 6, 4, 3, 4, 1, 2, 1, 2, 6, 4, 3, 4, + 5, 6, 5, 6, 2, -8, 7, -8, 5, 6, 5, 6, 2, -8, 7, -8}; + + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, SignedInt_NoZeroPoint_LastAxis) { + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector y_scale{-2.0, -4.0, -2.0, -4.0, 3.5, 1.0, 3.5, 1.0, + 2.0, 4.0, 2.0, 4.0, -3.5, -1.0, -3.5, -1.0}; + std::vector x{14, 12, 14, 12, 20, 16, 20, 16, 14, 12, 14, 12, 20, 16, 20, 16, + -10.5, -7, -10.5, -7, -1, 0, -1, 0, -10.5, -7, -10.5, -7, -1, 0, -1, 0, + 2, 4, 2, 4, 12, 16, 12, 16, 2, 4, 2, 4, 12, 16, 12, 16, + -17.5, -21, -17.5, -21, -7, 8, -7, 8, -17.5, -21, -17.5, -21, -7, 8, -7, 8}; + std::vector y_2{-7, -6, -7, -6, -5, -4, -5, -4, -7, -6, -7, -6, -5, -4, -5, -4, + -3, -2, -3, -2, -1, 0, -1, 0, -3, -2, -3, -2, -1, 0, -1, 0, + 1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4, + 5, 6, 5, 6, 7, -8, 7, -8, 5, 6, 5, 6, 7, -8, 7, -8}; + std::vector y_3{-7, -6, -7, -6, -10, -4, -5, -4, -7, -6, -7, -6, -10, -4, -5, -4, + -3, -2, -3, -2, 0, 0, -1, 0, -3, -2, -3, -2, 0, 0, -1, 0, + 1, 2, 1, 2, 6, 4, 3, 4, 1, 2, 1, 2, 6, 4, 3, 4, + 5, 6, 5, 6, 2, -8, 7, -8, 5, 6, 5, 6, 2, -8, 7, -8}; + + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, SignedInt4_UseZeroPoint_LastAxis) { + std::vector y_scale{-2.0, -4.0, -2.0, -4.0, 3.5, 1.0, 3.5, 1.0, + 2.0, 4.0, 2.0, 4.0, -3.5, -1.0, -3.5, -1.0}; + std::vector zero_point{-6, -4, -6, -4, -3, -1, -3, -1, + 0, 2, 0, 2, 4, 7, 4, 7}; + std::vector x{2, 0, 2, 0, 4, 0, 4, 0, 2, 0, 2, 0, 4, 0, 4, 0, + 0, 3.5, 0, 3.5, 0, 1, 0, 1, 0, 3.5, 0, 3.5, 0, 1, 0, 1, + 2, 4, 2, 4, 4, 8, 4, 8, 2, 4, 2, 4, 4, 8, 4, 8, + -3.5, -7, -3.5, -7, 0, 15, 0, 15, -3.5, -7, -3.5, -7, 0, 15, 0, 15}; + std::vector y_2{-7, -6, -7, -6, -5, -4, -5, -4, -7, -6, -7, -6, -5, -4, -5, -4, + -3, -2, -3, -2, -1, 0, -1, 0, -3, -2, -3, -2, -1, 0, -1, 0, + 1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4, + 5, 6, 5, 6, 7, -8, 7, -8, 5, 6, 5, 6, 7, -8, 7, -8}; + std::vector y_3{-7, -6, -7, -6, -8, -4, -5, -4, -7, -6, -7, -6, -8, -4, -5, -4, + -3, -2, -3, -2, -3, 0, -1, 0, -3, -2, -3, -2, -3, 0, -1, 0, + 1, 2, 1, 2, 2, 4, 3, 4, 1, 2, 1, 2, 2, 4, 3, 4, + 5, 6, 5, 6, 4, -8, 7, -8, 5, 6, 5, 6, 4, -8, 7, -8}; + + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, SignedInt_UseZeroPoint_LastAxis) { + std::vector y_scale{-2.0, -4.0, -2.0, -4.0, 3.5, 1.0, 3.5, 1.0, + 2.0, 4.0, 2.0, 4.0, -3.5, -1.0, -3.5, -1.0}; + std::vector zero_point{-6, -4, -6, -4, -3, -1, -3, -1, + 0, 2, 0, 2, 4, 7, 4, 7}; + std::vector x{2, 0, 2, 0, 4, 0, 4, 0, 2, 0, 2, 0, 4, 0, 4, 0, + 0, 3.5, 0, 3.5, 0, 1, 0, 1, 0, 3.5, 0, 3.5, 0, 1, 0, 1, + 2, 4, 2, 4, 4, 8, 4, 8, 2, 4, 2, 4, 4, 8, 4, 8, + -3.5, -7, -3.5, -7, 0, 15, 0, 15, -3.5, -7, -3.5, -7, 0, 15, 0, 15}; + std::vector y_2{-7, -6, -7, -6, -5, -4, -5, -4, -7, -6, -7, -6, -5, -4, -5, -4, + -3, -2, -3, -2, -1, 0, -1, 0, -3, -2, -3, -2, -1, 0, -1, 0, + 1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4, + 5, 6, 5, 6, 7, -8, 7, -8, 5, 6, 5, 6, 7, -8, 7, -8}; + std::vector y_3{-7, -6, -7, -6, -8, -4, -5, -4, -7, -6, -7, -6, -8, -4, -5, -4, + -3, -2, -3, -2, -3, 0, -1, 0, -3, -2, -3, -2, -3, 0, -1, 0, + 1, 2, 1, 2, 2, 4, 3, 4, 1, 2, 1, 2, 2, 4, 3, 4, + 5, 6, 5, 6, 4, -8, 7, -8, 5, 6, 5, 6, 4, -8, 7, -8}; + + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, UnsignedInt4_NoZeroPoint_FirstAxis) { + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, + 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector x{0, -4, 7, 3, 0, -4, 7, 3, 0, -4, 7, 3, 0, -4, 7, 3, + -8, -20, 21, 7, -8, -20, 21, 7, -8, -20, 21, 7, -8, -20, 21, 7, + 16, 36, -35, -11, 16, 36, -35, -11, 16, 36, -35, -11, 16, 36, -35, -11, + 24, 52, -49, -15, 24, 52, -49, -15, 24, 52, -49, -15, 24, 52, -49, -15}; + std::vector y_2{0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, + 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, + 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15}; + std::vector y_3{0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15}; + + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, UnsignedInt_NoZeroPoint_FirstAxis) { + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, + 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector x{0, -4, 7, 3, 0, -4, 7, 3, 0, -4, 7, 3, 0, -4, 7, 3, + -8, -20, 21, 7, -8, -20, 21, 7, -8, -20, 21, 7, -8, -20, 21, 7, + 16, 36, -35, -11, 16, 36, -35, -11, 16, 36, -35, -11, 16, 36, -35, -11, + 24, 52, -49, -15, 24, 52, -49, -15, 24, 52, -49, -15, 24, 52, -49, -15}; + std::vector y_2{0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, + 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, + 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15}; + std::vector y_3{0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15}; + + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, UnsignedInt4_UseZeroPoint_FirstAxis) { + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, + 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{2, 0, 1, 9, 2, 0, 1, 9, 2, 0, 1, 9, 2, 0, 1, 9, + 13, 5, 11, 6, 13, 5, 11, 6, 13, 5, 11, 6, 13, 5, 11, 6}; + std::vector x{4, -4, 3.5, -6, 4, -4, 3.5, -6, 4, -4, 3.5, -6, 4, -4, 3.5, -6, + -4, -20, 17.5, -2, -4, -20, 17.5, -2, -4, -20, 17.5, -2, -4, -20, 17.5, -2, + -10, 16, 3.5, -5, -10, 16, 3.5, -5, -10, 16, 3.5, -5, -10, 16, 3.5, -5, + -2, 32, -10.5, -9, -2, 32, -10.5, -9, -2, 32, -10.5, -9, -2, 32, -10.5, -9}; + std::vector y_2{0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, + 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, + 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15}; + std::vector y_3{0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, + 7, 0, 2, 4, 7, 0, 2, 4, 7, 0, 2, 4, 7, 0, 2, 4, + 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15}; + + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, UnsignedInt_UseZeroPoint_FirstAxis) { + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, + 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{2, 0, 1, 9, 2, 0, 1, 9, 2, 0, 1, 9, 2, 0, 1, 9, + 13, 5, 11, 6, 13, 5, 11, 6, 13, 5, 11, 6, 13, 5, 11, 6}; + std::vector x{4, -4, 3.5, -6, 4, -4, 3.5, -6, 4, -4, 3.5, -6, 4, -4, 3.5, -6, + -4, -20, 17.5, -2, -4, -20, 17.5, -2, -4, -20, 17.5, -2, -4, -20, 17.5, -2, + -10, 16, 3.5, -5, -10, 16, 3.5, -5, -10, 16, 3.5, -5, -10, 16, 3.5, -5, + -2, 32, -10.5, -9, -2, 32, -10.5, -9, -2, 32, -10.5, -9, -2, 32, -10.5, -9}; + std::vector y_2{0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, + 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, + 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15}; + std::vector y_3{0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, + 7, 0, 2, 4, 7, 0, 2, 4, 7, 0, 2, 4, 7, 0, 2, 4, + 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15}; + + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, UnsignedInt4_NoZeroPoint_MiddleAxis) { + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, + -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector x{0, -4, -4, -12, 14, 5, 21, 7, 16, 36, 20, 44, -42, -13, -49, -15, + 0, -4, -4, -12, 14, 5, 21, 7, 16, 36, 20, 44, -42, -13, -49, -15, + 0, -4, -4, -12, 14, 5, 21, 7, 16, 36, 20, 44, -42, -13, -49, -15, + 0, -4, -4, -12, 14, 5, 21, 7, 16, 36, 20, 44, -42, -13, -49, -15}; + std::vector y_2{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + std::vector y_3{0, 1, 2, 3, 0, 0, 6, 7, 8, 9, 10, 11, 0, 0, 14, 15, + 0, 1, 2, 3, 0, 0, 6, 7, 8, 9, 10, 11, 0, 0, 14, 15, + 0, 1, 2, 3, 0, 0, 6, 7, 8, 9, 10, 11, 0, 0, 14, 15, + 0, 1, 2, 3, 0, 0, 6, 7, 8, 9, 10, 11, 0, 0, 14, 15}; + + QuantizeLinearOp21BlockedTest_Int4_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int4_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, UnsignedInt_NoZeroPoint_MiddleAxis) { + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, + -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector x{0, -4, -4, -12, 14, 5, 21, 7, 16, 36, 20, 44, -42, -13, -49, -15, + 0, -4, -4, -12, 14, 5, 21, 7, 16, 36, 20, 44, -42, -13, -49, -15, + 0, -4, -4, -12, 14, 5, 21, 7, 16, 36, 20, 44, -42, -13, -49, -15, + 0, -4, -4, -12, 14, 5, 21, 7, 16, 36, 20, 44, -42, -13, -49, -15}; + std::vector y_2{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + std::vector y_3{0, 1, 2, 3, 0, 0, 6, 7, 8, 9, 10, 11, 0, 0, 14, 15, + 0, 1, 2, 3, 0, 0, 6, 7, 8, 9, 10, 11, 0, 0, 14, 15, + 0, 1, 2, 3, 0, 0, 6, 7, 8, 9, 10, 11, 0, 0, 14, 15, + 0, 1, 2, 3, 0, 0, 6, 7, 8, 9, 10, 11, 0, 0, 14, 15}; + + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, UnsignedInt4_UseZeroPoint_MiddleAxis) { + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, + -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{2, 0, 1, 9, 13, 5, 11, 6, 2, 0, 1, 9, 13, 5, 11, 6, + 2, 0, 1, 9, 13, 5, 11, 6, 2, 0, 1, 9, 13, 5, 11, 6}; + std::vector x{4, -4, 0, -12, 10.5, -4, 17.5, -2, -10, 16, -6, 24, -3.5, -7, -10.5, -9, + 4, -4, 0, -12, 10.5, -4, 17.5, -2, -10, 16, -6, 24, -3.5, -7, -10.5, -9, + 4, -4, 0, -12, 10.5, -4, 17.5, -2, -10, 16, -6, 24, -3.5, -7, -10.5, -9, + 4, -4, 0, -12, 10.5, -4, 17.5, -2, -10, 16, -6, 24, -3.5, -7, -10.5, -9}; + std::vector y_2{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + std::vector y_3{0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 11, 3, 14, 15, + 0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 11, 3, 14, 15, + 0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 11, 3, 14, 15, + 0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 11, 3, 14, 15}; + + QuantizeLinearOp21BlockedTest_Int4_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int4_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, UnsignedInt_UseZeroPoint_MiddleAxis) { + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, + -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{2, 0, 1, 9, 13, 5, 11, 6, 2, 0, 1, 9, 13, 5, 11, 6, + 2, 0, 1, 9, 13, 5, 11, 6, 2, 0, 1, 9, 13, 5, 11, 6}; + std::vector x{4, -4, 0, -12, 10.5, -4, 17.5, -2, -10, 16, -6, 24, -3.5, -7, -10.5, -9, + 4, -4, 0, -12, 10.5, -4, 17.5, -2, -10, 16, -6, 24, -3.5, -7, -10.5, -9, + 4, -4, 0, -12, 10.5, -4, 17.5, -2, -10, 16, -6, 24, -3.5, -7, -10.5, -9, + 4, -4, 0, -12, 10.5, -4, 17.5, -2, -10, 16, -6, 24, -3.5, -7, -10.5, -9}; + std::vector y_2{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + std::vector y_3{0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 11, 3, 14, 15, + 0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 11, 3, 14, 15, + 0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 11, 3, 14, 15, + 0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 11, 3, 14, 15}; + + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, UnsignedInt4_NoZeroPoint_LastAxis) { + std::vector y_scale{-2.0, -4.0, -2.0, -4.0, 3.5, 1.0, 3.5, 1.0, + 2.0, 4.0, 2.0, 4.0, -3.5, -1.0, -3.5, -1.0}; + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector x{0, -2, 0, -2, -8, -12, -8, -12, 0, -2, 0, -2, -8, -12, -8, -12, + 14, 17.5, 14, 17.5, 6, 7, 6, 7, 14, 17.5, 14, 17.5, 6, 7, 6, 7, + 16, 18, 16, 18, 40, 44, 40, 44, 16, 18, 16, 18, 40, 44, 40, 44, + -42, -45.5, -42, -45.5, -14, -15, -14, -15, -42, -45.5, -42, -45.5, -14, -15, -14, -15}; + std::vector y_2{0, 1, 0, 1, 2, 3, 2, 3, 0, 1, 0, 1, 2, 3, 2, 3, + 4, 5, 4, 5, 6, 7, 6, 7, 4, 5, 4, 5, 6, 7, 6, 7, + 8, 9, 8, 9, 10, 11, 10, 11, 8, 9, 8, 9, 10, 11, 10, 11, + 12, 13, 12, 13, 14, 15, 14, 15, 12, 13, 12, 13, 14, 15, 14, 15}; + std::vector y_3{0, 1, 0, 1, 4, 3, 2, 3, 0, 1, 0, 1, 4, 3, 2, 3, + 4, 5, 4, 5, 2, 7, 6, 7, 4, 5, 4, 5, 2, 7, 6, 7, + 8, 9, 8, 9, 15, 11, 10, 11, 8, 9, 8, 9, 15, 11, 10, 11, + 12, 13, 12, 13, 4, 15, 14, 15, 12, 13, 12, 13, 4, 15, 14, 15}; + + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, UnsignedInt_NoZeroPoint_LastAxis) { + std::vector y_scale{-2.0, -4.0, -2.0, -4.0, 3.5, 1.0, 3.5, 1.0, + 2.0, 4.0, 2.0, 4.0, -3.5, -1.0, -3.5, -1.0}; + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector x{0, -2, 0, -2, -8, -12, -8, -12, 0, -2, 0, -2, -8, -12, -8, -12, + 14, 17.5, 14, 17.5, 6, 7, 6, 7, 14, 17.5, 14, 17.5, 6, 7, 6, 7, + 16, 18, 16, 18, 40, 44, 40, 44, 16, 18, 16, 18, 40, 44, 40, 44, + -42, -45.5, -42, -45.5, -14, -15, -14, -15, -42, -45.5, -42, -45.5, -14, -15, -14, -15}; + std::vector y_2{0, 1, 0, 1, 2, 3, 2, 3, 0, 1, 0, 1, 2, 3, 2, 3, + 4, 5, 4, 5, 6, 7, 6, 7, 4, 5, 4, 5, 6, 7, 6, 7, + 8, 9, 8, 9, 10, 11, 10, 11, 8, 9, 8, 9, 10, 11, 10, 11, + 12, 13, 12, 13, 14, 15, 14, 15, 12, 13, 12, 13, 14, 15, 14, 15}; + std::vector y_3{0, 1, 0, 1, 4, 3, 2, 3, 0, 1, 0, 1, 4, 3, 2, 3, + 4, 5, 4, 5, 2, 7, 6, 7, 4, 5, 4, 5, 2, 7, 6, 7, + 8, 9, 8, 9, 20, 11, 10, 11, 8, 9, 8, 9, 20, 11, 10, 11, + 12, 13, 12, 13, 4, 15, 14, 15, 12, 13, 12, 13, 4, 15, 14, 15}; + + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, UnsignedInt4_UseZeroPoint_LastAxis) { + std::vector y_scale{-2.0, -4.0, -2.0, -4.0, 3.5, 1.0, 3.5, 1.0, + 2.0, 4.0, 2.0, 4.0, -3.5, -1.0, -3.5, -1.0}; + std::vector zero_point{2, 0, 2, 0, 1, 9, 1, 9, + 13, 5, 13, 5, 11, 6, 11, 6}; + std::vector x{4, 2, 4, 2, -8, -12, -8, -12, 4, 2, 4, 2, -8, -12, -8, -12, + 10.5, 14, 10.5, 14, -3, -2, -3, -2, 10.5, 14, 10.5, 14, -3, -2, -3, -2, + -10, -8, -10, -8, 20, 24, 20, 24, -10, -8, -10, -8, 20, 24, 20, 24, + -3.5, -7, -3.5, -7, -8, -9, -8, -9, -3.5, -7, -3.5, -7, -8, -9, -8, -9}; + std::vector y_2{0, 1, 0, 1, 2, 3, 2, 3, 0, 1, 0, 1, 2, 3, 2, 3, + 4, 5, 4, 5, 6, 7, 6, 7, 4, 5, 4, 5, 6, 7, 6, 7, + 8, 9, 8, 9, 10, 11, 10, 11, 8, 9, 8, 9, 10, 11, 10, 11, + 12, 13, 12, 13, 14, 15, 14, 15, 12, 13, 12, 13, 14, 15, 14, 15}; + std::vector y_3{0, 1, 0, 1, 6, 3, 2, 3, 0, 1, 0, 1, 6, 3, 2, 3, + 4, 5, 4, 5, 0, 7, 6, 7, 4, 5, 4, 5, 0, 7, 6, 7, + 8, 9, 8, 9, 15, 11, 10, 11, 8, 9, 8, 9, 15, 11, 10, 11, + 12, 13, 12, 13, 13, 15, 14, 15, 12, 13, 12, 13, 13, 15, 14, 15}; + + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int4_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); +} + +TEST(QuantizeLinearOp21BlockedTest, UnsignedInt_UseZeroPoint_LastAxis) { + std::vector y_scale{-2.0, -4.0, -2.0, -4.0, 3.5, 1.0, 3.5, 1.0, + 2.0, 4.0, 2.0, 4.0, -3.5, -1.0, -3.5, -1.0}; + std::vector zero_point{2, 0, 2, 0, 1, 9, 1, 9, + 13, 5, 13, 5, 11, 6, 11, 6}; + std::vector x{4, 2, 4, 2, -8, -12, -8, -12, 4, 2, 4, 2, -8, -12, -8, -12, + 10.5, 14, 10.5, 14, -3, -2, -3, -2, 10.5, 14, 10.5, 14, -3, -2, -3, -2, + -10, -8, -10, -8, 20, 24, 20, 24, -10, -8, -10, -8, 20, 24, 20, 24, + -3.5, -7, -3.5, -7, -8, -9, -8, -9, -3.5, -7, -3.5, -7, -8, -9, -8, -9}; + std::vector y_2{0, 1, 0, 1, 2, 3, 2, 3, 0, 1, 0, 1, 2, 3, 2, 3, + 4, 5, 4, 5, 6, 7, 6, 7, 4, 5, 4, 5, 6, 7, 6, 7, + 8, 9, 8, 9, 10, 11, 10, 11, 8, 9, 8, 9, 10, 11, 10, 11, + 12, 13, 12, 13, 14, 15, 14, 15, 12, 13, 12, 13, 14, 15, 14, 15}; + std::vector y_3{0, 1, 0, 1, 6, 3, 2, 3, 0, 1, 0, 1, 6, 3, 2, 3, + 4, 5, 4, 5, 0, 7, 6, 7, 4, 5, 4, 5, 0, 7, 6, 7, + 8, 9, 8, 9, 23, 11, 10, 11, 8, 9, 8, 9, 23, 11, 10, 11, + 12, 13, 12, 13, 13, 15, 14, 15, 12, 13, 12, 13, 13, 15, 14, 15}; + + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 4, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Int_Succeed({2, 4, 8}, 2, 5, x, y_scale, zero_point, y_3); +} + +#if !defined(DISABLE_FLOAT8_TYPES) +TEST(QuantizeLinearOp21BlockedTest, Float8_NoZeroPoint_FirstAxis) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCpuExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, -2.0, -4.0, 3.5, 1.0, + 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector x{14.0, 24.0, -17.5, -4, 14.0, 24.0, -17.5, -4, 14.0, 24.0, -17.5, -4, 14.0, 24.0, -17.5, -4, + 6.0, 8.0, -3.5, 0.0, 6.0, 8.0, -3.5, 0.0, 6.0, 8.0, -3.5, 0.0, 6.0, 8.0, -3.5, 0.0, + 2.0, 8.0, -10.5, -4.0, 2.0, 8.0, -10.5, -4.0, 2.0, 8.0, -10.5, -4.0, 2.0, 8.0, -10.5, -4.0, + 10.0, 24.0, -24.5, 8.0, 10.0, 24.0, -24.5, 8.0, 10.0, 24.0, -24.5, 8.0, 10.0, 24.0, -24.5, 8.0}; + std::vector y_2{-7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, + -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8}; + std::vector y_3{-7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, -7, -6, -5, -4, + -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, -3, -2, -1, 0, + -1, -2, -3, -4, -1, -2, -3, -4, -1, -2, -3, -4, -1, -2, -3, -4, + 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8, 5, 6, 7, -8}; + + if (enable_cpu || enable_cuda) { + QuantizeLinearOp21BlockedTest_Float8_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + } + if (enable_cpu) { + QuantizeLinearOp21BlockedTest_Float8_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {4, 8, 2}, 0, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed({4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {4, 8, 2}, 0, 3, x, y_scale, zero_point, y_3); + } +} + +TEST(QuantizeLinearOp21BlockedTest, Float8_NoZeroPoint_MiddleAxis) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCpuExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, + -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector x{14, 24, 10, 16, -10.5, -2, -3.5, 0, 2, 8, 6, 16, -17.5, -6, -24.5, 8, + 14, 24, 10, 16, -10.5, -2, -3.5, 0, 2, 8, 6, 16, -17.5, -6, -24.5, 8, + 14, 24, 10, 16, -10.5, -2, -3.5, 0, 2, 8, 6, 16, -17.5, -6, -24.5, 8, + 14, 24, 10, 16, -10.5, -2, -3.5, 0, 2, 8, 6, 16, -17.5, -6, -24.5, 8}; + std::vector y_2{-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, + -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, + -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, + -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8}; + std::vector y_3{-7, -6, -5, -4, 5, 0.5, -1, 0, 1, 2, 3, 4, -9, -1.5, 7, -8, + -7, -6, -5, -4, 5, 0.5, -1, 0, 1, 2, 3, 4, -9, -1.5, 7, -8, + -7, -6, -5, -4, 5, 0.5, -1, 0, 1, 2, 3, 4, -9, -1.5, 7, -8, + -7, -6, -5, -4, 5, 0.5, -1, 0, 1, 2, 3, 4, -9, -1.5, 7, -8}; + + if (enable_cpu || enable_cuda) { + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + } + if (enable_cpu) { + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {8, 4, 2}, 1, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {8, 4, 2}, 1, 3, x, y_scale, zero_point, y_3); + } +} + +TEST(QuantizeLinearOp21BlockedTest, Float8_NoZeroPoint_LastAxis) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCpuExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + std::vector zero_point{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector y_scale{-2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, + -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0, -2.0, -4.0, 3.5, 1.0, 2.0, 4.0, -3.5, -1.0}; + std::vector x{14, 12, 20, 16, -10.5, -7, -3.5, 0, 2, 4, 12, 16, -17.5, -21, -7, 8, + 14, 12, 20, 16, -10.5, -7, -3.5, 0, 2, 4, 12, 16, -17.5, -21, -7, 8, + 14, 12, 20, 16, -10.5, -7, -3.5, 0, 2, 4, 12, 16, -17.5, -21, -7, 8, + 14, 12, 20, 16, -10.5, -7, -3.5, 0, 2, 4, 12, 16, -17.5, -21, -7, 8}; + std::vector y_2{-7, -6, -5, -4, -3, -2, -3.5, 0, 1, 2, 3, 4, 5, 6, 7, -8, + -7, -6, -5, -4, -3, -2, -3.5, 0, 1, 2, 3, 4, 5, 6, 7, -8, + -7, -6, -5, -4, -3, -2, -3.5, 0, 1, 2, 3, 4, 5, 6, 7, -8, + -7, -6, -5, -4, -3, -2, -3.5, 0, 1, 2, 3, 4, 5, 6, 7, -8}; + std::vector y_3{-7, -6, -10, -4, -3, -2, -1, 0, 1, 2, 6, 4, 5, 6, 2, -8, + -7, -6, -10, -4, -3, -2, -1, 0, 1, 2, 6, 4, 5, 6, 2, -8, + -7, -6, -10, -4, -3, -2, -1, 0, 1, 2, 6, 4, 5, 6, 2, -8, + -7, -6, -10, -4, -3, -2, -1, 0, 1, 2, 6, 4, 5, 6, 2, -8}; + + if (enable_cpu || enable_cuda) { + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 2, 4}, 2, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {8, 2, 4}, 2, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 2, 4}, 2, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 2, 4}, 2, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 2, 4}, 2, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {8, 2, 4}, 2, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 2, 4}, 2, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 2, 4}, 2, 3, x, y_scale, zero_point, y_3); + } + if (enable_cpu) { + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 2, 4}, 2, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {8, 2, 4}, 2, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 2, 4}, 2, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {8, 2, 4}, 2, 2, x, y_scale, zero_point, y_2); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 2, 4}, 2, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {8, 2, 4}, 2, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed({8, 2, 4}, 2, 3, x, y_scale, zero_point, y_3); + QuantizeLinearOp21BlockedTest_Float8_Succeed( + {8, 2, 4}, 2, 3, x, y_scale, zero_point, y_3); + } +} +#endif +} // namespace blocked_quantization } // namespace test } // namespace onnxruntime