Skip to content

Commit

Permalink
[PT FE] Fix sym GPTQ pattern to have consistent graph (openvinotoolki…
Browse files Browse the repository at this point in the history
…t#27037)

### Details:
 - *Fix sym GPTQ pattern to have consistent graph*

### Tickets:
 - *ticket-id*
  • Loading branch information
mvafin authored and CuriousPanCake committed Nov 6, 2024
1 parent a314fe0 commit 9f86920
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/bindings/python/src/openvino/frontend/pytorch/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def patched_forward_sym(self, *args, **kwargs):
unpacked_weights, 1, 2).contiguous().view(-1, self.group_size, self.width)

# all zp is 8 for symmetrical, will repack to i4 in pt fe transformation
unpacked_weights = unpacked_weights.to(dtype) * self.scales
unpacked_weights = (unpacked_weights.to(torch.int8) - torch.tensor(8, dtype=torch.int8))
unpacked_weights = unpacked_weights.to(dtype) * self.scales
unpacked_weights = unpacked_weights.view(-1, self.width)

out = x @ unpacked_weights
Expand Down
23 changes: 21 additions & 2 deletions src/frontends/pytorch/src/transforms/u4_block_repack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "openvino/core/rt_info.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
Expand Down Expand Up @@ -53,6 +54,7 @@ U4BlockRepack::U4BlockRepack(bool is_symmetrical) {
auto reshape1 = pattern_to_output[m_reshape1].get_node_shared_ptr();
auto transpose = pattern_to_output[m_transpose].get_node_shared_ptr();
auto reshape2 = pattern_to_output[m_reshape2].get_node_shared_ptr();
auto pattern_root = reshape2;

if (constant->get_element_type() != element::u4)
return false;
Expand All @@ -76,9 +78,26 @@ U4BlockRepack::U4BlockRepack(bool is_symmetrical) {

auto get_number = get_u4;
auto constant_dtype = element::u4;
NodeVector copy_from{std::move(constant), std::move(reshape1), std::move(transpose), reshape2};
if (is_symmetrical) {
get_number = get_i4;
constant_dtype = element::i4;
// find pattern Convert(W, i8) -> Subtract(8)
auto reshape_targets = reshape2->output(0).get_target_inputs();
if (reshape_targets.size() != 1)
return false;
auto convert = reshape_targets.begin()->get_node()->shared_from_this();
if (!std::dynamic_pointer_cast<ov::op::v0::Convert>(convert))
return false;
auto convert_targets = convert->output(0).get_target_inputs();
if (convert_targets.size() != 1)
return false;
auto subtract = convert_targets.begin()->get_node()->shared_from_this();
if (!std::dynamic_pointer_cast<ov::op::v1::Subtract>(subtract))
return false;
pattern_root = subtract;
copy_from.push_back(std::move(convert));
copy_from.push_back(subtract);
}
auto new_const = std::make_shared<v0::Constant>(constant_dtype, destination_shape);
auto dst = const_cast<uint8_t*>( // const_cast?
Expand All @@ -96,8 +115,8 @@ U4BlockRepack::U4BlockRepack(bool is_symmetrical) {
}
}

copy_runtime_info({std::move(constant), std::move(reshape1), std::move(transpose), reshape2}, new_const);
replace_node(reshape2, new_const);
copy_runtime_info(copy_from, new_const);
replace_node(pattern_root, new_const);

return true;
});
Expand Down
2 changes: 1 addition & 1 deletion tests/model_hub_tests/pytorch/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def load_model(self, name, type):
example["past_key_values"] = pkv
example["attention_mask"] = torch.cat(
[example["attention_mask"], am], -1)
if atype not in ["opt", "falcon", "mbart_gptq", "mpt"]:
if atype not in ["opt", "falcon", "mbart", "mpt"]:
ids = torch.cumsum(example["attention_mask"] != 0, dim=1) - 1
example["position_ids"] = ids[:, -
example["input_ids"].shape[1]:]
Expand Down

0 comments on commit 9f86920

Please sign in to comment.