Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support VNNI pre-encoded input in AMX lowering. #210

Merged
merged 1 commit into from
Jan 13, 2025

Conversation

ienkovich
Copy link
Collaborator

This patch allows the use of AMX lowering for input that is preliminary VNNI encoded to have an ability do directly load input data to AMX tile registers for RHS. To enable such a scheme without changes in the front-end, I use VNNI decoding sequence in a kernel right before tl.dot. AMX lowering pass then detects this sequence and uses encoded data instead of emitting a code to VNNI encode RHS. For more convenient usage, the required sequence is available as a function in tl.extra.cpu.

Blocked matmul tutorial was modified to support FP16, BF16, and INT8 inputs with optional VNNI packing during the layout change. Enabled VNNI encoding provide nice performance improvement:

         M       N       K  triton-cpu-bb-tb-pb-prepack-bfloat16 triton-cpu-bb-tb-prepack-bfloat16 torch-cpu-native-bfloat16
0    256.0   256.0   256.0                           1030.115281                        997.606247                283.695229
1    384.0   384.0   384.0                           2408.509006                       2315.354053                771.912107
2    512.0   512.0   512.0                           3332.961791                       3275.987277               1517.463289
3    640.0   640.0   640.0                           4105.520000                       3981.941580               2247.206452
4    768.0   768.0   768.0                           5146.984510                       5017.549861               3346.319424
5    896.0   896.0   896.0                           6461.905640                       6180.681542               4062.503898
6   1024.0  1024.0  1024.0                           7049.696742                       6865.913180               6070.727605
7   1152.0  1152.0  1152.0                           8185.642292                       8024.265204               6754.843487
8   1280.0  1280.0  1280.0                           9066.353074                       8876.553991               7957.417071
9   1408.0  1408.0  1408.0                          10279.423034                       9859.907396               9265.297503
10  1536.0  1536.0  1536.0                          11846.624554                      11213.702977              11350.639719
11  1664.0  1664.0  1664.0                          12702.759698                      12124.802459              12606.923785
12  1792.0  1792.0  1792.0                          14220.336458                      13566.518273              13808.570932
13  1920.0  1920.0  1920.0                          15887.340273                      15075.636795              16285.160756
14  2048.0  2048.0  2048.0                          18620.652026                      16831.104536              18155.011342
15  2176.0  2176.0  2176.0                          20144.720598                      18091.425428              17340.249479
16  2304.0  2304.0  2304.0                          22151.868170                      19452.387288              19374.471299
17  2432.0  2432.0  2432.0                          23570.405937                      20818.248122              21151.911884
18  2560.0  2560.0  2560.0                          25702.017528                      22089.320415              22514.952314

@ienkovich ienkovich requested review from int3, minjang and Devjiu January 10, 2025 20:19
@ienkovich ienkovich requested a review from ptillet as a code owner January 10, 2025 20:19
@minjang
Copy link
Collaborator

minjang commented Jan 13, 2025

Thanks for the work! Actually didn't have time to have a full review, but stamping to unblock.

QQ: Is this patch only working with AMX? Can it be used for non-AMX CPUs? Most of recent AVX512 comes with VNNI.

@ienkovich
Copy link
Collaborator Author

At the moment, we don't generate any instructions working with VNNI-encoded data other than AMX. But surely, common methods to detect VNNI encoded data can be re-used for non-AMX code generation.

@ienkovich ienkovich merged commit dc8dfb6 into triton-lang:main Jan 13, 2025
3 checks passed
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Feb 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants