-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPU] LSTMSequence and LSTMCell optimization #26767
Changes from 160 commits
9ce143a
027f991
c191c58
d461e66
01fa2ac
1f017fd
5787c7d
837db22
d4ce531
19c268e
f5273bc
d50b3be
63a8dfd
f54ecc1
3748a11
fae772a
c00ff8a
1c08b14
6968881
31fcb79
4b16eef
dcad182
d6aeb54
9688f63
5003d47
5937b14
8b31a91
2ff5a7c
2d9e5c6
702e941
f4d3b71
0c7103c
f37482a
0058c57
1ac26d3
31040bf
83aa74f
0cce00c
0e37c8a
72b48d1
7a747c5
9b99f04
5052e26
aa5d906
4b524fd
c47c943
e486376
fe72cc8
d62f223
0b1fa3d
3e1fe20
cac921c
a165f30
8f74962
732eb52
165dd9b
1b9cc98
37ab01b
c99ddc0
60a0675
1b23648
7c1bf37
40abc31
56031d9
d954fe8
78cc4fc
81ca2ed
431d937
6b6800f
db8d75b
a9cd3cf
bfb80ba
57faed2
7f097ba
a79eca5
8d4e46b
00c6237
31b8ef0
07c1ac2
a78ef3a
5bcab62
d564228
b16bdac
173b5b2
c8eb682
43acd2b
0002e54
7741a46
ac352ea
d0fb8b4
a1497c4
f12aebd
6170710
01dc7dc
e9bf370
7158776
5e21106
daa83b5
6af1f3f
02942e5
f8dbec3
2abe8f8
00826ad
36b4853
c358ab3
19b1d93
4da2df6
303bf7d
bf9f13f
459e1ad
14e53f4
063ac02
892131b
b539a3f
dc8ac73
a5165a8
cc6b4b5
b168fbe
c38c321
73cde93
99a3ca5
b5ca43f
1d1deb7
f14ed32
fb12ef6
63bddcb
2b6ad65
8e1d36c
bd81512
e4fffa5
aea04bd
7af4e1e
75d0b67
acd6369
0430ee3
122ad96
a54f49a
79ef461
31313c9
d3ec29b
09b9283
bde6fb6
fcee4af
320c0d9
c7ccf49
421eb89
c6e160f
0adaa9b
fee5623
3482f33
e174f1a
06f394a
483e637
819e21f
72232ea
178f7fa
3307752
2d58b17
e1fcc01
9fa6fff
8bca96d
63ee3b4
c99a348
b551bda
d1bec7b
0a9756d
2a068c2
fcdaab0
6685209
debeb39
7ce51bc
26d0e50
68ae65b
e3b02c2
89bca48
33c9256
36e02cb
2c68353
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
#include "primitive.hpp" | ||
#include "activation.hpp" | ||
#include <vector> | ||
#include <algorithm> | ||
#include "intel_gpu/graph/serialization/activation_serializer.hpp" | ||
#include "rnn.hpp" | ||
|
||
|
||
namespace cldnn { | ||
|
||
struct lstm_cell : public RNNParams<lstm_cell> { | ||
CLDNN_DECLARE_PRIMITIVE(lstm_cell) | ||
using vec_activation = std::vector<activation_func>; | ||
using vec_activation_param = std::vector<activation_additional_params>; | ||
using RNNParams::RNNParams; | ||
lstm_cell(const lstm_cell&) = default; | ||
lstm_cell() : RNNParams() {} | ||
}; | ||
} // namespace cldnn |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
#include "primitive.hpp" | ||
#include "activation.hpp" | ||
#include <vector> | ||
#include <algorithm> | ||
#include <string> | ||
#include "intel_gpu/graph/serialization/activation_serializer.hpp" | ||
|
||
namespace cldnn { | ||
|
||
/// @brief Weights orders | ||
/// @details Specifies the order in which the weights are concatenated. | ||
/// e.g. [i, o, f, z] : [input, output, forget, block] | ||
/// ONNX order: iofz | ||
/// Caffe order: ifoz | ||
/// pyTorch order: izof | ||
/// OV order: fizo | ||
enum class lstm_weights_order { | ||
iofz, | ||
ifoz, | ||
izof, | ||
fizo | ||
}; | ||
|
||
template <typename PType> | ||
struct RNNParams : public primitive_base<PType> { | ||
RNNParams() : primitive_base<PType>("", {}) {} | ||
RNNParams(const RNNParams&) = default; | ||
RNNParams(const primitive_id& id, | ||
p-durandin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const input_info& x, | ||
const input_info& initial_hidden_state, | ||
const input_info& initial_cell_state, | ||
const input_info& W, | ||
const input_info& R, | ||
const input_info& B, | ||
const input_info& seq_lenghts, | ||
const primitive_id& out1_prim_id = "", | ||
const primitive_id& out2_prim_id = "", | ||
const float clip = 0, | ||
bool input_forget = false, | ||
const std::vector<activation_func>& activations = {activation_func::logistic, | ||
activation_func::hyperbolic_tan, | ||
activation_func::hyperbolic_tan}, | ||
const std::vector<activation_additional_params>& activation_params = {}, | ||
const lstm_weights_order& offset_order = lstm_weights_order::iofz, | ||
const ov::op::RecurrentSequenceDirection direction = ov::op::RecurrentSequenceDirection::FORWARD, | ||
const padding& output_padding = padding(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think padding arg is not needed as it's always set as default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
const int num_outputs = 1) | ||
: primitive_base<PType>(id, {x}, num_outputs, {optional_data_type()}, {output_padding}), | ||
x(x), | ||
initial_hidden_state(initial_hidden_state), | ||
initial_cell_state(initial_cell_state), | ||
W(W), | ||
R(R), | ||
B(B), | ||
seq_lenghts(seq_lenghts), | ||
out1_prim_id(out1_prim_id), | ||
out2_prim_id(out2_prim_id), | ||
clip(clip), | ||
input_forget(input_forget), | ||
activations(activations), | ||
activation_params(activation_params), | ||
offset_order(offset_order), | ||
direction(direction) { | ||
std::vector<std::string> pids{initial_hidden_state.pid, initial_cell_state.pid, W.pid, R.pid, B.pid, seq_lenghts.pid, out1_prim_id, out2_prim_id}; | ||
for (auto pid : pids) { | ||
if (!pid.empty()) { | ||
primitive_base<PType>::input.push_back(pid); | ||
} | ||
} | ||
} | ||
|
||
input_info x; | ||
input_info initial_hidden_state; | ||
input_info initial_cell_state; | ||
input_info W; | ||
input_info R; | ||
input_info B; | ||
input_info seq_lenghts; | ||
primitive_id out1_prim_id; | ||
primitive_id out2_prim_id; | ||
/// @brief Cell clip threshold T. It is applied to the input of activations [-T, T]. No clip is applied if it is not specified. | ||
float clip; | ||
bool input_forget; | ||
/// @brief A list of 3 activation functions for the input, output, forget, cell, and hidden. | ||
std::vector<activation_func> activations; | ||
/// @brief Optional scaling values used by some activation functions. The values are consumed in the order of activation functions. | ||
std::vector<activation_additional_params> activation_params; | ||
/// @brief Weights, recurrent weights, and biases order. [iofz] : ONNX, [ifoz] : Caffe | ||
lstm_weights_order offset_order; | ||
/// @brief direction of LSTMSequence - only FORWARD or REVERSE, currently BIDIRECTIONAL not supported | ||
ov::op::RecurrentSequenceDirection direction; | ||
|
||
int num_directions() const { | ||
return direction == ov::op::RecurrentSequenceDirection::BIDIRECTIONAL ? 2 : 1; | ||
} | ||
|
||
size_t hash() const override { | ||
size_t seed = primitive::hash(); | ||
seed = hash_combine(seed, x.pid); | ||
seed = hash_combine(seed, initial_hidden_state.pid); | ||
seed = hash_combine(seed, initial_cell_state.pid); | ||
seed = hash_combine(seed, seq_lenghts.pid); | ||
seed = hash_combine(seed, W.pid); | ||
seed = hash_combine(seed, R.pid); | ||
seed = hash_combine(seed, B.pid); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comparison and hashing of the primitive ids prevents primitive reuse if we have multiple instances of the same op. So you shall just hash/compare only presence flag to all inputs. As an example you can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
seed = hash_combine(seed, out1_prim_id); | ||
seed = hash_combine(seed, out2_prim_id); | ||
seed = hash_combine(seed, clip); | ||
seed = hash_range(seed, activations.begin(), activations.end()); | ||
for (auto& act_param : activation_params) { | ||
seed = hash_combine(seed, act_param.a); | ||
seed = hash_combine(seed, act_param.b); | ||
} | ||
seed = hash_combine(seed, offset_order); | ||
seed = hash_combine(seed, direction); | ||
return seed; | ||
} | ||
|
||
bool operator==(const primitive& rhs) const override { | ||
if (!primitive::compare_common_params(rhs)) | ||
return false; | ||
|
||
auto rhs_casted = downcast<const PType>(rhs); | ||
bool act_params_eq = activation_params.size() == rhs_casted.activation_params.size(); | ||
for (size_t i = 0; i < activation_params.size(); ++i) { | ||
act_params_eq &= activation_params[i].a == rhs_casted.activation_params[i].a && | ||
activation_params[i].b == rhs_casted.activation_params[i].b; | ||
} | ||
|
||
#define cmp_fields(name) name == rhs_casted.name | ||
return act_params_eq && | ||
cmp_fields(x) && | ||
cmp_fields(initial_hidden_state) && | ||
cmp_fields(initial_cell_state) && | ||
cmp_fields(seq_lenghts) && | ||
cmp_fields(W) && | ||
cmp_fields(R) && | ||
cmp_fields(B) && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here you also shouldn't compare string values, but rather check presence of inputs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
cmp_fields(out1_prim_id) && | ||
cmp_fields(out2_prim_id) && | ||
cmp_fields(clip) && | ||
cmp_fields(activations) && | ||
cmp_fields(offset_order) && | ||
cmp_fields(direction); | ||
#undef cmp_fields | ||
} | ||
|
||
void save(BinaryOutputBuffer& ob) const override { | ||
primitive_base<PType>::save(ob); | ||
ob << x; | ||
ob << initial_hidden_state; | ||
ob << initial_cell_state; | ||
ob << W; | ||
ob << R; | ||
ob << B; | ||
ob << seq_lenghts; | ||
ob << out1_prim_id; | ||
ob << out2_prim_id; | ||
ob << clip; | ||
ob << activations; | ||
ob << activation_params; | ||
ob << make_data(&offset_order, sizeof(lstm_weights_order)); | ||
ob << make_data(&direction, sizeof(ov::op::RecurrentSequenceDirection)); | ||
} | ||
|
||
void load(BinaryInputBuffer& ib) override{ | ||
primitive_base<PType>::load(ib); | ||
ib >> x; | ||
ib >> initial_hidden_state; | ||
ib >> initial_cell_state; | ||
ib >> W; | ||
ib >> R; | ||
ib >> B; | ||
ib >> seq_lenghts; | ||
ib >> out1_prim_id; | ||
ib >> out2_prim_id; | ||
ib >> clip; | ||
ib >> activations; | ||
ib >> activation_params; | ||
ib >> make_data(&offset_order, sizeof(lstm_weights_order)); | ||
ib >> make_data(&direction, sizeof(ov::op::RecurrentSequenceDirection)); | ||
} | ||
}; | ||
|
||
struct lstm_seq : public RNNParams<lstm_seq> { | ||
CLDNN_DECLARE_PRIMITIVE(lstm_seq) | ||
using vec_activation = std::vector<activation_func>; | ||
using vec_activation_param = std::vector<activation_additional_params>; | ||
using RNNParams::RNNParams; | ||
lstm_seq() : RNNParams() { | ||
weights = W.pid; | ||
input = x.pid; | ||
} | ||
lstm_seq(const lstm_seq&) = default; | ||
primitive_id input; | ||
primitive_id weights; | ||
}; | ||
} //namespace cldnn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you still keep (and use)
lstm_elt
primitive given that you introducelstm_cell
andlstm_seq
? I'd expect that it's not needed anymoreThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is poor performance of onednn on case of seq_len = 1 , so I don't update it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you try enabling ngraph pass for decomposition in such case? Ideally we need to get rid of this lstm_elt primitive and related decomposition code in program builder to switch to new shape inference completely
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lstm_cell which will be used in such case is too slow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of the initial goals of this patch was removing this lstm decomposition on the plugin side to bunch of custom primitives (and thus removing lstm_elt primitive). And that's still needed.
Also, as I can see, lstm_cell primitive is not used at all currently, which means there's no sense to add it. So my suggestion is to continue perf tuning then.