Skip to content

Commit

Permalink
Fix bug in list indexes that caused ViT models to raise errors. In ad…
Browse files Browse the repository at this point in the history
…dition, modified number of astar iterations from 500 to 1000 in compute_graph_max_cut to ensure ViT models function correctly
  • Loading branch information
liord committed Mar 4, 2025
1 parent aa2a182 commit 0608655
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def compute_graph_max_cut(memory_graph: MemoryGraph,
n_iter: int = 50,
astar_n_iter: int = 500,
astar_n_iter: int = 1000,
eps: float = 1e-2) -> Tuple[List[BaseNode], float, List[Cut]]:
"""
A wrapper function to compute max cut and schedule for a given model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, transpose
matmul_name = f'{attention_node_name}_matmul1'
return FunctionalNode(name=matmul_name,
framework_attr={},
input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape)),
input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape[0])),
output_shape=tuple(matmul1_output_shape),
weights={},
layer_class=torch.matmul,
Expand Down

0 comments on commit 0608655

Please sign in to comment.