diff --git a/demo/icebreaker/Makefile b/demo/icebreaker/Makefile index e916ed4..f771e32 100644 --- a/demo/icebreaker/Makefile +++ b/demo/icebreaker/Makefile @@ -3,7 +3,7 @@ CROSS=riscv64-unknown-elf- CFLAGS=-g -Os -march=rv32im -mabi=ilp32 -all: out/icebreaker.json out/icebreaker.bin +all: out/icebreaker.json out/icebreaker.bin out/icebreaker_fw.bin sim: icesim @@ -55,7 +55,7 @@ iceprog_fw: icebreaker_fw.bin brd/icebreaker_sections.lds: src/sections.lds $(CROSS)cpp -P -DICEBREAKER -o $@ $^ -out/icebreaker_fw.elf: brd/icebreaker_sections.lds src/start.s src/firmware.c +out/icebreaker_fw.elf: brd/icebreaker_sections.lds src/start.s src/firmware.c src/rv32-custom.h $(CROSS)gcc $(CFLAGS) -DICEBREAKER -Wl,-Bstatic,-T,brd/icebreaker_sections.lds,--strip-debug -ffreestanding -nostdlib -o out/icebreaker_fw.elf src/start.s src/firmware.c $(CROSS)size --format berkley out/icebreaker_fw.elf diff --git a/demo/icebreaker/README.md b/demo/icebreaker/README.md index 4792aad..37bc4d4 100644 --- a/demo/icebreaker/README.md +++ b/demo/icebreaker/README.md @@ -36,60 +36,51 @@ The **VEC-8U8-16I8-2S32** block executes in **2 x CPU clock** cycles but picorv3 Example of generated TIR representation for a MATMUL [64x64]*[64x64] inside TVM: ``` -primfn(X_1: handle, coeffW_1: handle, F.global_1: handle) -> () +primfn(X_1: handle, W_1: handle, F.wmma.accumulator_1: handle) -> () attr = {"global_symbol": "main", "tir.noalias": True} - buffers = {F.global: Buffer(F.global_2: Pointer(int32), int32, [64, 64], []), - coeffW: Buffer(coeffW_2: Pointer(int8), int8, [2, 32, 8, 8], []), + buffers = {F.wmma.accumulator: Buffer(F.wmma.accumulator_2: Pointer(int32), int32, [64, 64], []), + W: Buffer(W_2: Pointer(int8), int8, [64, 64], []), X: Buffer(X_2: Pointer(uint8), uint8, [64, 64], [])} - buffer_map = {X_1: X, coeffW_1: coeffW, F.global_1: F.global} { + buffer_map = {X_1: X, W_1: W, F.wmma.accumulator_1: F.wmma.accumulator} { attr [F: Pointer(int32)] "storage_scope" = "global"; allocate(F, int32, [4096]); - attr [coeffW.global: Pointer(int8)] "storage_scope" = "global"; - allocate(coeffW.global, int8, [128]); for (i: int32, 0, 64) { for (j.outer: int32, 0, 32) { - for (ax0: int32, 0, 2) { - for (ax2: int32, 0, 8) { - for (ax3: int32, 0, 8) { - coeffW.global[(((ax0*64) + (ax2*8)) + ax3)] = (int8*)coeffW_2[((((ax0*2048) + (j.outer*64)) + (ax2*8)) + ax3)] - } - } - } - @tir.call_extern("MACZ_olimp", - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), + @tir.call_extern("O_VEC_MACZ", + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, (i*64), 8, 1, dtype=handle), - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 0, 128, 1, dtype=handle), 64, dtype=int32) - @tir.call_extern("MACC_olimp", - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, (j.outer*128), 128, 1, dtype=handle), 64, dtype=int32) + @tir.call_extern("O_VEC_MACC", + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 8), 8, 1, dtype=handle), - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 8, 128, 1, dtype=handle), 64, dtype=int32) - @tir.call_extern("MACC_olimp", - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 8), 128, 1, dtype=handle), 64, dtype=int32) + @tir.call_extern("O_VEC_MACC", + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 16), 8, 1, dtype=handle), - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 16, 128, 1, dtype=handle), 64, dtype=int32) - @tir.call_extern("MACC_olimp", - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 16), 128, 1, dtype=handle), 64, dtype=int32) + @tir.call_extern("O_VEC_MACC", + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 24), 8, 1, dtype=handle), - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 24, 128, 1, dtype=handle), 64, dtype=int32) - @tir.call_extern("MACC_olimp", - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 24), 128, 1, dtype=handle), 64, dtype=int32) + @tir.call_extern("O_VEC_MACC", + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 32), 8, 1, dtype=handle), - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 32, 128, 1, dtype=handle), 64, dtype=int32) - @tir.call_extern("MACC_olimp", - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 32), 128, 1, dtype=handle), 64, dtype=int32) + @tir.call_extern("O_VEC_MACC", + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 40), 8, 1, dtype=handle), - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 40, 128, 1, dtype=handle), 64, dtype=int32) - @tir.call_extern("MACC_olimp", - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 40), 128, 1, dtype=handle), 64, dtype=int32) + @tir.call_extern("O_VEC_MACC", + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 48), 8, 1, dtype=handle), - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 48, 128, 1, dtype=handle), 64, dtype=int32) - @tir.call_extern("MACC_olimp", - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 48), 128, 1, dtype=handle), 64, dtype=int32) + @tir.call_extern("O_VEC_MACC", + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 56), 8, 1, dtype=handle), - @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 56, 128, 1, dtype=handle), 64, dtype=int32) - for (j.inner: int32, 0, 2) { - F[(((i*64) + (j.outer*2)) + j.inner)] = (int32*)F.global_2[(((i*64) + (j.outer*2)) + j.inner)] - } + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 56), 128, 1, dtype=handle), 64, dtype=int32) + @tir.call_extern("O_VEC_STOR", + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F, ((i*64) + (j.outer*2)), 2, 2, dtype=handle), + @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 1, dtype=handle), dtype=int32) } } } @@ -125,10 +116,10 @@ TVM_DLL int32_t intrinsic(void* args, void* arg_type_ids, int32_t num_args, void void* arg0_shape = (((DLTensor*)arg0)[0].shape); void* arg0_strides = (((DLTensor*)arg0)[0].strides); int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id); - void* coeffW = (((DLTensor*)arg1)[0].data); + void* W = (((DLTensor*)arg1)[0].data); void* arg1_shape = (((DLTensor*)arg1)[0].shape); void* arg1_strides = (((DLTensor*)arg1)[0].strides); - void* F_global = (((DLTensor*)arg2)[0].data); + void* F_wmma_accumulator = (((DLTensor*)arg2)[0].data); void* arg2_shape = (((DLTensor*)arg2)[0].shape); void* arg2_strides = (((DLTensor*)arg2)[0].strides); if (!(arg0_strides == NULL)) { @@ -141,53 +132,23 @@ TVM_DLL int32_t intrinsic(void* args, void* arg_type_ids, int32_t num_args, void if (F == NULL) { return -1; } - void* coeffW_global = TVMBackendAllocWorkspace(1, dev_id, (uint64_t)128, 0, 8); - if (coeffW_global == NULL) { - return -1; - } for (int32_t i = 0; i < 64; ++i) { for (int32_t j_outer = 0; j_outer < 32; ++j_outer) { - for (int32_t ax0 = 0; ax0 < 2; ++ax0) { - for (int32_t ax2 = 0; ax2 < 8; ++ax2) { - for (int32_t ax3 = 0; ax3 < 8; ++ax3) { - ((int8_t*)coeffW_global)[((((ax0 * 64) + (ax2 * 8)) + ax3))] = - ((int8_t*)coeffW)[(((((ax0 * 2048) + (j_outer * 64)) + (ax2 * 8)) + ax3))]; - } - } - } - (void)MACZ_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))), - ((uint8_t *)X + ((i * 64))), - ((int8_t *)coeffW_global + (0)), 64); - (void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))), - ((uint8_t *)X + (((i * 64) + 8))), - ((int8_t *)coeffW_global + (8)), 64); - (void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))), - ((uint8_t *)X + (((i * 64) + 16))), ((int8_t *)coeffW_global + (16)), 64); - (void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))), - ((uint8_t *)X + (((i * 64) + 24))), ((int8_t *)coeffW_global + (24)), 64); - (void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))), - ((uint8_t *)X + (((i * 64) + 32))), ((int8_t *)coeffW_global + (32)), 64); - (void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))), - ((uint8_t *)X + (((i * 64) + 40))), ((int8_t *)coeffW_global + (40)), 64); - (void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))), - ((uint8_t *)X + (((i * 64) + 48))), ((int8_t *)coeffW_global + (48)), 64); - (void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))), - ((uint8_t *)X + (((i * 64) + 56))), ((int8_t *)coeffW_global + (56)), 64); - for (int32_t j_inner = 0; j_inner < 2; ++j_inner) { - ((int32_t*)F)[((((i * 64) + (j_outer * 2)) + j_inner))] = - ((int32_t*)F_global)[((((i * 64) + (j_outer * 2)) + j_inner))]; - } + (void)O_VEC_MACZ(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + ((i * 64))), ((int8_t *)W + ((j_outer * 128))), 64); + (void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 8))), ((int8_t *)W + (((j_outer * 128) + 8))), 64); + (void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 16))), ((int8_t *)W + (((j_outer * 128) + 16))), 64); + (void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 24))), ((int8_t *)W + (((j_outer * 128) + 24))), 64); + (void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 32))), ((int8_t *)W + (((j_outer * 128) + 32))), 64); + (void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 40))), ((int8_t *)W + (((j_outer * 128) + 40))), 64); + (void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 48))), ((int8_t *)W + (((j_outer * 128) + 48))), 64); + (void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 56))), ((int8_t *)W + (((j_outer * 128) + 56))), 64); + (void)O_VEC_STOR(((int32_t *)F + (((i * 64) + (j_outer * 2)))), ((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2))))); } } - if (TVMBackendFreeWorkspace(1, dev_id, coeffW_global) != 0) { - return -1; - } if (TVMBackendFreeWorkspace(1, dev_id, F) != 0) { return -1; } return 0; -} - ``` *Note*: [MACC_olimp()](/demo/icebreaker/src/firmware.c#L340) are wrapped **__asm__( ".word 0xRV32custom")** RV32 ISA extension for OLIMP hardware block. diff --git a/srcs/sims/olimp-vec-gemm.py b/srcs/sims/olimp-vec-gemm.py new file mode 100755 index 0000000..ed98ccf --- /dev/null +++ b/srcs/sims/olimp-vec-gemm.py @@ -0,0 +1,344 @@ +#!/usr/bin/python3 + +""" + OLIMP VEC schedule validation script. +""" + +## +## License: GPLv3 +## https://www.gnu.org/licenses/gpl-3.0.en.html +## +## Copyright 2021 +## Cristian Balint < cristian dot balint at gmail dot com > +## + +import sys +import tvm +import random +import string + +import tvm.testing +from tvm import te + +import numpy as np + + +debug = False + +# init with OLIMP 8U8-16I8-2S32 example +INT8_MACS = 8 # int8 elements per int32 accumulator +INT32_LANES = 2 # int32 accumulator lanes (ACC0 & ACC1) + +def OLIMP_VEC_MAC_impl(): + cc_code = f""" + #include + #ifdef __cplusplus + extern "C" + #endif + int32_t O_VEC_MACC(int32_t *output, + const uint8_t *data, + const int8_t *kernel, + const int32_t stride) {{ + for (int i = 0; i < {INT32_LANES}; ++i) {{ + for (int j = 0; j < {INT8_MACS}; ++j) {{ + output[i] += data[j] * kernel[i * stride + j]; + }} + }} + return 0; + }} + #ifdef __cplusplus + extern "C" + #endif + int32_t O_VEC_MACZ(int32_t *output, + const uint8_t *data, + const int8_t *kernel, + const int32_t stride) {{ + for (int i = 0; i < {INT32_LANES}; ++i) {{ + output[i] = 0; + for (int j = 0; j < {INT8_MACS}; ++j) {{ + output[i] += data[j] * kernel[i * stride + j]; + }} + }} + return 0; + }} + #ifdef __cplusplus + extern "C" + #endif + int32_t O_VEC_STOR(int32_t *output, + const int32_t *acc) {{ + for (int i = 0; i < {INT32_LANES}; ++i) {{ + output[i] = acc[i]; + }} + return 0; + }} + """ + + from tvm.contrib import utils, clang + temp = utils.tempdir() + ll_path = temp.relpath("temp.ll") + # llvm ir from c source code + ll_code = clang.create_llvm(cc_code, output=ll_path) + return ll_code + +def OLIMP_VEC_MAC(): + """ + Int8 dot product by every INT8_MACS elements using OLIMP VEC + instructions. This function takes two arrays of uint8 and int8 + datatype -- data[INT8_MACS] and coef[INT32_LANES][INT8_MACS] -- + and computes a dot product of data[INT8_MACS] with every INT8_MACS + elements of coef[INT8_MACS], resulting acc[INT32_LANES] accumulators + of int32 datatype. + The pseudo code is as follows. + .. code-block:: c + void O_VEC_MAC{C,Z}(uint8 data[INT8_MACS], + int8 coef[INT32_LANES][INT8_MACS], + int32 acc[INT32_LANES]){ + for (int i = 0; i < INT32_LANES; i++){ + acc[i] = 0; // <- case of MACZ + for (int k = 0; k < INT8_MACS; k++){ + acc[i] += data[k] * coef[i][k] + } + } + } + + Physically, the coef arrays are accessed via [INT32_LANES] by [INT8_MACS] + memory order as the innermost region. + + This function returns a TensorIntrin that can be used to tensorize a schedule. + + Returns + ------- + intrin : TensorIntrin + The OLIMP MAC{C,Z} int8 TensorIntrin that can be used in tensorizing schedule + """ + + data = te.placeholder((INT8_MACS,), dtype='uint8', name='data') + coef = te.placeholder((INT32_LANES, INT8_MACS), dtype='int8', name='coef') + + k = te.reduce_axis((0, INT8_MACS), name='k') + C = te.compute((INT32_LANES,), lambda i: + te.sum( (data[ k] * + coef[i, k]).astype("int32"), + axis=k), + name="Co") + + Aa = tvm.tir.decl_buffer(data.shape, data.dtype, name="data_buffer", + scope="global", + offset_factor=1, strides=[1]) + Bb = tvm.tir.decl_buffer(coef.shape, coef.dtype, name="coef_buffer", + scope="global", + offset_factor=1, strides=[te.var("ldw"), 1]) + Co = tvm.tir.decl_buffer(C.shape, C.dtype, name="Co", + scope="global", + offset_factor=1, strides=[1]) + + def intrin_func(ins, outs): + Aa, Bb = ins + Co = outs[0] + def _body(): + ib = tvm.tir.ir_builder.create() + o_vec_macz = tvm.tir.call_extern( + "int32", + f"O_VEC_MACZ", + Co.access_ptr("w"), + Aa.access_ptr("r"), + Bb.access_ptr("r"), + Bb.strides[0]) + ib.emit(o_vec_macz) + return ib.get() + def _reduce_reset(): + return None + def _reduce_update(): + ib = tvm.tir.ir_builder.create() + o_vec_macc = tvm.tir.call_extern( + "int32", + f"O_VEC_MACC", + Co.access_ptr("w"), + Aa.access_ptr("r"), + Bb.access_ptr("r"), + Bb.strides[0]) + ib.emit(o_vec_macc) + return ib.get() + return _body(), _reduce_reset(), _reduce_update() + + buffer_params = {"offset_factor" : 1} + intrin_decl = te.decl_tensor_intrin( + C.op, intrin_func, binds={data: Aa, coef: Bb, C: Co}, + default_buffer_params=buffer_params) + return intrin_decl + +def OLIMP_VEC_STR(): + """ + Copy every -- acc[INT32_LANES] -- accumulator elements to the main memory + as reduction results of computation. The acc[INT32_LANES] registers + are not visibile otherwise to the main system. + + This function returns a TensorIntrin that can be used to tensorize a schedule. + + Returns + ------- + intrin : TensorIntrin + The OLIMP STOR int32 TensorIntrin that can be used in tensorizing schedule + """ + data = te.placeholder((INT32_LANES,), name="A", dtype="int32") + + C = te.compute((INT32_LANES,), lambda i: data[i], name="Cf") + + Aa = tvm.tir.decl_buffer(data.shape, data.dtype, name="Aa", + scope="global", + offset_factor=1) + + Co = tvm.tir.decl_buffer(C.shape, C.dtype, name="Cf", + scope="global", + offset_factor=1) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + Aa = ins[0] + Co = outs[0] + o_vec_stor = tvm.tir.call_extern( + "int32", + f"O_VEC_STOR", + Co.access_ptr("w"), + Aa.access_ptr("r")) + ib.emit(o_vec_stor) + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={data: Aa, C: Co}) + +## +## C(m, n) = X(m, k) * W(n, k) +## X is data +## W is coef +## + +# dummy arbitrary data matrix +M = 64 #64 +N = 64 #64 # must be multiple of INT32_LANES +K = 64 #64 # must be multiple of INT8_MACS (the common axis) + + +def compute(target="llvm"): + + device = tvm.cpu(0) + + # inputs + X = te.placeholder((M, K), name='X', dtype="uint8") + W = te.placeholder((N, K), name='W', dtype="int8") + + ## + ## weights ordering for vector computation + ## + + # WEIGHT [N, K ] -> + # WEIGHT [N/I32LANES, K/I8MACS, I32_LANES, I8MACS] + wshape = (N // INT32_LANES, K // INT8_MACS, INT32_LANES, INT8_MACS) + coefW = te.compute( + wshape, + lambda r_idx, s_idx, l_idx, t_idx: + W[r_idx * INT32_LANES + l_idx][s_idx * INT8_MACS + t_idx], + name="Wcoef") + + ## + ## matmul vector computation + ## + + idxd = tvm.tir.indexdiv + idxm = tvm.tir.indexmod + ak = te.reduce_axis((0, K), name='k') + C = te.compute((M, N), lambda i, j: + te.sum( (X[i, ak] * + coefW[idxd(j, INT32_LANES), + idxd(ak, INT8_MACS), + idxm(j, INT32_LANES), + idxm(ak, INT8_MACS)]).astype("int32") + , axis=ak), + name="F") + + ## + ## matmul vector scheduling + ## + + # create schedule + s = te.create_schedule(C.op) + + # reorganize coef inline + s[coefW].compute_inline() + + # schedule write cache + CF = s.cache_write(C, "wmma.accumulator") + + # schedule flush write + b_x, b_y = s[C].op.axis + b_yo, b_yi = s[C].split(b_y, factor=INT32_LANES) + + # schedule compute + a_x, a_y = s[CF].op.axis + a_k, = s[CF].op.reduce_axis + a_yo, a_yi = s[CF].split(a_y, factor=INT32_LANES) + a_ko, a_ki = s[CF].split(a_k, factor=INT8_MACS) + # (lanes, macs) as inner most + s[CF].reorder(a_yo, a_x, a_ko, a_yi, a_ki) + # fuse all outer + fuse = s[CF].fuse(a_yo, a_x) + # flush accumulators to end + s[CF].compute_at(s[C], b_yo) + + # unroll + s[CF].unroll(a_ko) + # tensorize vectors + s[CF].tensorize(a_yi, OLIMP_VEC_MAC()) + # tensorize accumulators + s[C].tensorize(b_yi, OLIMP_VEC_STR()) + + + ## + ## graph and code generation + ## + + # print lowered TIR computation graph + print(tvm.lower(s, [X, W, CF], simple_mode=True)) + + if (debug): + # visual graph debug + from tvm.contrib import tedd + tedd.viz_dataflow_graph(s, dot_file_path = 'tvm-dfg.dot') + tedd.viz_schedule_tree(s, dot_file_path = 'tvm-scheduletree.dot') + t_s = s.normalize() + tedd.viz_schedule_tree(t_s, dot_file_path = 'tvm-scheduletree_norm.dot') + tedd.viz_itervar_relationship_graph(t_s, dot_file_path = 'tvm-itervar.dot') + + # imprint vector machine function calls + s[CF].pragma(fuse, "import_llvm", OLIMP_VEC_MAC_impl()) + + # compile whole computation graph + t_func = tvm.build(s, [X, W, CF], target="llvm", name="intrinsic") + + if (debug): + print(t_func.get_source()) + + ## + ## evaluate + ## + + t_evaluator = t_func.time_evaluator(t_func.entry_name, device, number=0) + + # generate plain data + a_ = np.random.uniform(1, 10, size=(M, K)).astype("uint8") + b_ = np.random.uniform(1, 10, size=(N, K)).astype("int8") + + print("A shape =", a_.shape) + print("B shape =", b_.shape) + + x = tvm.nd.array(a_, device) + w = tvm.nd.array(b_, device) + y = tvm.nd.array(np.zeros((M, N), dtype="int32"), device) + result = t_evaluator(x, w, y) + + print("\nA x B :\n", y) + + # verify the correctness + tvm.testing.assert_allclose(y.asnumpy(), np.dot(a_, b_.T), rtol=0) + t_func.export_library("tensorize_acc32.o") + +compute()