Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR contains the changes required for FP8 for persistent + ping-pong schedule, with pre-transpose of V in python.
The following are the experimental results from PCIe based H100:
causal=False, headdim=128, batch_size=32, seqlen=512
Pytorch fwd: 52.28 TFLOPs/s, 1.3145040992336967 ms,
Flash3 fwd: 433.25 TFLOPs/s, 0.15861446736380458 ms,
cuDNN fwd: 539.92 TFLOPs/s, 0.12727656479304036 ms,
causal=False, headdim=128, batch_size=16, seqlen=1024
Pytorch fwd: 66.04 TFLOPs/s, 2.0810180343687534 ms,
Flash3 fwd: 538.23 TFLOPs/s, 0.2553513351206978 ms,
cuDNN fwd: 627.66 TFLOPs/s, 0.21897076318661374 ms,
causal=False, headdim=128, batch_size=8, seqlen=2048
Pytorch fwd: 73.80 TFLOPs/s, 3.7247368328583734 ms,
Flash3 fwd: 605.27 TFLOPs/s, 0.4541443311609328 ms,
cuDNN fwd: 674.59 TFLOPs/s, 0.4074742978749176 ms,
causal=False, headdim=128, batch_size=4, seqlen=4224
Pytorch fwd: 68.44 TFLOPs/s, 8.542563331623873 ms,
Flash3 fwd: 618.83 TFLOPs/s, 0.9447746366883317 ms,
cuDNN fwd: 612.95 TFLOPs/s, 0.9538280001531045 ms,
causal=False, headdim=128, batch_size=2, seqlen=8448
Pytorch fwd: 78.47 TFLOPs/s, 14.902190600211421 ms,
Flash3 fwd: 618.05 TFLOPs/s, 1.8919179681688547 ms,
cuDNN fwd: 668.55 TFLOPs/s, 1.7490218665140371 ms,
causal=False, headdim=128, batch_size=1, seqlen=16896
Pytorch fwd: 82.08 TFLOPs/s, 28.49179443437606 ms,
Flash3 fwd: 609.18 TFLOPs/s, 3.8389370311051607 ms,
cuDNN fwd: 653.12 TFLOPs/s, 3.5806546336971223 ms,
causal=False, headdim=256, batch_size=32, seqlen=512
Pytorch fwd: 67.70 TFLOPs/s, 1.015003933571279 ms,
Flash3 fwd: 554.54 TFLOPs/s, 0.12392113373304407 ms,
cuDNN fwd: 527.28 TFLOPs/s, 0.13032780261710286 ms,
causal=False, headdim=256, batch_size=16, seqlen=1024
Pytorch fwd: 95.30 TFLOPs/s, 1.4421883000371356 ms,
Flash3 fwd: 726.88 TFLOPs/s, 0.18908173466722172 ms,
cuDNN fwd: 670.95 TFLOPs/s, 0.20484313135966659 ms,
causal=False, headdim=256, batch_size=8, seqlen=2048
Pytorch fwd: 111.12 TFLOPs/s, 2.4737742030993104 ms,
Flash3 fwd: 840.47 TFLOPs/s, 0.3270527347922325 ms,
cuDNN fwd: 747.60 TFLOPs/s, 0.36767896575232345 ms,
causal=False, headdim=256, batch_size=4, seqlen=4224
Pytorch fwd: 107.88 TFLOPs/s, 5.419503967277706 ms,
Flash3 fwd: 818.27 TFLOPs/s, 0.7144988669703404 ms,
cuDNN fwd: 771.92 TFLOPs/s, 0.7573972029301028 ms,
causal=False, headdim=256, batch_size=2, seqlen=8448
Pytorch fwd: 119.90 TFLOPs/s, 9.752394099875044 ms,
Flash3 fwd: 872.29 TFLOPs/s, 1.3405053333068886 ms,
cuDNN fwd: 790.64 TFLOPs/s, 1.4789396664127707 ms,
causal=False, headdim=256, batch_size=1, seqlen=16896
Pytorch fwd: 131.30 TFLOPs/s, 17.811588799425714 ms,
Flash3 fwd: 832.29 TFLOPs/s, 2.8098559317489467 ms,
cuDNN fwd: 765.44 TFLOPs/s, 3.0552385336098573 ms,
causal=True, headdim=128, batch_size=32, seqlen=512
Pytorch fwd: 18.93 TFLOPs/s, 1.8149359345746536 ms,
Flash3 fwd: 265.59 TFLOPs/s, 0.1293713653770586 ms,
cuDNN fwd: 208.80 TFLOPs/s, 0.16455440005908412 ms,
causal=True, headdim=128, batch_size=16, seqlen=1024
Pytorch fwd: 22.39 TFLOPs/s, 3.069461334962398 ms,
Flash3 fwd: 385.26 TFLOPs/s, 0.17837323248386383 ms,
cuDNN fwd: 302.32 TFLOPs/s, 0.2273046683209638 ms,
causal=True, headdim=128, batch_size=8, seqlen=2048
Pytorch fwd: 23.35 TFLOPs/s, 5.8848636341281235 ms,
Flash3 fwd: 471.96 TFLOPs/s, 0.2912060357630253 ms,
cuDNN fwd: 401.16 TFLOPs/s, 0.3426045334587494 ms,
causal=True, headdim=128, batch_size=4, seqlen=4224
Pytorch fwd: 21.16 TFLOPs/s, 13.818083164126922 ms,
Flash3 fwd: 546.11 TFLOPs/s, 0.535288363850365 ms,
cuDNN fwd: 500.68 TFLOPs/s, 0.5838585668243468 ms,
causal=True, headdim=128, batch_size=2, seqlen=8448
Pytorch fwd: 23.33 TFLOPs/s, 25.05821470015993 ms,
Flash3 fwd: 588.62 TFLOPs/s, 0.9932667676669855 ms,
cuDNN fwd: 637.57 TFLOPs/s, 0.9170004671129087 ms,
causal=True, headdim=128, batch_size=1, seqlen=16896
Pytorch fwd: 23.21 TFLOPs/s, 50.376103636032596 ms,
Flash3 fwd: 583.35 TFLOPs/s, 2.004469364571075 ms,
cuDNN fwd: 629.06 TFLOPs/s, 1.8588003974097471 ms,
causal=True, headdim=256, batch_size=32, seqlen=512
Pytorch fwd: 26.92 TFLOPs/s, 1.27617100176091 ms,
Flash3 fwd: 279.40 TFLOPs/s, 0.12297740128512183 ms,
cuDNN fwd: 240.85 TFLOPs/s, 0.14266143552958965 ms,
causal=True, headdim=256, batch_size=16, seqlen=1024
Pytorch fwd: 35.24 TFLOPs/s, 1.9501414654466012 ms,
Flash3 fwd: 438.84 TFLOPs/s, 0.15659343528871736 ms,
cuDNN fwd: 341.00 TFLOPs/s, 0.20152316816772023 ms,
causal=True, headdim=256, batch_size=8, seqlen=2048
Pytorch fwd: 38.49 TFLOPs/s, 3.570398602945109 ms,
Flash3 fwd: 580.29 TFLOPs/s, 0.23684673166523376 ms,
cuDNN fwd: 560.02 TFLOPs/s, 0.24541896612693867 ms,
causal=True, headdim=256, batch_size=4, seqlen=4224
Pytorch fwd: 36.22 TFLOPs/s, 8.069820566258082 ms,
Flash3 fwd: 680.26 TFLOPs/s, 0.42972423446675145 ms,
cuDNN fwd: 686.43 TFLOPs/s, 0.42586706501121324 ms,
causal=True, headdim=256, batch_size=2, seqlen=8448
Pytorch fwd: 38.51 TFLOPs/s, 15.182181568040203 ms,
Flash3 fwd: 738.05 TFLOPs/s, 0.7921532999413708 ms,
cuDNN fwd: 740.54 TFLOPs/s, 0.7894946339850625 ms,
causal=True, headdim=256, batch_size=1, seqlen=16896
Pytorch fwd: 38.41 TFLOPs/s, 30.4462659987621 ms,
Flash3 fwd: 744.65 TFLOPs/s, 1.5702644324240587 ms,
cuDNN fwd: 753.08 TFLOPs/s, 1.5526972672281165 ms,