From 889448b113f5a2538b76e54ff537c28b3b7da845 Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Thu, 20 Feb 2025 13:22:36 -0800 Subject: [PATCH] update SqueezeInt4LinearInputs to process relu/gelu inputs too (#8601) Summary: Update/rename SqueezeInt4LinearInputs pass so it wraps gelu/relu with squeeze/unsqueeze view ops too Differential Revision: D69673068 --- backends/transforms/fuse_view_copy.py | 17 +++++++++++++++++ backends/vulkan/_passes/TARGETS | 7 ++++--- backends/vulkan/_passes/__init__.py | 6 +++--- ...ar_inputs.py => squeeze_unsqueeze_inputs.py} | 17 ++++++++++++++--- backends/vulkan/vulkan_preprocess.py | 4 ++-- 5 files changed, 40 insertions(+), 11 deletions(-) rename backends/vulkan/_passes/{squeeze_int4_linear_inputs.py => squeeze_unsqueeze_inputs.py} (80%) diff --git a/backends/transforms/fuse_view_copy.py b/backends/transforms/fuse_view_copy.py index bbc155dc45..22e20d1c88 100644 --- a/backends/transforms/fuse_view_copy.py +++ b/backends/transforms/fuse_view_copy.py @@ -40,7 +40,24 @@ def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph: return graph +def remove_noop_view_copy(graph: torch.fx.Graph) -> torch.fx.Graph: + """ + Remove view_copy nodes that are no-ops. + """ + ops = exir_ops.edge + view_op = ops.aten.view_copy.default + for node in graph.nodes: + if node.op == "call_function" and node.target == view_op: + input_shape = list(node.args[0].meta["val"].shape) + target_shape = node.args[1] + if input_shape == target_shape: + node.replace_all_uses_with(node.args[0]) + graph.eliminate_dead_code() + return graph + + class FuseViewCopyTransform(ExportPass): def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph_module.graph = merge_view_copy_chains(graph_module.graph) + graph_module.graph = remove_noop_view_copy(graph_module.graph) return PassResult(graph_module, True) diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 59658e58f2..5478ad0eab 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -31,14 +31,15 @@ runtime.python_library( ) runtime.python_library( - name = "squeeze_int4_linear_inputs", + name = "squeeze_unsqueeze_inputs", srcs = [ - "squeeze_int4_linear_inputs.py", + "squeeze_unsqueeze_inputs.py", ], visibility = [ "//executorch/backends/...", ], deps = [ + "//caffe2:torch", "//executorch/backends/vulkan:custom_ops_lib", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", @@ -114,7 +115,7 @@ runtime.python_library( ":remove_asserts", ":remove_local_scalar_dense", ":remove_redundant_ops", - ":squeeze_int4_linear_inputs", + ":squeeze_unsqueeze_inputs", ":tag_memory_meta_pass", ] ) diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 2a4a2b4b5c..220afa6a35 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -20,8 +20,8 @@ from executorch.backends.vulkan._passes.remove_redundant_ops import ( RemoveRedundantOpsTransform, ) -from executorch.backends.vulkan._passes.squeeze_int4_linear_inputs import ( - SqueezeInt4LinearInputs, +from executorch.backends.vulkan._passes.squeeze_unsqueeze_inputs import ( + SqueezeUnsqueezeInputs, ) from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass @@ -32,6 +32,6 @@ "RemoveAssertsTransform", "RemoveLocalScalarDenseOpsTransform", "RemoveRedundantOpsTransform", - "SqueezeInt4LinearInputs", + "SqueezeUnsqueezeInputs", "TagMemoryMetaPass", ] diff --git a/backends/vulkan/_passes/squeeze_int4_linear_inputs.py b/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py similarity index 80% rename from backends/vulkan/_passes/squeeze_int4_linear_inputs.py rename to backends/vulkan/_passes/squeeze_unsqueeze_inputs.py index 95fcef7f75..a0160efa90 100644 --- a/backends/vulkan/_passes/squeeze_int4_linear_inputs.py +++ b/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py @@ -6,16 +6,27 @@ # pyre-strict -from typing import Dict, List, Tuple +from typing import Dict, List, Set, Tuple, Union import executorch.backends.vulkan.custom_ops_lib # noqa: needed to access vk op from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch._ops import OpOverload + from torch.fx.node import Argument +OpType = Union[str, OpOverload, EdgeOpOverload] + + +class SqueezeUnsqueezeInputs(ExportPass): + _squeezable_ops: Set[OpType] = { + exir_ops.edge.et_vk.linear_weight_int4.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.gelu.default, + } -class SqueezeInt4LinearInputs(ExportPass): def call_operator( self, op, # pyre-ignore @@ -26,7 +37,7 @@ def call_operator( def _squeezable(shape: List[int]) -> bool: return len(shape) > 2 and 1 in shape - if op != exir_ops.edge.et_vk.linear_weight_int4.default: + if op not in self._squeezable_ops: return super().call_operator(op, args, kwargs, meta) # pyre-ignore[16]: `None` has no attribute `node` diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index c6b444e5de..3cfcac13a8 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -26,7 +26,7 @@ insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, RemoveRedundantOpsTransform, - SqueezeInt4LinearInputs, + SqueezeUnsqueezeInputs, TagMemoryMetaPass, ) @@ -153,7 +153,7 @@ def preprocess( # noqa: C901 RemoveRedundantOpsTransform(), AddmmToLinearTransform(), FuseDequantLinearPass(), - SqueezeInt4LinearInputs(), + SqueezeUnsqueezeInputs(), FuseViewCopyTransform(), ViewCopyToSqueezeUnsqueezePass(), FuseBatchNormWithConvPass(program),