Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support padding for logits and unequal batch size for logits and bitmask #220

Merged
merged 8 commits into from
Feb 26, 2025

Conversation

Ubospica
Copy link
Collaborator

@Ubospica Ubospica commented Feb 26, 2025

This PR supports:

  1. Padding on the vocabulary dimension for logits. vLLM could introduce such padding and this is not supported by the previous kernel.
  2. Unequal batch size for logits and bitmask when indices are specified. When indices are not specified, we require the batch sizes for logits and bitmask the same. When indices are specified, we only require the indices larger than

@Ubospica
Copy link
Collaborator Author

Ubospica commented Feb 26, 2025

Benchmark results are as follows.

Hardware: AMD EPYC 7R13, NVIDIA H100

Before:

__________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-1-128000-1024-1-float32] ___________________________________________________________________
Implementation: cpu     | Execution time (μs): 49.1820
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-1-128000-120000-1-float32] __________________________________________________________________
Implementation: cpu     | Execution time (μs): 557.4110
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-1-128001-120000-1-float32] __________________________________________________________________
Implementation: cpu     | Execution time (μs): 575.6520
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-1-128010-120000-1-float32] __________________________________________________________________
Implementation: cpu     | Execution time (μs): 566.5520
__________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-64-128000-1024-1-float32] __________________________________________________________________
Implementation: cpu     | Execution time (μs): 1032.8570
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-64-128000-120000-1-float32] _________________________________________________________________
Implementation: cpu     | Execution time (μs): 33565.5370
__________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-64-128000-1024-4-float32] __________________________________________________________________
Implementation: cpu     | Execution time (μs): 296.0160
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-64-128000-120000-4-float32] _________________________________________________________________
Implementation: cpu     | Execution time (μs): 8442.0470
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-64-128001-120000-1-float32] _________________________________________________________________
Implementation: cpu     | Execution time (μs): 32963.9440
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-64-128010-120000-1-float32] _________________________________________________________________
Implementation: cpu     | Execution time (μs): 32991.4650
__________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-1-128000-1024-1-float32] __________________________________________________________________
Implementation: cuda    | Execution time (μs): 6.7063
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-1-128000-120000-1-float32] _________________________________________________________________
Implementation: cuda    | Execution time (μs): 6.8728
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-1-128001-120000-1-float32] _________________________________________________________________
Implementation: cuda    | Execution time (μs): 7.0879
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-1-128010-120000-1-float32] _________________________________________________________________
Implementation: cuda    | Execution time (μs): 7.0633
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128000-1024-1-float32] __________________________________________________________________
Implementation: cuda    | Execution time (μs): 14.2442
________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128000-120000-1-float32] _________________________________________________________________
Implementation: cuda    | Execution time (μs): 27.4595
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128000-1024-4-float32] __________________________________________________________________
Implementation: cuda    | Execution time (μs): 8.6667
________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128000-120000-4-float32] _________________________________________________________________
Implementation: cuda    | Execution time (μs): 12.7050
________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128001-120000-1-float32] _________________________________________________________________
Implementation: cuda    | Execution time (μs): 37.6241
________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128010-120000-1-float32] _________________________________________________________________
Implementation: cuda    | Execution time (μs): 37.7105
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128000-1024-1-float16] __________________________________________________________________
Implementation: cuda    | Execution time (μs): 12.3578
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128000-1024-1-bfloat16] _________________________________________________________________
Implementation: cuda    | Execution time (μs): 12.3725
_________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-1-128000-1024-1-float32] _________________________________________________________________
Implementation: triton  | Execution time (μs): 6.3329
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-1-128000-120000-1-float32] ________________________________________________________________
Implementation: triton  | Execution time (μs): 7.0381
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-1-128001-120000-1-float32] ________________________________________________________________
Implementation: triton  | Execution time (μs): 6.4713
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-1-128010-120000-1-float32] ________________________________________________________________
Implementation: triton  | Execution time (μs): 6.4394
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128000-1024-1-float32] _________________________________________________________________
Implementation: triton  | Execution time (μs): 15.0053
_______________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128000-120000-1-float32] ________________________________________________________________
Implementation: triton  | Execution time (μs): 42.8678
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128000-1024-4-float32] _________________________________________________________________
Implementation: triton  | Execution time (μs): 8.7802
_______________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128000-120000-4-float32] ________________________________________________________________
Implementation: triton  | Execution time (μs): 15.4738
_______________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128001-120000-1-float32] ________________________________________________________________
Implementation: triton  | Execution time (μs): 34.3639
_______________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128010-120000-1-float32] ________________________________________________________________
Implementation: triton  | Execution time (μs): 34.2326
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128000-1024-1-float16] _________________________________________________________________
Implementation: triton  | Execution time (μs): 14.4927
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128000-1024-1-bfloat16] ________________________________________________________________
Implementation: triton  | Execution time (μs): 14.5345

After:

