-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy path17_09_4.py
74 lines (69 loc) · 2.08 KB
/
17_09_4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def tarjan(nb, graph, num, low, stack, in_stack, visited, ans, ans_):
num[nb] = low[nb] = len(visited)
stack.append(nb)
visited.add(nb)
in_stack.add(nb)
for dst in graph[nb][0]:
if dst not in visited:
tarjan(dst, graph, num, low, stack, in_stack, visited, ans, ans_)
low[nb] = min(low[nb], low[dst])
elif dst in in_stack:
low[nb] = min(low[nb], num[dst])
if num[nb] == low[nb]:
out = -1
ans.append(set())
while out != nb:
out = stack.pop()
in_stack.remove(out)
ans[-1].add(out)
ans_[out] = len(ans) - 1
def dfs(nb, tree, tree_child, ans, visited):
visited.add(nb)
for dst in tree[nb]:
if dst not in visited:
dfs(dst, tree, tree_child, ans, visited)
tree_child[nb] |= ans[dst] | tree_child[dst]
n, m = [int(it) for it in input().rstrip().split(' ')]
graph = [[[], []] for i in range(n)]
for i in range(m):
src, dst = [int(it) - 1 for it in input().rstrip().split(' ')]
graph[src][0].append(dst)
graph[dst][1].append(src)
stack = []
ans = []
ans_ = [-1] * n
num = [-1] * n
low = [-1] * n
in_stack = set()
visited = set()
for i in range(n):
if i not in visited:
tarjan(i, graph, num, low, stack, in_stack, visited, ans, ans_)
tree = []
tree_ = []
tree_child = [set() for i in range(len(ans))]
tree_child_ = [set() for i in range(len(ans))]
visited = set()
visited_ = set()
for i, an in enumerate(ans):
tree.append(set())
tree_.append(set())
for a in an:
for dst in graph[a][0]:
tree[-1].add(ans_[dst])
for src in graph[a][1]:
tree_[-1].add(ans_[src])
if i in tree[-1]:
tree[-1].remove(i)
if i in tree_[-1]:
tree_[-1].remove(i)
for i in range(len(ans)):
if i not in visited:
dfs(i, tree, tree_child, ans, visited)
if i not in visited_:
dfs(i, tree_, tree_child_, ans, visited_)
count = 0
for i in range(len(ans)):
if len(tree_child[i]) + len(tree_child_[i]) + len(ans[i]) == n:
count += len(ans[i])
print(count)