Skip to content

Commit

Permalink
Add template simulation for VEC types: GEMM operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 committed Jul 6, 2021
1 parent fd43802 commit b5301da
Show file tree
Hide file tree
Showing 3 changed files with 388 additions and 83 deletions.
4 changes: 2 additions & 2 deletions demo/icebreaker/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ CROSS=riscv64-unknown-elf-
CFLAGS=-g -Os -march=rv32im -mabi=ilp32


all: out/icebreaker.json out/icebreaker.bin
all: out/icebreaker.json out/icebreaker.bin out/icebreaker_fw.bin

sim: icesim

Expand Down Expand Up @@ -55,7 +55,7 @@ iceprog_fw: icebreaker_fw.bin
brd/icebreaker_sections.lds: src/sections.lds
$(CROSS)cpp -P -DICEBREAKER -o $@ $^

out/icebreaker_fw.elf: brd/icebreaker_sections.lds src/start.s src/firmware.c
out/icebreaker_fw.elf: brd/icebreaker_sections.lds src/start.s src/firmware.c src/rv32-custom.h
$(CROSS)gcc $(CFLAGS) -DICEBREAKER -Wl,-Bstatic,-T,brd/icebreaker_sections.lds,--strip-debug -ffreestanding -nostdlib -o out/icebreaker_fw.elf src/start.s src/firmware.c
$(CROSS)size --format berkley out/icebreaker_fw.elf

Expand Down
123 changes: 42 additions & 81 deletions demo/icebreaker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,60 +36,51 @@ The **VEC-8U8-16I8-2S32** block executes in **2 x CPU clock** cycles but picorv3
Example of generated TIR representation for a MATMUL [64x64]*[64x64] inside TVM:

