Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GroupQueryAttention with KV-Cache #3425

Merged
merged 81 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 78 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
194d5e9
Add rmsnorms, gqa
turneram May 24, 2024
157c576
Checkpoint
turneram Jul 9, 2024
06321e5
JIT impl running
turneram Aug 1, 2024
d3ac4c8
Merge
turneram Aug 1, 2024
dfcc73f
Split gpu op
turneram Aug 5, 2024
692a404
Checkpoint
turneram Aug 16, 2024
dcaba12
Improve rotary embedding perf; add scale factor to sln
turneram Aug 21, 2024
0d20b7f
Convert to float for reduce mean
turneram Aug 28, 2024
75788e1
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Aug 29, 2024
9040735
Merge
turneram Sep 5, 2024
875dd2e
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Sep 5, 2024
9e72f3b
Clean up
turneram Sep 5, 2024
5ece981
Clean up
turneram Sep 6, 2024
332fde6
Formatting
turneram Sep 6, 2024
c79b495
Undo changes to format.py
turneram Sep 6, 2024
695c4b9
Sync unchanged files
turneram Sep 6, 2024
79efe42
Sync unchange files
turneram Sep 6, 2024
442d055
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Sep 6, 2024
4fd0902
Refactoring
turneram Sep 11, 2024
1e31f03
Formatting
turneram Sep 11, 2024
cd194b2
Merge branch 'develop' into gqa-jit
turneram Sep 11, 2024
829324e
Fix CI issues
turneram Sep 12, 2024
979f4e2
Formatting
turneram Sep 12, 2024
d6b60a2
Merge branch 'gqa-jit' of https://github.com/ROCm/AMDMIGraphX into gq…
turneram Sep 12, 2024
90a2375
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Sep 12, 2024
736c0b2
Remove unused onnx files
turneram Sep 12, 2024
a9d9f9d
Fix clang tidy
turneram Sep 13, 2024
1a21ee1
Tidy fixes
turneram Sep 13, 2024
e8933cd
Formatting
turneram Sep 13, 2024
c420f26
Update simplified_layernorm tests
turneram Sep 13, 2024
fb6b6de
Formatting
turneram Sep 13, 2024
4b54ac5
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Sep 13, 2024
541a406
Rename to instructions_tuple and add test
turneram Sep 18, 2024
496e213
Formatting
turneram Sep 18, 2024
bc4a240
Add parser tests
turneram Sep 18, 2024
eaa0a87
Formatting
turneram Sep 18, 2024
291ba66
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Sep 18, 2024
41622b0
Only use packed qkv format
turneram Sep 18, 2024
141a1bf
Formatting
turneram Sep 18, 2024
41d2af7
Formatting
turneram Sep 18, 2024
7617ae5
Formatting
turneram Sep 18, 2024
3a99379
Clang tidy and codecov
turneram Sep 24, 2024
20c0f15
Formatting
turneram Sep 24, 2024
54ff0e9
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Sep 24, 2024
d94c2e4
Review comments
turneram Oct 1, 2024
08e1d35
Formatting
turneram Oct 1, 2024
5991822
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Oct 1, 2024
691c447
Remove instructions_tuple op
turneram Oct 2, 2024
56f7e96
Formatting
turneram Oct 2, 2024
eb734cb
Use tensor_views for ref gemms
turneram Oct 4, 2024
6092f21
Formatting
turneram Oct 4, 2024
f19f41d
Use constants for gqa_params
turneram Oct 4, 2024
37dc23e
Formatting
turneram Oct 4, 2024
facef36
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Oct 4, 2024
024953a
Tidy
turneram Oct 4, 2024
124fbb8
Tidy + formatting
turneram Oct 7, 2024
370b71c
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Oct 7, 2024
343c91c
Tidy + windows build
turneram Oct 7, 2024
c7b8590
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Oct 7, 2024
ada9b6d
Merge branch 'develop' into gqa-jit
turneram Oct 9, 2024
bbf73a8
Use literals in verify tests
turneram Oct 9, 2024
7ac592f
Formatting
turneram Oct 9, 2024
a892630
Merge branch 'gqa-jit' of https://github.com/ROCm/AMDMIGraphX into gq…
turneram Oct 9, 2024
946b619
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Oct 9, 2024
0ccf30c
Review comments
turneram Oct 10, 2024
54dfefb
Formatting
turneram Oct 10, 2024
3dd21b7
Formatting
turneram Oct 10, 2024
72d0985
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Oct 10, 2024
a67c609
Removed transposed param
turneram Oct 10, 2024
6128100
Formatting
turneram Oct 10, 2024
b107fb4
Formatting
turneram Oct 10, 2024
9d4fcb2
Revert accidental deletion
turneram Oct 10, 2024
441b390
Formatting
turneram Oct 10, 2024
2a486f4
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Oct 10, 2024
2903dcb
Remove transposed from gqa_parameters
turneram Oct 10, 2024
7c53d34
Formatting
turneram Oct 10, 2024
0a0551a
Clip verify test params instead of using lits
turneram Oct 10, 2024
aa8e18b
Formatting
turneram Oct 10, 2024
bd974c0
Review comments
turneram Oct 10, 2024
ad57d17
Formatting
turneram Oct 10, 2024
725f34f
Merge remote-tracking branch 'origin/develop' into gqa-jit
turneram Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ register_migraphx_ops(
gathernd
get_tuple_elem
greater
group_query_attention
gru
identity
if_op
Expand Down
572 changes: 572 additions & 0 deletions src/include/migraphx/op/group_query_attention.hpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/include/migraphx/operators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/group_query_attention.hpp>
turneram marked this conversation as resolved.
Show resolved Hide resolved
#include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/if_op.hpp>
Expand Down
101 changes: 101 additions & 0 deletions src/onnx/parse_group_query_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

struct parse_group_query_attention : op_parser<parse_group_query_attention>
{
std::vector<op_desc> operators() const { return {{"GroupQueryAttention"}}; }

std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
bool do_rotary = false;
std::size_t kv_num_heads = 0;
int local_window_size = -1;
std::size_t num_heads = 1;
bool rotary_interleaved = false;
float scale = 0.0;
if(contains(info.attributes, "do_rotary"))
{
do_rotary = parser.parse_value(info.attributes.at("do_rotary")).at<bool>();
}
if(contains(info.attributes, "kv_num_heads"))
{
kv_num_heads = parser.parse_value(info.attributes.at("kv_num_heads")).at<std::size_t>();
}
if(contains(info.attributes, "local_window_size"))
{
local_window_size =
parser.parse_value(info.attributes.at("local_window_size")).at<int>();
}
if(contains(info.attributes, "num_heads"))
{
num_heads = parser.parse_value(info.attributes.at("num_heads")).at<std::size_t>();
}
if(contains(info.attributes, "rotary_interleaved"))
{
rotary_interleaved =
parser.parse_value(info.attributes.at("rotary_interleaved")).at<bool>();
}
if(contains(info.attributes, "scale"))
{
scale = parser.parse_value(info.attributes.at("scale")).at<float>();
}

if(args.size() < 7 or args.size() > 9)
{
MIGRAPHX_THROW("GroupQueryAttention: Wrong number of inputs provided");

Check warning on line 77 in src/onnx/parse_group_query_attention.cpp

View check run for this annotation

Codecov / codecov/patch

src/onnx/parse_group_query_attention.cpp#L77

Added line #L77 was not covered by tests
}

auto present_kv_seqlen = args.at(args.size() - 6)->get_shape().lens()[2];
auto gqa = info.add_instruction(make_op("group_query_attention",
{{"do_rotary", do_rotary},
{"kv_num_heads", kv_num_heads},
{"local_window_size", local_window_size},
{"num_heads", num_heads},
{"rotary_interleaved", rotary_interleaved},
{"scale", scale},
{"present_kv_seqlen", present_kv_seqlen}}),
args);
auto gqa_output = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), gqa);
auto gqa_present_key = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), gqa);
auto gqa_present_value =
info.add_instruction(make_op("get_tuple_elem", {{"index", 2}}), gqa);

return {gqa_output, gqa_present_key, gqa_present_value};
}
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
7 changes: 6 additions & 1 deletion src/onnx/parse_simplified_layer_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,13 @@ struct parse_simplified_layer_normalization : op_parser<parse_simplified_layer_n
MIGRAPHX_THROW("PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input shape");
}

auto x_sq = info.add_common_op("mul", x, x);
// Convert to float before reduce_mean
// Fp16 reduce_mean on GPU causes loss of accuracy
auto float_x = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::float_type}}), x);
auto x_sq = info.add_common_op("mul", float_x, float_x);
auto rms = info.add_instruction(make_op("reduce_mean", {{"axes", {axis}}}), x_sq);
rms = info.add_instruction(make_op("convert", {{"target_type", x_dtype}}), rms);
auto mean = rms;
epsilon =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
Expand Down
7 changes: 6 additions & 1 deletion src/onnx/parse_skip_simplified_layer_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,13 @@ struct parse_skip_simplified_layer_normalization
}

x = info.add_common_op("add", x, skip);
auto x_sq = info.add_common_op("mul", x, x);
// Convert to float before reduce_mean
// Fp16 reduce_mean on GPU causes loss of accuracy
auto float_x = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::float_type}}), x);
auto x_sq = info.add_common_op("mul", float_x, float_x);
auto rms = info.add_instruction(make_op("reduce_mean", {{"axes", {axis}}}), x_sq);
rms = info.add_instruction(make_op("convert", {{"target_type", x_dtype}}), rms);
auto mean = rms;
epsilon =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
Expand Down
3 changes: 2 additions & 1 deletion src/targets/gpu/compile_hip_code_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options.global,
options.local,
options.inputs,
options.output};
options.output,
options.output_arg};
}

} // namespace gpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ struct hip_compile_options
std::vector<std::string> params = {};
std::vector<shape> virtual_inputs = {};
std::vector<src_file> additional_src_files = {};
std::int64_t output_arg = -1;

/**
* @brief Set the launch parameters but allow v to override the values
Expand Down
135 changes: 135 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/group_query_attention.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_GROUP_QUERY_ATTENTION_HPP
#define MIGRAPHX_GUARD_GPU_GROUP_QUERY_ATTENTION_HPP

#include <migraphx/stringutils.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/value.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

struct gqa_parameters
{
float scale;
std::uint32_t batch_size; // Batch size used by input
std::uint32_t sequence_length; // Sequence length used by input
std::uint32_t hidden_size; // Hidden size used by input
std::uint32_t head_size; // Head size
std::uint32_t rotary_embedding_dim; // Rotary embedding dimension.
std::uint32_t num_heads; // num_heads = hidden_size / head_size
std::uint32_t max_sequence_length; // Sequence length used by cos/sin cache
std::uint32_t head_stride; // Head stride
std::uint32_t seq_stride; // Sequence stride
std::uint32_t batch_stride; // Batch stride
std::uint32_t position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size,
// sequence_length)
std::uint32_t seqlen_present_kv_cache; // Sequence length of present kv-cache (4096 when using
// shared buffer)
bool do_rotary; // Whether to use rotary position embedding. Default value is 0.
std::uint32_t kv_num_heads; // Number of attention heads for k and v
int local_window_size; // left_window_size for local attention. Default value is -1 meaning
// unused.
bool rotary_interleaved; // Rotate using interleaved pattern. Default value is 0 (False).
bool past_present_share_buffer; // Whether to use same buffer for KV-cache inputs and outputs

std::string make_init_str() const
{
return "MIGRAPHX_MAKE_CONSTANT(float{" + std::to_string(scale) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(batch_size) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(sequence_length) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(hidden_size) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(head_size) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(rotary_embedding_dim) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(num_heads) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(max_sequence_length) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(head_stride) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(seq_stride) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(batch_stride) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(position_ids_format) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(seqlen_present_kv_cache) +
"}), " + "MIGRAPHX_MAKE_CONSTANT(bool{" +
std::to_string(static_cast<int>(do_rotary)) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(uint32_t{" + std::to_string(kv_num_heads) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(int32_t{" + std::to_string(local_window_size) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(bool{" +
std::to_string(static_cast<int>(rotary_interleaved)) + "}), " +
"MIGRAPHX_MAKE_CONSTANT(bool{" +
std::to_string(static_cast<int>(past_present_share_buffer)) + "})";
}
};

static inline gqa_parameters init_params(const std::vector<shape>& inputs, const value& v)
{
auto num_heads = v.at("num_heads").to<std::uint32_t>();
auto kv_num_heads = v.at("kv_num_heads").to<std::uint32_t>();
auto do_rotary = v.at("do_rotary").to<std::uint32_t>();
turneram marked this conversation as resolved.
Show resolved Hide resolved
auto local_window_size = v.at("local_window_size").to<std::uint32_t>();
auto rotary_interleaved = v.at("rotary_interleaved").to<std::uint32_t>();
turneram marked this conversation as resolved.
Show resolved Hide resolved
auto scale = v.at("scale").to<float>();
auto present_kv_seqlen = v.at("present_kv_seqlen").to<std::size_t>();

const auto& q_shape = inputs[0];
auto q_lens = q_shape.lens();
const std::size_t batch_size = q_lens[0];
const std::size_t sequence_length = q_lens[2];
std::size_t head_size = q_lens[3];
auto q_hidden_size = kv_num_heads * head_size;

std::size_t rotary_dim = inputs[3].lens()[1] * 2;
auto seq_stride = head_size;
auto head_stride = sequence_length * seq_stride;
auto batch_stride = (num_heads + 2 * kv_num_heads) * head_stride;
auto position_ids_format = sequence_length == 1 ? 1 : 0;
bool past_present_share_buffer = true;
gqa_parameters gqa_params;
gqa_params.batch_size = batch_size;
gqa_params.sequence_length = sequence_length;
gqa_params.hidden_size = q_hidden_size;
gqa_params.head_size = head_size;
gqa_params.rotary_embedding_dim = rotary_dim;
gqa_params.num_heads = num_heads;
gqa_params.max_sequence_length = sequence_length;
gqa_params.seq_stride = head_size;
gqa_params.head_stride = sequence_length * gqa_params.seq_stride;
gqa_params.batch_stride = batch_stride;
gqa_params.position_ids_format = position_ids_format;
gqa_params.seqlen_present_kv_cache = present_kv_seqlen;
gqa_params.do_rotary = static_cast<bool>(do_rotary);
gqa_params.kv_num_heads = kv_num_heads;
gqa_params.local_window_size = local_window_size;
gqa_params.rotary_interleaved = static_cast<bool>(rotary_interleaved);
gqa_params.scale = scale;
gqa_params.past_present_share_buffer = past_present_share_buffer;

return gqa_params;
}

} // namespace gpu

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_GROUP_QUERY_ATTENTION_HPP
Loading
Loading