__________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-1-128000-1024-1-float32] ___________________________________________________________________
Batch:  1 | Vocab: 128000 | Masked:   1024 | Stride: 1 | DType: torch.float32   | Impl: cpu    | Execution time (μs): 49.8630
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-1-128000-120000-1-float32] __________________________________________________________________
Batch:  1 | Vocab: 128000 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: cpu    | Execution time (μs): 565.0610
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-1-128001-120000-1-float32] __________________________________________________________________
Batch:  1 | Vocab: 128001 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: cpu    | Execution time (μs): 539.4300
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-1-128010-120000-1-float32] __________________________________________________________________
Batch:  1 | Vocab: 128010 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: cpu    | Execution time (μs): 568.3110
__________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-64-128000-1024-1-float32] __________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked:   1024 | Stride: 1 | DType: torch.float32   | Impl: cpu    | Execution time (μs): 914.4110
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-64-128000-120000-1-float32] _________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: cpu    | Execution time (μs): 33995.6310
__________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-64-128000-1024-4-float32] __________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked:   1024 | Stride: 4 | DType: torch.float32   | Impl: cpu    | Execution time (μs): 333.4490
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-64-128000-120000-4-float32] _________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked: 120000 | Stride: 4 | DType: torch.float32   | Impl: cpu    | Execution time (μs): 8533.5650
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-64-128001-120000-1-float32] _________________________________________________________________
Batch: 64 | Vocab: 128001 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: cpu    | Execution time (μs): 33668.2130
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cpu-64-128010-120000-1-float32] _________________________________________________________________
Batch: 64 | Vocab: 128010 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: cpu    | Execution time (μs): 33715.2950
__________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-1-128000-1024-1-float32] __________________________________________________________________
Batch:  1 | Vocab: 128000 | Masked:   1024 | Stride: 1 | DType: torch.float32   | Impl: cuda   | Execution time (μs): 6.7593
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-1-128000-120000-1-float32] _________________________________________________________________
Batch:  1 | Vocab: 128000 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: cuda   | Execution time (μs): 6.8451
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-1-128001-120000-1-float32] _________________________________________________________________
Batch:  1 | Vocab: 128001 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: cuda   | Execution time (μs): 7.0706
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-1-128010-120000-1-float32] _________________________________________________________________
Batch:  1 | Vocab: 128010 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: cuda   | Execution time (μs): 7.1020
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128000-1024-1-float32] __________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked:   1024 | Stride: 1 | DType: torch.float32   | Impl: cuda   | Execution time (μs): 14.3245
________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128000-120000-1-float32] _________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: cuda   | Execution time (μs): 27.4618
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128000-1024-4-float32] __________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked:   1024 | Stride: 4 | DType: torch.float32   | Impl: cuda   | Execution time (μs): 8.6505
________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128000-120000-4-float32] _________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked: 120000 | Stride: 4 | DType: torch.float32   | Impl: cuda   | Execution time (μs): 12.6526
________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128001-120000-1-float32] _________________________________________________________________
Batch: 64 | Vocab: 128001 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: cuda   | Execution time (μs): 37.6556
________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128010-120000-1-float32] _________________________________________________________________
Batch: 64 | Vocab: 128010 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: cuda   | Execution time (μs): 37.6465
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128000-1024-1-float16] __________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked:   1024 | Stride: 1 | DType: torch.float16   | Impl: cuda   | Execution time (μs): 12.3504
_________________________________________________________________ test_apply_token_bitmask_inplace_large[cuda-64-128000-1024-1-bfloat16] _________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked:   1024 | Stride: 1 | DType: torch.bfloat16  | Impl: cuda   | Execution time (μs): 12.3582
_________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-1-128000-1024-1-float32] _________________________________________________________________
Batch:  1 | Vocab: 128000 | Masked:   1024 | Stride: 1 | DType: torch.float32   | Impl: triton | Execution time (μs): 6.2745
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-1-128000-120000-1-float32] ________________________________________________________________
Batch:  1 | Vocab: 128000 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: triton | Execution time (μs): 6.9885
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-1-128001-120000-1-float32] ________________________________________________________________
Batch:  1 | Vocab: 128001 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: triton | Execution time (μs): 6.4483
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-1-128010-120000-1-float32] ________________________________________________________________
Batch:  1 | Vocab: 128010 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: triton | Execution time (μs): 6.4184
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128000-1024-1-float32] _________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked:   1024 | Stride: 1 | DType: torch.float32   | Impl: triton | Execution time (μs): 14.9453
_______________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128000-120000-1-float32] ________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: triton | Execution time (μs): 42.8455
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128000-1024-4-float32] _________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked:   1024 | Stride: 4 | DType: torch.float32   | Impl: triton | Execution time (μs): 23.5396
_______________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128000-120000-4-float32] ________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked: 120000 | Stride: 4 | DType: torch.float32   | Impl: triton | Execution time (μs): 22.5945
_______________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128001-120000-1-float32] ________________________________________________________________
Batch: 64 | Vocab: 128001 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: triton | Execution time (μs): 34.3060
_______________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128010-120000-1-float32] ________________________________________________________________
Batch: 64 | Vocab: 128010 | Masked: 120000 | Stride: 1 | DType: torch.float32   | Impl: triton | Execution time (μs): 34.1925
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128000-1024-1-float16] _________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked:   1024 | Stride: 1 | DType: torch.float16   | Impl: triton | Execution time (μs): 14.5101
________________________________________________________________ test_apply_token_bitmask_inplace_large[triton-64-128000-1024-1-bfloat16] ________________________________________________________________
Batch: 64 | Vocab: 128000 | Masked:   1024 | Stride: 1 | DType: torch.bfloat16  | Impl: triton | Execution time (μs): 14.4740```

@Ubospica Ubospica merged commit 6996ded into mlc-ai:main Feb 26, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant