forked from microsoft/Stable-Diffusion-WebUI-DirectML
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfusion_shape.py
114 lines (94 loc) · 3.88 KB
/
fusion_shape.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# This file is modified from https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_shape.py
# We will remove this file once the change is in nightly or release package.
from logging import getLogger
from typing import Dict, List, Union
from numpy import ndarray
from onnx import NodeProto, TensorProto
from onnxruntime.transformers.fusion_base import Fusion
from onnxruntime.transformers.fusion_utils import FusionUtils
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionShape(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "Shape", "Concat")
self.utils = FusionUtils(model)
self.shape_infer = None
self.shape_infer_done = False
def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]:
if tensor_proto.type.tensor_type.HasField("shape"):
return len(tensor_proto.type.tensor_type.shape.dim)
else:
return None
def get_dimensions(self, input_name: str) -> Union[int, None]:
shape = self.model.get_dtype(input_name)
if shape is not None:
return len(shape)
if not self.shape_infer_done:
self.shape_infer = self.model.infer_runtime_shape({}, update=True)
self.shape_infer_done = True
if self.shape_infer is not None:
return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
return None
def fuse(
self,
concat_node: NodeProto,
input_name_to_nodes: Dict[str, List[NodeProto]],
output_name_to_node: Dict[str, NodeProto],
):
#
# Simplify subgraph like
#
# (2d_input)
# / \
# Shape shape
# / \
# Gather(indices=0) Gather(indices=1)
# | |
# Unsqueeze(axes=0) Unsqueeze(axes=0)
# \ /
# Concat
# |
#
# into (2d_input) --> Shape -->
#
opset_version = self.model.get_opset_version()
inputs = len(concat_node.input)
root = None
shape_output = None
for i in range(inputs):
path = self.model.match_parent_path(
concat_node,
["Unsqueeze", "Gather", "Shape"],
[i, 0, 0],
output_name_to_node,
)
if path is None:
return
unsqueeze, gather, shape = path
if i == 0:
shape_output = shape.output[0]
if root is None:
root = shape.input[0]
if self.get_dimensions(root) != inputs:
return
elif shape.input[0] != root:
return
if not FusionUtils.check_node_attribute(unsqueeze, "axis", 0, default_value=0):
return
if opset_version < 13:
if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
return
else:
if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
return
value = self.model.get_constant_value(gather.input[1])
if not (isinstance(value, ndarray) and value.size == 1 and value.item() == i):
return
if self.model.find_graph_output(concat_node.output[0]) is None:
self.model.replace_input_of_all_nodes(concat_node.output[0], shape_output)
self.increase_counter("Reshape")
self.prune_graph = True