From 58e9ea0b42bef40b0669e1a6eba1032777300811 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 2 Dec 2021 14:31:16 -0800 Subject: [PATCH] fix: Fix fuse addmm pass Signed-off-by: Dheeraj Peri --- core/lowering/passes/fuse_addmm_branches.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/lowering/passes/fuse_addmm_branches.cpp b/core/lowering/passes/fuse_addmm_branches.cpp index 3a3e640d09..19f34b9b9e 100644 --- a/core/lowering/passes/fuse_addmm_branches.cpp +++ b/core/lowering/passes/fuse_addmm_branches.cpp @@ -49,7 +49,7 @@ struct AddMMBranchFusion { if ((*arm1_start)->kind().toQualString() == std::string("aten::addmm") && (*(++arm1_start))->kind() == prim::Return && (*arm2_start)->kind().toQualString() == std::string("aten::matmul") && - (*(++arm2_start))->kind().toQualString() != std::string("aten::add") && + (*(++arm2_start))->kind().toQualString() == std::string("aten::add") && (*(++arm2_start))->kind() == prim::Return) { // Make sure that block0 is solely just the aten::addmm op and block1 is matmul + add return true;