From 7b9134d722d7019a76a935ec46e70daff3a31b12 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 30 Oct 2024 15:06:13 +0000 Subject: [PATCH] Support multi-dimensional tensor prints in CPU runtime. Signed-off-by: Ilya Enkovich --- third_party/cpu/runtime/cpu_runtime.cpp | 375 +++++++++--------------- 1 file changed, 145 insertions(+), 230 deletions(-) diff --git a/third_party/cpu/runtime/cpu_runtime.cpp b/third_party/cpu/runtime/cpu_runtime.cpp index 06889b732364..537441903212 100644 --- a/third_party/cpu/runtime/cpu_runtime.cpp +++ b/third_party/cpu/runtime/cpu_runtime.cpp @@ -34,70 +34,99 @@ struct FormatInfo { bool isHex; }; +template struct RawMemRefDescriptor { + const T *allocated; + const T *aligned; + intptr_t offset; + intptr_t sizesAndStrides[]; +}; + +template class MemRefDescriptor { +private: + const T *data_; + std::vector sizes_; + std::vector strides_; + + MemRefDescriptor(const T *data, std::vector sizes, + std::vector strides) + : data_(data), sizes_(std::move(sizes)), strides_(std::move(strides)) {} + +public: + MemRefDescriptor(int32_t rank, void *rawDescriptor) { + auto *rawDesc = static_cast *>(rawDescriptor); + data_ = rawDesc->aligned + rawDesc->offset; + sizes_.insert(sizes_.begin(), rawDesc->sizesAndStrides, + rawDesc->sizesAndStrides + rank); + strides_.insert(strides_.begin(), rawDesc->sizesAndStrides + rank, + rawDesc->sizesAndStrides + rank * 2); + } + + const T *data() const { return data_; } + + int64_t rank() const { return static_cast(sizes_.size()); } + + int64_t size(int64_t dim) const { return sizes_[dim]; } + + int64_t stride(int64_t dim) const { return strides_[dim]; } + + MemRefDescriptor subView(int64_t idx) const { + assert(rank() > 1); + return {data_ + idx * stride(0), + {sizes_.begin() + 1, sizes_.end()}, + {strides_.begin() + 1, strides_.end()}}; + } +}; + +struct UnrankedMemRefType { + int64_t rank; + void *descriptor; +}; + template -std::pair -computeDigitInfoHelper(const void *array, size_t index) { - T elem = static_cast(array)[index]; - if (elem == 0) +std::pair computeDigitInfo(T val) { + if (val == 0) return {1, false}; - return {static_cast(std::log10(elem >= 0 ? elem : -elem)) + 1, elem < 0}; + int digits = + std::max(static_cast(std::log10(val >= 0 ? val : -val)), 0) + 1; + return {digits, val < 0}; } -std::pair computeDigitInfo(void *vec, bool isInt, bool isSigned, - int32_t bitWidth, size_t index) { - if (isInt == 0) { - if (bitWidth == 32) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 64) - return computeDigitInfoHelper(vec, index); - else - llvm_unreachable("Unsupported bitWidth"); +template +std::tuple computeDigitStats(const MemRefDescriptor &desc) { + int maxIntDigits = 0; + int minIntDigits = std::numeric_limits::max(); + bool hasNegative = false; + + if (desc.rank() == 1) { + const T *data = desc.data(); + int64_t stride = desc.stride(0); + for (int64_t i = 0; i < desc.size(0); ++i) { + auto [digits, negative] = computeDigitInfo(data[i * stride]); + hasNegative |= negative; + maxIntDigits = std::max(maxIntDigits, digits); + minIntDigits = std::min(minIntDigits, digits); + } } else { - if (isSigned) { - if (bitWidth == 64) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 32) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 16) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 8) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 1) - return computeDigitInfoHelper(vec, index); - } else { - if (bitWidth == 64) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 32) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 16) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 8) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 1) - return computeDigitInfoHelper(vec, index); + for (int64_t i = 0; i < desc.size(0); ++i) { + auto [maxDigits, minDigits, negative] = + computeDigitStats(desc.subView(i)); + hasNegative |= negative; + maxIntDigits = std::max(maxIntDigits, maxDigits); + minIntDigits = std::min(minIntDigits, minDigits); } - printf("bitWidth: %d\n", bitWidth); - llvm_unreachable("Unsupported bitWidth"); } + + return std::make_tuple(maxIntDigits, minIntDigits, hasNegative); } -FormatInfo getFormatInfo(void *vec, bool isInt, bool isSigned, int32_t bitWidth, - int64_t numElem, bool isHex) { +template +FormatInfo getFormatInfo(const MemRefDescriptor &desc, bool isInt, + bool isSigned, int32_t bitWidth, bool isHex) { if (isHex) { assert(bitWidth >= 8 && bitWidth <= 64 && bitWidth % 8 == 0); return {isInt, isSigned, bitWidth, bitWidth / 4, false, false, true}; } - // Compute the max/min widths for pretty printing. - int maxIntDigits = 0; - int minIntDigits = std::numeric_limits::max(); - bool hasNegative = false; - for (int64_t i = 0; i < numElem; ++i) { - auto [digits, negative] = - computeDigitInfo(vec, isInt, isSigned, bitWidth, i); - hasNegative |= negative; - maxIntDigits = std::max(maxIntDigits, digits); - minIntDigits = std::min(minIntDigits, digits); - } + auto [maxIntDigits, minIntDigits, hasNegative] = computeDigitStats(desc); // Fallback to the scientific format for certain cases. bool scientific; if (isInt) { @@ -111,178 +140,98 @@ FormatInfo getFormatInfo(void *vec, bool isInt, bool isSigned, int32_t bitWidth, } template -void printElementHelper(std::stringstream &ss, const void *array, - size_t index) { - ss << static_cast(array)[index]; -} - -void printElement(std::stringstream &ss, const void *vec, size_t index, - const FormatInfo &formatInfo) { - if (!formatInfo.isInt) { - switch (formatInfo.bitWidth) { - case 32: - printElementHelper(ss, vec, index); - return; - case 64: - printElementHelper(ss, vec, index); - return; - default: - llvm_unreachable("Unsupported bitWidth"); - } - } - - if (formatInfo.isSigned) { - switch (formatInfo.bitWidth) { - case 64: - printElementHelper(ss, vec, index); - return; - case 32: - printElementHelper(ss, vec, index); - return; - case 16: - printElementHelper(ss, vec, index); - return; - case 8: - // int8_t is printed as char. - ss << static_cast(static_cast(vec)[index]); - return; - case 1: - printElementHelper(ss, vec, index); - return; - default: - llvm_unreachable("Unsupported bitWidth"); - } - } - - switch (formatInfo.bitWidth) { - case 64: - printElementHelper(ss, vec, index); - return; - case 32: - printElementHelper(ss, vec, index); - return; - case 16: - printElementHelper(ss, vec, index); - return; - case 8: - ss << static_cast(static_cast(vec)[index]); - return; - case 1: - printElementHelper(ss, vec, index); - return; - default: - llvm_unreachable("Unsupported bitWidth"); - } -} - -void printFormattedElement(std::stringstream &ss, void *vec, size_t index, +void printFormattedElement(std::stringstream &ss, T val, const FormatInfo &formatInfo) { // Right now, the GPU's hex float doesn't work correctly. C++ has std:: // hexfloat, but let's consider only hex integers for now. if (formatInfo.isHex && formatInfo.isInt) { ss << "0x" << std::hex << std::setw(formatInfo.maxIntDigits) - << std::setfill('0'); - printElement(ss, vec, index, formatInfo); + << std::setfill('0') << val; return; } int padding = 0; - auto [digits, negative] = computeDigitInfo( - vec, formatInfo.isInt, formatInfo.isSigned, formatInfo.bitWidth, index); + auto [digits, negative] = computeDigitInfo(val); if (!negative && formatInfo.hasNegative) padding++; if (formatInfo.scientific) { ss << std::scientific << std::setw(MAX_FLOAT_WIDTH) - << std::setprecision(FLOAT_PREC) << std::string(padding, ' '); - printElement(ss, vec, index, formatInfo); + << std::setprecision(FLOAT_PREC) << std::string(padding, ' ') << val; } else { padding += formatInfo.maxIntDigits - digits; ss << std::fixed << std::setprecision(FLOAT_PREC) - << std::string(padding, ' '); - printElement(ss, vec, index, formatInfo); + << std::string(padding, ' ') << val; } } -template struct RawMemRefDescriptor { - T *allocated; - T *aligned; - intptr_t offset; - intptr_t sizesAndStrides[]; -}; - -template struct MemRefDescriptor { - T *allocated; - T *aligned; - intptr_t offset; - std::vector sizes; - std::vector strides; - int32_t rank; - - MemRefDescriptor(int32_t rank, void *rawDescriptor) : rank(rank) { - auto *rawDesc = static_cast *>(rawDescriptor); - allocated = rawDesc->allocated; - aligned = rawDesc->aligned; - offset = rawDesc->offset; - sizes.resize(rank); - strides.resize(rank); - for (int32_t i = 0; i < rank; i++) { - sizes[i] = rawDesc->sizesAndStrides[i]; - strides[i] = rawDesc->sizesAndStrides[i + rank]; - } - } -}; +// int8_t is printed as char, so use int16_t instead. +template <> +void printFormattedElement(std::stringstream &ss, int8_t val, + const FormatInfo &formatInfo) { + printFormattedElement(ss, val, formatInfo); +} -struct UnrankedMemRefType { - int64_t rank; - void *descriptor; -}; +template <> +void printFormattedElement(std::stringstream &ss, uint8_t val, + const FormatInfo &formatInfo) { + printFormattedElement(ss, val, formatInfo); +} template -void printToStream(MemRefDescriptor &&desc, std::stringstream &ss, - FormatInfo &partialFormatInfo) { - - if (desc.rank > 1) { - ss << "<>\n"; +void printToStreamRecursive(const MemRefDescriptor &desc, + std::stringstream &ss, const FormatInfo &formatInfo, + const std::string &linePrefix) { + if (desc.rank() > 1) { + ss << "["; + for (int64_t i = 0; i < desc.size(0); ++i) { + printToStreamRecursive(desc.subView(i), ss, formatInfo, linePrefix + " "); + if (i != desc.size(0) - 1) + ss << ",\n" << linePrefix << " "; + } + ss << "]"; return; } - if (desc.sizes.size() == 0) { - ss << "<>\n"; - } - - T *vec = desc.aligned; - int32_t numElems = desc.sizes[0]; - FormatInfo formatInfo = getFormatInfo( - vec, partialFormatInfo.isInt, partialFormatInfo.isSigned, - partialFormatInfo.bitWidth, numElems, partialFormatInfo.isHex); - - const size_t header = ss.str().size(); + const T *data = desc.data(); + int64_t stride = desc.stride(0); + int64_t numElems = desc.size(0); + ss << "["; if (numElems <= ELEMS_PER_LINE) { for (int i = 0; i < numElems; i++) { - printFormattedElement(ss, vec, i, formatInfo); + printFormattedElement(ss, data[i * stride], formatInfo); if (i != numElems - 1) ss << ", "; } } else { // TODO: Too many lines? Omit the middle lines. for (int i = 0; i < numElems; i++) { - printFormattedElement(ss, vec, i, formatInfo); + printFormattedElement(ss, data[i * stride], formatInfo); if (i == numElems - 1) break; if (i % ELEMS_PER_LINE == ELEMS_PER_LINE - 1) { - ss << ",\n" << std::string(header, ' '); + ss << ",\n" << linePrefix << " "; } else { ss << ", "; } } } - ss << "]\n"; + ss << "]"; +} + +template +void printToStream(const MemRefDescriptor &desc, std::stringstream &ss, + const FormatInfo &partialFormatInfo, + const std::string &linePrefix) { + FormatInfo formatInfo = getFormatInfo( + desc, partialFormatInfo.isInt, partialFormatInfo.isSigned, + partialFormatInfo.bitWidth, partialFormatInfo.isHex); + printToStreamRecursive(desc, ss, formatInfo, linePrefix); } void printMemRef(std::stringstream &ss, int32_t rank, void *descriptor, - int32_t btw, bool isInteger, bool isSignedInteger, - bool asHex) { + int32_t btw, bool isInteger, bool isSignedInteger, bool asHex, + const std::string &linePrefix) { FormatInfo partialFormat{.isInt = isInteger, .isSigned = isSignedInteger, @@ -292,11 +241,11 @@ void printMemRef(std::stringstream &ss, int32_t rank, void *descriptor, switch (btw) { case 64: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 32: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; default: llvm_unreachable("Unsupported bitWidth"); @@ -306,23 +255,23 @@ void printMemRef(std::stringstream &ss, int32_t rank, void *descriptor, switch (btw) { case 64: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 32: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 16: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 8: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 1: - printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + printToStream(MemRefDescriptor(rank, descriptor), ss, partialFormat, + linePrefix); return; default: llvm_unreachable("Unsupported bitWidth"); @@ -331,22 +280,23 @@ void printMemRef(std::stringstream &ss, int32_t rank, void *descriptor, switch (btw) { case 64: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 32: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 16: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 8: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 1: - printToStream(MemRefDescriptor(rank, descriptor), ss, partialFormat); + printToStream(MemRefDescriptor(rank, descriptor), ss, partialFormat, + linePrefix); return; default: llvm_unreachable("Unsupported bitWidth"); @@ -367,48 +317,11 @@ EXPORT void triton_assert(int32_t pid0, int32_t pid1, int32_t pid2, bool cond, abort(); } -// Print the pid prefix like the GPU ad interpreter. And vectors are printed +// Print the pid prefix like the GPU and interpreter. And vectors are printed // similar to Torch's printing like the following: // (1, 0, 0) x: [ -0.4963, -1.7682, 2.0885, 3.1320, -4.3074, 5.6341, // -6.4901, 7.8964, -8.4556, -9.6323, -10.3489, -11.4017, // -12.0223, 13.1689, 14.2939, -15.5185] -// -// TODO: Implement for higher dimension vectors. -EXPORT void triton_vector_print(int32_t pid0, int32_t pid1, int32_t pid2, - const char *prefix, void *vec, bool isInt, - bool isSigned, int32_t bitWidth, - int64_t numElem, bool isHex) { - - FormatInfo formatInfo = - getFormatInfo(vec, isInt != 0, isSigned != 0, bitWidth, numElem, isHex); - - std::stringstream ss; - ss << "(" << pid0 << ", " << pid1 << ", " << pid2 << ")" << prefix << "["; - const size_t header = ss.str().size(); - - if (numElem <= ELEMS_PER_LINE) { - for (int i = 0; i < numElem; i++) { - printFormattedElement(ss, vec, i, formatInfo); - if (i != numElem - 1) - ss << ", "; - } - } else { - // TODO: Too many lines? Omit the middle lines. - for (int i = 0; i < numElem; i++) { - printFormattedElement(ss, vec, i, formatInfo); - if (i == numElem - 1) - break; - if (i % ELEMS_PER_LINE == ELEMS_PER_LINE - 1) { - ss << ",\n" << std::string(header, ' '); - } else { - ss << ", "; - } - } - } - ss << "]\n"; - std::cout << ss.str() << std::flush; -} - EXPORT void triton_print_unranked_memref(int32_t pid0, int32_t pid1, int32_t pid2, const char *prefix, UnrankedMemRefType memref, int32_t btw, @@ -416,9 +329,11 @@ EXPORT void triton_print_unranked_memref(int32_t pid0, int32_t pid1, bool asHex) { std::stringstream ss; ss << "(" << pid0 << ", " << pid1 << ", " << pid2 << ")" << prefix; + std::string linePrefix(ss.str().size(), ' '); printMemRef(ss, memref.rank, memref.descriptor, btw, isInteger, - isSignedInteger, asHex); + isSignedInteger, asHex, linePrefix); + ss << "\n"; std::cout << ss.str() << std::flush; }