Skip to content

Commit

Permalink
support packed qkv in MultiHeadAttention op
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Feb 13, 2023
1 parent 89d0dd0 commit aea7176
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 98 deletions.
86 changes: 56 additions & 30 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,53 +22,79 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
int max_threads_per_block) {
// key_padding_mask (K/V) : (B) or (B, L) or None
// relative_position_bias : (B, 1, S, L)
// When no packing for q/k/v:
// query (Q) : (B, S, D)
// key (K) : (B, L, D)
// value (V) : (B, L, D_v)
// bias (Q/K/V) : (D + D + D_v)
// key_padding_mask (K/V) : (B) or (B, L) or None
// relative_position_bias : (B, 1, S, L)
// When packed kv is used:
// query (Q) : (B, S, D)
// key (K) : (B, L, N, 2, H)
// value (V) : None
// bias (Q/K/V) : None
// When packed qkv is used:
// query (Q) : (B, L, N, 3, H)
// key (K) : None
// value (V) : None
// bias (Q/K/V) : None

const auto& query_dims = query->Shape().GetDims();
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
if (query_dims.size() != 3 && query_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions, got ",
query_dims.size());
}

const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3 && key_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}

int batch_size = static_cast<int>(query_dims[0]);
int sequence_length = static_cast<int>(query_dims[1]);
int hidden_size = static_cast<int>(query_dims[2]);
int hidden_size = query_dims.size() == 3 ? static_cast<int>(query_dims[2]) : (num_heads * static_cast<int>(query_dims[4]));
int head_size = static_cast<int>(hidden_size) / num_heads;
int kv_sequence_length = static_cast<int>(key_dims[1]);
int kv_sequence_length = sequence_length;

if (key != nullptr) {
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions when key is given, got ",
query_dims.size());
}

if (key_dims.size() == 3) {
if (key_dims[2] != query_dims[2]) {
const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3 && key_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 2 (hidden_size)");
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}
} else // if (key_dims.size() == 5)
{
if (static_cast<int>(key_dims[2]) != num_heads || static_cast<int>(key_dims[3]) != 2 || static_cast<int>(key_dims[4]) != head_size) {

if (key_dims.size() == 3) {
if (key_dims[2] != query_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 2 (hidden_size)");
}
} else // if (key_dims.size() == 5)
{
if (static_cast<int>(key_dims[2]) != num_heads || static_cast<int>(key_dims[3]) != 2 || static_cast<int>(key_dims[4]) != head_size) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv");
}
if (value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format.");
}
}

kv_sequence_length = static_cast<int>(key_dims[1]);
} else { // packed QKV
if (query_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 5 dimensions when key is empty, got ",
query_dims.size());
}
if (static_cast<int>(query_dims[2]) != num_heads || static_cast<int>(query_dims[3]) != 3) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv");
}
if (value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format.");
"Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv");
}
}

Expand All @@ -82,17 +108,17 @@ Status CheckInputs(const T* query,
// Currently, bias is not allowed for packed KV. This constraint can be removed later.
// Here we assume that fusion tool will not include bias for packed KV.
if (value == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. ");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed qkv or kv. ");
}
}

AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
if (key_padding_mask != nullptr) {
mask_type = AttentionMaskType::MASK_UNKNOWN;
const auto& mask_dims = key_padding_mask->Shape().GetDims();
if (mask_dims.size() == 1 && mask_dims[0] == key_dims[0]) {
if (mask_dims.size() == 1 && mask_dims[0] == static_cast<int64_t>(kv_sequence_length)) {
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
} else if (mask_dims.size() == 2 && mask_dims[0] == key_dims[0] && mask_dims[1] == key_dims[1]) {
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) && mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
}

Expand All @@ -115,7 +141,7 @@ Status CheckInputs(const T* query,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}

if (key_dims[1] != value_dims[1]) {
if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have same same dim 1 (kv_sequence_length)");
}
Expand Down
30 changes: 28 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,15 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,

T* qkv = data.workspace;

bool use_fused_kernel = (nullptr != fused_runner && data.bias != nullptr && !parameters.is_unidirectional);
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);

// Default format for memory efficient attention.
// When there is past state, the format shal be BxNxSxH, so we disable memory efficient attention when there is past.
DUMP_TENSOR_INIT();
if (nullptr != data.gemm_buffer) {
if (data.bias == nullptr) {
assert(nullptr == fused_runner);
// For quantized attention, bias has been added so only need transpose here.
// gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
assert(qk_head_size == v_head_size);
Expand Down Expand Up @@ -303,6 +304,31 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
data.gemm_buffer, data.bias, qkv,
true, v_head_size, qkv_add_bias, 3);
}
} else if (data.key == nullptr) { // gemm_buffer == nullptr and packed qkv
assert(data.bias == nullptr);
assert(qk_head_size == v_head_size);

DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size);

if (use_memory_efficient_attention) {
// unpack qkv to BSNH. Note that there is no bias so we need not output query to q.
constexpr int format = 4;
T* qkv_add_bias = nullptr;
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, qkv,
true, v_head_size, qkv_add_bias, 3);
DUMP_TENSOR_D("k(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
} else {
if (!use_fused_kernel) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed QKV format is not implemented for current GPU. Please disable it in fusion options.");
}

qkv_format = AttentionQkvFormat::QKV_BSN3H;
}
} else if (data.value == nullptr) { // gemm_buffer == nullptr and packed kv
// TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint.
// CheckInputs verified this constraint.
Expand Down Expand Up @@ -330,7 +356,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,

qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
}
} else { // gemm_buffer == nullptr and not packed kv
} else { // gemm_buffer == nullptr and not packed
assert(data.query != nullptr && data.key != nullptr && data.value != nullptr && data.bias != nullptr);

DUMP_TENSOR_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size);
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
disable_memory_efficient_attention_ = true;
#endif

