From 6ed77866697a4ed069af063bf0927a97964c67fa Mon Sep 17 00:00:00 2001 From: Lior Dikstein <78903511+lior-dikstein@users.noreply.github.com> Date: Wed, 5 Mar 2025 12:18:46 +0200 Subject: [PATCH] Fix bug in list indexes that caused ViT models to raise errors. In addition, modified number of astar iterations from 500 to 1000 in compute_graph_max_cut to ensure ViT models function correctly (#1374) Co-authored-by: liord --- .../core/common/graph/memory_graph/compute_graph_max_cut.py | 2 +- .../substitutions/scaled_dot_product_attention.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py b/model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py index 324a3fdf0..855184adb 100644 --- a/model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +++ b/model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py @@ -27,7 +27,7 @@ def compute_graph_max_cut(memory_graph: MemoryGraph, n_iter: int = 50, - astar_n_iter: int = 500, + astar_n_iter: int = 1000, eps: float = 1e-2) -> Tuple[List[BaseNode], float, List[Cut]]: """ A wrapper function to compute max cut and schedule for a given model. diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py index 0e64120cf..948113a1d 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py @@ -103,7 +103,7 @@ def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, transpose matmul_name = f'{attention_node_name}_matmul1' return FunctionalNode(name=matmul_name, framework_attr={}, - input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape)), + input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape[0])), output_shape=tuple(matmul1_output_shape), weights={}, layer_class=torch.matmul,