Skip to content

Commit

Permalink
Matrix multiplication tutorial block pointer variant (triton-lang#1)
Browse files Browse the repository at this point in the history
Adds a `USE_BLOCK_POINTER` flag to the matmul_kernel so we can get IR for pointers-to-tensors instead of tensors-of-pointers.
  • Loading branch information
rolfmorel authored and Devjiu committed Jan 20, 2025
1 parent dc8dfb6 commit 7d75054
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
USE_BLOCK_POINTERS = os.getenv("USE_BLOCK_POINTERS", "1") != "0"
GROUP_SIZE_M = 8
USE_GPU = False
USE_BLOCK_POINTERS = False


@triton.jit
Expand Down Expand Up @@ -216,6 +217,9 @@ def matmul_kernel(
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
if USE_BLOCK_POINTERS:
block_offset_m = pid_m * BLOCK_SIZE_M
block_offset_n = pid_n * BLOCK_SIZE_N

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
Expand Down Expand Up @@ -329,6 +333,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, M: int, N: int, K:
GROUP_SIZE_M=GROUP_SIZE_M, #
USE_BLOCK_POINTERS=USE_BLOCK_POINTERS, #
num_threads=num_threads, #
USE_BLOCK_POINTERS=USE_BLOCK_POINTERS, #
)
return c

Expand Down

0 comments on commit 7d75054

Please sign in to comment.