```
primfn(X_1: handle, coeffW_1: handle, F.global_1: handle) -> ()
primfn(X_1: handle, W_1: handle, F.wmma.accumulator_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {F.global: Buffer(F.global_2: Pointer(int32), int32, [64, 64], []),
coeffW: Buffer(coeffW_2: Pointer(int8), int8, [2, 32, 8, 8], []),
buffers = {F.wmma.accumulator: Buffer(F.wmma.accumulator_2: Pointer(int32), int32, [64, 64], []),
W: Buffer(W_2: Pointer(int8), int8, [64, 64], []),
X: Buffer(X_2: Pointer(uint8), uint8, [64, 64], [])}
buffer_map = {X_1: X, coeffW_1: coeffW, F.global_1: F.global} {
buffer_map = {X_1: X, W_1: W, F.wmma.accumulator_1: F.wmma.accumulator} {
attr [F: Pointer(int32)] "storage_scope" = "global";
allocate(F, int32, [4096]);
attr [coeffW.global: Pointer(int8)] "storage_scope" = "global";
allocate(coeffW.global, int8, [128]);
for (i: int32, 0, 64) {
for (j.outer: int32, 0, 32) {
for (ax0: int32, 0, 2) {
for (ax2: int32, 0, 8) {
for (ax3: int32, 0, 8) {
coeffW.global[(((ax0*64) + (ax2*8)) + ax3)] = (int8*)coeffW_2[((((ax0*2048) + (j.outer*64)) + (ax2*8)) + ax3)]
}
}
}
@tir.call_extern("MACZ_olimp",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.call_extern("O_VEC_MACZ",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, (i*64), 8, 1, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 0, 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("MACC_olimp",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, (j.outer*128), 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("O_VEC_MACC",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 8), 8, 1, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 8, 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("MACC_olimp",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 8), 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("O_VEC_MACC",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 16), 8, 1, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 16, 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("MACC_olimp",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 16), 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("O_VEC_MACC",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 24), 8, 1, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 24, 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("MACC_olimp",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 24), 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("O_VEC_MACC",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 32), 8, 1, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 32, 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("MACC_olimp",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 32), 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("O_VEC_MACC",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 40), 8, 1, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 40, 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("MACC_olimp",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 40), 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("O_VEC_MACC",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 48), 8, 1, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 48, 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("MACC_olimp",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.global_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 48), 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("O_VEC_MACC",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=uint8), X_2, ((i*64) + 56), 8, 1, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), coeffW.global, 56, 128, 1, dtype=handle), 64, dtype=int32)
for (j.inner: int32, 0, 2) {
F[(((i*64) + (j.outer*2)) + j.inner)] = (int32*)F.global_2[(((i*64) + (j.outer*2)) + j.inner)]
}
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), W_2, ((j.outer*128) + 56), 128, 1, dtype=handle), 64, dtype=int32)
@tir.call_extern("O_VEC_STOR",
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F, ((i*64) + (j.outer*2)), 2, 2, dtype=handle),
@tir.tvm_access_ptr(@tir.type_annotation(, dtype=int32), F.wmma.accumulator_2, ((i*64) + (j.outer*2)), 2, 1, dtype=handle), dtype=int32)
}
}
}
Expand Down Expand Up @@ -125,10 +116,10 @@ TVM_DLL int32_t intrinsic(void* args, void* arg_type_ids, int32_t num_args, void
void* arg0_shape = (((DLTensor*)arg0)[0].shape);
void* arg0_strides = (((DLTensor*)arg0)[0].strides);
int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id);
void* coeffW = (((DLTensor*)arg1)[0].data);
void* W = (((DLTensor*)arg1)[0].data);
void* arg1_shape = (((DLTensor*)arg1)[0].shape);
void* arg1_strides = (((DLTensor*)arg1)[0].strides);
void* F_global = (((DLTensor*)arg2)[0].data);
void* F_wmma_accumulator = (((DLTensor*)arg2)[0].data);
void* arg2_shape = (((DLTensor*)arg2)[0].shape);
void* arg2_strides = (((DLTensor*)arg2)[0].strides);
if (!(arg0_strides == NULL)) {
Expand All @@ -141,53 +132,23 @@ TVM_DLL int32_t intrinsic(void* args, void* arg_type_ids, int32_t num_args, void
if (F == NULL) {
return -1;
}
void* coeffW_global = TVMBackendAllocWorkspace(1, dev_id, (uint64_t)128, 0, 8);
if (coeffW_global == NULL) {
return -1;
}
for (int32_t i = 0; i < 64; ++i) {
for (int32_t j_outer = 0; j_outer < 32; ++j_outer) {
for (int32_t ax0 = 0; ax0 < 2; ++ax0) {
for (int32_t ax2 = 0; ax2 < 8; ++ax2) {
for (int32_t ax3 = 0; ax3 < 8; ++ax3) {
((int8_t*)coeffW_global)[((((ax0 * 64) + (ax2 * 8)) + ax3))] =
((int8_t*)coeffW)[(((((ax0 * 2048) + (j_outer * 64)) + (ax2 * 8)) + ax3))];
}
}
}
(void)MACZ_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))),
((uint8_t *)X + ((i * 64))),
((int8_t *)coeffW_global + (0)), 64);
(void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))),
((uint8_t *)X + (((i * 64) + 8))),
((int8_t *)coeffW_global + (8)), 64);
(void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))),
((uint8_t *)X + (((i * 64) + 16))), ((int8_t *)coeffW_global + (16)), 64);
(void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))),
((uint8_t *)X + (((i * 64) + 24))), ((int8_t *)coeffW_global + (24)), 64);
(void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))),
((uint8_t *)X + (((i * 64) + 32))), ((int8_t *)coeffW_global + (32)), 64);
(void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))),
((uint8_t *)X + (((i * 64) + 40))), ((int8_t *)coeffW_global + (40)), 64);
(void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))),
((uint8_t *)X + (((i * 64) + 48))), ((int8_t *)coeffW_global + (48)), 64);
(void)MACC_olimp(((int32_t *)F_global + (((i * 64) + (j_outer * 2)))),
((uint8_t *)X + (((i * 64) + 56))), ((int8_t *)coeffW_global + (56)), 64);
for (int32_t j_inner = 0; j_inner < 2; ++j_inner) {
((int32_t*)F)[((((i * 64) + (j_outer * 2)) + j_inner))] =
((int32_t*)F_global)[((((i * 64) + (j_outer * 2)) + j_inner))];
}
(void)O_VEC_MACZ(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + ((i * 64))), ((int8_t *)W + ((j_outer * 128))), 64);
(void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 8))), ((int8_t *)W + (((j_outer * 128) + 8))), 64);
(void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 16))), ((int8_t *)W + (((j_outer * 128) + 16))), 64);
(void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 24))), ((int8_t *)W + (((j_outer * 128) + 24))), 64);
(void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 32))), ((int8_t *)W + (((j_outer * 128) + 32))), 64);
(void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 40))), ((int8_t *)W + (((j_outer * 128) + 40))), 64);
(void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 48))), ((int8_t *)W + (((j_outer * 128) + 48))), 64);
(void)O_VEC_MACC(((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))), ((uint8_t *)X + (((i * 64) + 56))), ((int8_t *)W + (((j_outer * 128) + 56))), 64);
(void)O_VEC_STOR(((int32_t *)F + (((i * 64) + (j_outer * 2)))), ((int32_t *)F_wmma_accumulator + (((i * 64) + (j_outer * 2)))));
}
}
if (TVMBackendFreeWorkspace(1, dev_id, coeffW_global) != 0) {
return -1;
}
if (TVMBackendFreeWorkspace(1, dev_id, F) != 0) {
return -1;
}
return 0;
}
```

*Note*: [MACC_olimp()](/demo/icebreaker/src/firmware.c#L340) are wrapped **__asm__( ".word 0xRV32custom")** RV32 ISA extension for OLIMP hardware block.
Expand Down
Loading

0 comments on commit b5301da

Please sign in to comment.