diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index cb34d2b289..752239ab45 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -346,7 +346,12 @@ cutlass_test_unit_add_executable( BATCH_SOURCES ON BATCH_SIZE 4 + + # Upcast on Operand A + gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu + # Upcast on Operand B gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu ) diff --git a/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu new file mode 100644 index 0000000000..6dae3d4c5d --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { + + using ElementA = int8_t; + using ElementB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu new file mode 100644 index 0000000000..51f6f4ab72 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_u8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { + + using ElementA = uint8_t; + using ElementB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 9d33810853..651cf44acf 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -2167,7 +2167,13 @@ def GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version): (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), ] + # Upcast on Operand A math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), MathInstruction( \ [16, 8, 16], \ DataType.s8, DataType.f16, DataType.f32, \ @@ -2210,6 +2216,50 @@ def GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version): CreateMixedInputGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) + # Upcast on Operand B + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.s8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.s8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.u8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.u8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [[8, 16, 8],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateMixedInputGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + # def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): diff --git a/tools/library/src/reference/gemm_fp_mixed_input.cu b/tools/library/src/reference/gemm_fp_mixed_input.cu index bf7e9cd7cf..d27e259bf2 100644 --- a/tools/library/src/reference/gemm_fp_mixed_input.cu +++ b/tools/library/src/reference/gemm_fp_mixed_input.cu @@ -79,6 +79,22 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) { float >(manifest); + make_gemm_real_canonical_layouts< + half_t, + uint8_t, + float, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + half_t, + int8_t, + float, + float, + float + >(manifest); + // bfloat16_t mixed with 8-bit integer input make_gemm_real_canonical_layouts< uint8_t, @@ -95,6 +111,22 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) { float, float >(manifest); + + make_gemm_real_canonical_layouts< + bfloat16_t, + uint8_t, + float, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + bfloat16_t, + int8_t, + float, + float, + float + >(manifest); } ///////////////////////////////////////////////////////////////////////////////////////////////////