-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_graph_viz.py
112 lines (97 loc) · 3.85 KB
/
_graph_viz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
def _shape_notation(int_shape):
X = ['S','B','C','H','W']
return [X[i] for i in int_shape]
def plot_graph(graph, graph_img_path='graph.png', show_coreml_mapped_shapes=False):
"""
Plot graph using pydot
It works in two steps:
1. Add nodes to pydot
2. connect nodes added in pydot
:param graph
:return: writes down a png/pdf file using dot
"""
try:
# pydot-ng is a fork of pydot that is better maintained.
import pydot_ng as pydot # type: ignore
except:
# pydotplus is an improved version of pydot
try:
import pydotplus as pydot # type: ignore
except:
# Fall back on pydot if necessary.
try:
import pydot # type: ignore
except:
return None
dot = pydot.Dot()
dot.set('rankdir', 'TB')
dot.set('concentrate', True)
dot.set_node_defaults(shape='record')
# Add nodes corresponding to graph inputs
graph_inputs = []
for input_ in graph.inputs:
if show_coreml_mapped_shapes:
if input_[0] in graph.onnx_coreml_shape_mapping:
shape = tuple(_shape_notation(graph.onnx_coreml_shape_mapping[input_[0]]))
else:
shape = 'NA, '
else:
shape = tuple(input_[2])
label = '%s\n|{|%s}|{{%s}|{%s}}' % ('Input',
input_[0],
'',
str(shape))
pydot_node = pydot.Node(input_[0], label=label)
dot.add_node(pydot_node)
graph_inputs.append(input_[0])
# Traverse graph and add nodes to pydot
for node in graph.nodes:
inputlabels = ''
for input_ in node.inputs:
if show_coreml_mapped_shapes:
if input_ in graph.onnx_coreml_shape_mapping:
inputlabels += str(tuple(_shape_notation(graph.onnx_coreml_shape_mapping[input_]))) + ', '
else:
inputlabels += 'NA, '
else:
if input_ in graph.shape_dict:
inputlabels += str(tuple(graph.shape_dict[input_])) + ', '
else:
inputlabels += 'NA, '
outputlabels = ''
for output_ in node.outputs:
if show_coreml_mapped_shapes:
if output_ in graph.onnx_coreml_shape_mapping:
outputlabels += str(tuple(_shape_notation(graph.onnx_coreml_shape_mapping[output_]))) + ', '
else:
outputlabels += 'NA, '
else:
if output_ in graph.shape_dict:
outputlabels += str(tuple(graph.shape_dict[output_])) + ', '
else:
outputlabels += 'NA, '
output_names = ', '.join([output_ for output_ in node.outputs])
input_names = ', '.join([input_ for input_ in node.inputs])
label = '%s\n|{{%s}|{%s}}|{{%s}|{%s}}' % (node.op_type,
input_names,
output_names,
inputlabels,
outputlabels)
pydot_node = pydot.Node(node.name, label=label)
dot.add_node(pydot_node)
# add edges
for node in graph.nodes:
for child in node.children:
# add edge in pydot
dot.add_edge(pydot.Edge(node.name, child.name))
for input_ in node.inputs:
if input_ in graph_inputs:
dot.add_edge(pydot.Edge(input_, node.name))
# write out the image file
_, extension = os.path.splitext(graph_img_path)
if not extension:
extension = 'pdf'
else:
extension = extension[1:]
dot.write(graph_img_path, format=extension)