disable_fused_cross_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedCrossAttention, false);
disable_fused_cross_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedCrossAttention, false);
}

template <typename T>
Expand Down Expand Up @@ -97,6 +98,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
bool use_fused_cross_attention = !disable_fused_cross_attention_ &&
nullptr == key_padding_mask &&
nullptr == relative_position_bias &&
key != nullptr &&
(value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV
parameters.hidden_size == parameters.v_hidden_size &&
has_fused_cross_attention_kernel(sm, parameters.head_size,
Expand All @@ -116,7 +118,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
bool use_fused_runner = !disable_fused_runner_ &&
fused_cross_attention_kernel == nullptr &&
nullptr == relative_position_bias &&
value != nullptr && // fused runner requires packed qkv instead of packed kv
key == nullptr && // fused runner requires packed qkv
(nullptr == key_padding_mask || is_mask_1d_seq_len) &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
Expand Down Expand Up @@ -171,7 +173,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.gemm_buffer = nullptr;
data.bias = (nullptr == bias) ? nullptr : reinterpret_cast<const CudaT*>(bias->Data<T>());
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = reinterpret_cast<const CudaT*>(key->Data<T>());
data.key = (nullptr == key) ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
data.value = (nullptr == value) ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data<int>();
data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span<const int64_t>() : key_padding_mask->Shape().GetDims();
Expand Down
40 changes: 32 additions & 8 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,42 @@ void RestorePaddingTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx)
}

void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
// Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
// Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, kv_sequence_length, num_heads, 2, head_size)
// Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or nullptr
// Output 0 has shape (batch_size, sequence_length, v_hidden_size)

// Q, K and V without packing:
// Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
// Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size)
// Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size)

// Packed KV:
// Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
// Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size)
// Input 2 nullptr

// Packed QKV:
// Input 0 (batch_size, sequence_length, num_heads, 3, head_size)
// Input 1 nullptr
// Input 2 nullptr

// Type inference
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);

// Shape inference
if (hasInputShape(ctx, 0)) {
auto& query_shape = getInputShape(ctx, 0);
auto& query_dims = query_shape.dim();
if (query_dims.size() != 3) {
fail_shape_inference("Inputs 0 (query) shall be 3 dimensions");

if (query_dims.size() != 3 && query_dims.size() != 5) {
fail_shape_inference("Inputs 0 (query) shall be 3 or 5 dimensions");
}

if (query_dims.size() == 5) { // packed QKV
ONNX_NAMESPACE::TensorShapeProto output_shape;
*output_shape.add_dim() = query_dims[0];
*output_shape.add_dim() = query_dims[1];
*output_shape.add_dim() = query_dims[2] * query_dims[4];
updateOutputShape(ctx, 0, output_shape);
return;
}

if (hasInputShape(ctx, 2)) {
Expand All @@ -154,11 +176,12 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c
*output_shape.add_dim() = query_dims[1];
*output_shape.add_dim() = value_dims[2];
updateOutputShape(ctx, 0, output_shape);
return;
}

if (hasInputShape(ctx, 1)) {
auto& key_shape = getInputShape(ctx, 1);
if (key_shape.dim().size() == 5) {
if (key_shape.dim().size() == 5) { // packed KV
ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx);
}
}
Expand Down Expand Up @@ -292,12 +315,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
AttributeProto::FLOAT, OPTIONAL_VALUE)
.Input(0,
"query",
"Query with shape (batch_size, sequence_length, hidden_size)",
"Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)",
"T")
.Input(1,
"key",
"Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)",
"T")
"T",
OpSchema::Optional)
.Input(2,
"value",
"Value with shape (batch_size, kv_sequence_length, v_hidden_size)",
Expand Down
51 changes: 33 additions & 18 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,28 +2025,43 @@ def _infer_BiasGelu(self, node):
self._propagate_shape_and_type(node)

def _infer_MultiHeadAttention(self, node):
# Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
# Without packed KV:
# Output 0 has shape (batch_size, sequence_length, v_hidden_size)
# Q, K and V without packing:
# Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
# Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size)
# Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size)
# With packed KV:
# Input 1 (key) has shape (batch_size, kv_sequence_length, num_heads, 2, head_size)
# Input 2 (value) is nullptr
# Output 0 has shape (batch_size, sequence_length, v_hidden_size)
query_shape = self._get_shape(node, 0)
key_shape = self._get_shape(node, 1)
if query_shape is not None and len(query_shape) == 3:
# Packed KV:
# Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
# Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size)
# Input 2 nullptr
# Packed QKV:
# Input 0 (batch_size, sequence_length, num_heads, 3, head_size)
# Input 1 nullptr
# Input 2 nullptr

# By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided.
output_shape = query_shape
if key_shape and len(key_shape) == 3:
value_shape = self._get_shape(node, 2)
if value_shape and len(value_shape) == 3:
output_shape[2] = value_shape[2]
query_shape = self._get_shape(node, 0)
if query_shape is not None:
if len(query_shape) == 3:
key_shape = self._get_shape(node, 1)
# By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided.
output_shape = query_shape
if key_shape and len(key_shape) == 3:
value_shape = self._get_shape(node, 2)
if value_shape and len(value_shape) == 3:
output_shape[2] = value_shape[2]

output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
elif len(query_shape) == 5:
if isinstance(query_shape[2], int) and isinstance(query_shape[4], int):
output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]]
else:
output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"]

output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))

def _infer_FastGelu(self, node):
self._propagate_shape_and_type(node)
Expand Down
Loading

0 comments on commit aea7176

Please sign in to comment.