Skip to content

Commit

Permalink
[Feature] Support padding for logits and unequal batch size for logit…
Browse files Browse the repository at this point in the history
…s and bitmask (#220)

This PR supports:
1. Padding on the vocabulary dimension for logits. vLLM could introduce
such padding and this is not supported by the previous kernel.
2. Unequal batch size for logits and bitmask when indices are specified.
When indices are not specified, we require the batch sizes for logits
and bitmask the same. When indices are specified, we only require the
indices larger than
  • Loading branch information
Ubospica authored Feb 26, 2025
1 parent 77837dc commit 6996ded
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 132 deletions.
72 changes: 39 additions & 33 deletions cpp/grammar_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "support/int_set.h"
#include "support/logging.h"
#include "testing.h"

namespace xgrammar {

/******************* Tool functions for token mask *******************/
Expand Down Expand Up @@ -69,59 +70,64 @@ void _DebugGetMaskedTokensFromBitmask(
void ApplyTokenBitmaskInplaceCPU(
DLTensor* logits, const DLTensor& bitmask, std::optional<std::vector<int>> indices
) {
// Check device and dim
XGRAMMAR_CHECK(logits->device.device_type == kDLCPU)
<< "The provided logits's device is not valid: should be CPU";
XGRAMMAR_CHECK(bitmask.device.device_type == kDLCPU)
<< "The provided bitmask's device is not valid: should be CPU";
int batch_size;
int vocab_size;
if (logits->ndim == 2) {
batch_size = logits->shape[0];
vocab_size = logits->shape[1];
} else {
batch_size = 1;
vocab_size = logits->shape[0];
}
int bitmask_size = GetBitmaskSize(vocab_size);
if (bitmask.ndim == 2) {
XGRAMMAR_CHECK(bitmask.shape[0] == batch_size)
<< "The provided bitmask's batch size is not consistent with logits";
XGRAMMAR_CHECK(bitmask.shape[1] == bitmask_size)
<< "The provided bitmask's bitmask size is not consistent with logits";
} else {
XGRAMMAR_CHECK(bitmask.ndim == 1)
<< "The provided bitmask's shape is not valid: should be (batch_size, vocab_size)";
XGRAMMAR_CHECK(bitmask.shape[0] == bitmask_size)
<< "The provided bitmask's bitmask size is not consistent with logits";
}
XGRAMMAR_CHECK(logits->ndim == 2 || logits->ndim == 1)
<< "The provided logits's shape is not valid: should be 2D or 1D";
XGRAMMAR_CHECK(bitmask.ndim == 2 || bitmask.ndim == 1)
<< "The provided bitmask's shape is not valid: should be 2D or 1D";

// Check type
XGRAMMAR_CHECK(
logits->dtype.code == kDLFloat && logits->dtype.bits == 32 && logits->dtype.lanes == 1
) << "The provided logits's dtype is not valid: should be float32";
XGRAMMAR_CHECK(
bitmask.dtype.code == kDLInt && bitmask.dtype.bits == 32 && bitmask.dtype.lanes == 1
) << "The provided bitmask's dtype is not valid: should be int32";

// Check shape
std::pair<int, int> logits_shape =
logits->ndim == 2
? std::make_pair(static_cast<int>(logits->shape[0]), static_cast<int>(logits->shape[1]))
: std::make_pair(1, static_cast<int>(logits->shape[0]));
std::pair<int, int> bitmask_shape =
bitmask.ndim == 2
? std::make_pair(static_cast<int>(bitmask.shape[0]), static_cast<int>(bitmask.shape[1]))
: std::make_pair(1, static_cast<int>(bitmask.shape[0]));

// logits may have extra paddings (in vLLM) so its vocab size can be larger than the bitmask's
// vocab size. So we are using >= instead of == here
XGRAMMAR_CHECK(GetBitmaskSize(logits_shape.second) >= bitmask_shape.second)
<< "The provided logits's vocab size should be no less than the bitmask's vocab size "
"(converted from bitmask size). But got vocab size "
<< logits_shape.second << " vs bitmask size " << bitmask_shape.second;

int vocab_size =
std::min(logits_shape.second, bitmask_shape.second * DynamicBitset::BITS_PER_BLOCK);

// Sort and deduplicate indices
std::vector<int> indices_value;
if (indices.has_value()) {
indices_value = indices.value();
std::sort(indices_value.begin(), indices_value.end());
indices_value.erase(
std::unique(indices_value.begin(), indices_value.end()), indices_value.end()
);
XGRAMMAR_CHECK(indices_value.back() < batch_size)
<< "The provided indices is out of bounds: " << indices_value.back()
<< " >= " << batch_size;
} else {
indices_value.resize(batch_size);
for (int i = 0; i < batch_size; ++i) {
indices_value[i] = i;
XGRAMMAR_CHECK(logits_shape.first == bitmask_shape.first)
<< "When indices is not provided, the logits's batch size should be equal to the "
"bitmask's batch size, but got "
<< logits_shape.first << " vs " << bitmask_shape.first;
indices_value.reserve(logits_shape.first);
for (int i = 0; i < logits_shape.first; ++i) {
indices_value.push_back(i);
}
}

// Apply mask
for (auto idx : indices_value) {
uint32_t* data_ptr = reinterpret_cast<uint32_t*>(bitmask.data) + idx * bitmask_size;
uint32_t* data_ptr = reinterpret_cast<uint32_t*>(bitmask.data) + idx * bitmask_shape.second;
DynamicBitset bitset(vocab_size, data_ptr);
auto logits_ptr = reinterpret_cast<float*>(logits->data) + idx * vocab_size;
auto logits_ptr = reinterpret_cast<float*>(logits->data) + idx * logits_shape.second;
for (int i = bitset.FindFirstZero(); i != -1; i = bitset.FindNextZero(i)) {
logits_ptr[i] = -std::numeric_limits<float>::infinity();
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/support/dynamic_bitset.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ class DynamicBitset {
return (data_[buffer_size_ - 1] & last_block_mask) == last_block_mask;
}

static constexpr int BITS_PER_BLOCK = 32;

private:
static int LowestBit(uint32_t value) {
#ifdef __GNUC__
Expand Down Expand Up @@ -203,7 +205,6 @@ class DynamicBitset {
return position * BITS_PER_BLOCK + LowestBit(~data_[position]);
}

static constexpr int BITS_PER_BLOCK = 32;
// The size of the bitset.
int size_;
// The size of the buffer.
Expand Down
142 changes: 84 additions & 58 deletions python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
#define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
#endif

constexpr int32_t kBitsPerMaskElement = 32;
constexpr int32_t kThreadsPerBlock = 256;
constexpr int32_t BITS_PER_BLOCK = 32;
constexpr int32_t THREADS_PER_THREAD_BLOCK = 256;

template <typename T>
__device__ T NegativeInfinity() {
Expand Down Expand Up @@ -61,34 +61,35 @@ __device__ PackedT PackedNegativeInfinity() {
}

template <typename T, typename PackedT, int32_t kBitsPerThread>
__global__ void __launch_bounds__(kThreadsPerBlock) LogitsBitmaskKernel(
__global__ void __launch_bounds__(THREADS_PER_THREAD_BLOCK) LogitsBitmaskKernel(
T* __restrict__ logits,
const int32_t* __restrict__ bitmask,
const int32_t* __restrict__ indices,
int32_t vocab_size,
int32_t bitmask_size
int32_t logits_stride,
int32_t bitmask_stride
) {
constexpr int kAlignment = sizeof(PackedT) / sizeof(T);
constexpr uint32_t kPackedMask = (1 << kAlignment) - 1;

const int batch_idx = (indices == nullptr) ? blockIdx.y : indices[blockIdx.y];

const int block_offset = blockIdx.x * kThreadsPerBlock * kBitsPerThread;
T* logits_gmem_ptr = logits + batch_idx * vocab_size + block_offset;
const int block_offset = blockIdx.x * THREADS_PER_THREAD_BLOCK * kBitsPerThread;
T* logits_gmem_ptr = logits + batch_idx * logits_stride + block_offset;
const int32_t* bitmask_gmem_ptr =
bitmask + batch_idx * bitmask_size + block_offset / kBitsPerMaskElement;
const int bitmask_inner_idx = threadIdx.x % (kBitsPerMaskElement / kAlignment);
bitmask + batch_idx * bitmask_stride + block_offset / BITS_PER_BLOCK;
const int bitmask_inner_idx = threadIdx.x % (BITS_PER_BLOCK / kAlignment);
T logits_reg[kAlignment];

#pragma unroll
for (int offset = threadIdx.x * kAlignment; offset < kThreadsPerBlock * kBitsPerThread;
offset += kThreadsPerBlock * kAlignment) {
for (int offset = threadIdx.x * kAlignment; offset < THREADS_PER_THREAD_BLOCK * kBitsPerThread;
offset += THREADS_PER_THREAD_BLOCK * kAlignment) {
if (block_offset + offset >= vocab_size) {
break;
}

const uint32_t bitmask_val =
(~bitmask_gmem_ptr[offset / kBitsPerMaskElement] >> (bitmask_inner_idx * kAlignment)) &
(~bitmask_gmem_ptr[offset / BITS_PER_BLOCK] >> (bitmask_inner_idx * kAlignment)) &
kPackedMask;

if (bitmask_val == 0) {
Expand Down Expand Up @@ -122,32 +123,38 @@ void ApplyTokenBitmaskInplaceDispatchToBitsPerThread(
const int32_t* __restrict__ bitmask,
const int32_t* __restrict__ indices,
int32_t vocab_size,
int32_t bitmask_size,
int32_t batch_size
int32_t logits_stride,
int32_t bitmask_stride,
int32_t num_rows
) {
constexpr int kAlignment = sizeof(PackedT) / sizeof(T);
const int32_t num_blocks_per_row = CeilDiv(2048 / kThreadsPerBlock * 128, batch_size);
const int32_t num_bits_per_thread = CeilDiv(vocab_size, kThreadsPerBlock * num_blocks_per_row);
const int32_t num_blocks_per_row = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows);
const int32_t num_bits_per_thread =
CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row);

const dim3 block(kThreadsPerBlock);
const dim3 block(THREADS_PER_THREAD_BLOCK);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();

if (num_bits_per_thread <= 4 && kAlignment <= 4) {
const dim3 grid(CeilDiv(vocab_size, kThreadsPerBlock * 4), batch_size);
LogitsBitmaskKernel<T, PackedT, 4>
<<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, bitmask_size);
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows);
LogitsBitmaskKernel<T, PackedT, 4><<<grid, block, 0, stream>>>(
logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride
);
} else if (num_bits_per_thread <= 8 && kAlignment <= 8) {
const dim3 grid(CeilDiv(vocab_size, kThreadsPerBlock * 8), batch_size);
LogitsBitmaskKernel<T, PackedT, 8>
<<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, bitmask_size);
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows);
LogitsBitmaskKernel<T, PackedT, 8><<<grid, block, 0, stream>>>(
logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride
);
} else if (num_bits_per_thread <= 16 && kAlignment <= 16) {
const dim3 grid(CeilDiv(vocab_size, kThreadsPerBlock * 16), batch_size);
LogitsBitmaskKernel<T, PackedT, 16>
<<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, bitmask_size);
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows);
LogitsBitmaskKernel<T, PackedT, 16><<<grid, block, 0, stream>>>(
logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride
);
} else {
const dim3 grid(CeilDiv(vocab_size, kThreadsPerBlock * 32), batch_size);
LogitsBitmaskKernel<T, PackedT, 32>
<<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, bitmask_size);
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows);
LogitsBitmaskKernel<T, PackedT, 32><<<grid, block, 0, stream>>>(
logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride
);
}
}

Expand All @@ -157,16 +164,17 @@ void ApplyTokenBitmaskInplaceDispatchToPackedT(
const int32_t* __restrict__ bitmask,
const int32_t* __restrict__ indices,
int32_t vocab_size,
int32_t bitmask_size,
int32_t batch_size
int32_t logits_stride,
int32_t bitmask_stride,
int32_t num_rows
) {
if (vocab_size % (sizeof(float4) / sizeof(T)) == 0) {
if (logits_stride % (sizeof(float4) / sizeof(T)) == 0) {
ApplyTokenBitmaskInplaceDispatchToBitsPerThread<T, float4>(
logits, bitmask, indices, vocab_size, bitmask_size, batch_size
logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows
);
} else {
ApplyTokenBitmaskInplaceDispatchToBitsPerThread<T, T>(
logits, bitmask, indices, vocab_size, bitmask_size, batch_size
logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows
);
}
}
Expand All @@ -177,35 +185,50 @@ void ApplyTokenBitmaskInplace(
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor.");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous.");
TORCH_CHECK(logits.dim() == 1 || logits.dim() == 2, "logits must be a 1D or 2D tensor.");
int32_t batch_size = 1;
int32_t vocab_size = logits.size(0);
if (logits.dim() == 2) {
batch_size = logits.size(0);
vocab_size = logits.size(1);
}
std::pair<int32_t, int32_t> logits_shape =
logits.dim() == 2
? std::make_pair(
static_cast<int32_t>(logits.size(0)), static_cast<int32_t>(logits.size(1))
)
: std::make_pair(1, static_cast<int32_t>(logits.size(0)));

TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor.");
TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous.");
TORCH_CHECK(bitmask.dim() == 1 || bitmask.dim() == 2, "bitmask must be a 1D or 2D tensor.");
int32_t bitmask_batch_size = 1;
int32_t bitmask_size = bitmask.size(0);
if (bitmask.dim() == 2) {
bitmask_batch_size = bitmask.size(0);
bitmask_size = bitmask.size(1);
}
TORCH_CHECK(bitmask_batch_size == batch_size, "bitmask must have the batch size same to logits.");
std::pair<int32_t, int32_t> bitmask_shape =
bitmask.dim() == 2
? std::make_pair(
static_cast<int32_t>(bitmask.size(0)), static_cast<int32_t>(bitmask.size(1))
)
: std::make_pair(1, static_cast<int32_t>(bitmask.size(0)));

TORCH_CHECK(bitmask.dtype() == torch::kInt32, "bitmask must be of type int32.");

TORCH_CHECK(
bitmask_size == CeilDiv(vocab_size, kBitsPerMaskElement),
"bitmask must have the hidden size equal to CeilDiv(vocab_size, 32), but got vocab_size=",
vocab_size,
" and bitmask_size=",
bitmask_size
(logits_shape.second + BITS_PER_BLOCK - 1) / BITS_PER_BLOCK >= bitmask_shape.second,
"The provided logits's vocab size should be no less than the bitmask's vocab size "
"(converted from bitmask size). But got vocab size ",
logits_shape.second,
" vs bitmask size ",
bitmask_shape.second
);

int vocab_size = std::min(logits_shape.second, bitmask_shape.second * BITS_PER_BLOCK);

int32_t num_rows = logits_shape.first;
int32_t* indices_ptr = nullptr;
if (indices) {
batch_size = indices->size(0);
TORCH_CHECK(indices->is_cuda(), "indices must be a CUDA tensor.");
TORCH_CHECK(indices->is_contiguous(), "indices must be contiguous.");
TORCH_CHECK(indices->dim() == 1, "indices must be a 1D tensor.");
TORCH_CHECK(indices->dtype() == torch::kInt32, "indices must be of type int32.");
num_rows = indices->size(0);
indices_ptr = indices->data_ptr<int32_t>();
} else {
TORCH_CHECK(
logits_shape.first == bitmask_shape.first,
"logits and bitmask must have the same batch size."
);
}

switch (logits.scalar_type()) {
Expand All @@ -215,8 +238,9 @@ void ApplyTokenBitmaskInplace(
bitmask.data_ptr<int32_t>(),
indices_ptr,
vocab_size,
bitmask_size,
batch_size
logits_shape.second,
bitmask_shape.second,
num_rows
);
break;
}
Expand All @@ -226,8 +250,9 @@ void ApplyTokenBitmaskInplace(
bitmask.data_ptr<int32_t>(),
indices_ptr,
vocab_size,
bitmask_size,
batch_size
logits_shape.second,
bitmask_shape.second,
num_rows
);
break;
}
Expand All @@ -237,8 +262,9 @@ void ApplyTokenBitmaskInplace(
bitmask.data_ptr<int32_t>(),
indices_ptr,
vocab_size,
bitmask_size,
batch_size
logits_shape.second,
bitmask_shape.second,
num_rows
);
break;
}
Expand Down
2 changes: 2 additions & 0 deletions python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,6 @@ def apply_token_bitmask_inplace_cuda(
) -> None:
if isinstance(indices, list):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
if indices is not None:
indices = indices.to(logits.device)
torch.ops.xgrammar.apply_token_bitmask_inplace_cuda(logits, bitmask, indices)
Loading

0 comments on commit 6996ded

Please sign in to comment.