Skip to content

Commit

Permalink
Fix bug that caused ViT models to raise errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
liord committed Mar 4, 2025
1 parent aa2a182 commit 3254880
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def compute_graph_max_cut(memory_graph: MemoryGraph,
n_iter: int = 50,
astar_n_iter: int = 500,
astar_n_iter: int = 1000,
eps: float = 1e-2) -> Tuple[List[BaseNode], float, List[Cut]]:
"""
A wrapper function to compute max cut and schedule for a given model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, transpose
matmul_name = f'{attention_node_name}_matmul1'
return FunctionalNode(name=matmul_name,
framework_attr={},
input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape)),
input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape[0])),
output_shape=tuple(matmul1_output_shape),
weights={},
layer_class=torch.matmul,
Expand Down

0 comments on commit 3254880

Please sign in to comment.