Skip to content

Commit

Permalink
two outputs from prim
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Aug 1, 2024
1 parent bc775be commit 240fe4a
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 34 deletions.
11 changes: 7 additions & 4 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ struct lstm_seq : public primitive_base<lstm_seq> {
const lstm_weights_order offset_order = lstm_weights_order::iofz,
const uint32_t direction = 0,
const padding& output_padding = padding())
: primitive_base(id, {x, initial_hidden_state, initial_cell_state, seq_lenghts, W, R, B, out2_prim_id}, {output_padding}, {}, 1),
: primitive_base(id, {x, initial_hidden_state, initial_cell_state, seq_lenghts, W, R, B, out1_prim_id, out2_prim_id}, {output_padding}, {}, 1),
out1_prim_id(out1_prim_id),
out2_prim_id(out2_prim_id),
cell(cell),
clip(clip),
Expand All @@ -194,7 +195,7 @@ struct lstm_seq : public primitive_base<lstm_seq> {
direction(direction) {}

/// @brief Primitive id containing the initial value of the cell state data.
//primitive_id out1_prim_id;
primitive_id out1_prim_id;
primitive_id out2_prim_id;
primitive_id cell;
/// @brief Cell clip threshold T. It is applied to the input of activations [-T, T]. No clip is applied if it is not specified.
Expand Down Expand Up @@ -239,6 +240,8 @@ struct lstm_seq : public primitive_base<lstm_seq> {

#define cmp_fields(name) name == rhs_casted.name
return act_params_eq &&
cmp_fields(out1_prim_id) &&
cmp_fields(out2_prim_id) &&
cmp_fields(clip) &&
cmp_fields(input_forget) &&
cmp_fields(activations) &&
Expand All @@ -250,7 +253,7 @@ struct lstm_seq : public primitive_base<lstm_seq> {

void save(BinaryOutputBuffer& ob) const override {
primitive_base<lstm_seq>::save(ob);
//ob << out1_prim_id;
ob << out1_prim_id;
ob << out2_prim_id;
ob << cell;
ob << clip;
Expand All @@ -263,7 +266,7 @@ struct lstm_seq : public primitive_base<lstm_seq> {

void load(BinaryInputBuffer& ib) override {
primitive_base<lstm_seq>::load(ib);
//ib >> out1_prim_id;
ib >> out1_prim_id;
ib >> out2_prim_id;
ib >> cell;
ib >> clip;
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/lstm_seq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct lstm_seq_impl : typed_primitive_impl_ocl<lstm_seq> {
protected:
kernel_arguments_data get_arguments(const typed_primitive_inst<lstm_seq>& instance) const override {
kernel_arguments_data args = parent::get_arguments(instance);
args.outputs.push_back(instance.dep_memory_ptr(instance.desc()->input_size() - 2));
args.outputs.push_back(instance.dep_memory_ptr(instance.desc()->input_size() - 1));
return args;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,24 @@ KERNEL(lstm_seq)(
const __global INPUT5_TYPE* R,
const __global INPUT6_TYPE* B,
__global OUTPUT_TYPE* hidden_history,
__global OUTPUT1_TYPE* cell_state
__global OUTPUT1_TYPE* hidden_state,
__global OUTPUT2_TYPE* cell_state
)
{
const uint hidden_idx = get_global_id(0);
const uint b = get_global_id(1);
const int weight_offsets[4] = {GEMM_OFFSET_F, GEMM_OFFSET_I, GEMM_OFFSET_Z, GEMM_OFFSET_O};
const int gate_num = 4;
printf("b %d MAX_SEQ_LENGTH %d sequence_lengths %d\n", b, MAX_SEQ_LENGTH, sequence_lengths[INPUT3_GET_INDEX_SAFE(b, 0, 0, 0)]);
//printf("b %d MAX_SEQ_LENGTH %d sequence_lengths %d\n", b, MAX_SEQ_LENGTH, sequence_lengths[INPUT3_GET_INDEX_SAFE(b, 0, 0, 0)]);
ACCUMULATOR_TYPE hidden_result[gate_num];
ACCUMULATOR_TYPE input_result[gate_num];
ACCUMULATOR_TYPE gate_output[gate_num];

for(int k=0;k<gate_num;k++){
gate_output[k] = 0;
}
printf("DIRECTION %d \n", DIRECTION);
//printf("initial hidden state is %f for b %d hidden idx %d\n", initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)], b, hidden_idx);
//printf("offsets are %d %d %d %d \n", weight_offsets[0], weight_offsets[1], weight_offsets[2], weight_offsets[3]);
//printf("W is %d R is %d B is %d\n", INPUT4_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[0], 0, 0), INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[0], 0, 0), INPUT6_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[0], 0, 0));
const int real_seq_length = sequence_lengths[INPUT3_GET_INDEX_SAFE(b, 0, 0, 0)];// == MAX_SEQ_LENGTH ? MAX_SEQ_LENGTH: sequence_lengths[INPUT3_GET_INDEX_SAFE(b, 0, 0, 0)]+1;
//printf("DIRECTION %d \n", DIRECTION);
const int real_seq_length = sequence_lengths[INPUT3_GET_INDEX_SAFE(b, 0, 0, 0)];
for(int i=0;i<real_seq_length;i++){
for(int k=0;k<gate_num;k++){
hidden_result[k] = 0;
Expand Down Expand Up @@ -85,8 +83,13 @@ KERNEL(lstm_seq)(
if(DIRECTION){ //reverse
cur_history_idx = real_seq_length - 1 - i ;
}
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, cur_history_idx, hidden_idx)] = (OUTPUT_TYPE)(gate_output[3]*ACTIVATION_H(cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_H));
}
hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = (OUTPUT_TYPE)(gate_output[3]*ACTIVATION_H(cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_H));
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, cur_history_idx, hidden_idx)] = hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)];
} barrier(CLK_LOCAL_MEM_FENCE);

//printf("R is %p B is %p ; hidden history %p cell state %p batch %d\n", &R[0], &B[0], &hidden_history[0], &cell_state[0], b);
printf("result is %f %f fb %d\n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, 0, hidden_idx)], hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, 1, hidden_idx)], b);
for(int i=0;i<real_seq_length;i++){
//hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] = i;
//printf("result is %f for hididx %d b %d\n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)], hidden_idx, b);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ KernelsData LSTMSeqKernelBase::GetCommonKernelsData(const Params& params) const
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 6});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 0});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 1});
//kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 2});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 2});
auto cldnnJit = GetJitConstants(orgParams);
auto entryPoint = GetEntryPoint(kernelName, orgParams.layerID, params);
auto jit = CreateJit(kernelName, cldnnJit, entryPoint);
Expand Down
35 changes: 16 additions & 19 deletions src/plugins/intel_gpu/src/plugin/ops/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,34 +244,31 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op
std::vector<cldnn::activation_additional_params> activation_params;
GetLSTMActivationParams(op, activations, activation_params);
float clip = op->get_clip();
cldnn::primitive_id lstm_seq_id = layerName;// + "_lstm_seq";

auto mutable_precision_second = op->get_output_element_type(2);
cldnn::layout out2Layout = cldnn::layout(
cldnn::element_type_to_data_type(mutable_precision_second),
cldnn::primitive_id lstm_seq_id = layerName;
auto mutable_precision_firstsecond = op->get_output_element_type(1);
cldnn::layout out12Layout = cldnn::layout(
cldnn::element_type_to_data_type(mutable_precision_firstsecond),
cldnn::format::bfyx,
tensor_from_dims(op->get_output_shape(2)));
cldnn::memory::ptr shared_memory1 = p.get_engine().allocate_memory(out2Layout);
tensor_from_dims(op->get_output_shape(1)));

cldnn::memory::ptr shared_memory1 = p.get_engine().allocate_memory(out12Layout);
const cldnn::primitive_id mutable_id_1 = layerName + "_md_write1";
const cldnn::mutable_data mutable_prim_1{mutable_id_1, shared_memory1};
p.add_primitive(*op, mutable_prim_1);


cldnn::memory::ptr shared_memory2 = p.get_engine().allocate_memory(out12Layout);
const cldnn::primitive_id mutable_id_2 = layerName + "_md_write2";
const cldnn::mutable_data mutable_prim_2{mutable_id_2, shared_memory2};
p.add_primitive(*op, mutable_prim_2);
int direction = op->get_direction() == ov::op::RecurrentSequenceDirection::REVERSE ? 1 : 0;
cldnn::lstm_seq prim(lstm_seq_id + ".out0", inputs[0], inputs[1], \
inputs[2], inputs[3], inputs[4], inputs[5], cldnn::input_info(bias), "", mutable_id_1, \
inputs[2], inputs[3], inputs[4], inputs[5], cldnn::input_info(bias), mutable_id_1, mutable_id_2, \
"", clip, 0, activations, activation_params, cldnn::lstm_weights_order::fizo, direction);
//prim.out1_prim_id = f_id;
p.add_primitive(*op, prim);
p.add_primitive(*op, cldnn::mutable_data(lstm_seq_id + ".out2", {cldnn::input_info(lstm_seq_id + ".out0")}, shared_memory1));
int b = op->get_input_shape(0)[0];
int seqlen = op->get_input_shape(0)[1];
int hidden_size = op->get_input_shape(1)[2];
if (direction) {
p.add_primitive(*op, cldnn::crop(lstm_seq_id + ".out1", {cldnn::input_info(lstm_seq_id + ".out0")}, \
cldnn::tensor{ b, 1, hidden_size, 1}, cldnn::tensor{ 0, 0, 0, 0}));
} else {
p.add_primitive(*op, cldnn::crop(lstm_seq_id + ".out1", {cldnn::input_info(lstm_seq_id + ".out0")}, \
cldnn::tensor{ b, 1, hidden_size, 1}, cldnn::tensor{ 0, 0, 0, seqlen-1}));
}
p.add_primitive(*op, cldnn::mutable_data(lstm_seq_id + ".out1", {cldnn::input_info(lstm_seq_id + ".out0")}, shared_memory1));
p.add_primitive(*op, cldnn::mutable_data(lstm_seq_id + ".out2", {cldnn::input_info(lstm_seq_id + ".out0")}, shared_memory2));
}

REGISTER_FACTORY_IMPL(v4, LSTMCell);
Expand Down

0 comments on commit 240fe4a

Please sign in to comment.