Skip to content

Commit

Permalink
Add int4b_t/uint4b_t support for mixed dtypes GEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Oct 27, 2023
1 parent 5f13dca commit f88a889
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 20 deletions.
82 changes: 81 additions & 1 deletion include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,86 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
(1 - odd_even_lane_id_) * kSelectBytesEvenThread;
}

CUTLASS_DEVICE
WarpFragment operator()(WarpFragment const &src) {

WarpFragment result;

MmaFragment const* mma_frag_src_ptr = reinterpret_cast<MmaFragment const *>(&src);
MmaFragment* mma_frag_dst_ptr = reinterpret_cast<MmaFragment *>(&result);

CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kNumMmaInstructions; n++) {

uint32_t const* src_ptr = reinterpret_cast<uint32_t const*>(&mma_frag_src_ptr[n]);
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&mma_frag_dst_ptr[n]);

// Shuffle data within the warp, pull from other threads within the warp
uint32_t tmp0 = __shfl_up_sync(0xFFFFFFFF, src_ptr[0], delta_up_);
uint32_t tmp1 = __shfl_down_sync(0xFFFFFFFF, src_ptr[0], delta_down_);

// Reorder the data within the 32-bit word (4x8b) required for mma.sync
dst_ptr[0] = __byte_perm(tmp0, tmp1, byte_selector_);
}

return result;
}

};
////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 4b (S4/U4)
/// for operand B multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
typename ElementMma_,
/// Element type for the operand in shared memory for ldmatrix
typename ElementLoad_,
/// Number of mma.sync operations performed along rows or columns
int NumMmaInstructions,
/// Number of elements in warp fragment
int NumElementsInWarpFragment,
/// Number of elements in mma fragment
int NumElementsInMmaFragment
>
struct FragmentShuffler <ElementMma_, ElementLoad_,
NumMmaInstructions,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kB,
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 4)>::type> {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;

static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand::kB;

using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;

static uint32_t const kSelectBytesEvenThread = 0x5140;
static uint32_t const kSelectBytesOddThread = 0x7362;

private:
int odd_even_lane_id_;
int delta_up_;
int delta_down_;
uint32_t byte_selector_;

public:
CUTLASS_DEVICE
FragmentShuffler() {
int lane_id = cutlass::arch::LaneId();
odd_even_lane_id_ = static_cast<int>(lane_id & 1);
delta_up_ = odd_even_lane_id_;
delta_down_ = 1 - odd_even_lane_id_;
byte_selector_ = odd_even_lane_id_ ? kSelectBytesOddThread : kSelectBytesEvenThread;
}

CUTLASS_DEVICE
WarpFragment operator()(WarpFragment const &src) {

Expand Down Expand Up @@ -551,4 +631,4 @@ class MmaMixedInputTensorOp {
} // namespace gemm
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
8 changes: 7 additions & 1 deletion include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,12 @@ class MmaTensorOpMultiplicandTileIterator<
"Shape of warp-level Mma must be divisible by operator shape.");

// Determine number of elements along outer dimension per individual LDSM op
static int const kLdsmOpOuter = Layout::kElementsPerAccess;
// FIXME: uncomment next line, remove what follows!
// static int const kLdsmOpOuter = Layout::kElementsPerAccess;
static int const kLdsmOpOuter = Layout::kElementsPerAccess >= InstructionShape::kContiguous ? InstructionShape::kContiguous : Layout::kElementsPerAccess;
// static_assert(!platform::is_same<Element_, cutlass::int4b_t>::value || InstructionShape::kContiguous == 16);
// static_assert(!platform::is_same<Element_, cutlass::int4b_t>::value || kLdsmOpOuter == 32);

static int const kLdsmOpInner = 8;

static_assert(!(Shape::kContiguous % kLdsmOpOuter),
Expand All @@ -1401,6 +1406,7 @@ class MmaTensorOpMultiplicandTileIterator<
/// Shape of one individual LDSM instruction
static int const LdsmShapeContiguous =
InstructionShape::kContiguous / kLdsmOpOuter;

static int const LdsmShapeStrided =
((4 / LdsmShapeContiguous * kLdsmOpInner) > Shape::kStrided)
? (Shape::kStrided / kLdsmOpInner)
Expand Down
65 changes: 64 additions & 1 deletion include/cutlass/numeric_conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -2454,13 +2454,14 @@ struct FastNumericArrayConverter<cutlass::half_t, int8_t, 4, Round> {
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;

#if 0 // Scalar conversion (Please keep this code for reference for vectorized version below)
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
int16_t tmp = source[i] + 26112 /* 0x6600 */;
result[i] = reinterpret_cast<cutlass::half_t const &>(tmp) - 1536.0_hf;
}
return result;
#endif

// Vectorized s8->f16 conversion using packed instructions
Expand Down Expand Up @@ -2541,6 +2542,33 @@ struct FastNumericArrayConverter<cutlass::half_t, uint8_t, 4, Round> {
}
};

/// Partial specialization for Array<cutlass::half_t, 8> <= Array<int4b_t, 8>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::half_t, cutlass::int4b_t, 8, Round> {
using result_type = Array<cutlass::half_t, 8>;
using source_type = Array<cutlass::int4b_t, 8>;
static FloatRoundStyle const round_style = Round;

CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;

#if 1 // Scalar conversion (Please keep this code for reference for vectorized version below)
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 8; ++i) {
int16_t tmp = (int(source[i]) << 4) + 22016 /* 0x5600 */;
result[i] = reinterpret_cast<cutlass::half_t const &>(tmp) - 96.0_hf;
}
return result;
#endif
}

CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};

/// Partial specialization for Array<cutlass::bfloat16_t, 4> <= Array<uint8_t, 4>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::bfloat16_t, uint8_t, 4, Round> {
Expand Down Expand Up @@ -2661,6 +2689,41 @@ struct FastNumericArrayConverter<T, S, N, Round,

};

/// Partial specialization for FastNumericArrayConverter to vectorize over 8 elements.
/// source `S` as 4b integers (S4 or U4) -> destination `T` as 16b floating-point (F16 or BF16)
template <typename T, typename S, int N, FloatRoundStyle Round>
struct FastNumericArrayConverter<T, S, N, Round,
typename platform::enable_if<(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value) &&
(platform::is_same<S, cutlass::int4b_t>::value || platform::is_same<S, uint4b_t>::value)>::type> {
static_assert(!(N % 8), "N must be multiple of 8.");

using result_type = Array<T, N>;
using source_type = Array<S, N>;
static FloatRoundStyle const round_style = Round;

CUTLASS_DEVICE
static result_type convert(source_type const &source) {
FastNumericArrayConverter<T, S, 8, Round> convert_vector_;
result_type result;

Array<T, 8> *result_ptr =
reinterpret_cast<Array<T, 8> *>(&result);
Array<S, 8> const *source_ptr =
reinterpret_cast<Array<S, 8> const *>(&source);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 8; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}

CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Defines preferred rounding mode for a pair of types
Expand Down
47 changes: 32 additions & 15 deletions test/unit/core/fast_numeric_conversion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ void run_test_integer_range_limited() {
}
}


template <typename Destination, typename Source, int Count>
void run_test_integer_range_all() {
const int kN = Count;
Expand All @@ -97,13 +96,24 @@ void run_test_integer_range_all() {
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});

int const kIntSourceMin = std::numeric_limits<Source>::min();
int const kIntSourceMax = std::numeric_limits<Source>::max();
int const kIntRange = kIntSourceMax - kIntSourceMin + 1;
constexpr auto is_source_arithmetic = std::is_arithmetic<Source>::value;

for (int i = 0; i < kN; ++i) {
source.host_data()[i] = Source(kIntSourceMin + (i % kIntRange));
int kIntSourceMin;
int kIntSourceMax;
if constexpr (is_source_arithmetic) {
kIntSourceMin = std::numeric_limits<Source>::min();
kIntSourceMax = std::numeric_limits<Source>::max();
} else {
kIntSourceMin = cutlass::platform::numeric_limits<Source>::lowest();
kIntSourceMax = cutlass::platform::numeric_limits<Source>::max();
}
int kIntRange = kIntSourceMax - kIntSourceMin + 1;

using SourceArray = cutlass::Array<Source, kN>;
SourceArray& source_array = *reinterpret_cast<SourceArray*>(source.host_data());

for (int i = 0; i < kN; ++i) {
source_array[i] = Source(kIntSourceMin + (i % kIntRange));
}

source.sync_device();
Expand All @@ -114,25 +124,25 @@ void run_test_integer_range_all() {
);

destination.sync_host();

// Verify conversion
bool passed = true;
for (int i = 0; i < kN; ++i) {
if(!(float(destination.host_data()[i]) == float(source.host_data()[i]))) {
if(!(float(destination.host_data()[i]) == float(source_array[i]))) {
passed = false;
break;
}
}
EXPECT_TRUE(passed) << " FastNumericArrayConverter failed";
// Print out results for the failed conversion.
if (!passed) {

// Print out results for the failed conversion.
if (!passed) {
for (int i = 0; i < kN; ++i) {
std::cout << "source(" << float(source.host_data()[i]) << ") -> "
<< "destination ("<< float(destination.host_data()[i]) << ")" << std::endl;
std::cout << "source(" << float(source_array[i]) << ") -> "
<< "destination ("<< float(destination.host_data()[i]) << ")" << std::endl;
}
}
std::flush(std::cout);
}
std::flush(std::cout);
}

} // namespace kernel
Expand Down Expand Up @@ -174,3 +184,10 @@ TEST(FastNumericConversion, s8_to_bf16_array) {
using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
}

TEST(FastNumericConversion, s4_to_f16_array) {
int const kN = 16;
using Source = cutlass::int4b_t;
using Destination = cutlass::half_t;
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
}
28 changes: 26 additions & 2 deletions test/unit/gemm/warp/gemm_mixed_input_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
.run();
.run(cutlass::Distribution::Sequential, cutlass::Distribution::Sequential);
}


Expand Down Expand Up @@ -319,4 +319,28 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_1
.run();
}

#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

////////////////////////////////////////////////////////////////////////////////
/// F32 <= F16 * I4 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i4, 128x128x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = cutlass::half_t;
using ElementB = cutlass::int4b_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
.run(cutlass::Distribution::Identity, cutlass::Distribution::Sequential);
}

#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

0 comments on commit f88a889

Please sign in to comment.